├── .coveragerc ├── .flake8 ├── .github ├── ISSUE_TEMPLATE │ └── custom.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── algo_test.yml │ ├── badge.yml │ ├── deploy.yml │ ├── doc.yml │ ├── envpool_test.yml │ ├── platform_test.yml │ ├── release.yml │ ├── release_conda.yml │ ├── style.yml │ └── unit_test.yml ├── .gitignore ├── .style.yapf ├── CHANGELOG ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── assets └── wechat.jpeg ├── cloc.sh ├── codecov.yml ├── conda ├── conda_build_config.yaml └── meta.yaml ├── ding ├── __init__.py ├── bonus │ ├── __init__.py │ ├── a2c.py │ ├── c51.py │ ├── common.py │ ├── config.py │ ├── ddpg.py │ ├── dqn.py │ ├── model.py │ ├── pg.py │ ├── ppo_offpolicy.py │ ├── ppof.py │ ├── sac.py │ ├── sql.py │ └── td3.py ├── compatibility.py ├── config │ ├── __init__.py │ ├── config.py │ ├── example │ │ ├── A2C │ │ │ ├── __init__.py │ │ │ ├── gym_bipedalwalker_v3.py │ │ │ └── gym_lunarlander_v2.py │ │ ├── C51 │ │ │ ├── __init__.py │ │ │ ├── gym_lunarlander_v2.py │ │ │ ├── gym_pongnoframeskip_v4.py │ │ │ ├── gym_qbertnoframeskip_v4.py │ │ │ └── gym_spaceInvadersnoframeskip_v4.py │ │ ├── DDPG │ │ │ ├── __init__.py │ │ │ ├── gym_bipedalwalker_v3.py │ │ │ ├── gym_halfcheetah_v3.py │ │ │ ├── gym_hopper_v3.py │ │ │ ├── gym_lunarlandercontinuous_v2.py │ │ │ ├── gym_pendulum_v1.py │ │ │ └── gym_walker2d_v3.py │ │ ├── DQN │ │ │ ├── __init__.py │ │ │ ├── gym_lunarlander_v2.py │ │ │ ├── gym_pongnoframeskip_v4.py │ │ │ ├── gym_qbertnoframeskip_v4.py │ │ │ └── gym_spaceInvadersnoframeskip_v4.py │ │ ├── PG │ │ │ ├── __init__.py │ │ │ └── gym_pendulum_v1.py │ │ ├── PPOF │ │ │ ├── __init__.py │ │ │ ├── gym_lunarlander_v2.py │ │ │ └── gym_lunarlandercontinuous_v2.py │ │ ├── PPOOffPolicy │ │ │ ├── __init__.py │ │ │ ├── gym_lunarlander_v2.py │ │ │ ├── gym_pongnoframeskip_v4.py │ │ │ ├── gym_qbertnoframeskip_v4.py │ │ │ └── gym_spaceInvadersnoframeskip_v4.py │ │ ├── SAC │ │ │ ├── __init__.py │ │ │ ├── gym_bipedalwalker_v3.py │ │ │ ├── gym_halfcheetah_v3.py │ │ │ ├── gym_hopper_v3.py │ │ │ ├── gym_lunarlandercontinuous_v2.py │ │ │ ├── gym_pendulum_v1.py │ │ │ └── gym_walker2d_v3.py │ │ ├── SQL │ │ │ ├── __init__.py │ │ │ └── gym_lunarlander_v2.py │ │ ├── TD3 │ │ │ ├── __init__.py │ │ │ ├── gym_bipedalwalker_v3.py │ │ │ ├── gym_halfcheetah_v3.py │ │ │ ├── gym_hopper_v3.py │ │ │ ├── gym_lunarlandercontinuous_v2.py │ │ │ ├── gym_pendulum_v1.py │ │ │ └── gym_walker2d_v3.py │ │ └── __init__.py │ ├── tests │ │ └── test_config_formatted.py │ └── utils.py ├── data │ ├── __init__.py │ ├── buffer │ │ ├── __init__.py │ │ ├── buffer.py │ │ ├── deque_buffer.py │ │ ├── deque_buffer_wrapper.py │ │ ├── middleware │ │ │ ├── __init__.py │ │ │ ├── clone_object.py │ │ │ ├── group_sample.py │ │ │ ├── padding.py │ │ │ ├── priority.py │ │ │ ├── sample_range_view.py │ │ │ ├── staleness_check.py │ │ │ └── use_time_check.py │ │ └── tests │ │ │ ├── test_buffer.py │ │ │ ├── test_buffer_benchmark.py │ │ │ └── test_middleware.py │ ├── level_replay │ │ ├── __init__.py │ │ ├── level_sampler.py │ │ └── tests │ │ │ └── test_level_sampler.py │ ├── model_loader.py │ ├── shm_buffer.py │ ├── storage │ │ ├── __init__.py │ │ ├── file.py │ │ ├── storage.py │ │ └── tests │ │ │ └── test_storage.py │ ├── storage_loader.py │ └── tests │ │ ├── test_model_loader.py │ │ ├── test_shm_buffer.py │ │ └── test_storage_loader.py ├── design │ ├── dataloader-sequence.png │ ├── dataloader-sequence.puml │ ├── env_state.png │ ├── parallel_main-sequence.png │ ├── parallel_main-sequence.puml │ ├── serial_collector-activity.png │ ├── serial_collector-activity.puml │ ├── serial_evaluator-activity.png │ ├── serial_evaluator-activity.puml │ ├── serial_learner-activity.png │ ├── serial_learner-activity.puml │ ├── serial_main-sequence.png │ └── serial_main.puml ├── entry │ ├── __init__.py │ ├── application_entry.py │ ├── application_entry_trex_collect_data.py │ ├── cli.py │ ├── cli_ditask.py │ ├── cli_parsers │ │ ├── __init__.py │ │ ├── k8s_parser.py │ │ ├── slurm_parser.py │ │ └── tests │ │ │ ├── test_k8s_parser.py │ │ │ └── test_slurm_parser.py │ ├── dist_entry.py │ ├── parallel_entry.py │ ├── predefined_config.py │ ├── serial_entry.py │ ├── serial_entry_bc.py │ ├── serial_entry_bco.py │ ├── serial_entry_dqfd.py │ ├── serial_entry_gail.py │ ├── serial_entry_guided_cost.py │ ├── serial_entry_mbrl.py │ ├── serial_entry_ngu.py │ ├── serial_entry_offline.py │ ├── serial_entry_onpolicy.py │ ├── serial_entry_onpolicy_ppg.py │ ├── serial_entry_pc.py │ ├── serial_entry_plr.py │ ├── serial_entry_preference_based_irl.py │ ├── serial_entry_preference_based_irl_onpolicy.py │ ├── serial_entry_r2d3.py │ ├── serial_entry_reward_model_offpolicy.py │ ├── serial_entry_reward_model_onpolicy.py │ ├── serial_entry_sqil.py │ ├── serial_entry_td3_vae.py │ ├── tests │ │ ├── config │ │ │ ├── agconfig.yaml │ │ │ ├── dijob-cartpole.yaml │ │ │ └── k8s-config.yaml │ │ ├── test_application_entry.py │ │ ├── test_application_entry_trex_collect_data.py │ │ ├── test_cli_ditask.py │ │ ├── test_parallel_entry.py │ │ ├── test_random_collect.py │ │ ├── test_serial_entry.py │ │ ├── test_serial_entry_algo.py │ │ ├── test_serial_entry_bc.py │ │ ├── test_serial_entry_bco.py │ │ ├── test_serial_entry_dqfd.py │ │ ├── test_serial_entry_for_anytrading.py │ │ ├── test_serial_entry_guided_cost.py │ │ ├── test_serial_entry_mbrl.py │ │ ├── test_serial_entry_onpolicy.py │ │ ├── test_serial_entry_preference_based_irl.py │ │ ├── test_serial_entry_preference_based_irl_onpolicy.py │ │ ├── test_serial_entry_reward_model.py │ │ └── test_serial_entry_sqil.py │ └── utils.py ├── envs │ ├── __init__.py │ ├── common │ │ ├── __init__.py │ │ ├── common_function.py │ │ ├── env_element.py │ │ ├── env_element_runner.py │ │ └── tests │ │ │ └── test_common_function.py │ ├── env │ │ ├── __init__.py │ │ ├── base_env.py │ │ ├── default_wrapper.py │ │ ├── ding_env_wrapper.py │ │ ├── env_implementation_check.py │ │ └── tests │ │ │ ├── __init__.py │ │ │ ├── demo_env.py │ │ │ ├── test_ding_env_wrapper.py │ │ │ └── test_env_implementation_check.py │ ├── env_manager │ │ ├── __init__.py │ │ ├── base_env_manager.py │ │ ├── ding_env_manager.py │ │ ├── env_supervisor.py │ │ ├── envpool_env_manager.py │ │ ├── gym_vector_env_manager.py │ │ ├── subprocess_env_manager.py │ │ └── tests │ │ │ ├── __init__.py │ │ │ ├── conftest.py │ │ │ ├── test_base_env_manager.py │ │ │ ├── test_env_supervisor.py │ │ │ ├── test_envpool_env_manager.py │ │ │ ├── test_gym_vector_env_manager.py │ │ │ ├── test_shm.py │ │ │ └── test_subprocess_env_manager.py │ ├── env_wrappers │ │ ├── __init__.py │ │ └── env_wrappers.py │ └── gym_env.py ├── example │ ├── __init__.py │ ├── bcq.py │ ├── c51_nstep.py │ ├── collect_demo_data.py │ ├── cql.py │ ├── d4pg.py │ ├── ddpg.py │ ├── dqn.py │ ├── dqn_eval.py │ ├── dqn_frozen_lake.py │ ├── dqn_her.py │ ├── dqn_new_env.py │ ├── dqn_nstep.py │ ├── dqn_nstep_gymnasium.py │ ├── dqn_per.py │ ├── dqn_rnd.py │ ├── dt.py │ ├── edac.py │ ├── impala.py │ ├── iqn_nstep.py │ ├── mappo.py │ ├── masac.py │ ├── pdqn.py │ ├── ppg_offpolicy.py │ ├── ppo.py │ ├── ppo_lunarlander.py │ ├── ppo_lunarlander_continuous.py │ ├── ppo_offpolicy.py │ ├── ppo_with_complex_obs.py │ ├── qgpo.py │ ├── qrdqn_nstep.py │ ├── r2d2.py │ ├── sac.py │ ├── sqil.py │ ├── sqil_continuous.py │ ├── sql.py │ ├── td3.py │ └── trex.py ├── framework │ ├── __init__.py │ ├── context.py │ ├── event_loop.py │ ├── message_queue │ │ ├── __init__.py │ │ ├── mq.py │ │ ├── nng.py │ │ ├── redis.py │ │ └── tests │ │ │ ├── test_nng.py │ │ │ └── test_redis.py │ ├── middleware │ │ ├── __init__.py │ │ ├── barrier.py │ │ ├── ckpt_handler.py │ │ ├── collector.py │ │ ├── data_fetcher.py │ │ ├── distributer.py │ │ ├── functional │ │ │ ├── __init__.py │ │ │ ├── advantage_estimator.py │ │ │ ├── collector.py │ │ │ ├── ctx_helper.py │ │ │ ├── data_processor.py │ │ │ ├── enhancer.py │ │ │ ├── evaluator.py │ │ │ ├── explorer.py │ │ │ ├── logger.py │ │ │ ├── priority.py │ │ │ ├── termination_checker.py │ │ │ ├── timer.py │ │ │ └── trainer.py │ │ ├── learner.py │ │ └── tests │ │ │ ├── __init__.py │ │ │ ├── mock_for_test.py │ │ │ ├── test_advantage_estimator.py │ │ │ ├── test_barrier.py │ │ │ ├── test_ckpt_handler.py │ │ │ ├── test_collector.py │ │ │ ├── test_data_processor.py │ │ │ ├── test_distributer.py │ │ │ ├── test_enhancer.py │ │ │ ├── test_evaluator.py │ │ │ ├── test_explorer.py │ │ │ ├── test_logger.py │ │ │ ├── test_priority.py │ │ │ └── test_trainer.py │ ├── parallel.py │ ├── supervisor.py │ ├── task.py │ ├── tests │ │ ├── context_fake_data.py │ │ ├── test_context.py │ │ ├── test_event_loop.py │ │ ├── test_parallel.py │ │ ├── test_supervisor.py │ │ ├── test_task.py │ │ └── test_wrapper.py │ └── wrapper │ │ ├── __init__.py │ │ └── step_timer.py ├── hpc_rl │ ├── README.md │ ├── __init__.py │ ├── tests │ │ ├── test_dntd.py │ │ ├── test_gae.py │ │ ├── test_lstm.py │ │ ├── test_ppo.py │ │ ├── test_qntd.py │ │ ├── test_qntd_rescale.py │ │ ├── test_scatter.py │ │ ├── test_tdlambda.py │ │ ├── test_upgo.py │ │ ├── test_vtrace.py │ │ └── testbase.py │ └── wrapper.py ├── interaction │ ├── __init__.py │ ├── base │ │ ├── __init__.py │ │ ├── app.py │ │ ├── common.py │ │ ├── network.py │ │ └── threading.py │ ├── config │ │ ├── __init__.py │ │ └── base.py │ ├── exception │ │ ├── __init__.py │ │ ├── base.py │ │ ├── master.py │ │ └── slave.py │ ├── master │ │ ├── __init__.py │ │ ├── base.py │ │ ├── connection.py │ │ ├── master.py │ │ └── task.py │ ├── slave │ │ ├── __init__.py │ │ ├── action.py │ │ └── slave.py │ └── tests │ │ ├── __init__.py │ │ ├── base │ │ ├── __init__.py │ │ ├── test_app.py │ │ ├── test_common.py │ │ ├── test_network.py │ │ └── test_threading.py │ │ ├── config │ │ ├── __init__.py │ │ └── test_base.py │ │ ├── exception │ │ ├── __init__.py │ │ ├── test_base.py │ │ ├── test_master.py │ │ └── test_slave.py │ │ ├── interaction │ │ ├── __init__.py │ │ ├── bases.py │ │ ├── test_errors.py │ │ └── test_simple.py │ │ └── test_utils │ │ ├── __init__.py │ │ ├── random.py │ │ └── stream.py ├── league │ ├── __init__.py │ ├── algorithm.py │ ├── base_league.py │ ├── metric.py │ ├── one_vs_one_league.py │ ├── player.py │ ├── shared_payoff.py │ ├── starcraft_player.py │ └── tests │ │ ├── conftest.py │ │ ├── league_test_default_config.py │ │ ├── test_league_metric.py │ │ ├── test_one_vs_one_league.py │ │ ├── test_payoff.py │ │ └── test_player.py ├── model │ ├── __init__.py │ ├── common │ │ ├── __init__.py │ │ ├── encoder.py │ │ ├── head.py │ │ ├── tests │ │ │ ├── test_encoder.py │ │ │ └── test_head.py │ │ └── utils.py │ ├── template │ │ ├── __init__.py │ │ ├── acer.py │ │ ├── atoc.py │ │ ├── bc.py │ │ ├── bcq.py │ │ ├── collaq.py │ │ ├── coma.py │ │ ├── decision_transformer.py │ │ ├── diffusion.py │ │ ├── ebm.py │ │ ├── edac.py │ │ ├── havac.py │ │ ├── hpt.py │ │ ├── language_transformer.py │ │ ├── madqn.py │ │ ├── maqac.py │ │ ├── mavac.py │ │ ├── ngu.py │ │ ├── pdqn.py │ │ ├── pg.py │ │ ├── ppg.py │ │ ├── procedure_cloning.py │ │ ├── q_learning.py │ │ ├── qac.py │ │ ├── qac_dist.py │ │ ├── qgpo.py │ │ ├── qmix.py │ │ ├── qtran.py │ │ ├── qvac.py │ │ ├── sqn.py │ │ ├── tests │ │ │ ├── test_acer.py │ │ │ ├── test_atoc.py │ │ │ ├── test_bc.py │ │ │ ├── test_bcq.py │ │ │ ├── test_collaq.py │ │ │ ├── test_coma_nn.py │ │ │ ├── test_decision_transformer.py │ │ │ ├── test_ebm.py │ │ │ ├── test_edac.py │ │ │ ├── test_havac.py │ │ │ ├── test_hpt.py │ │ │ ├── test_hybrid_qac.py │ │ │ ├── test_language_transformer.py │ │ │ ├── test_madqn.py │ │ │ ├── test_maqac.py │ │ │ ├── test_mavac.py │ │ │ ├── test_ngu.py │ │ │ ├── test_pdqn.py │ │ │ ├── test_pg.py │ │ │ ├── test_procedure_cloning.py │ │ │ ├── test_q_learning.py │ │ │ ├── test_qac.py │ │ │ ├── test_qac_dist.py │ │ │ ├── test_qmix.py │ │ │ ├── test_qtran.py │ │ │ ├── test_vac.py │ │ │ ├── test_vae.py │ │ │ └── test_wqmix.py │ │ ├── vac.py │ │ ├── vae.py │ │ └── wqmix.py │ └── wrapper │ │ ├── __init__.py │ │ ├── model_wrappers.py │ │ └── test_model_wrappers.py ├── policy │ ├── __init__.py │ ├── a2c.py │ ├── acer.py │ ├── atoc.py │ ├── base_policy.py │ ├── bc.py │ ├── bcq.py │ ├── bdq.py │ ├── c51.py │ ├── collaq.py │ ├── coma.py │ ├── command_mode_policy_instance.py │ ├── common_utils.py │ ├── cql.py │ ├── d4pg.py │ ├── ddpg.py │ ├── dqfd.py │ ├── dqn.py │ ├── dt.py │ ├── edac.py │ ├── fqf.py │ ├── happo.py │ ├── ibc.py │ ├── il.py │ ├── impala.py │ ├── iql.py │ ├── iqn.py │ ├── madqn.py │ ├── mbpolicy │ │ ├── __init__.py │ │ ├── dreamer.py │ │ ├── mbsac.py │ │ ├── tests │ │ │ └── test_mbpolicy_utils.py │ │ └── utils.py │ ├── mdqn.py │ ├── ngu.py │ ├── offppo_collect_traj.py │ ├── pc.py │ ├── pdqn.py │ ├── pg.py │ ├── plan_diffuser.py │ ├── policy_factory.py │ ├── ppg.py │ ├── ppo.py │ ├── ppof.py │ ├── prompt_awr.py │ ├── prompt_pg.py │ ├── qgpo.py │ ├── qmix.py │ ├── qrdqn.py │ ├── qtran.py │ ├── r2d2.py │ ├── r2d2_collect_traj.py │ ├── r2d2_gtrxl.py │ ├── r2d3.py │ ├── rainbow.py │ ├── sac.py │ ├── sql.py │ ├── sqn.py │ ├── td3.py │ ├── td3_bc.py │ ├── td3_vae.py │ ├── tests │ │ ├── test_common_utils.py │ │ ├── test_cql.py │ │ ├── test_r2d3.py │ │ └── test_stdim.py │ └── wqmix.py ├── reward_model │ ├── __init__.py │ ├── base_reward_model.py │ ├── drex_reward_model.py │ ├── gail_irl_model.py │ ├── guided_cost_reward_model.py │ ├── her_reward_model.py │ ├── icm_reward_model.py │ ├── ngu_reward_model.py │ ├── pdeil_irl_model.py │ ├── pwil_irl_model.py │ ├── red_irl_model.py │ ├── rnd_reward_model.py │ ├── tests │ │ └── test_gail_irl_model.py │ └── trex_reward_model.py ├── rl_utils │ ├── README.md │ ├── __init__.py │ ├── a2c.py │ ├── acer.py │ ├── adder.py │ ├── beta_function.py │ ├── coma.py │ ├── exploration.py │ ├── gae.py │ ├── grpo.py │ ├── happo.py │ ├── isw.py │ ├── log_prob_utils.py │ ├── ppg.py │ ├── ppo.py │ ├── retrace.py │ ├── rloo.py │ ├── sampler.py │ ├── td.py │ ├── tests │ │ ├── test_a2c.py │ │ ├── test_adder.py │ │ ├── test_coma.py │ │ ├── test_exploration.py │ │ ├── test_gae.py │ │ ├── test_grpo_rlhf.py │ │ ├── test_happo.py │ │ ├── test_log_prob_fn.py │ │ ├── test_log_prob_utils.py │ │ ├── test_ppg.py │ │ ├── test_ppo.py │ │ ├── test_ppo_rlhf.py │ │ ├── test_retrace.py │ │ ├── test_rloo_rlhf.py │ │ ├── test_td.py │ │ ├── test_upgo.py │ │ ├── test_value_rescale.py │ │ └── test_vtrace.py │ ├── upgo.py │ ├── value_rescale.py │ └── vtrace.py ├── scripts │ ├── dijob-qbert.yaml │ ├── docker-test-entry.sh │ ├── docker-test.sh │ ├── install-k8s-tools.sh │ ├── kill.sh │ ├── local_parallel.sh │ ├── local_serial.sh │ ├── main_league.sh │ ├── main_league_slurm.sh │ └── tests │ │ ├── test_parallel_socket.py │ │ └── test_parallel_socket.sh ├── torch_utils │ ├── __init__.py │ ├── backend_helper.py │ ├── checkpoint_helper.py │ ├── data_helper.py │ ├── dataparallel.py │ ├── diffusion_SDE │ │ ├── __init__.py │ │ └── dpm_solver_pytorch.py │ ├── distribution.py │ ├── loss │ │ ├── __init__.py │ │ ├── contrastive_loss.py │ │ ├── cross_entropy_loss.py │ │ ├── multi_logits_loss.py │ │ └── tests │ │ │ ├── test_contrastive_loss.py │ │ │ ├── test_cross_entropy_loss.py │ │ │ └── test_multi_logits_loss.py │ ├── lr_scheduler.py │ ├── math_helper.py │ ├── metric.py │ ├── model_helper.py │ ├── network │ │ ├── __init__.py │ │ ├── activation.py │ │ ├── diffusion.py │ │ ├── dreamer.py │ │ ├── gtrxl.py │ │ ├── gumbel_softmax.py │ │ ├── merge.py │ │ ├── nn_module.py │ │ ├── normalization.py │ │ ├── popart.py │ │ ├── res_block.py │ │ ├── resnet.py │ │ ├── rnn.py │ │ ├── scatter_connection.py │ │ ├── soft_argmax.py │ │ ├── tests │ │ │ ├── test_activation.py │ │ │ ├── test_diffusion.py │ │ │ ├── test_dreamer.py │ │ │ ├── test_gtrxl.py │ │ │ ├── test_gumbel_softmax.py │ │ │ ├── test_merge.py │ │ │ ├── test_nn_module.py │ │ │ ├── test_normalization.py │ │ │ ├── test_popart.py │ │ │ ├── test_res_block.py │ │ │ ├── test_resnet.py │ │ │ ├── test_rnn.py │ │ │ ├── test_scatter.py │ │ │ ├── test_soft_argmax.py │ │ │ └── test_transformer.py │ │ └── transformer.py │ ├── nn_test_helper.py │ ├── optimizer_helper.py │ ├── parameter.py │ ├── reshape_helper.py │ └── tests │ │ ├── test_backend_helper.py │ │ ├── test_ckpt_helper.py │ │ ├── test_data_helper.py │ │ ├── test_distribution.py │ │ ├── test_feature_merge.py │ │ ├── test_lr_scheduler.py │ │ ├── test_math_helper.py │ │ ├── test_metric.py │ │ ├── test_model_helper.py │ │ ├── test_nn_test_helper.py │ │ ├── test_optimizer.py │ │ ├── test_parameter.py │ │ └── test_reshape_helper.py ├── utils │ ├── __init__.py │ ├── autolog │ │ ├── __init__.py │ │ ├── base.py │ │ ├── data.py │ │ ├── model.py │ │ ├── tests │ │ │ ├── __init__.py │ │ │ ├── test_data.py │ │ │ ├── test_model.py │ │ │ └── test_time.py │ │ ├── time_ctl.py │ │ └── value.py │ ├── bfs_helper.py │ ├── collection_helper.py │ ├── compression_helper.py │ ├── data │ │ ├── __init__.py │ │ ├── base_dataloader.py │ │ ├── collate_fn.py │ │ ├── dataloader.py │ │ ├── dataset.py │ │ ├── rlhf_offline_dataset.py │ │ ├── rlhf_online_dataset.py │ │ ├── structure │ │ │ ├── __init__.py │ │ │ ├── cache.py │ │ │ └── lifo_deque.py │ │ └── tests │ │ │ ├── dataloader_speed │ │ │ └── experiment_dataloader_speed.py │ │ │ ├── test_cache.py │ │ │ ├── test_collate_fn.py │ │ │ ├── test_dataloader.py │ │ │ ├── test_dataset.py │ │ │ ├── test_rlhf_offline_dataset.py │ │ │ └── test_rlhf_online_dataset.py │ ├── default_helper.py │ ├── deprecation.py │ ├── design_helper.py │ ├── dict_helper.py │ ├── fake_linklink.py │ ├── fast_copy.py │ ├── file_helper.py │ ├── import_helper.py │ ├── k8s_helper.py │ ├── linklink_dist_helper.py │ ├── loader │ │ ├── __init__.py │ │ ├── base.py │ │ ├── collection.py │ │ ├── dict.py │ │ ├── exception.py │ │ ├── mapping.py │ │ ├── norm.py │ │ ├── number.py │ │ ├── string.py │ │ ├── tests │ │ │ ├── __init__.py │ │ │ ├── loader │ │ │ │ ├── __init__.py │ │ │ │ ├── test_base.py │ │ │ │ ├── test_collection.py │ │ │ │ ├── test_dict.py │ │ │ │ ├── test_mapping.py │ │ │ │ ├── test_norm.py │ │ │ │ ├── test_number.py │ │ │ │ ├── test_string.py │ │ │ │ ├── test_types.py │ │ │ │ └── test_utils.py │ │ │ └── test_cartpole_dqn_serial_config_loader.py │ │ ├── types.py │ │ └── utils.py │ ├── lock_helper.py │ ├── log_helper.py │ ├── log_writer_helper.py │ ├── memory_helper.py │ ├── normalizer_helper.py │ ├── orchestrator_launcher.py │ ├── profiler_helper.py │ ├── pytorch_ddp_dist_helper.py │ ├── registry.py │ ├── registry_factory.py │ ├── render_helper.py │ ├── scheduler_helper.py │ ├── segment_tree.py │ ├── slurm_helper.py │ ├── system_helper.py │ ├── tests │ │ ├── config │ │ │ └── k8s-config.yaml │ │ ├── test_bfs_helper.py │ │ ├── test_collection_helper.py │ │ ├── test_compression_helper.py │ │ ├── test_config_helper.py │ │ ├── test_default_helper.py │ │ ├── test_deprecation.py │ │ ├── test_design_helper.py │ │ ├── test_file_helper.py │ │ ├── test_import_helper.py │ │ ├── test_k8s_launcher.py │ │ ├── test_lock.py │ │ ├── test_log_helper.py │ │ ├── test_log_writer_helper.py │ │ ├── test_memory_helper.py │ │ ├── test_normalizer_helper.py │ │ ├── test_profiler_helper.py │ │ ├── test_registry.py │ │ ├── test_scheduler_helper.py │ │ ├── test_segment_tree.py │ │ ├── test_system_helper.py │ │ └── test_time_helper.py │ ├── time_helper.py │ ├── time_helper_base.py │ ├── time_helper_cuda.py │ └── type_helper.py ├── worker │ ├── __init__.py │ ├── adapter │ │ ├── __init__.py │ │ ├── learner_aggregator.py │ │ └── tests │ │ │ └── test_learner_aggregator.py │ ├── collector │ │ ├── __init__.py │ │ ├── base_parallel_collector.py │ │ ├── base_serial_collector.py │ │ ├── base_serial_evaluator.py │ │ ├── battle_episode_serial_collector.py │ │ ├── battle_interaction_serial_evaluator.py │ │ ├── battle_sample_serial_collector.py │ │ ├── comm │ │ │ ├── __init__.py │ │ │ ├── base_comm_collector.py │ │ │ ├── flask_fs_collector.py │ │ │ ├── tests │ │ │ │ └── test_collector_with_coordinator.py │ │ │ └── utils.py │ │ ├── episode_serial_collector.py │ │ ├── interaction_serial_evaluator.py │ │ ├── marine_parallel_collector.py │ │ ├── metric_serial_evaluator.py │ │ ├── sample_serial_collector.py │ │ ├── tests │ │ │ ├── __init__.py │ │ │ ├── fake_cls_policy.py │ │ │ ├── fake_cpong_dqn_config.py │ │ │ ├── speed_test │ │ │ │ ├── __init__.py │ │ │ │ ├── fake_env.py │ │ │ │ ├── fake_policy.py │ │ │ │ ├── test_collector_profile.py │ │ │ │ └── utils.py │ │ │ ├── test_base_serial_collector.py │ │ │ ├── test_episode_serial_collector.py │ │ │ ├── test_marine_parallel_collector.py │ │ │ ├── test_metric_serial_evaluator.py │ │ │ └── test_sample_serial_collector.py │ │ └── zergling_parallel_collector.py │ ├── coordinator │ │ ├── __init__.py │ │ ├── base_parallel_commander.py │ │ ├── base_serial_commander.py │ │ ├── comm_coordinator.py │ │ ├── coordinator.py │ │ ├── one_vs_one_parallel_commander.py │ │ ├── operator_server.py │ │ ├── resource_manager.py │ │ ├── solo_parallel_commander.py │ │ └── tests │ │ │ ├── conftest.py │ │ │ ├── test_coordinator.py │ │ │ ├── test_fake_operator_server.py │ │ │ └── test_one_vs_one_commander.py │ ├── learner │ │ ├── __init__.py │ │ ├── base_learner.py │ │ ├── comm │ │ │ ├── __init__.py │ │ │ ├── base_comm_learner.py │ │ │ ├── flask_fs_learner.py │ │ │ ├── tests │ │ │ │ └── test_learner_with_coordinator.py │ │ │ └── utils.py │ │ ├── learner_hook.py │ │ └── tests │ │ │ ├── test_base_learner.py │ │ │ └── test_learner_hook.py │ └── replay_buffer │ │ ├── __init__.py │ │ ├── advanced_buffer.py │ │ ├── base_buffer.py │ │ ├── episode_buffer.py │ │ ├── naive_buffer.py │ │ ├── tests │ │ ├── conftest.py │ │ ├── test_advanced_buffer.py │ │ └── test_naive_buffer.py │ │ └── utils.py └── world_model │ ├── __init__.py │ ├── base_world_model.py │ ├── ddppo.py │ ├── dreamer.py │ ├── idm.py │ ├── mbpo.py │ ├── model │ ├── __init__.py │ ├── ensemble.py │ ├── networks.py │ └── tests │ │ ├── test_ensemble.py │ │ └── test_networks.py │ ├── tests │ ├── test_ddppo.py │ ├── test_dreamerv3.py │ ├── test_idm.py │ ├── test_mbpo.py │ ├── test_world_model.py │ └── test_world_model_utils.py │ └── utils.py ├── dizoo ├── __init__.py ├── atari │ ├── __init__.py │ ├── atari.gif │ ├── config │ │ ├── __init__.py │ │ └── serial │ │ │ ├── __init__.py │ │ │ ├── asterix │ │ │ ├── __init__.py │ │ │ └── asterix_mdqn_config.py │ │ │ ├── demon_attack │ │ │ └── demon_attack_dqn_config.py │ │ │ ├── enduro │ │ │ ├── __init__.py │ │ │ ├── enduro_dqn_config.py │ │ │ ├── enduro_impala_config.py │ │ │ ├── enduro_mdqn_config.py │ │ │ ├── enduro_onppo_config.py │ │ │ ├── enduro_qrdqn_config.py │ │ │ └── enduro_rainbow_config.py │ │ │ ├── montezuma │ │ │ └── montezuma_ngu_config.py │ │ │ ├── phoenix │ │ │ ├── phoenix_fqf_config.py │ │ │ └── phoenix_iqn_config.py │ │ │ ├── pitfall │ │ │ └── pitfall_ngu_config.py │ │ │ ├── pong │ │ │ ├── __init__.py │ │ │ ├── pong_a2c_config.py │ │ │ ├── pong_acer_config.py │ │ │ ├── pong_c51_config.py │ │ │ ├── pong_cql_config.py │ │ │ ├── pong_dqfd_config.py │ │ │ ├── pong_dqn_config.py │ │ │ ├── pong_dqn_ddp_config.py │ │ │ ├── pong_dqn_envpool_config.py │ │ │ ├── pong_dqn_multi_gpu_config.py │ │ │ ├── pong_dqn_render_config.py │ │ │ ├── pong_dqn_stdim_config.py │ │ │ ├── pong_dt_config.py │ │ │ ├── pong_fqf_config.py │ │ │ ├── pong_gail_dqn_config.py │ │ │ ├── pong_impala_config.py │ │ │ ├── pong_iqn_config.py │ │ │ ├── pong_ngu_config.py │ │ │ ├── pong_ppg_config.py │ │ │ ├── pong_ppo_config.py │ │ │ ├── pong_ppo_ddp_config.py │ │ │ ├── pong_qrdqn_config.py │ │ │ ├── pong_qrdqn_generation_data_config.py │ │ │ ├── pong_r2d2_config.py │ │ │ ├── pong_r2d2_gtrxl_config.py │ │ │ ├── pong_r2d2_residual_config.py │ │ │ ├── pong_r2d3_offppoexpert_config.py │ │ │ ├── pong_r2d3_r2d2expert_config.py │ │ │ ├── pong_rainbow_config.py │ │ │ ├── pong_sqil_config.py │ │ │ ├── pong_sql_config.py │ │ │ ├── pong_trex_offppo_config.py │ │ │ └── pong_trex_sql_config.py │ │ │ ├── qbert │ │ │ ├── __init__.py │ │ │ ├── qbert_a2c_config.py │ │ │ ├── qbert_acer_config.py │ │ │ ├── qbert_c51_config.py │ │ │ ├── qbert_cql_config.py │ │ │ ├── qbert_dqfd_config.py │ │ │ ├── qbert_dqn_config.py │ │ │ ├── qbert_fqf_config.py │ │ │ ├── qbert_impala_config.py │ │ │ ├── qbert_iqn_config.py │ │ │ ├── qbert_ngu_config.py │ │ │ ├── qbert_offppo_config.py │ │ │ ├── qbert_onppo_config.py │ │ │ ├── qbert_ppg_config.py │ │ │ ├── qbert_qrdqn_config.py │ │ │ ├── qbert_qrdqn_generation_data_config.py │ │ │ ├── qbert_r2d2_config.py │ │ │ ├── qbert_r2d2_gtrxl_config.py │ │ │ ├── qbert_rainbow_config.py │ │ │ ├── qbert_sqil_config.py │ │ │ ├── qbert_sql_config.py │ │ │ ├── qbert_trex_dqn_config.py │ │ │ └── qbert_trex_offppo_config.py │ │ │ └── spaceinvaders │ │ │ ├── __init__.py │ │ │ ├── spaceinvaders_a2c_config.py │ │ │ ├── spaceinvaders_acer_config.py │ │ │ ├── spaceinvaders_c51_config.py │ │ │ ├── spaceinvaders_dqfd_config.py │ │ │ ├── spaceinvaders_dqn_config.py │ │ │ ├── spaceinvaders_dqn_config_multi_gpu_ddp.py │ │ │ ├── spaceinvaders_dqn_config_multi_gpu_dp.py │ │ │ ├── spaceinvaders_fqf_config.py │ │ │ ├── spaceinvaders_impala_config.py │ │ │ ├── spaceinvaders_iqn_config.py │ │ │ ├── spaceinvaders_mdqn_config.py │ │ │ ├── spaceinvaders_ngu_config.py │ │ │ ├── spaceinvaders_offppo_config.py │ │ │ ├── spaceinvaders_onppo_config.py │ │ │ ├── spaceinvaders_ppg_config.py │ │ │ ├── spaceinvaders_qrdqn_config.py │ │ │ ├── spaceinvaders_r2d2_config.py │ │ │ ├── spaceinvaders_r2d2_gtrxl_config.py │ │ │ ├── spaceinvaders_r2d2_residual_config.py │ │ │ ├── spaceinvaders_rainbow_config.py │ │ │ ├── spaceinvaders_sqil_config.py │ │ │ ├── spaceinvaders_sql_config.py │ │ │ ├── spaceinvaders_trex_dqn_config.py │ │ │ └── spaceinvaders_trex_offppo_config.py │ ├── entry │ │ ├── __init__.py │ │ ├── atari_dqn_main.py │ │ ├── atari_dt_main.py │ │ ├── atari_impala_main.py │ │ ├── atari_ppg_main.py │ │ ├── phoenix_fqf_main.py │ │ ├── phoenix_iqn_main.py │ │ ├── pong_cql_main.py │ │ ├── pong_dqn_envpool_main.py │ │ ├── pong_fqf_main.py │ │ ├── qbert_cql_main.py │ │ ├── qbert_fqf_main.py │ │ ├── spaceinvaders_dqn_eval.py │ │ ├── spaceinvaders_dqn_main_multi_gpu_ddp.py │ │ ├── spaceinvaders_dqn_main_multi_gpu_dp.py │ │ └── spaceinvaders_fqf_main.py │ ├── envs │ │ ├── __init__.py │ │ ├── atari_env.py │ │ ├── atari_wrappers.py │ │ └── test_atari_env.py │ └── example │ │ ├── atari_dqn.py │ │ ├── atari_dqn_ddp.py │ │ ├── atari_dqn_dist.py │ │ ├── atari_dqn_dist_ddp.py │ │ ├── atari_dqn_dist_rdma.py │ │ ├── atari_dqn_dp.py │ │ ├── atari_ppo.py │ │ └── atari_ppo_ddp.py ├── beergame │ ├── __init__.py │ ├── beergame.png │ ├── config │ │ └── beergame_onppo_config.py │ ├── entry │ │ └── beergame_eval.py │ └── envs │ │ ├── BGAgent.py │ │ ├── __init__.py │ │ ├── beergame_core.py │ │ ├── beergame_env.py │ │ ├── clBeergame.py │ │ ├── plotting.py │ │ └── utils.py ├── bitflip │ ├── README.md │ ├── __init__.py │ ├── bitflip.gif │ ├── config │ │ ├── __init__.py │ │ ├── bitflip_her_dqn_config.py │ │ └── bitflip_pure_dqn_config.py │ ├── entry │ │ ├── __init__.py │ │ └── bitflip_dqn_main.py │ └── envs │ │ ├── __init__.py │ │ ├── bitflip_env.py │ │ └── test_bitfilp_env.py ├── box2d │ ├── __init__.py │ ├── bipedalwalker │ │ ├── __init__.py │ │ ├── config │ │ │ ├── __init__.py │ │ │ ├── bipedalwalker_a2c_config.py │ │ │ ├── bipedalwalker_bco_config.py │ │ │ ├── bipedalwalker_ddpg_config.py │ │ │ ├── bipedalwalker_dt_config.py │ │ │ ├── bipedalwalker_gail_sac_config.py │ │ │ ├── bipedalwalker_impala_config.py │ │ │ ├── bipedalwalker_pg_config.py │ │ │ ├── bipedalwalker_ppo_config.py │ │ │ ├── bipedalwalker_ppopg_config.py │ │ │ ├── bipedalwalker_sac_config.py │ │ │ └── bipedalwalker_td3_config.py │ │ ├── entry │ │ │ ├── __init__.py │ │ │ └── bipedalwalker_ppo_eval.py │ │ ├── envs │ │ │ ├── __init__.py │ │ │ ├── bipedalwalker_env.py │ │ │ └── test_bipedalwalker.py │ │ └── original.gif │ ├── carracing │ │ ├── __init__.py │ │ ├── car_racing.gif │ │ ├── config │ │ │ ├── __init__.py │ │ │ └── carracing_dqn_config.py │ │ └── envs │ │ │ ├── __init__.py │ │ │ ├── carracing_env.py │ │ │ └── test_carracing_env.py │ └── lunarlander │ │ ├── __init__.py │ │ ├── config │ │ ├── __init__.py │ │ ├── lunarlander_a2c_config.py │ │ ├── lunarlander_acer_config.py │ │ ├── lunarlander_bco_config.py │ │ ├── lunarlander_c51_config.py │ │ ├── lunarlander_cont_ddpg_config.py │ │ ├── lunarlander_cont_sac_config.py │ │ ├── lunarlander_cont_td3_config.py │ │ ├── lunarlander_cont_td3_vae_config.py │ │ ├── lunarlander_discrete_sac_config.py │ │ ├── lunarlander_dqfd_config.py │ │ ├── lunarlander_dqn_config.py │ │ ├── lunarlander_dqn_deque_config.py │ │ ├── lunarlander_dt_config.py │ │ ├── lunarlander_gail_dqn_config.py │ │ ├── lunarlander_gcl_config.py │ │ ├── lunarlander_hpt_config.py │ │ ├── lunarlander_impala_config.py │ │ ├── lunarlander_ngu_config.py │ │ ├── lunarlander_offppo_config.py │ │ ├── lunarlander_pg_config.py │ │ ├── lunarlander_ppo_config.py │ │ ├── lunarlander_ppo_continuous_config.py │ │ ├── lunarlander_qrdqn_config.py │ │ ├── lunarlander_r2d2_config.py │ │ ├── lunarlander_r2d2_gtrxl_config.py │ │ ├── lunarlander_r2d3_ppoexpert_config.py │ │ ├── lunarlander_r2d3_r2d2expert_config.py │ │ ├── lunarlander_rnd_onppo_config.py │ │ ├── lunarlander_sqil_config.py │ │ ├── lunarlander_sql_config.py │ │ ├── lunarlander_trex_dqn_config.py │ │ └── lunarlander_trex_offppo_config.py │ │ ├── entry │ │ ├── __init__.py │ │ ├── lunarlander_dqn_eval.py │ │ ├── lunarlander_dqn_example.py │ │ └── lunarlander_hpt_example.py │ │ ├── envs │ │ ├── __init__.py │ │ ├── lunarlander_env.py │ │ └── test_lunarlander_env.py │ │ ├── lunarlander.gif │ │ └── offline_data │ │ ├── collect_dqn_data_config.py │ │ ├── lunarlander_collect_data.py │ │ └── lunarlander_show_data.py ├── bsuite │ ├── __init__.py │ ├── bsuite.png │ ├── config │ │ ├── __init__.py │ │ └── serial │ │ │ ├── bandit_noise │ │ │ └── bandit_noise_0_dqn_config.py │ │ │ ├── cartpole_swingup │ │ │ └── cartpole_swingup_0_dqn_config.py │ │ │ └── memory_len │ │ │ ├── memory_len_0_a2c_config.py │ │ │ ├── memory_len_0_dqn_config.py │ │ │ ├── memory_len_15_r2d2_config.py │ │ │ └── memory_len_15_r2d2_gtrxl_config.py │ └── envs │ │ ├── __init__.py │ │ ├── bsuite_env.py │ │ └── test_bsuite_env.py ├── classic_control │ ├── __init__.py │ ├── acrobot │ │ ├── __init__.py │ │ ├── acrobot.gif │ │ ├── config │ │ │ ├── __init__.py │ │ │ └── acrobot_dqn_config.py │ │ └── envs │ │ │ ├── __init__.py │ │ │ ├── acrobot_env.py │ │ │ └── test_acrobot_env.py │ ├── cartpole │ │ ├── __init__.py │ │ ├── cartpole.gif │ │ ├── config │ │ │ ├── __init__.py │ │ │ ├── cartpole_a2c_config.py │ │ │ ├── cartpole_acer_config.py │ │ │ ├── cartpole_bc_config.py │ │ │ ├── cartpole_bco_config.py │ │ │ ├── cartpole_c51_config.py │ │ │ ├── cartpole_cql_config.py │ │ │ ├── cartpole_decision_transformer.py │ │ │ ├── cartpole_dqfd_config.py │ │ │ ├── cartpole_dqn_config.py │ │ │ ├── cartpole_dqn_ddp_config.py │ │ │ ├── cartpole_dqn_gail_config.py │ │ │ ├── cartpole_dqn_rnd_config.py │ │ │ ├── cartpole_dqn_stdim_config.py │ │ │ ├── cartpole_drex_dqn_config.py │ │ │ ├── cartpole_dt_config.py │ │ │ ├── cartpole_fqf_config.py │ │ │ ├── cartpole_gcl_config.py │ │ │ ├── cartpole_impala_config.py │ │ │ ├── cartpole_iqn_config.py │ │ │ ├── cartpole_mdqn_config.py │ │ │ ├── cartpole_ngu_config.py │ │ │ ├── cartpole_pg_config.py │ │ │ ├── cartpole_ppg_config.py │ │ │ ├── cartpole_ppo_config.py │ │ │ ├── cartpole_ppo_ddp_config.py │ │ │ ├── cartpole_ppo_icm_config.py │ │ │ ├── cartpole_ppo_offpolicy_config.py │ │ │ ├── cartpole_ppo_stdim_config.py │ │ │ ├── cartpole_ppopg_config.py │ │ │ ├── cartpole_qrdqn_config.py │ │ │ ├── cartpole_qrdqn_generation_data_config.py │ │ │ ├── cartpole_r2d2_config.py │ │ │ ├── cartpole_r2d2_gtrxl_config.py │ │ │ ├── cartpole_r2d2_residual_config.py │ │ │ ├── cartpole_rainbow_config.py │ │ │ ├── cartpole_rnd_onppo_config.py │ │ │ ├── cartpole_sac_config.py │ │ │ ├── cartpole_sqil_config.py │ │ │ ├── cartpole_sql_config.py │ │ │ ├── cartpole_sqn_config.py │ │ │ ├── cartpole_trex_dqn_config.py │ │ │ ├── cartpole_trex_offppo_config.py │ │ │ ├── cartpole_trex_onppo_config.py │ │ │ └── parallel │ │ │ │ ├── __init__.py │ │ │ │ ├── cartpole_dqn_config.py │ │ │ │ ├── cartpole_dqn_config_k8s.py │ │ │ │ └── cartpole_dqn_dist.sh │ │ ├── entry │ │ │ ├── __init__.py │ │ │ ├── cartpole_c51_deploy.py │ │ │ ├── cartpole_c51_main.py │ │ │ ├── cartpole_cql_main.py │ │ │ ├── cartpole_dqn_buffer_main.py │ │ │ ├── cartpole_dqn_eval.py │ │ │ ├── cartpole_dqn_main.py │ │ │ ├── cartpole_dqn_pwil_main.py │ │ │ ├── cartpole_fqf_main.py │ │ │ ├── cartpole_ppg_main.py │ │ │ ├── cartpole_ppo_main.py │ │ │ └── cartpole_ppo_offpolicy_main.py │ │ └── envs │ │ │ ├── __init__.py │ │ │ ├── cartpole_env.py │ │ │ ├── test_cartpole_env.py │ │ │ └── test_cartpole_env_manager.py │ ├── mountain_car │ │ ├── __init__.py │ │ ├── config │ │ │ └── mtcar_rainbow_config.py │ │ └── envs │ │ │ ├── __init__.py │ │ │ ├── mtcar_env.py │ │ │ └── test_mtcar_env.py │ └── pendulum │ │ ├── __init__.py │ │ ├── config │ │ ├── __init__.py │ │ ├── mbrl │ │ │ ├── pendulum_mbsac_ddppo_config.py │ │ │ ├── pendulum_mbsac_mbpo_config.py │ │ │ ├── pendulum_sac_ddppo_config.py │ │ │ ├── pendulum_sac_mbpo_config.py │ │ │ └── pendulum_stevesac_mbpo_config.py │ │ ├── pendulum_a2c_config.py │ │ ├── pendulum_bdq_config.py │ │ ├── pendulum_cql_config.py │ │ ├── pendulum_d4pg_config.py │ │ ├── pendulum_ddpg_config.py │ │ ├── pendulum_dqn_config.py │ │ ├── pendulum_ibc_config.py │ │ ├── pendulum_pg_config.py │ │ ├── pendulum_ppo_config.py │ │ ├── pendulum_sac_config.py │ │ ├── pendulum_sac_data_generation_config.py │ │ ├── pendulum_sqil_sac_config.py │ │ ├── pendulum_td3_bc_config.py │ │ ├── pendulum_td3_config.py │ │ └── pendulum_td3_data_generation_config.py │ │ ├── entry │ │ ├── __init__.py │ │ ├── pendulum_cql_ddpg_main.py │ │ ├── pendulum_cql_main.py │ │ ├── pendulum_d4pg_main.py │ │ ├── pendulum_ddpg_main.py │ │ ├── pendulum_dqn_eval.py │ │ ├── pendulum_ppo_main.py │ │ ├── pendulum_td3_bc_main.py │ │ └── pendulum_td3_main.py │ │ ├── envs │ │ ├── __init__.py │ │ ├── pendulum_env.py │ │ └── test_pendulum_env.py │ │ └── pendulum.gif ├── cliffwalking │ ├── __init__.py │ ├── cliff_walking.gif │ ├── config │ │ └── cliffwalking_dqn_config.py │ ├── entry │ │ ├── cliffwalking_dqn_deploy.py │ │ └── cliffwalking_dqn_main.py │ └── envs │ │ ├── __init__.py │ │ ├── cliffwalking_env.py │ │ └── test_cliffwalking_env.py ├── common │ ├── __init__.py │ └── policy │ │ ├── __init__.py │ │ ├── md_dqn.py │ │ ├── md_ppo.py │ │ └── md_rainbow_dqn.py ├── competitive_rl │ ├── README.md │ ├── __init__.py │ ├── competitive_rl.gif │ ├── config │ │ └── cpong_dqn_config.py │ └── envs │ │ ├── __init__.py │ │ ├── competitive_rl_env.py │ │ ├── competitive_rl_env_wrapper.py │ │ ├── resources │ │ └── pong │ │ │ ├── checkpoint-alphapong.pkl │ │ │ ├── checkpoint-medium.pkl │ │ │ ├── checkpoint-strong.pkl │ │ │ └── checkpoint-weak.pkl │ │ └── test_competitive_rl.py ├── d4rl │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ ├── antmaze_umaze_pd_config.py │ │ ├── halfcheetah_expert_cql_config.py │ │ ├── halfcheetah_expert_dt_config.py │ │ ├── halfcheetah_expert_td3bc_config.py │ │ ├── halfcheetah_medium_bcq_config.py │ │ ├── halfcheetah_medium_cql_config.py │ │ ├── halfcheetah_medium_dt_config.py │ │ ├── halfcheetah_medium_edac_config.py │ │ ├── halfcheetah_medium_expert_bcq_config.py │ │ ├── halfcheetah_medium_expert_cql_config.py │ │ ├── halfcheetah_medium_expert_dt_config.py │ │ ├── halfcheetah_medium_expert_edac_config.py │ │ ├── halfcheetah_medium_expert_iql_config.py │ │ ├── halfcheetah_medium_expert_pd_config.py │ │ ├── halfcheetah_medium_expert_qgpo_config.py │ │ ├── halfcheetah_medium_expert_td3bc_config.py │ │ ├── halfcheetah_medium_iql_config.py │ │ ├── halfcheetah_medium_pd_config.py │ │ ├── halfcheetah_medium_replay_cql_config.py │ │ ├── halfcheetah_medium_replay_dt_config.py │ │ ├── halfcheetah_medium_replay_iql_config.py │ │ ├── halfcheetah_medium_replay_td3bc_config.py │ │ ├── halfcheetah_medium_td3bc_config.py │ │ ├── halfcheetah_random_cql_config.py │ │ ├── halfcheetah_random_dt_config.py │ │ ├── halfcheetah_random_td3bc_config.py │ │ ├── hopper_expert_cql_config.py │ │ ├── hopper_expert_dt_config.py │ │ ├── hopper_expert_td3bc_config.py │ │ ├── hopper_medium_bcq_config.py │ │ ├── hopper_medium_cql_config.py │ │ ├── hopper_medium_dt_config.py │ │ ├── hopper_medium_edac_config.py │ │ ├── hopper_medium_expert_bc_config.py │ │ ├── hopper_medium_expert_bcq_config.py │ │ ├── hopper_medium_expert_cql_config.py │ │ ├── hopper_medium_expert_dt_config.py │ │ ├── hopper_medium_expert_edac_config.py │ │ ├── hopper_medium_expert_ibc_ar_config.py │ │ ├── hopper_medium_expert_ibc_config.py │ │ ├── hopper_medium_expert_ibc_mcmc_config.py │ │ ├── hopper_medium_expert_iql_config.py │ │ ├── hopper_medium_expert_pd_config.py │ │ ├── hopper_medium_expert_qgpo_config.py │ │ ├── hopper_medium_expert_td3bc_config.py │ │ ├── hopper_medium_iql_config.py │ │ ├── hopper_medium_pd_config.py │ │ ├── hopper_medium_replay_cql_config.py │ │ ├── hopper_medium_replay_dt_config.py │ │ ├── hopper_medium_replay_iql_config.py │ │ ├── hopper_medium_replay_td3bc_config.py │ │ ├── hopper_medium_td3bc_config.py │ │ ├── hopper_random_cql_config.py │ │ ├── hopper_random_dt_config.py │ │ ├── hopper_random_td3bc_config.py │ │ ├── kitchen_complete_bc_config.py │ │ ├── kitchen_complete_ibc_ar_config.py │ │ ├── kitchen_complete_ibc_config.py │ │ ├── kitchen_complete_ibc_mcmc_config.py │ │ ├── maze2d_large_pd_config.py │ │ ├── maze2d_medium_pd_config.py │ │ ├── maze2d_umaze_pd_config.py │ │ ├── pen_human_bc_config.py │ │ ├── pen_human_ibc_ar_config.py │ │ ├── pen_human_ibc_config.py │ │ ├── pen_human_ibc_mcmc_config.py │ │ ├── walker2d_expert_cql_config.py │ │ ├── walker2d_expert_dt_config.py │ │ ├── walker2d_expert_td3bc_config.py │ │ ├── walker2d_medium_cql_config.py │ │ ├── walker2d_medium_dt_config.py │ │ ├── walker2d_medium_expert_cql_config.py │ │ ├── walker2d_medium_expert_dt_config.py │ │ ├── walker2d_medium_expert_iql_config.py │ │ ├── walker2d_medium_expert_pd_config.py │ │ ├── walker2d_medium_expert_qgpo_config.py │ │ ├── walker2d_medium_expert_td3bc_config.py │ │ ├── walker2d_medium_iql_config.py │ │ ├── walker2d_medium_pd_config.py │ │ ├── walker2d_medium_replay_cql_config.py │ │ ├── walker2d_medium_replay_dt_config.py │ │ ├── walker2d_medium_replay_iql_config.py │ │ ├── walker2d_medium_replay_td3bc_config.py │ │ ├── walker2d_medium_td3bc_config.py │ │ ├── walker2d_random_cql_config.py │ │ ├── walker2d_random_dt_config.py │ │ └── walker2d_random_td3bc_config.py │ ├── d4rl.gif │ ├── entry │ │ ├── __init__.py │ │ ├── d4rl_bcq_main.py │ │ ├── d4rl_cql_main.py │ │ ├── d4rl_dt_mujoco.py │ │ ├── d4rl_edac_main.py │ │ ├── d4rl_ibc_main.py │ │ ├── d4rl_iql_main.py │ │ ├── d4rl_pd_main.py │ │ └── d4rl_td3_bc_main.py │ └── envs │ │ ├── __init__.py │ │ ├── d4rl_env.py │ │ └── d4rl_wrappers.py ├── dmc2gym │ ├── __init__.py │ ├── config │ │ ├── cartpole_balance │ │ │ └── cartpole_balance_dreamer_config.py │ │ ├── cheetah_run │ │ │ └── cheetah_run_dreamer_config.py │ │ ├── dmc2gym_dreamer_config.py │ │ ├── dmc2gym_ppo_config.py │ │ ├── dmc2gym_sac_pixel_config.py │ │ ├── dmc2gym_sac_state_config.py │ │ └── walker_walk │ │ │ └── walker_walk_dreamer_config.py │ ├── dmc2gym_cheetah.png │ ├── entry │ │ ├── dmc2gym_onppo_main.py │ │ ├── dmc2gym_sac_pixel_main.py │ │ ├── dmc2gym_sac_state_main.py │ │ └── dmc2gym_save_replay_example.py │ └── envs │ │ ├── __init__.py │ │ ├── dmc2gym_env.py │ │ └── test_dmc2gym_env.py ├── evogym │ ├── __init__.py │ ├── config │ │ ├── bridgewalker_ddpg_config.py │ │ ├── carrier_ppo_config.py │ │ ├── walker_ddpg_config.py │ │ └── walker_ppo_config.py │ ├── entry │ │ └── walker_ppo_eval.py │ ├── envs │ │ ├── __init__.py │ │ ├── evogym_env.py │ │ ├── test │ │ │ ├── test_evogym_env.py │ │ │ └── visualize_simple_env.py │ │ └── world_data │ │ │ ├── carry_bot.json │ │ │ ├── simple_evironment.json │ │ │ └── speed_bot.json │ └── evogym.gif ├── frozen_lake │ ├── FrozenLake.gif │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ └── frozen_lake_dqn_config.py │ └── envs │ │ ├── __init__.py │ │ ├── frozen_lake_env.py │ │ └── test_frozen_lake_env.py ├── gfootball │ ├── README.md │ ├── __init__.py │ ├── config │ │ ├── gfootball_counter_mappo_config.py │ │ ├── gfootball_counter_masac_config.py │ │ ├── gfootball_keeper_mappo_config.py │ │ └── gfootball_keeper_masac_config.py │ ├── entry │ │ ├── __init__.py │ │ ├── gfootball_bc_config.py │ │ ├── gfootball_bc_kaggle5th_main.py │ │ ├── gfootball_bc_rule_lt0_main.py │ │ ├── gfootball_bc_rule_main.py │ │ ├── gfootball_dqn_config.py │ │ ├── parallel │ │ │ ├── gfootball_il_parallel_config.py │ │ │ └── gfootball_ppo_parallel_config.py │ │ ├── show_dataset.py │ │ └── test_accuracy.py │ ├── envs │ │ ├── __init__.py │ │ ├── action │ │ │ ├── gfootball_action.py │ │ │ └── gfootball_action_runner.py │ │ ├── fake_dataset.py │ │ ├── gfootball_academy_env.py │ │ ├── gfootball_env.py │ │ ├── gfootballsp_env.py │ │ ├── obs │ │ │ ├── encoder.py │ │ │ ├── gfootball_obs.py │ │ │ └── gfootball_obs_runner.py │ │ ├── reward │ │ │ ├── gfootball_reward.py │ │ │ └── gfootball_reward_runner.py │ │ └── tests │ │ │ ├── test_env_gfootball.py │ │ │ └── test_env_gfootball_academy.py │ ├── gfootball.gif │ ├── model │ │ ├── __init__.py │ │ ├── bots │ │ │ ├── TamakEriFever │ │ │ │ ├── config.yaml │ │ │ │ ├── football │ │ │ │ │ └── util.py │ │ │ │ ├── football_ikki.py │ │ │ │ ├── handyrl_core │ │ │ │ │ ├── model.py │ │ │ │ │ └── util.py │ │ │ │ ├── readme.md │ │ │ │ ├── submission.py │ │ │ │ └── view_test.py │ │ │ ├── __init__.py │ │ │ ├── kaggle_5th_place_model.py │ │ │ └── rule_based_bot_model.py │ │ ├── conv1d │ │ │ ├── conv1d.py │ │ │ └── conv1d_default_config.py │ │ └── q_network │ │ │ ├── football_q_network.py │ │ │ ├── football_q_network_default_config.py │ │ │ └── tests │ │ │ └── test_football_model.py │ ├── policy │ │ ├── __init__.py │ │ └── ppo_lstm.py │ └── replay.py ├── gobigger_overview.gif ├── gym_anytrading │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ └── stocks_dqn_config.py │ ├── envs │ │ ├── README.md │ │ ├── __init__.py │ │ ├── data │ │ │ └── README.md │ │ ├── position.png │ │ ├── profit.png │ │ ├── statemachine.png │ │ ├── stocks_env.py │ │ ├── test_stocks_env.py │ │ └── trading_env.py │ └── worker │ │ ├── __init__.py │ │ └── trading_serial_evaluator.py ├── gym_hybrid │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ ├── gym_hybrid_ddpg_config.py │ │ ├── gym_hybrid_hppo_config.py │ │ ├── gym_hybrid_mpdqn_config.py │ │ └── gym_hybrid_pdqn_config.py │ ├── entry │ │ ├── __init__.py │ │ ├── gym_hybrid_ddpg_eval.py │ │ └── gym_hybrid_ddpg_main.py │ ├── envs │ │ ├── README.md │ │ ├── __init__.py │ │ ├── gym-hybrid │ │ │ ├── README.md │ │ │ ├── gym_hybrid │ │ │ │ ├── __init__.py │ │ │ │ ├── agents.py │ │ │ │ ├── bg.jpg │ │ │ │ ├── environments.py │ │ │ │ └── target.png │ │ │ ├── setup.py │ │ │ └── tests │ │ │ │ ├── hardmove.py │ │ │ │ ├── moving.py │ │ │ │ ├── record.py │ │ │ │ ├── render.py │ │ │ │ └── sliding.py │ │ ├── gym_hybrid_env.py │ │ └── test_gym_hybrid_env.py │ └── moving_v0.gif ├── gym_pybullet_drones │ ├── __init__.py │ ├── config │ │ ├── flythrugate_onppo_config.py │ │ └── takeoffaviary_onppo_config.py │ ├── entry │ │ ├── flythrugate_onppo_eval.py │ │ └── takeoffaviary_onppo_eval.py │ ├── envs │ │ ├── __init__.py │ │ ├── gym_pybullet_drones_env.py │ │ ├── test_ding_env.py │ │ └── test_ori_env.py │ └── gym_pybullet_drones.gif ├── gym_soccer │ ├── __init__.py │ ├── config │ │ └── gym_soccer_pdqn_config.py │ ├── envs │ │ ├── README.md │ │ ├── __init__.py │ │ ├── gym_soccer_env.py │ │ └── test_gym_soccer_env.py │ └── half_offensive.gif ├── image_classification │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── dataset.py │ │ └── sampler.py │ ├── entry │ │ ├── imagenet_res18_config.py │ │ └── imagenet_res18_main.py │ ├── imagenet.png │ └── policy │ │ ├── __init__.py │ │ └── policy.py ├── ising_env │ ├── __init__.py │ ├── config │ │ └── ising_mfq_config.py │ ├── entry │ │ └── ising_mfq_eval.py │ ├── envs │ │ ├── __init__.py │ │ ├── ising_model │ │ │ ├── Ising.py │ │ │ ├── __init__.py │ │ │ └── multiagent │ │ │ │ ├── __init__.py │ │ │ │ ├── core.py │ │ │ │ └── environment.py │ │ ├── ising_model_env.py │ │ └── test_ising_model_env.py │ └── ising_env.gif ├── league_demo │ ├── __init__.py │ ├── demo_league.py │ ├── game_env.py │ ├── league_demo.png │ ├── league_demo_collector.py │ ├── league_demo_ppo_config.py │ ├── league_demo_ppo_main.py │ ├── selfplay_demo_ppo_config.py │ └── selfplay_demo_ppo_main.py ├── mario │ ├── __init__.py │ ├── mario.gif │ ├── mario_dqn_config.py │ ├── mario_dqn_example.py │ └── mario_dqn_main.py ├── maze │ ├── __init__.py │ ├── config │ │ ├── maze_bc_config.py │ │ └── maze_pc_config.py │ ├── entry │ │ └── maze_bc_main.py │ └── envs │ │ ├── __init__.py │ │ ├── maze_env.py │ │ └── test_maze_env.py ├── metadrive │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ ├── metadrive_onppo_config.py │ │ └── metadrive_onppo_eval_config.py │ ├── env │ │ ├── __init__.py │ │ ├── drive_env.py │ │ ├── drive_utils.py │ │ └── drive_wrapper.py │ └── metadrive_env.gif ├── minigrid │ ├── __init__.py │ ├── config │ │ ├── minigrid_dreamer_config.py │ │ ├── minigrid_icm_offppo_config.py │ │ ├── minigrid_icm_onppo_config.py │ │ ├── minigrid_ngu_config.py │ │ ├── minigrid_offppo_config.py │ │ ├── minigrid_onppo_config.py │ │ ├── minigrid_onppo_stdim_config.py │ │ ├── minigrid_r2d2_config.py │ │ └── minigrid_rnd_onppo_config.py │ ├── entry │ │ └── minigrid_onppo_main.py │ ├── envs │ │ ├── __init__.py │ │ ├── app_key_to_door_treasure.py │ │ ├── minigrid_env.py │ │ ├── minigrid_wrapper.py │ │ ├── noisy_tv.py │ │ └── test_minigrid_env.py │ ├── minigrid.gif │ └── utils │ │ └── eval.py ├── mujoco │ ├── __init__.py │ ├── addition │ │ └── install_mesa.sh │ ├── config │ │ ├── __init__.py │ │ ├── ant_ddpg_config.py │ │ ├── ant_gail_sac_config.py │ │ ├── ant_onppo_config.py │ │ ├── ant_ppo_config.py │ │ ├── ant_sac_config.py │ │ ├── ant_td3_config.py │ │ ├── ant_trex_onppo_config.py │ │ ├── ant_trex_sac_config.py │ │ ├── halfcheetah_bco_config.py │ │ ├── halfcheetah_bdq_config.py │ │ ├── halfcheetah_d4pg_config.py │ │ ├── halfcheetah_ddpg_config.py │ │ ├── halfcheetah_gail_sac_config.py │ │ ├── halfcheetah_gcl_sac_config.py │ │ ├── halfcheetah_onppo_config.py │ │ ├── halfcheetah_sac_config.py │ │ ├── halfcheetah_sqil_sac_config.py │ │ ├── halfcheetah_td3_config.py │ │ ├── halfcheetah_trex_onppo_config.py │ │ ├── halfcheetah_trex_sac_config.py │ │ ├── hopper_bco_config.py │ │ ├── hopper_bdq_config.py │ │ ├── hopper_cql_config.py │ │ ├── hopper_d4pg_config.py │ │ ├── hopper_ddpg_config.py │ │ ├── hopper_gail_sac_config.py │ │ ├── hopper_gcl_config.py │ │ ├── hopper_onppo_config.py │ │ ├── hopper_sac_config.py │ │ ├── hopper_sac_data_generation_config.py │ │ ├── hopper_sqil_sac_config.py │ │ ├── hopper_td3_bc_config.py │ │ ├── hopper_td3_config.py │ │ ├── hopper_td3_data_generation_config.py │ │ ├── hopper_trex_onppo_config.py │ │ ├── hopper_trex_sac_config.py │ │ ├── mbrl │ │ │ ├── halfcheetah_mbsac_mbpo_config.py │ │ │ ├── halfcheetah_sac_mbpo_config.py │ │ │ ├── halfcheetah_stevesac_mbpo_config.py │ │ │ ├── hopper_mbsac_mbpo_config.py │ │ │ ├── hopper_sac_mbpo_config.py │ │ │ ├── hopper_stevesac_mbpo_config.py │ │ │ ├── walker2d_mbsac_mbpo_config.py │ │ │ ├── walker2d_sac_mbpo_config.py │ │ │ └── walker2d_stevesac_mbpo_config.py │ │ ├── walker2d_d4pg_config.py │ │ ├── walker2d_ddpg_config.py │ │ ├── walker2d_gail_ddpg_config.py │ │ ├── walker2d_gail_sac_config.py │ │ ├── walker2d_gcl_config.py │ │ ├── walker2d_onppo_config.py │ │ ├── walker2d_sac_config.py │ │ ├── walker2d_sqil_sac_config.py │ │ ├── walker2d_td3_config.py │ │ ├── walker2d_trex_onppo_config.py │ │ └── walker2d_trex_sac_config.py │ ├── entry │ │ ├── __init__.py │ │ ├── mujoco_cql_generation_main.py │ │ ├── mujoco_cql_main.py │ │ ├── mujoco_d4pg_main.py │ │ ├── mujoco_ddpg_eval.py │ │ ├── mujoco_ddpg_main.py │ │ ├── mujoco_ppo_main.py │ │ └── mujoco_td3_bc_main.py │ ├── envs │ │ ├── __init__.py │ │ ├── mujoco_disc_env.py │ │ ├── mujoco_env.py │ │ ├── mujoco_gym_env.py │ │ ├── mujoco_wrappers.py │ │ └── test │ │ │ ├── test_mujoco_disc_env.py │ │ │ ├── test_mujoco_env.py │ │ │ └── test_mujoco_gym_env.py │ ├── example │ │ ├── mujoco_bc_main.py │ │ └── mujoco_sac.py │ └── mujoco.gif ├── multiagent_mujoco │ ├── README.md │ ├── __init__.py │ ├── config │ │ ├── ant_maddpg_config.py │ │ ├── ant_mappo_config.py │ │ ├── ant_masac_config.py │ │ ├── ant_matd3_config.py │ │ ├── halfcheetah_happo_config.py │ │ ├── halfcheetah_mappo_config.py │ │ └── walker2d_happo_config.py │ └── envs │ │ ├── __init__.py │ │ ├── assets │ │ ├── .gitignore │ │ ├── __init__.py │ │ ├── coupled_half_cheetah.xml │ │ ├── manyagent_ant.xml │ │ ├── manyagent_ant.xml.template │ │ ├── manyagent_ant__stage1.xml │ │ ├── manyagent_swimmer.xml.template │ │ ├── manyagent_swimmer__bckp2.xml │ │ └── manyagent_swimmer_bckp.xml │ │ ├── coupled_half_cheetah.py │ │ ├── manyagent_ant.py │ │ ├── manyagent_swimmer.py │ │ ├── mujoco_multi.py │ │ ├── multi_mujoco_env.py │ │ ├── multiagentenv.py │ │ └── obsk.py ├── overcooked │ ├── README.md │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ └── overcooked_ppo_config.py │ ├── envs │ │ ├── __init__.py │ │ ├── overcooked_env.py │ │ └── test_overcooked_env.py │ └── overcooked.gif ├── petting_zoo │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ ├── ptz_pistonball_qmix_config.py │ │ ├── ptz_simple_spread_atoc_config.py │ │ ├── ptz_simple_spread_collaq_config.py │ │ ├── ptz_simple_spread_coma_config.py │ │ ├── ptz_simple_spread_happo_config.py │ │ ├── ptz_simple_spread_maddpg_config.py │ │ ├── ptz_simple_spread_madqn_config.py │ │ ├── ptz_simple_spread_mappo_config.py │ │ ├── ptz_simple_spread_masac_config.py │ │ ├── ptz_simple_spread_qmix_config.py │ │ ├── ptz_simple_spread_qtran_config.py │ │ ├── ptz_simple_spread_vdn_config.py │ │ └── ptz_simple_spread_wqmix_config.py │ ├── entry │ │ └── ptz_simple_spread_eval.py │ ├── envs │ │ ├── __init__.py │ │ ├── petting_zoo_pistonball_env.py │ │ ├── petting_zoo_simple_spread_env.py │ │ ├── test_petting_zoo_pistonball_env.py │ │ └── test_petting_zoo_simple_spread_env.py │ └── petting_zoo_mpe_simple_spread.gif ├── pomdp │ ├── __init__.py │ ├── config │ │ ├── pomdp_dqn_config.py │ │ └── pomdp_ppo_config.py │ └── envs │ │ ├── __init__.py │ │ ├── atari_env.py │ │ ├── atari_wrappers.py │ │ └── test_atari_env.py ├── procgen │ ├── README.md │ ├── __init__.py │ ├── coinrun.gif │ ├── coinrun.png │ ├── coinrun_dqn.svg │ ├── coinrun_ppo.svg │ ├── config │ │ ├── __init__.py │ │ ├── bigfish_plr_config.py │ │ ├── bigfish_ppg_config.py │ │ ├── coinrun_dqn_config.py │ │ ├── coinrun_ppg_config.py │ │ ├── coinrun_ppo_config.py │ │ ├── maze_dqn_config.py │ │ ├── maze_ppg_config.py │ │ └── maze_ppo_config.py │ ├── entry │ │ └── coinrun_onppo_main.py │ ├── envs │ │ ├── __init__.py │ │ ├── procgen_env.py │ │ └── test_coinrun_env.py │ ├── maze.gif │ ├── maze.png │ └── maze_dqn.svg ├── pybullet │ ├── __init__.py │ ├── envs │ │ ├── __init__.py │ │ ├── pybullet_env.py │ │ └── pybullet_wrappers.py │ └── pybullet.gif ├── rocket │ ├── README.md │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ ├── rocket_hover_ppo_config.py │ │ └── rocket_landing_ppo_config.py │ ├── entry │ │ ├── __init__.py │ │ ├── rocket_hover_onppo_main_v2.py │ │ ├── rocket_hover_ppo_main.py │ │ ├── rocket_landing_onppo_main_v2.py │ │ └── rocket_landing_ppo_main.py │ └── envs │ │ ├── __init__.py │ │ ├── rocket_env.py │ │ └── test_rocket_env.py ├── slime_volley │ ├── __init__.py │ ├── config │ │ ├── slime_volley_league_ppo_config.py │ │ └── slime_volley_ppo_config.py │ ├── entry │ │ ├── slime_volley_league_ppo_main.py │ │ └── slime_volley_selfplay_ppo_main.py │ ├── envs │ │ ├── __init__.py │ │ ├── slime_volley_env.py │ │ └── test_slime_volley_env.py │ └── slime_volley.gif ├── smac │ ├── README.md │ ├── __init__.py │ ├── config │ │ ├── smac_10m11m_mappo_config.py │ │ ├── smac_10m11m_masac_config.py │ │ ├── smac_25m_mappo_config.py │ │ ├── smac_25m_masac_config.py │ │ ├── smac_27m30m_mappo_config.py │ │ ├── smac_2c64zg_mappo_config.py │ │ ├── smac_2c64zg_masac_config.py │ │ ├── smac_2c64zg_qmix_config.py │ │ ├── smac_2s3z_qmix_config.py │ │ ├── smac_2s3z_qtran_config.py │ │ ├── smac_3m_masac_config.py │ │ ├── smac_3s5z_collaq_config.py │ │ ├── smac_3s5z_collaq_per_config.py │ │ ├── smac_3s5z_coma_config.py │ │ ├── smac_3s5z_madqn_config.py │ │ ├── smac_3s5z_mappo_config.py │ │ ├── smac_3s5z_masac_config.py │ │ ├── smac_3s5z_qmix_config.py │ │ ├── smac_3s5z_qtran_config.py │ │ ├── smac_3s5z_wqmix_config.py │ │ ├── smac_3s5zvs3s6z_madqn_config.py │ │ ├── smac_3s5zvs3s6z_mappo_config.py │ │ ├── smac_3s5zvs3s6z_masac_config.py │ │ ├── smac_5m6m_collaq_config.py │ │ ├── smac_5m6m_madqn_config.py │ │ ├── smac_5m6m_mappo_config.py │ │ ├── smac_5m6m_masac_config.py │ │ ├── smac_5m6m_qmix_config.py │ │ ├── smac_5m6m_qtran_config.py │ │ ├── smac_5m6m_wqmix_config.py │ │ ├── smac_8m9m_madqn_config.py │ │ ├── smac_8m9m_mappo_config.py │ │ ├── smac_8m9m_masac_config.py │ │ ├── smac_MMM2_collaq_config.py │ │ ├── smac_MMM2_coma_config.py │ │ ├── smac_MMM2_madqn_config.py │ │ ├── smac_MMM2_mappo_config.py │ │ ├── smac_MMM2_masac_config.py │ │ ├── smac_MMM2_qmix_config.py │ │ ├── smac_MMM2_wqmix_config.py │ │ ├── smac_MMM_collaq_config.py │ │ ├── smac_MMM_coma_config.py │ │ ├── smac_MMM_madqn_config.py │ │ ├── smac_MMM_mappo_config.py │ │ ├── smac_MMM_masac_config.py │ │ ├── smac_MMM_qmix_config.py │ │ ├── smac_MMM_qtran_config.py │ │ ├── smac_MMM_wqmix_config.py │ │ ├── smac_corridor_mappo_config.py │ │ └── smac_corridor_masac_config.py │ ├── envs │ │ ├── __init__.py │ │ ├── fake_smac_env.py │ │ ├── maps │ │ │ ├── README.md │ │ │ ├── SMAC_Maps │ │ │ │ ├── 10m_vs_11m.SC2Map │ │ │ │ ├── 1c3s5z.SC2Map │ │ │ │ ├── 25m.SC2Map │ │ │ │ ├── 27m_vs_30m.SC2Map │ │ │ │ ├── 2c_vs_64zg.SC2Map │ │ │ │ ├── 2m_vs_1z.SC2Map │ │ │ │ ├── 2s3z.SC2Map │ │ │ │ ├── 2s_vs_1sc.SC2Map │ │ │ │ ├── 3m.SC2Map │ │ │ │ ├── 3s5z.SC2Map │ │ │ │ ├── 3s5z_vs_3s6z.SC2Map │ │ │ │ ├── 3s_vs_3z.SC2Map │ │ │ │ ├── 3s_vs_4z.SC2Map │ │ │ │ ├── 3s_vs_5z.SC2Map │ │ │ │ ├── 5m_vs_6m.SC2Map │ │ │ │ ├── 6h_vs_8z.SC2Map │ │ │ │ ├── 8m.SC2Map │ │ │ │ ├── 8m_vs_9m.SC2Map │ │ │ │ ├── MMM.SC2Map │ │ │ │ ├── MMM2.SC2Map │ │ │ │ ├── __init__.py │ │ │ │ ├── bane_vs_bane.SC2Map │ │ │ │ ├── corridor.SC2Map │ │ │ │ ├── infestor_viper.SC2Map │ │ │ │ └── so_many_baneling.SC2Map │ │ │ ├── SMAC_Maps_two_player │ │ │ │ ├── 3m.SC2Map │ │ │ │ ├── 3s5z.SC2Map │ │ │ │ └── __init__.py │ │ │ └── __init__.py │ │ ├── smac_action.py │ │ ├── smac_env.py │ │ ├── smac_map.py │ │ ├── smac_reward.py │ │ └── test_smac_env.py │ ├── smac.gif │ └── utils │ │ └── eval.py ├── sokoban │ ├── __init__.py │ └── envs │ │ ├── __init__.py │ │ ├── sokoban_env.py │ │ ├── sokoban_wrappers.py │ │ └── test_sokoban_env.py ├── tabmwp │ ├── README.md │ ├── __init__.py │ ├── benchmark.png │ ├── config │ │ ├── tabmwp_awr_config.py │ │ └── tabmwp_pg_config.py │ ├── envs │ │ ├── __init__.py │ │ ├── tabmwp_env.py │ │ ├── test_tabmwp_env.py │ │ └── utils.py │ └── tabmwp.jpeg └── taxi │ ├── Taxi-v3_episode_0.gif │ ├── __init__.py │ ├── config │ ├── __init__.py │ └── taxi_dqn_config.py │ ├── entry │ └── taxi_dqn_deploy.py │ └── envs │ ├── __init__.py │ ├── taxi_env.py │ └── test_taxi_env.py ├── docker ├── Dockerfile.base ├── Dockerfile.env ├── Dockerfile.hpc └── Dockerfile.rpc ├── format.sh ├── pytest.ini └── setup.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | concurrency = multiprocessing,thread 3 | omit = 4 | ding/utils/slurm_helper.py 5 | ding/utils/file_helper.py 6 | ding/utils/linklink_dist_helper.py 7 | ding/utils/pytorch_ddp_dist_helper.py 8 | ding/utils/k8s_helper.py 9 | ding/utils/tests/test_k8s_launcher.py 10 | ding/utils/time_helper_cuda.py 11 | ding/utils/time_helper_base.py 12 | ding/utils/data/tests/test_dataloader.py 13 | ding/config/utils.py 14 | ding/entry/tests/test_serial_entry_algo.py 15 | ding/entry/tests/test_serial_entry.py 16 | ding/entry/dist_entry.py 17 | ding/entry/cli.py 18 | ding/entry/predefined_config.py 19 | ding/hpc_rl/* 20 | ding/worker/collector/tests/speed_test/* 21 | ding/envs/env_wrappers/env_wrappers.py 22 | ding/envs/env_manager/tests/test_env_supervisor.py 23 | ding/rl_utils/tests/test_ppg.py 24 | ding/scripts/tests/test_parallel_socket.py 25 | ding/data/buffer/tests/test_benchmark.py 26 | ding/example/* 27 | ding/torch_utils/loss/tests/* 28 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore=F401,F841,F403,E226,E126,W504,E265,E722,W503,W605,E741,E122,E731 3 | max-line-length=120 4 | statistics 5 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/custom.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Custom issue template 3 | about: Describe this issue template's purpose here. 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | - [ ] I have marked all applicable categories: 11 | + [ ] exception-raising bug 12 | + [ ] RL algorithm bug 13 | + [ ] system worker bug 14 | + [ ] system utils bug 15 | + [ ] code design/refactor 16 | + [ ] documentation request 17 | + [ ] new feature request 18 | - [ ] I have visited the [readme](https://github.com/opendilab/DI-engine/blob/github-dev/README.md) and [doc](https://opendilab.github.io/DI-engine/) 19 | - [ ] I have searched through the [issue tracker](https://github.com/opendilab/DI-engine/issues) and [pr tracker](https://github.com/opendilab/DI-engine/pulls) 20 | - [ ] I have mentioned version numbers, operating system and environment, where applicable: 21 | ```python 22 | import ding, torch, sys 23 | print(ding.__version__, torch.__version__, sys.version, sys.platform) 24 | ``` 25 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | 4 | ## Related Issue 5 | 6 | 7 | ## TODO 8 | 9 | 10 | ## Check List 11 | 12 | - [ ] merge the latest version source branch/repo, and resolve all the conflicts 13 | - [ ] pass style check 14 | - [ ] pass all the tests 15 | -------------------------------------------------------------------------------- /.github/workflows/algo_test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will check pytest 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: algo_test 5 | 6 | on: 7 | push: 8 | paths: 9 | - "ding/policy/*" 10 | - "ding/model/*" 11 | - "ding/rl_utils/*" 12 | 13 | jobs: 14 | test_algotest: 15 | runs-on: ubuntu-latest 16 | if: "!contains(github.event.head_commit.message, 'ci skip')" 17 | strategy: 18 | matrix: 19 | python-version: [3.8, 3.9] 20 | 21 | steps: 22 | - uses: actions/checkout@v4 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v5 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: do_algotest 28 | env: 29 | WORKERS: 4 30 | DURATIONS: 600 31 | run: | 32 | python -m pip install . 33 | python -m pip install ".[test,k8s]" 34 | python -m pip install transformers 35 | ./ding/scripts/install-k8s-tools.sh 36 | make algotest 37 | -------------------------------------------------------------------------------- /.github/workflows/envpool_test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will check pytest 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: envpool_test 5 | 6 | on: [push, pull_request] 7 | 8 | jobs: 9 | test_envpooltest: 10 | runs-on: ubuntu-latest 11 | if: "!contains(github.event.head_commit.message, 'ci skip')" 12 | strategy: 13 | matrix: 14 | python-version: [3.8] # Envpool only supports python>=3.7 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: do_envpool_test 23 | run: | 24 | python -m pip install . 25 | python -m pip install ".[test,k8s]" 26 | python -m pip install ".[envpool]" 27 | python -m pip install transformers 28 | ./ding/scripts/install-k8s-tools.sh 29 | make envpooltest 30 | -------------------------------------------------------------------------------- /.github/workflows/platform_test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will check pytest 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: platform_test 5 | 6 | on: [push, pull_request] 7 | 8 | jobs: 9 | test_unittest: 10 | runs-on: ${{ matrix.os }} 11 | if: "!contains(github.event.head_commit.message, 'ci skip')" 12 | strategy: 13 | matrix: 14 | os: [macos-13, windows-latest] 15 | python-version: [3.8, 3.9] 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: do_platform_test 24 | timeout-minutes: 30 25 | run: | 26 | python -m pip install . 27 | python -m pip install ".[test,k8s]" 28 | python -m pip install transformers 29 | python -m pip uninstall pytest-timeouts -y 30 | make platformtest 31 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: package_release 2 | 3 | on: [push] 4 | 5 | jobs: 6 | release: 7 | name: Publish to official pypi 8 | runs-on: ${{ matrix.os }} 9 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') 10 | strategy: 11 | matrix: 12 | os: 13 | - ubuntu-latest 14 | python-version: [3.8] 15 | 16 | steps: 17 | - name: Checkout code 18 | uses: actions/checkout@v4 19 | - name: Set up python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: Set up python dependences 24 | run: | 25 | pip install --upgrade pip 26 | pip install --upgrade flake8 setuptools wheel twine 27 | pip install . 28 | pip install --upgrade build 29 | - name: Build packages 30 | run: | 31 | python -m build --sdist --wheel --outdir dist/ 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@release/v1 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.github/workflows/release_conda.yml: -------------------------------------------------------------------------------- 1 | name: package_release_conda 2 | 3 | on: [push] 4 | 5 | jobs: 6 | release: 7 | runs-on: ubuntu-latest 8 | if: github.event_name == 'push' && (startsWith(github.ref, 'refs/tags') || contains(github.event.head_commit.message, 'conda')) 9 | steps: 10 | - uses: actions/checkout@v4 11 | - name: publish-to-conda 12 | uses: fcakyon/conda-publish-action@v1.3 13 | with: 14 | subdir: 'conda' 15 | anacondatoken: ${{ secrets.ANACONDA_TOKEN }} 16 | platforms: 'win osx linux' 17 | -------------------------------------------------------------------------------- /.github/workflows/style.yml: -------------------------------------------------------------------------------- 1 | # This workflow will check flake style 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: style 5 | 6 | on: [push, pull_request] 7 | 8 | jobs: 9 | style: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: ['3.8', '3.9', '3.10'] 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: code style 22 | run: | 23 | python -m pip install "yapf==0.29.0" 24 | python -m pip install "flake8<=3.9.2" 25 | python -m pip install "importlib-metadata<5.0.0" 26 | make format_test flake_check 27 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | # For explanation and more information: https://github.com/google/yapf 3 | BASED_ON_STYLE=pep8 4 | DEDENT_CLOSING_BRACKETS=True 5 | SPLIT_BEFORE_FIRST_ARGUMENT=True 6 | ALLOW_SPLIT_BEFORE_DICT_VALUE=False 7 | JOIN_MULTIPLE_LINES=False 8 | COLUMN_LIMIT=120 9 | BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF=True 10 | BLANK_LINES_AROUND_TOP_LEVEL_DEFINITION=2 11 | SPACES_AROUND_POWER_OPERATOR=True 12 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | [Git Guide](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/git_guide.html) 2 | 3 | [GitHub Cooperation Guide](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/issue_pr.html) 4 | 5 | - [Code Style](https://di-engine-docs.readthedocs.io/en/latest/21_code_style/index.html) 6 | - [Unit Test](https://di-engine-docs.readthedocs.io/en/latest/22_test/index.html) 7 | - [Code Review](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/issue_pr.html#pr-s-code-review) 8 | -------------------------------------------------------------------------------- /assets/wechat.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/assets/wechat.jpeg -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | # basic 6 | target: auto 7 | threshold: 0.5% 8 | if_ci_failed: success #success, failure, error, ignore 9 | -------------------------------------------------------------------------------- /conda/conda_build_config.yaml: -------------------------------------------------------------------------------- 1 | python: 2 | - 3.7 3 | -------------------------------------------------------------------------------- /conda/meta.yaml: -------------------------------------------------------------------------------- 1 | {% set data = load_setup_py_data() %} 2 | package: 3 | name: di-engine 4 | version: v0.5.3 5 | 6 | source: 7 | path: .. 8 | 9 | build: 10 | number: 0 11 | script: python -m pip install . -vv 12 | entry_points: 13 | - ding = ding.entry.cli:cli 14 | 15 | requirements: 16 | build: 17 | - python 18 | - setuptools 19 | run: 20 | - python 21 | 22 | test: 23 | imports: 24 | - ding 25 | - dizoo 26 | 27 | about: 28 | home: https://github.com/opendilab/DI-engine 29 | license: Apache-2.0 30 | license_file: LICENSE 31 | summary: DI-engine is a generalized Decision Intelligence engine (https://github.com/opendilab/DI-engine). 32 | description: Please refer to https://di-engine-docs.readthedocs.io/en/latest/00_intro/index.html#what-is-di-engine 33 | dev_url: https://github.com/opendilab/DI-engine 34 | doc_url: https://di-engine-docs.readthedocs.io/en/latest/index.html 35 | doc_source_url: https://github.com/opendilab/DI-engine-docs 36 | -------------------------------------------------------------------------------- /ding/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | __TITLE__ = 'DI-engine' 4 | __VERSION__ = 'v0.5.3' 5 | __DESCRIPTION__ = 'Decision AI Engine' 6 | __AUTHOR__ = "OpenDILab Contributors" 7 | __AUTHOR_EMAIL__ = "opendilab@pjlab.org.cn" 8 | __version__ = __VERSION__ 9 | 10 | enable_hpc_rl = os.environ.get('ENABLE_DI_HPC', 'false').lower() == 'true' 11 | enable_linklink = os.environ.get('ENABLE_LINKLINK', 'false').lower() == 'true' 12 | enable_numba = True 13 | -------------------------------------------------------------------------------- /ding/bonus/common.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import numpy as np 3 | 4 | 5 | @dataclass 6 | class TrainingReturn: 7 | ''' 8 | Attributions 9 | wandb_url: The weight & biases (wandb) project url of the trainning experiment. 10 | ''' 11 | wandb_url: str 12 | 13 | 14 | @dataclass 15 | class EvalReturn: 16 | ''' 17 | Attributions 18 | eval_value: The mean of evaluation return. 19 | eval_value_std: The standard deviation of evaluation return. 20 | ''' 21 | eval_value: np.float32 22 | eval_value_std: np.float32 23 | -------------------------------------------------------------------------------- /ding/compatibility.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def torch_ge_131(): 5 | return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 131 6 | 7 | 8 | def torch_ge_180(): 9 | return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 180 10 | -------------------------------------------------------------------------------- /ding/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import Config, read_config, save_config, compile_config, compile_config_parallel, read_config_directly, \ 2 | read_config_with_system, save_config_py 3 | from .utils import parallel_transform, parallel_transform_slurm 4 | from .example import A2C, C51, DDPG, DQN, PG, PPOF, PPOOffPolicy, SAC, SQL, TD3 5 | -------------------------------------------------------------------------------- /ding/config/example/A2C/__init__.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | from . import gym_bipedalwalker_v3 3 | from . import gym_lunarlander_v2 4 | 5 | supported_env_cfg = { 6 | gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.cfg, 7 | gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg, 8 | } 9 | 10 | supported_env_cfg = EasyDict(supported_env_cfg) 11 | 12 | supported_env = { 13 | gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.env, 14 | gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env, 15 | } 16 | 17 | supported_env = EasyDict(supported_env) 18 | -------------------------------------------------------------------------------- /ding/config/example/A2C/gym_lunarlander_v2.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | import ding.envs.gym_env 3 | 4 | cfg = dict( 5 | exp_name='LunarLander-v2-A2C', 6 | env=dict( 7 | collector_env_num=8, 8 | evaluator_env_num=8, 9 | env_id='LunarLander-v2', 10 | n_evaluator_episode=8, 11 | stop_value=260, 12 | ), 13 | policy=dict( 14 | cuda=True, 15 | model=dict( 16 | obs_shape=8, 17 | action_shape=4, 18 | ), 19 | learn=dict( 20 | batch_size=64, 21 | learning_rate=3e-4, 22 | entropy_weight=0.001, 23 | adv_norm=True, 24 | ), 25 | collect=dict( 26 | n_sample=64, 27 | discount_factor=0.99, 28 | gae_lambda=0.95, 29 | ), 30 | ), 31 | wandb_logger=dict( 32 | gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False 33 | ), 34 | ) 35 | 36 | cfg = EasyDict(cfg) 37 | 38 | env = ding.envs.gym_env.env 39 | -------------------------------------------------------------------------------- /ding/config/example/C51/__init__.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | from . import gym_lunarlander_v2 3 | from . import gym_pongnoframeskip_v4 4 | from . import gym_qbertnoframeskip_v4 5 | from . import gym_spaceInvadersnoframeskip_v4 6 | 7 | supported_env_cfg = { 8 | gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg, 9 | gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.cfg, 10 | gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.cfg, 11 | gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.cfg, 12 | } 13 | 14 | supported_env_cfg = EasyDict(supported_env_cfg) 15 | 16 | supported_env = { 17 | gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env, 18 | gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.env, 19 | gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.env, 20 | gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.env, 21 | } 22 | 23 | supported_env = EasyDict(supported_env) 24 | -------------------------------------------------------------------------------- /ding/config/example/DQN/__init__.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | from . import gym_lunarlander_v2 3 | from . import gym_pongnoframeskip_v4 4 | from . import gym_qbertnoframeskip_v4 5 | from . import gym_spaceInvadersnoframeskip_v4 6 | 7 | supported_env_cfg = { 8 | gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg, 9 | gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.cfg, 10 | gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.cfg, 11 | gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.cfg, 12 | } 13 | 14 | supported_env_cfg = EasyDict(supported_env_cfg) 15 | 16 | supported_env = { 17 | gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env, 18 | gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.env, 19 | gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.env, 20 | gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.env, 21 | } 22 | 23 | supported_env = EasyDict(supported_env) 24 | -------------------------------------------------------------------------------- /ding/config/example/PG/__init__.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | from . import gym_pendulum_v1 3 | 4 | supported_env_cfg = { 5 | gym_pendulum_v1.cfg.env.env_id: gym_pendulum_v1.cfg, 6 | } 7 | 8 | supported_env_cfg = EasyDict(supported_env_cfg) 9 | 10 | supported_env = { 11 | gym_pendulum_v1.cfg.env.env_id: gym_pendulum_v1.env, 12 | } 13 | 14 | supported_env = EasyDict(supported_env) 15 | -------------------------------------------------------------------------------- /ding/config/example/PG/gym_pendulum_v1.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | import ding.envs.gym_env 3 | 4 | cfg = dict( 5 | exp_name='Pendulum-v1-PG', 6 | seed=0, 7 | env=dict( 8 | env_id='Pendulum-v1', 9 | collector_env_num=8, 10 | evaluator_env_num=5, 11 | n_evaluator_episode=5, 12 | stop_value=-200, 13 | act_scale=True, 14 | ), 15 | policy=dict( 16 | cuda=False, 17 | action_space='continuous', 18 | model=dict( 19 | action_space='continuous', 20 | obs_shape=3, 21 | action_shape=1, 22 | ), 23 | learn=dict( 24 | batch_size=4000, 25 | learning_rate=0.001, 26 | entropy_weight=0.001, 27 | ), 28 | collect=dict( 29 | n_episode=20, 30 | unroll_len=1, 31 | discount_factor=0.99, 32 | ), 33 | eval=dict(evaluator=dict(eval_freq=1, )) 34 | ), 35 | wandb_logger=dict( 36 | gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False 37 | ), 38 | ) 39 | 40 | cfg = EasyDict(cfg) 41 | 42 | env = ding.envs.gym_env.env 43 | -------------------------------------------------------------------------------- /ding/config/example/PPOF/__init__.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | from . import gym_lunarlander_v2 3 | from . import gym_lunarlandercontinuous_v2 4 | 5 | supported_env_cfg = { 6 | gym_lunarlander_v2.cfg.env_id: gym_lunarlander_v2.cfg, 7 | gym_lunarlandercontinuous_v2.cfg.env_id: gym_lunarlandercontinuous_v2.cfg, 8 | } 9 | 10 | supported_env_cfg = EasyDict(supported_env_cfg) 11 | 12 | supported_env = { 13 | gym_lunarlander_v2.cfg.env_id: gym_lunarlander_v2.env, 14 | gym_lunarlandercontinuous_v2.cfg.env_id: gym_lunarlandercontinuous_v2.env, 15 | } 16 | 17 | supported_env = EasyDict(supported_env) 18 | -------------------------------------------------------------------------------- /ding/config/example/PPOF/gym_lunarlander_v2.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | import ding.envs.gym_env 3 | 4 | cfg = dict( 5 | exp_name='LunarLander-v2-PPO', 6 | env_id='LunarLander-v2', 7 | n_sample=400, 8 | value_norm='popart', 9 | ) 10 | 11 | cfg = EasyDict(cfg) 12 | 13 | env = ding.envs.gym_env.env 14 | -------------------------------------------------------------------------------- /ding/config/example/PPOF/gym_lunarlandercontinuous_v2.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | from functools import partial 3 | import ding.envs.gym_env 4 | 5 | cfg = dict( 6 | exp_name='LunarLanderContinuous-V2-PPO', 7 | env_id='LunarLanderContinuous-v2', 8 | action_space='continuous', 9 | n_sample=400, 10 | act_scale=True, 11 | ) 12 | 13 | cfg = EasyDict(cfg) 14 | 15 | env = partial(ding.envs.gym_env.env, continuous=True) 16 | -------------------------------------------------------------------------------- /ding/config/example/PPOOffPolicy/__init__.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | from . import gym_lunarlander_v2 3 | from . import gym_pongnoframeskip_v4 4 | from . import gym_qbertnoframeskip_v4 5 | from . import gym_spaceInvadersnoframeskip_v4 6 | 7 | supported_env_cfg = { 8 | gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg, 9 | gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.cfg, 10 | gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.cfg, 11 | gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.cfg, 12 | } 13 | 14 | supported_env_cfg = EasyDict(supported_env_cfg) 15 | 16 | supported_env = { 17 | gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env, 18 | gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.env, 19 | gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.env, 20 | gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.env, 21 | } 22 | 23 | supported_env = EasyDict(supported_env) 24 | -------------------------------------------------------------------------------- /ding/config/example/SQL/__init__.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | from . import gym_lunarlander_v2 3 | 4 | supported_env_cfg = { 5 | gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg, 6 | } 7 | 8 | supported_env_cfg = EasyDict(supported_env_cfg) 9 | 10 | supported_env = { 11 | gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env, 12 | } 13 | 14 | supported_env = EasyDict(supported_env) 15 | -------------------------------------------------------------------------------- /ding/config/example/TD3/gym_hopper_v3.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | import ding.envs.gym_env 3 | 4 | cfg = dict( 5 | exp_name='Hopper-v3-TD3', 6 | seed=0, 7 | env=dict( 8 | env_id='Hopper-v3', 9 | collector_env_num=8, 10 | evaluator_env_num=8, 11 | n_evaluator_episode=8, 12 | stop_value=6000, 13 | env_wrapper='mujoco_default', 14 | ), 15 | policy=dict( 16 | cuda=True, 17 | random_collect_size=25000, 18 | model=dict( 19 | obs_shape=11, 20 | action_shape=3, 21 | actor_head_hidden_size=256, 22 | critic_head_hidden_size=256, 23 | action_space='regression', 24 | ), 25 | collect=dict(n_sample=1, ), 26 | other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ), 27 | ), 28 | wandb_logger=dict( 29 | gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False 30 | ), 31 | ) 32 | 33 | cfg = EasyDict(cfg) 34 | 35 | env = ding.envs.gym_env.env 36 | -------------------------------------------------------------------------------- /ding/config/example/__init__.py: -------------------------------------------------------------------------------- 1 | from . import A2C 2 | from . import C51 3 | from . import DDPG 4 | from . import DQN 5 | from . import PG 6 | from . import PPOF 7 | from . import PPOOffPolicy 8 | from . import SAC 9 | from . import SQL 10 | from . import TD3 11 | -------------------------------------------------------------------------------- /ding/data/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | from ding.utils.data import create_dataset, offline_data_save_type # for compatibility 3 | from .buffer import * 4 | from .storage import * 5 | from .storage_loader import StorageLoader, FileStorageLoader 6 | from .shm_buffer import ShmBufferContainer, ShmBuffer 7 | from .model_loader import ModelLoader, FileModelLoader 8 | -------------------------------------------------------------------------------- /ding/data/buffer/__init__.py: -------------------------------------------------------------------------------- 1 | from .buffer import Buffer, apply_middleware, BufferedData 2 | from .deque_buffer import DequeBuffer 3 | from .deque_buffer_wrapper import DequeBufferWrapper 4 | -------------------------------------------------------------------------------- /ding/data/buffer/middleware/__init__.py: -------------------------------------------------------------------------------- 1 | from .clone_object import clone_object 2 | from .use_time_check import use_time_check 3 | from .staleness_check import staleness_check 4 | from .priority import PriorityExperienceReplay 5 | from .padding import padding 6 | from .group_sample import group_sample 7 | from .sample_range_view import sample_range_view 8 | -------------------------------------------------------------------------------- /ding/data/buffer/middleware/sample_range_view.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Any, List, Optional, Union, TYPE_CHECKING 2 | from ding.data.buffer import BufferedData 3 | if TYPE_CHECKING: 4 | from ding.data.buffer.buffer import Buffer 5 | 6 | 7 | def sample_range_view(buffer_: 'Buffer', start: Optional[int] = None, end: Optional[int] = None) -> Callable: 8 | """ 9 | Overview: 10 | The middleware that places restrictions on the range of indices during sampling. 11 | Arguments: 12 | - start (:obj:`int`): The starting index. 13 | - end (:obj:`int`): One above the ending index. 14 | """ 15 | assert start is not None or end is not None 16 | if start and start < 0: 17 | start = buffer_.size + start 18 | if end and end < 0: 19 | end = buffer_.size + end 20 | sample_range = slice(start, end) 21 | 22 | def _sample_range_view(action: str, chain: Callable, *args, **kwargs) -> Any: 23 | if action == "sample": 24 | return chain(*args, sample_range=sample_range) 25 | return chain(*args, **kwargs) 26 | 27 | return _sample_range_view 28 | -------------------------------------------------------------------------------- /ding/data/level_replay/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/data/level_replay/__init__.py -------------------------------------------------------------------------------- /ding/data/storage/__init__.py: -------------------------------------------------------------------------------- 1 | from .storage import Storage 2 | from .file import FileStorage, FileModelStorage 3 | -------------------------------------------------------------------------------- /ding/data/storage/file.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from ding.data.storage import Storage 3 | import pickle 4 | 5 | from ding.utils.file_helper import read_file, save_file 6 | 7 | 8 | class FileStorage(Storage): 9 | 10 | def save(self, data: Any) -> None: 11 | with open(self.path, "wb") as f: 12 | pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) 13 | 14 | def load(self) -> Any: 15 | with open(self.path, "rb") as f: 16 | return pickle.load(f) 17 | 18 | 19 | class FileModelStorage(Storage): 20 | 21 | def save(self, state_dict: object) -> None: 22 | save_file(self.path, state_dict) 23 | 24 | def load(self) -> object: 25 | return read_file(self.path) 26 | -------------------------------------------------------------------------------- /ding/data/storage/storage.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any 3 | 4 | 5 | class Storage(ABC): 6 | 7 | def __init__(self, path: str) -> None: 8 | self.path = path 9 | 10 | @abstractmethod 11 | def save(self, data: Any) -> None: 12 | raise NotImplementedError 13 | 14 | @abstractmethod 15 | def load(self) -> Any: 16 | raise NotImplementedError 17 | -------------------------------------------------------------------------------- /ding/data/storage/tests/test_storage.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import pytest 3 | import os 4 | from os import path 5 | from ding.data.storage import FileStorage 6 | 7 | 8 | @pytest.mark.unittest 9 | def test_file_storage(): 10 | path_ = path.join(tempfile.gettempdir(), "test_storage.txt") 11 | try: 12 | storage = FileStorage(path=path_) 13 | storage.save("test") 14 | content = storage.load() 15 | assert content == "test" 16 | finally: 17 | if path.exists(path_): 18 | os.remove(path_) 19 | -------------------------------------------------------------------------------- /ding/data/tests/test_shm_buffer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import timeit 4 | from ding.data.shm_buffer import ShmBuffer 5 | import multiprocessing as mp 6 | 7 | 8 | def subprocess(shm_buf): 9 | data = np.random.rand(1024, 1024).astype(np.float32) 10 | res = timeit.repeat(lambda: shm_buf.fill(data), repeat=5, number=1000) 11 | print("Mean: {:.4f}s, STD: {:.4f}s, Mean each call: {:.4f}ms".format(np.mean(res), np.std(res), np.mean(res))) 12 | 13 | 14 | @pytest.mark.benchmark 15 | def test_shm_buffer(): 16 | data = np.random.rand(1024, 1024).astype(np.float32) 17 | shm_buf = ShmBuffer(data.dtype, data.shape, copy_on_get=False) 18 | proc = mp.Process(target=subprocess, args=[shm_buf]) 19 | proc.start() 20 | proc.join() 21 | -------------------------------------------------------------------------------- /ding/design/dataloader-sequence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/design/dataloader-sequence.png -------------------------------------------------------------------------------- /ding/design/env_state.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/design/env_state.png -------------------------------------------------------------------------------- /ding/design/parallel_main-sequence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/design/parallel_main-sequence.png -------------------------------------------------------------------------------- /ding/design/serial_collector-activity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/design/serial_collector-activity.png -------------------------------------------------------------------------------- /ding/design/serial_evaluator-activity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/design/serial_evaluator-activity.png -------------------------------------------------------------------------------- /ding/design/serial_evaluator-activity.puml: -------------------------------------------------------------------------------- 1 | @startuml serial_evaluator 2 | header Serial Pipeline 3 | title Serial Evaluator 4 | 5 | |#99CCCC|serial_controller| 6 | |#99CCFF|env_manager| 7 | |#CCCCFF|policy| 8 | |#FFCCCC|evaluator| 9 | 10 | |#99CCCC|serial_controller| 11 | start 12 | :init evaluator, set its \nenv_manager and \neval_mode policy; 13 | |#99CCFF|env_manager| 14 | repeat 15 | :return current obs; 16 | |#CCCCFF|policy| 17 | :<b>[network]</b> forward with obs; 18 | |#99CCFF|env_manager| 19 | :env step with action; 20 | |#FFCCCC|evaluator| 21 | if (for every env: env i is done?) then (yes) 22 | |#99CCFF|env_manager| 23 | :env i reset; 24 | |#FFCCCC|evaluator| 25 | :log eval_episode_info; 26 | endif 27 | repeat while (evaluate episodes are not enough?) 28 | |#FFCCCC|evaluator| 29 | :return eval_episode_return; 30 | stop 31 | @enduml 32 | -------------------------------------------------------------------------------- /ding/design/serial_learner-activity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/design/serial_learner-activity.png -------------------------------------------------------------------------------- /ding/design/serial_learner-activity.puml: -------------------------------------------------------------------------------- 1 | @startuml serial_learner 2 | header Serial Pipeline 3 | title Serial Learner 4 | 5 | |#99CCCC|serial_controller| 6 | |#CCCCFF|policy| 7 | |#99CCFF|learner| 8 | 9 | |#99CCCC|serial_controller| 10 | start 11 | :init learner, \nset its learn_mode policy; 12 | |#99CCFF|learner| 13 | :get data from buffer; 14 | |#CCCCFF|policy| 15 | :data forward; 16 | :loss backward; 17 | :optimizer step, gradient update; 18 | |#99CCFF|learner| 19 | :update train info(loss, value) and log; 20 | :update learn info(iteration, priority); 21 | stop 22 | @enduml 23 | -------------------------------------------------------------------------------- /ding/design/serial_main-sequence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/design/serial_main-sequence.png -------------------------------------------------------------------------------- /ding/entry/cli_parsers/__init__.py: -------------------------------------------------------------------------------- 1 | from .slurm_parser import slurm_parser 2 | from .k8s_parser import k8s_parser 3 | PLATFORM_PARSERS = {"slurm": slurm_parser, "k8s": k8s_parser} 4 | -------------------------------------------------------------------------------- /ding/entry/tests/config/agconfig.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: diengine.opendilab.org/v1alpha1 2 | kind: AggregatorConfig 3 | metadata: 4 | name: aggregator-config 5 | namespace: di-system 6 | spec: 7 | aggregator: 8 | template: 9 | spec: 10 | containers: 11 | - name: di-container 12 | image: diorchestrator/ding:v0.1.1 13 | imagePullPolicy: IfNotPresent 14 | env: 15 | - name: PYTHONUNBUFFERED 16 | value: "1" 17 | command: ["/bin/bash", "-c",] 18 | args: 19 | - | 20 | # if code has been changed in the mount path, we have to reinstall cli 21 | # pip install --no-cache-dir -e .; 22 | # pip install --no-cache-dir -e .[common_env] 23 | 24 | ding -m dist --module learner_aggregator 25 | ports: 26 | - name: di-port 27 | containerPort: 22270 28 | -------------------------------------------------------------------------------- /ding/entry/tests/config/k8s-config.yaml: -------------------------------------------------------------------------------- 1 | type: k3s # k3s or local 2 | name: di-cluster 3 | servers: 1 # # of k8s masters 4 | agents: 0 # # of k8s nodes 5 | preload_images: 6 | - diorchestrator/ding:v0.1.1 # di-engine image for training should be preloaded 7 | -------------------------------------------------------------------------------- /ding/entry/tests/test_cli_ditask.py: -------------------------------------------------------------------------------- 1 | from time import sleep 2 | import pytest 3 | import pathlib 4 | import os 5 | from ding.entry.cli_ditask import _cli_ditask 6 | 7 | 8 | def cli_ditask_main(): 9 | sleep(0.1) 10 | 11 | 12 | @pytest.mark.unittest 13 | def test_cli_ditask(): 14 | kwargs = { 15 | "package": os.path.dirname(pathlib.Path(__file__)), 16 | "main": "test_cli_ditask.cli_ditask_main", 17 | "parallel_workers": 1, 18 | "topology": "mesh", 19 | "platform": "k8s", 20 | "protocol": "tcp", 21 | "ports": 50501, 22 | "attach_to": "", 23 | "address": "127.0.0.1", 24 | "labels": "", 25 | "node_ids": 0, 26 | "mq_type": "nng", 27 | "redis_host": "", 28 | "redis_port": "", 29 | "startup_interval": 1 30 | } 31 | os.environ["DI_NODES"] = '127.0.0.1' 32 | os.environ["DI_RANK"] = '0' 33 | try: 34 | _cli_ditask(**kwargs) 35 | finally: 36 | del os.environ["DI_NODES"] 37 | del os.environ["DI_RANK"] 38 | -------------------------------------------------------------------------------- /ding/entry/tests/test_parallel_entry.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from copy import deepcopy 3 | from ding.entry import parallel_pipeline 4 | from dizoo.classic_control.cartpole.config.parallel.cartpole_dqn_config import main_config, create_config,\ 5 | system_config 6 | 7 | 8 | # @pytest.mark.unittest 9 | @pytest.mark.execution_timeout(120.0, method='thread') 10 | def test_dqn(): 11 | config = tuple([deepcopy(main_config), deepcopy(create_config), deepcopy(system_config)]) 12 | config[0].env.stop_value = 9 13 | try: 14 | parallel_pipeline(config, seed=0) 15 | except Exception: 16 | assert False, "pipeline fail" 17 | -------------------------------------------------------------------------------- /ding/entry/tests/test_serial_entry_guided_cost.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from copy import deepcopy 4 | from ding.entry import serial_pipeline_onpolicy, serial_pipeline_guided_cost 5 | from dizoo.classic_control.cartpole.config import cartpole_ppo_config, cartpole_ppo_create_config 6 | from dizoo.classic_control.cartpole.config import cartpole_gcl_ppo_onpolicy_config, \ 7 | cartpole_gcl_ppo_onpolicy_create_config 8 | 9 | 10 | @pytest.mark.unittest 11 | def test_guided_cost(): 12 | expert_policy_state_dict_path = './expert_policy.pth' 13 | config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)] 14 | expert_policy = serial_pipeline_onpolicy(config, seed=0) 15 | torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path) 16 | 17 | config = [deepcopy(cartpole_gcl_ppo_onpolicy_config), deepcopy(cartpole_gcl_ppo_onpolicy_create_config)] 18 | config[0].policy.collect.model_path = expert_policy_state_dict_path 19 | config[0].policy.learn.update_per_collect = 1 20 | try: 21 | serial_pipeline_guided_cost(config, seed=0, max_train_iter=1) 22 | except Exception: 23 | assert False, "pipeline fail" 24 | -------------------------------------------------------------------------------- /ding/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .env import * 2 | from .env_wrappers import * 3 | from .env_manager import * 4 | from .env_manager.ding_env_manager import setup_ding_env_manager 5 | from . import gym_env 6 | -------------------------------------------------------------------------------- /ding/envs/common/__init__.py: -------------------------------------------------------------------------------- 1 | from .common_function import num_first_one_hot, sqrt_one_hot, div_one_hot, div_func, clip_one_hot, \ 2 | reorder_one_hot, reorder_one_hot_array, reorder_boolean_vector, affine_transform, \ 3 | batch_binary_encode, get_postion_vector, save_frames_as_gif 4 | from .env_element import EnvElement, EnvElementInfo 5 | from .env_element_runner import EnvElementRunner 6 | -------------------------------------------------------------------------------- /ding/envs/common/env_element_runner.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any 3 | 4 | from .env_element import EnvElement, IEnvElement, EnvElementInfo 5 | from ..env.base_env import BaseEnv 6 | 7 | 8 | class IEnvElementRunner(IEnvElement): 9 | 10 | @abstractmethod 11 | def get(self, engine: BaseEnv) -> Any: 12 | raise NotImplementedError 13 | 14 | @abstractmethod 15 | def reset(self, *args, **kwargs) -> None: 16 | raise NotImplementedError 17 | 18 | 19 | class EnvElementRunner(IEnvElementRunner): 20 | 21 | def __init__(self, *args, **kwargs) -> None: 22 | self._init(*args, **kwargs) 23 | self._check() 24 | 25 | @abstractmethod 26 | def _init(self, *args, **kwargs) -> None: 27 | # set self._core and other state variable 28 | raise NotImplementedError 29 | 30 | def _check(self) -> None: 31 | flag = [hasattr(self, '_core'), isinstance(self._core, EnvElement)] 32 | assert all(flag), flag 33 | 34 | def __repr__(self) -> str: 35 | return repr(self._core) 36 | 37 | @property 38 | def info(self) -> 'EnvElementInfo': 39 | return self._core.info 40 | -------------------------------------------------------------------------------- /ding/envs/env/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_env import BaseEnv, get_vec_env_setting, BaseEnvTimestep, get_env_cls, create_model_env 2 | from .ding_env_wrapper import DingEnvWrapper 3 | from .default_wrapper import get_default_wrappers 4 | from .env_implementation_check import check_space_dtype, check_array_space, check_reset, check_step, \ 5 | check_different_memory, check_obs_deepcopy, check_all, demonstrate_correct_procedure 6 | -------------------------------------------------------------------------------- /ding/envs/env/tests/__init__.py: -------------------------------------------------------------------------------- 1 | from .demo_env import DemoEnv 2 | -------------------------------------------------------------------------------- /ding/envs/env_manager/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_env_manager import BaseEnvManager, BaseEnvManagerV2, create_env_manager, get_env_manager_cls 2 | from .subprocess_env_manager import AsyncSubprocessEnvManager, SyncSubprocessEnvManager, SubprocessEnvManagerV2 3 | from .gym_vector_env_manager import GymVectorEnvManager 4 | # Do not import PoolEnvManager here, because it depends on installation of `envpool` 5 | from .env_supervisor import EnvSupervisor 6 | -------------------------------------------------------------------------------- /ding/envs/env_manager/ding_env_manager.py: -------------------------------------------------------------------------------- 1 | from . import BaseEnvManagerV2, SubprocessEnvManagerV2 2 | from ..env import DingEnvWrapper 3 | from typing import Optional 4 | from functools import partial 5 | 6 | 7 | def setup_ding_env_manager( 8 | env: DingEnvWrapper, 9 | env_num: int, 10 | context: Optional[str] = None, 11 | debug: bool = False, 12 | caller: str = 'collector' 13 | ) -> BaseEnvManagerV2: 14 | assert caller in ['evaluator', 'collector'] 15 | if debug: 16 | env_cls = BaseEnvManagerV2 17 | manager_cfg = env_cls.default_config() 18 | else: 19 | env_cls = SubprocessEnvManagerV2 20 | manager_cfg = env_cls.default_config() 21 | if context is not None: 22 | manager_cfg.context = context 23 | return env_cls([partial(env.clone, caller) for _ in range(env_num)], manager_cfg) 24 | -------------------------------------------------------------------------------- /ding/envs/env_manager/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/envs/env_manager/tests/__init__.py -------------------------------------------------------------------------------- /ding/envs/env_manager/tests/test_shm.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import time 3 | import numpy as np 4 | import torch 5 | from multiprocessing import Process 6 | 7 | from ding.envs.env_manager.subprocess_env_manager import ShmBuffer 8 | 9 | 10 | def writer(shm): 11 | while True: 12 | shm.fill(np.random.random(size=(4, 84, 84)).astype(np.float32)) 13 | time.sleep(1) 14 | 15 | 16 | @pytest.mark.unittest 17 | def test_shm(): 18 | 19 | shm = ShmBuffer(dtype=np.float32, shape=(4, 84, 84), copy_on_get=False) 20 | writer_process = Process(target=writer, args=(shm, )) 21 | writer_process.start() 22 | 23 | time.sleep(0.1) 24 | 25 | data1 = shm.get() 26 | time.sleep(1) 27 | data2 = shm.get() 28 | # same memory 29 | assert (data1 == data2).all() 30 | 31 | time.sleep(1) 32 | data3 = shm.get().copy() 33 | time.sleep(1) 34 | data4 = shm.get() 35 | assert (data3 != data4).all() 36 | 37 | writer_process.terminate() 38 | -------------------------------------------------------------------------------- /ding/envs/env_wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from .env_wrappers import * 2 | -------------------------------------------------------------------------------- /ding/envs/gym_env.py: -------------------------------------------------------------------------------- 1 | from ding.envs import BaseEnv, DingEnvWrapper 2 | 3 | 4 | def env(cfg, seed_api=True, caller='collector', **kwargs) -> BaseEnv: 5 | import gym 6 | return DingEnvWrapper(gym.make(cfg.env_id, **kwargs), cfg=cfg, seed_api=seed_api, caller=caller) 7 | -------------------------------------------------------------------------------- /ding/example/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/example/__init__.py -------------------------------------------------------------------------------- /ding/framework/__init__.py: -------------------------------------------------------------------------------- 1 | from .context import Context, OnlineRLContext, OfflineRLContext 2 | from .task import Task, task, VoidMiddleware 3 | from .parallel import Parallel 4 | from .event_loop import EventLoop 5 | from .supervisor import Supervisor 6 | from easydict import EasyDict 7 | from ding.utils import DistributedWriter 8 | 9 | 10 | def ding_init(cfg: EasyDict): 11 | DistributedWriter.get_instance(cfg.exp_name) 12 | -------------------------------------------------------------------------------- /ding/framework/message_queue/__init__.py: -------------------------------------------------------------------------------- 1 | from .mq import MQ 2 | from .redis import RedisMQ 3 | from .nng import NNGMQ 4 | -------------------------------------------------------------------------------- /ding/framework/message_queue/tests/test_nng.py: -------------------------------------------------------------------------------- 1 | from time import sleep 2 | import pytest 3 | 4 | import multiprocessing as mp 5 | from ding.framework.message_queue.nng import NNGMQ 6 | 7 | 8 | def nng_main(i): 9 | if i == 0: 10 | listen_to = "tcp://127.0.0.1:50515" 11 | attach_to = None 12 | mq = NNGMQ(listen_to=listen_to, attach_to=attach_to) 13 | mq.listen() 14 | for _ in range(10): 15 | mq.publish("t", b"data") 16 | sleep(0.1) 17 | else: 18 | listen_to = "tcp://127.0.0.1:50516" 19 | attach_to = ["tcp://127.0.0.1:50515"] 20 | mq = NNGMQ(listen_to=listen_to, attach_to=attach_to) 21 | mq.listen() 22 | topic, msg = mq.recv() 23 | assert topic == "t" 24 | assert msg == b"data" 25 | 26 | 27 | @pytest.mark.unittest 28 | @pytest.mark.execution_timeout(10) 29 | def test_nng(): 30 | ctx = mp.get_context("spawn") 31 | with ctx.Pool(processes=2) as pool: 32 | pool.map(nng_main, range(2)) 33 | -------------------------------------------------------------------------------- /ding/framework/middleware/__init__.py: -------------------------------------------------------------------------------- 1 | from .functional import * 2 | from .collector import StepCollector, EpisodeCollector, PPOFStepCollector 3 | from .learner import OffPolicyLearner, HERLearner 4 | from .ckpt_handler import CkptSaver 5 | from .distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger 6 | from .barrier import Barrier, BarrierRuntime 7 | from .data_fetcher import OfflineMemoryDataFetcher 8 | -------------------------------------------------------------------------------- /ding/framework/middleware/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import trainer, multistep_trainer 2 | from .data_processor import offpolicy_data_fetcher, data_pusher, offline_data_fetcher, offline_data_saver, \ 3 | offline_data_fetcher_from_mem, sqil_data_pusher, buffer_saver 4 | from .collector import inferencer, rolloutor, TransitionList 5 | from .evaluator import interaction_evaluator, interaction_evaluator_ttorch 6 | from .termination_checker import termination_checker, ddp_termination_checker 7 | from .logger import online_logger, offline_logger, wandb_online_logger, wandb_offline_logger 8 | from .ctx_helper import final_ctx_saver 9 | 10 | # algorithm 11 | from .explorer import eps_greedy_handler, eps_greedy_masker 12 | from .advantage_estimator import gae_estimator, ppof_adv_estimator, montecarlo_return_estimator 13 | from .enhancer import reward_estimator, her_data_enhancer, nstep_reward_enhancer 14 | from .priority import priority_calculator 15 | from .timer import epoch_timer 16 | -------------------------------------------------------------------------------- /ding/framework/middleware/functional/ctx_helper.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Callable 2 | import os 3 | import pickle 4 | import dataclasses 5 | from ding.framework import task 6 | if TYPE_CHECKING: 7 | from ding.framework import Context 8 | 9 | 10 | def final_ctx_saver(name: str) -> Callable: 11 | 12 | def _save(ctx: "Context"): 13 | if task.finish: 14 | # make sure the items to be recorded are all kept in the context 15 | with open(os.path.join(name, 'result.pkl'), 'wb') as f: 16 | final_data = { 17 | 'total_step': ctx.total_step, 18 | 'train_iter': ctx.train_iter, 19 | 'last_eval_iter': ctx.last_eval_iter, 20 | 'eval_value': ctx.last_eval_value, 21 | } 22 | if ctx.has_attr('env_step'): 23 | final_data['env_step'] = ctx.env_step 24 | final_data['env_episode'] = ctx.env_episode 25 | if ctx.has_attr('trained_env_step'): 26 | final_data['trained_env_step'] = ctx.trained_env_step 27 | pickle.dump(final_data, f) 28 | 29 | return _save 30 | -------------------------------------------------------------------------------- /ding/framework/middleware/functional/priority.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Callable 2 | from ding.framework import task 3 | if TYPE_CHECKING: 4 | from ding.framework import OnlineRLContext 5 | 6 | 7 | def priority_calculator(priority_calculation_fn: Callable) -> Callable: 8 | """ 9 | Overview: 10 | The middleware that calculates the priority of the collected data. 11 | Arguments: 12 | - priority_calculation_fn (:obj:`Callable`): The function that calculates the priority of the collected data. 13 | """ 14 | 15 | if task.router.is_active and not task.has_role(task.role.COLLECTOR): 16 | return task.void() 17 | 18 | def _priority_calculator(ctx: "OnlineRLContext") -> None: 19 | 20 | priority = priority_calculation_fn(ctx.trajectories) 21 | for i in range(len(priority)): 22 | ctx.trajectories[i]['priority'] = priority[i] 23 | 24 | return _priority_calculator 25 | -------------------------------------------------------------------------------- /ding/framework/middleware/functional/timer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import deque 3 | from ditk import logging 4 | from time import time 5 | 6 | from ding.framework import task 7 | from typing import TYPE_CHECKING 8 | if TYPE_CHECKING: 9 | from ding.framework.context import Context 10 | 11 | 12 | def epoch_timer(print_per: int = 1, smooth_window: int = 10): 13 | """ 14 | Overview: 15 | Print time cost of each epoch. 16 | Arguments: 17 | - print_per (:obj:`int`): Print each N epoch. 18 | - smooth_window (:obj:`int`): The window size to smooth the mean. 19 | """ 20 | records = deque(maxlen=print_per * smooth_window) 21 | 22 | def _epoch_timer(ctx: "Context"): 23 | start = time() 24 | yield 25 | time_cost = time() - start 26 | records.append(time_cost) 27 | if ctx.total_step % print_per == 0: 28 | logging.info( 29 | "[Epoch Timer][Node:{:>2}]: Cost: {:.2f}ms, Mean: {:.2f}ms".format( 30 | task.router.node_id or 0, time_cost * 1000, 31 | np.mean(records) * 1000 32 | ) 33 | ) 34 | 35 | return _epoch_timer 36 | -------------------------------------------------------------------------------- /ding/framework/middleware/tests/__init__.py: -------------------------------------------------------------------------------- 1 | from .mock_for_test import MockEnv, MockPolicy, MockHerRewardModel, CONFIG 2 | -------------------------------------------------------------------------------- /ding/framework/middleware/tests/test_explorer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import copy 3 | from ding.framework import OnlineRLContext 4 | from ding.framework.middleware import eps_greedy_handler, eps_greedy_masker 5 | from ding.framework.middleware.tests import MockPolicy, MockEnv, CONFIG 6 | 7 | 8 | @pytest.mark.unittest 9 | def test_eps_greedy_handler(): 10 | cfg = copy.deepcopy(CONFIG) 11 | ctx = OnlineRLContext() 12 | 13 | ctx.env_step = 0 14 | next(eps_greedy_handler(cfg)(ctx)) 15 | assert ctx.collect_kwargs['eps'] == 0.95 16 | 17 | ctx.env_step = 1000000 18 | next(eps_greedy_handler(cfg)(ctx)) 19 | assert ctx.collect_kwargs['eps'] == 0.1 20 | 21 | 22 | @pytest.mark.unittest 23 | def test_eps_greedy_masker(): 24 | ctx = OnlineRLContext() 25 | for _ in range(10): 26 | eps_greedy_masker()(ctx) 27 | assert ctx.collect_kwargs['eps'] == -1 28 | -------------------------------------------------------------------------------- /ding/framework/middleware/tests/test_priority.py: -------------------------------------------------------------------------------- 1 | #unittest for priority_calculator 2 | 3 | import unittest 4 | import pytest 5 | import numpy as np 6 | from unittest.mock import Mock, patch 7 | from ding.framework import OnlineRLContext, OfflineRLContext 8 | from ding.framework import task, Parallel 9 | from ding.framework.middleware.functional import priority_calculator 10 | 11 | 12 | class MockPolicy(Mock): 13 | 14 | def priority_fun(self, data): 15 | return np.random.rand(len(data)) 16 | 17 | 18 | @pytest.mark.unittest 19 | def test_priority_calculator(): 20 | policy = MockPolicy() 21 | ctx = OnlineRLContext() 22 | ctx.trajectories = [ 23 | { 24 | 'obs': np.random.rand(2, 2), 25 | 'next_obs': np.random.rand(2, 2), 26 | 'reward': np.random.rand(1), 27 | 'info': {} 28 | } for _ in range(10) 29 | ] 30 | priority_calculator_middleware = priority_calculator(priority_calculation_fn=policy.priority_fun) 31 | priority_calculator_middleware(ctx) 32 | assert len(ctx.trajectories) == 10 33 | assert all([isinstance(traj['priority'], float) for traj in ctx.trajectories]) 34 | -------------------------------------------------------------------------------- /ding/framework/wrapper/__init__.py: -------------------------------------------------------------------------------- 1 | from .step_timer import StepTimer 2 | -------------------------------------------------------------------------------- /ding/hpc_rl/README.md: -------------------------------------------------------------------------------- 1 | Step 0. clean old version 2 | rm ~/.local/lib/python3.6/site-packages/hpc_*.so 3 | rm ~/.local/lib/python3.6/site-packages/hpc_rl* -rf 4 | rm ~/.local/lib/python3.6/site-packages/di_hpc_rl* -rf 5 | 6 | Step 1. 7 | pip install di_hpc_rll-0.0.1-cp36-cp36m-linux_x86_64.whl --user 8 | ls ~/.local/lib/python3.6/site-packages/di_hpc_rl* 9 | ls ~/.local/lib/python3.6/site-packages/hpc_rl* 10 | 11 | Step 2. 12 | python3 tests/test_gae.py -------------------------------------------------------------------------------- /ding/hpc_rl/__init__.py: -------------------------------------------------------------------------------- 1 | from .wrapper import hpc_wrapper 2 | -------------------------------------------------------------------------------- /ding/hpc_rl/tests/testbase.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | torch.set_printoptions(precision=6) 5 | 6 | times = 6 7 | 8 | 9 | def mean_relative_error(y_true, y_pred): 10 | eps = 1e-5 11 | relative_error = np.average(np.abs(y_true - y_pred) / (y_true + eps), axis=0) 12 | return relative_error 13 | -------------------------------------------------------------------------------- /ding/interaction/__init__.py: -------------------------------------------------------------------------------- 1 | from .master import * 2 | from .slave import * 3 | -------------------------------------------------------------------------------- /ding/interaction/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .app import CommonErrorCode, success_response, failure_response, get_values_from_response, flask_response, \ 2 | ResponsibleException, responsible 3 | from .common import random_token, translate_dict_func, ControllableService, ControllableContext, default_func 4 | from .network import get_host_ip, get_http_engine_class, HttpEngine, split_http_address 5 | from .threading import DblEvent 6 | -------------------------------------------------------------------------------- /ding/interaction/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import MIN_HEARTBEAT_CHECK_SPAN, MIN_HEARTBEAT_SPAN, DEFAULT_MASTER_PORT, DEFAULT_SLAVE_PORT, \ 2 | DEFAULT_CHANNEL, DEFAULT_HEARTBEAT_CHECK_SPAN, DEFAULT_HEARTBEAT_TOLERANCE, DEFAULT_HEARTBEAT_SPAN, LOCAL_HOST, \ 3 | GLOBAL_HOST, DEFAULT_REQUEST_RETRIES, DEFAULT_REQUEST_RETRY_WAITING 4 | -------------------------------------------------------------------------------- /ding/interaction/config/base.py: -------------------------------------------------------------------------------- 1 | # System configs 2 | GLOBAL_HOST = '0.0.0.0' 3 | LOCAL_HOST = '127.0.0.1' 4 | 5 | # General request 6 | DEFAULT_REQUEST_RETRIES = 5 7 | DEFAULT_REQUEST_RETRY_WAITING = 1.0 8 | 9 | # Slave configs 10 | MIN_HEARTBEAT_SPAN = 0.2 11 | DEFAULT_HEARTBEAT_SPAN = 3.0 12 | DEFAULT_SLAVE_PORT = 7236 13 | 14 | # Master configs 15 | MIN_HEARTBEAT_CHECK_SPAN = 0.1 16 | DEFAULT_HEARTBEAT_CHECK_SPAN = 1.0 17 | DEFAULT_HEARTBEAT_TOLERANCE = 17.0 18 | DEFAULT_MASTER_PORT = 7235 19 | 20 | # Two-side configs 21 | DEFAULT_CHANNEL = 0 22 | -------------------------------------------------------------------------------- /ding/interaction/exception/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ResponseException 2 | from .master import MasterErrorCode, get_master_exception_by_error, MasterResponseException, MasterSuccess, \ 3 | MasterChannelInvalid, MasterChannelNotGiven, MasterMasterTokenInvalid, MasterMasterTokenNotGiven, \ 4 | MasterSelfTokenInvalid, MasterSelfTokenNotGiven, MasterSlaveTokenInvalid, MasterSlaveTokenNotGiven, \ 5 | MasterSystemShuttingDown, MasterTaskDataInvalid 6 | from .slave import SlaveErrorCode, get_slave_exception_by_error, SlaveResponseException, SlaveSuccess, \ 7 | SlaveChannelInvalid, SlaveChannelNotFound, SlaveSelfTokenInvalid, SlaveTaskAlreadyExist, SlaveTaskRefused, \ 8 | SlaveMasterTokenInvalid, SlaveMasterTokenNotFound, SlaveSelfTokenNotFound, SlaveSlaveAlreadyConnected, \ 9 | SlaveSlaveConnectionRefused, SlaveSlaveDisconnectionRefused, SlaveSlaveNotConnected, SlaveSystemShuttingDown 10 | -------------------------------------------------------------------------------- /ding/interaction/master/__init__.py: -------------------------------------------------------------------------------- 1 | from .master import Master 2 | -------------------------------------------------------------------------------- /ding/interaction/master/base.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Mapping, Any, Optional 2 | 3 | from requests import RequestException 4 | 5 | _BEFORE_HOOK_TYPE = Callable[..., Mapping[str, Any]] 6 | _AFTER_HOOK_TYPE = Callable[[int, bool, int, Optional[str], Optional[Mapping[str, Any]]], Any] 7 | _ERROR_HOOK_TYPE = Callable[[RequestException], Any] 8 | -------------------------------------------------------------------------------- /ding/interaction/slave/__init__.py: -------------------------------------------------------------------------------- 1 | from .action import TaskRefuse, DisconnectionRefuse, ConnectionRefuse, TaskFail 2 | from .slave import Slave 3 | -------------------------------------------------------------------------------- /ding/interaction/tests/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .config import * 3 | from .exception import * 4 | from .interaction import * 5 | -------------------------------------------------------------------------------- /ding/interaction/tests/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .test_app import TestInteractionBaseApp, TestInteractionBaseResponsibleException 2 | from .test_common import TestInteractionBaseCommon, TestInteractionBaseControllableService 3 | from .test_network import TestInteractionBaseHttpEngine, TestInteractionBaseNetwork 4 | from .test_threading import TestInteractionBaseThreading 5 | -------------------------------------------------------------------------------- /ding/interaction/tests/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .test_base import TestInteractionConfig 2 | -------------------------------------------------------------------------------- /ding/interaction/tests/config/test_base.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from ...config import GLOBAL_HOST, LOCAL_HOST 4 | 5 | 6 | @pytest.mark.unittest 7 | class TestInteractionConfig: 8 | 9 | def test_base_host(self): 10 | assert GLOBAL_HOST == '0.0.0.0' 11 | assert LOCAL_HOST == '127.0.0.1' 12 | -------------------------------------------------------------------------------- /ding/interaction/tests/exception/__init__.py: -------------------------------------------------------------------------------- 1 | from .test_master import TestInteractionExceptionMaster 2 | from .test_slave import TestInteractionExceptionSlave 3 | -------------------------------------------------------------------------------- /ding/interaction/tests/interaction/__init__.py: -------------------------------------------------------------------------------- 1 | from .test_errors import TestInteractionErrors 2 | from .test_simple import TestInteractionSimple 3 | -------------------------------------------------------------------------------- /ding/interaction/tests/test_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .random import random_port, random_channel 2 | from .stream import silence, silence_function 3 | -------------------------------------------------------------------------------- /ding/interaction/tests/test_utils/random.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Iterable 3 | 4 | 5 | def random_port(excludes: Iterable[int] = None) -> int: 6 | return random.choice(list(set(range(10000, 20000)) - set(excludes or []))) 7 | 8 | 9 | def random_channel(excludes: Iterable[int] = None) -> int: 10 | excludes = set(list(excludes or [])) 11 | while True: 12 | _channel = random.randint(1000, (1 << 31) - 1) 13 | if _channel not in excludes: 14 | return _channel 15 | -------------------------------------------------------------------------------- /ding/league/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_league import BaseLeague, create_league 2 | from .one_vs_one_league import OneVsOneLeague 3 | from .player import Player, ActivePlayer, HistoricalPlayer, create_player 4 | from .starcraft_player import MainPlayer, MainExploiter, LeagueExploiter 5 | from .shared_payoff import create_payoff 6 | from .metric import get_elo, get_elo_array, LeagueMetricEnv 7 | -------------------------------------------------------------------------------- /ding/league/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | 5 | @pytest.fixture(scope='session') 6 | def random_job_result(): 7 | 8 | def fn(): 9 | p = np.random.uniform() 10 | if p < 1. / 3: 11 | return "wins" 12 | elif p < 2. / 3: 13 | return "draws" 14 | else: 15 | return "losses" 16 | 17 | return fn 18 | 19 | 20 | @pytest.fixture(scope='session') 21 | def get_job_result_categories(): 22 | return ["wins", 'draws', 'losses'] 23 | -------------------------------------------------------------------------------- /ding/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import * 2 | from .template import * 3 | from .wrapper import * 4 | -------------------------------------------------------------------------------- /ding/model/common/__init__.py: -------------------------------------------------------------------------------- 1 | from .head import DiscreteHead, DuelingHead, DistributionHead, RainbowHead, QRDQNHead, StochasticDuelingHead, \ 2 | QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, BranchingHead, head_cls_map, \ 3 | independent_normal_dist, AttentionPolicyHead, PopArtVHead, EnsembleHead 4 | from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder, GaussianFourierProjectionTimeEncoder 5 | from .utils import create_model 6 | -------------------------------------------------------------------------------- /ding/model/template/__init__.py: -------------------------------------------------------------------------------- 1 | # general 2 | from .q_learning import DQN, RainbowDQN, QRDQN, IQN, FQF, DRQN, C51DQN, BDQ, GTrXLDQN 3 | from .qac import DiscreteQAC, ContinuousQAC 4 | from .pdqn import PDQN 5 | from .vac import VAC, DREAMERVAC 6 | from .bc import DiscreteBC, ContinuousBC 7 | from .language_transformer import LanguageTransformer 8 | # algorithm-specific 9 | from .pg import PG 10 | from .ppg import PPG 11 | from .qmix import Mixer, QMix 12 | from .collaq import CollaQ 13 | from .wqmix import WQMix 14 | from .coma import COMA 15 | from .atoc import ATOC 16 | from .sqn import SQN 17 | from .acer import ACER 18 | from .qtran import QTran 19 | from .mavac import MAVAC 20 | from .ngu import NGU 21 | from .qac_dist import QACDIST 22 | from .maqac import DiscreteMAQAC, ContinuousMAQAC 23 | from .madqn import MADQN 24 | from .vae import VanillaVAE 25 | from .decision_transformer import DecisionTransformer 26 | from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS 27 | from .bcq import BCQ 28 | from .edac import EDAC 29 | from .hpt import HPT 30 | from .qgpo import QGPO 31 | from .ebm import EBM, AutoregressiveEBM 32 | from .havac import HAVAC 33 | -------------------------------------------------------------------------------- /ding/model/template/sqn.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | import torch.nn as nn 4 | 5 | from ding.utils import MODEL_REGISTRY 6 | from .q_learning import DQN 7 | 8 | 9 | @MODEL_REGISTRY.register('sqn') 10 | class SQN(nn.Module): 11 | 12 | def __init__(self, *args, **kwargs) -> None: 13 | super(SQN, self).__init__() 14 | self.q0 = DQN(*args, **kwargs) 15 | self.q1 = DQN(*args, **kwargs) 16 | 17 | def forward(self, data: torch.Tensor) -> Dict: 18 | output0 = self.q0(data) 19 | output1 = self.q1(data) 20 | return { 21 | 'q_value': [output0['logit'], output1['logit']], 22 | 'logit': output0['logit'], 23 | } 24 | -------------------------------------------------------------------------------- /ding/model/wrapper/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_wrappers import model_wrap, register_wrapper, IModelWrapper 2 | -------------------------------------------------------------------------------- /ding/policy/mbpolicy/__init__.py: -------------------------------------------------------------------------------- 1 | from .mbsac import MBSACPolicy 2 | from .dreamer import DREAMERPolicy 3 | -------------------------------------------------------------------------------- /ding/policy/mbpolicy/tests/test_mbpolicy_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from ding.policy.mbpolicy.utils import q_evaluation 4 | 5 | 6 | @pytest.mark.unittest 7 | def test_q_evaluation(): 8 | T, B, O, A = 10, 20, 100, 30 9 | obss = torch.randn(T, B, O) 10 | actions = torch.randn(T, B, A) 11 | 12 | def fake_q_fn(obss, actions): 13 | # obss: flatten_B * O 14 | # actions: flatten_B * A 15 | # return: flatten_B 16 | return obss.sum(-1) 17 | 18 | q_value = q_evaluation(obss, actions, fake_q_fn) 19 | assert q_value.shape == (T, B) 20 | -------------------------------------------------------------------------------- /ding/reward_model/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_reward_model import BaseRewardModel, create_reward_model, get_reward_model_cls 2 | # inverse RL 3 | from .pdeil_irl_model import PdeilRewardModel 4 | from .gail_irl_model import GailRewardModel 5 | from .pwil_irl_model import PwilRewardModel 6 | from .red_irl_model import RedRewardModel 7 | from .trex_reward_model import TrexRewardModel 8 | from .drex_reward_model import DrexRewardModel 9 | # sparse reward 10 | from .her_reward_model import HerRewardModel 11 | # exploration 12 | from .rnd_reward_model import RndRewardModel 13 | from .guided_cost_reward_model import GuidedCostRewardModel 14 | from .ngu_reward_model import RndNGURewardModel, EpisodicNGURewardModel 15 | from .icm_reward_model import ICMRewardModel 16 | -------------------------------------------------------------------------------- /ding/rl_utils/tests/test_gae.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from ding.rl_utils import gae_data, gae 4 | 5 | 6 | @pytest.mark.unittest 7 | def test_gae(): 8 | # batch trajectory case 9 | T, B = 32, 4 10 | value = torch.randn(T, B) 11 | next_value = torch.randn(T, B) 12 | reward = torch.randn(T, B) 13 | done = torch.zeros((T, B)) 14 | data = gae_data(value, next_value, reward, done, None) 15 | adv = gae(data) 16 | assert adv.shape == (T, B) 17 | # single trajectory case/concat trajectory case 18 | T = 24 19 | value = torch.randn(T) 20 | next_value = torch.randn(T) 21 | reward = torch.randn(T) 22 | done = torch.zeros((T)) 23 | data = gae_data(value, next_value, reward, done, None) 24 | adv = gae(data) 25 | assert adv.shape == (T, ) 26 | 27 | 28 | def test_gae_multi_agent(): 29 | T, B, A = 32, 4, 8 30 | value = torch.randn(T, B, A) 31 | next_value = torch.randn(T, B, A) 32 | reward = torch.randn(T, B) 33 | done = torch.zeros(T, B) 34 | data = gae_data(value, next_value, reward, done, None) 35 | adv = gae(data) 36 | assert adv.shape == (T, B, A) 37 | -------------------------------------------------------------------------------- /ding/rl_utils/tests/test_retrace.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from ding.rl_utils import compute_q_retraces 4 | 5 | 6 | @pytest.mark.unittest 7 | def test_compute_q_retraces(): 8 | T, B, N = 64, 32, 6 9 | q_values = torch.randn(T + 1, B, N) 10 | v_pred = torch.randn(T + 1, B, 1) 11 | rewards = torch.randn(T, B) 12 | ratio = torch.rand(T, B, N) * 0.4 + 0.8 13 | assert ratio.max() <= 1.2 and ratio.min() >= 0.8 14 | weights = torch.rand(T, B) 15 | actions = torch.randint(0, N, size=(T, B)) 16 | with torch.no_grad(): 17 | q_retraces = compute_q_retraces(q_values, v_pred, rewards, actions, weights, ratio, gamma=0.99) 18 | assert q_retraces.shape == (T + 1, B, 1) 19 | -------------------------------------------------------------------------------- /ding/scripts/docker-test-entry.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONTAINER_ID=$(docker run --rm -d opendilab/ding:nightly tail -f /dev/null) 4 | 5 | trap "docker rm -f $CONTAINER_ID" EXIT 6 | 7 | docker exec $CONTAINER_ID rm -rf /ding && 8 | docker cp $(pwd) ${CONTAINER_ID}:/ding && 9 | docker exec -it $CONTAINER_ID /ding/ding/scripts/docker-test.sh 10 | -------------------------------------------------------------------------------- /ding/scripts/docker-test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if [ ! -f /.dockerenv ]; then 4 | echo "This script should be executed in docker container" 5 | exit 1 6 | fi 7 | 8 | pip install --ignore-installed 'PyYAML<6.0' 9 | pip install -e .[test,k8s] && 10 | ./ding/scripts/install-k8s-tools.sh && 11 | make test 12 | -------------------------------------------------------------------------------- /ding/scripts/install-k8s-tools.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | ROOT_DIR="$(dirname "$0")" 6 | : ${USE_SUDO:="true"} 7 | 8 | # runs the given command as root (detects if we are root already) 9 | runAsRoot() { 10 | local CMD="$*" 11 | 12 | if [ $EUID -ne 0 -a $USE_SUDO = "true" ]; then 13 | CMD="sudo $CMD" 14 | fi 15 | 16 | $CMD 17 | } 18 | 19 | # install k3d 20 | curl -s https://raw.githubusercontent.com/rancher/k3d/main/install.sh | TAG=v4.4.8 bash 21 | 22 | # install kubectl 23 | if [[ $(which kubectl) == "" ]]; then 24 | echo "Installing kubectl..." 25 | curl -LO https://dl.k8s.io/release/v1.21.3/bin/linux/amd64/kubectl 26 | chmod +x kubectl 27 | runAsRoot mv kubectl /usr/local/bin/kubectl 28 | fi 29 | -------------------------------------------------------------------------------- /ding/scripts/kill.sh: -------------------------------------------------------------------------------- 1 | ps -ef | grep 'ding' | grep -v grep | awk '{print $2}'|xargs kill -9 2 | -------------------------------------------------------------------------------- /ding/scripts/local_parallel.sh: -------------------------------------------------------------------------------- 1 | ding -m parallel -c $1 -s $2 2 | -------------------------------------------------------------------------------- /ding/scripts/local_serial.sh: -------------------------------------------------------------------------------- 1 | ding -m serial -c $1 -s $2 2 | -------------------------------------------------------------------------------- /ding/scripts/main_league_slurm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export LC_ALL=en_US.utf-8 4 | export LANG=en_US.utf-8 5 | BASEDIR=$(dirname "$0") 6 | # srun -p partition_name --quotatype=reserved --mpi=pmi2 -n6 --ntasks-per-node=3 bash ding/scripts/main_league_slurm.sh 7 | ditask --package $BASEDIR/../entry --main main_league.main --platform slurm --platform-spec '{"tasks":[{"labels":"league,collect","node_ids":10},{"labels":"league,collect","node_ids":11},{"labels":"evaluate","node_ids":20,"attach_to":"$node.10,$node.11"},{"labels":"learn","node_ids":31,"attach_to":"$node.10,$node.11,$node.20"},{"labels":"learn","node_ids":32,"attach_to":"$node.10,$node.11,$node.20"},{"labels":"learn","node_ids":33,"attach_to":"$node.10,$node.11,$node.20"}]}' 8 | -------------------------------------------------------------------------------- /ding/scripts/tests/test_parallel_socket.sh: -------------------------------------------------------------------------------- 1 | total_epoch=1200 # the total num of msg 2 | interval=0.1 # msg send interval 3 | size=16 # data size (MB) 4 | test_start_time=20 # network fail time (s) 5 | test_duration=40 # network fail duration (s) 6 | output_file="my_test.log" # the python script will write its output into this file 7 | ip="0.0.0.0" 8 | 9 | rm -f pytmp_* 10 | 11 | nohup python test_parallel_socket.py -t $total_epoch -i $interval -s $size 1>$output_file 2>&1 & 12 | 13 | flag=true 14 | while $flag 15 | do 16 | for file in `ls` 17 | do 18 | if [[ $file =~ "pytmp" ]]; then 19 | ip=`cat $file` 20 | flag=false 21 | break 22 | fi 23 | done 24 | sleep 0.1 25 | done 26 | echo "get ip: $ip" 27 | 28 | sleep $test_start_time 29 | echo "Network shutsown . . ." 30 | sudo iptables -A INPUT -p tcp -s $ip --dport 50516 -j DROP 31 | 32 | sleep $test_duration 33 | sudo iptables -D INPUT -p tcp -s $ip --dport 50516 -j DROP 34 | echo "Network recovered . . ." 35 | 36 | 37 | -------------------------------------------------------------------------------- /ding/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .checkpoint_helper import build_checkpoint_helper, CountVar, auto_checkpoint 2 | from .data_helper import to_device, to_tensor, to_ndarray, to_list, to_dtype, same_shape, tensor_to_list, \ 3 | build_log_buffer, CudaFetcher, get_tensor_data, unsqueeze, squeeze, get_null_data, get_shape0, to_item, \ 4 | zeros_like 5 | from .distribution import CategoricalPd, CategoricalPdPytorch 6 | from .metric import levenshtein_distance, hamming_distance 7 | from .network import * 8 | from .loss import * 9 | from .optimizer_helper import Adam, RMSprop, calculate_grad_norm, calculate_grad_norm_without_bias_two_norm 10 | from .nn_test_helper import is_differentiable 11 | from .math_helper import cov 12 | from .dataparallel import DataParallel 13 | from .reshape_helper import fold_batch, unfold_batch, unsqueeze_repeat 14 | from .parameter import NonegativeParameter, TanhParameter 15 | -------------------------------------------------------------------------------- /ding/torch_utils/backend_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def enable_tf32() -> None: 5 | """ 6 | Overview: 7 | Enable tf32 on matmul and cudnn for faster computation. This only works on Ampere GPU devices. \ 8 | For detailed information, please refer to: \ 9 | https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices. 10 | """ 11 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 12 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 13 | -------------------------------------------------------------------------------- /ding/torch_utils/diffusion_SDE/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/torch_utils/diffusion_SDE/__init__.py -------------------------------------------------------------------------------- /ding/torch_utils/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .cross_entropy_loss import LabelSmoothCELoss, SoftFocalLoss, build_ce_criterion 2 | from .multi_logits_loss import MultiLogitsLoss 3 | from .contrastive_loss import ContrastiveLoss 4 | -------------------------------------------------------------------------------- /ding/torch_utils/loss/tests/test_cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | 5 | from ding.torch_utils import LabelSmoothCELoss, SoftFocalLoss 6 | 7 | 8 | @pytest.mark.unittest 9 | class TestLabelSmoothCE: 10 | 11 | def test_label_smooth_ce_loss(self): 12 | logits = torch.randn(4, 6) 13 | labels = torch.LongTensor([i for i in range(4)]) 14 | criterion1 = LabelSmoothCELoss(0) 15 | criterion2 = nn.CrossEntropyLoss() 16 | assert (torch.abs(criterion1(logits, labels) - criterion2(logits, labels)) < 1e-6) 17 | 18 | 19 | @pytest.mark.unittest 20 | class TestSoftFocalLoss: 21 | 22 | def test_soft_focal_loss(self): 23 | logits = torch.randn(4, 6) 24 | labels = torch.LongTensor([i for i in range(4)]) 25 | criterion = SoftFocalLoss() 26 | loss = criterion(logits, labels) 27 | assert loss.shape == () 28 | loss_value = loss.item() 29 | -------------------------------------------------------------------------------- /ding/torch_utils/loss/tests/test_multi_logits_loss.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from ding.torch_utils import MultiLogitsLoss 4 | 5 | 6 | @pytest.mark.unittest 7 | @pytest.mark.parametrize('criterion_type', ['cross_entropy', 'label_smooth_ce']) 8 | def test_multi_logits_loss(criterion_type): 9 | logits = torch.randn(4, 8).requires_grad_(True) 10 | label = torch.LongTensor([0, 1, 3, 2]) 11 | criterion = MultiLogitsLoss(criterion=criterion_type) 12 | loss = criterion(logits, label) 13 | assert loss.shape == () 14 | assert logits.grad is None 15 | loss.backward() 16 | assert isinstance(logits, torch.Tensor) 17 | -------------------------------------------------------------------------------- /ding/torch_utils/model_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_num_params(model: torch.nn.Module) -> int: 5 | """ 6 | Overview: 7 | Return the number of parameters in the model. 8 | Arguments: 9 | - model (:obj:`torch.nn.Module`): The model object to calculate the parameter number. 10 | Returns: 11 | - n_params (:obj:`int`): The calculated number of parameters. 12 | Examples: 13 | >>> model = torch.nn.Linear(3, 5) 14 | >>> num = get_num_params(model) 15 | >>> assert num == 15 16 | """ 17 | return sum(p.numel() for p in model.parameters()) 18 | -------------------------------------------------------------------------------- /ding/torch_utils/network/__init__.py: -------------------------------------------------------------------------------- 1 | from .activation import build_activation, Swish 2 | from .res_block import ResBlock, ResFCBlock 3 | from .nn_module import fc_block, conv2d_block, one_hot, deconv2d_block, BilinearUpsample, NearestUpsample, \ 4 | binary_encode, NoiseLinearLayer, noise_block, MLP, Flatten, normed_linear, normed_conv2d, conv1d_block 5 | from .normalization import build_normalization 6 | from .rnn import get_lstm, sequence_mask 7 | from .soft_argmax import SoftArgmax 8 | from .transformer import Transformer, ScaledDotProductAttention 9 | from .scatter_connection import ScatterConnection 10 | from .resnet import resnet18, ResNet 11 | from .gumbel_softmax import GumbelSoftmax 12 | from .gtrxl import GTrXL, GRUGatingUnit 13 | from .popart import PopArt 14 | #from .dreamer import Conv2dSame, DreamerLayerNorm, ActionHead, DenseHead 15 | from .merge import GatingType, SumMerge, VectorMerge 16 | -------------------------------------------------------------------------------- /ding/torch_utils/network/tests/test_gumbel_softmax.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from ding.torch_utils.network import GumbelSoftmax, gumbel_softmax 6 | 7 | 8 | @pytest.mark.unittest 9 | class TestGumbelSoftmax: 10 | 11 | def test(self): 12 | B = 4 13 | N = 10 14 | model = GumbelSoftmax() 15 | # data case 1 16 | for _ in range(N): 17 | data = torch.rand((4, 10)) 18 | data = torch.log(data) 19 | gumbelsoftmax = model(data, hard=False) 20 | assert gumbelsoftmax.shape == (B, N) 21 | # data case 2 22 | for _ in range(N): 23 | data = torch.rand((4, 10)) 24 | data = torch.log(data) 25 | gumbelsoftmax = model(data, hard=True) 26 | assert gumbelsoftmax.shape == (B, N) 27 | -------------------------------------------------------------------------------- /ding/torch_utils/network/tests/test_transformer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from ding.torch_utils import Transformer 5 | 6 | 7 | @pytest.mark.unittest 8 | class TestTransformer: 9 | 10 | def test(self): 11 | batch_size = 2 12 | num_entries = 2 13 | C = 2 14 | masks = [None, torch.rand(batch_size, num_entries).round().bool()] 15 | for mask in masks: 16 | output_dim = 4 17 | model = Transformer( 18 | input_dim=C, 19 | head_dim=2, 20 | hidden_dim=3, 21 | output_dim=output_dim, 22 | head_num=2, 23 | mlp_num=2, 24 | layer_num=2, 25 | ) 26 | input = torch.rand(batch_size, num_entries, C).requires_grad_(True) 27 | output = model(input, mask) 28 | loss = output.mean() 29 | loss.backward() 30 | assert isinstance(input.grad, torch.Tensor) 31 | assert output.shape == (batch_size, num_entries, output_dim) 32 | -------------------------------------------------------------------------------- /ding/torch_utils/tests/test_backend_helper.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from ding.torch_utils.backend_helper import enable_tf32 5 | 6 | 7 | @pytest.mark.cudatest 8 | class TestBackendHelper: 9 | 10 | def test_tf32(self): 11 | r""" 12 | Overview: 13 | Test the tf32. 14 | """ 15 | enable_tf32() 16 | net = torch.nn.Linear(3, 4) 17 | x = torch.randn(1, 3) 18 | y = torch.sum(net(x)) 19 | net.zero_grad() 20 | y.backward() 21 | assert net.weight.grad is not None 22 | -------------------------------------------------------------------------------- /ding/torch_utils/tests/test_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.optim import Adam 4 | 5 | from ding.torch_utils.lr_scheduler import cos_lr_scheduler 6 | 7 | 8 | @pytest.mark.unittest 9 | class TestLRSchedulerHelper: 10 | 11 | def test_cos_lr_scheduler(self): 12 | r""" 13 | Overview: 14 | Test the cos lr scheduler. 15 | """ 16 | net = torch.nn.Linear(3, 4) 17 | opt = Adam(net.parameters(), lr=1e-2) 18 | scheduler = cos_lr_scheduler(opt, learning_rate=1e-2, min_lr=6e-5) 19 | scheduler.step(101) 20 | assert opt.param_groups[0]['lr'] == 6e-5 21 | -------------------------------------------------------------------------------- /ding/torch_utils/tests/test_model_helper.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from ding.torch_utils.model_helper import get_num_params 5 | 6 | 7 | @pytest.mark.unittest 8 | class TestModelHelper: 9 | 10 | def test_model_helper(self): 11 | r""" 12 | Overview: 13 | Test the model helper. 14 | """ 15 | net = torch.nn.Linear(3, 4, bias=False) 16 | assert get_num_params(net) == 12 17 | 18 | net = torch.nn.Conv2d(3, 3, kernel_size=3, bias=False) 19 | assert get_num_params(net) == 81 20 | -------------------------------------------------------------------------------- /ding/torch_utils/tests/test_parameter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pytest 3 | import torch 4 | from ding.torch_utils.parameter import NonegativeParameter, TanhParameter 5 | 6 | 7 | @pytest.mark.unittest 8 | def test_nonegative_parameter(): 9 | nonegative_parameter = NonegativeParameter(torch.tensor([2.0, 3.0])) 10 | assert torch.sum(torch.abs(nonegative_parameter() - torch.tensor([2.0, 3.0]))) == 0 11 | nonegative_parameter.set_data(torch.tensor(1)) 12 | assert nonegative_parameter() == 1 13 | 14 | 15 | @pytest.mark.unittest 16 | def test_tanh_parameter(): 17 | tanh_parameter = TanhParameter(torch.tensor([0.5, -0.2])) 18 | assert torch.isclose(tanh_parameter() - torch.tensor([0.5, -0.2]), torch.zeros(2), atol=1e-6).all() 19 | tanh_parameter.set_data(torch.tensor(0.3)) 20 | assert tanh_parameter() == 0.3 21 | 22 | 23 | if __name__ == "__main__": 24 | test_nonegative_parameter() 25 | test_tanh_parameter() 26 | -------------------------------------------------------------------------------- /ding/utils/autolog/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import TimeMode 2 | from .data import RangedData, TimeRangedData 3 | from .model import LoggedModel 4 | from .time_ctl import BaseTime, NaturalTime, TickTime, TimeProxy 5 | from .value import LoggedValue 6 | 7 | if __name__ == "__main__": 8 | pass 9 | -------------------------------------------------------------------------------- /ding/utils/autolog/base.py: -------------------------------------------------------------------------------- 1 | from enum import unique, IntEnum 2 | from typing import TypeVar, Union 3 | 4 | _LOGGED_VALUE__PROPERTY_NAME = '__property_name__' 5 | _LOGGED_MODEL__PROPERTIES = '__properties__' 6 | _LOGGED_MODEL__PROPERTY_ATTR_PREFIX = '_property_' 7 | 8 | _TimeType = TypeVar('_TimeType', bound=Union[float, int]) 9 | _ValueType = TypeVar('_ValueType') 10 | 11 | 12 | @unique 13 | class TimeMode(IntEnum): 14 | """ 15 | Overview: 16 | Mode that used to decide the format of range_values function 17 | 18 | ABSOLUTE: use absolute time 19 | RELATIVE_LIFECYCLE: use relative time based on property's lifecycle 20 | RELATIVE_CURRENT_TIME: use relative time based on current time 21 | """ 22 | ABSOLUTE = 0 23 | RELATIVE_LIFECYCLE = 1 24 | RELATIVE_CURRENT_TIME = 2 25 | -------------------------------------------------------------------------------- /ding/utils/autolog/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/utils/autolog/tests/__init__.py -------------------------------------------------------------------------------- /ding/utils/collection_helper.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, TypeVar, Callable 2 | 3 | _IterType = TypeVar('_IterType') 4 | _IterTargetType = TypeVar('_IterTargetType') 5 | 6 | 7 | def iter_mapping(iter_: Iterable[_IterType], mapping: Callable[[_IterType], _IterTargetType]): 8 | """ 9 | Overview: 10 | Map a list of iterable elements to input iteration callable 11 | Arguments: 12 | - iter_(:obj:`_IterType list`): The list for iteration 13 | - mapping (:obj:`Callable [[_IterType], _IterTargetType]`): A callable that maps iterable elements function. 14 | Return: 15 | - (:obj:`iter_mapping object`): Iteration results 16 | Example: 17 | >>> iterable_list = [1, 2, 3, 4, 5] 18 | >>> _iter = iter_mapping(iterable_list, lambda x: x ** 2) 19 | >>> print(list(_iter)) 20 | [1, 4, 9, 16, 25] 21 | """ 22 | for item in iter_: 23 | yield mapping(item) 24 | -------------------------------------------------------------------------------- /ding/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .collate_fn import diff_shape_collate, default_collate, default_decollate, timestep_collate, ttorch_collate 2 | from .dataloader import AsyncDataLoader 3 | from .dataset import NaiveRLDataset, D4RLDataset, HDF5Dataset, BCODataset, \ 4 | create_dataset, hdf5_save, offline_data_save_type 5 | from .rlhf_online_dataset import OnlineRLDataset 6 | from .rlhf_offline_dataset import OfflineRLDataset 7 | -------------------------------------------------------------------------------- /ding/utils/data/structure/__init__.py: -------------------------------------------------------------------------------- 1 | from .cache import Cache 2 | from .lifo_deque import LifoDeque 3 | -------------------------------------------------------------------------------- /ding/utils/data/structure/lifo_deque.py: -------------------------------------------------------------------------------- 1 | from queue import LifoQueue 2 | from collections import deque 3 | 4 | 5 | class LifoDeque(LifoQueue): 6 | """ 7 | Overview: 8 | Like LifoQueue, but automatically replaces the oldest data when the queue is full. 9 | Interfaces: 10 | ``_init``, ``_put``, ``_get`` 11 | """ 12 | 13 | def _init(self, maxsize): 14 | self.maxsize = maxsize + 1 15 | self.queue = deque(maxlen=maxsize) 16 | -------------------------------------------------------------------------------- /ding/utils/design_helper.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | 3 | 4 | # ABCMeta is a subclass of type, extending ABCMeta makes this metaclass is compatible with some classes 5 | # which extends ABC 6 | class SingletonMetaclass(ABCMeta): 7 | """ 8 | Overview: 9 | Returns the given type instance in input class 10 | Interfaces: 11 | ``__call__`` 12 | """ 13 | instances = {} 14 | 15 | def __call__(cls: type, *args, **kwargs) -> object: 16 | """ 17 | Overview: 18 | Returns the given type instance in input class 19 | """ 20 | 21 | if cls not in SingletonMetaclass.instances: 22 | SingletonMetaclass.instances[cls] = super(SingletonMetaclass, cls).__call__(*args, **kwargs) 23 | cls.instance = SingletonMetaclass.instances[cls] 24 | return SingletonMetaclass.instances[cls] 25 | -------------------------------------------------------------------------------- /ding/utils/dict_helper.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | 3 | 4 | def convert_easy_dict_to_dict(easy_dict: EasyDict) -> dict: 5 | """ 6 | Overview: 7 | Convert an EasyDict object to a dict object recursively. 8 | Arguments: 9 | - easy_dict (:obj:`EasyDict`): The EasyDict object to be converted. 10 | Returns: 11 | - dict: The converted dict object. 12 | """ 13 | return {k: convert_easy_dict_to_dict(v) if isinstance(v, EasyDict) else v for k, v in easy_dict.items()} 14 | -------------------------------------------------------------------------------- /ding/utils/fake_linklink.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | 4 | class FakeClass: 5 | """ 6 | Overview: 7 | Fake class. 8 | """ 9 | 10 | def __init__(self, *args, **kwargs): 11 | pass 12 | 13 | 14 | class FakeNN: 15 | """ 16 | Overview: 17 | Fake nn class. 18 | """ 19 | 20 | SyncBatchNorm2d = FakeClass 21 | 22 | 23 | class FakeLink: 24 | """ 25 | Overview: 26 | Fake link class. 27 | """ 28 | 29 | nn = FakeNN() 30 | syncbnVarMode_t = namedtuple("syncbnVarMode_t", "L2")(L2=None) 31 | allreduceOp_t = namedtuple("allreduceOp_t", ['Sum', 'Max']) 32 | 33 | 34 | link = FakeLink() 35 | -------------------------------------------------------------------------------- /ding/utils/loader/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Loader 2 | from .collection import collection, CollectionError, length, length_is, contains, tuple_, cofilter, tpselector 3 | from .dict import DictError, dict_ 4 | from .exception import CompositeStructureError 5 | from .mapping import mapping, MappingError, mpfilter, mpkeys, mpvalues, mpitems, item, item_or 6 | from .norm import norm, normfunc, lnot, land, lor, lin, lis, lisnot, lsum, lcmp 7 | from .number import interval, numeric, negative, positive, plus, minus, minus_with, multi, divide, divide_with, power, \ 8 | power_with, msum, mmulti, mcmp, is_negative, is_positive, non_negative, non_positive 9 | from .string import enum, rematch, regrep 10 | from .types import is_type, to_type, is_callable, prop, method, fcall, fpartial 11 | from .utils import keep, optional, check_only, raw, check 12 | -------------------------------------------------------------------------------- /ding/utils/loader/exception.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import List, Union, Tuple 3 | 4 | INDEX_TYPING = Union[int, str] 5 | ERROR_ITEM_TYPING = Tuple[INDEX_TYPING, Exception] 6 | ERROR_ITEMS = List[ERROR_ITEM_TYPING] 7 | 8 | 9 | class CompositeStructureError(ValueError, metaclass=ABCMeta): 10 | """ 11 | Overview: 12 | Composite structure error. 13 | Interfaces: 14 | ``__init__``, ``errors`` 15 | Properties: 16 | ``errors`` 17 | """ 18 | 19 | @property 20 | @abstractmethod 21 | def errors(self) -> ERROR_ITEMS: 22 | """ 23 | Overview: 24 | Get the errors. 25 | """ 26 | 27 | raise NotImplementedError 28 | -------------------------------------------------------------------------------- /ding/utils/loader/tests/__init__.py: -------------------------------------------------------------------------------- 1 | from .loader import * 2 | from .test_cartpole_dqn_serial_config_loader import test_main_config, test_create_config 3 | -------------------------------------------------------------------------------- /ding/utils/loader/tests/loader/__init__.py: -------------------------------------------------------------------------------- 1 | from .test_base import TestConfigLoaderBase 2 | from .test_collection import TestConfigLoaderCollection 3 | from .test_dict import TestConfigLoaderDict 4 | from .test_mapping import TestConfigLoaderMapping 5 | from .test_norm import TestConfigLoaderNorm 6 | from .test_number import TestConfigLoaderNumber 7 | from .test_string import TestConfigLoaderString 8 | from .test_types import TestConfigLoaderTypes 9 | from .test_utils import TestConfigLoaderUtils 10 | -------------------------------------------------------------------------------- /ding/utils/loader/utils.py: -------------------------------------------------------------------------------- 1 | from .base import Loader, ILoaderClass 2 | 3 | 4 | def keep() -> ILoaderClass: 5 | """ 6 | Overview: 7 | Create a keep loader. 8 | """ 9 | 10 | return Loader(lambda v: v) 11 | 12 | 13 | def raw(value) -> ILoaderClass: 14 | """ 15 | Overview: 16 | Create a raw loader. 17 | """ 18 | 19 | return Loader(lambda v: value) 20 | 21 | 22 | def optional(loader) -> ILoaderClass: 23 | """ 24 | Overview: 25 | Create a optional loader. 26 | Arguments: 27 | - loader (:obj:`ILoaderClass`): The loader. 28 | """ 29 | 30 | return Loader(loader) | None 31 | 32 | 33 | def check_only(loader) -> ILoaderClass: 34 | """ 35 | Overview: 36 | Create a check only loader. 37 | Arguments: 38 | - loader (:obj:`ILoaderClass`): The loader. 39 | """ 40 | 41 | return Loader(loader) & keep() 42 | 43 | 44 | def check(loader) -> ILoaderClass: 45 | """ 46 | Overview: 47 | Create a check loader. 48 | Arguments: 49 | - loader (:obj:`ILoaderClass`): The loader. 50 | """ 51 | 52 | return Loader(lambda x: Loader(loader).check(x)) 53 | -------------------------------------------------------------------------------- /ding/utils/tests/config/k8s-config.yaml: -------------------------------------------------------------------------------- 1 | type: k3s # k3s or local 2 | name: di-cluster 3 | servers: 1 # # of k8s masters 4 | agents: 0 # # of k8s nodes 5 | preload_images: 6 | - busybox:latest 7 | - hello-world:latest 8 | -------------------------------------------------------------------------------- /ding/utils/tests/test_bfs_helper.py: -------------------------------------------------------------------------------- 1 | import easydict 2 | import numpy 3 | import pytest 4 | 5 | from ding.utils import get_vi_sequence 6 | from dizoo.maze.envs.maze_env import Maze 7 | 8 | 9 | @pytest.mark.unittest 10 | class TestBFSHelper: 11 | 12 | def test_bfs(self): 13 | 14 | def load_env(seed): 15 | ccc = easydict.EasyDict({'size': 16}) 16 | e = Maze(ccc) 17 | e.seed(seed) 18 | e.reset() 19 | return e 20 | 21 | env = load_env(314) 22 | start_obs = env.process_states(env._get_obs(), env.get_maze_map()) 23 | vi_sequence, track_back = get_vi_sequence(env, start_obs) 24 | assert vi_sequence.shape[1:] == (16, 16) 25 | assert track_back[0][0].shape == (16, 16, 3) 26 | assert isinstance(track_back[0][1], numpy.int32) 27 | -------------------------------------------------------------------------------- /ding/utils/tests/test_collection_helper.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from ding.utils.collection_helper import iter_mapping 4 | 5 | 6 | @pytest.mark.unittest 7 | class TestCollectionHelper: 8 | 9 | def test_iter_mapping(self): 10 | _iter = iter_mapping([1, 2, 3, 4, 5], lambda x: x ** 2) 11 | 12 | assert not isinstance(_iter, list) 13 | assert list(_iter) == [1, 4, 9, 16, 25] 14 | -------------------------------------------------------------------------------- /ding/utils/tests/test_compression_helper.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from ding.utils.compression_helper import get_data_compressor, get_data_decompressor 4 | 5 | import pytest 6 | 7 | 8 | @pytest.mark.unittest 9 | class TestCompression(): 10 | 11 | def get_step_data(self): 12 | return {'input': [random.randint(10, 100) for i in range(100)]} 13 | 14 | def testnaive(self): 15 | compress_names = ['lz4', 'zlib', 'none'] 16 | for s in compress_names: 17 | compressor = get_data_compressor(s) 18 | decompressor = get_data_decompressor(s) 19 | data = self.get_step_data() 20 | assert data == decompressor(compressor(data)) 21 | 22 | def test_arr_to_st(self): 23 | data = np.random.randint(0, 255, (96, 96, 3), dtype=np.uint8) 24 | compress_names = ['jpeg'] 25 | for s in compress_names: 26 | compressor = get_data_compressor(s) 27 | decompressor = get_data_decompressor(s) 28 | assert data.shape == decompressor(compressor(data)).shape 29 | -------------------------------------------------------------------------------- /ding/utils/tests/test_deprecation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import warnings 3 | from ding.utils.deprecation import deprecated 4 | 5 | 6 | @pytest.mark.unittest 7 | def test_deprecated(): 8 | 9 | @deprecated('0.4.1', '0.5.1') 10 | def deprecated_func1(): 11 | pass 12 | 13 | @deprecated('0.4.1', '0.5.1', 'deprecated_func3') 14 | def deprecated_func2(): 15 | pass 16 | 17 | with warnings.catch_warnings(record=True) as w: 18 | deprecated_func1() 19 | assert ( 20 | 'API `test_deprecation.deprecated_func1` is deprecated ' 21 | 'since version 0.4.1 and will be removed in version 0.5.1.' 22 | ) == str(w[-1].message) 23 | deprecated_func2() 24 | assert ( 25 | 'API `test_deprecation.deprecated_func2` is deprecated ' 26 | 'since version 0.4.1 and will be removed in version 0.5.1, ' 27 | 'please use `deprecated_func3` instead.' 28 | ) == str(w[-1].message) 29 | -------------------------------------------------------------------------------- /ding/utils/tests/test_design_helper.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import pytest 4 | 5 | from ding.utils import SingletonMetaclass 6 | 7 | 8 | @pytest.mark.unittest 9 | def test_singleton(): 10 | global count 11 | count = 0 12 | 13 | class A(object, metaclass=SingletonMetaclass): 14 | 15 | def __init__(self, t): 16 | self.t = t 17 | self.p = random.randint(0, 10) 18 | global count 19 | count += 1 20 | 21 | obj = [A(i) for i in range(3)] 22 | assert count == 1 23 | assert all([o.t == 0 for o in obj]) 24 | assert all([o.p == obj[0].p for o in obj]) 25 | assert all([id(o) == id(obj[0]) for o in obj]) 26 | assert id(A.instance) == id(obj[0]) 27 | 28 | # subclass test 29 | class B(A): 30 | pass 31 | 32 | obj = [B(i) for i in range(3, 6)] 33 | assert count == 2 34 | assert all([o.t == 3 for o in obj]) 35 | assert all([o.p == obj[0].p for o in obj]) 36 | assert all([id(o) == id(obj[0]) for o in obj]) 37 | assert id(B.instance) == id(obj[0]) 38 | -------------------------------------------------------------------------------- /ding/utils/tests/test_file_helper.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import random 3 | import pickle 4 | 5 | from ding.utils.file_helper import read_file, read_from_file, remove_file, save_file, read_from_path, save_file_ceph 6 | 7 | 8 | @pytest.mark.unittest 9 | def test_normal_file(): 10 | data1 = {'a': [random.randint(0, 100) for i in range(100)]} 11 | save_file('./f', data1) 12 | data2 = read_file("./f") 13 | assert (data2 == data1) 14 | with open("./f1", "wb") as f1: 15 | pickle.dump(data1, f1) 16 | data3 = read_from_file("./f1") 17 | assert (data3 == data1) 18 | data4 = read_from_path("./f1") 19 | assert (data4 == data1) 20 | save_file_ceph("./f2", data1) 21 | assert (data1 == read_from_file("./f2")) 22 | # test lock 23 | save_file('./f3', data1, use_lock=True) 24 | data_read = read_file('./f3', use_lock=True) 25 | assert isinstance(data_read, dict) 26 | 27 | remove_file("./f") 28 | remove_file("./f1") 29 | remove_file("./f2") 30 | remove_file("./f3") 31 | remove_file('./f.lock') 32 | remove_file('./f2.lock') 33 | remove_file('./f3.lock') 34 | remove_file('./name.txt') 35 | -------------------------------------------------------------------------------- /ding/utils/tests/test_import_helper.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import ding 4 | from ding.utils.import_helper import try_import_ceph, try_import_mc, try_import_redis, try_import_rediscluster, \ 5 | try_import_link, import_module 6 | 7 | 8 | @pytest.mark.unittest 9 | def test_try_import(): 10 | try_import_ceph() 11 | try_import_mc() 12 | try_import_redis() 13 | try_import_rediscluster() 14 | try_import_link() 15 | import_module(['ding.utils']) 16 | ding.enable_linklink = True 17 | try_import_link() 18 | -------------------------------------------------------------------------------- /ding/utils/tests/test_lock.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | from collections import deque 4 | 5 | from ding.utils import LockContext, LockContextType, get_rw_file_lock 6 | 7 | 8 | @pytest.mark.unittest 9 | def test_usage(): 10 | lock = LockContext(LockContextType.PROCESS_LOCK) 11 | queue = deque(maxlen=10) 12 | data = np.random.randn(4) 13 | with lock: 14 | queue.append(np.copy(data)) 15 | with lock: 16 | output = queue.popleft() 17 | assert (output == data).all() 18 | lock.acquire() 19 | queue.append(np.copy(data)) 20 | lock.release() 21 | lock.acquire() 22 | output = queue.popleft() 23 | lock.release() 24 | assert (output == data).all() 25 | 26 | 27 | @pytest.mark.unittest 28 | def test_get_rw_file_lock(): 29 | path = 'tmp.npy' 30 | # TODO real read-write case 31 | read_lock = get_rw_file_lock(path, 'read') 32 | write_lock = get_rw_file_lock(path, 'write') 33 | with write_lock: 34 | np.save(path, np.random.randint(0, 1, size=(3, 4))) 35 | with read_lock: 36 | data = np.load(path) 37 | assert data.shape == (3, 4) 38 | -------------------------------------------------------------------------------- /ding/utils/tests/test_registry.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from ding.utils.registry import Registry 3 | 4 | 5 | @pytest.mark.unittest 6 | def test_registry(): 7 | TEST_REGISTRY = Registry() 8 | 9 | @TEST_REGISTRY.register('a') 10 | class A: 11 | pass 12 | 13 | instance = TEST_REGISTRY.build('a') 14 | assert isinstance(instance, A) 15 | 16 | with pytest.raises(AssertionError): 17 | 18 | @TEST_REGISTRY.register('a') 19 | class A1: 20 | pass 21 | 22 | @TEST_REGISTRY.register('a', force_overwrite=True) 23 | class A2: 24 | pass 25 | 26 | instance = TEST_REGISTRY.build('a') 27 | assert isinstance(instance, A2) 28 | -------------------------------------------------------------------------------- /ding/utils/tests/test_system_helper.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from ding.utils.system_helper import get_ip, get_pid, get_task_uid 4 | 5 | 6 | @pytest.mark.unittest 7 | class TestSystemHelper(): 8 | 9 | def test_get(self): 10 | try: 11 | get_ip() 12 | except: 13 | pass 14 | assert isinstance(get_pid(), int) 15 | assert isinstance(get_task_uid(), str) 16 | -------------------------------------------------------------------------------- /ding/utils/time_helper_base.py: -------------------------------------------------------------------------------- 1 | class TimeWrapper(object): 2 | """ 3 | Overview: 4 | Abstract class method that defines ``TimeWrapper`` class 5 | 6 | Interfaces: 7 | ``wrapper``, ``start_time``, ``end_time`` 8 | """ 9 | 10 | @classmethod 11 | def wrapper(cls, fn): 12 | """ 13 | Overview: 14 | Classmethod wrapper, wrap a function and automatically return its running time 15 | Arguments: 16 | - fn (:obj:`function`): The function to be wrap and timed 17 | """ 18 | 19 | def time_func(*args, **kwargs): 20 | cls.start_time() 21 | ret = fn(*args, **kwargs) 22 | t = cls.end_time() 23 | return ret, t 24 | 25 | return time_func 26 | 27 | @classmethod 28 | def start_time(cls): 29 | """ 30 | Overview: 31 | Abstract classmethod, start timing 32 | """ 33 | raise NotImplementedError 34 | 35 | @classmethod 36 | def end_time(cls): 37 | """ 38 | Overview: 39 | Abstract classmethod, stop timing 40 | """ 41 | raise NotImplementedError 42 | -------------------------------------------------------------------------------- /ding/utils/type_helper.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from typing import List, Tuple, TypeVar 3 | 4 | SequenceType = TypeVar('SequenceType', List, Tuple, namedtuple) 5 | Tensor = TypeVar('torch.Tensor') 6 | -------------------------------------------------------------------------------- /ding/worker/__init__.py: -------------------------------------------------------------------------------- 1 | from .collector import * 2 | from .learner import * 3 | from .replay_buffer import * 4 | from .coordinator import * 5 | from .adapter import * 6 | -------------------------------------------------------------------------------- /ding/worker/adapter/__init__.py: -------------------------------------------------------------------------------- 1 | from .learner_aggregator import LearnerAggregator 2 | -------------------------------------------------------------------------------- /ding/worker/collector/__init__.py: -------------------------------------------------------------------------------- 1 | # serial 2 | from .base_serial_collector import ISerialCollector, create_serial_collector, get_serial_collector_cls, \ 3 | to_tensor_transitions 4 | 5 | from .sample_serial_collector import SampleSerialCollector 6 | from .episode_serial_collector import EpisodeSerialCollector 7 | from .battle_episode_serial_collector import BattleEpisodeSerialCollector 8 | from .battle_sample_serial_collector import BattleSampleSerialCollector 9 | 10 | from .base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor, create_serial_evaluator 11 | from .interaction_serial_evaluator import InteractionSerialEvaluator 12 | from .battle_interaction_serial_evaluator import BattleInteractionSerialEvaluator 13 | from .metric_serial_evaluator import MetricSerialEvaluator, IMetric 14 | # parallel 15 | from .base_parallel_collector import BaseParallelCollector, create_parallel_collector, get_parallel_collector_cls 16 | from .zergling_parallel_collector import ZerglingParallelCollector 17 | from .marine_parallel_collector import MarineParallelCollector 18 | from .comm import BaseCommCollector, FlaskFileSystemCollector, create_comm_collector, NaiveCollector 19 | -------------------------------------------------------------------------------- /ding/worker/collector/comm/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_comm_collector import BaseCommCollector, create_comm_collector 2 | from .flask_fs_collector import FlaskFileSystemCollector 3 | from .utils import NaiveCollector # for test 4 | -------------------------------------------------------------------------------- /ding/worker/collector/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/worker/collector/tests/__init__.py -------------------------------------------------------------------------------- /ding/worker/collector/tests/fake_cls_policy.py: -------------------------------------------------------------------------------- 1 | from ding.policy import Policy 2 | from ding.model import model_wrap 3 | 4 | 5 | class fake_policy(Policy): 6 | 7 | def _init_learn(self): 8 | pass 9 | 10 | def _forward_learn(self, data): 11 | pass 12 | 13 | def _init_eval(self): 14 | self._eval_model = model_wrap(self._model, 'base') 15 | 16 | def _forward_eval(self, data): 17 | self._eval_model.eval() 18 | output = self._eval_model.forward(data) 19 | return output 20 | 21 | def _monitor_vars_learn(self): 22 | return ['forward_time', 'backward_time', 'sync_time'] 23 | 24 | def _init_collect(self): 25 | pass 26 | 27 | def _forward_collect(self, data): 28 | pass 29 | 30 | def _process_transition(self): 31 | pass 32 | 33 | def _get_train_sample(self): 34 | pass 35 | -------------------------------------------------------------------------------- /ding/worker/collector/tests/speed_test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/worker/collector/tests/speed_test/__init__.py -------------------------------------------------------------------------------- /ding/worker/collector/tests/speed_test/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def random_change(number): 5 | return number * (1 + (np.random.random() - 0.5) * 0.6) 6 | -------------------------------------------------------------------------------- /ding/worker/coordinator/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_serial_commander import BaseSerialCommander 2 | from .base_parallel_commander import create_parallel_commander, get_parallel_commander_cls 3 | from .coordinator import Coordinator 4 | -------------------------------------------------------------------------------- /ding/worker/learner/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_learner import BaseLearner, create_learner 2 | from .comm import BaseCommLearner, FlaskFileSystemLearner, create_comm_learner 3 | from .learner_hook import register_learner_hook, add_learner_hook, merge_hooks, LearnerHook, build_learner_hook_by_cfg 4 | -------------------------------------------------------------------------------- /ding/worker/learner/comm/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_comm_learner import BaseCommLearner, create_comm_learner 2 | from .flask_fs_learner import FlaskFileSystemLearner 3 | from .utils import NaiveLearner # for test 4 | -------------------------------------------------------------------------------- /ding/worker/replay_buffer/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_buffer import IBuffer, create_buffer, get_buffer_cls 2 | from .naive_buffer import NaiveReplayBuffer, SequenceReplayBuffer 3 | from .advanced_buffer import AdvancedReplayBuffer 4 | from .episode_buffer import EpisodeReplayBuffer 5 | -------------------------------------------------------------------------------- /ding/worker/replay_buffer/episode_buffer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from ding.worker.replay_buffer import NaiveReplayBuffer 3 | from ding.utils import BUFFER_REGISTRY 4 | 5 | 6 | @BUFFER_REGISTRY.register('episode') 7 | class EpisodeReplayBuffer(NaiveReplayBuffer): 8 | r""" 9 | Overview: 10 | Episode replay buffer is a buffer to store complete episodes, i.e. Each element in episode buffer is an episode. 11 | Some algorithms do not want to sample `batch_size` complete episodes, however, they want some transitions with 12 | some fixed length. As a result, ``sample`` should be overwritten for those requirements. 13 | Interface: 14 | start, close, push, update, sample, clear, count, state_dict, load_state_dict, default_config 15 | """ 16 | 17 | @property 18 | def episode_len(self) -> List[int]: 19 | return [len(episode) for episode in self._data] 20 | -------------------------------------------------------------------------------- /ding/worker/replay_buffer/tests/conftest.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import numpy as np 3 | from ding.utils import save_file 4 | 5 | ID_COUNT = 0 6 | np.random.seed(1) 7 | 8 | 9 | def generate_data(meta: bool = False) -> dict: 10 | global ID_COUNT 11 | ret = {'obs': np.random.randn(4), 'data_id': str(ID_COUNT)} 12 | ID_COUNT += 1 13 | p_weight = np.random.uniform() 14 | if p_weight < 1 / 3: 15 | pass # no key 'priority' 16 | elif p_weight < 2 / 3: 17 | ret['priority'] = None 18 | else: 19 | ret['priority'] = np.random.uniform() + 1e-3 20 | if not meta: 21 | return ret 22 | else: 23 | obs = ret.pop('obs') 24 | save_file(ret['data_id'], obs) 25 | return ret 26 | 27 | 28 | def generate_data_list(count: int, meta: bool = False) -> List[dict]: 29 | return [generate_data(meta) for _ in range(0, count)] 30 | -------------------------------------------------------------------------------- /ding/world_model/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_world_model import WorldModel, DynaWorldModel, DreamWorldModel, HybridWorldModel, \ 2 | get_world_model_cls, create_world_model 3 | -------------------------------------------------------------------------------- /ding/world_model/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/world_model/model/__init__.py -------------------------------------------------------------------------------- /ding/world_model/model/tests/test_ensemble.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from itertools import product 4 | from ding.world_model.model.ensemble import EnsembleFC, EnsembleModel 5 | 6 | # arguments 7 | state_size = [16] 8 | action_size = [16, 1] 9 | reward_size = [1] 10 | args = list(product(*[state_size, action_size, reward_size])) 11 | 12 | 13 | @pytest.mark.unittest 14 | def test_EnsembleFC(): 15 | in_dim, out_dim, ensemble_size, B = 4, 8, 7, 64 16 | fc = EnsembleFC(in_dim, out_dim, ensemble_size) 17 | x = torch.randn(ensemble_size, B, in_dim) 18 | y = fc(x) 19 | assert y.shape == (ensemble_size, B, out_dim) 20 | 21 | 22 | @pytest.mark.parametrize('state_size, action_size, reward_size', args) 23 | def test_EnsembleModel(state_size, action_size, reward_size): 24 | ensemble_size, B = 7, 64 25 | model = EnsembleModel(state_size, action_size, reward_size, ensemble_size) 26 | x = torch.randn(ensemble_size, B, state_size + action_size) 27 | y = model(x) 28 | assert len(y) == 2 29 | assert y[0].shape == y[1].shape == (ensemble_size, B, state_size + reward_size) 30 | -------------------------------------------------------------------------------- /ding/world_model/model/tests/test_networks.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from itertools import product 4 | -------------------------------------------------------------------------------- /ding/world_model/tests/test_world_model_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from easydict import EasyDict 3 | from ding.world_model.utils import get_rollout_length_scheduler 4 | 5 | 6 | @pytest.mark.unittest 7 | def test_get_rollout_length_scheduler(): 8 | fake_cfg = EasyDict( 9 | type='linear', 10 | rollout_start_step=20000, 11 | rollout_end_step=150000, 12 | rollout_length_min=1, 13 | rollout_length_max=25, 14 | ) 15 | scheduler = get_rollout_length_scheduler(fake_cfg) 16 | assert scheduler(0) == 1 17 | assert scheduler(19999) == 1 18 | assert scheduler(150000) == 25 19 | assert scheduler(1500000) == 25 20 | -------------------------------------------------------------------------------- /ding/world_model/utils.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | from typing import Callable 3 | 4 | 5 | def get_rollout_length_scheduler(cfg: EasyDict) -> Callable[[int], int]: 6 | """ 7 | Overview: 8 | Get the rollout length scheduler that adapts rollout length based\ 9 | on the current environment steps. 10 | Returns: 11 | - scheduler (:obj:`Callble`): The function that takes envstep and\ 12 | return the current rollout length. 13 | """ 14 | if cfg.type == 'linear': 15 | x0 = cfg.rollout_start_step 16 | x1 = cfg.rollout_end_step 17 | y0 = cfg.rollout_length_min 18 | y1 = cfg.rollout_length_max 19 | w = (y1 - y0) / (x1 - x0) 20 | b = y0 21 | return lambda x: int(min(max(w * (x - x0) + b, y0), y1)) 22 | elif cfg.type == 'constant': 23 | return lambda x: cfg.rollout_length 24 | else: 25 | raise KeyError("not implemented key: {}".format(cfg.type)) 26 | -------------------------------------------------------------------------------- /dizoo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/__init__.py -------------------------------------------------------------------------------- /dizoo/atari/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/atari/__init__.py -------------------------------------------------------------------------------- /dizoo/atari/atari.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/atari/atari.gif -------------------------------------------------------------------------------- /dizoo/atari/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/atari/config/__init__.py -------------------------------------------------------------------------------- /dizoo/atari/config/serial/__init__.py: -------------------------------------------------------------------------------- 1 | from dizoo.atari.config.serial.enduro import * 2 | from dizoo.atari.config.serial.pong import * 3 | from dizoo.atari.config.serial.qbert import * 4 | from dizoo.atari.config.serial.spaceinvaders import * 5 | from dizoo.atari.config.serial.asterix import * 6 | -------------------------------------------------------------------------------- /dizoo/atari/config/serial/asterix/__init__.py: -------------------------------------------------------------------------------- 1 | from .asterix_mdqn_config import asterix_mdqn_config, asterix_mdqn_create_config 2 | -------------------------------------------------------------------------------- /dizoo/atari/config/serial/enduro/__init__.py: -------------------------------------------------------------------------------- 1 | from .enduro_dqn_config import enduro_dqn_config, enduro_dqn_create_config 2 | -------------------------------------------------------------------------------- /dizoo/atari/config/serial/pong/__init__.py: -------------------------------------------------------------------------------- 1 | from .pong_dqn_config import pong_dqn_config, pong_dqn_create_config 2 | from .pong_dqn_envpool_config import pong_dqn_envpool_config, pong_dqn_envpool_create_config 3 | from .pong_dqfd_config import pong_dqfd_config, pong_dqfd_create_config 4 | -------------------------------------------------------------------------------- /dizoo/atari/config/serial/qbert/__init__.py: -------------------------------------------------------------------------------- 1 | from .qbert_dqn_config import qbert_dqn_config, qbert_dqn_create_config 2 | from .qbert_dqfd_config import qbert_dqfd_config, qbert_dqfd_create_config 3 | -------------------------------------------------------------------------------- /dizoo/atari/config/serial/spaceinvaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .spaceinvaders_dqn_config import spaceinvaders_dqn_config, spaceinvaders_dqn_create_config 2 | from .spaceinvaders_dqfd_config import spaceinvaders_dqfd_config, spaceinvaders_dqfd_create_config 3 | -------------------------------------------------------------------------------- /dizoo/atari/entry/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/atari/entry/__init__.py -------------------------------------------------------------------------------- /dizoo/atari/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .atari_env import AtariEnv, AtariEnvMR 2 | -------------------------------------------------------------------------------- /dizoo/beergame/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/beergame/__init__.py -------------------------------------------------------------------------------- /dizoo/beergame/beergame.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/beergame/beergame.png -------------------------------------------------------------------------------- /dizoo/beergame/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .clBeergame import clBeerGame 2 | from .beergame_core import BeerGame 3 | -------------------------------------------------------------------------------- /dizoo/bitflip/README.md: -------------------------------------------------------------------------------- 1 | ## BitFlip Environment 2 | A simple environment to flip a 01 sequence into a specific state. With the bits number increasing, the task becomes harder. 3 | Well suited for testing Hindsight Experience Replay. 4 | 5 | ## DI-engine's HER on BitFlip 6 | 7 | The table shows how many envsteps are needed at least to converge for PureDQN and HER-DQN implemented in DI-engine. '-' means no convergence in 20M envsteps. 8 | 9 | | n_bit | PureDQN | HER-DQN | 10 | | ------ | ------- | ------- | 11 | | 15 | - | 150K | 12 | | 20 | - | 1.5M | 13 | DI-engine's HER-DQN can converge 14 | 15 | You can refer to the RL algorithm doc for implementation and experiment details. 16 | -------------------------------------------------------------------------------- /dizoo/bitflip/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/bitflip/__init__.py -------------------------------------------------------------------------------- /dizoo/bitflip/bitflip.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/bitflip/bitflip.gif -------------------------------------------------------------------------------- /dizoo/bitflip/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .bitflip_her_dqn_config import bitflip_her_dqn_config, bitflip_her_dqn_create_config 2 | from .bitflip_pure_dqn_config import bitflip_pure_dqn_config, bitflip_pure_dqn_create_config 3 | -------------------------------------------------------------------------------- /dizoo/bitflip/entry/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/bitflip/entry/__init__.py -------------------------------------------------------------------------------- /dizoo/bitflip/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .bitflip_env import BitFlipEnv 2 | -------------------------------------------------------------------------------- /dizoo/bitflip/envs/test_bitfilp_env.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from easydict import EasyDict 3 | import numpy as np 4 | from dizoo.bitflip.envs import BitFlipEnv 5 | 6 | 7 | @pytest.mark.envtest 8 | def test_bitfilp_env(): 9 | n_bits = 10 10 | env = BitFlipEnv(EasyDict({'n_bits': n_bits})) 11 | env.seed(314) 12 | assert env._seed == 314 13 | obs = env.reset() 14 | assert obs.shape == (2 * n_bits, ) 15 | for i in range(10): 16 | # Both ``env.random_action()``, and utilizing ``np.random`` as well as action space, 17 | # can generate legal random action. 18 | if i < 5: 19 | action = np.random.randint(0, n_bits, size=(1, )) 20 | else: 21 | action = env.random_action() 22 | timestep = env.step(action) 23 | assert timestep.obs.shape == (2 * n_bits, ) 24 | assert timestep.reward.shape == (1, ) 25 | -------------------------------------------------------------------------------- /dizoo/box2d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/box2d/__init__.py -------------------------------------------------------------------------------- /dizoo/box2d/bipedalwalker/__init__.py: -------------------------------------------------------------------------------- 1 | from dizoo.box2d.bipedalwalker.config import * 2 | -------------------------------------------------------------------------------- /dizoo/box2d/bipedalwalker/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .bipedalwalker_sac_config import bipedalwalker_sac_config, bipedalwalker_sac_create_config 2 | -------------------------------------------------------------------------------- /dizoo/box2d/bipedalwalker/entry/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/box2d/bipedalwalker/entry/__init__.py -------------------------------------------------------------------------------- /dizoo/box2d/bipedalwalker/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .bipedalwalker_env import BipedalWalkerEnv 2 | -------------------------------------------------------------------------------- /dizoo/box2d/bipedalwalker/envs/test_bipedalwalker.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from easydict import EasyDict 3 | import numpy as np 4 | from dizoo.box2d.bipedalwalker.envs import BipedalWalkerEnv 5 | 6 | 7 | @pytest.mark.envtest 8 | class TestBipedalWalkerEnv: 9 | 10 | def test_naive(self): 11 | env = BipedalWalkerEnv(EasyDict({'act_scale': True, 'rew_clip': True, 'replay_path': None})) 12 | env.seed(123) 13 | assert env._seed == 123 14 | obs = env.reset() 15 | assert obs.shape == (24, ) 16 | for i in range(10): 17 | random_action = env.random_action() 18 | timestep = env.step(random_action) 19 | print(timestep) 20 | assert isinstance(timestep.obs, np.ndarray) 21 | assert isinstance(timestep.done, bool) 22 | assert timestep.obs.shape == (24, ) 23 | assert timestep.reward.shape == (1, ) 24 | assert timestep.reward >= env.reward_space.low 25 | assert timestep.reward <= env.reward_space.high 26 | # assert isinstance(timestep, tuple) 27 | print(env.observation_space, env.action_space, env.reward_space) 28 | env.close() 29 | -------------------------------------------------------------------------------- /dizoo/box2d/bipedalwalker/original.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/box2d/bipedalwalker/original.gif -------------------------------------------------------------------------------- /dizoo/box2d/carracing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/box2d/carracing/__init__.py -------------------------------------------------------------------------------- /dizoo/box2d/carracing/car_racing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/box2d/carracing/car_racing.gif -------------------------------------------------------------------------------- /dizoo/box2d/carracing/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .carracing_dqn_config import carracing_dqn_config, carracing_dqn_create_config 2 | -------------------------------------------------------------------------------- /dizoo/box2d/carracing/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .carracing_env import CarRacingEnv 2 | -------------------------------------------------------------------------------- /dizoo/box2d/carracing/envs/test_carracing_env.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | from easydict import EasyDict 4 | from carracing_env import CarRacingEnv 5 | 6 | 7 | @pytest.mark.envtest 8 | @pytest.mark.parametrize('cfg', [EasyDict({'env_id': 'CarRacing-v2', 'continuous': False, 'act_scale': False})]) 9 | class TestCarRacing: 10 | 11 | def test_naive(self, cfg): 12 | env = CarRacingEnv(cfg) 13 | env.seed(314) 14 | assert env._seed == 314 15 | obs = env.reset() 16 | assert obs.shape == (3, 96, 96) 17 | for i in range(10): 18 | random_action = env.random_action() 19 | timestep = env.step(random_action) 20 | print(timestep) 21 | assert isinstance(timestep.obs, np.ndarray) 22 | assert isinstance(timestep.done, bool) 23 | assert timestep.obs.shape == (3, 96, 96) 24 | assert timestep.reward.shape == (1, ) 25 | assert timestep.reward >= env.reward_space.low 26 | assert timestep.reward <= env.reward_space.high 27 | print(env.observation_space, env.action_space, env.reward_space) 28 | env.close() 29 | -------------------------------------------------------------------------------- /dizoo/box2d/lunarlander/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/box2d/lunarlander/__init__.py -------------------------------------------------------------------------------- /dizoo/box2d/lunarlander/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .lunarlander_dqn_config import lunarlander_dqn_config, lunarlander_dqn_create_config 2 | from .lunarlander_gail_dqn_config import lunarlander_dqn_gail_create_config, lunarlander_dqn_gail_config 3 | from .lunarlander_dqfd_config import lunarlander_dqfd_config, lunarlander_dqfd_create_config 4 | from .lunarlander_qrdqn_config import lunarlander_qrdqn_config, lunarlander_qrdqn_create_config 5 | from .lunarlander_trex_dqn_config import lunarlander_trex_dqn_config, lunarlander_trex_dqn_create_config 6 | from .lunarlander_trex_offppo_config import lunarlander_trex_ppo_config, lunarlander_trex_ppo_create_config 7 | -------------------------------------------------------------------------------- /dizoo/box2d/lunarlander/entry/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/box2d/lunarlander/entry/__init__.py -------------------------------------------------------------------------------- /dizoo/box2d/lunarlander/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .lunarlander_env import LunarLanderEnv 2 | -------------------------------------------------------------------------------- /dizoo/box2d/lunarlander/lunarlander.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/box2d/lunarlander/lunarlander.gif -------------------------------------------------------------------------------- /dizoo/bsuite/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/bsuite/__init__.py -------------------------------------------------------------------------------- /dizoo/bsuite/bsuite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/bsuite/bsuite.png -------------------------------------------------------------------------------- /dizoo/bsuite/config/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dizoo/bsuite/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .bsuite_env import BSuiteEnv 2 | -------------------------------------------------------------------------------- /dizoo/classic_control/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/classic_control/__init__.py -------------------------------------------------------------------------------- /dizoo/classic_control/acrobot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/classic_control/acrobot/__init__.py -------------------------------------------------------------------------------- /dizoo/classic_control/acrobot/acrobot.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/classic_control/acrobot/acrobot.gif -------------------------------------------------------------------------------- /dizoo/classic_control/acrobot/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .acrobot_dqn_config import acrobot_dqn_config, acrobot_dqn_create_config 2 | -------------------------------------------------------------------------------- /dizoo/classic_control/acrobot/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .acrobot_env import AcroBotEnv 2 | -------------------------------------------------------------------------------- /dizoo/classic_control/cartpole/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/classic_control/cartpole/__init__.py -------------------------------------------------------------------------------- /dizoo/classic_control/cartpole/cartpole.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/classic_control/cartpole/cartpole.gif -------------------------------------------------------------------------------- /dizoo/classic_control/cartpole/config/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config, cartpole_dqn_system_config 2 | -------------------------------------------------------------------------------- /dizoo/classic_control/cartpole/entry/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/classic_control/cartpole/entry/__init__.py -------------------------------------------------------------------------------- /dizoo/classic_control/cartpole/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .cartpole_env import CartPoleEnv 2 | -------------------------------------------------------------------------------- /dizoo/classic_control/mountain_car/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/classic_control/mountain_car/__init__.py -------------------------------------------------------------------------------- /dizoo/classic_control/mountain_car/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .mtcar_env import MountainCarEnv 2 | -------------------------------------------------------------------------------- /dizoo/classic_control/pendulum/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/classic_control/pendulum/__init__.py -------------------------------------------------------------------------------- /dizoo/classic_control/pendulum/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .pendulum_ddpg_config import pendulum_ddpg_config, pendulum_ddpg_create_config 2 | from .pendulum_td3_config import pendulum_td3_config, pendulum_td3_create_config 3 | from .pendulum_sac_config import pendulum_sac_config, pendulum_sac_create_config 4 | from .pendulum_d4pg_config import pendulum_d4pg_config, pendulum_d4pg_create_config 5 | from .pendulum_ppo_config import pendulum_ppo_config, pendulum_ppo_create_config 6 | from .pendulum_td3_bc_config import pendulum_td3_bc_config, pendulum_td3_bc_create_config 7 | from .pendulum_td3_data_generation_config import pendulum_td3_generation_config, pendulum_td3_generation_create_config 8 | -------------------------------------------------------------------------------- /dizoo/classic_control/pendulum/entry/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/classic_control/pendulum/entry/__init__.py -------------------------------------------------------------------------------- /dizoo/classic_control/pendulum/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .pendulum_env import PendulumEnv 2 | -------------------------------------------------------------------------------- /dizoo/classic_control/pendulum/pendulum.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/classic_control/pendulum/pendulum.gif -------------------------------------------------------------------------------- /dizoo/cliffwalking/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/cliffwalking/__init__.py -------------------------------------------------------------------------------- /dizoo/cliffwalking/cliff_walking.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/cliffwalking/cliff_walking.gif -------------------------------------------------------------------------------- /dizoo/cliffwalking/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .cliffwalking_env import CliffWalkingEnv 2 | -------------------------------------------------------------------------------- /dizoo/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/common/__init__.py -------------------------------------------------------------------------------- /dizoo/common/policy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/common/policy/__init__.py -------------------------------------------------------------------------------- /dizoo/competitive_rl/README.md: -------------------------------------------------------------------------------- 1 | Environment "Competitive RL"'s original repo is https://github.com/cuhkrlcourse/competitive-rl. 2 | You can refer to it for guide on installation and usage. -------------------------------------------------------------------------------- /dizoo/competitive_rl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/competitive_rl/__init__.py -------------------------------------------------------------------------------- /dizoo/competitive_rl/competitive_rl.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/competitive_rl/competitive_rl.gif -------------------------------------------------------------------------------- /dizoo/competitive_rl/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .competitive_rl_env import CompetitiveRlEnv 2 | -------------------------------------------------------------------------------- /dizoo/competitive_rl/envs/resources/pong/checkpoint-alphapong.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/competitive_rl/envs/resources/pong/checkpoint-alphapong.pkl -------------------------------------------------------------------------------- /dizoo/competitive_rl/envs/resources/pong/checkpoint-medium.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/competitive_rl/envs/resources/pong/checkpoint-medium.pkl -------------------------------------------------------------------------------- /dizoo/competitive_rl/envs/resources/pong/checkpoint-strong.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/competitive_rl/envs/resources/pong/checkpoint-strong.pkl -------------------------------------------------------------------------------- /dizoo/competitive_rl/envs/resources/pong/checkpoint-weak.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/competitive_rl/envs/resources/pong/checkpoint-weak.pkl -------------------------------------------------------------------------------- /dizoo/d4rl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/d4rl/__init__.py -------------------------------------------------------------------------------- /dizoo/d4rl/config/__init__.py: -------------------------------------------------------------------------------- 1 | # from .hopper_cql_config import hopper_cql_config 2 | # from .hopper_expert_cql_config import hopper_expert_cql_config 3 | # from .hopper_medium_cql_config import hopper_medium_cql_config 4 | -------------------------------------------------------------------------------- /dizoo/d4rl/d4rl.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/d4rl/d4rl.gif -------------------------------------------------------------------------------- /dizoo/d4rl/entry/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/d4rl/entry/__init__.py -------------------------------------------------------------------------------- /dizoo/d4rl/entry/d4rl_bcq_main.py: -------------------------------------------------------------------------------- 1 | from ding.entry import serial_pipeline_offline 2 | from ding.config import read_config 3 | from pathlib import Path 4 | 5 | 6 | def train(args): 7 | # launch from anywhere 8 | config = Path(__file__).absolute().parent.parent / 'config' / args.config 9 | config = read_config(str(config)) 10 | config[0].exp_name = config[0].exp_name.replace('0', str(args.seed)) 11 | serial_pipeline_offline(config, seed=args.seed) 12 | 13 | 14 | if __name__ == "__main__": 15 | import argparse 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--seed', '-s', type=int, default=0) 19 | parser.add_argument('--config', '-c', type=str, default='halfcheetah_medium_bcq_config.py') 20 | args = parser.parse_args() 21 | train(args) 22 | -------------------------------------------------------------------------------- /dizoo/d4rl/entry/d4rl_cql_main.py: -------------------------------------------------------------------------------- 1 | from ding.entry import serial_pipeline_offline 2 | from ding.config import read_config 3 | from pathlib import Path 4 | 5 | 6 | def train(args): 7 | # launch from anywhere 8 | config = Path(__file__).absolute().parent.parent / 'config' / args.config 9 | config = read_config(str(config)) 10 | config[0].exp_name = config[0].exp_name.replace('0', str(args.seed)) 11 | serial_pipeline_offline(config, seed=args.seed) 12 | 13 | 14 | if __name__ == "__main__": 15 | import argparse 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--seed', '-s', type=int, default=10) 19 | parser.add_argument('--config', '-c', type=str, default='hopper_expert_cql_config.py') 20 | args = parser.parse_args() 21 | train(args) 22 | -------------------------------------------------------------------------------- /dizoo/d4rl/entry/d4rl_edac_main.py: -------------------------------------------------------------------------------- 1 | from ding.entry import serial_pipeline_offline 2 | from ding.config import read_config 3 | from pathlib import Path 4 | 5 | 6 | def train(args): 7 | # launch from anywhere 8 | config = Path(__file__).absolute().parent.parent / 'config' / args.config 9 | config = read_config(str(config)) 10 | config[0].exp_name = config[0].exp_name.replace('0', str(args.seed)) 11 | serial_pipeline_offline(config, seed=args.seed) 12 | 13 | 14 | if __name__ == "__main__": 15 | import argparse 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--seed', '-s', type=int, default=0) 19 | parser.add_argument('--config', '-c', type=str, default='halfcheetah_medium_edac_config.py') 20 | args = parser.parse_args() 21 | train(args) 22 | -------------------------------------------------------------------------------- /dizoo/d4rl/entry/d4rl_iql_main.py: -------------------------------------------------------------------------------- 1 | from ding.entry import serial_pipeline_offline 2 | from ding.config import read_config 3 | from pathlib import Path 4 | 5 | 6 | def train(args): 7 | # launch from anywhere 8 | config = Path(__file__).absolute().parent.parent / 'config' / args.config 9 | config = read_config(str(config)) 10 | config[0].exp_name = config[0].exp_name.replace('0', str(args.seed)) 11 | serial_pipeline_offline(config, seed=args.seed) 12 | 13 | 14 | if __name__ == "__main__": 15 | import argparse 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--seed', '-s', type=int, default=10) 19 | parser.add_argument('--config', '-c', type=str, default='halfcheetah_medium_iql_config.py') 20 | args = parser.parse_args() 21 | train(args) 22 | -------------------------------------------------------------------------------- /dizoo/d4rl/entry/d4rl_pd_main.py: -------------------------------------------------------------------------------- 1 | from ding.entry import serial_pipeline_offline 2 | from ding.config import read_config 3 | from pathlib import Path 4 | 5 | 6 | def train(args): 7 | # launch from anywhere 8 | config = Path(__file__).absolute().parent.parent / 'config' / args.config 9 | config = read_config(str(config)) 10 | config[0].exp_name = config[0].exp_name.replace('0', str(args.seed)) 11 | serial_pipeline_offline(config, seed=args.seed) 12 | 13 | 14 | if __name__ == "__main__": 15 | import argparse 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--seed', '-s', type=int, default=10) 19 | parser.add_argument('--config', '-c', type=str, default='halfcheetah_medium_pd_config.py') 20 | args = parser.parse_args() 21 | train(args) 22 | -------------------------------------------------------------------------------- /dizoo/d4rl/entry/d4rl_td3_bc_main.py: -------------------------------------------------------------------------------- 1 | from ding.entry import serial_pipeline_offline 2 | from ding.config import read_config 3 | from pathlib import Path 4 | 5 | 6 | def train(args): 7 | # launch from anywhere 8 | config = Path(__file__).absolute().parent.parent / 'config' / args.config 9 | config = read_config(str(config)) 10 | config[0].exp_name = config[0].exp_name.replace('0', str(args.seed)) 11 | serial_pipeline_offline(config, seed=args.seed) 12 | 13 | 14 | if __name__ == "__main__": 15 | import argparse 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--seed', '-s', type=int, default=10) 19 | parser.add_argument('--config', '-c', type=str, default='hopper_medium_expert_td3bc_config.py') 20 | args = parser.parse_args() 21 | train(args) 22 | -------------------------------------------------------------------------------- /dizoo/d4rl/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .d4rl_env import D4RLEnv 2 | -------------------------------------------------------------------------------- /dizoo/dmc2gym/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/dmc2gym/__init__.py -------------------------------------------------------------------------------- /dizoo/dmc2gym/dmc2gym_cheetah.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/dmc2gym/dmc2gym_cheetah.png -------------------------------------------------------------------------------- /dizoo/dmc2gym/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .dmc2gym_env import DMC2GymEnv 2 | -------------------------------------------------------------------------------- /dizoo/evogym/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/evogym/__init__.py -------------------------------------------------------------------------------- /dizoo/evogym/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .evogym_env import EvoGymEnv 2 | -------------------------------------------------------------------------------- /dizoo/evogym/evogym.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/evogym/evogym.gif -------------------------------------------------------------------------------- /dizoo/frozen_lake/FrozenLake.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/frozen_lake/FrozenLake.gif -------------------------------------------------------------------------------- /dizoo/frozen_lake/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/frozen_lake/__init__.py -------------------------------------------------------------------------------- /dizoo/frozen_lake/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .frozen_lake_dqn_config import main_config, create_config 2 | -------------------------------------------------------------------------------- /dizoo/frozen_lake/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .frozen_lake_env import FrozenLakeEnv 2 | -------------------------------------------------------------------------------- /dizoo/gfootball/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gfootball/__init__.py -------------------------------------------------------------------------------- /dizoo/gfootball/entry/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gfootball/entry/__init__.py -------------------------------------------------------------------------------- /dizoo/gfootball/envs/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | try: 4 | from .gfootball_env import GfootballEnv 5 | except ImportError: 6 | warnings.warn("not found gfootball env, please install it") 7 | GfootballEnv = None 8 | -------------------------------------------------------------------------------- /dizoo/gfootball/envs/action/gfootball_action_runner.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | 5 | from ding.envs.common import EnvElementRunner 6 | from ding.envs.env.base_env import BaseEnv 7 | from .gfootball_action import GfootballRawAction 8 | 9 | 10 | class GfootballRawActionRunner(EnvElementRunner): 11 | 12 | def _init(self, cfg, *args, **kwargs) -> None: 13 | # set self._core and other state variable 14 | self._core = GfootballRawAction(cfg) 15 | 16 | def get(self, engine: BaseEnv) -> np.array: 17 | agent_action = copy.deepcopy(engine.agent_action) 18 | return agent_action 19 | 20 | def reset(self) -> None: 21 | pass 22 | -------------------------------------------------------------------------------- /dizoo/gfootball/envs/obs/gfootball_obs_runner.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | 5 | from ding.envs.common import EnvElementRunner, EnvElement 6 | from ding.envs.env.base_env import BaseEnv 7 | from .gfootball_obs import PlayerObs, MatchObs 8 | from ding.utils import deep_merge_dicts 9 | 10 | 11 | class GfootballObsRunner(EnvElementRunner): 12 | 13 | def _init(self, cfg, *args, **kwargs) -> None: 14 | # set self._core and other state variable 15 | self._obs_match = MatchObs(cfg) 16 | self._obs_player = PlayerObs(cfg) 17 | self._core = self._obs_player # placeholder 18 | 19 | def get(self, engine: BaseEnv) -> dict: 20 | ret = copy.deepcopy(engine._football_obs) 21 | # print(ret, type(ret)) 22 | assert isinstance(ret, dict) 23 | match_obs = self._obs_match._to_agent_processor(ret) 24 | players_obs = self._obs_player._to_agent_processor(ret) 25 | return deep_merge_dicts(match_obs, players_obs) 26 | 27 | def reset(self) -> None: 28 | pass 29 | 30 | # override 31 | @property 32 | def info(self): 33 | return {'match': self._obs_match.info, 'player': self._obs_player.info} 34 | -------------------------------------------------------------------------------- /dizoo/gfootball/envs/reward/gfootball_reward_runner.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | 5 | from ding.envs.common import EnvElementRunner 6 | from ding.envs.env.base_env import BaseEnv 7 | from .gfootball_reward import GfootballReward 8 | 9 | 10 | class GfootballRewardRunner(EnvElementRunner): 11 | 12 | def _init(self, cfg, *args, **kwargs) -> None: 13 | # set self._core and other state variable 14 | self._core = GfootballReward(cfg) 15 | self._cum_reward = 0.0 16 | 17 | def get(self, engine: BaseEnv) -> torch.tensor: 18 | ret = copy.deepcopy(engine._reward_of_action) 19 | self._cum_reward += ret 20 | return self._core._to_agent_processor(ret) 21 | 22 | def reset(self) -> None: 23 | self._cum_reward = 0.0 24 | 25 | @property 26 | def cum_reward(self) -> torch.tensor: 27 | return torch.FloatTensor([self._cum_reward]) 28 | -------------------------------------------------------------------------------- /dizoo/gfootball/gfootball.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gfootball/gfootball.gif -------------------------------------------------------------------------------- /dizoo/gfootball/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gfootball/model/__init__.py -------------------------------------------------------------------------------- /dizoo/gfootball/model/bots/TamakEriFever/config.yaml: -------------------------------------------------------------------------------- 1 | 2 | env_args: 3 | env: 'Football' 4 | source: 'football_ikki' 5 | frames_per_sec: 10 # we cannot change 6 | 7 | frame_skip: 0 8 | limit_steps: 3002 9 | 10 | train_args: 11 | gamma_per_sec: 0.97 12 | lambda_per_sec: 0.4 13 | forward_steps: 64 14 | compress_steps: 16 15 | entropy_regularization: 1.3e-3 16 | monte_carlo_rate: 1.0 17 | update_episodes: 400 18 | batch_size: 192 19 | minimum_episodes: 3000 20 | maximum_episodes: 30000 21 | num_batchers: 23 22 | eval_rate: 0.1 23 | replay_rate: 0 # 0.1 24 | supervised_weight: 0 # 0.1 25 | record_dir: "records/" 26 | randomized_start_rate: 0.3 27 | randomized_start_max_steps: 400 28 | reward_reset: True 29 | worker: 30 | num_gather: 2 31 | num_process: 6 32 | seed: 1800 33 | restart_epoch: 1679 34 | 35 | entry_args: 36 | remote_host: '' 37 | num_gather: 2 38 | num_process: 6 39 | 40 | eval_args: 41 | remote_host: '' 42 | 43 | -------------------------------------------------------------------------------- /dizoo/gfootball/model/bots/TamakEriFever/readme.md: -------------------------------------------------------------------------------- 1 | This is the kaggle gfootball competition 5 th place solution. 2 | 3 | See https://www.kaggle.com/c/google-football/discussion/203412 from detail. 4 | 5 | Thanks [kyazuki](https://www.kaggle.com/kyazuki) and [@yuricat](https://www.kaggle.com/yuricat) who are generous to share their code. -------------------------------------------------------------------------------- /dizoo/gfootball/model/bots/TamakEriFever/view_test.py: -------------------------------------------------------------------------------- 1 | # Set up the Environment. 2 | 3 | import time 4 | 5 | from kaggle_environments import make 6 | 7 | # opponent = "football/idle.py" 8 | # opponent = "football/rulebaseC.py" 9 | opponent = "builtin_ai" 10 | 11 | video_title = "chain" 12 | video_path = "videos/" + video_title + "_" + opponent.split("/")[-1].replace(".py", 13 | "") + str(int(time.time())) + ".webm" 14 | 15 | env = make( 16 | "football", 17 | configuration={ 18 | "save_video": True, 19 | "scenario_name": "11_vs_11_kaggle", 20 | "running_in_notebook": False 21 | }, 22 | info={"LiveVideoPath": video_path}, 23 | debug=True 24 | ) 25 | output = env.run(["submission.py", opponent])[-1] 26 | 27 | scores = [output[i]['observation']['players_raw'][0]['score'][0] for i in range(2)] 28 | print('Left player: score = %s, status = %s, info = %s' % (scores[0], output[0]['status'], output[0]['info'])) 29 | print('Right player: score = %s, status = %s, info = %s' % (scores[1], output[1]['status'], output[1]['info'])) 30 | 31 | env.render(mode="human", width=800, height=600) 32 | -------------------------------------------------------------------------------- /dizoo/gfootball/model/bots/__init__.py: -------------------------------------------------------------------------------- 1 | from .kaggle_5th_place_model import FootballKaggle5thPlaceModel 2 | from .rule_based_bot_model import FootballRuleBaseModel 3 | -------------------------------------------------------------------------------- /dizoo/gfootball/policy/__init__.py: -------------------------------------------------------------------------------- 1 | from .ppo_lstm import PPOPolicy, PPOCommandModePolicy 2 | -------------------------------------------------------------------------------- /dizoo/gobigger_overview.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gobigger_overview.gif -------------------------------------------------------------------------------- /dizoo/gym_anytrading/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_anytrading/__init__.py -------------------------------------------------------------------------------- /dizoo/gym_anytrading/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .stocks_dqn_config import stocks_dqn_config, stocks_dqn_create_config 2 | -------------------------------------------------------------------------------- /dizoo/gym_anytrading/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .trading_env import TradingEnv, Actions, Positions 2 | from .stocks_env import StocksEnv 3 | -------------------------------------------------------------------------------- /dizoo/gym_anytrading/envs/data/README.md: -------------------------------------------------------------------------------- 1 | You can put stocks data here. 2 | Your data file needs to be named like "STOCKS_GOOGL.csv", which ends up with ".csv" suffix. -------------------------------------------------------------------------------- /dizoo/gym_anytrading/envs/position.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_anytrading/envs/position.png -------------------------------------------------------------------------------- /dizoo/gym_anytrading/envs/profit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_anytrading/envs/profit.png -------------------------------------------------------------------------------- /dizoo/gym_anytrading/envs/statemachine.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_anytrading/envs/statemachine.png -------------------------------------------------------------------------------- /dizoo/gym_anytrading/worker/__init__.py: -------------------------------------------------------------------------------- 1 | import imp 2 | from .trading_serial_evaluator import * 3 | -------------------------------------------------------------------------------- /dizoo/gym_hybrid/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_hybrid/__init__.py -------------------------------------------------------------------------------- /dizoo/gym_hybrid/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_hybrid/config/__init__.py -------------------------------------------------------------------------------- /dizoo/gym_hybrid/entry/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_hybrid/entry/__init__.py -------------------------------------------------------------------------------- /dizoo/gym_hybrid/envs/README.md: -------------------------------------------------------------------------------- 1 | # Modified gym-hybrid 2 | 3 | The gym-hybrid directory is modified from https://github.com/thomashirtz/gym-hybrid. 4 | We add the HardMove environment additionally. (Please refer to https://arxiv.org/abs/2109.05490 Section 5.1 for details about HardMove env.) 5 | 6 | Specifically, the modified gym-hybrid contains the following three types of environments: 7 | 8 | - Moving-v0 9 | - Sliding-v0 10 | - HardMove-v0 11 | 12 | ### Install Guide 13 | 14 | ```bash 15 | cd DI-engine/dizoo/gym_hybrid/envs/gym-hybrid 16 | pip install -e . 17 | ``` 18 | 19 | ## Acknowledgement 20 | 21 | https://github.com/thomashirtz/gym-hybrid -------------------------------------------------------------------------------- /dizoo/gym_hybrid/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .gym_hybrid_env import GymHybridEnv 2 | -------------------------------------------------------------------------------- /dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | from gym_hybrid.environments import MovingEnv 3 | from gym_hybrid.environments import SlidingEnv 4 | from gym_hybrid.environments import HardMoveEnv 5 | 6 | register( 7 | id='Moving-v0', 8 | entry_point='gym_hybrid:MovingEnv', 9 | ) 10 | register( 11 | id='Sliding-v0', 12 | entry_point='gym_hybrid:SlidingEnv', 13 | ) 14 | register( 15 | id='HardMove-v0', 16 | entry_point='gym_hybrid:HardMoveEnv', 17 | ) 18 | -------------------------------------------------------------------------------- /dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/bg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/bg.jpg -------------------------------------------------------------------------------- /dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/target.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/target.png -------------------------------------------------------------------------------- /dizoo/gym_hybrid/envs/gym-hybrid/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='gym_hybrid', 5 | version='0.0.2', # original gym_hybrid version='0.0.1' 6 | packages=['gym_hybrid'], 7 | install_requires=['gym', 'numpy'], 8 | ) 9 | -------------------------------------------------------------------------------- /dizoo/gym_hybrid/envs/gym-hybrid/tests/hardmove.py: -------------------------------------------------------------------------------- 1 | import time 2 | import gym 3 | import gym_hybrid 4 | 5 | if __name__ == '__main__': 6 | env = gym.make('HardMove-v0') 7 | env.reset() 8 | 9 | ACTION_SPACE = env.action_space[0].n 10 | PARAMETERS_SPACE = env.action_space[1].shape[0] 11 | OBSERVATION_SPACE = env.observation_space.shape[0] 12 | 13 | done = False 14 | while not done: 15 | state, reward, done, info = env.step(env.action_space.sample()) 16 | print(f'State: {state} Reward: {reward} Done: {done}') 17 | time.sleep(0.1) 18 | -------------------------------------------------------------------------------- /dizoo/gym_hybrid/envs/gym-hybrid/tests/moving.py: -------------------------------------------------------------------------------- 1 | import time 2 | import gym 3 | import gym_hybrid 4 | 5 | if __name__ == '__main__': 6 | env = gym.make('Moving-v0') 7 | env.reset() 8 | 9 | ACTION_SPACE = env.action_space[0].n 10 | PARAMETERS_SPACE = env.action_space[1].shape[0] 11 | OBSERVATION_SPACE = env.observation_space.shape[0] 12 | 13 | done = False 14 | while not done: 15 | state, reward, done, info = env.step(env.action_space.sample()) 16 | print(f'State: {state} Reward: {reward} Done: {done}') 17 | time.sleep(0.1) 18 | -------------------------------------------------------------------------------- /dizoo/gym_hybrid/envs/gym-hybrid/tests/record.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import gym_hybrid 3 | 4 | if __name__ == '__main__': 5 | env = gym.make('Sliding-v0') 6 | env = gym.wrappers.Monitor(env, "./video", force=True) 7 | env.metadata["render.modes"] = ["human", "rgb_array"] 8 | env.reset() 9 | 10 | done = False 11 | while not done: 12 | _, _, done, _ = env.step(env.action_space.sample()) 13 | 14 | env.close() 15 | -------------------------------------------------------------------------------- /dizoo/gym_hybrid/envs/gym-hybrid/tests/render.py: -------------------------------------------------------------------------------- 1 | import time 2 | import gym 3 | import gym_hybrid 4 | 5 | if __name__ == '__main__': 6 | env = gym.make('Sliding-v0') 7 | env.reset() 8 | 9 | done = False 10 | while not done: 11 | _, _, done, _ = env.step(env.action_space.sample()) 12 | env.render() 13 | time.sleep(0.1) 14 | 15 | time.sleep(1) 16 | env.close() 17 | -------------------------------------------------------------------------------- /dizoo/gym_hybrid/envs/gym-hybrid/tests/sliding.py: -------------------------------------------------------------------------------- 1 | import time 2 | import gym 3 | import gym_hybrid 4 | 5 | if __name__ == '__main__': 6 | env = gym.make('Sliding-v0') 7 | env.reset() 8 | 9 | ACTION_SPACE = env.action_space[0].n 10 | PARAMETERS_SPACE = env.action_space[1].shape[0] 11 | OBSERVATION_SPACE = env.observation_space.shape[0] 12 | 13 | done = False 14 | while not done: 15 | state, reward, done, info = env.step(env.action_space.sample()) 16 | print(f'State: {state} Reward: {reward} Done: {done}') 17 | time.sleep(0.1) 18 | -------------------------------------------------------------------------------- /dizoo/gym_hybrid/moving_v0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_hybrid/moving_v0.gif -------------------------------------------------------------------------------- /dizoo/gym_pybullet_drones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_pybullet_drones/__init__.py -------------------------------------------------------------------------------- /dizoo/gym_pybullet_drones/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .gym_pybullet_drones_env import GymPybulletDronesEnv 2 | -------------------------------------------------------------------------------- /dizoo/gym_pybullet_drones/envs/test_ding_env.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from easydict import EasyDict 3 | import gym_pybullet_drones 4 | 5 | from ding.envs import BaseEnv, BaseEnvTimestep 6 | from dizoo.gym_pybullet_drones.envs.gym_pybullet_drones_env import GymPybulletDronesEnv 7 | 8 | 9 | @pytest.mark.envtest 10 | class TestGymPybulletDronesEnv: 11 | 12 | def test_naive(self): 13 | cfg = {"env_id": "takeoff-aviary-v0"} 14 | cfg = EasyDict(cfg) 15 | env = GymPybulletDronesEnv(cfg) 16 | 17 | env.reset() 18 | done = False 19 | while not done: 20 | action = env.action_space.sample() 21 | assert action.shape[0] == 4 22 | 23 | for i in range(action.shape[0]): 24 | assert action[i] >= env.action_space.low[i] and action[i] <= env.action_space.high[i] 25 | 26 | obs, reward, done, info = env.step(action) 27 | 28 | assert obs.shape[0] == 12 29 | for i in range(obs.shape[0]): 30 | assert obs[i] >= env.observation_space.low[i] and obs[i] <= env.observation_space.high[i] 31 | 32 | assert reward >= env.reward_space.low and reward <= env.reward_space.high 33 | -------------------------------------------------------------------------------- /dizoo/gym_pybullet_drones/envs/test_ori_env.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import gym 3 | import numpy as np 4 | 5 | import gym_pybullet_drones 6 | 7 | 8 | @pytest.mark.envtest 9 | class TestGymPybulletDronesOriEnv: 10 | 11 | def test_naive(self): 12 | env = gym.make("takeoff-aviary-v0") 13 | env.reset() 14 | done = False 15 | while not done: 16 | action = env.action_space.sample() 17 | assert action.shape[0] == 4 18 | 19 | for i in range(action.shape[0]): 20 | assert action[i] >= env.action_space.low[i] and action[i] <= env.action_space.high[i] 21 | 22 | obs, reward, done, info = env.step(action) 23 | assert obs.shape[0] == 12 24 | for i in range(obs.shape[0]): 25 | assert obs[i] >= env.observation_space.low[i] and obs[i] <= env.observation_space.high[i] 26 | 27 | assert reward >= env.reward_space.low and reward <= env.reward_space.high 28 | -------------------------------------------------------------------------------- /dizoo/gym_pybullet_drones/gym_pybullet_drones.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_pybullet_drones/gym_pybullet_drones.gif -------------------------------------------------------------------------------- /dizoo/gym_soccer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_soccer/__init__.py -------------------------------------------------------------------------------- /dizoo/gym_soccer/envs/README.md: -------------------------------------------------------------------------------- 1 | # How to replay a log 2 | 3 | 1. Set the log path to store episode logs by the following command: 4 | 5 | `env.enable_save_replay('./game_log')` 6 | 7 | 2. After running the game, you can see some log files in the game_log directory. 8 | 9 | 3. Execute the following command to replay the log file (*.rcg) 10 | 11 | ` env.replay_log("game_log/20211019011053-base_left_0-vs-base_right_0.rcg")` -------------------------------------------------------------------------------- /dizoo/gym_soccer/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_soccer/envs/__init__.py -------------------------------------------------------------------------------- /dizoo/gym_soccer/half_offensive.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_soccer/half_offensive.gif -------------------------------------------------------------------------------- /dizoo/image_classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/image_classification/__init__.py -------------------------------------------------------------------------------- /dizoo/image_classification/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import ImageNetDataset 2 | from .sampler import DistributedSampler 3 | -------------------------------------------------------------------------------- /dizoo/image_classification/imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/image_classification/imagenet.png -------------------------------------------------------------------------------- /dizoo/image_classification/policy/__init__.py: -------------------------------------------------------------------------------- 1 | from .policy import ImageClassificationPolicy 2 | -------------------------------------------------------------------------------- /dizoo/ising_env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/ising_env/__init__.py -------------------------------------------------------------------------------- /dizoo/ising_env/entry/ising_mfq_eval.py: -------------------------------------------------------------------------------- 1 | from dizoo.ising_env.config.ising_mfq_config import main_config, create_config 2 | from ding.entry import eval 3 | 4 | 5 | def main(): 6 | main_config.env.collector_env_num = 1 7 | main_config.env.evaluator_env_num = 1 8 | main_config.env.n_evaluator_episode = 1 9 | ckpt_path = './ckpt_best.pth.tar' 10 | replay_path = './replay_videos' 11 | eval((main_config, create_config), seed=1, load_path=ckpt_path, replay_path=replay_path) 12 | 13 | 14 | if __name__ == "__main__": 15 | main() 16 | -------------------------------------------------------------------------------- /dizoo/ising_env/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .ising_model_env import IsingModelEnv 2 | -------------------------------------------------------------------------------- /dizoo/ising_env/envs/ising_model/__init__.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import os.path as osp 3 | 4 | 5 | def load(name): 6 | pathname = osp.join(osp.dirname(__file__), name) 7 | return imp.load_source('', pathname) 8 | -------------------------------------------------------------------------------- /dizoo/ising_env/envs/ising_model/multiagent/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/ising_env/envs/ising_model/multiagent/__init__.py -------------------------------------------------------------------------------- /dizoo/ising_env/ising_env.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/ising_env/ising_env.gif -------------------------------------------------------------------------------- /dizoo/league_demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/league_demo/__init__.py -------------------------------------------------------------------------------- /dizoo/league_demo/league_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/league_demo/league_demo.png -------------------------------------------------------------------------------- /dizoo/mario/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/mario/__init__.py -------------------------------------------------------------------------------- /dizoo/mario/mario.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/mario/mario.gif -------------------------------------------------------------------------------- /dizoo/maze/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register(id='Maze', entry_point='dizoo.maze.envs:Maze') 4 | -------------------------------------------------------------------------------- /dizoo/maze/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .maze_env import Maze 2 | -------------------------------------------------------------------------------- /dizoo/maze/envs/test_maze_env.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | import numpy as np 4 | from dizoo.maze.envs.maze_env import Maze 5 | from easydict import EasyDict 6 | import copy 7 | 8 | 9 | @pytest.mark.envtest 10 | class TestMazeEnv: 11 | 12 | def test_maze(self): 13 | env = Maze(EasyDict({'size': 16})) 14 | env.seed(314) 15 | assert env._seed == 314 16 | obs = env.reset() 17 | assert obs.shape == (16, 16, 3) 18 | min_val, max_val = 0, 3 19 | for i in range(100): 20 | random_action = np.random.randint(min_val, max_val, size=(1, )) 21 | timestep = env.step(random_action) 22 | print(timestep) 23 | print(timestep.obs.max()) 24 | assert isinstance(timestep.obs, np.ndarray) 25 | assert isinstance(timestep.done, bool) 26 | if timestep.done: 27 | env.reset() 28 | env.close() 29 | -------------------------------------------------------------------------------- /dizoo/metadrive/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/metadrive/__init__.py -------------------------------------------------------------------------------- /dizoo/metadrive/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/metadrive/config/__init__.py -------------------------------------------------------------------------------- /dizoo/metadrive/env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/metadrive/env/__init__.py -------------------------------------------------------------------------------- /dizoo/metadrive/metadrive_env.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/metadrive/metadrive_env.gif -------------------------------------------------------------------------------- /dizoo/minigrid/__init__.py: -------------------------------------------------------------------------------- 1 | from gymnasium.envs.registration import register 2 | 3 | register(id='MiniGrid-AKTDT-7x7-1-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_7x7_1') 4 | 5 | register(id='MiniGrid-AKTDT-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure') 6 | 7 | register(id='MiniGrid-AKTDT-13x13-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_13x13') 8 | 9 | register(id='MiniGrid-AKTDT-13x13-1-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_13x13_1') 10 | 11 | register(id='MiniGrid-AKTDT-19x19-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_19x19') 12 | 13 | register(id='MiniGrid-AKTDT-19x19-3-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_19x19_3') 14 | 15 | register(id='MiniGrid-NoisyTV-v0', entry_point='dizoo.minigrid.envs:NoisyTVEnv') 16 | -------------------------------------------------------------------------------- /dizoo/minigrid/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .minigrid_env import MiniGridEnv 2 | from dizoo.minigrid.envs.app_key_to_door_treasure import AppleKeyToDoorTreasure, AppleKeyToDoorTreasure_13x13, AppleKeyToDoorTreasure_19x19, AppleKeyToDoorTreasure_13x13_1, AppleKeyToDoorTreasure_19x19_3, AppleKeyToDoorTreasure_7x7_1 3 | from dizoo.minigrid.envs.noisy_tv import NoisyTVEnv 4 | -------------------------------------------------------------------------------- /dizoo/minigrid/minigrid.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/minigrid/minigrid.gif -------------------------------------------------------------------------------- /dizoo/mujoco/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/mujoco/__init__.py -------------------------------------------------------------------------------- /dizoo/mujoco/addition/install_mesa.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ~/rpm 2 | yumdownloader --destdir ~/rpm --resolve mesa-libOSMesa.x86_64 mesa-libOSMesa-devel.x86_64 patchelf.x86_64 3 | cd ~/rpm 4 | for rpm in `ls`; do rpm2cpio $rpm | cpio -id ; done 5 | -------------------------------------------------------------------------------- /dizoo/mujoco/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/mujoco/config/__init__.py -------------------------------------------------------------------------------- /dizoo/mujoco/entry/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/mujoco/entry/__init__.py -------------------------------------------------------------------------------- /dizoo/mujoco/entry/mujoco_cql_generation_main.py: -------------------------------------------------------------------------------- 1 | from dizoo.mujoco.config.hopper_sac_data_generation_config import main_config, create_config 2 | from ding.entry import collect_demo_data, eval 3 | import torch 4 | import copy 5 | 6 | 7 | def eval_ckpt(args): 8 | config = copy.deepcopy([main_config, create_config]) 9 | eval(config, seed=args.seed, load_path=main_config.policy.learn.learner.hook.load_ckpt_before_run) 10 | 11 | 12 | def generate(args): 13 | config = copy.deepcopy([main_config, create_config]) 14 | state_dict = torch.load(main_config.policy.learn.learner.load_path, map_location='cpu') 15 | collect_demo_data( 16 | config, 17 | collect_count=main_config.policy.other.replay_buffer.replay_buffer_size, 18 | seed=args.seed, 19 | expert_data_path=main_config.policy.collect.save_path, 20 | state_dict=state_dict 21 | ) 22 | 23 | 24 | if __name__ == "__main__": 25 | import argparse 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--seed', '-s', type=int, default=0) 29 | args = parser.parse_args() 30 | 31 | eval_ckpt(args) 32 | generate(args) 33 | -------------------------------------------------------------------------------- /dizoo/mujoco/entry/mujoco_cql_main.py: -------------------------------------------------------------------------------- 1 | from dizoo.mujoco.config.hopper_cql_config import main_config, create_config 2 | from ding.entry import serial_pipeline_offline 3 | 4 | 5 | def train(args): 6 | config = [main_config, create_config] 7 | serial_pipeline_offline(config, seed=args.seed) 8 | 9 | 10 | if __name__ == "__main__": 11 | import argparse 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--seed', '-s', type=int, default=10) 15 | args = parser.parse_args() 16 | 17 | train(args) 18 | -------------------------------------------------------------------------------- /dizoo/mujoco/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .mujoco_env import MujocoEnv 2 | from .mujoco_disc_env import MujocoDiscEnv 3 | -------------------------------------------------------------------------------- /dizoo/mujoco/envs/test/test_mujoco_gym_env.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import gym 3 | 4 | 5 | @pytest.mark.envtest 6 | def test_shapes(): 7 | from dizoo.mujoco.envs import mujoco_gym_env 8 | ant = gym.make('AntTruncatedObs-v2') 9 | assert ant.observation_space.shape == (27, ) 10 | assert ant.action_space.shape == (8, ) 11 | humanoid = gym.make('HumanoidTruncatedObs-v2') 12 | assert humanoid.observation_space.shape == (45, ) 13 | assert humanoid.action_space.shape == (17, ) 14 | -------------------------------------------------------------------------------- /dizoo/mujoco/mujoco.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/mujoco/mujoco.gif -------------------------------------------------------------------------------- /dizoo/multiagent_mujoco/README.md: -------------------------------------------------------------------------------- 1 | ## Multi Agent Mujoco Env 2 | 3 | Multi Agent Mujoco is an environment for Continuous Multi-Agent Robotic Control, based on OpenAI's Mujoco Gym environments. 4 | 5 | The environment is described in the paper [Deep Multi-Agent Reinforcement Learning for Decentralized Continuous Cooperative Control](https://arxiv.org/abs/2003.06709) by Christian Schroeder de Witt, Bei Peng, Pierre-Alexandre Kamienny, Philip Torr, Wendelin Böhmer and Shimon Whiteson, Torr Vision Group and Whiteson Research Lab, University of Oxford, 2020 6 | 7 | You can find more details in [Multi-Agent Mujoco Environment](https://github.com/schroederdewitt/multiagent_mujoco) 8 | -------------------------------------------------------------------------------- /dizoo/multiagent_mujoco/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/multiagent_mujoco/__init__.py -------------------------------------------------------------------------------- /dizoo/multiagent_mujoco/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .mujoco_multi import MujocoMulti 2 | from .coupled_half_cheetah import CoupledHalfCheetah 3 | from .manyagent_swimmer import ManyAgentSwimmerEnv 4 | from .manyagent_ant import ManyAgentAntEnv 5 | -------------------------------------------------------------------------------- /dizoo/multiagent_mujoco/envs/assets/.gitignore: -------------------------------------------------------------------------------- 1 | *.auto.xml 2 | -------------------------------------------------------------------------------- /dizoo/multiagent_mujoco/envs/assets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/multiagent_mujoco/envs/assets/__init__.py -------------------------------------------------------------------------------- /dizoo/overcooked/README.md: -------------------------------------------------------------------------------- 1 | This is the overcooked-ai environment compatiable to DI-engine. 2 | 3 | The origin code is referenced on [Overcooked-AI](https://github.com/HumanCompatibleAI/overcooked_ai), which is a benchmark environment for fully cooperative human-AI task performance, based on the wildly popular video game [Overcooked](http://www.ghosttowngames.com/overcooked/). -------------------------------------------------------------------------------- /dizoo/overcooked/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/overcooked/__init__.py -------------------------------------------------------------------------------- /dizoo/overcooked/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .overcooked_demo_ppo_config import overcooked_demo_ppo_config 2 | -------------------------------------------------------------------------------- /dizoo/overcooked/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .overcooked_env import OvercookEnv, OvercookGameEnv 2 | -------------------------------------------------------------------------------- /dizoo/overcooked/overcooked.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/overcooked/overcooked.gif -------------------------------------------------------------------------------- /dizoo/petting_zoo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/petting_zoo/__init__.py -------------------------------------------------------------------------------- /dizoo/petting_zoo/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .ptz_simple_spread_atoc_config import ptz_simple_spread_atoc_config, ptz_simple_spread_atoc_create_config 2 | from .ptz_simple_spread_collaq_config import ptz_simple_spread_collaq_config, ptz_simple_spread_collaq_create_config 3 | from .ptz_simple_spread_coma_config import ptz_simple_spread_coma_config, ptz_simple_spread_coma_create_config 4 | from .ptz_simple_spread_mappo_config import ptz_simple_spread_mappo_config, ptz_simple_spread_mappo_create_config 5 | from .ptz_simple_spread_qmix_config import ptz_simple_spread_qmix_config, ptz_simple_spread_qmix_create_config 6 | from .ptz_simple_spread_qtran_config import ptz_simple_spread_qtran_config, ptz_simple_spread_qtran_create_config 7 | from .ptz_simple_spread_vdn_config import ptz_simple_spread_vdn_config, ptz_simple_spread_vdn_create_config 8 | from .ptz_simple_spread_wqmix_config import ptz_simple_spread_wqmix_config, ptz_simple_spread_wqmix_create_config 9 | from .ptz_simple_spread_madqn_config import ptz_simple_spread_madqn_config, ptz_simple_spread_madqn_create_config # noqa 10 | -------------------------------------------------------------------------------- /dizoo/petting_zoo/entry/ptz_simple_spread_eval.py: -------------------------------------------------------------------------------- 1 | from dizoo.petting_zoo.config.ptz_simple_spread_mappo_config import main_config, create_config 2 | from ding.entry import eval 3 | 4 | 5 | def main(): 6 | ckpt_path = './ckpt_best.pth.tar' 7 | replay_path = './replay_videos' 8 | eval((main_config, create_config), seed=0, load_path=ckpt_path, replay_path=replay_path) 9 | 10 | 11 | if __name__ == "__main__": 12 | main() 13 | -------------------------------------------------------------------------------- /dizoo/petting_zoo/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/petting_zoo/envs/__init__.py -------------------------------------------------------------------------------- /dizoo/petting_zoo/petting_zoo_mpe_simple_spread.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/petting_zoo/petting_zoo_mpe_simple_spread.gif -------------------------------------------------------------------------------- /dizoo/pomdp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/pomdp/__init__.py -------------------------------------------------------------------------------- /dizoo/pomdp/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .atari_env import PomdpAtariEnv 2 | -------------------------------------------------------------------------------- /dizoo/pomdp/envs/test_atari_env.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import gym 3 | import numpy as np 4 | from easydict import EasyDict 5 | from dizoo.pomdp.envs import PomdpAtariEnv 6 | 7 | 8 | @pytest.mark.envtest 9 | def test_env(): 10 | cfg = { 11 | 'env_id': 'Pong-ramNoFrameskip-v4', 12 | 'frame_stack': 4, 13 | 'is_train': True, 14 | 'warp_frame': False, 15 | 'clip_reward': False, 16 | 'use_ram': True, 17 | 'render': False, 18 | 'pomdp': dict(noise_scale=0.001, zero_p=0.1, reward_noise=0.01, duplicate_p=0.2) 19 | } 20 | 21 | cfg = EasyDict(cfg) 22 | pong_env = PomdpAtariEnv(cfg) 23 | pong_env.seed(0) 24 | obs = pong_env.reset() 25 | act_dim = pong_env.info().act_space.shape[0] 26 | while True: 27 | random_action = np.random.choice(range(act_dim), size=(1, )) 28 | timestep = pong_env.step(random_action) 29 | assert timestep.obs.shape == (512, ) 30 | assert timestep.reward.shape == (1, ) 31 | # assert isinstance(timestep, tuple) 32 | if timestep.done: 33 | assert 'eval_episode_return' in timestep.info, timestep.info 34 | break 35 | pong_env.close() 36 | -------------------------------------------------------------------------------- /dizoo/procgen/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/procgen/__init__.py -------------------------------------------------------------------------------- /dizoo/procgen/coinrun.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/procgen/coinrun.gif -------------------------------------------------------------------------------- /dizoo/procgen/coinrun.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/procgen/coinrun.png -------------------------------------------------------------------------------- /dizoo/procgen/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .coinrun_dqn_config import main_config, create_config 2 | from .coinrun_ppo_config import main_config, create_config 3 | -------------------------------------------------------------------------------- /dizoo/procgen/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .procgen_env import ProcgenEnv 2 | -------------------------------------------------------------------------------- /dizoo/procgen/envs/test_coinrun_env.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | from easydict import EasyDict 4 | from dizoo.procgen.envs import ProcgenEnv 5 | 6 | 7 | @pytest.mark.envtest 8 | class TestProcgenEnv: 9 | 10 | def test_naive(self): 11 | env = ProcgenEnv(EasyDict({})) 12 | env.seed(314) 13 | assert env._seed == 314 14 | obs = env.reset() 15 | assert obs.shape == (3, 64, 64) 16 | for i in range(10): 17 | random_action = np.tanh(np.random.random(1)) 18 | timestep = env.step(random_action) 19 | assert timestep.obs.shape == (3, 64, 64) 20 | assert timestep.reward.shape == (1, ) 21 | assert timestep.reward >= env.info().rew_space.value['min'] 22 | assert timestep.reward <= env.info().rew_space.value['max'] 23 | # assert isinstance(timestep, tuple) 24 | print(env.info()) 25 | env.close() 26 | -------------------------------------------------------------------------------- /dizoo/procgen/maze.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/procgen/maze.gif -------------------------------------------------------------------------------- /dizoo/procgen/maze.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/procgen/maze.png -------------------------------------------------------------------------------- /dizoo/pybullet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/pybullet/__init__.py -------------------------------------------------------------------------------- /dizoo/pybullet/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .pybullet_env import PybulletEnv 2 | -------------------------------------------------------------------------------- /dizoo/pybullet/pybullet.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/pybullet/pybullet.gif -------------------------------------------------------------------------------- /dizoo/rocket/README.md: -------------------------------------------------------------------------------- 1 | # Install 2 | 3 | ```shell 4 | pip install git+https://github.com/nighood/rocket-recycling@master#egg=rocket_recycling 5 | ``` 6 | 7 | # Chek Install 8 | ```shell 9 | pytest -sv test_rocket_env.py 10 | ``` 11 | -------------------------------------------------------------------------------- /dizoo/rocket/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/rocket/__init__.py -------------------------------------------------------------------------------- /dizoo/rocket/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/rocket/config/__init__.py -------------------------------------------------------------------------------- /dizoo/rocket/entry/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/rocket/entry/__init__.py -------------------------------------------------------------------------------- /dizoo/rocket/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .rocket_env import RocketEnv 2 | -------------------------------------------------------------------------------- /dizoo/slime_volley/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/slime_volley/__init__.py -------------------------------------------------------------------------------- /dizoo/slime_volley/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .slime_volley_env import SlimeVolleyEnv 2 | -------------------------------------------------------------------------------- /dizoo/slime_volley/slime_volley.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/slime_volley/slime_volley.gif -------------------------------------------------------------------------------- /dizoo/smac/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/__init__.py -------------------------------------------------------------------------------- /dizoo/smac/envs/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from .fake_smac_env import FakeSMACEnv 4 | try: 5 | from .smac_env import SMACEnv 6 | except ImportError: 7 | warnings.warn("not found pysc2 env, please install it") 8 | SMACEnv = None 9 | -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/README.md: -------------------------------------------------------------------------------- 1 | # Notes on Two Player Maps 2 | 3 | Before starting, you need to do the following things: 4 | 5 | 1. copy the maps in `maps/SMAC_Maps_two_player/*.SC2Map` to the directory `StarCraft II/Maps/SMAC_Maps_two_player/`. 6 | 2. copy the maps in `maps/SMAC_Maps/*.SC2Map` to the directory `StarCraft II/Maps/SMAC_Maps/`. 7 | 8 | A convenient bash script is: 9 | 10 | ```bash 11 | # In linux 12 | cp -r SMAC_Maps_two_player/ ~/StarCraftII/Maps/ 13 | cp -r SMAC_Maps/ ~/StarCraftII/Maps/ 14 | ``` 15 | 16 | -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/10m_vs_11m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/10m_vs_11m.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/1c3s5z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/1c3s5z.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/25m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/25m.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/27m_vs_30m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/27m_vs_30m.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/2c_vs_64zg.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/2c_vs_64zg.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/2m_vs_1z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/2m_vs_1z.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/2s3z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/2s3z.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/2s_vs_1sc.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/2s_vs_1sc.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/3m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/3m.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/3s5z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/3s5z.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/3s5z_vs_3s6z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/3s5z_vs_3s6z.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/3s_vs_3z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/3s_vs_3z.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/3s_vs_4z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/3s_vs_4z.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/3s_vs_5z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/3s_vs_5z.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/5m_vs_6m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/5m_vs_6m.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/6h_vs_8z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/6h_vs_8z.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/8m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/8m.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/8m_vs_9m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/8m_vs_9m.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/MMM.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/MMM.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/MMM2.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/MMM2.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/__init__.py -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/bane_vs_bane.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/bane_vs_bane.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/corridor.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/corridor.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/infestor_viper.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/infestor_viper.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps/so_many_baneling.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/so_many_baneling.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps_two_player/3m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps_two_player/3m.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps_two_player/3s5z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps_two_player/3s5z.SC2Map -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/SMAC_Maps_two_player/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps_two_player/__init__.py -------------------------------------------------------------------------------- /dizoo/smac/envs/maps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/__init__.py -------------------------------------------------------------------------------- /dizoo/smac/smac.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/smac.gif -------------------------------------------------------------------------------- /dizoo/sokoban/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/sokoban/__init__.py -------------------------------------------------------------------------------- /dizoo/sokoban/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .sokoban_env import SokobanEnv 2 | -------------------------------------------------------------------------------- /dizoo/sokoban/envs/test_sokoban_env.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | import pytest 3 | import numpy as np 4 | from dizoo.sokoban.envs.sokoban_env import SokobanEnv 5 | 6 | 7 | @pytest.mark.envtest 8 | class TestSokoban: 9 | 10 | def test_sokoban(self): 11 | env = SokobanEnv(EasyDict({'env_id': 'Sokoban-v0'})) 12 | env.reset() 13 | for i in range(100): 14 | action = np.random.randint(8) 15 | timestep = env.step(np.array(action)) 16 | print(timestep) 17 | print(timestep.obs.max()) 18 | assert isinstance(timestep.obs, np.ndarray) 19 | assert isinstance(timestep.done, bool) 20 | assert timestep.obs.shape == (160, 160, 3) 21 | print(timestep.info) 22 | assert timestep.reward.shape == (1, ) 23 | if timestep.done: 24 | env.reset() 25 | env.close() 26 | -------------------------------------------------------------------------------- /dizoo/tabmwp/README.md: -------------------------------------------------------------------------------- 1 | ## TabMWP Env 2 | 3 | ## Dataset 4 | 5 | The **TabMWP** dataset contains 38,431 tabular math word problems. Each question in **TabMWP** is aligned with a tabular context, which is presented as an image, semi-structured text, and a structured table. There are two types of questions: *free-text* and *multi-choice*, and each problem is annotated with gold solutions to reveal the multi-step reasoning process. 6 | 7 | The environment is described in the paper [Dynamic Prompt Learning via Policy Gradient for Semi-structured Mathematical Reasoning](https://arxiv.org/abs/2209.14610) by Pan Lu, Liang Qiu, Kai-Wei Chang, Ying Nian Wu, Song-Chun Zhu, Tanmay Rajpurohit, Peter Clark, Ashwin Kalyan, 2023. 8 | 9 | You can find more details in [Prompt PG](https://github.com/lupantech/PromptPG) 10 | 11 | ## Benchmark 12 | 13 | - We collect the responses of GPT-3 using a reduced dataset with 80 training samples and 16 candidates. In this way, there is no need for users to interact with GPT-3 using the API-key of openai. 14 | - You can directly reproduce the benchmark by running ``python dizoo/tabmwp/configs/tabmwp_pg_config.py`` 15 | 16 |  17 | -------------------------------------------------------------------------------- /dizoo/tabmwp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/tabmwp/__init__.py -------------------------------------------------------------------------------- /dizoo/tabmwp/benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/tabmwp/benchmark.png -------------------------------------------------------------------------------- /dizoo/tabmwp/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/tabmwp/envs/__init__.py -------------------------------------------------------------------------------- /dizoo/tabmwp/envs/test_tabmwp_env.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | import pytest 3 | from dizoo.tabmwp.envs.tabmwp_env import TabMWP 4 | 5 | 6 | @pytest.mark.envtest 7 | class TestSokoban: 8 | 9 | def test_tabmwp(self): 10 | config = dict( 11 | cand_number=20, 12 | train_number=100, 13 | engine='text-davinci-002', 14 | temperature=0., 15 | max_tokens=512, 16 | top_p=1., 17 | frequency_penalty=0., 18 | presence_penalty=0., 19 | option_inds=["A", "B", "C", "D", "E", "F"], 20 | api_key='', 21 | ) 22 | config = EasyDict(config) 23 | env = TabMWP(config) 24 | env.seed(0) 25 | env.close() 26 | -------------------------------------------------------------------------------- /dizoo/tabmwp/tabmwp.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/tabmwp/tabmwp.jpeg -------------------------------------------------------------------------------- /dizoo/taxi/Taxi-v3_episode_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/taxi/Taxi-v3_episode_0.gif -------------------------------------------------------------------------------- /dizoo/taxi/__init__.py: -------------------------------------------------------------------------------- 1 | from .envs import * 2 | -------------------------------------------------------------------------------- /dizoo/taxi/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .taxi_dqn_config import main_config, create_config 2 | -------------------------------------------------------------------------------- /dizoo/taxi/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .taxi_env import TaxiEnv 2 | -------------------------------------------------------------------------------- /docker/Dockerfile.rpc: -------------------------------------------------------------------------------- 1 | FROM snsao/pytorch:tensorpipe-fix as base 2 | 3 | WORKDIR /ding 4 | 5 | RUN apt update \ 6 | && apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git vim gcc \g++ make wget locales dnsutils -y \ 7 | && apt clean \ 8 | && rm -rf /var/cache/apt/* \ 9 | && sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen \ 10 | && locale-gen 11 | 12 | ENV LANG en_US.UTF-8 13 | ENV LANGUAGE en_US:UTF-8 14 | ENV LC_ALL en_US.UTF-8 15 | 16 | ADD setup.py setup.py 17 | ADD dizoo dizoo 18 | ADD ding ding 19 | ADD README.md README.md 20 | 21 | RUN python3 -m pip install --upgrade pip \ 22 | && python3 -m pip install --ignore-installed 'PyYAML<6.0' \ 23 | && python3 -m pip install --no-cache-dir .[fast,test] 24 | -------------------------------------------------------------------------------- /format.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Usage: at the root dir >> bash scripts/format.sh . 3 | 4 | # Check yapf version. (20200318 latest is 0.29.0. Format might be changed in future version.) 5 | ver=$(yapf --version) 6 | if ! echo $ver | grep -q 0.29.0; then 7 | echo "Wrong YAPF version installed: 0.29.0 is required, not $ver. $YAPF_DOWNLOAD_COMMAND_MSG" 8 | exit 1 9 | fi 10 | 11 | yapf --in-place --recursive -p --verbose --style .style.yapf $1 12 | 13 | if [[ "$2" == '--test' ]]; then # Only for CI usage, user should not use --test flag. 14 | if ! git diff --quiet &>/dev/null; then 15 | echo '*** You have not reformatted your codes! Please run [bash format.sh] at root directory before commit! Thanks! ***' 16 | exit 1 17 | else 18 | echo "Code style test passed!" 19 | fi 20 | fi 21 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | execution_timeout = 600 3 | markers = 4 | unittest 5 | platformtest 6 | envtest 7 | cudatest 8 | algotest 9 | benchmark 10 | envpooltest 11 | other 12 | tmp 13 | 14 | norecursedirs = ding/hpc_rl/tests 15 | --------------------------------------------------------------------------------