The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .coveragerc
├── .flake8
├── .github
    ├── ISSUE_TEMPLATE
    │   └── custom.md
    ├── PULL_REQUEST_TEMPLATE.md
    └── workflows
    │   ├── algo_test.yml
    │   ├── badge.yml
    │   ├── deploy.yml
    │   ├── doc.yml
    │   ├── envpool_test.yml
    │   ├── platform_test.yml
    │   ├── release.yml
    │   ├── release_conda.yml
    │   ├── style.yml
    │   └── unit_test.yml
├── .gitignore
├── .style.yapf
├── CHANGELOG
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── Makefile
├── README.md
├── assets
    └── wechat.jpeg
├── cloc.sh
├── codecov.yml
├── conda
    ├── conda_build_config.yaml
    └── meta.yaml
├── ding
    ├── __init__.py
    ├── bonus
    │   ├── __init__.py
    │   ├── a2c.py
    │   ├── c51.py
    │   ├── common.py
    │   ├── config.py
    │   ├── ddpg.py
    │   ├── dqn.py
    │   ├── model.py
    │   ├── pg.py
    │   ├── ppo_offpolicy.py
    │   ├── ppof.py
    │   ├── sac.py
    │   ├── sql.py
    │   └── td3.py
    ├── compatibility.py
    ├── config
    │   ├── __init__.py
    │   ├── config.py
    │   ├── example
    │   │   ├── A2C
    │   │   │   ├── __init__.py
    │   │   │   ├── gym_bipedalwalker_v3.py
    │   │   │   └── gym_lunarlander_v2.py
    │   │   ├── C51
    │   │   │   ├── __init__.py
    │   │   │   ├── gym_lunarlander_v2.py
    │   │   │   ├── gym_pongnoframeskip_v4.py
    │   │   │   ├── gym_qbertnoframeskip_v4.py
    │   │   │   └── gym_spaceInvadersnoframeskip_v4.py
    │   │   ├── DDPG
    │   │   │   ├── __init__.py
    │   │   │   ├── gym_bipedalwalker_v3.py
    │   │   │   ├── gym_halfcheetah_v3.py
    │   │   │   ├── gym_hopper_v3.py
    │   │   │   ├── gym_lunarlandercontinuous_v2.py
    │   │   │   ├── gym_pendulum_v1.py
    │   │   │   └── gym_walker2d_v3.py
    │   │   ├── DQN
    │   │   │   ├── __init__.py
    │   │   │   ├── gym_lunarlander_v2.py
    │   │   │   ├── gym_pongnoframeskip_v4.py
    │   │   │   ├── gym_qbertnoframeskip_v4.py
    │   │   │   └── gym_spaceInvadersnoframeskip_v4.py
    │   │   ├── PG
    │   │   │   ├── __init__.py
    │   │   │   └── gym_pendulum_v1.py
    │   │   ├── PPOF
    │   │   │   ├── __init__.py
    │   │   │   ├── gym_lunarlander_v2.py
    │   │   │   └── gym_lunarlandercontinuous_v2.py
    │   │   ├── PPOOffPolicy
    │   │   │   ├── __init__.py
    │   │   │   ├── gym_lunarlander_v2.py
    │   │   │   ├── gym_pongnoframeskip_v4.py
    │   │   │   ├── gym_qbertnoframeskip_v4.py
    │   │   │   └── gym_spaceInvadersnoframeskip_v4.py
    │   │   ├── SAC
    │   │   │   ├── __init__.py
    │   │   │   ├── gym_bipedalwalker_v3.py
    │   │   │   ├── gym_halfcheetah_v3.py
    │   │   │   ├── gym_hopper_v3.py
    │   │   │   ├── gym_lunarlandercontinuous_v2.py
    │   │   │   ├── gym_pendulum_v1.py
    │   │   │   └── gym_walker2d_v3.py
    │   │   ├── SQL
    │   │   │   ├── __init__.py
    │   │   │   └── gym_lunarlander_v2.py
    │   │   ├── TD3
    │   │   │   ├── __init__.py
    │   │   │   ├── gym_bipedalwalker_v3.py
    │   │   │   ├── gym_halfcheetah_v3.py
    │   │   │   ├── gym_hopper_v3.py
    │   │   │   ├── gym_lunarlandercontinuous_v2.py
    │   │   │   ├── gym_pendulum_v1.py
    │   │   │   └── gym_walker2d_v3.py
    │   │   └── __init__.py
    │   ├── tests
    │   │   └── test_config_formatted.py
    │   └── utils.py
    ├── data
    │   ├── __init__.py
    │   ├── buffer
    │   │   ├── __init__.py
    │   │   ├── buffer.py
    │   │   ├── deque_buffer.py
    │   │   ├── deque_buffer_wrapper.py
    │   │   ├── middleware
    │   │   │   ├── __init__.py
    │   │   │   ├── clone_object.py
    │   │   │   ├── group_sample.py
    │   │   │   ├── padding.py
    │   │   │   ├── priority.py
    │   │   │   ├── sample_range_view.py
    │   │   │   ├── staleness_check.py
    │   │   │   └── use_time_check.py
    │   │   └── tests
    │   │   │   ├── test_buffer.py
    │   │   │   ├── test_buffer_benchmark.py
    │   │   │   └── test_middleware.py
    │   ├── level_replay
    │   │   ├── __init__.py
    │   │   ├── level_sampler.py
    │   │   └── tests
    │   │   │   └── test_level_sampler.py
    │   ├── model_loader.py
    │   ├── shm_buffer.py
    │   ├── storage
    │   │   ├── __init__.py
    │   │   ├── file.py
    │   │   ├── storage.py
    │   │   └── tests
    │   │   │   └── test_storage.py
    │   ├── storage_loader.py
    │   └── tests
    │   │   ├── test_model_loader.py
    │   │   ├── test_shm_buffer.py
    │   │   └── test_storage_loader.py
    ├── design
    │   ├── dataloader-sequence.png
    │   ├── dataloader-sequence.puml
    │   ├── env_state.png
    │   ├── parallel_main-sequence.png
    │   ├── parallel_main-sequence.puml
    │   ├── serial_collector-activity.png
    │   ├── serial_collector-activity.puml
    │   ├── serial_evaluator-activity.png
    │   ├── serial_evaluator-activity.puml
    │   ├── serial_learner-activity.png
    │   ├── serial_learner-activity.puml
    │   ├── serial_main-sequence.png
    │   └── serial_main.puml
    ├── entry
    │   ├── __init__.py
    │   ├── application_entry.py
    │   ├── application_entry_trex_collect_data.py
    │   ├── cli.py
    │   ├── cli_ditask.py
    │   ├── cli_parsers
    │   │   ├── __init__.py
    │   │   ├── k8s_parser.py
    │   │   ├── slurm_parser.py
    │   │   └── tests
    │   │   │   ├── test_k8s_parser.py
    │   │   │   └── test_slurm_parser.py
    │   ├── dist_entry.py
    │   ├── parallel_entry.py
    │   ├── predefined_config.py
    │   ├── serial_entry.py
    │   ├── serial_entry_bc.py
    │   ├── serial_entry_bco.py
    │   ├── serial_entry_dqfd.py
    │   ├── serial_entry_gail.py
    │   ├── serial_entry_guided_cost.py
    │   ├── serial_entry_mbrl.py
    │   ├── serial_entry_ngu.py
    │   ├── serial_entry_offline.py
    │   ├── serial_entry_onpolicy.py
    │   ├── serial_entry_onpolicy_ppg.py
    │   ├── serial_entry_pc.py
    │   ├── serial_entry_plr.py
    │   ├── serial_entry_preference_based_irl.py
    │   ├── serial_entry_preference_based_irl_onpolicy.py
    │   ├── serial_entry_r2d3.py
    │   ├── serial_entry_reward_model_offpolicy.py
    │   ├── serial_entry_reward_model_onpolicy.py
    │   ├── serial_entry_sqil.py
    │   ├── serial_entry_td3_vae.py
    │   ├── tests
    │   │   ├── config
    │   │   │   ├── agconfig.yaml
    │   │   │   ├── dijob-cartpole.yaml
    │   │   │   └── k8s-config.yaml
    │   │   ├── test_application_entry.py
    │   │   ├── test_application_entry_trex_collect_data.py
    │   │   ├── test_cli_ditask.py
    │   │   ├── test_parallel_entry.py
    │   │   ├── test_random_collect.py
    │   │   ├── test_serial_entry.py
    │   │   ├── test_serial_entry_algo.py
    │   │   ├── test_serial_entry_bc.py
    │   │   ├── test_serial_entry_bco.py
    │   │   ├── test_serial_entry_dqfd.py
    │   │   ├── test_serial_entry_for_anytrading.py
    │   │   ├── test_serial_entry_guided_cost.py
    │   │   ├── test_serial_entry_mbrl.py
    │   │   ├── test_serial_entry_onpolicy.py
    │   │   ├── test_serial_entry_preference_based_irl.py
    │   │   ├── test_serial_entry_preference_based_irl_onpolicy.py
    │   │   ├── test_serial_entry_reward_model.py
    │   │   └── test_serial_entry_sqil.py
    │   └── utils.py
    ├── envs
    │   ├── __init__.py
    │   ├── common
    │   │   ├── __init__.py
    │   │   ├── common_function.py
    │   │   ├── env_element.py
    │   │   ├── env_element_runner.py
    │   │   └── tests
    │   │   │   └── test_common_function.py
    │   ├── env
    │   │   ├── __init__.py
    │   │   ├── base_env.py
    │   │   ├── default_wrapper.py
    │   │   ├── ding_env_wrapper.py
    │   │   ├── env_implementation_check.py
    │   │   └── tests
    │   │   │   ├── __init__.py
    │   │   │   ├── demo_env.py
    │   │   │   ├── test_ding_env_wrapper.py
    │   │   │   └── test_env_implementation_check.py
    │   ├── env_manager
    │   │   ├── __init__.py
    │   │   ├── base_env_manager.py
    │   │   ├── ding_env_manager.py
    │   │   ├── env_supervisor.py
    │   │   ├── envpool_env_manager.py
    │   │   ├── gym_vector_env_manager.py
    │   │   ├── subprocess_env_manager.py
    │   │   └── tests
    │   │   │   ├── __init__.py
    │   │   │   ├── conftest.py
    │   │   │   ├── test_base_env_manager.py
    │   │   │   ├── test_env_supervisor.py
    │   │   │   ├── test_envpool_env_manager.py
    │   │   │   ├── test_gym_vector_env_manager.py
    │   │   │   ├── test_shm.py
    │   │   │   └── test_subprocess_env_manager.py
    │   ├── env_wrappers
    │   │   ├── __init__.py
    │   │   └── env_wrappers.py
    │   └── gym_env.py
    ├── example
    │   ├── __init__.py
    │   ├── bcq.py
    │   ├── c51_nstep.py
    │   ├── collect_demo_data.py
    │   ├── cql.py
    │   ├── d4pg.py
    │   ├── ddpg.py
    │   ├── dqn.py
    │   ├── dqn_eval.py
    │   ├── dqn_frozen_lake.py
    │   ├── dqn_her.py
    │   ├── dqn_new_env.py
    │   ├── dqn_nstep.py
    │   ├── dqn_nstep_gymnasium.py
    │   ├── dqn_per.py
    │   ├── dqn_rnd.py
    │   ├── dt.py
    │   ├── edac.py
    │   ├── impala.py
    │   ├── iqn_nstep.py
    │   ├── mappo.py
    │   ├── masac.py
    │   ├── pdqn.py
    │   ├── ppg_offpolicy.py
    │   ├── ppo.py
    │   ├── ppo_lunarlander.py
    │   ├── ppo_lunarlander_continuous.py
    │   ├── ppo_offpolicy.py
    │   ├── ppo_with_complex_obs.py
    │   ├── qgpo.py
    │   ├── qrdqn_nstep.py
    │   ├── r2d2.py
    │   ├── sac.py
    │   ├── sqil.py
    │   ├── sqil_continuous.py
    │   ├── sql.py
    │   ├── td3.py
    │   └── trex.py
    ├── framework
    │   ├── __init__.py
    │   ├── context.py
    │   ├── event_loop.py
    │   ├── message_queue
    │   │   ├── __init__.py
    │   │   ├── mq.py
    │   │   ├── nng.py
    │   │   ├── redis.py
    │   │   └── tests
    │   │   │   ├── test_nng.py
    │   │   │   └── test_redis.py
    │   ├── middleware
    │   │   ├── __init__.py
    │   │   ├── barrier.py
    │   │   ├── ckpt_handler.py
    │   │   ├── collector.py
    │   │   ├── data_fetcher.py
    │   │   ├── distributer.py
    │   │   ├── functional
    │   │   │   ├── __init__.py
    │   │   │   ├── advantage_estimator.py
    │   │   │   ├── collector.py
    │   │   │   ├── ctx_helper.py
    │   │   │   ├── data_processor.py
    │   │   │   ├── enhancer.py
    │   │   │   ├── evaluator.py
    │   │   │   ├── explorer.py
    │   │   │   ├── logger.py
    │   │   │   ├── priority.py
    │   │   │   ├── termination_checker.py
    │   │   │   ├── timer.py
    │   │   │   └── trainer.py
    │   │   ├── learner.py
    │   │   └── tests
    │   │   │   ├── __init__.py
    │   │   │   ├── mock_for_test.py
    │   │   │   ├── test_advantage_estimator.py
    │   │   │   ├── test_barrier.py
    │   │   │   ├── test_ckpt_handler.py
    │   │   │   ├── test_collector.py
    │   │   │   ├── test_data_processor.py
    │   │   │   ├── test_distributer.py
    │   │   │   ├── test_enhancer.py
    │   │   │   ├── test_evaluator.py
    │   │   │   ├── test_explorer.py
    │   │   │   ├── test_logger.py
    │   │   │   ├── test_priority.py
    │   │   │   └── test_trainer.py
    │   ├── parallel.py
    │   ├── supervisor.py
    │   ├── task.py
    │   ├── tests
    │   │   ├── context_fake_data.py
    │   │   ├── test_context.py
    │   │   ├── test_event_loop.py
    │   │   ├── test_parallel.py
    │   │   ├── test_supervisor.py
    │   │   ├── test_task.py
    │   │   └── test_wrapper.py
    │   └── wrapper
    │   │   ├── __init__.py
    │   │   └── step_timer.py
    ├── hpc_rl
    │   ├── README.md
    │   ├── __init__.py
    │   ├── tests
    │   │   ├── test_dntd.py
    │   │   ├── test_gae.py
    │   │   ├── test_lstm.py
    │   │   ├── test_ppo.py
    │   │   ├── test_qntd.py
    │   │   ├── test_qntd_rescale.py
    │   │   ├── test_scatter.py
    │   │   ├── test_tdlambda.py
    │   │   ├── test_upgo.py
    │   │   ├── test_vtrace.py
    │   │   └── testbase.py
    │   └── wrapper.py
    ├── interaction
    │   ├── __init__.py
    │   ├── base
    │   │   ├── __init__.py
    │   │   ├── app.py
    │   │   ├── common.py
    │   │   ├── network.py
    │   │   └── threading.py
    │   ├── config
    │   │   ├── __init__.py
    │   │   └── base.py
    │   ├── exception
    │   │   ├── __init__.py
    │   │   ├── base.py
    │   │   ├── master.py
    │   │   └── slave.py
    │   ├── master
    │   │   ├── __init__.py
    │   │   ├── base.py
    │   │   ├── connection.py
    │   │   ├── master.py
    │   │   └── task.py
    │   ├── slave
    │   │   ├── __init__.py
    │   │   ├── action.py
    │   │   └── slave.py
    │   └── tests
    │   │   ├── __init__.py
    │   │   ├── base
    │   │       ├── __init__.py
    │   │       ├── test_app.py
    │   │       ├── test_common.py
    │   │       ├── test_network.py
    │   │       └── test_threading.py
    │   │   ├── config
    │   │       ├── __init__.py
    │   │       └── test_base.py
    │   │   ├── exception
    │   │       ├── __init__.py
    │   │       ├── test_base.py
    │   │       ├── test_master.py
    │   │       └── test_slave.py
    │   │   ├── interaction
    │   │       ├── __init__.py
    │   │       ├── bases.py
    │   │       ├── test_errors.py
    │   │       └── test_simple.py
    │   │   └── test_utils
    │   │       ├── __init__.py
    │   │       ├── random.py
    │   │       └── stream.py
    ├── league
    │   ├── __init__.py
    │   ├── algorithm.py
    │   ├── base_league.py
    │   ├── metric.py
    │   ├── one_vs_one_league.py
    │   ├── player.py
    │   ├── shared_payoff.py
    │   ├── starcraft_player.py
    │   └── tests
    │   │   ├── conftest.py
    │   │   ├── league_test_default_config.py
    │   │   ├── test_league_metric.py
    │   │   ├── test_one_vs_one_league.py
    │   │   ├── test_payoff.py
    │   │   └── test_player.py
    ├── model
    │   ├── __init__.py
    │   ├── common
    │   │   ├── __init__.py
    │   │   ├── encoder.py
    │   │   ├── head.py
    │   │   ├── tests
    │   │   │   ├── test_encoder.py
    │   │   │   └── test_head.py
    │   │   └── utils.py
    │   ├── template
    │   │   ├── __init__.py
    │   │   ├── acer.py
    │   │   ├── atoc.py
    │   │   ├── bc.py
    │   │   ├── bcq.py
    │   │   ├── collaq.py
    │   │   ├── coma.py
    │   │   ├── decision_transformer.py
    │   │   ├── diffusion.py
    │   │   ├── ebm.py
    │   │   ├── edac.py
    │   │   ├── havac.py
    │   │   ├── hpt.py
    │   │   ├── language_transformer.py
    │   │   ├── madqn.py
    │   │   ├── maqac.py
    │   │   ├── mavac.py
    │   │   ├── ngu.py
    │   │   ├── pdqn.py
    │   │   ├── pg.py
    │   │   ├── ppg.py
    │   │   ├── procedure_cloning.py
    │   │   ├── q_learning.py
    │   │   ├── qac.py
    │   │   ├── qac_dist.py
    │   │   ├── qgpo.py
    │   │   ├── qmix.py
    │   │   ├── qtran.py
    │   │   ├── qvac.py
    │   │   ├── sqn.py
    │   │   ├── tests
    │   │   │   ├── test_acer.py
    │   │   │   ├── test_atoc.py
    │   │   │   ├── test_bc.py
    │   │   │   ├── test_bcq.py
    │   │   │   ├── test_collaq.py
    │   │   │   ├── test_coma_nn.py
    │   │   │   ├── test_decision_transformer.py
    │   │   │   ├── test_ebm.py
    │   │   │   ├── test_edac.py
    │   │   │   ├── test_havac.py
    │   │   │   ├── test_hpt.py
    │   │   │   ├── test_hybrid_qac.py
    │   │   │   ├── test_language_transformer.py
    │   │   │   ├── test_madqn.py
    │   │   │   ├── test_maqac.py
    │   │   │   ├── test_mavac.py
    │   │   │   ├── test_ngu.py
    │   │   │   ├── test_pdqn.py
    │   │   │   ├── test_pg.py
    │   │   │   ├── test_procedure_cloning.py
    │   │   │   ├── test_q_learning.py
    │   │   │   ├── test_qac.py
    │   │   │   ├── test_qac_dist.py
    │   │   │   ├── test_qmix.py
    │   │   │   ├── test_qtran.py
    │   │   │   ├── test_vac.py
    │   │   │   ├── test_vae.py
    │   │   │   └── test_wqmix.py
    │   │   ├── vac.py
    │   │   ├── vae.py
    │   │   └── wqmix.py
    │   └── wrapper
    │   │   ├── __init__.py
    │   │   ├── model_wrappers.py
    │   │   └── test_model_wrappers.py
    ├── policy
    │   ├── __init__.py
    │   ├── a2c.py
    │   ├── acer.py
    │   ├── atoc.py
    │   ├── base_policy.py
    │   ├── bc.py
    │   ├── bcq.py
    │   ├── bdq.py
    │   ├── c51.py
    │   ├── collaq.py
    │   ├── coma.py
    │   ├── command_mode_policy_instance.py
    │   ├── common_utils.py
    │   ├── cql.py
    │   ├── d4pg.py
    │   ├── ddpg.py
    │   ├── dqfd.py
    │   ├── dqn.py
    │   ├── dt.py
    │   ├── edac.py
    │   ├── fqf.py
    │   ├── happo.py
    │   ├── ibc.py
    │   ├── il.py
    │   ├── impala.py
    │   ├── iql.py
    │   ├── iqn.py
    │   ├── madqn.py
    │   ├── mbpolicy
    │   │   ├── __init__.py
    │   │   ├── dreamer.py
    │   │   ├── mbsac.py
    │   │   ├── tests
    │   │   │   └── test_mbpolicy_utils.py
    │   │   └── utils.py
    │   ├── mdqn.py
    │   ├── ngu.py
    │   ├── offppo_collect_traj.py
    │   ├── pc.py
    │   ├── pdqn.py
    │   ├── pg.py
    │   ├── plan_diffuser.py
    │   ├── policy_factory.py
    │   ├── ppg.py
    │   ├── ppo.py
    │   ├── ppof.py
    │   ├── prompt_awr.py
    │   ├── prompt_pg.py
    │   ├── qgpo.py
    │   ├── qmix.py
    │   ├── qrdqn.py
    │   ├── qtran.py
    │   ├── r2d2.py
    │   ├── r2d2_collect_traj.py
    │   ├── r2d2_gtrxl.py
    │   ├── r2d3.py
    │   ├── rainbow.py
    │   ├── sac.py
    │   ├── sql.py
    │   ├── sqn.py
    │   ├── td3.py
    │   ├── td3_bc.py
    │   ├── td3_vae.py
    │   ├── tests
    │   │   ├── test_common_utils.py
    │   │   ├── test_cql.py
    │   │   ├── test_r2d3.py
    │   │   └── test_stdim.py
    │   └── wqmix.py
    ├── reward_model
    │   ├── __init__.py
    │   ├── base_reward_model.py
    │   ├── drex_reward_model.py
    │   ├── gail_irl_model.py
    │   ├── guided_cost_reward_model.py
    │   ├── her_reward_model.py
    │   ├── icm_reward_model.py
    │   ├── ngu_reward_model.py
    │   ├── pdeil_irl_model.py
    │   ├── pwil_irl_model.py
    │   ├── red_irl_model.py
    │   ├── rnd_reward_model.py
    │   ├── tests
    │   │   └── test_gail_irl_model.py
    │   └── trex_reward_model.py
    ├── rl_utils
    │   ├── README.md
    │   ├── __init__.py
    │   ├── a2c.py
    │   ├── acer.py
    │   ├── adder.py
    │   ├── beta_function.py
    │   ├── coma.py
    │   ├── exploration.py
    │   ├── gae.py
    │   ├── grpo.py
    │   ├── happo.py
    │   ├── isw.py
    │   ├── log_prob_utils.py
    │   ├── ppg.py
    │   ├── ppo.py
    │   ├── retrace.py
    │   ├── rloo.py
    │   ├── sampler.py
    │   ├── td.py
    │   ├── tests
    │   │   ├── test_a2c.py
    │   │   ├── test_adder.py
    │   │   ├── test_coma.py
    │   │   ├── test_exploration.py
    │   │   ├── test_gae.py
    │   │   ├── test_grpo_rlhf.py
    │   │   ├── test_happo.py
    │   │   ├── test_log_prob_fn.py
    │   │   ├── test_log_prob_utils.py
    │   │   ├── test_ppg.py
    │   │   ├── test_ppo.py
    │   │   ├── test_ppo_rlhf.py
    │   │   ├── test_retrace.py
    │   │   ├── test_rloo_rlhf.py
    │   │   ├── test_td.py
    │   │   ├── test_upgo.py
    │   │   ├── test_value_rescale.py
    │   │   └── test_vtrace.py
    │   ├── upgo.py
    │   ├── value_rescale.py
    │   └── vtrace.py
    ├── scripts
    │   ├── dijob-qbert.yaml
    │   ├── docker-test-entry.sh
    │   ├── docker-test.sh
    │   ├── install-k8s-tools.sh
    │   ├── kill.sh
    │   ├── local_parallel.sh
    │   ├── local_serial.sh
    │   ├── main_league.sh
    │   ├── main_league_slurm.sh
    │   └── tests
    │   │   ├── test_parallel_socket.py
    │   │   └── test_parallel_socket.sh
    ├── torch_utils
    │   ├── __init__.py
    │   ├── backend_helper.py
    │   ├── checkpoint_helper.py
    │   ├── data_helper.py
    │   ├── dataparallel.py
    │   ├── diffusion_SDE
    │   │   ├── __init__.py
    │   │   └── dpm_solver_pytorch.py
    │   ├── distribution.py
    │   ├── loss
    │   │   ├── __init__.py
    │   │   ├── contrastive_loss.py
    │   │   ├── cross_entropy_loss.py
    │   │   ├── multi_logits_loss.py
    │   │   └── tests
    │   │   │   ├── test_contrastive_loss.py
    │   │   │   ├── test_cross_entropy_loss.py
    │   │   │   └── test_multi_logits_loss.py
    │   ├── lr_scheduler.py
    │   ├── math_helper.py
    │   ├── metric.py
    │   ├── model_helper.py
    │   ├── network
    │   │   ├── __init__.py
    │   │   ├── activation.py
    │   │   ├── diffusion.py
    │   │   ├── dreamer.py
    │   │   ├── gtrxl.py
    │   │   ├── gumbel_softmax.py
    │   │   ├── merge.py
    │   │   ├── nn_module.py
    │   │   ├── normalization.py
    │   │   ├── popart.py
    │   │   ├── res_block.py
    │   │   ├── resnet.py
    │   │   ├── rnn.py
    │   │   ├── scatter_connection.py
    │   │   ├── soft_argmax.py
    │   │   ├── tests
    │   │   │   ├── test_activation.py
    │   │   │   ├── test_diffusion.py
    │   │   │   ├── test_dreamer.py
    │   │   │   ├── test_gtrxl.py
    │   │   │   ├── test_gumbel_softmax.py
    │   │   │   ├── test_merge.py
    │   │   │   ├── test_nn_module.py
    │   │   │   ├── test_normalization.py
    │   │   │   ├── test_popart.py
    │   │   │   ├── test_res_block.py
    │   │   │   ├── test_resnet.py
    │   │   │   ├── test_rnn.py
    │   │   │   ├── test_scatter.py
    │   │   │   ├── test_soft_argmax.py
    │   │   │   └── test_transformer.py
    │   │   └── transformer.py
    │   ├── nn_test_helper.py
    │   ├── optimizer_helper.py
    │   ├── parameter.py
    │   ├── reshape_helper.py
    │   └── tests
    │   │   ├── test_backend_helper.py
    │   │   ├── test_ckpt_helper.py
    │   │   ├── test_data_helper.py
    │   │   ├── test_distribution.py
    │   │   ├── test_feature_merge.py
    │   │   ├── test_lr_scheduler.py
    │   │   ├── test_math_helper.py
    │   │   ├── test_metric.py
    │   │   ├── test_model_helper.py
    │   │   ├── test_nn_test_helper.py
    │   │   ├── test_optimizer.py
    │   │   ├── test_parameter.py
    │   │   └── test_reshape_helper.py
    ├── utils
    │   ├── __init__.py
    │   ├── autolog
    │   │   ├── __init__.py
    │   │   ├── base.py
    │   │   ├── data.py
    │   │   ├── model.py
    │   │   ├── tests
    │   │   │   ├── __init__.py
    │   │   │   ├── test_data.py
    │   │   │   ├── test_model.py
    │   │   │   └── test_time.py
    │   │   ├── time_ctl.py
    │   │   └── value.py
    │   ├── bfs_helper.py
    │   ├── collection_helper.py
    │   ├── compression_helper.py
    │   ├── data
    │   │   ├── __init__.py
    │   │   ├── base_dataloader.py
    │   │   ├── collate_fn.py
    │   │   ├── dataloader.py
    │   │   ├── dataset.py
    │   │   ├── rlhf_offline_dataset.py
    │   │   ├── rlhf_online_dataset.py
    │   │   ├── structure
    │   │   │   ├── __init__.py
    │   │   │   ├── cache.py
    │   │   │   └── lifo_deque.py
    │   │   └── tests
    │   │   │   ├── dataloader_speed
    │   │   │       └── experiment_dataloader_speed.py
    │   │   │   ├── test_cache.py
    │   │   │   ├── test_collate_fn.py
    │   │   │   ├── test_dataloader.py
    │   │   │   ├── test_dataset.py
    │   │   │   ├── test_rlhf_offline_dataset.py
    │   │   │   └── test_rlhf_online_dataset.py
    │   ├── default_helper.py
    │   ├── deprecation.py
    │   ├── design_helper.py
    │   ├── dict_helper.py
    │   ├── fake_linklink.py
    │   ├── fast_copy.py
    │   ├── file_helper.py
    │   ├── import_helper.py
    │   ├── k8s_helper.py
    │   ├── linklink_dist_helper.py
    │   ├── loader
    │   │   ├── __init__.py
    │   │   ├── base.py
    │   │   ├── collection.py
    │   │   ├── dict.py
    │   │   ├── exception.py
    │   │   ├── mapping.py
    │   │   ├── norm.py
    │   │   ├── number.py
    │   │   ├── string.py
    │   │   ├── tests
    │   │   │   ├── __init__.py
    │   │   │   ├── loader
    │   │   │   │   ├── __init__.py
    │   │   │   │   ├── test_base.py
    │   │   │   │   ├── test_collection.py
    │   │   │   │   ├── test_dict.py
    │   │   │   │   ├── test_mapping.py
    │   │   │   │   ├── test_norm.py
    │   │   │   │   ├── test_number.py
    │   │   │   │   ├── test_string.py
    │   │   │   │   ├── test_types.py
    │   │   │   │   └── test_utils.py
    │   │   │   └── test_cartpole_dqn_serial_config_loader.py
    │   │   ├── types.py
    │   │   └── utils.py
    │   ├── lock_helper.py
    │   ├── log_helper.py
    │   ├── log_writer_helper.py
    │   ├── memory_helper.py
    │   ├── normalizer_helper.py
    │   ├── orchestrator_launcher.py
    │   ├── profiler_helper.py
    │   ├── pytorch_ddp_dist_helper.py
    │   ├── registry.py
    │   ├── registry_factory.py
    │   ├── render_helper.py
    │   ├── scheduler_helper.py
    │   ├── segment_tree.py
    │   ├── slurm_helper.py
    │   ├── system_helper.py
    │   ├── tests
    │   │   ├── config
    │   │   │   └── k8s-config.yaml
    │   │   ├── test_bfs_helper.py
    │   │   ├── test_collection_helper.py
    │   │   ├── test_compression_helper.py
    │   │   ├── test_config_helper.py
    │   │   ├── test_default_helper.py
    │   │   ├── test_deprecation.py
    │   │   ├── test_design_helper.py
    │   │   ├── test_file_helper.py
    │   │   ├── test_import_helper.py
    │   │   ├── test_k8s_launcher.py
    │   │   ├── test_lock.py
    │   │   ├── test_log_helper.py
    │   │   ├── test_log_writer_helper.py
    │   │   ├── test_memory_helper.py
    │   │   ├── test_normalizer_helper.py
    │   │   ├── test_profiler_helper.py
    │   │   ├── test_registry.py
    │   │   ├── test_scheduler_helper.py
    │   │   ├── test_segment_tree.py
    │   │   ├── test_system_helper.py
    │   │   └── test_time_helper.py
    │   ├── time_helper.py
    │   ├── time_helper_base.py
    │   ├── time_helper_cuda.py
    │   └── type_helper.py
    ├── worker
    │   ├── __init__.py
    │   ├── adapter
    │   │   ├── __init__.py
    │   │   ├── learner_aggregator.py
    │   │   └── tests
    │   │   │   └── test_learner_aggregator.py
    │   ├── collector
    │   │   ├── __init__.py
    │   │   ├── base_parallel_collector.py
    │   │   ├── base_serial_collector.py
    │   │   ├── base_serial_evaluator.py
    │   │   ├── battle_episode_serial_collector.py
    │   │   ├── battle_interaction_serial_evaluator.py
    │   │   ├── battle_sample_serial_collector.py
    │   │   ├── comm
    │   │   │   ├── __init__.py
    │   │   │   ├── base_comm_collector.py
    │   │   │   ├── flask_fs_collector.py
    │   │   │   ├── tests
    │   │   │   │   └── test_collector_with_coordinator.py
    │   │   │   └── utils.py
    │   │   ├── episode_serial_collector.py
    │   │   ├── interaction_serial_evaluator.py
    │   │   ├── marine_parallel_collector.py
    │   │   ├── metric_serial_evaluator.py
    │   │   ├── sample_serial_collector.py
    │   │   ├── tests
    │   │   │   ├── __init__.py
    │   │   │   ├── fake_cls_policy.py
    │   │   │   ├── fake_cpong_dqn_config.py
    │   │   │   ├── speed_test
    │   │   │   │   ├── __init__.py
    │   │   │   │   ├── fake_env.py
    │   │   │   │   ├── fake_policy.py
    │   │   │   │   ├── test_collector_profile.py
    │   │   │   │   └── utils.py
    │   │   │   ├── test_base_serial_collector.py
    │   │   │   ├── test_episode_serial_collector.py
    │   │   │   ├── test_marine_parallel_collector.py
    │   │   │   ├── test_metric_serial_evaluator.py
    │   │   │   └── test_sample_serial_collector.py
    │   │   └── zergling_parallel_collector.py
    │   ├── coordinator
    │   │   ├── __init__.py
    │   │   ├── base_parallel_commander.py
    │   │   ├── base_serial_commander.py
    │   │   ├── comm_coordinator.py
    │   │   ├── coordinator.py
    │   │   ├── one_vs_one_parallel_commander.py
    │   │   ├── operator_server.py
    │   │   ├── resource_manager.py
    │   │   ├── solo_parallel_commander.py
    │   │   └── tests
    │   │   │   ├── conftest.py
    │   │   │   ├── test_coordinator.py
    │   │   │   ├── test_fake_operator_server.py
    │   │   │   └── test_one_vs_one_commander.py
    │   ├── learner
    │   │   ├── __init__.py
    │   │   ├── base_learner.py
    │   │   ├── comm
    │   │   │   ├── __init__.py
    │   │   │   ├── base_comm_learner.py
    │   │   │   ├── flask_fs_learner.py
    │   │   │   ├── tests
    │   │   │   │   └── test_learner_with_coordinator.py
    │   │   │   └── utils.py
    │   │   ├── learner_hook.py
    │   │   └── tests
    │   │   │   ├── test_base_learner.py
    │   │   │   └── test_learner_hook.py
    │   └── replay_buffer
    │   │   ├── __init__.py
    │   │   ├── advanced_buffer.py
    │   │   ├── base_buffer.py
    │   │   ├── episode_buffer.py
    │   │   ├── naive_buffer.py
    │   │   ├── tests
    │   │       ├── conftest.py
    │   │       ├── test_advanced_buffer.py
    │   │       └── test_naive_buffer.py
    │   │   └── utils.py
    └── world_model
    │   ├── __init__.py
    │   ├── base_world_model.py
    │   ├── ddppo.py
    │   ├── dreamer.py
    │   ├── idm.py
    │   ├── mbpo.py
    │   ├── model
    │       ├── __init__.py
    │       ├── ensemble.py
    │       ├── networks.py
    │       └── tests
    │       │   ├── test_ensemble.py
    │       │   └── test_networks.py
    │   ├── tests
    │       ├── test_ddppo.py
    │       ├── test_dreamerv3.py
    │       ├── test_idm.py
    │       ├── test_mbpo.py
    │       ├── test_world_model.py
    │       └── test_world_model_utils.py
    │   └── utils.py
├── dizoo
    ├── __init__.py
    ├── atari
    │   ├── __init__.py
    │   ├── atari.gif
    │   ├── config
    │   │   ├── __init__.py
    │   │   └── serial
    │   │   │   ├── __init__.py
    │   │   │   ├── asterix
    │   │   │       ├── __init__.py
    │   │   │       └── asterix_mdqn_config.py
    │   │   │   ├── demon_attack
    │   │   │       └── demon_attack_dqn_config.py
    │   │   │   ├── enduro
    │   │   │       ├── __init__.py
    │   │   │       ├── enduro_dqn_config.py
    │   │   │       ├── enduro_impala_config.py
    │   │   │       ├── enduro_mdqn_config.py
    │   │   │       ├── enduro_onppo_config.py
    │   │   │       ├── enduro_qrdqn_config.py
    │   │   │       └── enduro_rainbow_config.py
    │   │   │   ├── montezuma
    │   │   │       └── montezuma_ngu_config.py
    │   │   │   ├── phoenix
    │   │   │       ├── phoenix_fqf_config.py
    │   │   │       └── phoenix_iqn_config.py
    │   │   │   ├── pitfall
    │   │   │       └── pitfall_ngu_config.py
    │   │   │   ├── pong
    │   │   │       ├── __init__.py
    │   │   │       ├── pong_a2c_config.py
    │   │   │       ├── pong_acer_config.py
    │   │   │       ├── pong_c51_config.py
    │   │   │       ├── pong_cql_config.py
    │   │   │       ├── pong_dqfd_config.py
    │   │   │       ├── pong_dqn_config.py
    │   │   │       ├── pong_dqn_ddp_config.py
    │   │   │       ├── pong_dqn_envpool_config.py
    │   │   │       ├── pong_dqn_multi_gpu_config.py
    │   │   │       ├── pong_dqn_render_config.py
    │   │   │       ├── pong_dqn_stdim_config.py
    │   │   │       ├── pong_dt_config.py
    │   │   │       ├── pong_fqf_config.py
    │   │   │       ├── pong_gail_dqn_config.py
    │   │   │       ├── pong_impala_config.py
    │   │   │       ├── pong_iqn_config.py
    │   │   │       ├── pong_ngu_config.py
    │   │   │       ├── pong_ppg_config.py
    │   │   │       ├── pong_ppo_config.py
    │   │   │       ├── pong_ppo_ddp_config.py
    │   │   │       ├── pong_qrdqn_config.py
    │   │   │       ├── pong_qrdqn_generation_data_config.py
    │   │   │       ├── pong_r2d2_config.py
    │   │   │       ├── pong_r2d2_gtrxl_config.py
    │   │   │       ├── pong_r2d2_residual_config.py
    │   │   │       ├── pong_r2d3_offppoexpert_config.py
    │   │   │       ├── pong_r2d3_r2d2expert_config.py
    │   │   │       ├── pong_rainbow_config.py
    │   │   │       ├── pong_sqil_config.py
    │   │   │       ├── pong_sql_config.py
    │   │   │       ├── pong_trex_offppo_config.py
    │   │   │       └── pong_trex_sql_config.py
    │   │   │   ├── qbert
    │   │   │       ├── __init__.py
    │   │   │       ├── qbert_a2c_config.py
    │   │   │       ├── qbert_acer_config.py
    │   │   │       ├── qbert_c51_config.py
    │   │   │       ├── qbert_cql_config.py
    │   │   │       ├── qbert_dqfd_config.py
    │   │   │       ├── qbert_dqn_config.py
    │   │   │       ├── qbert_fqf_config.py
    │   │   │       ├── qbert_impala_config.py
    │   │   │       ├── qbert_iqn_config.py
    │   │   │       ├── qbert_ngu_config.py
    │   │   │       ├── qbert_offppo_config.py
    │   │   │       ├── qbert_onppo_config.py
    │   │   │       ├── qbert_ppg_config.py
    │   │   │       ├── qbert_qrdqn_config.py
    │   │   │       ├── qbert_qrdqn_generation_data_config.py
    │   │   │       ├── qbert_r2d2_config.py
    │   │   │       ├── qbert_r2d2_gtrxl_config.py
    │   │   │       ├── qbert_rainbow_config.py
    │   │   │       ├── qbert_sqil_config.py
    │   │   │       ├── qbert_sql_config.py
    │   │   │       ├── qbert_trex_dqn_config.py
    │   │   │       └── qbert_trex_offppo_config.py
    │   │   │   └── spaceinvaders
    │   │   │       ├── __init__.py
    │   │   │       ├── spaceinvaders_a2c_config.py
    │   │   │       ├── spaceinvaders_acer_config.py
    │   │   │       ├── spaceinvaders_c51_config.py
    │   │   │       ├── spaceinvaders_dqfd_config.py
    │   │   │       ├── spaceinvaders_dqn_config.py
    │   │   │       ├── spaceinvaders_dqn_config_multi_gpu_ddp.py
    │   │   │       ├── spaceinvaders_dqn_config_multi_gpu_dp.py
    │   │   │       ├── spaceinvaders_fqf_config.py
    │   │   │       ├── spaceinvaders_impala_config.py
    │   │   │       ├── spaceinvaders_iqn_config.py
    │   │   │       ├── spaceinvaders_mdqn_config.py
    │   │   │       ├── spaceinvaders_ngu_config.py
    │   │   │       ├── spaceinvaders_offppo_config.py
    │   │   │       ├── spaceinvaders_onppo_config.py
    │   │   │       ├── spaceinvaders_ppg_config.py
    │   │   │       ├── spaceinvaders_qrdqn_config.py
    │   │   │       ├── spaceinvaders_r2d2_config.py
    │   │   │       ├── spaceinvaders_r2d2_gtrxl_config.py
    │   │   │       ├── spaceinvaders_r2d2_residual_config.py
    │   │   │       ├── spaceinvaders_rainbow_config.py
    │   │   │       ├── spaceinvaders_sqil_config.py
    │   │   │       ├── spaceinvaders_sql_config.py
    │   │   │       ├── spaceinvaders_trex_dqn_config.py
    │   │   │       └── spaceinvaders_trex_offppo_config.py
    │   ├── entry
    │   │   ├── __init__.py
    │   │   ├── atari_dqn_main.py
    │   │   ├── atari_dt_main.py
    │   │   ├── atari_impala_main.py
    │   │   ├── atari_ppg_main.py
    │   │   ├── phoenix_fqf_main.py
    │   │   ├── phoenix_iqn_main.py
    │   │   ├── pong_cql_main.py
    │   │   ├── pong_dqn_envpool_main.py
    │   │   ├── pong_fqf_main.py
    │   │   ├── qbert_cql_main.py
    │   │   ├── qbert_fqf_main.py
    │   │   ├── spaceinvaders_dqn_eval.py
    │   │   ├── spaceinvaders_dqn_main_multi_gpu_ddp.py
    │   │   ├── spaceinvaders_dqn_main_multi_gpu_dp.py
    │   │   └── spaceinvaders_fqf_main.py
    │   ├── envs
    │   │   ├── __init__.py
    │   │   ├── atari_env.py
    │   │   ├── atari_wrappers.py
    │   │   └── test_atari_env.py
    │   └── example
    │   │   ├── atari_dqn.py
    │   │   ├── atari_dqn_ddp.py
    │   │   ├── atari_dqn_dist.py
    │   │   ├── atari_dqn_dist_ddp.py
    │   │   ├── atari_dqn_dist_rdma.py
    │   │   ├── atari_dqn_dp.py
    │   │   ├── atari_ppo.py
    │   │   └── atari_ppo_ddp.py
    ├── beergame
    │   ├── __init__.py
    │   ├── beergame.png
    │   ├── config
    │   │   └── beergame_onppo_config.py
    │   ├── entry
    │   │   └── beergame_eval.py
    │   └── envs
    │   │   ├── BGAgent.py
    │   │   ├── __init__.py
    │   │   ├── beergame_core.py
    │   │   ├── beergame_env.py
    │   │   ├── clBeergame.py
    │   │   ├── plotting.py
    │   │   └── utils.py
    ├── bitflip
    │   ├── README.md
    │   ├── __init__.py
    │   ├── bitflip.gif
    │   ├── config
    │   │   ├── __init__.py
    │   │   ├── bitflip_her_dqn_config.py
    │   │   └── bitflip_pure_dqn_config.py
    │   ├── entry
    │   │   ├── __init__.py
    │   │   └── bitflip_dqn_main.py
    │   └── envs
    │   │   ├── __init__.py
    │   │   ├── bitflip_env.py
    │   │   └── test_bitfilp_env.py
    ├── box2d
    │   ├── __init__.py
    │   ├── bipedalwalker
    │   │   ├── __init__.py
    │   │   ├── config
    │   │   │   ├── __init__.py
    │   │   │   ├── bipedalwalker_a2c_config.py
    │   │   │   ├── bipedalwalker_bco_config.py
    │   │   │   ├── bipedalwalker_ddpg_config.py
    │   │   │   ├── bipedalwalker_dt_config.py
    │   │   │   ├── bipedalwalker_gail_sac_config.py
    │   │   │   ├── bipedalwalker_impala_config.py
    │   │   │   ├── bipedalwalker_pg_config.py
    │   │   │   ├── bipedalwalker_ppo_config.py
    │   │   │   ├── bipedalwalker_ppopg_config.py
    │   │   │   ├── bipedalwalker_sac_config.py
    │   │   │   └── bipedalwalker_td3_config.py
    │   │   ├── entry
    │   │   │   ├── __init__.py
    │   │   │   └── bipedalwalker_ppo_eval.py
    │   │   ├── envs
    │   │   │   ├── __init__.py
    │   │   │   ├── bipedalwalker_env.py
    │   │   │   └── test_bipedalwalker.py
    │   │   └── original.gif
    │   ├── carracing
    │   │   ├── __init__.py
    │   │   ├── car_racing.gif
    │   │   ├── config
    │   │   │   ├── __init__.py
    │   │   │   └── carracing_dqn_config.py
    │   │   └── envs
    │   │   │   ├── __init__.py
    │   │   │   ├── carracing_env.py
    │   │   │   └── test_carracing_env.py
    │   └── lunarlander
    │   │   ├── __init__.py
    │   │   ├── config
    │   │       ├── __init__.py
    │   │       ├── lunarlander_a2c_config.py
    │   │       ├── lunarlander_acer_config.py
    │   │       ├── lunarlander_bco_config.py
    │   │       ├── lunarlander_c51_config.py
    │   │       ├── lunarlander_cont_ddpg_config.py
    │   │       ├── lunarlander_cont_sac_config.py
    │   │       ├── lunarlander_cont_td3_config.py
    │   │       ├── lunarlander_cont_td3_vae_config.py
    │   │       ├── lunarlander_discrete_sac_config.py
    │   │       ├── lunarlander_dqfd_config.py
    │   │       ├── lunarlander_dqn_config.py
    │   │       ├── lunarlander_dqn_deque_config.py
    │   │       ├── lunarlander_dt_config.py
    │   │       ├── lunarlander_gail_dqn_config.py
    │   │       ├── lunarlander_gcl_config.py
    │   │       ├── lunarlander_hpt_config.py
    │   │       ├── lunarlander_impala_config.py
    │   │       ├── lunarlander_ngu_config.py
    │   │       ├── lunarlander_offppo_config.py
    │   │       ├── lunarlander_pg_config.py
    │   │       ├── lunarlander_ppo_config.py
    │   │       ├── lunarlander_ppo_continuous_config.py
    │   │       ├── lunarlander_qrdqn_config.py
    │   │       ├── lunarlander_r2d2_config.py
    │   │       ├── lunarlander_r2d2_gtrxl_config.py
    │   │       ├── lunarlander_r2d3_ppoexpert_config.py
    │   │       ├── lunarlander_r2d3_r2d2expert_config.py
    │   │       ├── lunarlander_rnd_onppo_config.py
    │   │       ├── lunarlander_sqil_config.py
    │   │       ├── lunarlander_sql_config.py
    │   │       ├── lunarlander_trex_dqn_config.py
    │   │       └── lunarlander_trex_offppo_config.py
    │   │   ├── entry
    │   │       ├── __init__.py
    │   │       ├── lunarlander_dqn_eval.py
    │   │       ├── lunarlander_dqn_example.py
    │   │       └── lunarlander_hpt_example.py
    │   │   ├── envs
    │   │       ├── __init__.py
    │   │       ├── lunarlander_env.py
    │   │       └── test_lunarlander_env.py
    │   │   ├── lunarlander.gif
    │   │   └── offline_data
    │   │       ├── collect_dqn_data_config.py
    │   │       ├── lunarlander_collect_data.py
    │   │       └── lunarlander_show_data.py
    ├── bsuite
    │   ├── __init__.py
    │   ├── bsuite.png
    │   ├── config
    │   │   ├── __init__.py
    │   │   └── serial
    │   │   │   ├── bandit_noise
    │   │   │       └── bandit_noise_0_dqn_config.py
    │   │   │   ├── cartpole_swingup
    │   │   │       └── cartpole_swingup_0_dqn_config.py
    │   │   │   └── memory_len
    │   │   │       ├── memory_len_0_a2c_config.py
    │   │   │       ├── memory_len_0_dqn_config.py
    │   │   │       ├── memory_len_15_r2d2_config.py
    │   │   │       └── memory_len_15_r2d2_gtrxl_config.py
    │   └── envs
    │   │   ├── __init__.py
    │   │   ├── bsuite_env.py
    │   │   └── test_bsuite_env.py
    ├── classic_control
    │   ├── __init__.py
    │   ├── acrobot
    │   │   ├── __init__.py
    │   │   ├── acrobot.gif
    │   │   ├── config
    │   │   │   ├── __init__.py
    │   │   │   └── acrobot_dqn_config.py
    │   │   └── envs
    │   │   │   ├── __init__.py
    │   │   │   ├── acrobot_env.py
    │   │   │   └── test_acrobot_env.py
    │   ├── cartpole
    │   │   ├── __init__.py
    │   │   ├── cartpole.gif
    │   │   ├── config
    │   │   │   ├── __init__.py
    │   │   │   ├── cartpole_a2c_config.py
    │   │   │   ├── cartpole_acer_config.py
    │   │   │   ├── cartpole_bc_config.py
    │   │   │   ├── cartpole_bco_config.py
    │   │   │   ├── cartpole_c51_config.py
    │   │   │   ├── cartpole_cql_config.py
    │   │   │   ├── cartpole_decision_transformer.py
    │   │   │   ├── cartpole_dqfd_config.py
    │   │   │   ├── cartpole_dqn_config.py
    │   │   │   ├── cartpole_dqn_ddp_config.py
    │   │   │   ├── cartpole_dqn_gail_config.py
    │   │   │   ├── cartpole_dqn_rnd_config.py
    │   │   │   ├── cartpole_dqn_stdim_config.py
    │   │   │   ├── cartpole_drex_dqn_config.py
    │   │   │   ├── cartpole_dt_config.py
    │   │   │   ├── cartpole_fqf_config.py
    │   │   │   ├── cartpole_gcl_config.py
    │   │   │   ├── cartpole_impala_config.py
    │   │   │   ├── cartpole_iqn_config.py
    │   │   │   ├── cartpole_mdqn_config.py
    │   │   │   ├── cartpole_ngu_config.py
    │   │   │   ├── cartpole_pg_config.py
    │   │   │   ├── cartpole_ppg_config.py
    │   │   │   ├── cartpole_ppo_config.py
    │   │   │   ├── cartpole_ppo_ddp_config.py
    │   │   │   ├── cartpole_ppo_icm_config.py
    │   │   │   ├── cartpole_ppo_offpolicy_config.py
    │   │   │   ├── cartpole_ppo_stdim_config.py
    │   │   │   ├── cartpole_ppopg_config.py
    │   │   │   ├── cartpole_qrdqn_config.py
    │   │   │   ├── cartpole_qrdqn_generation_data_config.py
    │   │   │   ├── cartpole_r2d2_config.py
    │   │   │   ├── cartpole_r2d2_gtrxl_config.py
    │   │   │   ├── cartpole_r2d2_residual_config.py
    │   │   │   ├── cartpole_rainbow_config.py
    │   │   │   ├── cartpole_rnd_onppo_config.py
    │   │   │   ├── cartpole_sac_config.py
    │   │   │   ├── cartpole_sqil_config.py
    │   │   │   ├── cartpole_sql_config.py
    │   │   │   ├── cartpole_sqn_config.py
    │   │   │   ├── cartpole_trex_dqn_config.py
    │   │   │   ├── cartpole_trex_offppo_config.py
    │   │   │   ├── cartpole_trex_onppo_config.py
    │   │   │   └── parallel
    │   │   │   │   ├── __init__.py
    │   │   │   │   ├── cartpole_dqn_config.py
    │   │   │   │   ├── cartpole_dqn_config_k8s.py
    │   │   │   │   └── cartpole_dqn_dist.sh
    │   │   ├── entry
    │   │   │   ├── __init__.py
    │   │   │   ├── cartpole_c51_deploy.py
    │   │   │   ├── cartpole_c51_main.py
    │   │   │   ├── cartpole_cql_main.py
    │   │   │   ├── cartpole_dqn_buffer_main.py
    │   │   │   ├── cartpole_dqn_eval.py
    │   │   │   ├── cartpole_dqn_main.py
    │   │   │   ├── cartpole_dqn_pwil_main.py
    │   │   │   ├── cartpole_fqf_main.py
    │   │   │   ├── cartpole_ppg_main.py
    │   │   │   ├── cartpole_ppo_main.py
    │   │   │   └── cartpole_ppo_offpolicy_main.py
    │   │   └── envs
    │   │   │   ├── __init__.py
    │   │   │   ├── cartpole_env.py
    │   │   │   ├── test_cartpole_env.py
    │   │   │   └── test_cartpole_env_manager.py
    │   ├── mountain_car
    │   │   ├── __init__.py
    │   │   ├── config
    │   │   │   └── mtcar_rainbow_config.py
    │   │   └── envs
    │   │   │   ├── __init__.py
    │   │   │   ├── mtcar_env.py
    │   │   │   └── test_mtcar_env.py
    │   └── pendulum
    │   │   ├── __init__.py
    │   │   ├── config
    │   │       ├── __init__.py
    │   │       ├── mbrl
    │   │       │   ├── pendulum_mbsac_ddppo_config.py
    │   │       │   ├── pendulum_mbsac_mbpo_config.py
    │   │       │   ├── pendulum_sac_ddppo_config.py
    │   │       │   ├── pendulum_sac_mbpo_config.py
    │   │       │   └── pendulum_stevesac_mbpo_config.py
    │   │       ├── pendulum_a2c_config.py
    │   │       ├── pendulum_bdq_config.py
    │   │       ├── pendulum_cql_config.py
    │   │       ├── pendulum_d4pg_config.py
    │   │       ├── pendulum_ddpg_config.py
    │   │       ├── pendulum_dqn_config.py
    │   │       ├── pendulum_ibc_config.py
    │   │       ├── pendulum_pg_config.py
    │   │       ├── pendulum_ppo_config.py
    │   │       ├── pendulum_sac_config.py
    │   │       ├── pendulum_sac_data_generation_config.py
    │   │       ├── pendulum_sqil_sac_config.py
    │   │       ├── pendulum_td3_bc_config.py
    │   │       ├── pendulum_td3_config.py
    │   │       └── pendulum_td3_data_generation_config.py
    │   │   ├── entry
    │   │       ├── __init__.py
    │   │       ├── pendulum_cql_ddpg_main.py
    │   │       ├── pendulum_cql_main.py
    │   │       ├── pendulum_d4pg_main.py
    │   │       ├── pendulum_ddpg_main.py
    │   │       ├── pendulum_dqn_eval.py
    │   │       ├── pendulum_ppo_main.py
    │   │       ├── pendulum_td3_bc_main.py
    │   │       └── pendulum_td3_main.py
    │   │   ├── envs
    │   │       ├── __init__.py
    │   │       ├── pendulum_env.py
    │   │       └── test_pendulum_env.py
    │   │   └── pendulum.gif
    ├── cliffwalking
    │   ├── __init__.py
    │   ├── cliff_walking.gif
    │   ├── config
    │   │   └── cliffwalking_dqn_config.py
    │   ├── entry
    │   │   ├── cliffwalking_dqn_deploy.py
    │   │   └── cliffwalking_dqn_main.py
    │   └── envs
    │   │   ├── __init__.py
    │   │   ├── cliffwalking_env.py
    │   │   └── test_cliffwalking_env.py
    ├── common
    │   ├── __init__.py
    │   └── policy
    │   │   ├── __init__.py
    │   │   ├── md_dqn.py
    │   │   ├── md_ppo.py
    │   │   └── md_rainbow_dqn.py
    ├── competitive_rl
    │   ├── README.md
    │   ├── __init__.py
    │   ├── competitive_rl.gif
    │   ├── config
    │   │   └── cpong_dqn_config.py
    │   └── envs
    │   │   ├── __init__.py
    │   │   ├── competitive_rl_env.py
    │   │   ├── competitive_rl_env_wrapper.py
    │   │   ├── resources
    │   │       └── pong
    │   │       │   ├── checkpoint-alphapong.pkl
    │   │       │   ├── checkpoint-medium.pkl
    │   │       │   ├── checkpoint-strong.pkl
    │   │       │   └── checkpoint-weak.pkl
    │   │   └── test_competitive_rl.py
    ├── d4rl
    │   ├── __init__.py
    │   ├── config
    │   │   ├── __init__.py
    │   │   ├── antmaze_umaze_pd_config.py
    │   │   ├── halfcheetah_expert_cql_config.py
    │   │   ├── halfcheetah_expert_dt_config.py
    │   │   ├── halfcheetah_expert_td3bc_config.py
    │   │   ├── halfcheetah_medium_bcq_config.py
    │   │   ├── halfcheetah_medium_cql_config.py
    │   │   ├── halfcheetah_medium_dt_config.py
    │   │   ├── halfcheetah_medium_edac_config.py
    │   │   ├── halfcheetah_medium_expert_bcq_config.py
    │   │   ├── halfcheetah_medium_expert_cql_config.py
    │   │   ├── halfcheetah_medium_expert_dt_config.py
    │   │   ├── halfcheetah_medium_expert_edac_config.py
    │   │   ├── halfcheetah_medium_expert_iql_config.py
    │   │   ├── halfcheetah_medium_expert_pd_config.py
    │   │   ├── halfcheetah_medium_expert_qgpo_config.py
    │   │   ├── halfcheetah_medium_expert_td3bc_config.py
    │   │   ├── halfcheetah_medium_iql_config.py
    │   │   ├── halfcheetah_medium_pd_config.py
    │   │   ├── halfcheetah_medium_replay_cql_config.py
    │   │   ├── halfcheetah_medium_replay_dt_config.py
    │   │   ├── halfcheetah_medium_replay_iql_config.py
    │   │   ├── halfcheetah_medium_replay_td3bc_config.py
    │   │   ├── halfcheetah_medium_td3bc_config.py
    │   │   ├── halfcheetah_random_cql_config.py
    │   │   ├── halfcheetah_random_dt_config.py
    │   │   ├── halfcheetah_random_td3bc_config.py
    │   │   ├── hopper_expert_cql_config.py
    │   │   ├── hopper_expert_dt_config.py
    │   │   ├── hopper_expert_td3bc_config.py
    │   │   ├── hopper_medium_bcq_config.py
    │   │   ├── hopper_medium_cql_config.py
    │   │   ├── hopper_medium_dt_config.py
    │   │   ├── hopper_medium_edac_config.py
    │   │   ├── hopper_medium_expert_bc_config.py
    │   │   ├── hopper_medium_expert_bcq_config.py
    │   │   ├── hopper_medium_expert_cql_config.py
    │   │   ├── hopper_medium_expert_dt_config.py
    │   │   ├── hopper_medium_expert_edac_config.py
    │   │   ├── hopper_medium_expert_ibc_ar_config.py
    │   │   ├── hopper_medium_expert_ibc_config.py
    │   │   ├── hopper_medium_expert_ibc_mcmc_config.py
    │   │   ├── hopper_medium_expert_iql_config.py
    │   │   ├── hopper_medium_expert_pd_config.py
    │   │   ├── hopper_medium_expert_qgpo_config.py
    │   │   ├── hopper_medium_expert_td3bc_config.py
    │   │   ├── hopper_medium_iql_config.py
    │   │   ├── hopper_medium_pd_config.py
    │   │   ├── hopper_medium_replay_cql_config.py
    │   │   ├── hopper_medium_replay_dt_config.py
    │   │   ├── hopper_medium_replay_iql_config.py
    │   │   ├── hopper_medium_replay_td3bc_config.py
    │   │   ├── hopper_medium_td3bc_config.py
    │   │   ├── hopper_random_cql_config.py
    │   │   ├── hopper_random_dt_config.py
    │   │   ├── hopper_random_td3bc_config.py
    │   │   ├── kitchen_complete_bc_config.py
    │   │   ├── kitchen_complete_ibc_ar_config.py
    │   │   ├── kitchen_complete_ibc_config.py
    │   │   ├── kitchen_complete_ibc_mcmc_config.py
    │   │   ├── maze2d_large_pd_config.py
    │   │   ├── maze2d_medium_pd_config.py
    │   │   ├── maze2d_umaze_pd_config.py
    │   │   ├── pen_human_bc_config.py
    │   │   ├── pen_human_ibc_ar_config.py
    │   │   ├── pen_human_ibc_config.py
    │   │   ├── pen_human_ibc_mcmc_config.py
    │   │   ├── walker2d_expert_cql_config.py
    │   │   ├── walker2d_expert_dt_config.py
    │   │   ├── walker2d_expert_td3bc_config.py
    │   │   ├── walker2d_medium_cql_config.py
    │   │   ├── walker2d_medium_dt_config.py
    │   │   ├── walker2d_medium_expert_cql_config.py
    │   │   ├── walker2d_medium_expert_dt_config.py
    │   │   ├── walker2d_medium_expert_iql_config.py
    │   │   ├── walker2d_medium_expert_pd_config.py
    │   │   ├── walker2d_medium_expert_qgpo_config.py
    │   │   ├── walker2d_medium_expert_td3bc_config.py
    │   │   ├── walker2d_medium_iql_config.py
    │   │   ├── walker2d_medium_pd_config.py
    │   │   ├── walker2d_medium_replay_cql_config.py
    │   │   ├── walker2d_medium_replay_dt_config.py
    │   │   ├── walker2d_medium_replay_iql_config.py
    │   │   ├── walker2d_medium_replay_td3bc_config.py
    │   │   ├── walker2d_medium_td3bc_config.py
    │   │   ├── walker2d_random_cql_config.py
    │   │   ├── walker2d_random_dt_config.py
    │   │   └── walker2d_random_td3bc_config.py
    │   ├── d4rl.gif
    │   ├── entry
    │   │   ├── __init__.py
    │   │   ├── d4rl_bcq_main.py
    │   │   ├── d4rl_cql_main.py
    │   │   ├── d4rl_dt_mujoco.py
    │   │   ├── d4rl_edac_main.py
    │   │   ├── d4rl_ibc_main.py
    │   │   ├── d4rl_iql_main.py
    │   │   ├── d4rl_pd_main.py
    │   │   └── d4rl_td3_bc_main.py
    │   └── envs
    │   │   ├── __init__.py
    │   │   ├── d4rl_env.py
    │   │   └── d4rl_wrappers.py
    ├── dmc2gym
    │   ├── __init__.py
    │   ├── config
    │   │   ├── cartpole_balance
    │   │   │   └── cartpole_balance_dreamer_config.py
    │   │   ├── cheetah_run
    │   │   │   └── cheetah_run_dreamer_config.py
    │   │   ├── dmc2gym_dreamer_config.py
    │   │   ├── dmc2gym_ppo_config.py
    │   │   ├── dmc2gym_sac_pixel_config.py
    │   │   ├── dmc2gym_sac_state_config.py
    │   │   └── walker_walk
    │   │   │   └── walker_walk_dreamer_config.py
    │   ├── dmc2gym_cheetah.png
    │   ├── entry
    │   │   ├── dmc2gym_onppo_main.py
    │   │   ├── dmc2gym_sac_pixel_main.py
    │   │   ├── dmc2gym_sac_state_main.py
    │   │   └── dmc2gym_save_replay_example.py
    │   └── envs
    │   │   ├── __init__.py
    │   │   ├── dmc2gym_env.py
    │   │   └── test_dmc2gym_env.py
    ├── evogym
    │   ├── __init__.py
    │   ├── config
    │   │   ├── bridgewalker_ddpg_config.py
    │   │   ├── carrier_ppo_config.py
    │   │   ├── walker_ddpg_config.py
    │   │   └── walker_ppo_config.py
    │   ├── entry
    │   │   └── walker_ppo_eval.py
    │   ├── envs
    │   │   ├── __init__.py
    │   │   ├── evogym_env.py
    │   │   ├── test
    │   │   │   ├── test_evogym_env.py
    │   │   │   └── visualize_simple_env.py
    │   │   └── world_data
    │   │   │   ├── carry_bot.json
    │   │   │   ├── simple_evironment.json
    │   │   │   └── speed_bot.json
    │   └── evogym.gif
    ├── frozen_lake
    │   ├── FrozenLake.gif
    │   ├── __init__.py
    │   ├── config
    │   │   ├── __init__.py
    │   │   └── frozen_lake_dqn_config.py
    │   └── envs
    │   │   ├── __init__.py
    │   │   ├── frozen_lake_env.py
    │   │   └── test_frozen_lake_env.py
    ├── gfootball
    │   ├── README.md
    │   ├── __init__.py
    │   ├── config
    │   │   ├── gfootball_counter_mappo_config.py
    │   │   ├── gfootball_counter_masac_config.py
    │   │   ├── gfootball_keeper_mappo_config.py
    │   │   └── gfootball_keeper_masac_config.py
    │   ├── entry
    │   │   ├── __init__.py
    │   │   ├── gfootball_bc_config.py
    │   │   ├── gfootball_bc_kaggle5th_main.py
    │   │   ├── gfootball_bc_rule_lt0_main.py
    │   │   ├── gfootball_bc_rule_main.py
    │   │   ├── gfootball_dqn_config.py
    │   │   ├── parallel
    │   │   │   ├── gfootball_il_parallel_config.py
    │   │   │   └── gfootball_ppo_parallel_config.py
    │   │   ├── show_dataset.py
    │   │   └── test_accuracy.py
    │   ├── envs
    │   │   ├── __init__.py
    │   │   ├── action
    │   │   │   ├── gfootball_action.py
    │   │   │   └── gfootball_action_runner.py
    │   │   ├── fake_dataset.py
    │   │   ├── gfootball_academy_env.py
    │   │   ├── gfootball_env.py
    │   │   ├── gfootballsp_env.py
    │   │   ├── obs
    │   │   │   ├── encoder.py
    │   │   │   ├── gfootball_obs.py
    │   │   │   └── gfootball_obs_runner.py
    │   │   ├── reward
    │   │   │   ├── gfootball_reward.py
    │   │   │   └── gfootball_reward_runner.py
    │   │   └── tests
    │   │   │   ├── test_env_gfootball.py
    │   │   │   └── test_env_gfootball_academy.py
    │   ├── gfootball.gif
    │   ├── model
    │   │   ├── __init__.py
    │   │   ├── bots
    │   │   │   ├── TamakEriFever
    │   │   │   │   ├── config.yaml
    │   │   │   │   ├── football
    │   │   │   │   │   └── util.py
    │   │   │   │   ├── football_ikki.py
    │   │   │   │   ├── handyrl_core
    │   │   │   │   │   ├── model.py
    │   │   │   │   │   └── util.py
    │   │   │   │   ├── readme.md
    │   │   │   │   ├── submission.py
    │   │   │   │   └── view_test.py
    │   │   │   ├── __init__.py
    │   │   │   ├── kaggle_5th_place_model.py
    │   │   │   └── rule_based_bot_model.py
    │   │   ├── conv1d
    │   │   │   ├── conv1d.py
    │   │   │   └── conv1d_default_config.py
    │   │   └── q_network
    │   │   │   ├── football_q_network.py
    │   │   │   ├── football_q_network_default_config.py
    │   │   │   └── tests
    │   │   │       └── test_football_model.py
    │   ├── policy
    │   │   ├── __init__.py
    │   │   └── ppo_lstm.py
    │   └── replay.py
    ├── gobigger_overview.gif
    ├── gym_anytrading
    │   ├── __init__.py
    │   ├── config
    │   │   ├── __init__.py
    │   │   └── stocks_dqn_config.py
    │   ├── envs
    │   │   ├── README.md
    │   │   ├── __init__.py
    │   │   ├── data
    │   │   │   └── README.md
    │   │   ├── position.png
    │   │   ├── profit.png
    │   │   ├── statemachine.png
    │   │   ├── stocks_env.py
    │   │   ├── test_stocks_env.py
    │   │   └── trading_env.py
    │   └── worker
    │   │   ├── __init__.py
    │   │   └── trading_serial_evaluator.py
    ├── gym_hybrid
    │   ├── __init__.py
    │   ├── config
    │   │   ├── __init__.py
    │   │   ├── gym_hybrid_ddpg_config.py
    │   │   ├── gym_hybrid_hppo_config.py
    │   │   ├── gym_hybrid_mpdqn_config.py
    │   │   └── gym_hybrid_pdqn_config.py
    │   ├── entry
    │   │   ├── __init__.py
    │   │   ├── gym_hybrid_ddpg_eval.py
    │   │   └── gym_hybrid_ddpg_main.py
    │   ├── envs
    │   │   ├── README.md
    │   │   ├── __init__.py
    │   │   ├── gym-hybrid
    │   │   │   ├── README.md
    │   │   │   ├── gym_hybrid
    │   │   │   │   ├── __init__.py
    │   │   │   │   ├── agents.py
    │   │   │   │   ├── bg.jpg
    │   │   │   │   ├── environments.py
    │   │   │   │   └── target.png
    │   │   │   ├── setup.py
    │   │   │   └── tests
    │   │   │   │   ├── hardmove.py
    │   │   │   │   ├── moving.py
    │   │   │   │   ├── record.py
    │   │   │   │   ├── render.py
    │   │   │   │   └── sliding.py
    │   │   ├── gym_hybrid_env.py
    │   │   └── test_gym_hybrid_env.py
    │   └── moving_v0.gif
    ├── gym_pybullet_drones
    │   ├── __init__.py
    │   ├── config
    │   │   ├── flythrugate_onppo_config.py
    │   │   └── takeoffaviary_onppo_config.py
    │   ├── entry
    │   │   ├── flythrugate_onppo_eval.py
    │   │   └── takeoffaviary_onppo_eval.py
    │   ├── envs
    │   │   ├── __init__.py
    │   │   ├── gym_pybullet_drones_env.py
    │   │   ├── test_ding_env.py
    │   │   └── test_ori_env.py
    │   └── gym_pybullet_drones.gif
    ├── gym_soccer
    │   ├── __init__.py
    │   ├── config
    │   │   └── gym_soccer_pdqn_config.py
    │   ├── envs
    │   │   ├── README.md
    │   │   ├── __init__.py
    │   │   ├── gym_soccer_env.py
    │   │   └── test_gym_soccer_env.py
    │   └── half_offensive.gif
    ├── image_classification
    │   ├── __init__.py
    │   ├── data
    │   │   ├── __init__.py
    │   │   ├── dataset.py
    │   │   └── sampler.py
    │   ├── entry
    │   │   ├── imagenet_res18_config.py
    │   │   └── imagenet_res18_main.py
    │   ├── imagenet.png
    │   └── policy
    │   │   ├── __init__.py
    │   │   └── policy.py
    ├── ising_env
    │   ├── __init__.py
    │   ├── config
    │   │   └── ising_mfq_config.py
    │   ├── entry
    │   │   └── ising_mfq_eval.py
    │   ├── envs
    │   │   ├── __init__.py
    │   │   ├── ising_model
    │   │   │   ├── Ising.py
    │   │   │   ├── __init__.py
    │   │   │   └── multiagent
    │   │   │   │   ├── __init__.py
    │   │   │   │   ├── core.py
    │   │   │   │   └── environment.py
    │   │   ├── ising_model_env.py
    │   │   └── test_ising_model_env.py
    │   └── ising_env.gif
    ├── league_demo
    │   ├── __init__.py
    │   ├── demo_league.py
    │   ├── game_env.py
    │   ├── league_demo.png
    │   ├── league_demo_collector.py
    │   ├── league_demo_ppo_config.py
    │   ├── league_demo_ppo_main.py
    │   ├── selfplay_demo_ppo_config.py
    │   └── selfplay_demo_ppo_main.py
    ├── mario
    │   ├── __init__.py
    │   ├── mario.gif
    │   ├── mario_dqn_config.py
    │   ├── mario_dqn_example.py
    │   └── mario_dqn_main.py
    ├── maze
    │   ├── __init__.py
    │   ├── config
    │   │   ├── maze_bc_config.py
    │   │   └── maze_pc_config.py
    │   ├── entry
    │   │   └── maze_bc_main.py
    │   └── envs
    │   │   ├── __init__.py
    │   │   ├── maze_env.py
    │   │   └── test_maze_env.py
    ├── metadrive
    │   ├── __init__.py
    │   ├── config
    │   │   ├── __init__.py
    │   │   ├── metadrive_onppo_config.py
    │   │   └── metadrive_onppo_eval_config.py
    │   ├── env
    │   │   ├── __init__.py
    │   │   ├── drive_env.py
    │   │   ├── drive_utils.py
    │   │   └── drive_wrapper.py
    │   └── metadrive_env.gif
    ├── minigrid
    │   ├── __init__.py
    │   ├── config
    │   │   ├── minigrid_dreamer_config.py
    │   │   ├── minigrid_icm_offppo_config.py
    │   │   ├── minigrid_icm_onppo_config.py
    │   │   ├── minigrid_ngu_config.py
    │   │   ├── minigrid_offppo_config.py
    │   │   ├── minigrid_onppo_config.py
    │   │   ├── minigrid_onppo_stdim_config.py
    │   │   ├── minigrid_r2d2_config.py
    │   │   └── minigrid_rnd_onppo_config.py
    │   ├── entry
    │   │   └── minigrid_onppo_main.py
    │   ├── envs
    │   │   ├── __init__.py
    │   │   ├── app_key_to_door_treasure.py
    │   │   ├── minigrid_env.py
    │   │   ├── minigrid_wrapper.py
    │   │   ├── noisy_tv.py
    │   │   └── test_minigrid_env.py
    │   ├── minigrid.gif
    │   └── utils
    │   │   └── eval.py
    ├── mujoco
    │   ├── __init__.py
    │   ├── addition
    │   │   └── install_mesa.sh
    │   ├── config
    │   │   ├── __init__.py
    │   │   ├── ant_ddpg_config.py
    │   │   ├── ant_gail_sac_config.py
    │   │   ├── ant_onppo_config.py
    │   │   ├── ant_ppo_config.py
    │   │   ├── ant_sac_config.py
    │   │   ├── ant_td3_config.py
    │   │   ├── ant_trex_onppo_config.py
    │   │   ├── ant_trex_sac_config.py
    │   │   ├── halfcheetah_bco_config.py
    │   │   ├── halfcheetah_bdq_config.py
    │   │   ├── halfcheetah_d4pg_config.py
    │   │   ├── halfcheetah_ddpg_config.py
    │   │   ├── halfcheetah_gail_sac_config.py
    │   │   ├── halfcheetah_gcl_sac_config.py
    │   │   ├── halfcheetah_onppo_config.py
    │   │   ├── halfcheetah_sac_config.py
    │   │   ├── halfcheetah_sqil_sac_config.py
    │   │   ├── halfcheetah_td3_config.py
    │   │   ├── halfcheetah_trex_onppo_config.py
    │   │   ├── halfcheetah_trex_sac_config.py
    │   │   ├── hopper_bco_config.py
    │   │   ├── hopper_bdq_config.py
    │   │   ├── hopper_cql_config.py
    │   │   ├── hopper_d4pg_config.py
    │   │   ├── hopper_ddpg_config.py
    │   │   ├── hopper_gail_sac_config.py
    │   │   ├── hopper_gcl_config.py
    │   │   ├── hopper_onppo_config.py
    │   │   ├── hopper_sac_config.py
    │   │   ├── hopper_sac_data_generation_config.py
    │   │   ├── hopper_sqil_sac_config.py
    │   │   ├── hopper_td3_bc_config.py
    │   │   ├── hopper_td3_config.py
    │   │   ├── hopper_td3_data_generation_config.py
    │   │   ├── hopper_trex_onppo_config.py
    │   │   ├── hopper_trex_sac_config.py
    │   │   ├── mbrl
    │   │   │   ├── halfcheetah_mbsac_mbpo_config.py
    │   │   │   ├── halfcheetah_sac_mbpo_config.py
    │   │   │   ├── halfcheetah_stevesac_mbpo_config.py
    │   │   │   ├── hopper_mbsac_mbpo_config.py
    │   │   │   ├── hopper_sac_mbpo_config.py
    │   │   │   ├── hopper_stevesac_mbpo_config.py
    │   │   │   ├── walker2d_mbsac_mbpo_config.py
    │   │   │   ├── walker2d_sac_mbpo_config.py
    │   │   │   └── walker2d_stevesac_mbpo_config.py
    │   │   ├── walker2d_d4pg_config.py
    │   │   ├── walker2d_ddpg_config.py
    │   │   ├── walker2d_gail_ddpg_config.py
    │   │   ├── walker2d_gail_sac_config.py
    │   │   ├── walker2d_gcl_config.py
    │   │   ├── walker2d_onppo_config.py
    │   │   ├── walker2d_sac_config.py
    │   │   ├── walker2d_sqil_sac_config.py
    │   │   ├── walker2d_td3_config.py
    │   │   ├── walker2d_trex_onppo_config.py
    │   │   └── walker2d_trex_sac_config.py
    │   ├── entry
    │   │   ├── __init__.py
    │   │   ├── mujoco_cql_generation_main.py
    │   │   ├── mujoco_cql_main.py
    │   │   ├── mujoco_d4pg_main.py
    │   │   ├── mujoco_ddpg_eval.py
    │   │   ├── mujoco_ddpg_main.py
    │   │   ├── mujoco_ppo_main.py
    │   │   └── mujoco_td3_bc_main.py
    │   ├── envs
    │   │   ├── __init__.py
    │   │   ├── mujoco_disc_env.py
    │   │   ├── mujoco_env.py
    │   │   ├── mujoco_gym_env.py
    │   │   ├── mujoco_wrappers.py
    │   │   └── test
    │   │   │   ├── test_mujoco_disc_env.py
    │   │   │   ├── test_mujoco_env.py
    │   │   │   └── test_mujoco_gym_env.py
    │   ├── example
    │   │   ├── mujoco_bc_main.py
    │   │   └── mujoco_sac.py
    │   └── mujoco.gif
    ├── multiagent_mujoco
    │   ├── README.md
    │   ├── __init__.py
    │   ├── config
    │   │   ├── ant_maddpg_config.py
    │   │   ├── ant_mappo_config.py
    │   │   ├── ant_masac_config.py
    │   │   ├── ant_matd3_config.py
    │   │   ├── halfcheetah_happo_config.py
    │   │   ├── halfcheetah_mappo_config.py
    │   │   └── walker2d_happo_config.py
    │   └── envs
    │   │   ├── __init__.py
    │   │   ├── assets
    │   │       ├── .gitignore
    │   │       ├── __init__.py
    │   │       ├── coupled_half_cheetah.xml
    │   │       ├── manyagent_ant.xml
    │   │       ├── manyagent_ant.xml.template
    │   │       ├── manyagent_ant__stage1.xml
    │   │       ├── manyagent_swimmer.xml.template
    │   │       ├── manyagent_swimmer__bckp2.xml
    │   │       └── manyagent_swimmer_bckp.xml
    │   │   ├── coupled_half_cheetah.py
    │   │   ├── manyagent_ant.py
    │   │   ├── manyagent_swimmer.py
    │   │   ├── mujoco_multi.py
    │   │   ├── multi_mujoco_env.py
    │   │   ├── multiagentenv.py
    │   │   └── obsk.py
    ├── overcooked
    │   ├── README.md
    │   ├── __init__.py
    │   ├── config
    │   │   ├── __init__.py
    │   │   └── overcooked_ppo_config.py
    │   ├── envs
    │   │   ├── __init__.py
    │   │   ├── overcooked_env.py
    │   │   └── test_overcooked_env.py
    │   └── overcooked.gif
    ├── petting_zoo
    │   ├── __init__.py
    │   ├── config
    │   │   ├── __init__.py
    │   │   ├── ptz_pistonball_qmix_config.py
    │   │   ├── ptz_simple_spread_atoc_config.py
    │   │   ├── ptz_simple_spread_collaq_config.py
    │   │   ├── ptz_simple_spread_coma_config.py
    │   │   ├── ptz_simple_spread_happo_config.py
    │   │   ├── ptz_simple_spread_maddpg_config.py
    │   │   ├── ptz_simple_spread_madqn_config.py
    │   │   ├── ptz_simple_spread_mappo_config.py
    │   │   ├── ptz_simple_spread_masac_config.py
    │   │   ├── ptz_simple_spread_qmix_config.py
    │   │   ├── ptz_simple_spread_qtran_config.py
    │   │   ├── ptz_simple_spread_vdn_config.py
    │   │   └── ptz_simple_spread_wqmix_config.py
    │   ├── entry
    │   │   └── ptz_simple_spread_eval.py
    │   ├── envs
    │   │   ├── __init__.py
    │   │   ├── petting_zoo_pistonball_env.py
    │   │   ├── petting_zoo_simple_spread_env.py
    │   │   ├── test_petting_zoo_pistonball_env.py
    │   │   └── test_petting_zoo_simple_spread_env.py
    │   └── petting_zoo_mpe_simple_spread.gif
    ├── pomdp
    │   ├── __init__.py
    │   ├── config
    │   │   ├── pomdp_dqn_config.py
    │   │   └── pomdp_ppo_config.py
    │   └── envs
    │   │   ├── __init__.py
    │   │   ├── atari_env.py
    │   │   ├── atari_wrappers.py
    │   │   └── test_atari_env.py
    ├── procgen
    │   ├── README.md
    │   ├── __init__.py
    │   ├── coinrun.gif
    │   ├── coinrun.png
    │   ├── coinrun_dqn.svg
    │   ├── coinrun_ppo.svg
    │   ├── config
    │   │   ├── __init__.py
    │   │   ├── bigfish_plr_config.py
    │   │   ├── bigfish_ppg_config.py
    │   │   ├── coinrun_dqn_config.py
    │   │   ├── coinrun_ppg_config.py
    │   │   ├── coinrun_ppo_config.py
    │   │   ├── maze_dqn_config.py
    │   │   ├── maze_ppg_config.py
    │   │   └── maze_ppo_config.py
    │   ├── entry
    │   │   └── coinrun_onppo_main.py
    │   ├── envs
    │   │   ├── __init__.py
    │   │   ├── procgen_env.py
    │   │   └── test_coinrun_env.py
    │   ├── maze.gif
    │   ├── maze.png
    │   └── maze_dqn.svg
    ├── pybullet
    │   ├── __init__.py
    │   ├── envs
    │   │   ├── __init__.py
    │   │   ├── pybullet_env.py
    │   │   └── pybullet_wrappers.py
    │   └── pybullet.gif
    ├── rocket
    │   ├── README.md
    │   ├── __init__.py
    │   ├── config
    │   │   ├── __init__.py
    │   │   ├── rocket_hover_ppo_config.py
    │   │   └── rocket_landing_ppo_config.py
    │   ├── entry
    │   │   ├── __init__.py
    │   │   ├── rocket_hover_onppo_main_v2.py
    │   │   ├── rocket_hover_ppo_main.py
    │   │   ├── rocket_landing_onppo_main_v2.py
    │   │   └── rocket_landing_ppo_main.py
    │   └── envs
    │   │   ├── __init__.py
    │   │   ├── rocket_env.py
    │   │   └── test_rocket_env.py
    ├── slime_volley
    │   ├── __init__.py
    │   ├── config
    │   │   ├── slime_volley_league_ppo_config.py
    │   │   └── slime_volley_ppo_config.py
    │   ├── entry
    │   │   ├── slime_volley_league_ppo_main.py
    │   │   └── slime_volley_selfplay_ppo_main.py
    │   ├── envs
    │   │   ├── __init__.py
    │   │   ├── slime_volley_env.py
    │   │   └── test_slime_volley_env.py
    │   └── slime_volley.gif
    ├── smac
    │   ├── README.md
    │   ├── __init__.py
    │   ├── config
    │   │   ├── smac_10m11m_mappo_config.py
    │   │   ├── smac_10m11m_masac_config.py
    │   │   ├── smac_25m_mappo_config.py
    │   │   ├── smac_25m_masac_config.py
    │   │   ├── smac_27m30m_mappo_config.py
    │   │   ├── smac_2c64zg_mappo_config.py
    │   │   ├── smac_2c64zg_masac_config.py
    │   │   ├── smac_2c64zg_qmix_config.py
    │   │   ├── smac_2s3z_qmix_config.py
    │   │   ├── smac_2s3z_qtran_config.py
    │   │   ├── smac_3m_masac_config.py
    │   │   ├── smac_3s5z_collaq_config.py
    │   │   ├── smac_3s5z_collaq_per_config.py
    │   │   ├── smac_3s5z_coma_config.py
    │   │   ├── smac_3s5z_madqn_config.py
    │   │   ├── smac_3s5z_mappo_config.py
    │   │   ├── smac_3s5z_masac_config.py
    │   │   ├── smac_3s5z_qmix_config.py
    │   │   ├── smac_3s5z_qtran_config.py
    │   │   ├── smac_3s5z_wqmix_config.py
    │   │   ├── smac_3s5zvs3s6z_madqn_config.py
    │   │   ├── smac_3s5zvs3s6z_mappo_config.py
    │   │   ├── smac_3s5zvs3s6z_masac_config.py
    │   │   ├── smac_5m6m_collaq_config.py
    │   │   ├── smac_5m6m_madqn_config.py
    │   │   ├── smac_5m6m_mappo_config.py
    │   │   ├── smac_5m6m_masac_config.py
    │   │   ├── smac_5m6m_qmix_config.py
    │   │   ├── smac_5m6m_qtran_config.py
    │   │   ├── smac_5m6m_wqmix_config.py
    │   │   ├── smac_8m9m_madqn_config.py
    │   │   ├── smac_8m9m_mappo_config.py
    │   │   ├── smac_8m9m_masac_config.py
    │   │   ├── smac_MMM2_collaq_config.py
    │   │   ├── smac_MMM2_coma_config.py
    │   │   ├── smac_MMM2_madqn_config.py
    │   │   ├── smac_MMM2_mappo_config.py
    │   │   ├── smac_MMM2_masac_config.py
    │   │   ├── smac_MMM2_qmix_config.py
    │   │   ├── smac_MMM2_wqmix_config.py
    │   │   ├── smac_MMM_collaq_config.py
    │   │   ├── smac_MMM_coma_config.py
    │   │   ├── smac_MMM_madqn_config.py
    │   │   ├── smac_MMM_mappo_config.py
    │   │   ├── smac_MMM_masac_config.py
    │   │   ├── smac_MMM_qmix_config.py
    │   │   ├── smac_MMM_qtran_config.py
    │   │   ├── smac_MMM_wqmix_config.py
    │   │   ├── smac_corridor_mappo_config.py
    │   │   └── smac_corridor_masac_config.py
    │   ├── envs
    │   │   ├── __init__.py
    │   │   ├── fake_smac_env.py
    │   │   ├── maps
    │   │   │   ├── README.md
    │   │   │   ├── SMAC_Maps
    │   │   │   │   ├── 10m_vs_11m.SC2Map
    │   │   │   │   ├── 1c3s5z.SC2Map
    │   │   │   │   ├── 25m.SC2Map
    │   │   │   │   ├── 27m_vs_30m.SC2Map
    │   │   │   │   ├── 2c_vs_64zg.SC2Map
    │   │   │   │   ├── 2m_vs_1z.SC2Map
    │   │   │   │   ├── 2s3z.SC2Map
    │   │   │   │   ├── 2s_vs_1sc.SC2Map
    │   │   │   │   ├── 3m.SC2Map
    │   │   │   │   ├── 3s5z.SC2Map
    │   │   │   │   ├── 3s5z_vs_3s6z.SC2Map
    │   │   │   │   ├── 3s_vs_3z.SC2Map
    │   │   │   │   ├── 3s_vs_4z.SC2Map
    │   │   │   │   ├── 3s_vs_5z.SC2Map
    │   │   │   │   ├── 5m_vs_6m.SC2Map
    │   │   │   │   ├── 6h_vs_8z.SC2Map
    │   │   │   │   ├── 8m.SC2Map
    │   │   │   │   ├── 8m_vs_9m.SC2Map
    │   │   │   │   ├── MMM.SC2Map
    │   │   │   │   ├── MMM2.SC2Map
    │   │   │   │   ├── __init__.py
    │   │   │   │   ├── bane_vs_bane.SC2Map
    │   │   │   │   ├── corridor.SC2Map
    │   │   │   │   ├── infestor_viper.SC2Map
    │   │   │   │   └── so_many_baneling.SC2Map
    │   │   │   ├── SMAC_Maps_two_player
    │   │   │   │   ├── 3m.SC2Map
    │   │   │   │   ├── 3s5z.SC2Map
    │   │   │   │   └── __init__.py
    │   │   │   └── __init__.py
    │   │   ├── smac_action.py
    │   │   ├── smac_env.py
    │   │   ├── smac_map.py
    │   │   ├── smac_reward.py
    │   │   └── test_smac_env.py
    │   ├── smac.gif
    │   └── utils
    │   │   └── eval.py
    ├── sokoban
    │   ├── __init__.py
    │   └── envs
    │   │   ├── __init__.py
    │   │   ├── sokoban_env.py
    │   │   ├── sokoban_wrappers.py
    │   │   └── test_sokoban_env.py
    ├── tabmwp
    │   ├── README.md
    │   ├── __init__.py
    │   ├── benchmark.png
    │   ├── config
    │   │   ├── tabmwp_awr_config.py
    │   │   └── tabmwp_pg_config.py
    │   ├── envs
    │   │   ├── __init__.py
    │   │   ├── tabmwp_env.py
    │   │   ├── test_tabmwp_env.py
    │   │   └── utils.py
    │   └── tabmwp.jpeg
    └── taxi
    │   ├── Taxi-v3_episode_0.gif
    │   ├── __init__.py
    │   ├── config
    │       ├── __init__.py
    │       └── taxi_dqn_config.py
    │   ├── entry
    │       └── taxi_dqn_deploy.py
    │   └── envs
    │       ├── __init__.py
    │       ├── taxi_env.py
    │       └── test_taxi_env.py
├── docker
    ├── Dockerfile.base
    ├── Dockerfile.env
    ├── Dockerfile.hpc
    └── Dockerfile.rpc
├── format.sh
├── pytest.ini
└── setup.py


/.coveragerc:
--------------------------------------------------------------------------------
 1 | [run]
 2 | concurrency = multiprocessing,thread
 3 | omit =
 4 |   ding/utils/slurm_helper.py
 5 |   ding/utils/file_helper.py
 6 |   ding/utils/linklink_dist_helper.py
 7 |   ding/utils/pytorch_ddp_dist_helper.py
 8 |   ding/utils/k8s_helper.py
 9 |   ding/utils/tests/test_k8s_launcher.py
10 |   ding/utils/time_helper_cuda.py
11 |   ding/utils/time_helper_base.py
12 |   ding/utils/data/tests/test_dataloader.py
13 |   ding/config/utils.py
14 |   ding/entry/tests/test_serial_entry_algo.py
15 |   ding/entry/tests/test_serial_entry.py
16 |   ding/entry/dist_entry.py
17 |   ding/entry/cli.py
18 |   ding/entry/predefined_config.py
19 |   ding/hpc_rl/*
20 |   ding/worker/collector/tests/speed_test/*
21 |   ding/envs/env_wrappers/env_wrappers.py
22 |   ding/envs/env_manager/tests/test_env_supervisor.py
23 |   ding/rl_utils/tests/test_ppg.py
24 |   ding/scripts/tests/test_parallel_socket.py
25 |   ding/data/buffer/tests/test_benchmark.py
26 |   ding/example/*
27 |   ding/torch_utils/loss/tests/*
28 | 


--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore=F401,F841,F403,E226,E126,W504,E265,E722,W503,W605,E741,E122,E731
3 | max-line-length=120
4 | statistics
5 | 


--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/custom.md:
--------------------------------------------------------------------------------
 1 | ---
 2 | name: Custom issue template
 3 | about: Describe this issue template's purpose here.
 4 | title: ''
 5 | labels: ''
 6 | assignees: ''
 7 | 
 8 | ---
 9 | 
10 | - [ ] I have marked all applicable categories:
11 |     + [ ] exception-raising bug
12 |     + [ ] RL algorithm bug
13 |     + [ ] system worker bug
14 |     + [ ] system utils bug
15 |     + [ ] code design/refactor
16 |     + [ ] documentation request
17 |     + [ ] new feature request
18 | - [ ] I have visited the [readme](https://github.com/opendilab/DI-engine/blob/github-dev/README.md) and [doc](https://opendilab.github.io/DI-engine/)
19 | - [ ] I have searched through the [issue tracker](https://github.com/opendilab/DI-engine/issues) and [pr tracker](https://github.com/opendilab/DI-engine/pulls)
20 | - [ ] I have mentioned version numbers, operating system and environment, where applicable:
21 |   ```python
22 |   import ding, torch, sys
23 |   print(ding.__version__, torch.__version__, sys.version, sys.platform)
24 |   ```
25 | 


--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
 1 | ## Description
 2 | 
 3 | 
 4 | ## Related Issue
 5 | 
 6 | 
 7 | ## TODO
 8 | 
 9 | 
10 | ## Check List
11 | 
12 | - [ ] merge the latest version source branch/repo, and resolve all the conflicts
13 | - [ ] pass style check
14 | - [ ] pass all the tests
15 | 


--------------------------------------------------------------------------------
/.github/workflows/algo_test.yml:
--------------------------------------------------------------------------------
 1 | # This workflow will check pytest
 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
 3 | 
 4 | name: algo_test
 5 | 
 6 | on:
 7 |   push:
 8 |     paths:
 9 |       - "ding/policy/*"
10 |       - "ding/model/*"
11 |       - "ding/rl_utils/*"
12 | 
13 | jobs:
14 |   test_algotest:
15 |     runs-on: ubuntu-latest
16 |     if: "!contains(github.event.head_commit.message, 'ci skip')"
17 |     strategy:
18 |       matrix:
19 |         python-version: [3.8, 3.9]
20 | 
21 |     steps:
22 |       - uses: actions/checkout@v4
23 |       - name: Set up Python ${{ matrix.python-version }}
24 |         uses: actions/setup-python@v5
25 |         with:
26 |           python-version: ${{ matrix.python-version }}
27 |       - name: do_algotest
28 |         env:
29 |           WORKERS: 4
30 |           DURATIONS: 600
31 |         run: |
32 |           python -m pip install .
33 |           python -m pip install ".[test,k8s]"
34 |           python -m pip install transformers
35 |           ./ding/scripts/install-k8s-tools.sh
36 |           make algotest
37 | 


--------------------------------------------------------------------------------
/.github/workflows/envpool_test.yml:
--------------------------------------------------------------------------------
 1 | # This workflow will check pytest
 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
 3 | 
 4 | name: envpool_test
 5 | 
 6 | on: [push, pull_request]
 7 | 
 8 | jobs:
 9 |   test_envpooltest:
10 |     runs-on: ubuntu-latest
11 |     if: "!contains(github.event.head_commit.message, 'ci skip')"
12 |     strategy:
13 |       matrix:
14 |         python-version: [3.8] # Envpool only supports python>=3.7
15 | 
16 |     steps:
17 |       - uses: actions/checkout@v4
18 |       - name: Set up Python ${{ matrix.python-version }}
19 |         uses: actions/setup-python@v5
20 |         with:
21 |           python-version: ${{ matrix.python-version }}
22 |       - name: do_envpool_test
23 |         run: |
24 |           python -m pip install .
25 |           python -m pip install ".[test,k8s]"
26 |           python -m pip install ".[envpool]"
27 |           python -m pip install transformers
28 |           ./ding/scripts/install-k8s-tools.sh
29 |           make envpooltest
30 | 


--------------------------------------------------------------------------------
/.github/workflows/platform_test.yml:
--------------------------------------------------------------------------------
 1 | # This workflow will check pytest
 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
 3 | 
 4 | name: platform_test
 5 | 
 6 | on: [push, pull_request]
 7 | 
 8 | jobs:
 9 |   test_unittest:
10 |     runs-on: ${{ matrix.os }}
11 |     if: "!contains(github.event.head_commit.message, 'ci skip')"
12 |     strategy:
13 |       matrix:
14 |         os: [macos-13, windows-latest]
15 |         python-version: [3.8, 3.9]
16 | 
17 |     steps:
18 |       - uses: actions/checkout@v4
19 |       - name: Set up Python ${{ matrix.python-version }}
20 |         uses: actions/setup-python@v5
21 |         with:
22 |           python-version: ${{ matrix.python-version }}
23 |       - name: do_platform_test
24 |         timeout-minutes: 30
25 |         run: |
26 |           python -m pip install .
27 |           python -m pip install ".[test,k8s]"
28 |           python -m pip install transformers
29 |           python -m pip uninstall pytest-timeouts -y
30 |           make platformtest
31 | 


--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
 1 | name: package_release
 2 | 
 3 | on: [push]
 4 | 
 5 | jobs:
 6 |   release:
 7 |     name: Publish to official pypi
 8 |     runs-on: ${{ matrix.os }}
 9 |     if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
10 |     strategy:
11 |       matrix:
12 |         os:
13 |           - ubuntu-latest
14 |         python-version: [3.8]
15 | 
16 |     steps:
17 |       - name: Checkout code
18 |         uses: actions/checkout@v4
19 |       - name: Set up python ${{ matrix.python-version }}
20 |         uses: actions/setup-python@v5
21 |         with:
22 |           python-version: ${{ matrix.python-version }}
23 |       - name: Set up python dependences
24 |         run: |
25 |           pip install --upgrade pip
26 |           pip install --upgrade flake8 setuptools wheel twine
27 |           pip install .
28 |           pip install --upgrade build
29 |       - name: Build packages
30 |         run: |
31 |           python -m build --sdist --wheel --outdir dist/
32 |       - name: Publish package
33 |         uses: pypa/gh-action-pypi-publish@release/v1
34 |         with:
35 |           user: __token__
36 |           password: ${{ secrets.PYPI_API_TOKEN }}
37 | 


--------------------------------------------------------------------------------
/.github/workflows/release_conda.yml:
--------------------------------------------------------------------------------
 1 | name: package_release_conda
 2 | 
 3 | on: [push]
 4 |     
 5 | jobs:
 6 |   release:
 7 |     runs-on: ubuntu-latest
 8 |     if: github.event_name == 'push' && (startsWith(github.ref, 'refs/tags') || contains(github.event.head_commit.message, 'conda'))
 9 |     steps:
10 |     - uses: actions/checkout@v4
11 |     - name: publish-to-conda
12 |       uses: fcakyon/conda-publish-action@v1.3
13 |       with:
14 |         subdir: 'conda'
15 |         anacondatoken: ${{ secrets.ANACONDA_TOKEN }}
16 |         platforms: 'win osx linux'
17 | 


--------------------------------------------------------------------------------
/.github/workflows/style.yml:
--------------------------------------------------------------------------------
 1 | # This workflow will check flake style
 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
 3 | 
 4 | name: style
 5 | 
 6 | on: [push, pull_request]
 7 | 
 8 | jobs:
 9 |   style:
10 |     runs-on: ubuntu-latest
11 |     strategy:
12 |       matrix:
13 |         python-version: ['3.8', '3.9', '3.10']
14 | 
15 |     steps:
16 |       - uses: actions/checkout@v2
17 |       - name: Set up Python ${{ matrix.python-version }}
18 |         uses: actions/setup-python@v5
19 |         with:
20 |           python-version: ${{ matrix.python-version }}
21 |       - name: code style
22 |         run: |
23 |           python -m pip install "yapf==0.29.0"
24 |           python -m pip install "flake8<=3.9.2"
25 |           python -m pip install "importlib-metadata<5.0.0"
26 |           make format_test flake_check
27 | 


--------------------------------------------------------------------------------
/.style.yapf:
--------------------------------------------------------------------------------
 1 | [style]
 2 | # For explanation and more information: https://github.com/google/yapf
 3 | BASED_ON_STYLE=pep8
 4 | DEDENT_CLOSING_BRACKETS=True
 5 | SPLIT_BEFORE_FIRST_ARGUMENT=True
 6 | ALLOW_SPLIT_BEFORE_DICT_VALUE=False
 7 | JOIN_MULTIPLE_LINES=False
 8 | COLUMN_LIMIT=120
 9 | BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF=True
10 | BLANK_LINES_AROUND_TOP_LEVEL_DEFINITION=2
11 | SPACES_AROUND_POWER_OPERATOR=True
12 | 


--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | [Git Guide](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/git_guide.html)
2 | 
3 | [GitHub Cooperation Guide](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/issue_pr.html)
4 | 
5 |   - [Code Style](https://di-engine-docs.readthedocs.io/en/latest/21_code_style/index.html)
6 |   - [Unit Test](https://di-engine-docs.readthedocs.io/en/latest/22_test/index.html)
7 |   - [Code Review](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/issue_pr.html#pr-s-code-review)
8 | 


--------------------------------------------------------------------------------
/assets/wechat.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/assets/wechat.jpeg


--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | coverage:
2 |   status:
3 |     project:
4 |       default:
5 |         # basic
6 |         target: auto
7 |         threshold: 0.5%
8 |         if_ci_failed: success #success, failure, error, ignore
9 | 


--------------------------------------------------------------------------------
/conda/conda_build_config.yaml:
--------------------------------------------------------------------------------
1 | python:
2 |   - 3.7
3 | 


--------------------------------------------------------------------------------
/conda/meta.yaml:
--------------------------------------------------------------------------------
 1 | {% set data = load_setup_py_data() %}
 2 | package:
 3 |   name: di-engine
 4 |   version: v0.5.3
 5 | 
 6 | source:
 7 |   path: ..
 8 | 
 9 | build:
10 |   number: 0
11 |   script: python -m pip install . -vv
12 |   entry_points:
13 |     - ding = ding.entry.cli:cli
14 | 
15 | requirements:
16 |   build:
17 |     - python
18 |     - setuptools
19 |   run:
20 |     - python
21 | 
22 | test:
23 |   imports:
24 |     - ding
25 |     - dizoo
26 | 
27 | about:
28 |   home: https://github.com/opendilab/DI-engine
29 |   license: Apache-2.0
30 |   license_file: LICENSE
31 |   summary: DI-engine is a generalized Decision Intelligence engine (https://github.com/opendilab/DI-engine).
32 |   description: Please refer to https://di-engine-docs.readthedocs.io/en/latest/00_intro/index.html#what-is-di-engine
33 |   dev_url: https://github.com/opendilab/DI-engine
34 |   doc_url: https://di-engine-docs.readthedocs.io/en/latest/index.html
35 |   doc_source_url: https://github.com/opendilab/DI-engine-docs
36 | 


--------------------------------------------------------------------------------
/ding/__init__.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | 
 3 | __TITLE__ = 'DI-engine'
 4 | __VERSION__ = 'v0.5.3'
 5 | __DESCRIPTION__ = 'Decision AI Engine'
 6 | __AUTHOR__ = "OpenDILab Contributors"
 7 | __AUTHOR_EMAIL__ = "opendilab@pjlab.org.cn"
 8 | __version__ = __VERSION__
 9 | 
10 | enable_hpc_rl = os.environ.get('ENABLE_DI_HPC', 'false').lower() == 'true'
11 | enable_linklink = os.environ.get('ENABLE_LINKLINK', 'false').lower() == 'true'
12 | enable_numba = True
13 | 


--------------------------------------------------------------------------------
/ding/bonus/common.py:
--------------------------------------------------------------------------------
 1 | from dataclasses import dataclass
 2 | import numpy as np
 3 | 
 4 | 
 5 | @dataclass
 6 | class TrainingReturn:
 7 |     '''
 8 |     Attributions
 9 |     wandb_url: The weight & biases (wandb) project url of the trainning experiment.
10 |     '''
11 |     wandb_url: str
12 | 
13 | 
14 | @dataclass
15 | class EvalReturn:
16 |     '''
17 |     Attributions
18 |     eval_value: The mean of evaluation return.
19 |     eval_value_std: The standard deviation of evaluation return.
20 |     '''
21 |     eval_value: np.float32
22 |     eval_value_std: np.float32
23 | 


--------------------------------------------------------------------------------
/ding/compatibility.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | 
 3 | 
 4 | def torch_ge_131():
 5 |     return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 131
 6 | 
 7 | 
 8 | def torch_ge_180():
 9 |     return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 180
10 | 


--------------------------------------------------------------------------------
/ding/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .config import Config, read_config, save_config, compile_config, compile_config_parallel, read_config_directly, \
2 |     read_config_with_system, save_config_py
3 | from .utils import parallel_transform, parallel_transform_slurm
4 | from .example import A2C, C51, DDPG, DQN, PG, PPOF, PPOOffPolicy, SAC, SQL, TD3
5 | 


--------------------------------------------------------------------------------
/ding/config/example/A2C/__init__.py:
--------------------------------------------------------------------------------
 1 | from easydict import EasyDict
 2 | from . import gym_bipedalwalker_v3
 3 | from . import gym_lunarlander_v2
 4 | 
 5 | supported_env_cfg = {
 6 |     gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.cfg,
 7 |     gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg,
 8 | }
 9 | 
10 | supported_env_cfg = EasyDict(supported_env_cfg)
11 | 
12 | supported_env = {
13 |     gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.env,
14 |     gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env,
15 | }
16 | 
17 | supported_env = EasyDict(supported_env)
18 | 


--------------------------------------------------------------------------------
/ding/config/example/A2C/gym_lunarlander_v2.py:
--------------------------------------------------------------------------------
 1 | from easydict import EasyDict
 2 | import ding.envs.gym_env
 3 | 
 4 | cfg = dict(
 5 |     exp_name='LunarLander-v2-A2C',
 6 |     env=dict(
 7 |         collector_env_num=8,
 8 |         evaluator_env_num=8,
 9 |         env_id='LunarLander-v2',
10 |         n_evaluator_episode=8,
11 |         stop_value=260,
12 |     ),
13 |     policy=dict(
14 |         cuda=True,
15 |         model=dict(
16 |             obs_shape=8,
17 |             action_shape=4,
18 |         ),
19 |         learn=dict(
20 |             batch_size=64,
21 |             learning_rate=3e-4,
22 |             entropy_weight=0.001,
23 |             adv_norm=True,
24 |         ),
25 |         collect=dict(
26 |             n_sample=64,
27 |             discount_factor=0.99,
28 |             gae_lambda=0.95,
29 |         ),
30 |     ),
31 |     wandb_logger=dict(
32 |         gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
33 |     ),
34 | )
35 | 
36 | cfg = EasyDict(cfg)
37 | 
38 | env = ding.envs.gym_env.env
39 | 


--------------------------------------------------------------------------------
/ding/config/example/C51/__init__.py:
--------------------------------------------------------------------------------
 1 | from easydict import EasyDict
 2 | from . import gym_lunarlander_v2
 3 | from . import gym_pongnoframeskip_v4
 4 | from . import gym_qbertnoframeskip_v4
 5 | from . import gym_spaceInvadersnoframeskip_v4
 6 | 
 7 | supported_env_cfg = {
 8 |     gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg,
 9 |     gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.cfg,
10 |     gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.cfg,
11 |     gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.cfg,
12 | }
13 | 
14 | supported_env_cfg = EasyDict(supported_env_cfg)
15 | 
16 | supported_env = {
17 |     gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env,
18 |     gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.env,
19 |     gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.env,
20 |     gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.env,
21 | }
22 | 
23 | supported_env = EasyDict(supported_env)
24 | 


--------------------------------------------------------------------------------
/ding/config/example/DQN/__init__.py:
--------------------------------------------------------------------------------
 1 | from easydict import EasyDict
 2 | from . import gym_lunarlander_v2
 3 | from . import gym_pongnoframeskip_v4
 4 | from . import gym_qbertnoframeskip_v4
 5 | from . import gym_spaceInvadersnoframeskip_v4
 6 | 
 7 | supported_env_cfg = {
 8 |     gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg,
 9 |     gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.cfg,
10 |     gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.cfg,
11 |     gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.cfg,
12 | }
13 | 
14 | supported_env_cfg = EasyDict(supported_env_cfg)
15 | 
16 | supported_env = {
17 |     gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env,
18 |     gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.env,
19 |     gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.env,
20 |     gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.env,
21 | }
22 | 
23 | supported_env = EasyDict(supported_env)
24 | 


--------------------------------------------------------------------------------
/ding/config/example/PG/__init__.py:
--------------------------------------------------------------------------------
 1 | from easydict import EasyDict
 2 | from . import gym_pendulum_v1
 3 | 
 4 | supported_env_cfg = {
 5 |     gym_pendulum_v1.cfg.env.env_id: gym_pendulum_v1.cfg,
 6 | }
 7 | 
 8 | supported_env_cfg = EasyDict(supported_env_cfg)
 9 | 
10 | supported_env = {
11 |     gym_pendulum_v1.cfg.env.env_id: gym_pendulum_v1.env,
12 | }
13 | 
14 | supported_env = EasyDict(supported_env)
15 | 


--------------------------------------------------------------------------------
/ding/config/example/PG/gym_pendulum_v1.py:
--------------------------------------------------------------------------------
 1 | from easydict import EasyDict
 2 | import ding.envs.gym_env
 3 | 
 4 | cfg = dict(
 5 |     exp_name='Pendulum-v1-PG',
 6 |     seed=0,
 7 |     env=dict(
 8 |         env_id='Pendulum-v1',
 9 |         collector_env_num=8,
10 |         evaluator_env_num=5,
11 |         n_evaluator_episode=5,
12 |         stop_value=-200,
13 |         act_scale=True,
14 |     ),
15 |     policy=dict(
16 |         cuda=False,
17 |         action_space='continuous',
18 |         model=dict(
19 |             action_space='continuous',
20 |             obs_shape=3,
21 |             action_shape=1,
22 |         ),
23 |         learn=dict(
24 |             batch_size=4000,
25 |             learning_rate=0.001,
26 |             entropy_weight=0.001,
27 |         ),
28 |         collect=dict(
29 |             n_episode=20,
30 |             unroll_len=1,
31 |             discount_factor=0.99,
32 |         ),
33 |         eval=dict(evaluator=dict(eval_freq=1, ))
34 |     ),
35 |     wandb_logger=dict(
36 |         gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
37 |     ),
38 | )
39 | 
40 | cfg = EasyDict(cfg)
41 | 
42 | env = ding.envs.gym_env.env
43 | 


--------------------------------------------------------------------------------
/ding/config/example/PPOF/__init__.py:
--------------------------------------------------------------------------------
 1 | from easydict import EasyDict
 2 | from . import gym_lunarlander_v2
 3 | from . import gym_lunarlandercontinuous_v2
 4 | 
 5 | supported_env_cfg = {
 6 |     gym_lunarlander_v2.cfg.env_id: gym_lunarlander_v2.cfg,
 7 |     gym_lunarlandercontinuous_v2.cfg.env_id: gym_lunarlandercontinuous_v2.cfg,
 8 | }
 9 | 
10 | supported_env_cfg = EasyDict(supported_env_cfg)
11 | 
12 | supported_env = {
13 |     gym_lunarlander_v2.cfg.env_id: gym_lunarlander_v2.env,
14 |     gym_lunarlandercontinuous_v2.cfg.env_id: gym_lunarlandercontinuous_v2.env,
15 | }
16 | 
17 | supported_env = EasyDict(supported_env)
18 | 


--------------------------------------------------------------------------------
/ding/config/example/PPOF/gym_lunarlander_v2.py:
--------------------------------------------------------------------------------
 1 | from easydict import EasyDict
 2 | import ding.envs.gym_env
 3 | 
 4 | cfg = dict(
 5 |     exp_name='LunarLander-v2-PPO',
 6 |     env_id='LunarLander-v2',
 7 |     n_sample=400,
 8 |     value_norm='popart',
 9 | )
10 | 
11 | cfg = EasyDict(cfg)
12 | 
13 | env = ding.envs.gym_env.env
14 | 


--------------------------------------------------------------------------------
/ding/config/example/PPOF/gym_lunarlandercontinuous_v2.py:
--------------------------------------------------------------------------------
 1 | from easydict import EasyDict
 2 | from functools import partial
 3 | import ding.envs.gym_env
 4 | 
 5 | cfg = dict(
 6 |     exp_name='LunarLanderContinuous-V2-PPO',
 7 |     env_id='LunarLanderContinuous-v2',
 8 |     action_space='continuous',
 9 |     n_sample=400,
10 |     act_scale=True,
11 | )
12 | 
13 | cfg = EasyDict(cfg)
14 | 
15 | env = partial(ding.envs.gym_env.env, continuous=True)
16 | 


--------------------------------------------------------------------------------
/ding/config/example/PPOOffPolicy/__init__.py:
--------------------------------------------------------------------------------
 1 | from easydict import EasyDict
 2 | from . import gym_lunarlander_v2
 3 | from . import gym_pongnoframeskip_v4
 4 | from . import gym_qbertnoframeskip_v4
 5 | from . import gym_spaceInvadersnoframeskip_v4
 6 | 
 7 | supported_env_cfg = {
 8 |     gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg,
 9 |     gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.cfg,
10 |     gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.cfg,
11 |     gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.cfg,
12 | }
13 | 
14 | supported_env_cfg = EasyDict(supported_env_cfg)
15 | 
16 | supported_env = {
17 |     gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env,
18 |     gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.env,
19 |     gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.env,
20 |     gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.env,
21 | }
22 | 
23 | supported_env = EasyDict(supported_env)
24 | 


--------------------------------------------------------------------------------
/ding/config/example/SQL/__init__.py:
--------------------------------------------------------------------------------
 1 | from easydict import EasyDict
 2 | from . import gym_lunarlander_v2
 3 | 
 4 | supported_env_cfg = {
 5 |     gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg,
 6 | }
 7 | 
 8 | supported_env_cfg = EasyDict(supported_env_cfg)
 9 | 
10 | supported_env = {
11 |     gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env,
12 | }
13 | 
14 | supported_env = EasyDict(supported_env)
15 | 


--------------------------------------------------------------------------------
/ding/config/example/TD3/gym_hopper_v3.py:
--------------------------------------------------------------------------------
 1 | from easydict import EasyDict
 2 | import ding.envs.gym_env
 3 | 
 4 | cfg = dict(
 5 |     exp_name='Hopper-v3-TD3',
 6 |     seed=0,
 7 |     env=dict(
 8 |         env_id='Hopper-v3',
 9 |         collector_env_num=8,
10 |         evaluator_env_num=8,
11 |         n_evaluator_episode=8,
12 |         stop_value=6000,
13 |         env_wrapper='mujoco_default',
14 |     ),
15 |     policy=dict(
16 |         cuda=True,
17 |         random_collect_size=25000,
18 |         model=dict(
19 |             obs_shape=11,
20 |             action_shape=3,
21 |             actor_head_hidden_size=256,
22 |             critic_head_hidden_size=256,
23 |             action_space='regression',
24 |         ),
25 |         collect=dict(n_sample=1, ),
26 |         other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
27 |     ),
28 |     wandb_logger=dict(
29 |         gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
30 |     ),
31 | )
32 | 
33 | cfg = EasyDict(cfg)
34 | 
35 | env = ding.envs.gym_env.env
36 | 


--------------------------------------------------------------------------------
/ding/config/example/__init__.py:
--------------------------------------------------------------------------------
 1 | from . import A2C
 2 | from . import C51
 3 | from . import DDPG
 4 | from . import DQN
 5 | from . import PG
 6 | from . import PPOF
 7 | from . import PPOOffPolicy
 8 | from . import SAC
 9 | from . import SQL
10 | from . import TD3
11 | 


--------------------------------------------------------------------------------
/ding/data/__init__.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader
2 | from ding.utils.data import create_dataset, offline_data_save_type  # for compatibility
3 | from .buffer import *
4 | from .storage import *
5 | from .storage_loader import StorageLoader, FileStorageLoader
6 | from .shm_buffer import ShmBufferContainer, ShmBuffer
7 | from .model_loader import ModelLoader, FileModelLoader
8 | 


--------------------------------------------------------------------------------
/ding/data/buffer/__init__.py:
--------------------------------------------------------------------------------
1 | from .buffer import Buffer, apply_middleware, BufferedData
2 | from .deque_buffer import DequeBuffer
3 | from .deque_buffer_wrapper import DequeBufferWrapper
4 | 


--------------------------------------------------------------------------------
/ding/data/buffer/middleware/__init__.py:
--------------------------------------------------------------------------------
1 | from .clone_object import clone_object
2 | from .use_time_check import use_time_check
3 | from .staleness_check import staleness_check
4 | from .priority import PriorityExperienceReplay
5 | from .padding import padding
6 | from .group_sample import group_sample
7 | from .sample_range_view import sample_range_view
8 | 


--------------------------------------------------------------------------------
/ding/data/buffer/middleware/sample_range_view.py:
--------------------------------------------------------------------------------
 1 | from typing import Callable, Any, List, Optional, Union, TYPE_CHECKING
 2 | from ding.data.buffer import BufferedData
 3 | if TYPE_CHECKING:
 4 |     from ding.data.buffer.buffer import Buffer
 5 | 
 6 | 
 7 | def sample_range_view(buffer_: 'Buffer', start: Optional[int] = None, end: Optional[int] = None) -> Callable:
 8 |     """
 9 |     Overview:
10 |         The middleware that places restrictions on the range of indices during sampling.
11 |     Arguments:
12 |         - start (:obj:`int`): The starting index.
13 |         - end (:obj:`int`): One above the ending index.
14 |     """
15 |     assert start is not None or end is not None
16 |     if start and start < 0:
17 |         start = buffer_.size + start
18 |     if end and end < 0:
19 |         end = buffer_.size + end
20 |     sample_range = slice(start, end)
21 | 
22 |     def _sample_range_view(action: str, chain: Callable, *args, **kwargs) -> Any:
23 |         if action == "sample":
24 |             return chain(*args, sample_range=sample_range)
25 |         return chain(*args, **kwargs)
26 | 
27 |     return _sample_range_view
28 | 


--------------------------------------------------------------------------------
/ding/data/level_replay/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/data/level_replay/__init__.py


--------------------------------------------------------------------------------
/ding/data/storage/__init__.py:
--------------------------------------------------------------------------------
1 | from .storage import Storage
2 | from .file import FileStorage, FileModelStorage
3 | 


--------------------------------------------------------------------------------
/ding/data/storage/file.py:
--------------------------------------------------------------------------------
 1 | from typing import Any
 2 | from ding.data.storage import Storage
 3 | import pickle
 4 | 
 5 | from ding.utils.file_helper import read_file, save_file
 6 | 
 7 | 
 8 | class FileStorage(Storage):
 9 | 
10 |     def save(self, data: Any) -> None:
11 |         with open(self.path, "wb") as f:
12 |             pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
13 | 
14 |     def load(self) -> Any:
15 |         with open(self.path, "rb") as f:
16 |             return pickle.load(f)
17 | 
18 | 
19 | class FileModelStorage(Storage):
20 | 
21 |     def save(self, state_dict: object) -> None:
22 |         save_file(self.path, state_dict)
23 | 
24 |     def load(self) -> object:
25 |         return read_file(self.path)
26 | 


--------------------------------------------------------------------------------
/ding/data/storage/storage.py:
--------------------------------------------------------------------------------
 1 | from abc import ABC, abstractmethod
 2 | from typing import Any
 3 | 
 4 | 
 5 | class Storage(ABC):
 6 | 
 7 |     def __init__(self, path: str) -> None:
 8 |         self.path = path
 9 | 
10 |     @abstractmethod
11 |     def save(self, data: Any) -> None:
12 |         raise NotImplementedError
13 | 
14 |     @abstractmethod
15 |     def load(self) -> Any:
16 |         raise NotImplementedError
17 | 


--------------------------------------------------------------------------------
/ding/data/storage/tests/test_storage.py:
--------------------------------------------------------------------------------
 1 | import tempfile
 2 | import pytest
 3 | import os
 4 | from os import path
 5 | from ding.data.storage import FileStorage
 6 | 
 7 | 
 8 | @pytest.mark.unittest
 9 | def test_file_storage():
10 |     path_ = path.join(tempfile.gettempdir(), "test_storage.txt")
11 |     try:
12 |         storage = FileStorage(path=path_)
13 |         storage.save("test")
14 |         content = storage.load()
15 |         assert content == "test"
16 |     finally:
17 |         if path.exists(path_):
18 |             os.remove(path_)
19 | 


--------------------------------------------------------------------------------
/ding/data/tests/test_shm_buffer.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import numpy as np
 3 | import timeit
 4 | from ding.data.shm_buffer import ShmBuffer
 5 | import multiprocessing as mp
 6 | 
 7 | 
 8 | def subprocess(shm_buf):
 9 |     data = np.random.rand(1024, 1024).astype(np.float32)
10 |     res = timeit.repeat(lambda: shm_buf.fill(data), repeat=5, number=1000)
11 |     print("Mean: {:.4f}s, STD: {:.4f}s, Mean each call: {:.4f}ms".format(np.mean(res), np.std(res), np.mean(res)))
12 | 
13 | 
14 | @pytest.mark.benchmark
15 | def test_shm_buffer():
16 |     data = np.random.rand(1024, 1024).astype(np.float32)
17 |     shm_buf = ShmBuffer(data.dtype, data.shape, copy_on_get=False)
18 |     proc = mp.Process(target=subprocess, args=[shm_buf])
19 |     proc.start()
20 |     proc.join()
21 | 


--------------------------------------------------------------------------------
/ding/design/dataloader-sequence.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/design/dataloader-sequence.png


--------------------------------------------------------------------------------
/ding/design/env_state.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/design/env_state.png


--------------------------------------------------------------------------------
/ding/design/parallel_main-sequence.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/design/parallel_main-sequence.png


--------------------------------------------------------------------------------
/ding/design/serial_collector-activity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/design/serial_collector-activity.png


--------------------------------------------------------------------------------
/ding/design/serial_evaluator-activity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/design/serial_evaluator-activity.png


--------------------------------------------------------------------------------
/ding/design/serial_evaluator-activity.puml:
--------------------------------------------------------------------------------
 1 | @startuml serial_evaluator
 2 | header Serial Pipeline
 3 | title Serial Evaluator
 4 | 
 5 | |#99CCCC|serial_controller|
 6 | |#99CCFF|env_manager|
 7 | |#CCCCFF|policy|
 8 | |#FFCCCC|evaluator|
 9 | 
10 | |#99CCCC|serial_controller|
11 | start
12 | :init evaluator, set its \nenv_manager and \neval_mode policy;
13 | |#99CCFF|env_manager|
14 | repeat
15 |   :return current obs;
16 |   |#CCCCFF|policy|
17 |   :<b>[network]</b> forward with obs;
18 |   |#99CCFF|env_manager|
19 |   :env step with action;
20 |   |#FFCCCC|evaluator|
21 |   if (for every env: env i is done?) then (yes)
22 |     |#99CCFF|env_manager|
23 |     :env i reset;
24 |     |#FFCCCC|evaluator|
25 |     :log eval_episode_info;
26 |   endif
27 | repeat while (evaluate episodes are not enough?)
28 | |#FFCCCC|evaluator|
29 | :return eval_episode_return;
30 | stop
31 | @enduml
32 | 


--------------------------------------------------------------------------------
/ding/design/serial_learner-activity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/design/serial_learner-activity.png


--------------------------------------------------------------------------------
/ding/design/serial_learner-activity.puml:
--------------------------------------------------------------------------------
 1 | @startuml serial_learner
 2 | header Serial Pipeline
 3 | title Serial Learner
 4 | 
 5 | |#99CCCC|serial_controller|
 6 | |#CCCCFF|policy|
 7 | |#99CCFF|learner|
 8 | 
 9 | |#99CCCC|serial_controller|
10 | start
11 | :init learner, \nset its learn_mode policy;
12 | |#99CCFF|learner|
13 | :get data from buffer;
14 | |#CCCCFF|policy|
15 | :data forward;
16 | :loss backward;
17 | :optimizer step, gradient update;
18 | |#99CCFF|learner|
19 | :update train info(loss, value) and log;
20 | :update learn info(iteration, priority);
21 | stop
22 | @enduml
23 | 


--------------------------------------------------------------------------------
/ding/design/serial_main-sequence.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/design/serial_main-sequence.png


--------------------------------------------------------------------------------
/ding/entry/cli_parsers/__init__.py:
--------------------------------------------------------------------------------
1 | from .slurm_parser import slurm_parser
2 | from .k8s_parser import k8s_parser
3 | PLATFORM_PARSERS = {"slurm": slurm_parser, "k8s": k8s_parser}
4 | 


--------------------------------------------------------------------------------
/ding/entry/tests/config/agconfig.yaml:
--------------------------------------------------------------------------------
 1 | apiVersion: diengine.opendilab.org/v1alpha1
 2 | kind: AggregatorConfig
 3 | metadata:
 4 |   name: aggregator-config
 5 |   namespace: di-system
 6 | spec:
 7 |   aggregator:
 8 |     template:
 9 |       spec:
10 |         containers:
11 |         - name: di-container
12 |           image: diorchestrator/ding:v0.1.1
13 |           imagePullPolicy: IfNotPresent
14 |           env:
15 |           - name: PYTHONUNBUFFERED
16 |             value: "1"
17 |           command: ["/bin/bash", "-c",]
18 |           args:
19 |           - |
20 |             # if code has been changed in the mount path, we have to reinstall  cli
21 |             # pip install --no-cache-dir -e .;
22 |             # pip install --no-cache-dir -e .[common_env]
23 | 
24 |             ding -m dist --module learner_aggregator
25 |           ports:
26 |           - name: di-port
27 |             containerPort: 22270
28 | 


--------------------------------------------------------------------------------
/ding/entry/tests/config/k8s-config.yaml:
--------------------------------------------------------------------------------
1 | type: k3s # k3s or local
2 | name: di-cluster
3 | servers: 1 # # of k8s masters
4 | agents: 0 # # of k8s nodes
5 | preload_images:
6 | - diorchestrator/ding:v0.1.1 # di-engine image for training should be preloaded
7 | 


--------------------------------------------------------------------------------
/ding/entry/tests/test_cli_ditask.py:
--------------------------------------------------------------------------------
 1 | from time import sleep
 2 | import pytest
 3 | import pathlib
 4 | import os
 5 | from ding.entry.cli_ditask import _cli_ditask
 6 | 
 7 | 
 8 | def cli_ditask_main():
 9 |     sleep(0.1)
10 | 
11 | 
12 | @pytest.mark.unittest
13 | def test_cli_ditask():
14 |     kwargs = {
15 |         "package": os.path.dirname(pathlib.Path(__file__)),
16 |         "main": "test_cli_ditask.cli_ditask_main",
17 |         "parallel_workers": 1,
18 |         "topology": "mesh",
19 |         "platform": "k8s",
20 |         "protocol": "tcp",
21 |         "ports": 50501,
22 |         "attach_to": "",
23 |         "address": "127.0.0.1",
24 |         "labels": "",
25 |         "node_ids": 0,
26 |         "mq_type": "nng",
27 |         "redis_host": "",
28 |         "redis_port": "",
29 |         "startup_interval": 1
30 |     }
31 |     os.environ["DI_NODES"] = '127.0.0.1'
32 |     os.environ["DI_RANK"] = '0'
33 |     try:
34 |         _cli_ditask(**kwargs)
35 |     finally:
36 |         del os.environ["DI_NODES"]
37 |         del os.environ["DI_RANK"]
38 | 


--------------------------------------------------------------------------------
/ding/entry/tests/test_parallel_entry.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | from copy import deepcopy
 3 | from ding.entry import parallel_pipeline
 4 | from dizoo.classic_control.cartpole.config.parallel.cartpole_dqn_config import main_config, create_config,\
 5 |     system_config
 6 | 
 7 | 
 8 | # @pytest.mark.unittest
 9 | @pytest.mark.execution_timeout(120.0, method='thread')
10 | def test_dqn():
11 |     config = tuple([deepcopy(main_config), deepcopy(create_config), deepcopy(system_config)])
12 |     config[0].env.stop_value = 9
13 |     try:
14 |         parallel_pipeline(config, seed=0)
15 |     except Exception:
16 |         assert False, "pipeline fail"
17 | 


--------------------------------------------------------------------------------
/ding/entry/tests/test_serial_entry_guided_cost.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import torch
 3 | from copy import deepcopy
 4 | from ding.entry import serial_pipeline_onpolicy, serial_pipeline_guided_cost
 5 | from dizoo.classic_control.cartpole.config import cartpole_ppo_config, cartpole_ppo_create_config
 6 | from dizoo.classic_control.cartpole.config import cartpole_gcl_ppo_onpolicy_config, \
 7 |     cartpole_gcl_ppo_onpolicy_create_config
 8 | 
 9 | 
10 | @pytest.mark.unittest
11 | def test_guided_cost():
12 |     expert_policy_state_dict_path = './expert_policy.pth'
13 |     config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)]
14 |     expert_policy = serial_pipeline_onpolicy(config, seed=0)
15 |     torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path)
16 | 
17 |     config = [deepcopy(cartpole_gcl_ppo_onpolicy_config), deepcopy(cartpole_gcl_ppo_onpolicy_create_config)]
18 |     config[0].policy.collect.model_path = expert_policy_state_dict_path
19 |     config[0].policy.learn.update_per_collect = 1
20 |     try:
21 |         serial_pipeline_guided_cost(config, seed=0, max_train_iter=1)
22 |     except Exception:
23 |         assert False, "pipeline fail"
24 | 


--------------------------------------------------------------------------------
/ding/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .env import *
2 | from .env_wrappers import *
3 | from .env_manager import *
4 | from .env_manager.ding_env_manager import setup_ding_env_manager
5 | from . import gym_env
6 | 


--------------------------------------------------------------------------------
/ding/envs/common/__init__.py:
--------------------------------------------------------------------------------
1 | from .common_function import num_first_one_hot, sqrt_one_hot, div_one_hot, div_func, clip_one_hot, \
2 |     reorder_one_hot, reorder_one_hot_array, reorder_boolean_vector, affine_transform, \
3 |     batch_binary_encode, get_postion_vector, save_frames_as_gif
4 | from .env_element import EnvElement, EnvElementInfo
5 | from .env_element_runner import EnvElementRunner
6 | 


--------------------------------------------------------------------------------
/ding/envs/common/env_element_runner.py:
--------------------------------------------------------------------------------
 1 | from abc import abstractmethod
 2 | from typing import Any
 3 | 
 4 | from .env_element import EnvElement, IEnvElement, EnvElementInfo
 5 | from ..env.base_env import BaseEnv
 6 | 
 7 | 
 8 | class IEnvElementRunner(IEnvElement):
 9 | 
10 |     @abstractmethod
11 |     def get(self, engine: BaseEnv) -> Any:
12 |         raise NotImplementedError
13 | 
14 |     @abstractmethod
15 |     def reset(self, *args, **kwargs) -> None:
16 |         raise NotImplementedError
17 | 
18 | 
19 | class EnvElementRunner(IEnvElementRunner):
20 | 
21 |     def __init__(self, *args, **kwargs) -> None:
22 |         self._init(*args, **kwargs)
23 |         self._check()
24 | 
25 |     @abstractmethod
26 |     def _init(self, *args, **kwargs) -> None:
27 |         # set self._core and other state variable
28 |         raise NotImplementedError
29 | 
30 |     def _check(self) -> None:
31 |         flag = [hasattr(self, '_core'), isinstance(self._core, EnvElement)]
32 |         assert all(flag), flag
33 | 
34 |     def __repr__(self) -> str:
35 |         return repr(self._core)
36 | 
37 |     @property
38 |     def info(self) -> 'EnvElementInfo':
39 |         return self._core.info
40 | 


--------------------------------------------------------------------------------
/ding/envs/env/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_env import BaseEnv, get_vec_env_setting, BaseEnvTimestep, get_env_cls, create_model_env
2 | from .ding_env_wrapper import DingEnvWrapper
3 | from .default_wrapper import get_default_wrappers
4 | from .env_implementation_check import check_space_dtype, check_array_space, check_reset, check_step, \
5 |     check_different_memory, check_obs_deepcopy, check_all, demonstrate_correct_procedure
6 | 


--------------------------------------------------------------------------------
/ding/envs/env/tests/__init__.py:
--------------------------------------------------------------------------------
1 | from .demo_env import DemoEnv
2 | 


--------------------------------------------------------------------------------
/ding/envs/env_manager/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_env_manager import BaseEnvManager, BaseEnvManagerV2, create_env_manager, get_env_manager_cls
2 | from .subprocess_env_manager import AsyncSubprocessEnvManager, SyncSubprocessEnvManager, SubprocessEnvManagerV2
3 | from .gym_vector_env_manager import GymVectorEnvManager
4 | # Do not import PoolEnvManager here, because it depends on installation of `envpool`
5 | from .env_supervisor import EnvSupervisor
6 | 


--------------------------------------------------------------------------------
/ding/envs/env_manager/ding_env_manager.py:
--------------------------------------------------------------------------------
 1 | from . import BaseEnvManagerV2, SubprocessEnvManagerV2
 2 | from ..env import DingEnvWrapper
 3 | from typing import Optional
 4 | from functools import partial
 5 | 
 6 | 
 7 | def setup_ding_env_manager(
 8 |         env: DingEnvWrapper,
 9 |         env_num: int,
10 |         context: Optional[str] = None,
11 |         debug: bool = False,
12 |         caller: str = 'collector'
13 | ) -> BaseEnvManagerV2:
14 |     assert caller in ['evaluator', 'collector']
15 |     if debug:
16 |         env_cls = BaseEnvManagerV2
17 |         manager_cfg = env_cls.default_config()
18 |     else:
19 |         env_cls = SubprocessEnvManagerV2
20 |         manager_cfg = env_cls.default_config()
21 |         if context is not None:
22 |             manager_cfg.context = context
23 |     return env_cls([partial(env.clone, caller) for _ in range(env_num)], manager_cfg)
24 | 


--------------------------------------------------------------------------------
/ding/envs/env_manager/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/envs/env_manager/tests/__init__.py


--------------------------------------------------------------------------------
/ding/envs/env_manager/tests/test_shm.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import time
 3 | import numpy as np
 4 | import torch
 5 | from multiprocessing import Process
 6 | 
 7 | from ding.envs.env_manager.subprocess_env_manager import ShmBuffer
 8 | 
 9 | 
10 | def writer(shm):
11 |     while True:
12 |         shm.fill(np.random.random(size=(4, 84, 84)).astype(np.float32))
13 |         time.sleep(1)
14 | 
15 | 
16 | @pytest.mark.unittest
17 | def test_shm():
18 | 
19 |     shm = ShmBuffer(dtype=np.float32, shape=(4, 84, 84), copy_on_get=False)
20 |     writer_process = Process(target=writer, args=(shm, ))
21 |     writer_process.start()
22 | 
23 |     time.sleep(0.1)
24 | 
25 |     data1 = shm.get()
26 |     time.sleep(1)
27 |     data2 = shm.get()
28 |     # same memory
29 |     assert (data1 == data2).all()
30 | 
31 |     time.sleep(1)
32 |     data3 = shm.get().copy()
33 |     time.sleep(1)
34 |     data4 = shm.get()
35 |     assert (data3 != data4).all()
36 | 
37 |     writer_process.terminate()
38 | 


--------------------------------------------------------------------------------
/ding/envs/env_wrappers/__init__.py:
--------------------------------------------------------------------------------
1 | from .env_wrappers import *
2 | 


--------------------------------------------------------------------------------
/ding/envs/gym_env.py:
--------------------------------------------------------------------------------
1 | from ding.envs import BaseEnv, DingEnvWrapper
2 | 
3 | 
4 | def env(cfg, seed_api=True, caller='collector', **kwargs) -> BaseEnv:
5 |     import gym
6 |     return DingEnvWrapper(gym.make(cfg.env_id, **kwargs), cfg=cfg, seed_api=seed_api, caller=caller)
7 | 


--------------------------------------------------------------------------------
/ding/example/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/example/__init__.py


--------------------------------------------------------------------------------
/ding/framework/__init__.py:
--------------------------------------------------------------------------------
 1 | from .context import Context, OnlineRLContext, OfflineRLContext
 2 | from .task import Task, task, VoidMiddleware
 3 | from .parallel import Parallel
 4 | from .event_loop import EventLoop
 5 | from .supervisor import Supervisor
 6 | from easydict import EasyDict
 7 | from ding.utils import DistributedWriter
 8 | 
 9 | 
10 | def ding_init(cfg: EasyDict):
11 |     DistributedWriter.get_instance(cfg.exp_name)
12 | 


--------------------------------------------------------------------------------
/ding/framework/message_queue/__init__.py:
--------------------------------------------------------------------------------
1 | from .mq import MQ
2 | from .redis import RedisMQ
3 | from .nng import NNGMQ
4 | 


--------------------------------------------------------------------------------
/ding/framework/message_queue/tests/test_nng.py:
--------------------------------------------------------------------------------
 1 | from time import sleep
 2 | import pytest
 3 | 
 4 | import multiprocessing as mp
 5 | from ding.framework.message_queue.nng import NNGMQ
 6 | 
 7 | 
 8 | def nng_main(i):
 9 |     if i == 0:
10 |         listen_to = "tcp://127.0.0.1:50515"
11 |         attach_to = None
12 |         mq = NNGMQ(listen_to=listen_to, attach_to=attach_to)
13 |         mq.listen()
14 |         for _ in range(10):
15 |             mq.publish("t", b"data")
16 |             sleep(0.1)
17 |     else:
18 |         listen_to = "tcp://127.0.0.1:50516"
19 |         attach_to = ["tcp://127.0.0.1:50515"]
20 |         mq = NNGMQ(listen_to=listen_to, attach_to=attach_to)
21 |         mq.listen()
22 |         topic, msg = mq.recv()
23 |         assert topic == "t"
24 |         assert msg == b"data"
25 | 
26 | 
27 | @pytest.mark.unittest
28 | @pytest.mark.execution_timeout(10)
29 | def test_nng():
30 |     ctx = mp.get_context("spawn")
31 |     with ctx.Pool(processes=2) as pool:
32 |         pool.map(nng_main, range(2))
33 | 


--------------------------------------------------------------------------------
/ding/framework/middleware/__init__.py:
--------------------------------------------------------------------------------
1 | from .functional import *
2 | from .collector import StepCollector, EpisodeCollector, PPOFStepCollector
3 | from .learner import OffPolicyLearner, HERLearner
4 | from .ckpt_handler import CkptSaver
5 | from .distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger
6 | from .barrier import Barrier, BarrierRuntime
7 | from .data_fetcher import OfflineMemoryDataFetcher
8 | 


--------------------------------------------------------------------------------
/ding/framework/middleware/functional/__init__.py:
--------------------------------------------------------------------------------
 1 | from .trainer import trainer, multistep_trainer
 2 | from .data_processor import offpolicy_data_fetcher, data_pusher, offline_data_fetcher, offline_data_saver, \
 3 |     offline_data_fetcher_from_mem, sqil_data_pusher, buffer_saver
 4 | from .collector import inferencer, rolloutor, TransitionList
 5 | from .evaluator import interaction_evaluator, interaction_evaluator_ttorch
 6 | from .termination_checker import termination_checker, ddp_termination_checker
 7 | from .logger import online_logger, offline_logger, wandb_online_logger, wandb_offline_logger
 8 | from .ctx_helper import final_ctx_saver
 9 | 
10 | # algorithm
11 | from .explorer import eps_greedy_handler, eps_greedy_masker
12 | from .advantage_estimator import gae_estimator, ppof_adv_estimator, montecarlo_return_estimator
13 | from .enhancer import reward_estimator, her_data_enhancer, nstep_reward_enhancer
14 | from .priority import priority_calculator
15 | from .timer import epoch_timer
16 | 


--------------------------------------------------------------------------------
/ding/framework/middleware/functional/ctx_helper.py:
--------------------------------------------------------------------------------
 1 | from typing import TYPE_CHECKING, Callable
 2 | import os
 3 | import pickle
 4 | import dataclasses
 5 | from ding.framework import task
 6 | if TYPE_CHECKING:
 7 |     from ding.framework import Context
 8 | 
 9 | 
10 | def final_ctx_saver(name: str) -> Callable:
11 | 
12 |     def _save(ctx: "Context"):
13 |         if task.finish:
14 |             # make sure the items to be recorded are all kept in the context
15 |             with open(os.path.join(name, 'result.pkl'), 'wb') as f:
16 |                 final_data = {
17 |                     'total_step': ctx.total_step,
18 |                     'train_iter': ctx.train_iter,
19 |                     'last_eval_iter': ctx.last_eval_iter,
20 |                     'eval_value': ctx.last_eval_value,
21 |                 }
22 |                 if ctx.has_attr('env_step'):
23 |                     final_data['env_step'] = ctx.env_step
24 |                     final_data['env_episode'] = ctx.env_episode
25 |                 if ctx.has_attr('trained_env_step'):
26 |                     final_data['trained_env_step'] = ctx.trained_env_step
27 |                 pickle.dump(final_data, f)
28 | 
29 |     return _save
30 | 


--------------------------------------------------------------------------------
/ding/framework/middleware/functional/priority.py:
--------------------------------------------------------------------------------
 1 | from typing import TYPE_CHECKING, Callable
 2 | from ding.framework import task
 3 | if TYPE_CHECKING:
 4 |     from ding.framework import OnlineRLContext
 5 | 
 6 | 
 7 | def priority_calculator(priority_calculation_fn: Callable) -> Callable:
 8 |     """
 9 |     Overview:
10 |         The middleware that calculates the priority of the collected data.
11 |     Arguments:
12 |         - priority_calculation_fn (:obj:`Callable`): The function that calculates the priority of the collected data.
13 |     """
14 | 
15 |     if task.router.is_active and not task.has_role(task.role.COLLECTOR):
16 |         return task.void()
17 | 
18 |     def _priority_calculator(ctx: "OnlineRLContext") -> None:
19 | 
20 |         priority = priority_calculation_fn(ctx.trajectories)
21 |         for i in range(len(priority)):
22 |             ctx.trajectories[i]['priority'] = priority[i]
23 | 
24 |     return _priority_calculator
25 | 


--------------------------------------------------------------------------------
/ding/framework/middleware/functional/timer.py:
--------------------------------------------------------------------------------
 1 | import numpy as np
 2 | from collections import deque
 3 | from ditk import logging
 4 | from time import time
 5 | 
 6 | from ding.framework import task
 7 | from typing import TYPE_CHECKING
 8 | if TYPE_CHECKING:
 9 |     from ding.framework.context import Context
10 | 
11 | 
12 | def epoch_timer(print_per: int = 1, smooth_window: int = 10):
13 |     """
14 |     Overview:
15 |         Print time cost of each epoch.
16 |     Arguments:
17 |         - print_per (:obj:`int`): Print each N epoch.
18 |         - smooth_window (:obj:`int`): The window size to smooth the mean.
19 |     """
20 |     records = deque(maxlen=print_per * smooth_window)
21 | 
22 |     def _epoch_timer(ctx: "Context"):
23 |         start = time()
24 |         yield
25 |         time_cost = time() - start
26 |         records.append(time_cost)
27 |         if ctx.total_step % print_per == 0:
28 |             logging.info(
29 |                 "[Epoch Timer][Node:{:>2}]: Cost: {:.2f}ms, Mean: {:.2f}ms".format(
30 |                     task.router.node_id or 0, time_cost * 1000,
31 |                     np.mean(records) * 1000
32 |                 )
33 |             )
34 | 
35 |     return _epoch_timer
36 | 


--------------------------------------------------------------------------------
/ding/framework/middleware/tests/__init__.py:
--------------------------------------------------------------------------------
1 | from .mock_for_test import MockEnv, MockPolicy, MockHerRewardModel, CONFIG
2 | 


--------------------------------------------------------------------------------
/ding/framework/middleware/tests/test_explorer.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import copy
 3 | from ding.framework import OnlineRLContext
 4 | from ding.framework.middleware import eps_greedy_handler, eps_greedy_masker
 5 | from ding.framework.middleware.tests import MockPolicy, MockEnv, CONFIG
 6 | 
 7 | 
 8 | @pytest.mark.unittest
 9 | def test_eps_greedy_handler():
10 |     cfg = copy.deepcopy(CONFIG)
11 |     ctx = OnlineRLContext()
12 | 
13 |     ctx.env_step = 0
14 |     next(eps_greedy_handler(cfg)(ctx))
15 |     assert ctx.collect_kwargs['eps'] == 0.95
16 | 
17 |     ctx.env_step = 1000000
18 |     next(eps_greedy_handler(cfg)(ctx))
19 |     assert ctx.collect_kwargs['eps'] == 0.1
20 | 
21 | 
22 | @pytest.mark.unittest
23 | def test_eps_greedy_masker():
24 |     ctx = OnlineRLContext()
25 |     for _ in range(10):
26 |         eps_greedy_masker()(ctx)
27 |     assert ctx.collect_kwargs['eps'] == -1
28 | 


--------------------------------------------------------------------------------
/ding/framework/middleware/tests/test_priority.py:
--------------------------------------------------------------------------------
 1 | #unittest for priority_calculator
 2 | 
 3 | import unittest
 4 | import pytest
 5 | import numpy as np
 6 | from unittest.mock import Mock, patch
 7 | from ding.framework import OnlineRLContext, OfflineRLContext
 8 | from ding.framework import task, Parallel
 9 | from ding.framework.middleware.functional import priority_calculator
10 | 
11 | 
12 | class MockPolicy(Mock):
13 | 
14 |     def priority_fun(self, data):
15 |         return np.random.rand(len(data))
16 | 
17 | 
18 | @pytest.mark.unittest
19 | def test_priority_calculator():
20 |     policy = MockPolicy()
21 |     ctx = OnlineRLContext()
22 |     ctx.trajectories = [
23 |         {
24 |             'obs': np.random.rand(2, 2),
25 |             'next_obs': np.random.rand(2, 2),
26 |             'reward': np.random.rand(1),
27 |             'info': {}
28 |         } for _ in range(10)
29 |     ]
30 |     priority_calculator_middleware = priority_calculator(priority_calculation_fn=policy.priority_fun)
31 |     priority_calculator_middleware(ctx)
32 |     assert len(ctx.trajectories) == 10
33 |     assert all([isinstance(traj['priority'], float) for traj in ctx.trajectories])
34 | 


--------------------------------------------------------------------------------
/ding/framework/wrapper/__init__.py:
--------------------------------------------------------------------------------
1 | from .step_timer import StepTimer
2 | 


--------------------------------------------------------------------------------
/ding/hpc_rl/README.md:
--------------------------------------------------------------------------------
 1 | Step 0. clean old version
 2 | rm ~/.local/lib/python3.6/site-packages/hpc_*.so
 3 | rm ~/.local/lib/python3.6/site-packages/hpc_rl* -rf
 4 | rm ~/.local/lib/python3.6/site-packages/di_hpc_rl* -rf
 5 | 
 6 | Step 1.
 7 | pip install di_hpc_rll-0.0.1-cp36-cp36m-linux_x86_64.whl --user
 8 | ls ~/.local/lib/python3.6/site-packages/di_hpc_rl*
 9 | ls ~/.local/lib/python3.6/site-packages/hpc_rl*
10 | 
11 | Step 2.
12 | python3 tests/test_gae.py


--------------------------------------------------------------------------------
/ding/hpc_rl/__init__.py:
--------------------------------------------------------------------------------
1 | from .wrapper import hpc_wrapper
2 | 


--------------------------------------------------------------------------------
/ding/hpc_rl/tests/testbase.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | import numpy as np
 3 | 
 4 | torch.set_printoptions(precision=6)
 5 | 
 6 | times = 6
 7 | 
 8 | 
 9 | def mean_relative_error(y_true, y_pred):
10 |     eps = 1e-5
11 |     relative_error = np.average(np.abs(y_true - y_pred) / (y_true + eps), axis=0)
12 |     return relative_error
13 | 


--------------------------------------------------------------------------------
/ding/interaction/__init__.py:
--------------------------------------------------------------------------------
1 | from .master import *
2 | from .slave import *
3 | 


--------------------------------------------------------------------------------
/ding/interaction/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .app import CommonErrorCode, success_response, failure_response, get_values_from_response, flask_response, \
2 |     ResponsibleException, responsible
3 | from .common import random_token, translate_dict_func, ControllableService, ControllableContext, default_func
4 | from .network import get_host_ip, get_http_engine_class, HttpEngine, split_http_address
5 | from .threading import DblEvent
6 | 


--------------------------------------------------------------------------------
/ding/interaction/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import MIN_HEARTBEAT_CHECK_SPAN, MIN_HEARTBEAT_SPAN, DEFAULT_MASTER_PORT, DEFAULT_SLAVE_PORT, \
2 |     DEFAULT_CHANNEL, DEFAULT_HEARTBEAT_CHECK_SPAN, DEFAULT_HEARTBEAT_TOLERANCE, DEFAULT_HEARTBEAT_SPAN, LOCAL_HOST, \
3 |     GLOBAL_HOST, DEFAULT_REQUEST_RETRIES, DEFAULT_REQUEST_RETRY_WAITING
4 | 


--------------------------------------------------------------------------------
/ding/interaction/config/base.py:
--------------------------------------------------------------------------------
 1 | # System configs
 2 | GLOBAL_HOST = '0.0.0.0'
 3 | LOCAL_HOST = '127.0.0.1'
 4 | 
 5 | # General request
 6 | DEFAULT_REQUEST_RETRIES = 5
 7 | DEFAULT_REQUEST_RETRY_WAITING = 1.0
 8 | 
 9 | # Slave configs
10 | MIN_HEARTBEAT_SPAN = 0.2
11 | DEFAULT_HEARTBEAT_SPAN = 3.0
12 | DEFAULT_SLAVE_PORT = 7236
13 | 
14 | # Master configs
15 | MIN_HEARTBEAT_CHECK_SPAN = 0.1
16 | DEFAULT_HEARTBEAT_CHECK_SPAN = 1.0
17 | DEFAULT_HEARTBEAT_TOLERANCE = 17.0
18 | DEFAULT_MASTER_PORT = 7235
19 | 
20 | # Two-side configs
21 | DEFAULT_CHANNEL = 0
22 | 


--------------------------------------------------------------------------------
/ding/interaction/exception/__init__.py:
--------------------------------------------------------------------------------
 1 | from .base import ResponseException
 2 | from .master import MasterErrorCode, get_master_exception_by_error, MasterResponseException, MasterSuccess, \
 3 |     MasterChannelInvalid, MasterChannelNotGiven, MasterMasterTokenInvalid, MasterMasterTokenNotGiven, \
 4 |     MasterSelfTokenInvalid, MasterSelfTokenNotGiven, MasterSlaveTokenInvalid, MasterSlaveTokenNotGiven, \
 5 |     MasterSystemShuttingDown, MasterTaskDataInvalid
 6 | from .slave import SlaveErrorCode, get_slave_exception_by_error, SlaveResponseException, SlaveSuccess, \
 7 |     SlaveChannelInvalid, SlaveChannelNotFound, SlaveSelfTokenInvalid, SlaveTaskAlreadyExist, SlaveTaskRefused, \
 8 |     SlaveMasterTokenInvalid, SlaveMasterTokenNotFound, SlaveSelfTokenNotFound, SlaveSlaveAlreadyConnected, \
 9 |     SlaveSlaveConnectionRefused, SlaveSlaveDisconnectionRefused, SlaveSlaveNotConnected, SlaveSystemShuttingDown
10 | 


--------------------------------------------------------------------------------
/ding/interaction/master/__init__.py:
--------------------------------------------------------------------------------
1 | from .master import Master
2 | 


--------------------------------------------------------------------------------
/ding/interaction/master/base.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Mapping, Any, Optional
2 | 
3 | from requests import RequestException
4 | 
5 | _BEFORE_HOOK_TYPE = Callable[..., Mapping[str, Any]]
6 | _AFTER_HOOK_TYPE = Callable[[int, bool, int, Optional[str], Optional[Mapping[str, Any]]], Any]
7 | _ERROR_HOOK_TYPE = Callable[[RequestException], Any]
8 | 


--------------------------------------------------------------------------------
/ding/interaction/slave/__init__.py:
--------------------------------------------------------------------------------
1 | from .action import TaskRefuse, DisconnectionRefuse, ConnectionRefuse, TaskFail
2 | from .slave import Slave
3 | 


--------------------------------------------------------------------------------
/ding/interaction/tests/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import *
2 | from .config import *
3 | from .exception import *
4 | from .interaction import *
5 | 


--------------------------------------------------------------------------------
/ding/interaction/tests/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .test_app import TestInteractionBaseApp, TestInteractionBaseResponsibleException
2 | from .test_common import TestInteractionBaseCommon, TestInteractionBaseControllableService
3 | from .test_network import TestInteractionBaseHttpEngine, TestInteractionBaseNetwork
4 | from .test_threading import TestInteractionBaseThreading
5 | 


--------------------------------------------------------------------------------
/ding/interaction/tests/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .test_base import TestInteractionConfig
2 | 


--------------------------------------------------------------------------------
/ding/interaction/tests/config/test_base.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | 
 3 | from ...config import GLOBAL_HOST, LOCAL_HOST
 4 | 
 5 | 
 6 | @pytest.mark.unittest
 7 | class TestInteractionConfig:
 8 | 
 9 |     def test_base_host(self):
10 |         assert GLOBAL_HOST == '0.0.0.0'
11 |         assert LOCAL_HOST == '127.0.0.1'
12 | 


--------------------------------------------------------------------------------
/ding/interaction/tests/exception/__init__.py:
--------------------------------------------------------------------------------
1 | from .test_master import TestInteractionExceptionMaster
2 | from .test_slave import TestInteractionExceptionSlave
3 | 


--------------------------------------------------------------------------------
/ding/interaction/tests/interaction/__init__.py:
--------------------------------------------------------------------------------
1 | from .test_errors import TestInteractionErrors
2 | from .test_simple import TestInteractionSimple
3 | 


--------------------------------------------------------------------------------
/ding/interaction/tests/test_utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .random import random_port, random_channel
2 | from .stream import silence, silence_function
3 | 


--------------------------------------------------------------------------------
/ding/interaction/tests/test_utils/random.py:
--------------------------------------------------------------------------------
 1 | import random
 2 | from typing import Iterable
 3 | 
 4 | 
 5 | def random_port(excludes: Iterable[int] = None) -> int:
 6 |     return random.choice(list(set(range(10000, 20000)) - set(excludes or [])))
 7 | 
 8 | 
 9 | def random_channel(excludes: Iterable[int] = None) -> int:
10 |     excludes = set(list(excludes or []))
11 |     while True:
12 |         _channel = random.randint(1000, (1 << 31) - 1)
13 |         if _channel not in excludes:
14 |             return _channel
15 | 


--------------------------------------------------------------------------------
/ding/league/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_league import BaseLeague, create_league
2 | from .one_vs_one_league import OneVsOneLeague
3 | from .player import Player, ActivePlayer, HistoricalPlayer, create_player
4 | from .starcraft_player import MainPlayer, MainExploiter, LeagueExploiter
5 | from .shared_payoff import create_payoff
6 | from .metric import get_elo, get_elo_array, LeagueMetricEnv
7 | 


--------------------------------------------------------------------------------
/ding/league/tests/conftest.py:
--------------------------------------------------------------------------------
 1 | import numpy as np
 2 | import pytest
 3 | 
 4 | 
 5 | @pytest.fixture(scope='session')
 6 | def random_job_result():
 7 | 
 8 |     def fn():
 9 |         p = np.random.uniform()
10 |         if p < 1. / 3:
11 |             return "wins"
12 |         elif p < 2. / 3:
13 |             return "draws"
14 |         else:
15 |             return "losses"
16 | 
17 |     return fn
18 | 
19 | 
20 | @pytest.fixture(scope='session')
21 | def get_job_result_categories():
22 |     return ["wins", 'draws', 'losses']
23 | 


--------------------------------------------------------------------------------
/ding/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .common import *
2 | from .template import *
3 | from .wrapper import *
4 | 


--------------------------------------------------------------------------------
/ding/model/common/__init__.py:
--------------------------------------------------------------------------------
1 | from .head import DiscreteHead, DuelingHead, DistributionHead, RainbowHead, QRDQNHead, StochasticDuelingHead, \
2 |     QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, BranchingHead, head_cls_map, \
3 |     independent_normal_dist, AttentionPolicyHead, PopArtVHead, EnsembleHead
4 | from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder, GaussianFourierProjectionTimeEncoder
5 | from .utils import create_model
6 | 


--------------------------------------------------------------------------------
/ding/model/template/__init__.py:
--------------------------------------------------------------------------------
 1 | # general
 2 | from .q_learning import DQN, RainbowDQN, QRDQN, IQN, FQF, DRQN, C51DQN, BDQ, GTrXLDQN
 3 | from .qac import DiscreteQAC, ContinuousQAC
 4 | from .pdqn import PDQN
 5 | from .vac import VAC, DREAMERVAC
 6 | from .bc import DiscreteBC, ContinuousBC
 7 | from .language_transformer import LanguageTransformer
 8 | # algorithm-specific
 9 | from .pg import PG
10 | from .ppg import PPG
11 | from .qmix import Mixer, QMix
12 | from .collaq import CollaQ
13 | from .wqmix import WQMix
14 | from .coma import COMA
15 | from .atoc import ATOC
16 | from .sqn import SQN
17 | from .acer import ACER
18 | from .qtran import QTran
19 | from .mavac import MAVAC
20 | from .ngu import NGU
21 | from .qac_dist import QACDIST
22 | from .maqac import DiscreteMAQAC, ContinuousMAQAC
23 | from .madqn import MADQN
24 | from .vae import VanillaVAE
25 | from .decision_transformer import DecisionTransformer
26 | from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS
27 | from .bcq import BCQ
28 | from .edac import EDAC
29 | from .hpt import HPT
30 | from .qgpo import QGPO
31 | from .ebm import EBM, AutoregressiveEBM
32 | from .havac import HAVAC
33 | 


--------------------------------------------------------------------------------
/ding/model/template/sqn.py:
--------------------------------------------------------------------------------
 1 | from typing import Dict
 2 | import torch
 3 | import torch.nn as nn
 4 | 
 5 | from ding.utils import MODEL_REGISTRY
 6 | from .q_learning import DQN
 7 | 
 8 | 
 9 | @MODEL_REGISTRY.register('sqn')
10 | class SQN(nn.Module):
11 | 
12 |     def __init__(self, *args, **kwargs) -> None:
13 |         super(SQN, self).__init__()
14 |         self.q0 = DQN(*args, **kwargs)
15 |         self.q1 = DQN(*args, **kwargs)
16 | 
17 |     def forward(self, data: torch.Tensor) -> Dict:
18 |         output0 = self.q0(data)
19 |         output1 = self.q1(data)
20 |         return {
21 |             'q_value': [output0['logit'], output1['logit']],
22 |             'logit': output0['logit'],
23 |         }
24 | 


--------------------------------------------------------------------------------
/ding/model/wrapper/__init__.py:
--------------------------------------------------------------------------------
1 | from .model_wrappers import model_wrap, register_wrapper, IModelWrapper
2 | 


--------------------------------------------------------------------------------
/ding/policy/mbpolicy/__init__.py:
--------------------------------------------------------------------------------
1 | from .mbsac import MBSACPolicy
2 | from .dreamer import DREAMERPolicy
3 | 


--------------------------------------------------------------------------------
/ding/policy/mbpolicy/tests/test_mbpolicy_utils.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import torch
 3 | from ding.policy.mbpolicy.utils import q_evaluation
 4 | 
 5 | 
 6 | @pytest.mark.unittest
 7 | def test_q_evaluation():
 8 |     T, B, O, A = 10, 20, 100, 30
 9 |     obss = torch.randn(T, B, O)
10 |     actions = torch.randn(T, B, A)
11 | 
12 |     def fake_q_fn(obss, actions):
13 |         # obss:    flatten_B * O
14 |         # actions: flatten_B * A
15 |         # return:  flatten_B
16 |         return obss.sum(-1)
17 | 
18 |     q_value = q_evaluation(obss, actions, fake_q_fn)
19 |     assert q_value.shape == (T, B)
20 | 


--------------------------------------------------------------------------------
/ding/reward_model/__init__.py:
--------------------------------------------------------------------------------
 1 | from .base_reward_model import BaseRewardModel, create_reward_model, get_reward_model_cls
 2 | # inverse RL
 3 | from .pdeil_irl_model import PdeilRewardModel
 4 | from .gail_irl_model import GailRewardModel
 5 | from .pwil_irl_model import PwilRewardModel
 6 | from .red_irl_model import RedRewardModel
 7 | from .trex_reward_model import TrexRewardModel
 8 | from .drex_reward_model import DrexRewardModel
 9 | # sparse reward
10 | from .her_reward_model import HerRewardModel
11 | # exploration
12 | from .rnd_reward_model import RndRewardModel
13 | from .guided_cost_reward_model import GuidedCostRewardModel
14 | from .ngu_reward_model import RndNGURewardModel, EpisodicNGURewardModel
15 | from .icm_reward_model import ICMRewardModel
16 | 


--------------------------------------------------------------------------------
/ding/rl_utils/tests/test_gae.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import torch
 3 | from ding.rl_utils import gae_data, gae
 4 | 
 5 | 
 6 | @pytest.mark.unittest
 7 | def test_gae():
 8 |     # batch trajectory case
 9 |     T, B = 32, 4
10 |     value = torch.randn(T, B)
11 |     next_value = torch.randn(T, B)
12 |     reward = torch.randn(T, B)
13 |     done = torch.zeros((T, B))
14 |     data = gae_data(value, next_value, reward, done, None)
15 |     adv = gae(data)
16 |     assert adv.shape == (T, B)
17 |     # single trajectory case/concat trajectory case
18 |     T = 24
19 |     value = torch.randn(T)
20 |     next_value = torch.randn(T)
21 |     reward = torch.randn(T)
22 |     done = torch.zeros((T))
23 |     data = gae_data(value, next_value, reward, done, None)
24 |     adv = gae(data)
25 |     assert adv.shape == (T, )
26 | 
27 | 
28 | def test_gae_multi_agent():
29 |     T, B, A = 32, 4, 8
30 |     value = torch.randn(T, B, A)
31 |     next_value = torch.randn(T, B, A)
32 |     reward = torch.randn(T, B)
33 |     done = torch.zeros(T, B)
34 |     data = gae_data(value, next_value, reward, done, None)
35 |     adv = gae(data)
36 |     assert adv.shape == (T, B, A)
37 | 


--------------------------------------------------------------------------------
/ding/rl_utils/tests/test_retrace.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import torch
 3 | from ding.rl_utils import compute_q_retraces
 4 | 
 5 | 
 6 | @pytest.mark.unittest
 7 | def test_compute_q_retraces():
 8 |     T, B, N = 64, 32, 6
 9 |     q_values = torch.randn(T + 1, B, N)
10 |     v_pred = torch.randn(T + 1, B, 1)
11 |     rewards = torch.randn(T, B)
12 |     ratio = torch.rand(T, B, N) * 0.4 + 0.8
13 |     assert ratio.max() <= 1.2 and ratio.min() >= 0.8
14 |     weights = torch.rand(T, B)
15 |     actions = torch.randint(0, N, size=(T, B))
16 |     with torch.no_grad():
17 |         q_retraces = compute_q_retraces(q_values, v_pred, rewards, actions, weights, ratio, gamma=0.99)
18 |     assert q_retraces.shape == (T + 1, B, 1)
19 | 


--------------------------------------------------------------------------------
/ding/scripts/docker-test-entry.sh:
--------------------------------------------------------------------------------
 1 | #!/usr/bin/env bash
 2 | 
 3 | CONTAINER_ID=$(docker run --rm -d opendilab/ding:nightly tail -f /dev/null)
 4 | 
 5 | trap "docker rm -f $CONTAINER_ID" EXIT
 6 | 
 7 | docker exec $CONTAINER_ID rm -rf /ding &&
 8 |   docker cp $(pwd) ${CONTAINER_ID}:/ding &&
 9 |   docker exec -it $CONTAINER_ID /ding/ding/scripts/docker-test.sh
10 | 


--------------------------------------------------------------------------------
/ding/scripts/docker-test.sh:
--------------------------------------------------------------------------------
 1 | #!/usr/bin/env bash
 2 | 
 3 | if [ ! -f /.dockerenv ]; then
 4 |   echo "This script should be executed in docker container"
 5 |   exit 1
 6 | fi
 7 | 
 8 | pip install --ignore-installed 'PyYAML<6.0'
 9 | pip install -e .[test,k8s] &&
10 |   ./ding/scripts/install-k8s-tools.sh &&
11 |   make test
12 | 


--------------------------------------------------------------------------------
/ding/scripts/install-k8s-tools.sh:
--------------------------------------------------------------------------------
 1 | #!/usr/bin/env bash
 2 | 
 3 | set -e
 4 | 
 5 | ROOT_DIR="$(dirname "$0")"
 6 | : ${USE_SUDO:="true"}
 7 | 
 8 | # runs the given command as root (detects if we are root already)
 9 | runAsRoot() {
10 |     local CMD="$*"
11 | 
12 |     if [ $EUID -ne 0 -a $USE_SUDO = "true" ]; then
13 |         CMD="sudo $CMD"
14 |     fi
15 | 
16 |     $CMD
17 | }
18 | 
19 | # install k3d
20 | curl -s https://raw.githubusercontent.com/rancher/k3d/main/install.sh | TAG=v4.4.8 bash
21 | 
22 | # install kubectl
23 | if [[ $(which kubectl) == "" ]]; then
24 |     echo "Installing kubectl..."
25 |     curl -LO https://dl.k8s.io/release/v1.21.3/bin/linux/amd64/kubectl
26 |     chmod +x kubectl
27 |     runAsRoot mv kubectl /usr/local/bin/kubectl
28 | fi
29 | 


--------------------------------------------------------------------------------
/ding/scripts/kill.sh:
--------------------------------------------------------------------------------
1 | ps -ef | grep  'ding' | grep -v grep | awk '{print $2}'|xargs kill -9
2 | 


--------------------------------------------------------------------------------
/ding/scripts/local_parallel.sh:
--------------------------------------------------------------------------------
1 | ding -m parallel -c $1 -s $2
2 | 


--------------------------------------------------------------------------------
/ding/scripts/local_serial.sh:
--------------------------------------------------------------------------------
1 | ding -m serial -c $1 -s $2
2 | 


--------------------------------------------------------------------------------
/ding/scripts/main_league_slurm.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | 
3 | export LC_ALL=en_US.utf-8
4 | export LANG=en_US.utf-8
5 | BASEDIR=$(dirname "$0")
6 | # srun -p partition_name --quotatype=reserved --mpi=pmi2 -n6 --ntasks-per-node=3 bash ding/scripts/main_league_slurm.sh
7 | ditask --package $BASEDIR/../entry --main main_league.main --platform slurm --platform-spec '{"tasks":[{"labels":"league,collect","node_ids":10},{"labels":"league,collect","node_ids":11},{"labels":"evaluate","node_ids":20,"attach_to":"$node.10,$node.11"},{"labels":"learn","node_ids":31,"attach_to":"$node.10,$node.11,$node.20"},{"labels":"learn","node_ids":32,"attach_to":"$node.10,$node.11,$node.20"},{"labels":"learn","node_ids":33,"attach_to":"$node.10,$node.11,$node.20"}]}'
8 | 


--------------------------------------------------------------------------------
/ding/scripts/tests/test_parallel_socket.sh:
--------------------------------------------------------------------------------
 1 | total_epoch=1200            # the total num of msg
 2 | interval=0.1                # msg send interval
 3 | size=16                     # data size (MB)
 4 | test_start_time=20          # network fail time (s)
 5 | test_duration=40            # network fail duration (s)
 6 | output_file="my_test.log"   # the python script will write its output into this file
 7 | ip="0.0.0.0"
 8 | 
 9 | rm -f pytmp_*
10 | 
11 | nohup python test_parallel_socket.py -t $total_epoch -i $interval -s $size 1>$output_file 2>&1 &
12 | 
13 | flag=true
14 | while $flag
15 | do
16 |     for file in `ls`
17 |     do
18 |         if [[ $file =~ "pytmp" ]]; then
19 |             ip=`cat $file`
20 |             flag=false
21 |             break
22 |         fi
23 |     done
24 |     sleep 0.1
25 | done
26 | echo "get ip: $ip"
27 | 
28 | sleep $test_start_time
29 | echo "Network shutsown . . ."
30 | sudo iptables -A INPUT -p tcp -s $ip --dport 50516 -j DROP
31 | 
32 | sleep $test_duration
33 | sudo iptables -D INPUT -p tcp -s $ip --dport 50516 -j DROP
34 | echo "Network recovered . . ."
35 | 
36 | 
37 | 


--------------------------------------------------------------------------------
/ding/torch_utils/__init__.py:
--------------------------------------------------------------------------------
 1 | from .checkpoint_helper import build_checkpoint_helper, CountVar, auto_checkpoint
 2 | from .data_helper import to_device, to_tensor, to_ndarray, to_list, to_dtype, same_shape, tensor_to_list, \
 3 |     build_log_buffer, CudaFetcher, get_tensor_data, unsqueeze, squeeze, get_null_data, get_shape0, to_item, \
 4 |     zeros_like
 5 | from .distribution import CategoricalPd, CategoricalPdPytorch
 6 | from .metric import levenshtein_distance, hamming_distance
 7 | from .network import *
 8 | from .loss import *
 9 | from .optimizer_helper import Adam, RMSprop, calculate_grad_norm, calculate_grad_norm_without_bias_two_norm
10 | from .nn_test_helper import is_differentiable
11 | from .math_helper import cov
12 | from .dataparallel import DataParallel
13 | from .reshape_helper import fold_batch, unfold_batch, unsqueeze_repeat
14 | from .parameter import NonegativeParameter, TanhParameter
15 | 


--------------------------------------------------------------------------------
/ding/torch_utils/backend_helper.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | 
 3 | 
 4 | def enable_tf32() -> None:
 5 |     """
 6 |     Overview:
 7 |         Enable tf32 on matmul and cudnn for faster computation. This only works on Ampere GPU devices. \
 8 |         For detailed information, please refer to: \
 9 |         https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices.
10 |     """
11 |     torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
12 |     torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
13 | 


--------------------------------------------------------------------------------
/ding/torch_utils/diffusion_SDE/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/torch_utils/diffusion_SDE/__init__.py


--------------------------------------------------------------------------------
/ding/torch_utils/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from .cross_entropy_loss import LabelSmoothCELoss, SoftFocalLoss, build_ce_criterion
2 | from .multi_logits_loss import MultiLogitsLoss
3 | from .contrastive_loss import ContrastiveLoss
4 | 


--------------------------------------------------------------------------------
/ding/torch_utils/loss/tests/test_cross_entropy_loss.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import torch
 3 | import torch.nn as nn
 4 | 
 5 | from ding.torch_utils import LabelSmoothCELoss, SoftFocalLoss
 6 | 
 7 | 
 8 | @pytest.mark.unittest
 9 | class TestLabelSmoothCE:
10 | 
11 |     def test_label_smooth_ce_loss(self):
12 |         logits = torch.randn(4, 6)
13 |         labels = torch.LongTensor([i for i in range(4)])
14 |         criterion1 = LabelSmoothCELoss(0)
15 |         criterion2 = nn.CrossEntropyLoss()
16 |         assert (torch.abs(criterion1(logits, labels) - criterion2(logits, labels)) < 1e-6)
17 | 
18 | 
19 | @pytest.mark.unittest
20 | class TestSoftFocalLoss:
21 | 
22 |     def test_soft_focal_loss(self):
23 |         logits = torch.randn(4, 6)
24 |         labels = torch.LongTensor([i for i in range(4)])
25 |         criterion = SoftFocalLoss()
26 |         loss = criterion(logits, labels)
27 |         assert loss.shape == ()
28 |         loss_value = loss.item()
29 | 


--------------------------------------------------------------------------------
/ding/torch_utils/loss/tests/test_multi_logits_loss.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import torch
 3 | from ding.torch_utils import MultiLogitsLoss
 4 | 
 5 | 
 6 | @pytest.mark.unittest
 7 | @pytest.mark.parametrize('criterion_type', ['cross_entropy', 'label_smooth_ce'])
 8 | def test_multi_logits_loss(criterion_type):
 9 |     logits = torch.randn(4, 8).requires_grad_(True)
10 |     label = torch.LongTensor([0, 1, 3, 2])
11 |     criterion = MultiLogitsLoss(criterion=criterion_type)
12 |     loss = criterion(logits, label)
13 |     assert loss.shape == ()
14 |     assert logits.grad is None
15 |     loss.backward()
16 |     assert isinstance(logits, torch.Tensor)
17 | 


--------------------------------------------------------------------------------
/ding/torch_utils/model_helper.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | 
 3 | 
 4 | def get_num_params(model: torch.nn.Module) -> int:
 5 |     """
 6 |     Overview:
 7 |         Return the number of parameters in the model.
 8 |     Arguments:
 9 |         - model (:obj:`torch.nn.Module`): The model object to calculate the parameter number.
10 |     Returns:
11 |         - n_params (:obj:`int`): The calculated number of parameters.
12 |     Examples:
13 |         >>> model = torch.nn.Linear(3, 5)
14 |         >>> num = get_num_params(model)
15 |         >>> assert num == 15
16 |     """
17 |     return sum(p.numel() for p in model.parameters())
18 | 


--------------------------------------------------------------------------------
/ding/torch_utils/network/__init__.py:
--------------------------------------------------------------------------------
 1 | from .activation import build_activation, Swish
 2 | from .res_block import ResBlock, ResFCBlock
 3 | from .nn_module import fc_block, conv2d_block, one_hot, deconv2d_block, BilinearUpsample, NearestUpsample, \
 4 |     binary_encode, NoiseLinearLayer, noise_block, MLP, Flatten, normed_linear, normed_conv2d, conv1d_block
 5 | from .normalization import build_normalization
 6 | from .rnn import get_lstm, sequence_mask
 7 | from .soft_argmax import SoftArgmax
 8 | from .transformer import Transformer, ScaledDotProductAttention
 9 | from .scatter_connection import ScatterConnection
10 | from .resnet import resnet18, ResNet
11 | from .gumbel_softmax import GumbelSoftmax
12 | from .gtrxl import GTrXL, GRUGatingUnit
13 | from .popart import PopArt
14 | #from .dreamer import Conv2dSame, DreamerLayerNorm, ActionHead, DenseHead
15 | from .merge import GatingType, SumMerge, VectorMerge
16 | 


--------------------------------------------------------------------------------
/ding/torch_utils/network/tests/test_gumbel_softmax.py:
--------------------------------------------------------------------------------
 1 | import numpy as np
 2 | import pytest
 3 | import torch
 4 | 
 5 | from ding.torch_utils.network import GumbelSoftmax, gumbel_softmax
 6 | 
 7 | 
 8 | @pytest.mark.unittest
 9 | class TestGumbelSoftmax:
10 | 
11 |     def test(self):
12 |         B = 4
13 |         N = 10
14 |         model = GumbelSoftmax()
15 |         # data case 1
16 |         for _ in range(N):
17 |             data = torch.rand((4, 10))
18 |             data = torch.log(data)
19 |             gumbelsoftmax = model(data, hard=False)
20 |             assert gumbelsoftmax.shape == (B, N)
21 |         # data case 2
22 |         for _ in range(N):
23 |             data = torch.rand((4, 10))
24 |             data = torch.log(data)
25 |             gumbelsoftmax = model(data, hard=True)
26 |             assert gumbelsoftmax.shape == (B, N)
27 | 


--------------------------------------------------------------------------------
/ding/torch_utils/network/tests/test_transformer.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import torch
 3 | 
 4 | from ding.torch_utils import Transformer
 5 | 
 6 | 
 7 | @pytest.mark.unittest
 8 | class TestTransformer:
 9 | 
10 |     def test(self):
11 |         batch_size = 2
12 |         num_entries = 2
13 |         C = 2
14 |         masks = [None, torch.rand(batch_size, num_entries).round().bool()]
15 |         for mask in masks:
16 |             output_dim = 4
17 |             model = Transformer(
18 |                 input_dim=C,
19 |                 head_dim=2,
20 |                 hidden_dim=3,
21 |                 output_dim=output_dim,
22 |                 head_num=2,
23 |                 mlp_num=2,
24 |                 layer_num=2,
25 |             )
26 |             input = torch.rand(batch_size, num_entries, C).requires_grad_(True)
27 |             output = model(input, mask)
28 |             loss = output.mean()
29 |             loss.backward()
30 |             assert isinstance(input.grad, torch.Tensor)
31 |             assert output.shape == (batch_size, num_entries, output_dim)
32 | 


--------------------------------------------------------------------------------
/ding/torch_utils/tests/test_backend_helper.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import torch
 3 | 
 4 | from ding.torch_utils.backend_helper import enable_tf32
 5 | 
 6 | 
 7 | @pytest.mark.cudatest
 8 | class TestBackendHelper:
 9 | 
10 |     def test_tf32(self):
11 |         r"""
12 |         Overview:
13 |             Test the tf32.
14 |         """
15 |         enable_tf32()
16 |         net = torch.nn.Linear(3, 4)
17 |         x = torch.randn(1, 3)
18 |         y = torch.sum(net(x))
19 |         net.zero_grad()
20 |         y.backward()
21 |         assert net.weight.grad is not None
22 | 


--------------------------------------------------------------------------------
/ding/torch_utils/tests/test_lr_scheduler.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import torch
 3 | from torch.optim import Adam
 4 | 
 5 | from ding.torch_utils.lr_scheduler import cos_lr_scheduler
 6 | 
 7 | 
 8 | @pytest.mark.unittest
 9 | class TestLRSchedulerHelper:
10 | 
11 |     def test_cos_lr_scheduler(self):
12 |         r"""
13 |         Overview:
14 |             Test the cos lr scheduler.
15 |         """
16 |         net = torch.nn.Linear(3, 4)
17 |         opt = Adam(net.parameters(), lr=1e-2)
18 |         scheduler = cos_lr_scheduler(opt, learning_rate=1e-2, min_lr=6e-5)
19 |         scheduler.step(101)
20 |         assert opt.param_groups[0]['lr'] == 6e-5
21 | 


--------------------------------------------------------------------------------
/ding/torch_utils/tests/test_model_helper.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import torch
 3 | 
 4 | from ding.torch_utils.model_helper import get_num_params
 5 | 
 6 | 
 7 | @pytest.mark.unittest
 8 | class TestModelHelper:
 9 | 
10 |     def test_model_helper(self):
11 |         r"""
12 |         Overview:
13 |             Test the model helper.
14 |         """
15 |         net = torch.nn.Linear(3, 4, bias=False)
16 |         assert get_num_params(net) == 12
17 | 
18 |         net = torch.nn.Conv2d(3, 3, kernel_size=3, bias=False)
19 |         assert get_num_params(net) == 81
20 | 


--------------------------------------------------------------------------------
/ding/torch_utils/tests/test_parameter.py:
--------------------------------------------------------------------------------
 1 | import unittest
 2 | import pytest
 3 | import torch
 4 | from ding.torch_utils.parameter import NonegativeParameter, TanhParameter
 5 | 
 6 | 
 7 | @pytest.mark.unittest
 8 | def test_nonegative_parameter():
 9 |     nonegative_parameter = NonegativeParameter(torch.tensor([2.0, 3.0]))
10 |     assert torch.sum(torch.abs(nonegative_parameter() - torch.tensor([2.0, 3.0]))) == 0
11 |     nonegative_parameter.set_data(torch.tensor(1))
12 |     assert nonegative_parameter() == 1
13 | 
14 | 
15 | @pytest.mark.unittest
16 | def test_tanh_parameter():
17 |     tanh_parameter = TanhParameter(torch.tensor([0.5, -0.2]))
18 |     assert torch.isclose(tanh_parameter() - torch.tensor([0.5, -0.2]), torch.zeros(2), atol=1e-6).all()
19 |     tanh_parameter.set_data(torch.tensor(0.3))
20 |     assert tanh_parameter() == 0.3
21 | 
22 | 
23 | if __name__ == "__main__":
24 |     test_nonegative_parameter()
25 |     test_tanh_parameter()
26 | 


--------------------------------------------------------------------------------
/ding/utils/autolog/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import TimeMode
2 | from .data import RangedData, TimeRangedData
3 | from .model import LoggedModel
4 | from .time_ctl import BaseTime, NaturalTime, TickTime, TimeProxy
5 | from .value import LoggedValue
6 | 
7 | if __name__ == "__main__":
8 |     pass
9 | 


--------------------------------------------------------------------------------
/ding/utils/autolog/base.py:
--------------------------------------------------------------------------------
 1 | from enum import unique, IntEnum
 2 | from typing import TypeVar, Union
 3 | 
 4 | _LOGGED_VALUE__PROPERTY_NAME = '__property_name__'
 5 | _LOGGED_MODEL__PROPERTIES = '__properties__'
 6 | _LOGGED_MODEL__PROPERTY_ATTR_PREFIX = '_property_'
 7 | 
 8 | _TimeType = TypeVar('_TimeType', bound=Union[float, int])
 9 | _ValueType = TypeVar('_ValueType')
10 | 
11 | 
12 | @unique
13 | class TimeMode(IntEnum):
14 |     """
15 |     Overview:
16 |         Mode that used to decide the format of range_values function
17 | 
18 |         ABSOLUTE: use absolute time
19 |         RELATIVE_LIFECYCLE: use relative time based on property's lifecycle
20 |         RELATIVE_CURRENT_TIME: use relative time based on current time
21 |     """
22 |     ABSOLUTE = 0
23 |     RELATIVE_LIFECYCLE = 1
24 |     RELATIVE_CURRENT_TIME = 2
25 | 


--------------------------------------------------------------------------------
/ding/utils/autolog/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/utils/autolog/tests/__init__.py


--------------------------------------------------------------------------------
/ding/utils/collection_helper.py:
--------------------------------------------------------------------------------
 1 | from typing import Iterable, TypeVar, Callable
 2 | 
 3 | _IterType = TypeVar('_IterType')
 4 | _IterTargetType = TypeVar('_IterTargetType')
 5 | 
 6 | 
 7 | def iter_mapping(iter_: Iterable[_IterType], mapping: Callable[[_IterType], _IterTargetType]):
 8 |     """
 9 |     Overview:
10 |         Map a list of iterable elements to input iteration callable
11 |     Arguments:
12 |         - iter_(:obj:`_IterType list`): The list for iteration
13 |         - mapping (:obj:`Callable [[_IterType], _IterTargetType]`): A callable that maps iterable elements function.
14 |     Return:
15 |         - (:obj:`iter_mapping object`): Iteration results
16 |     Example:
17 |         >>> iterable_list = [1, 2, 3, 4, 5]
18 |         >>> _iter = iter_mapping(iterable_list, lambda x: x ** 2)
19 |         >>> print(list(_iter))
20 |         [1, 4, 9, 16, 25]
21 |     """
22 |     for item in iter_:
23 |         yield mapping(item)
24 | 


--------------------------------------------------------------------------------
/ding/utils/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .collate_fn import diff_shape_collate, default_collate, default_decollate, timestep_collate, ttorch_collate
2 | from .dataloader import AsyncDataLoader
3 | from .dataset import NaiveRLDataset, D4RLDataset, HDF5Dataset, BCODataset, \
4 |     create_dataset, hdf5_save, offline_data_save_type
5 | from .rlhf_online_dataset import OnlineRLDataset
6 | from .rlhf_offline_dataset import OfflineRLDataset
7 | 


--------------------------------------------------------------------------------
/ding/utils/data/structure/__init__.py:
--------------------------------------------------------------------------------
1 | from .cache import Cache
2 | from .lifo_deque import LifoDeque
3 | 


--------------------------------------------------------------------------------
/ding/utils/data/structure/lifo_deque.py:
--------------------------------------------------------------------------------
 1 | from queue import LifoQueue
 2 | from collections import deque
 3 | 
 4 | 
 5 | class LifoDeque(LifoQueue):
 6 |     """
 7 |     Overview:
 8 |         Like LifoQueue, but automatically replaces the oldest data when the queue is full.
 9 |     Interfaces:
10 |         ``_init``, ``_put``, ``_get``
11 |     """
12 | 
13 |     def _init(self, maxsize):
14 |         self.maxsize = maxsize + 1
15 |         self.queue = deque(maxlen=maxsize)
16 | 


--------------------------------------------------------------------------------
/ding/utils/design_helper.py:
--------------------------------------------------------------------------------
 1 | from abc import ABCMeta
 2 | 
 3 | 
 4 | # ABCMeta is a subclass of type, extending ABCMeta makes this metaclass is compatible with some classes
 5 | # which extends ABC
 6 | class SingletonMetaclass(ABCMeta):
 7 |     """
 8 |     Overview:
 9 |         Returns the given type instance in input class
10 |     Interfaces:
11 |         ``__call__``
12 |     """
13 |     instances = {}
14 | 
15 |     def __call__(cls: type, *args, **kwargs) -> object:
16 |         """
17 |         Overview:
18 |             Returns the given type instance in input class
19 |         """
20 | 
21 |         if cls not in SingletonMetaclass.instances:
22 |             SingletonMetaclass.instances[cls] = super(SingletonMetaclass, cls).__call__(*args, **kwargs)
23 |             cls.instance = SingletonMetaclass.instances[cls]
24 |         return SingletonMetaclass.instances[cls]
25 | 


--------------------------------------------------------------------------------
/ding/utils/dict_helper.py:
--------------------------------------------------------------------------------
 1 | from easydict import EasyDict
 2 | 
 3 | 
 4 | def convert_easy_dict_to_dict(easy_dict: EasyDict) -> dict:
 5 |     """
 6 |     Overview:
 7 |         Convert an EasyDict object to a dict object recursively.
 8 |     Arguments:
 9 |         - easy_dict (:obj:`EasyDict`): The EasyDict object to be converted.
10 |     Returns:
11 |         - dict: The converted dict object.
12 |     """
13 |     return {k: convert_easy_dict_to_dict(v) if isinstance(v, EasyDict) else v for k, v in easy_dict.items()}
14 | 


--------------------------------------------------------------------------------
/ding/utils/fake_linklink.py:
--------------------------------------------------------------------------------
 1 | from collections import namedtuple
 2 | 
 3 | 
 4 | class FakeClass:
 5 |     """
 6 |     Overview:
 7 |         Fake class.
 8 |     """
 9 | 
10 |     def __init__(self, *args, **kwargs):
11 |         pass
12 | 
13 | 
14 | class FakeNN:
15 |     """
16 |     Overview:
17 |         Fake nn class.
18 |     """
19 | 
20 |     SyncBatchNorm2d = FakeClass
21 | 
22 | 
23 | class FakeLink:
24 |     """
25 |     Overview:
26 |         Fake link class.
27 |     """
28 | 
29 |     nn = FakeNN()
30 |     syncbnVarMode_t = namedtuple("syncbnVarMode_t", "L2")(L2=None)
31 |     allreduceOp_t = namedtuple("allreduceOp_t", ['Sum', 'Max'])
32 | 
33 | 
34 | link = FakeLink()
35 | 


--------------------------------------------------------------------------------
/ding/utils/loader/__init__.py:
--------------------------------------------------------------------------------
 1 | from .base import Loader
 2 | from .collection import collection, CollectionError, length, length_is, contains, tuple_, cofilter, tpselector
 3 | from .dict import DictError, dict_
 4 | from .exception import CompositeStructureError
 5 | from .mapping import mapping, MappingError, mpfilter, mpkeys, mpvalues, mpitems, item, item_or
 6 | from .norm import norm, normfunc, lnot, land, lor, lin, lis, lisnot, lsum, lcmp
 7 | from .number import interval, numeric, negative, positive, plus, minus, minus_with, multi, divide, divide_with, power, \
 8 |     power_with, msum, mmulti, mcmp, is_negative, is_positive, non_negative, non_positive
 9 | from .string import enum, rematch, regrep
10 | from .types import is_type, to_type, is_callable, prop, method, fcall, fpartial
11 | from .utils import keep, optional, check_only, raw, check
12 | 


--------------------------------------------------------------------------------
/ding/utils/loader/exception.py:
--------------------------------------------------------------------------------
 1 | from abc import ABCMeta, abstractmethod
 2 | from typing import List, Union, Tuple
 3 | 
 4 | INDEX_TYPING = Union[int, str]
 5 | ERROR_ITEM_TYPING = Tuple[INDEX_TYPING, Exception]
 6 | ERROR_ITEMS = List[ERROR_ITEM_TYPING]
 7 | 
 8 | 
 9 | class CompositeStructureError(ValueError, metaclass=ABCMeta):
10 |     """
11 |     Overview:
12 |         Composite structure error.
13 |     Interfaces:
14 |         ``__init__``, ``errors``
15 |     Properties:
16 |         ``errors``
17 |     """
18 | 
19 |     @property
20 |     @abstractmethod
21 |     def errors(self) -> ERROR_ITEMS:
22 |         """
23 |         Overview:
24 |             Get the errors.
25 |         """
26 | 
27 |         raise NotImplementedError
28 | 


--------------------------------------------------------------------------------
/ding/utils/loader/tests/__init__.py:
--------------------------------------------------------------------------------
1 | from .loader import *
2 | from .test_cartpole_dqn_serial_config_loader import test_main_config, test_create_config
3 | 


--------------------------------------------------------------------------------
/ding/utils/loader/tests/loader/__init__.py:
--------------------------------------------------------------------------------
 1 | from .test_base import TestConfigLoaderBase
 2 | from .test_collection import TestConfigLoaderCollection
 3 | from .test_dict import TestConfigLoaderDict
 4 | from .test_mapping import TestConfigLoaderMapping
 5 | from .test_norm import TestConfigLoaderNorm
 6 | from .test_number import TestConfigLoaderNumber
 7 | from .test_string import TestConfigLoaderString
 8 | from .test_types import TestConfigLoaderTypes
 9 | from .test_utils import TestConfigLoaderUtils
10 | 


--------------------------------------------------------------------------------
/ding/utils/loader/utils.py:
--------------------------------------------------------------------------------
 1 | from .base import Loader, ILoaderClass
 2 | 
 3 | 
 4 | def keep() -> ILoaderClass:
 5 |     """
 6 |     Overview:
 7 |         Create a keep loader.
 8 |     """
 9 | 
10 |     return Loader(lambda v: v)
11 | 
12 | 
13 | def raw(value) -> ILoaderClass:
14 |     """
15 |     Overview:
16 |         Create a raw loader.
17 |     """
18 | 
19 |     return Loader(lambda v: value)
20 | 
21 | 
22 | def optional(loader) -> ILoaderClass:
23 |     """
24 |     Overview:
25 |         Create a optional loader.
26 |     Arguments:
27 |         - loader (:obj:`ILoaderClass`): The loader.
28 |     """
29 | 
30 |     return Loader(loader) | None
31 | 
32 | 
33 | def check_only(loader) -> ILoaderClass:
34 |     """
35 |     Overview:
36 |         Create a check only loader.
37 |     Arguments:
38 |         - loader (:obj:`ILoaderClass`): The loader.
39 |     """
40 | 
41 |     return Loader(loader) & keep()
42 | 
43 | 
44 | def check(loader) -> ILoaderClass:
45 |     """
46 |     Overview:
47 |         Create a check loader.
48 |     Arguments:
49 |         - loader (:obj:`ILoaderClass`): The loader.
50 |     """
51 | 
52 |     return Loader(lambda x: Loader(loader).check(x))
53 | 


--------------------------------------------------------------------------------
/ding/utils/tests/config/k8s-config.yaml:
--------------------------------------------------------------------------------
1 | type: k3s # k3s or local
2 | name: di-cluster
3 | servers: 1 # # of k8s masters
4 | agents: 0 # # of k8s nodes
5 | preload_images: 
6 | - busybox:latest
7 | - hello-world:latest
8 | 


--------------------------------------------------------------------------------
/ding/utils/tests/test_bfs_helper.py:
--------------------------------------------------------------------------------
 1 | import easydict
 2 | import numpy
 3 | import pytest
 4 | 
 5 | from ding.utils import get_vi_sequence
 6 | from dizoo.maze.envs.maze_env import Maze
 7 | 
 8 | 
 9 | @pytest.mark.unittest
10 | class TestBFSHelper:
11 | 
12 |     def test_bfs(self):
13 | 
14 |         def load_env(seed):
15 |             ccc = easydict.EasyDict({'size': 16})
16 |             e = Maze(ccc)
17 |             e.seed(seed)
18 |             e.reset()
19 |             return e
20 | 
21 |         env = load_env(314)
22 |         start_obs = env.process_states(env._get_obs(), env.get_maze_map())
23 |         vi_sequence, track_back = get_vi_sequence(env, start_obs)
24 |         assert vi_sequence.shape[1:] == (16, 16)
25 |         assert track_back[0][0].shape == (16, 16, 3)
26 |         assert isinstance(track_back[0][1], numpy.int32)
27 | 


--------------------------------------------------------------------------------
/ding/utils/tests/test_collection_helper.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | 
 3 | from ding.utils.collection_helper import iter_mapping
 4 | 
 5 | 
 6 | @pytest.mark.unittest
 7 | class TestCollectionHelper:
 8 | 
 9 |     def test_iter_mapping(self):
10 |         _iter = iter_mapping([1, 2, 3, 4, 5], lambda x: x ** 2)
11 | 
12 |         assert not isinstance(_iter, list)
13 |         assert list(_iter) == [1, 4, 9, 16, 25]
14 | 


--------------------------------------------------------------------------------
/ding/utils/tests/test_compression_helper.py:
--------------------------------------------------------------------------------
 1 | import random
 2 | import numpy as np
 3 | from ding.utils.compression_helper import get_data_compressor, get_data_decompressor
 4 | 
 5 | import pytest
 6 | 
 7 | 
 8 | @pytest.mark.unittest
 9 | class TestCompression():
10 | 
11 |     def get_step_data(self):
12 |         return {'input': [random.randint(10, 100) for i in range(100)]}
13 | 
14 |     def testnaive(self):
15 |         compress_names = ['lz4', 'zlib', 'none']
16 |         for s in compress_names:
17 |             compressor = get_data_compressor(s)
18 |             decompressor = get_data_decompressor(s)
19 |             data = self.get_step_data()
20 |             assert data == decompressor(compressor(data))
21 | 
22 |     def test_arr_to_st(self):
23 |         data = np.random.randint(0, 255, (96, 96, 3), dtype=np.uint8)
24 |         compress_names = ['jpeg']
25 |         for s in compress_names:
26 |             compressor = get_data_compressor(s)
27 |             decompressor = get_data_decompressor(s)
28 |             assert data.shape == decompressor(compressor(data)).shape
29 | 


--------------------------------------------------------------------------------
/ding/utils/tests/test_deprecation.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import warnings
 3 | from ding.utils.deprecation import deprecated
 4 | 
 5 | 
 6 | @pytest.mark.unittest
 7 | def test_deprecated():
 8 | 
 9 |     @deprecated('0.4.1', '0.5.1')
10 |     def deprecated_func1():
11 |         pass
12 | 
13 |     @deprecated('0.4.1', '0.5.1', 'deprecated_func3')
14 |     def deprecated_func2():
15 |         pass
16 | 
17 |     with warnings.catch_warnings(record=True) as w:
18 |         deprecated_func1()
19 |         assert (
20 |             'API `test_deprecation.deprecated_func1` is deprecated '
21 |             'since version 0.4.1 and will be removed in version 0.5.1.'
22 |         ) == str(w[-1].message)
23 |         deprecated_func2()
24 |         assert (
25 |             'API `test_deprecation.deprecated_func2` is deprecated '
26 |             'since version 0.4.1 and will be removed in version 0.5.1, '
27 |             'please use `deprecated_func3` instead.'
28 |         ) == str(w[-1].message)
29 | 


--------------------------------------------------------------------------------
/ding/utils/tests/test_design_helper.py:
--------------------------------------------------------------------------------
 1 | import random
 2 | 
 3 | import pytest
 4 | 
 5 | from ding.utils import SingletonMetaclass
 6 | 
 7 | 
 8 | @pytest.mark.unittest
 9 | def test_singleton():
10 |     global count
11 |     count = 0
12 | 
13 |     class A(object, metaclass=SingletonMetaclass):
14 | 
15 |         def __init__(self, t):
16 |             self.t = t
17 |             self.p = random.randint(0, 10)
18 |             global count
19 |             count += 1
20 | 
21 |     obj = [A(i) for i in range(3)]
22 |     assert count == 1
23 |     assert all([o.t == 0 for o in obj])
24 |     assert all([o.p == obj[0].p for o in obj])
25 |     assert all([id(o) == id(obj[0]) for o in obj])
26 |     assert id(A.instance) == id(obj[0])
27 | 
28 |     # subclass test
29 |     class B(A):
30 |         pass
31 | 
32 |     obj = [B(i) for i in range(3, 6)]
33 |     assert count == 2
34 |     assert all([o.t == 3 for o in obj])
35 |     assert all([o.p == obj[0].p for o in obj])
36 |     assert all([id(o) == id(obj[0]) for o in obj])
37 |     assert id(B.instance) == id(obj[0])
38 | 


--------------------------------------------------------------------------------
/ding/utils/tests/test_file_helper.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import random
 3 | import pickle
 4 | 
 5 | from ding.utils.file_helper import read_file, read_from_file, remove_file, save_file, read_from_path, save_file_ceph
 6 | 
 7 | 
 8 | @pytest.mark.unittest
 9 | def test_normal_file():
10 |     data1 = {'a': [random.randint(0, 100) for i in range(100)]}
11 |     save_file('./f', data1)
12 |     data2 = read_file("./f")
13 |     assert (data2 == data1)
14 |     with open("./f1", "wb") as f1:
15 |         pickle.dump(data1, f1)
16 |     data3 = read_from_file("./f1")
17 |     assert (data3 == data1)
18 |     data4 = read_from_path("./f1")
19 |     assert (data4 == data1)
20 |     save_file_ceph("./f2", data1)
21 |     assert (data1 == read_from_file("./f2"))
22 |     # test lock
23 |     save_file('./f3', data1, use_lock=True)
24 |     data_read = read_file('./f3', use_lock=True)
25 |     assert isinstance(data_read, dict)
26 | 
27 |     remove_file("./f")
28 |     remove_file("./f1")
29 |     remove_file("./f2")
30 |     remove_file("./f3")
31 |     remove_file('./f.lock')
32 |     remove_file('./f2.lock')
33 |     remove_file('./f3.lock')
34 |     remove_file('./name.txt')
35 | 


--------------------------------------------------------------------------------
/ding/utils/tests/test_import_helper.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | 
 3 | import ding
 4 | from ding.utils.import_helper import try_import_ceph, try_import_mc, try_import_redis, try_import_rediscluster, \
 5 |     try_import_link, import_module
 6 | 
 7 | 
 8 | @pytest.mark.unittest
 9 | def test_try_import():
10 |     try_import_ceph()
11 |     try_import_mc()
12 |     try_import_redis()
13 |     try_import_rediscluster()
14 |     try_import_link()
15 |     import_module(['ding.utils'])
16 |     ding.enable_linklink = True
17 |     try_import_link()
18 | 


--------------------------------------------------------------------------------
/ding/utils/tests/test_lock.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import numpy as np
 3 | from collections import deque
 4 | 
 5 | from ding.utils import LockContext, LockContextType, get_rw_file_lock
 6 | 
 7 | 
 8 | @pytest.mark.unittest
 9 | def test_usage():
10 |     lock = LockContext(LockContextType.PROCESS_LOCK)
11 |     queue = deque(maxlen=10)
12 |     data = np.random.randn(4)
13 |     with lock:
14 |         queue.append(np.copy(data))
15 |     with lock:
16 |         output = queue.popleft()
17 |     assert (output == data).all()
18 |     lock.acquire()
19 |     queue.append(np.copy(data))
20 |     lock.release()
21 |     lock.acquire()
22 |     output = queue.popleft()
23 |     lock.release()
24 |     assert (output == data).all()
25 | 
26 | 
27 | @pytest.mark.unittest
28 | def test_get_rw_file_lock():
29 |     path = 'tmp.npy'
30 |     # TODO real read-write case
31 |     read_lock = get_rw_file_lock(path, 'read')
32 |     write_lock = get_rw_file_lock(path, 'write')
33 |     with write_lock:
34 |         np.save(path, np.random.randint(0, 1, size=(3, 4)))
35 |     with read_lock:
36 |         data = np.load(path)
37 |     assert data.shape == (3, 4)
38 | 


--------------------------------------------------------------------------------
/ding/utils/tests/test_registry.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | from ding.utils.registry import Registry
 3 | 
 4 | 
 5 | @pytest.mark.unittest
 6 | def test_registry():
 7 |     TEST_REGISTRY = Registry()
 8 | 
 9 |     @TEST_REGISTRY.register('a')
10 |     class A:
11 |         pass
12 | 
13 |     instance = TEST_REGISTRY.build('a')
14 |     assert isinstance(instance, A)
15 | 
16 |     with pytest.raises(AssertionError):
17 | 
18 |         @TEST_REGISTRY.register('a')
19 |         class A1:
20 |             pass
21 | 
22 |     @TEST_REGISTRY.register('a', force_overwrite=True)
23 |     class A2:
24 |         pass
25 | 
26 |     instance = TEST_REGISTRY.build('a')
27 |     assert isinstance(instance, A2)
28 | 


--------------------------------------------------------------------------------
/ding/utils/tests/test_system_helper.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | 
 3 | from ding.utils.system_helper import get_ip, get_pid, get_task_uid
 4 | 
 5 | 
 6 | @pytest.mark.unittest
 7 | class TestSystemHelper():
 8 | 
 9 |     def test_get(self):
10 |         try:
11 |             get_ip()
12 |         except:
13 |             pass
14 |         assert isinstance(get_pid(), int)
15 |         assert isinstance(get_task_uid(), str)
16 | 


--------------------------------------------------------------------------------
/ding/utils/time_helper_base.py:
--------------------------------------------------------------------------------
 1 | class TimeWrapper(object):
 2 |     """
 3 |     Overview:
 4 |         Abstract class method that defines ``TimeWrapper`` class
 5 | 
 6 |     Interfaces:
 7 |         ``wrapper``, ``start_time``, ``end_time``
 8 |     """
 9 | 
10 |     @classmethod
11 |     def wrapper(cls, fn):
12 |         """
13 |         Overview:
14 |             Classmethod wrapper, wrap a function and automatically return its running time
15 |         Arguments:
16 |             - fn (:obj:`function`): The function to be wrap and timed
17 |         """
18 | 
19 |         def time_func(*args, **kwargs):
20 |             cls.start_time()
21 |             ret = fn(*args, **kwargs)
22 |             t = cls.end_time()
23 |             return ret, t
24 | 
25 |         return time_func
26 | 
27 |     @classmethod
28 |     def start_time(cls):
29 |         """
30 |         Overview:
31 |             Abstract classmethod, start timing
32 |         """
33 |         raise NotImplementedError
34 | 
35 |     @classmethod
36 |     def end_time(cls):
37 |         """
38 |         Overview:
39 |             Abstract classmethod, stop timing
40 |         """
41 |         raise NotImplementedError
42 | 


--------------------------------------------------------------------------------
/ding/utils/type_helper.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from typing import List, Tuple, TypeVar
3 | 
4 | SequenceType = TypeVar('SequenceType', List, Tuple, namedtuple)
5 | Tensor = TypeVar('torch.Tensor')
6 | 


--------------------------------------------------------------------------------
/ding/worker/__init__.py:
--------------------------------------------------------------------------------
1 | from .collector import *
2 | from .learner import *
3 | from .replay_buffer import *
4 | from .coordinator import *
5 | from .adapter import *
6 | 


--------------------------------------------------------------------------------
/ding/worker/adapter/__init__.py:
--------------------------------------------------------------------------------
1 | from .learner_aggregator import LearnerAggregator
2 | 


--------------------------------------------------------------------------------
/ding/worker/collector/__init__.py:
--------------------------------------------------------------------------------
 1 | # serial
 2 | from .base_serial_collector import ISerialCollector, create_serial_collector, get_serial_collector_cls, \
 3 |     to_tensor_transitions
 4 | 
 5 | from .sample_serial_collector import SampleSerialCollector
 6 | from .episode_serial_collector import EpisodeSerialCollector
 7 | from .battle_episode_serial_collector import BattleEpisodeSerialCollector
 8 | from .battle_sample_serial_collector import BattleSampleSerialCollector
 9 | 
10 | from .base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor, create_serial_evaluator
11 | from .interaction_serial_evaluator import InteractionSerialEvaluator
12 | from .battle_interaction_serial_evaluator import BattleInteractionSerialEvaluator
13 | from .metric_serial_evaluator import MetricSerialEvaluator, IMetric
14 | # parallel
15 | from .base_parallel_collector import BaseParallelCollector, create_parallel_collector, get_parallel_collector_cls
16 | from .zergling_parallel_collector import ZerglingParallelCollector
17 | from .marine_parallel_collector import MarineParallelCollector
18 | from .comm import BaseCommCollector, FlaskFileSystemCollector, create_comm_collector, NaiveCollector
19 | 


--------------------------------------------------------------------------------
/ding/worker/collector/comm/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_comm_collector import BaseCommCollector, create_comm_collector
2 | from .flask_fs_collector import FlaskFileSystemCollector
3 | from .utils import NaiveCollector  # for test
4 | 


--------------------------------------------------------------------------------
/ding/worker/collector/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/worker/collector/tests/__init__.py


--------------------------------------------------------------------------------
/ding/worker/collector/tests/fake_cls_policy.py:
--------------------------------------------------------------------------------
 1 | from ding.policy import Policy
 2 | from ding.model import model_wrap
 3 | 
 4 | 
 5 | class fake_policy(Policy):
 6 | 
 7 |     def _init_learn(self):
 8 |         pass
 9 | 
10 |     def _forward_learn(self, data):
11 |         pass
12 | 
13 |     def _init_eval(self):
14 |         self._eval_model = model_wrap(self._model, 'base')
15 | 
16 |     def _forward_eval(self, data):
17 |         self._eval_model.eval()
18 |         output = self._eval_model.forward(data)
19 |         return output
20 | 
21 |     def _monitor_vars_learn(self):
22 |         return ['forward_time', 'backward_time', 'sync_time']
23 | 
24 |     def _init_collect(self):
25 |         pass
26 | 
27 |     def _forward_collect(self, data):
28 |         pass
29 | 
30 |     def _process_transition(self):
31 |         pass
32 | 
33 |     def _get_train_sample(self):
34 |         pass
35 | 


--------------------------------------------------------------------------------
/ding/worker/collector/tests/speed_test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/worker/collector/tests/speed_test/__init__.py


--------------------------------------------------------------------------------
/ding/worker/collector/tests/speed_test/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | 
3 | 
4 | def random_change(number):
5 |     return number * (1 + (np.random.random() - 0.5) * 0.6)
6 | 


--------------------------------------------------------------------------------
/ding/worker/coordinator/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_serial_commander import BaseSerialCommander
2 | from .base_parallel_commander import create_parallel_commander, get_parallel_commander_cls
3 | from .coordinator import Coordinator
4 | 


--------------------------------------------------------------------------------
/ding/worker/learner/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_learner import BaseLearner, create_learner
2 | from .comm import BaseCommLearner, FlaskFileSystemLearner, create_comm_learner
3 | from .learner_hook import register_learner_hook, add_learner_hook, merge_hooks, LearnerHook, build_learner_hook_by_cfg
4 | 


--------------------------------------------------------------------------------
/ding/worker/learner/comm/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_comm_learner import BaseCommLearner, create_comm_learner
2 | from .flask_fs_learner import FlaskFileSystemLearner
3 | from .utils import NaiveLearner  # for test
4 | 


--------------------------------------------------------------------------------
/ding/worker/replay_buffer/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_buffer import IBuffer, create_buffer, get_buffer_cls
2 | from .naive_buffer import NaiveReplayBuffer, SequenceReplayBuffer
3 | from .advanced_buffer import AdvancedReplayBuffer
4 | from .episode_buffer import EpisodeReplayBuffer
5 | 


--------------------------------------------------------------------------------
/ding/worker/replay_buffer/episode_buffer.py:
--------------------------------------------------------------------------------
 1 | from typing import List
 2 | from ding.worker.replay_buffer import NaiveReplayBuffer
 3 | from ding.utils import BUFFER_REGISTRY
 4 | 
 5 | 
 6 | @BUFFER_REGISTRY.register('episode')
 7 | class EpisodeReplayBuffer(NaiveReplayBuffer):
 8 |     r"""
 9 |     Overview:
10 |         Episode replay buffer is a buffer to store complete episodes, i.e. Each element in episode buffer is an episode.
11 |         Some algorithms do not want to sample `batch_size` complete episodes, however, they want some transitions with
12 |         some fixed length. As a result, ``sample`` should be overwritten for those requirements.
13 |     Interface:
14 |         start, close, push, update, sample, clear, count, state_dict, load_state_dict, default_config
15 |     """
16 | 
17 |     @property
18 |     def episode_len(self) -> List[int]:
19 |         return [len(episode) for episode in self._data]
20 | 


--------------------------------------------------------------------------------
/ding/worker/replay_buffer/tests/conftest.py:
--------------------------------------------------------------------------------
 1 | from typing import List
 2 | import numpy as np
 3 | from ding.utils import save_file
 4 | 
 5 | ID_COUNT = 0
 6 | np.random.seed(1)
 7 | 
 8 | 
 9 | def generate_data(meta: bool = False) -> dict:
10 |     global ID_COUNT
11 |     ret = {'obs': np.random.randn(4), 'data_id': str(ID_COUNT)}
12 |     ID_COUNT += 1
13 |     p_weight = np.random.uniform()
14 |     if p_weight < 1 / 3:
15 |         pass  # no key 'priority'
16 |     elif p_weight < 2 / 3:
17 |         ret['priority'] = None
18 |     else:
19 |         ret['priority'] = np.random.uniform() + 1e-3
20 |     if not meta:
21 |         return ret
22 |     else:
23 |         obs = ret.pop('obs')
24 |         save_file(ret['data_id'], obs)
25 |         return ret
26 | 
27 | 
28 | def generate_data_list(count: int, meta: bool = False) -> List[dict]:
29 |     return [generate_data(meta) for _ in range(0, count)]
30 | 


--------------------------------------------------------------------------------
/ding/world_model/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_world_model import WorldModel, DynaWorldModel, DreamWorldModel, HybridWorldModel, \
2 |     get_world_model_cls, create_world_model
3 | 


--------------------------------------------------------------------------------
/ding/world_model/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/ding/world_model/model/__init__.py


--------------------------------------------------------------------------------
/ding/world_model/model/tests/test_ensemble.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import torch
 3 | from itertools import product
 4 | from ding.world_model.model.ensemble import EnsembleFC, EnsembleModel
 5 | 
 6 | # arguments
 7 | state_size = [16]
 8 | action_size = [16, 1]
 9 | reward_size = [1]
10 | args = list(product(*[state_size, action_size, reward_size]))
11 | 
12 | 
13 | @pytest.mark.unittest
14 | def test_EnsembleFC():
15 |     in_dim, out_dim, ensemble_size, B = 4, 8, 7, 64
16 |     fc = EnsembleFC(in_dim, out_dim, ensemble_size)
17 |     x = torch.randn(ensemble_size, B, in_dim)
18 |     y = fc(x)
19 |     assert y.shape == (ensemble_size, B, out_dim)
20 | 
21 | 
22 | @pytest.mark.parametrize('state_size, action_size, reward_size', args)
23 | def test_EnsembleModel(state_size, action_size, reward_size):
24 |     ensemble_size, B = 7, 64
25 |     model = EnsembleModel(state_size, action_size, reward_size, ensemble_size)
26 |     x = torch.randn(ensemble_size, B, state_size + action_size)
27 |     y = model(x)
28 |     assert len(y) == 2
29 |     assert y[0].shape == y[1].shape == (ensemble_size, B, state_size + reward_size)
30 | 


--------------------------------------------------------------------------------
/ding/world_model/model/tests/test_networks.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from itertools import product
4 | 


--------------------------------------------------------------------------------
/ding/world_model/tests/test_world_model_utils.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | from easydict import EasyDict
 3 | from ding.world_model.utils import get_rollout_length_scheduler
 4 | 
 5 | 
 6 | @pytest.mark.unittest
 7 | def test_get_rollout_length_scheduler():
 8 |     fake_cfg = EasyDict(
 9 |         type='linear',
10 |         rollout_start_step=20000,
11 |         rollout_end_step=150000,
12 |         rollout_length_min=1,
13 |         rollout_length_max=25,
14 |     )
15 |     scheduler = get_rollout_length_scheduler(fake_cfg)
16 |     assert scheduler(0) == 1
17 |     assert scheduler(19999) == 1
18 |     assert scheduler(150000) == 25
19 |     assert scheduler(1500000) == 25
20 | 


--------------------------------------------------------------------------------
/ding/world_model/utils.py:
--------------------------------------------------------------------------------
 1 | from easydict import EasyDict
 2 | from typing import Callable
 3 | 
 4 | 
 5 | def get_rollout_length_scheduler(cfg: EasyDict) -> Callable[[int], int]:
 6 |     """
 7 |     Overview:
 8 |         Get the rollout length scheduler that adapts rollout length based\
 9 |         on the current environment steps.
10 |     Returns:
11 |         - scheduler (:obj:`Callble`): The function that takes envstep and\
12 |           return the current rollout length.
13 |     """
14 |     if cfg.type == 'linear':
15 |         x0 = cfg.rollout_start_step
16 |         x1 = cfg.rollout_end_step
17 |         y0 = cfg.rollout_length_min
18 |         y1 = cfg.rollout_length_max
19 |         w = (y1 - y0) / (x1 - x0)
20 |         b = y0
21 |         return lambda x: int(min(max(w * (x - x0) + b, y0), y1))
22 |     elif cfg.type == 'constant':
23 |         return lambda x: cfg.rollout_length
24 |     else:
25 |         raise KeyError("not implemented key: {}".format(cfg.type))
26 | 


--------------------------------------------------------------------------------
/dizoo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/__init__.py


--------------------------------------------------------------------------------
/dizoo/atari/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/atari/__init__.py


--------------------------------------------------------------------------------
/dizoo/atari/atari.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/atari/atari.gif


--------------------------------------------------------------------------------
/dizoo/atari/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/atari/config/__init__.py


--------------------------------------------------------------------------------
/dizoo/atari/config/serial/__init__.py:
--------------------------------------------------------------------------------
1 | from dizoo.atari.config.serial.enduro import *
2 | from dizoo.atari.config.serial.pong import *
3 | from dizoo.atari.config.serial.qbert import *
4 | from dizoo.atari.config.serial.spaceinvaders import *
5 | from dizoo.atari.config.serial.asterix import *
6 | 


--------------------------------------------------------------------------------
/dizoo/atari/config/serial/asterix/__init__.py:
--------------------------------------------------------------------------------
1 | from .asterix_mdqn_config import asterix_mdqn_config, asterix_mdqn_create_config
2 | 


--------------------------------------------------------------------------------
/dizoo/atari/config/serial/enduro/__init__.py:
--------------------------------------------------------------------------------
1 | from .enduro_dqn_config import enduro_dqn_config, enduro_dqn_create_config
2 | 


--------------------------------------------------------------------------------
/dizoo/atari/config/serial/pong/__init__.py:
--------------------------------------------------------------------------------
1 | from .pong_dqn_config import pong_dqn_config, pong_dqn_create_config
2 | from .pong_dqn_envpool_config import pong_dqn_envpool_config, pong_dqn_envpool_create_config
3 | from .pong_dqfd_config import pong_dqfd_config, pong_dqfd_create_config
4 | 


--------------------------------------------------------------------------------
/dizoo/atari/config/serial/qbert/__init__.py:
--------------------------------------------------------------------------------
1 | from .qbert_dqn_config import qbert_dqn_config, qbert_dqn_create_config
2 | from .qbert_dqfd_config import qbert_dqfd_config, qbert_dqfd_create_config
3 | 


--------------------------------------------------------------------------------
/dizoo/atari/config/serial/spaceinvaders/__init__.py:
--------------------------------------------------------------------------------
1 | from .spaceinvaders_dqn_config import spaceinvaders_dqn_config, spaceinvaders_dqn_create_config
2 | from .spaceinvaders_dqfd_config import spaceinvaders_dqfd_config, spaceinvaders_dqfd_create_config
3 | 


--------------------------------------------------------------------------------
/dizoo/atari/entry/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/atari/entry/__init__.py


--------------------------------------------------------------------------------
/dizoo/atari/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .atari_env import AtariEnv, AtariEnvMR
2 | 


--------------------------------------------------------------------------------
/dizoo/beergame/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/beergame/__init__.py


--------------------------------------------------------------------------------
/dizoo/beergame/beergame.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/beergame/beergame.png


--------------------------------------------------------------------------------
/dizoo/beergame/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .clBeergame import clBeerGame
2 | from .beergame_core import BeerGame
3 | 


--------------------------------------------------------------------------------
/dizoo/bitflip/README.md:
--------------------------------------------------------------------------------
 1 | ## BitFlip Environment
 2 | A simple environment to flip a 01 sequence into a specific state. With the bits number increasing, the task becomes harder. 
 3 | Well suited for testing Hindsight Experience Replay.
 4 | 
 5 | ## DI-engine's HER on BitFlip
 6 | 
 7 | The table shows how many envsteps are needed at least to converge for PureDQN and HER-DQN implemented in DI-engine. '-' means no convergence in 20M envsteps.
 8 | 
 9 | | n_bit  | PureDQN | HER-DQN |
10 | | ------ | ------- | ------- |
11 | | 15     | -       | 150K    |
12 | | 20     | -       | 1.5M    |
13 | DI-engine's HER-DQN can converge 
14 | 
15 | You can refer to the RL algorithm doc for implementation and experiment details.
16 | 


--------------------------------------------------------------------------------
/dizoo/bitflip/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/bitflip/__init__.py


--------------------------------------------------------------------------------
/dizoo/bitflip/bitflip.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/bitflip/bitflip.gif


--------------------------------------------------------------------------------
/dizoo/bitflip/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .bitflip_her_dqn_config import bitflip_her_dqn_config, bitflip_her_dqn_create_config
2 | from .bitflip_pure_dqn_config import bitflip_pure_dqn_config, bitflip_pure_dqn_create_config
3 | 


--------------------------------------------------------------------------------
/dizoo/bitflip/entry/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/bitflip/entry/__init__.py


--------------------------------------------------------------------------------
/dizoo/bitflip/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .bitflip_env import BitFlipEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/bitflip/envs/test_bitfilp_env.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | from easydict import EasyDict
 3 | import numpy as np
 4 | from dizoo.bitflip.envs import BitFlipEnv
 5 | 
 6 | 
 7 | @pytest.mark.envtest
 8 | def test_bitfilp_env():
 9 |     n_bits = 10
10 |     env = BitFlipEnv(EasyDict({'n_bits': n_bits}))
11 |     env.seed(314)
12 |     assert env._seed == 314
13 |     obs = env.reset()
14 |     assert obs.shape == (2 * n_bits, )
15 |     for i in range(10):
16 |         # Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
17 |         # can generate legal random action.
18 |         if i < 5:
19 |             action = np.random.randint(0, n_bits, size=(1, ))
20 |         else:
21 |             action = env.random_action()
22 |         timestep = env.step(action)
23 |         assert timestep.obs.shape == (2 * n_bits, )
24 |         assert timestep.reward.shape == (1, )
25 | 


--------------------------------------------------------------------------------
/dizoo/box2d/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/box2d/__init__.py


--------------------------------------------------------------------------------
/dizoo/box2d/bipedalwalker/__init__.py:
--------------------------------------------------------------------------------
1 | from dizoo.box2d.bipedalwalker.config import *
2 | 


--------------------------------------------------------------------------------
/dizoo/box2d/bipedalwalker/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .bipedalwalker_sac_config import bipedalwalker_sac_config, bipedalwalker_sac_create_config
2 | 


--------------------------------------------------------------------------------
/dizoo/box2d/bipedalwalker/entry/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/box2d/bipedalwalker/entry/__init__.py


--------------------------------------------------------------------------------
/dizoo/box2d/bipedalwalker/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .bipedalwalker_env import BipedalWalkerEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/box2d/bipedalwalker/envs/test_bipedalwalker.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | from easydict import EasyDict
 3 | import numpy as np
 4 | from dizoo.box2d.bipedalwalker.envs import BipedalWalkerEnv
 5 | 
 6 | 
 7 | @pytest.mark.envtest
 8 | class TestBipedalWalkerEnv:
 9 | 
10 |     def test_naive(self):
11 |         env = BipedalWalkerEnv(EasyDict({'act_scale': True, 'rew_clip': True, 'replay_path': None}))
12 |         env.seed(123)
13 |         assert env._seed == 123
14 |         obs = env.reset()
15 |         assert obs.shape == (24, )
16 |         for i in range(10):
17 |             random_action = env.random_action()
18 |             timestep = env.step(random_action)
19 |             print(timestep)
20 |             assert isinstance(timestep.obs, np.ndarray)
21 |             assert isinstance(timestep.done, bool)
22 |             assert timestep.obs.shape == (24, )
23 |             assert timestep.reward.shape == (1, )
24 |             assert timestep.reward >= env.reward_space.low
25 |             assert timestep.reward <= env.reward_space.high
26 |             # assert isinstance(timestep, tuple)
27 |         print(env.observation_space, env.action_space, env.reward_space)
28 |         env.close()
29 | 


--------------------------------------------------------------------------------
/dizoo/box2d/bipedalwalker/original.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/box2d/bipedalwalker/original.gif


--------------------------------------------------------------------------------
/dizoo/box2d/carracing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/box2d/carracing/__init__.py


--------------------------------------------------------------------------------
/dizoo/box2d/carracing/car_racing.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/box2d/carracing/car_racing.gif


--------------------------------------------------------------------------------
/dizoo/box2d/carracing/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .carracing_dqn_config import carracing_dqn_config, carracing_dqn_create_config
2 | 


--------------------------------------------------------------------------------
/dizoo/box2d/carracing/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .carracing_env import CarRacingEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/box2d/carracing/envs/test_carracing_env.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import numpy as np
 3 | from easydict import EasyDict
 4 | from carracing_env import CarRacingEnv
 5 | 
 6 | 
 7 | @pytest.mark.envtest
 8 | @pytest.mark.parametrize('cfg', [EasyDict({'env_id': 'CarRacing-v2', 'continuous': False, 'act_scale': False})])
 9 | class TestCarRacing:
10 | 
11 |     def test_naive(self, cfg):
12 |         env = CarRacingEnv(cfg)
13 |         env.seed(314)
14 |         assert env._seed == 314
15 |         obs = env.reset()
16 |         assert obs.shape == (3, 96, 96)
17 |         for i in range(10):
18 |             random_action = env.random_action()
19 |             timestep = env.step(random_action)
20 |             print(timestep)
21 |             assert isinstance(timestep.obs, np.ndarray)
22 |             assert isinstance(timestep.done, bool)
23 |             assert timestep.obs.shape == (3, 96, 96)
24 |             assert timestep.reward.shape == (1, )
25 |             assert timestep.reward >= env.reward_space.low
26 |             assert timestep.reward <= env.reward_space.high
27 |         print(env.observation_space, env.action_space, env.reward_space)
28 |         env.close()
29 | 


--------------------------------------------------------------------------------
/dizoo/box2d/lunarlander/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/box2d/lunarlander/__init__.py


--------------------------------------------------------------------------------
/dizoo/box2d/lunarlander/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .lunarlander_dqn_config import lunarlander_dqn_config, lunarlander_dqn_create_config
2 | from .lunarlander_gail_dqn_config import lunarlander_dqn_gail_create_config, lunarlander_dqn_gail_config
3 | from .lunarlander_dqfd_config import lunarlander_dqfd_config, lunarlander_dqfd_create_config
4 | from .lunarlander_qrdqn_config import lunarlander_qrdqn_config, lunarlander_qrdqn_create_config
5 | from .lunarlander_trex_dqn_config import lunarlander_trex_dqn_config, lunarlander_trex_dqn_create_config
6 | from .lunarlander_trex_offppo_config import lunarlander_trex_ppo_config, lunarlander_trex_ppo_create_config
7 | 


--------------------------------------------------------------------------------
/dizoo/box2d/lunarlander/entry/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/box2d/lunarlander/entry/__init__.py


--------------------------------------------------------------------------------
/dizoo/box2d/lunarlander/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .lunarlander_env import LunarLanderEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/box2d/lunarlander/lunarlander.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/box2d/lunarlander/lunarlander.gif


--------------------------------------------------------------------------------
/dizoo/bsuite/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/bsuite/__init__.py


--------------------------------------------------------------------------------
/dizoo/bsuite/bsuite.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/bsuite/bsuite.png


--------------------------------------------------------------------------------
/dizoo/bsuite/config/__init__.py:
--------------------------------------------------------------------------------
1 | 
2 | 


--------------------------------------------------------------------------------
/dizoo/bsuite/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .bsuite_env import BSuiteEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/classic_control/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/classic_control/__init__.py


--------------------------------------------------------------------------------
/dizoo/classic_control/acrobot/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/classic_control/acrobot/__init__.py


--------------------------------------------------------------------------------
/dizoo/classic_control/acrobot/acrobot.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/classic_control/acrobot/acrobot.gif


--------------------------------------------------------------------------------
/dizoo/classic_control/acrobot/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .acrobot_dqn_config import acrobot_dqn_config, acrobot_dqn_create_config
2 | 


--------------------------------------------------------------------------------
/dizoo/classic_control/acrobot/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .acrobot_env import AcroBotEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/classic_control/cartpole/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/classic_control/cartpole/__init__.py


--------------------------------------------------------------------------------
/dizoo/classic_control/cartpole/cartpole.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/classic_control/cartpole/cartpole.gif


--------------------------------------------------------------------------------
/dizoo/classic_control/cartpole/config/parallel/__init__.py:
--------------------------------------------------------------------------------
1 | from .cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config, cartpole_dqn_system_config
2 | 


--------------------------------------------------------------------------------
/dizoo/classic_control/cartpole/entry/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/classic_control/cartpole/entry/__init__.py


--------------------------------------------------------------------------------
/dizoo/classic_control/cartpole/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .cartpole_env import CartPoleEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/classic_control/mountain_car/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/classic_control/mountain_car/__init__.py


--------------------------------------------------------------------------------
/dizoo/classic_control/mountain_car/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .mtcar_env import MountainCarEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/classic_control/pendulum/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/classic_control/pendulum/__init__.py


--------------------------------------------------------------------------------
/dizoo/classic_control/pendulum/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .pendulum_ddpg_config import pendulum_ddpg_config, pendulum_ddpg_create_config
2 | from .pendulum_td3_config import pendulum_td3_config, pendulum_td3_create_config
3 | from .pendulum_sac_config import pendulum_sac_config, pendulum_sac_create_config
4 | from .pendulum_d4pg_config import pendulum_d4pg_config, pendulum_d4pg_create_config
5 | from .pendulum_ppo_config import pendulum_ppo_config, pendulum_ppo_create_config
6 | from .pendulum_td3_bc_config import pendulum_td3_bc_config, pendulum_td3_bc_create_config
7 | from .pendulum_td3_data_generation_config import pendulum_td3_generation_config, pendulum_td3_generation_create_config
8 | 


--------------------------------------------------------------------------------
/dizoo/classic_control/pendulum/entry/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/classic_control/pendulum/entry/__init__.py


--------------------------------------------------------------------------------
/dizoo/classic_control/pendulum/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .pendulum_env import PendulumEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/classic_control/pendulum/pendulum.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/classic_control/pendulum/pendulum.gif


--------------------------------------------------------------------------------
/dizoo/cliffwalking/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/cliffwalking/__init__.py


--------------------------------------------------------------------------------
/dizoo/cliffwalking/cliff_walking.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/cliffwalking/cliff_walking.gif


--------------------------------------------------------------------------------
/dizoo/cliffwalking/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .cliffwalking_env import CliffWalkingEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/common/__init__.py


--------------------------------------------------------------------------------
/dizoo/common/policy/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/common/policy/__init__.py


--------------------------------------------------------------------------------
/dizoo/competitive_rl/README.md:
--------------------------------------------------------------------------------
1 | Environment "Competitive RL"'s original repo is https://github.com/cuhkrlcourse/competitive-rl.
2 | You can refer to it for guide on installation and usage.


--------------------------------------------------------------------------------
/dizoo/competitive_rl/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/competitive_rl/__init__.py


--------------------------------------------------------------------------------
/dizoo/competitive_rl/competitive_rl.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/competitive_rl/competitive_rl.gif


--------------------------------------------------------------------------------
/dizoo/competitive_rl/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .competitive_rl_env import CompetitiveRlEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/competitive_rl/envs/resources/pong/checkpoint-alphapong.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/competitive_rl/envs/resources/pong/checkpoint-alphapong.pkl


--------------------------------------------------------------------------------
/dizoo/competitive_rl/envs/resources/pong/checkpoint-medium.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/competitive_rl/envs/resources/pong/checkpoint-medium.pkl


--------------------------------------------------------------------------------
/dizoo/competitive_rl/envs/resources/pong/checkpoint-strong.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/competitive_rl/envs/resources/pong/checkpoint-strong.pkl


--------------------------------------------------------------------------------
/dizoo/competitive_rl/envs/resources/pong/checkpoint-weak.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/competitive_rl/envs/resources/pong/checkpoint-weak.pkl


--------------------------------------------------------------------------------
/dizoo/d4rl/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/d4rl/__init__.py


--------------------------------------------------------------------------------
/dizoo/d4rl/config/__init__.py:
--------------------------------------------------------------------------------
1 | # from .hopper_cql_config import hopper_cql_config
2 | # from .hopper_expert_cql_config import hopper_expert_cql_config
3 | # from .hopper_medium_cql_config import hopper_medium_cql_config
4 | 


--------------------------------------------------------------------------------
/dizoo/d4rl/d4rl.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/d4rl/d4rl.gif


--------------------------------------------------------------------------------
/dizoo/d4rl/entry/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/d4rl/entry/__init__.py


--------------------------------------------------------------------------------
/dizoo/d4rl/entry/d4rl_bcq_main.py:
--------------------------------------------------------------------------------
 1 | from ding.entry import serial_pipeline_offline
 2 | from ding.config import read_config
 3 | from pathlib import Path
 4 | 
 5 | 
 6 | def train(args):
 7 |     # launch from anywhere
 8 |     config = Path(__file__).absolute().parent.parent / 'config' / args.config
 9 |     config = read_config(str(config))
10 |     config[0].exp_name = config[0].exp_name.replace('0', str(args.seed))
11 |     serial_pipeline_offline(config, seed=args.seed)
12 | 
13 | 
14 | if __name__ == "__main__":
15 |     import argparse
16 | 
17 |     parser = argparse.ArgumentParser()
18 |     parser.add_argument('--seed', '-s', type=int, default=0)
19 |     parser.add_argument('--config', '-c', type=str, default='halfcheetah_medium_bcq_config.py')
20 |     args = parser.parse_args()
21 |     train(args)
22 | 


--------------------------------------------------------------------------------
/dizoo/d4rl/entry/d4rl_cql_main.py:
--------------------------------------------------------------------------------
 1 | from ding.entry import serial_pipeline_offline
 2 | from ding.config import read_config
 3 | from pathlib import Path
 4 | 
 5 | 
 6 | def train(args):
 7 |     # launch from anywhere
 8 |     config = Path(__file__).absolute().parent.parent / 'config' / args.config
 9 |     config = read_config(str(config))
10 |     config[0].exp_name = config[0].exp_name.replace('0', str(args.seed))
11 |     serial_pipeline_offline(config, seed=args.seed)
12 | 
13 | 
14 | if __name__ == "__main__":
15 |     import argparse
16 | 
17 |     parser = argparse.ArgumentParser()
18 |     parser.add_argument('--seed', '-s', type=int, default=10)
19 |     parser.add_argument('--config', '-c', type=str, default='hopper_expert_cql_config.py')
20 |     args = parser.parse_args()
21 |     train(args)
22 | 


--------------------------------------------------------------------------------
/dizoo/d4rl/entry/d4rl_edac_main.py:
--------------------------------------------------------------------------------
 1 | from ding.entry import serial_pipeline_offline
 2 | from ding.config import read_config
 3 | from pathlib import Path
 4 | 
 5 | 
 6 | def train(args):
 7 |     # launch from anywhere
 8 |     config = Path(__file__).absolute().parent.parent / 'config' / args.config
 9 |     config = read_config(str(config))
10 |     config[0].exp_name = config[0].exp_name.replace('0', str(args.seed))
11 |     serial_pipeline_offline(config, seed=args.seed)
12 | 
13 | 
14 | if __name__ == "__main__":
15 |     import argparse
16 | 
17 |     parser = argparse.ArgumentParser()
18 |     parser.add_argument('--seed', '-s', type=int, default=0)
19 |     parser.add_argument('--config', '-c', type=str, default='halfcheetah_medium_edac_config.py')
20 |     args = parser.parse_args()
21 |     train(args)
22 | 


--------------------------------------------------------------------------------
/dizoo/d4rl/entry/d4rl_iql_main.py:
--------------------------------------------------------------------------------
 1 | from ding.entry import serial_pipeline_offline
 2 | from ding.config import read_config
 3 | from pathlib import Path
 4 | 
 5 | 
 6 | def train(args):
 7 |     # launch from anywhere
 8 |     config = Path(__file__).absolute().parent.parent / 'config' / args.config
 9 |     config = read_config(str(config))
10 |     config[0].exp_name = config[0].exp_name.replace('0', str(args.seed))
11 |     serial_pipeline_offline(config, seed=args.seed)
12 | 
13 | 
14 | if __name__ == "__main__":
15 |     import argparse
16 | 
17 |     parser = argparse.ArgumentParser()
18 |     parser.add_argument('--seed', '-s', type=int, default=10)
19 |     parser.add_argument('--config', '-c', type=str, default='halfcheetah_medium_iql_config.py')
20 |     args = parser.parse_args()
21 |     train(args)
22 | 


--------------------------------------------------------------------------------
/dizoo/d4rl/entry/d4rl_pd_main.py:
--------------------------------------------------------------------------------
 1 | from ding.entry import serial_pipeline_offline
 2 | from ding.config import read_config
 3 | from pathlib import Path
 4 | 
 5 | 
 6 | def train(args):
 7 |     # launch from anywhere
 8 |     config = Path(__file__).absolute().parent.parent / 'config' / args.config
 9 |     config = read_config(str(config))
10 |     config[0].exp_name = config[0].exp_name.replace('0', str(args.seed))
11 |     serial_pipeline_offline(config, seed=args.seed)
12 | 
13 | 
14 | if __name__ == "__main__":
15 |     import argparse
16 | 
17 |     parser = argparse.ArgumentParser()
18 |     parser.add_argument('--seed', '-s', type=int, default=10)
19 |     parser.add_argument('--config', '-c', type=str, default='halfcheetah_medium_pd_config.py')
20 |     args = parser.parse_args()
21 |     train(args)
22 | 


--------------------------------------------------------------------------------
/dizoo/d4rl/entry/d4rl_td3_bc_main.py:
--------------------------------------------------------------------------------
 1 | from ding.entry import serial_pipeline_offline
 2 | from ding.config import read_config
 3 | from pathlib import Path
 4 | 
 5 | 
 6 | def train(args):
 7 |     # launch from anywhere
 8 |     config = Path(__file__).absolute().parent.parent / 'config' / args.config
 9 |     config = read_config(str(config))
10 |     config[0].exp_name = config[0].exp_name.replace('0', str(args.seed))
11 |     serial_pipeline_offline(config, seed=args.seed)
12 | 
13 | 
14 | if __name__ == "__main__":
15 |     import argparse
16 | 
17 |     parser = argparse.ArgumentParser()
18 |     parser.add_argument('--seed', '-s', type=int, default=10)
19 |     parser.add_argument('--config', '-c', type=str, default='hopper_medium_expert_td3bc_config.py')
20 |     args = parser.parse_args()
21 |     train(args)
22 | 


--------------------------------------------------------------------------------
/dizoo/d4rl/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .d4rl_env import D4RLEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/dmc2gym/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/dmc2gym/__init__.py


--------------------------------------------------------------------------------
/dizoo/dmc2gym/dmc2gym_cheetah.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/dmc2gym/dmc2gym_cheetah.png


--------------------------------------------------------------------------------
/dizoo/dmc2gym/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .dmc2gym_env import DMC2GymEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/evogym/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/evogym/__init__.py


--------------------------------------------------------------------------------
/dizoo/evogym/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .evogym_env import EvoGymEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/evogym/evogym.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/evogym/evogym.gif


--------------------------------------------------------------------------------
/dizoo/frozen_lake/FrozenLake.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/frozen_lake/FrozenLake.gif


--------------------------------------------------------------------------------
/dizoo/frozen_lake/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/frozen_lake/__init__.py


--------------------------------------------------------------------------------
/dizoo/frozen_lake/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .frozen_lake_dqn_config import main_config, create_config
2 | 


--------------------------------------------------------------------------------
/dizoo/frozen_lake/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .frozen_lake_env import FrozenLakeEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/gfootball/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gfootball/__init__.py


--------------------------------------------------------------------------------
/dizoo/gfootball/entry/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gfootball/entry/__init__.py


--------------------------------------------------------------------------------
/dizoo/gfootball/envs/__init__.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | 
3 | try:
4 |     from .gfootball_env import GfootballEnv
5 | except ImportError:
6 |     warnings.warn("not found gfootball env, please install it")
7 |     GfootballEnv = None
8 | 


--------------------------------------------------------------------------------
/dizoo/gfootball/envs/action/gfootball_action_runner.py:
--------------------------------------------------------------------------------
 1 | import copy
 2 | 
 3 | import numpy as np
 4 | 
 5 | from ding.envs.common import EnvElementRunner
 6 | from ding.envs.env.base_env import BaseEnv
 7 | from .gfootball_action import GfootballRawAction
 8 | 
 9 | 
10 | class GfootballRawActionRunner(EnvElementRunner):
11 | 
12 |     def _init(self, cfg, *args, **kwargs) -> None:
13 |         # set self._core and other state variable
14 |         self._core = GfootballRawAction(cfg)
15 | 
16 |     def get(self, engine: BaseEnv) -> np.array:
17 |         agent_action = copy.deepcopy(engine.agent_action)
18 |         return agent_action
19 | 
20 |     def reset(self) -> None:
21 |         pass
22 | 


--------------------------------------------------------------------------------
/dizoo/gfootball/envs/obs/gfootball_obs_runner.py:
--------------------------------------------------------------------------------
 1 | import copy
 2 | 
 3 | import numpy as np
 4 | 
 5 | from ding.envs.common import EnvElementRunner, EnvElement
 6 | from ding.envs.env.base_env import BaseEnv
 7 | from .gfootball_obs import PlayerObs, MatchObs
 8 | from ding.utils import deep_merge_dicts
 9 | 
10 | 
11 | class GfootballObsRunner(EnvElementRunner):
12 | 
13 |     def _init(self, cfg, *args, **kwargs) -> None:
14 |         # set self._core and other state variable
15 |         self._obs_match = MatchObs(cfg)
16 |         self._obs_player = PlayerObs(cfg)
17 |         self._core = self._obs_player  # placeholder
18 | 
19 |     def get(self, engine: BaseEnv) -> dict:
20 |         ret = copy.deepcopy(engine._football_obs)
21 |         # print(ret, type(ret))
22 |         assert isinstance(ret, dict)
23 |         match_obs = self._obs_match._to_agent_processor(ret)
24 |         players_obs = self._obs_player._to_agent_processor(ret)
25 |         return deep_merge_dicts(match_obs, players_obs)
26 | 
27 |     def reset(self) -> None:
28 |         pass
29 | 
30 |     # override
31 |     @property
32 |     def info(self):
33 |         return {'match': self._obs_match.info, 'player': self._obs_player.info}
34 | 


--------------------------------------------------------------------------------
/dizoo/gfootball/envs/reward/gfootball_reward_runner.py:
--------------------------------------------------------------------------------
 1 | import copy
 2 | 
 3 | import torch
 4 | 
 5 | from ding.envs.common import EnvElementRunner
 6 | from ding.envs.env.base_env import BaseEnv
 7 | from .gfootball_reward import GfootballReward
 8 | 
 9 | 
10 | class GfootballRewardRunner(EnvElementRunner):
11 | 
12 |     def _init(self, cfg, *args, **kwargs) -> None:
13 |         # set self._core and other state variable
14 |         self._core = GfootballReward(cfg)
15 |         self._cum_reward = 0.0
16 | 
17 |     def get(self, engine: BaseEnv) -> torch.tensor:
18 |         ret = copy.deepcopy(engine._reward_of_action)
19 |         self._cum_reward += ret
20 |         return self._core._to_agent_processor(ret)
21 | 
22 |     def reset(self) -> None:
23 |         self._cum_reward = 0.0
24 | 
25 |     @property
26 |     def cum_reward(self) -> torch.tensor:
27 |         return torch.FloatTensor([self._cum_reward])
28 | 


--------------------------------------------------------------------------------
/dizoo/gfootball/gfootball.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gfootball/gfootball.gif


--------------------------------------------------------------------------------
/dizoo/gfootball/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gfootball/model/__init__.py


--------------------------------------------------------------------------------
/dizoo/gfootball/model/bots/TamakEriFever/config.yaml:
--------------------------------------------------------------------------------
 1 | 
 2 | env_args:
 3 |     env: 'Football'
 4 |     source: 'football_ikki'
 5 |     frames_per_sec: 10  # we cannot change
 6 | 
 7 |     frame_skip: 0
 8 |     limit_steps: 3002
 9 | 
10 | train_args:
11 |     gamma_per_sec: 0.97
12 |     lambda_per_sec: 0.4
13 |     forward_steps: 64
14 |     compress_steps: 16
15 |     entropy_regularization: 1.3e-3
16 |     monte_carlo_rate: 1.0
17 |     update_episodes: 400
18 |     batch_size: 192
19 |     minimum_episodes: 3000
20 |     maximum_episodes: 30000
21 |     num_batchers: 23
22 |     eval_rate: 0.1
23 |     replay_rate: 0  # 0.1
24 |     supervised_weight: 0  # 0.1
25 |     record_dir: "records/"
26 |     randomized_start_rate: 0.3
27 |     randomized_start_max_steps: 400
28 |     reward_reset: True
29 |     worker:
30 |         num_gather: 2
31 |         num_process: 6
32 |     seed: 1800
33 |     restart_epoch: 1679
34 | 
35 | entry_args:
36 |     remote_host: ''
37 |     num_gather: 2
38 |     num_process: 6
39 | 
40 | eval_args:
41 |     remote_host: ''
42 | 
43 | 


--------------------------------------------------------------------------------
/dizoo/gfootball/model/bots/TamakEriFever/readme.md:
--------------------------------------------------------------------------------
1 | This is the kaggle gfootball competition 5 th place solution.
2 | 
3 | See https://www.kaggle.com/c/google-football/discussion/203412 from detail.
4 | 
5 | Thanks [kyazuki](https://www.kaggle.com/kyazuki) and  [@yuricat](https://www.kaggle.com/yuricat) who are generous to share their code.


--------------------------------------------------------------------------------
/dizoo/gfootball/model/bots/TamakEriFever/view_test.py:
--------------------------------------------------------------------------------
 1 | # Set up the Environment.
 2 | 
 3 | import time
 4 | 
 5 | from kaggle_environments import make
 6 | 
 7 | # opponent = "football/idle.py"
 8 | # opponent = "football/rulebaseC.py"
 9 | opponent = "builtin_ai"
10 | 
11 | video_title = "chain"
12 | video_path = "videos/" + video_title + "_" + opponent.split("/")[-1].replace(".py",
13 |                                                                              "") + str(int(time.time())) + ".webm"
14 | 
15 | env = make(
16 |     "football",
17 |     configuration={
18 |         "save_video": True,
19 |         "scenario_name": "11_vs_11_kaggle",
20 |         "running_in_notebook": False
21 |     },
22 |     info={"LiveVideoPath": video_path},
23 |     debug=True
24 | )
25 | output = env.run(["submission.py", opponent])[-1]
26 | 
27 | scores = [output[i]['observation']['players_raw'][0]['score'][0] for i in range(2)]
28 | print('Left player: score = %s, status = %s, info = %s' % (scores[0], output[0]['status'], output[0]['info']))
29 | print('Right player: score = %s, status = %s, info = %s' % (scores[1], output[1]['status'], output[1]['info']))
30 | 
31 | env.render(mode="human", width=800, height=600)
32 | 


--------------------------------------------------------------------------------
/dizoo/gfootball/model/bots/__init__.py:
--------------------------------------------------------------------------------
1 | from .kaggle_5th_place_model import FootballKaggle5thPlaceModel
2 | from .rule_based_bot_model import FootballRuleBaseModel
3 | 


--------------------------------------------------------------------------------
/dizoo/gfootball/policy/__init__.py:
--------------------------------------------------------------------------------
1 | from .ppo_lstm import PPOPolicy, PPOCommandModePolicy
2 | 


--------------------------------------------------------------------------------
/dizoo/gobigger_overview.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gobigger_overview.gif


--------------------------------------------------------------------------------
/dizoo/gym_anytrading/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_anytrading/__init__.py


--------------------------------------------------------------------------------
/dizoo/gym_anytrading/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .stocks_dqn_config import stocks_dqn_config, stocks_dqn_create_config
2 | 


--------------------------------------------------------------------------------
/dizoo/gym_anytrading/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .trading_env import TradingEnv, Actions, Positions
2 | from .stocks_env import StocksEnv
3 | 


--------------------------------------------------------------------------------
/dizoo/gym_anytrading/envs/data/README.md:
--------------------------------------------------------------------------------
1 | You can put stocks data here.
2 | Your data file needs to be named like "STOCKS_GOOGL.csv", which ends up with ".csv" suffix.


--------------------------------------------------------------------------------
/dizoo/gym_anytrading/envs/position.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_anytrading/envs/position.png


--------------------------------------------------------------------------------
/dizoo/gym_anytrading/envs/profit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_anytrading/envs/profit.png


--------------------------------------------------------------------------------
/dizoo/gym_anytrading/envs/statemachine.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_anytrading/envs/statemachine.png


--------------------------------------------------------------------------------
/dizoo/gym_anytrading/worker/__init__.py:
--------------------------------------------------------------------------------
1 | import imp
2 | from .trading_serial_evaluator import *
3 | 


--------------------------------------------------------------------------------
/dizoo/gym_hybrid/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_hybrid/__init__.py


--------------------------------------------------------------------------------
/dizoo/gym_hybrid/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_hybrid/config/__init__.py


--------------------------------------------------------------------------------
/dizoo/gym_hybrid/entry/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_hybrid/entry/__init__.py


--------------------------------------------------------------------------------
/dizoo/gym_hybrid/envs/README.md:
--------------------------------------------------------------------------------
 1 | # Modified gym-hybrid
 2 | 
 3 | The gym-hybrid directory is modified from https://github.com/thomashirtz/gym-hybrid.     
 4 | We add the HardMove environment additionally.  (Please refer to https://arxiv.org/abs/2109.05490 Section 5.1 for details about HardMove env.) 
 5 | 
 6 | Specifically, the modified gym-hybrid contains the following three types of environments:
 7 | 
 8 | - Moving-v0 
 9 | - Sliding-v0
10 | - HardMove-v0 
11 | 
12 | ### Install Guide
13 | 
14 | ```bash
15 | cd DI-engine/dizoo/gym_hybrid/envs/gym-hybrid
16 | pip install -e .
17 | ```
18 | 
19 | ## Acknowledgement
20 | 
21 | https://github.com/thomashirtz/gym-hybrid


--------------------------------------------------------------------------------
/dizoo/gym_hybrid/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .gym_hybrid_env import GymHybridEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/__init__.py:
--------------------------------------------------------------------------------
 1 | from gym.envs.registration import register
 2 | from gym_hybrid.environments import MovingEnv
 3 | from gym_hybrid.environments import SlidingEnv
 4 | from gym_hybrid.environments import HardMoveEnv
 5 | 
 6 | register(
 7 |     id='Moving-v0',
 8 |     entry_point='gym_hybrid:MovingEnv',
 9 | )
10 | register(
11 |     id='Sliding-v0',
12 |     entry_point='gym_hybrid:SlidingEnv',
13 | )
14 | register(
15 |     id='HardMove-v0',
16 |     entry_point='gym_hybrid:HardMoveEnv',
17 | )
18 | 


--------------------------------------------------------------------------------
/dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/bg.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/bg.jpg


--------------------------------------------------------------------------------
/dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/target.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/target.png


--------------------------------------------------------------------------------
/dizoo/gym_hybrid/envs/gym-hybrid/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | 
3 | setup(
4 |     name='gym_hybrid',
5 |     version='0.0.2',  # original gym_hybrid version='0.0.1'
6 |     packages=['gym_hybrid'],
7 |     install_requires=['gym', 'numpy'],
8 | )
9 | 


--------------------------------------------------------------------------------
/dizoo/gym_hybrid/envs/gym-hybrid/tests/hardmove.py:
--------------------------------------------------------------------------------
 1 | import time
 2 | import gym
 3 | import gym_hybrid
 4 | 
 5 | if __name__ == '__main__':
 6 |     env = gym.make('HardMove-v0')
 7 |     env.reset()
 8 | 
 9 |     ACTION_SPACE = env.action_space[0].n
10 |     PARAMETERS_SPACE = env.action_space[1].shape[0]
11 |     OBSERVATION_SPACE = env.observation_space.shape[0]
12 | 
13 |     done = False
14 |     while not done:
15 |         state, reward, done, info = env.step(env.action_space.sample())
16 |         print(f'State: {state} Reward: {reward} Done: {done}')
17 |         time.sleep(0.1)
18 | 


--------------------------------------------------------------------------------
/dizoo/gym_hybrid/envs/gym-hybrid/tests/moving.py:
--------------------------------------------------------------------------------
 1 | import time
 2 | import gym
 3 | import gym_hybrid
 4 | 
 5 | if __name__ == '__main__':
 6 |     env = gym.make('Moving-v0')
 7 |     env.reset()
 8 | 
 9 |     ACTION_SPACE = env.action_space[0].n
10 |     PARAMETERS_SPACE = env.action_space[1].shape[0]
11 |     OBSERVATION_SPACE = env.observation_space.shape[0]
12 | 
13 |     done = False
14 |     while not done:
15 |         state, reward, done, info = env.step(env.action_space.sample())
16 |         print(f'State: {state} Reward: {reward} Done: {done}')
17 |         time.sleep(0.1)
18 | 


--------------------------------------------------------------------------------
/dizoo/gym_hybrid/envs/gym-hybrid/tests/record.py:
--------------------------------------------------------------------------------
 1 | import gym
 2 | import gym_hybrid
 3 | 
 4 | if __name__ == '__main__':
 5 |     env = gym.make('Sliding-v0')
 6 |     env = gym.wrappers.Monitor(env, "./video", force=True)
 7 |     env.metadata["render.modes"] = ["human", "rgb_array"]
 8 |     env.reset()
 9 | 
10 |     done = False
11 |     while not done:
12 |         _, _, done, _ = env.step(env.action_space.sample())
13 | 
14 |     env.close()
15 | 


--------------------------------------------------------------------------------
/dizoo/gym_hybrid/envs/gym-hybrid/tests/render.py:
--------------------------------------------------------------------------------
 1 | import time
 2 | import gym
 3 | import gym_hybrid
 4 | 
 5 | if __name__ == '__main__':
 6 |     env = gym.make('Sliding-v0')
 7 |     env.reset()
 8 | 
 9 |     done = False
10 |     while not done:
11 |         _, _, done, _ = env.step(env.action_space.sample())
12 |         env.render()
13 |         time.sleep(0.1)
14 | 
15 |     time.sleep(1)
16 |     env.close()
17 | 


--------------------------------------------------------------------------------
/dizoo/gym_hybrid/envs/gym-hybrid/tests/sliding.py:
--------------------------------------------------------------------------------
 1 | import time
 2 | import gym
 3 | import gym_hybrid
 4 | 
 5 | if __name__ == '__main__':
 6 |     env = gym.make('Sliding-v0')
 7 |     env.reset()
 8 | 
 9 |     ACTION_SPACE = env.action_space[0].n
10 |     PARAMETERS_SPACE = env.action_space[1].shape[0]
11 |     OBSERVATION_SPACE = env.observation_space.shape[0]
12 | 
13 |     done = False
14 |     while not done:
15 |         state, reward, done, info = env.step(env.action_space.sample())
16 |         print(f'State: {state} Reward: {reward} Done: {done}')
17 |         time.sleep(0.1)
18 | 


--------------------------------------------------------------------------------
/dizoo/gym_hybrid/moving_v0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_hybrid/moving_v0.gif


--------------------------------------------------------------------------------
/dizoo/gym_pybullet_drones/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_pybullet_drones/__init__.py


--------------------------------------------------------------------------------
/dizoo/gym_pybullet_drones/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .gym_pybullet_drones_env import GymPybulletDronesEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/gym_pybullet_drones/envs/test_ding_env.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | from easydict import EasyDict
 3 | import gym_pybullet_drones
 4 | 
 5 | from ding.envs import BaseEnv, BaseEnvTimestep
 6 | from dizoo.gym_pybullet_drones.envs.gym_pybullet_drones_env import GymPybulletDronesEnv
 7 | 
 8 | 
 9 | @pytest.mark.envtest
10 | class TestGymPybulletDronesEnv:
11 | 
12 |     def test_naive(self):
13 |         cfg = {"env_id": "takeoff-aviary-v0"}
14 |         cfg = EasyDict(cfg)
15 |         env = GymPybulletDronesEnv(cfg)
16 | 
17 |         env.reset()
18 |         done = False
19 |         while not done:
20 |             action = env.action_space.sample()
21 |             assert action.shape[0] == 4
22 | 
23 |             for i in range(action.shape[0]):
24 |                 assert action[i] >= env.action_space.low[i] and action[i] <= env.action_space.high[i]
25 | 
26 |             obs, reward, done, info = env.step(action)
27 | 
28 |             assert obs.shape[0] == 12
29 |             for i in range(obs.shape[0]):
30 |                 assert obs[i] >= env.observation_space.low[i] and obs[i] <= env.observation_space.high[i]
31 | 
32 |             assert reward >= env.reward_space.low and reward <= env.reward_space.high
33 | 


--------------------------------------------------------------------------------
/dizoo/gym_pybullet_drones/envs/test_ori_env.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import gym
 3 | import numpy as np
 4 | 
 5 | import gym_pybullet_drones
 6 | 
 7 | 
 8 | @pytest.mark.envtest
 9 | class TestGymPybulletDronesOriEnv:
10 | 
11 |     def test_naive(self):
12 |         env = gym.make("takeoff-aviary-v0")
13 |         env.reset()
14 |         done = False
15 |         while not done:
16 |             action = env.action_space.sample()
17 |             assert action.shape[0] == 4
18 | 
19 |             for i in range(action.shape[0]):
20 |                 assert action[i] >= env.action_space.low[i] and action[i] <= env.action_space.high[i]
21 | 
22 |             obs, reward, done, info = env.step(action)
23 |             assert obs.shape[0] == 12
24 |             for i in range(obs.shape[0]):
25 |                 assert obs[i] >= env.observation_space.low[i] and obs[i] <= env.observation_space.high[i]
26 | 
27 |             assert reward >= env.reward_space.low and reward <= env.reward_space.high
28 | 


--------------------------------------------------------------------------------
/dizoo/gym_pybullet_drones/gym_pybullet_drones.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_pybullet_drones/gym_pybullet_drones.gif


--------------------------------------------------------------------------------
/dizoo/gym_soccer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_soccer/__init__.py


--------------------------------------------------------------------------------
/dizoo/gym_soccer/envs/README.md:
--------------------------------------------------------------------------------
 1 | # How to replay a log
 2 | 
 3 | 1. Set the log path to store episode logs by the following command:
 4 | 
 5 |    `env.enable_save_replay('./game_log')`
 6 | 
 7 | 2. After running the game, you can see some log files in the game_log directory.
 8 | 
 9 | 3. Execute the following command to replay the log file (*.rcg)
10 | 
11 |    ` env.replay_log("game_log/20211019011053-base_left_0-vs-base_right_0.rcg")`


--------------------------------------------------------------------------------
/dizoo/gym_soccer/envs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_soccer/envs/__init__.py


--------------------------------------------------------------------------------
/dizoo/gym_soccer/half_offensive.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/gym_soccer/half_offensive.gif


--------------------------------------------------------------------------------
/dizoo/image_classification/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/image_classification/__init__.py


--------------------------------------------------------------------------------
/dizoo/image_classification/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import ImageNetDataset
2 | from .sampler import DistributedSampler
3 | 


--------------------------------------------------------------------------------
/dizoo/image_classification/imagenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/image_classification/imagenet.png


--------------------------------------------------------------------------------
/dizoo/image_classification/policy/__init__.py:
--------------------------------------------------------------------------------
1 | from .policy import ImageClassificationPolicy
2 | 


--------------------------------------------------------------------------------
/dizoo/ising_env/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/ising_env/__init__.py


--------------------------------------------------------------------------------
/dizoo/ising_env/entry/ising_mfq_eval.py:
--------------------------------------------------------------------------------
 1 | from dizoo.ising_env.config.ising_mfq_config import main_config, create_config
 2 | from ding.entry import eval
 3 | 
 4 | 
 5 | def main():
 6 |     main_config.env.collector_env_num = 1
 7 |     main_config.env.evaluator_env_num = 1
 8 |     main_config.env.n_evaluator_episode = 1
 9 |     ckpt_path = './ckpt_best.pth.tar'
10 |     replay_path = './replay_videos'
11 |     eval((main_config, create_config), seed=1, load_path=ckpt_path, replay_path=replay_path)
12 | 
13 | 
14 | if __name__ == "__main__":
15 |     main()
16 | 


--------------------------------------------------------------------------------
/dizoo/ising_env/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .ising_model_env import IsingModelEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/ising_env/envs/ising_model/__init__.py:
--------------------------------------------------------------------------------
1 | import imp
2 | import os.path as osp
3 | 
4 | 
5 | def load(name):
6 |     pathname = osp.join(osp.dirname(__file__), name)
7 |     return imp.load_source('', pathname)
8 | 


--------------------------------------------------------------------------------
/dizoo/ising_env/envs/ising_model/multiagent/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/ising_env/envs/ising_model/multiagent/__init__.py


--------------------------------------------------------------------------------
/dizoo/ising_env/ising_env.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/ising_env/ising_env.gif


--------------------------------------------------------------------------------
/dizoo/league_demo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/league_demo/__init__.py


--------------------------------------------------------------------------------
/dizoo/league_demo/league_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/league_demo/league_demo.png


--------------------------------------------------------------------------------
/dizoo/mario/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/mario/__init__.py


--------------------------------------------------------------------------------
/dizoo/mario/mario.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/mario/mario.gif


--------------------------------------------------------------------------------
/dizoo/maze/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.registration import register
2 | 
3 | register(id='Maze', entry_point='dizoo.maze.envs:Maze')
4 | 


--------------------------------------------------------------------------------
/dizoo/maze/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .maze_env import Maze
2 | 


--------------------------------------------------------------------------------
/dizoo/maze/envs/test_maze_env.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import os
 3 | import numpy as np
 4 | from dizoo.maze.envs.maze_env import Maze
 5 | from easydict import EasyDict
 6 | import copy
 7 | 
 8 | 
 9 | @pytest.mark.envtest
10 | class TestMazeEnv:
11 | 
12 |     def test_maze(self):
13 |         env = Maze(EasyDict({'size': 16}))
14 |         env.seed(314)
15 |         assert env._seed == 314
16 |         obs = env.reset()
17 |         assert obs.shape == (16, 16, 3)
18 |         min_val, max_val = 0, 3
19 |         for i in range(100):
20 |             random_action = np.random.randint(min_val, max_val, size=(1, ))
21 |             timestep = env.step(random_action)
22 |             print(timestep)
23 |             print(timestep.obs.max())
24 |             assert isinstance(timestep.obs, np.ndarray)
25 |             assert isinstance(timestep.done, bool)
26 |             if timestep.done:
27 |                 env.reset()
28 |         env.close()
29 | 


--------------------------------------------------------------------------------
/dizoo/metadrive/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/metadrive/__init__.py


--------------------------------------------------------------------------------
/dizoo/metadrive/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/metadrive/config/__init__.py


--------------------------------------------------------------------------------
/dizoo/metadrive/env/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/metadrive/env/__init__.py


--------------------------------------------------------------------------------
/dizoo/metadrive/metadrive_env.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/metadrive/metadrive_env.gif


--------------------------------------------------------------------------------
/dizoo/minigrid/__init__.py:
--------------------------------------------------------------------------------
 1 | from gymnasium.envs.registration import register
 2 | 
 3 | register(id='MiniGrid-AKTDT-7x7-1-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_7x7_1')
 4 | 
 5 | register(id='MiniGrid-AKTDT-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure')
 6 | 
 7 | register(id='MiniGrid-AKTDT-13x13-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_13x13')
 8 | 
 9 | register(id='MiniGrid-AKTDT-13x13-1-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_13x13_1')
10 | 
11 | register(id='MiniGrid-AKTDT-19x19-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_19x19')
12 | 
13 | register(id='MiniGrid-AKTDT-19x19-3-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_19x19_3')
14 | 
15 | register(id='MiniGrid-NoisyTV-v0', entry_point='dizoo.minigrid.envs:NoisyTVEnv')
16 | 


--------------------------------------------------------------------------------
/dizoo/minigrid/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .minigrid_env import MiniGridEnv
2 | from dizoo.minigrid.envs.app_key_to_door_treasure import AppleKeyToDoorTreasure, AppleKeyToDoorTreasure_13x13, AppleKeyToDoorTreasure_19x19, AppleKeyToDoorTreasure_13x13_1, AppleKeyToDoorTreasure_19x19_3, AppleKeyToDoorTreasure_7x7_1
3 | from dizoo.minigrid.envs.noisy_tv import NoisyTVEnv
4 | 


--------------------------------------------------------------------------------
/dizoo/minigrid/minigrid.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/minigrid/minigrid.gif


--------------------------------------------------------------------------------
/dizoo/mujoco/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/mujoco/__init__.py


--------------------------------------------------------------------------------
/dizoo/mujoco/addition/install_mesa.sh:
--------------------------------------------------------------------------------
1 | mkdir -p ~/rpm
2 | yumdownloader --destdir ~/rpm --resolve mesa-libOSMesa.x86_64 mesa-libOSMesa-devel.x86_64 patchelf.x86_64
3 | cd ~/rpm
4 | for rpm in `ls`; do rpm2cpio $rpm | cpio -id ; done
5 | 


--------------------------------------------------------------------------------
/dizoo/mujoco/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/mujoco/config/__init__.py


--------------------------------------------------------------------------------
/dizoo/mujoco/entry/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/mujoco/entry/__init__.py


--------------------------------------------------------------------------------
/dizoo/mujoco/entry/mujoco_cql_generation_main.py:
--------------------------------------------------------------------------------
 1 | from dizoo.mujoco.config.hopper_sac_data_generation_config import main_config, create_config
 2 | from ding.entry import collect_demo_data, eval
 3 | import torch
 4 | import copy
 5 | 
 6 | 
 7 | def eval_ckpt(args):
 8 |     config = copy.deepcopy([main_config, create_config])
 9 |     eval(config, seed=args.seed, load_path=main_config.policy.learn.learner.hook.load_ckpt_before_run)
10 | 
11 | 
12 | def generate(args):
13 |     config = copy.deepcopy([main_config, create_config])
14 |     state_dict = torch.load(main_config.policy.learn.learner.load_path, map_location='cpu')
15 |     collect_demo_data(
16 |         config,
17 |         collect_count=main_config.policy.other.replay_buffer.replay_buffer_size,
18 |         seed=args.seed,
19 |         expert_data_path=main_config.policy.collect.save_path,
20 |         state_dict=state_dict
21 |     )
22 | 
23 | 
24 | if __name__ == "__main__":
25 |     import argparse
26 | 
27 |     parser = argparse.ArgumentParser()
28 |     parser.add_argument('--seed', '-s', type=int, default=0)
29 |     args = parser.parse_args()
30 | 
31 |     eval_ckpt(args)
32 |     generate(args)
33 | 


--------------------------------------------------------------------------------
/dizoo/mujoco/entry/mujoco_cql_main.py:
--------------------------------------------------------------------------------
 1 | from dizoo.mujoco.config.hopper_cql_config import main_config, create_config
 2 | from ding.entry import serial_pipeline_offline
 3 | 
 4 | 
 5 | def train(args):
 6 |     config = [main_config, create_config]
 7 |     serial_pipeline_offline(config, seed=args.seed)
 8 | 
 9 | 
10 | if __name__ == "__main__":
11 |     import argparse
12 | 
13 |     parser = argparse.ArgumentParser()
14 |     parser.add_argument('--seed', '-s', type=int, default=10)
15 |     args = parser.parse_args()
16 | 
17 |     train(args)
18 | 


--------------------------------------------------------------------------------
/dizoo/mujoco/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .mujoco_env import MujocoEnv
2 | from .mujoco_disc_env import MujocoDiscEnv
3 | 


--------------------------------------------------------------------------------
/dizoo/mujoco/envs/test/test_mujoco_gym_env.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import gym
 3 | 
 4 | 
 5 | @pytest.mark.envtest
 6 | def test_shapes():
 7 |     from dizoo.mujoco.envs import mujoco_gym_env
 8 |     ant = gym.make('AntTruncatedObs-v2')
 9 |     assert ant.observation_space.shape == (27, )
10 |     assert ant.action_space.shape == (8, )
11 |     humanoid = gym.make('HumanoidTruncatedObs-v2')
12 |     assert humanoid.observation_space.shape == (45, )
13 |     assert humanoid.action_space.shape == (17, )
14 | 


--------------------------------------------------------------------------------
/dizoo/mujoco/mujoco.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/mujoco/mujoco.gif


--------------------------------------------------------------------------------
/dizoo/multiagent_mujoco/README.md:
--------------------------------------------------------------------------------
1 | ## Multi Agent Mujoco Env
2 | 
3 | Multi Agent Mujoco is an environment for Continuous Multi-Agent Robotic Control, based on OpenAI's Mujoco Gym environments.
4 | 
5 | The environment is described in the paper [Deep Multi-Agent Reinforcement Learning for Decentralized Continuous Cooperative Control](https://arxiv.org/abs/2003.06709) by Christian Schroeder de Witt, Bei Peng, Pierre-Alexandre Kamienny, Philip Torr, Wendelin Böhmer and Shimon Whiteson, Torr Vision Group and Whiteson Research Lab, University of Oxford, 2020
6 | 
7 | You can find more details in [Multi-Agent Mujoco Environment](https://github.com/schroederdewitt/multiagent_mujoco)
8 | 


--------------------------------------------------------------------------------
/dizoo/multiagent_mujoco/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/multiagent_mujoco/__init__.py


--------------------------------------------------------------------------------
/dizoo/multiagent_mujoco/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .mujoco_multi import MujocoMulti
2 | from .coupled_half_cheetah import CoupledHalfCheetah
3 | from .manyagent_swimmer import ManyAgentSwimmerEnv
4 | from .manyagent_ant import ManyAgentAntEnv
5 | 


--------------------------------------------------------------------------------
/dizoo/multiagent_mujoco/envs/assets/.gitignore:
--------------------------------------------------------------------------------
1 | *.auto.xml
2 | 


--------------------------------------------------------------------------------
/dizoo/multiagent_mujoco/envs/assets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/multiagent_mujoco/envs/assets/__init__.py


--------------------------------------------------------------------------------
/dizoo/overcooked/README.md:
--------------------------------------------------------------------------------
1 | This is the overcooked-ai environment compatiable to DI-engine.
2 | 
3 | The origin code is referenced on [Overcooked-AI](https://github.com/HumanCompatibleAI/overcooked_ai), which is a benchmark environment for fully cooperative human-AI task performance, based on the wildly popular video game [Overcooked](http://www.ghosttowngames.com/overcooked/).


--------------------------------------------------------------------------------
/dizoo/overcooked/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/overcooked/__init__.py


--------------------------------------------------------------------------------
/dizoo/overcooked/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .overcooked_demo_ppo_config import overcooked_demo_ppo_config
2 | 


--------------------------------------------------------------------------------
/dizoo/overcooked/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .overcooked_env import OvercookEnv, OvercookGameEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/overcooked/overcooked.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/overcooked/overcooked.gif


--------------------------------------------------------------------------------
/dizoo/petting_zoo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/petting_zoo/__init__.py


--------------------------------------------------------------------------------
/dizoo/petting_zoo/config/__init__.py:
--------------------------------------------------------------------------------
 1 | from .ptz_simple_spread_atoc_config import ptz_simple_spread_atoc_config, ptz_simple_spread_atoc_create_config
 2 | from .ptz_simple_spread_collaq_config import ptz_simple_spread_collaq_config, ptz_simple_spread_collaq_create_config
 3 | from .ptz_simple_spread_coma_config import ptz_simple_spread_coma_config, ptz_simple_spread_coma_create_config
 4 | from .ptz_simple_spread_mappo_config import ptz_simple_spread_mappo_config, ptz_simple_spread_mappo_create_config
 5 | from .ptz_simple_spread_qmix_config import ptz_simple_spread_qmix_config, ptz_simple_spread_qmix_create_config
 6 | from .ptz_simple_spread_qtran_config import ptz_simple_spread_qtran_config, ptz_simple_spread_qtran_create_config
 7 | from .ptz_simple_spread_vdn_config import ptz_simple_spread_vdn_config, ptz_simple_spread_vdn_create_config
 8 | from .ptz_simple_spread_wqmix_config import ptz_simple_spread_wqmix_config, ptz_simple_spread_wqmix_create_config
 9 | from .ptz_simple_spread_madqn_config import ptz_simple_spread_madqn_config, ptz_simple_spread_madqn_create_config  # noqa
10 | 


--------------------------------------------------------------------------------
/dizoo/petting_zoo/entry/ptz_simple_spread_eval.py:
--------------------------------------------------------------------------------
 1 | from dizoo.petting_zoo.config.ptz_simple_spread_mappo_config import main_config, create_config
 2 | from ding.entry import eval
 3 | 
 4 | 
 5 | def main():
 6 |     ckpt_path = './ckpt_best.pth.tar'
 7 |     replay_path = './replay_videos'
 8 |     eval((main_config, create_config), seed=0, load_path=ckpt_path, replay_path=replay_path)
 9 | 
10 | 
11 | if __name__ == "__main__":
12 |     main()
13 | 


--------------------------------------------------------------------------------
/dizoo/petting_zoo/envs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/petting_zoo/envs/__init__.py


--------------------------------------------------------------------------------
/dizoo/petting_zoo/petting_zoo_mpe_simple_spread.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/petting_zoo/petting_zoo_mpe_simple_spread.gif


--------------------------------------------------------------------------------
/dizoo/pomdp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/pomdp/__init__.py


--------------------------------------------------------------------------------
/dizoo/pomdp/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .atari_env import PomdpAtariEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/pomdp/envs/test_atari_env.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import gym
 3 | import numpy as np
 4 | from easydict import EasyDict
 5 | from dizoo.pomdp.envs import PomdpAtariEnv
 6 | 
 7 | 
 8 | @pytest.mark.envtest
 9 | def test_env():
10 |     cfg = {
11 |         'env_id': 'Pong-ramNoFrameskip-v4',
12 |         'frame_stack': 4,
13 |         'is_train': True,
14 |         'warp_frame': False,
15 |         'clip_reward': False,
16 |         'use_ram': True,
17 |         'render': False,
18 |         'pomdp': dict(noise_scale=0.001, zero_p=0.1, reward_noise=0.01, duplicate_p=0.2)
19 |     }
20 | 
21 |     cfg = EasyDict(cfg)
22 |     pong_env = PomdpAtariEnv(cfg)
23 |     pong_env.seed(0)
24 |     obs = pong_env.reset()
25 |     act_dim = pong_env.info().act_space.shape[0]
26 |     while True:
27 |         random_action = np.random.choice(range(act_dim), size=(1, ))
28 |         timestep = pong_env.step(random_action)
29 |         assert timestep.obs.shape == (512, )
30 |         assert timestep.reward.shape == (1, )
31 |         # assert isinstance(timestep, tuple)
32 |         if timestep.done:
33 |             assert 'eval_episode_return' in timestep.info, timestep.info
34 |             break
35 |     pong_env.close()
36 | 


--------------------------------------------------------------------------------
/dizoo/procgen/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/procgen/__init__.py


--------------------------------------------------------------------------------
/dizoo/procgen/coinrun.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/procgen/coinrun.gif


--------------------------------------------------------------------------------
/dizoo/procgen/coinrun.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/procgen/coinrun.png


--------------------------------------------------------------------------------
/dizoo/procgen/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .coinrun_dqn_config import main_config, create_config
2 | from .coinrun_ppo_config import main_config, create_config
3 | 


--------------------------------------------------------------------------------
/dizoo/procgen/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .procgen_env import ProcgenEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/procgen/envs/test_coinrun_env.py:
--------------------------------------------------------------------------------
 1 | import pytest
 2 | import numpy as np
 3 | from easydict import EasyDict
 4 | from dizoo.procgen.envs import ProcgenEnv
 5 | 
 6 | 
 7 | @pytest.mark.envtest
 8 | class TestProcgenEnv:
 9 | 
10 |     def test_naive(self):
11 |         env = ProcgenEnv(EasyDict({}))
12 |         env.seed(314)
13 |         assert env._seed == 314
14 |         obs = env.reset()
15 |         assert obs.shape == (3, 64, 64)
16 |         for i in range(10):
17 |             random_action = np.tanh(np.random.random(1))
18 |             timestep = env.step(random_action)
19 |             assert timestep.obs.shape == (3, 64, 64)
20 |             assert timestep.reward.shape == (1, )
21 |             assert timestep.reward >= env.info().rew_space.value['min']
22 |             assert timestep.reward <= env.info().rew_space.value['max']
23 |             # assert isinstance(timestep, tuple)
24 |         print(env.info())
25 |         env.close()
26 | 


--------------------------------------------------------------------------------
/dizoo/procgen/maze.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/procgen/maze.gif


--------------------------------------------------------------------------------
/dizoo/procgen/maze.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/procgen/maze.png


--------------------------------------------------------------------------------
/dizoo/pybullet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/pybullet/__init__.py


--------------------------------------------------------------------------------
/dizoo/pybullet/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .pybullet_env import PybulletEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/pybullet/pybullet.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/pybullet/pybullet.gif


--------------------------------------------------------------------------------
/dizoo/rocket/README.md:
--------------------------------------------------------------------------------
 1 | # Install
 2 | 
 3 | ```shell
 4 | pip install git+https://github.com/nighood/rocket-recycling@master#egg=rocket_recycling
 5 | ```
 6 | 
 7 | # Chek Install
 8 | ```shell
 9 | pytest -sv test_rocket_env.py
10 | ```
11 | 


--------------------------------------------------------------------------------
/dizoo/rocket/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/rocket/__init__.py


--------------------------------------------------------------------------------
/dizoo/rocket/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/rocket/config/__init__.py


--------------------------------------------------------------------------------
/dizoo/rocket/entry/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/rocket/entry/__init__.py


--------------------------------------------------------------------------------
/dizoo/rocket/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .rocket_env import RocketEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/slime_volley/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/slime_volley/__init__.py


--------------------------------------------------------------------------------
/dizoo/slime_volley/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .slime_volley_env import SlimeVolleyEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/slime_volley/slime_volley.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/slime_volley/slime_volley.gif


--------------------------------------------------------------------------------
/dizoo/smac/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/__init__.py


--------------------------------------------------------------------------------
/dizoo/smac/envs/__init__.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | 
3 | from .fake_smac_env import FakeSMACEnv
4 | try:
5 |     from .smac_env import SMACEnv
6 | except ImportError:
7 |     warnings.warn("not found pysc2 env, please install it")
8 |     SMACEnv = None
9 | 


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/README.md:
--------------------------------------------------------------------------------
 1 | # Notes on Two Player Maps
 2 | 
 3 | Before starting, you need to do the following things:
 4 | 
 5 | 1. copy the maps in `maps/SMAC_Maps_two_player/*.SC2Map` to the directory `StarCraft II/Maps/SMAC_Maps_two_player/`.
 6 | 2. copy the maps in `maps/SMAC_Maps/*.SC2Map` to the directory `StarCraft II/Maps/SMAC_Maps/`.
 7 | 
 8 | A convenient bash script is:
 9 | 
10 | ```bash
11 | # In linux
12 | cp -r SMAC_Maps_two_player/ ~/StarCraftII/Maps/
13 | cp -r SMAC_Maps/ ~/StarCraftII/Maps/
14 | ```
15 | 
16 | 


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/10m_vs_11m.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/10m_vs_11m.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/1c3s5z.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/1c3s5z.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/25m.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/25m.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/27m_vs_30m.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/27m_vs_30m.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/2c_vs_64zg.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/2c_vs_64zg.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/2m_vs_1z.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/2m_vs_1z.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/2s3z.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/2s3z.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/2s_vs_1sc.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/2s_vs_1sc.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/3m.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/3m.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/3s5z.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/3s5z.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/3s5z_vs_3s6z.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/3s5z_vs_3s6z.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/3s_vs_3z.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/3s_vs_3z.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/3s_vs_4z.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/3s_vs_4z.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/3s_vs_5z.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/3s_vs_5z.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/5m_vs_6m.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/5m_vs_6m.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/6h_vs_8z.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/6h_vs_8z.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/8m.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/8m.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/8m_vs_9m.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/8m_vs_9m.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/MMM.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/MMM.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/MMM2.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/MMM2.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/__init__.py


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/bane_vs_bane.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/bane_vs_bane.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/corridor.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/corridor.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/infestor_viper.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/infestor_viper.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps/so_many_baneling.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps/so_many_baneling.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps_two_player/3m.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps_two_player/3m.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps_two_player/3s5z.SC2Map:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps_two_player/3s5z.SC2Map


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/SMAC_Maps_two_player/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/SMAC_Maps_two_player/__init__.py


--------------------------------------------------------------------------------
/dizoo/smac/envs/maps/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/envs/maps/__init__.py


--------------------------------------------------------------------------------
/dizoo/smac/smac.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/smac/smac.gif


--------------------------------------------------------------------------------
/dizoo/sokoban/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/sokoban/__init__.py


--------------------------------------------------------------------------------
/dizoo/sokoban/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .sokoban_env import SokobanEnv
2 | 


--------------------------------------------------------------------------------
/dizoo/sokoban/envs/test_sokoban_env.py:
--------------------------------------------------------------------------------
 1 | from easydict import EasyDict
 2 | import pytest
 3 | import numpy as np
 4 | from dizoo.sokoban.envs.sokoban_env import SokobanEnv
 5 | 
 6 | 
 7 | @pytest.mark.envtest
 8 | class TestSokoban:
 9 | 
10 |     def test_sokoban(self):
11 |         env = SokobanEnv(EasyDict({'env_id': 'Sokoban-v0'}))
12 |         env.reset()
13 |         for i in range(100):
14 |             action = np.random.randint(8)
15 |             timestep = env.step(np.array(action))
16 |             print(timestep)
17 |             print(timestep.obs.max())
18 |             assert isinstance(timestep.obs, np.ndarray)
19 |             assert isinstance(timestep.done, bool)
20 |             assert timestep.obs.shape == (160, 160, 3)
21 |             print(timestep.info)
22 |             assert timestep.reward.shape == (1, )
23 |             if timestep.done:
24 |                 env.reset()
25 |         env.close()
26 | 


--------------------------------------------------------------------------------
/dizoo/tabmwp/README.md:
--------------------------------------------------------------------------------
 1 | ## TabMWP Env
 2 | 
 3 | ## Dataset
 4 | 
 5 | The **TabMWP** dataset contains 38,431 tabular math word problems. Each question in **TabMWP** is aligned with a tabular context, which is presented as an image, semi-structured text, and a structured table. There are two types of questions: *free-text* and *multi-choice*, and each problem is annotated with gold solutions to reveal the multi-step reasoning process.
 6 | 
 7 | The environment is described in the paper [Dynamic Prompt Learning via Policy Gradient for Semi-structured Mathematical Reasoning](https://arxiv.org/abs/2209.14610) by Pan Lu, Liang Qiu, Kai-Wei Chang, Ying Nian Wu, Song-Chun Zhu, Tanmay Rajpurohit, Peter Clark, Ashwin Kalyan, 2023.
 8 | 
 9 | You can find more details in [Prompt PG](https://github.com/lupantech/PromptPG)
10 | 
11 | ## Benchmark
12 | 
13 | - We collect the responses of GPT-3 using a reduced dataset with 80 training samples and 16 candidates. In this way, there is no need for users to interact with GPT-3 using the API-key of openai.
14 | - You can directly reproduce the benchmark by running ``python dizoo/tabmwp/configs/tabmwp_pg_config.py``
15 | 
16 | ![origin](./benchmark.png)
17 | 


--------------------------------------------------------------------------------
/dizoo/tabmwp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/tabmwp/__init__.py


--------------------------------------------------------------------------------
/dizoo/tabmwp/benchmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/tabmwp/benchmark.png


--------------------------------------------------------------------------------
/dizoo/tabmwp/envs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/tabmwp/envs/__init__.py


--------------------------------------------------------------------------------
/dizoo/tabmwp/envs/test_tabmwp_env.py:
--------------------------------------------------------------------------------
 1 | from easydict import EasyDict
 2 | import pytest
 3 | from dizoo.tabmwp.envs.tabmwp_env import TabMWP
 4 | 
 5 | 
 6 | @pytest.mark.envtest
 7 | class TestSokoban:
 8 | 
 9 |     def test_tabmwp(self):
10 |         config = dict(
11 |             cand_number=20,
12 |             train_number=100,
13 |             engine='text-davinci-002',
14 |             temperature=0.,
15 |             max_tokens=512,
16 |             top_p=1.,
17 |             frequency_penalty=0.,
18 |             presence_penalty=0.,
19 |             option_inds=["A", "B", "C", "D", "E", "F"],
20 |             api_key='',
21 |         )
22 |         config = EasyDict(config)
23 |         env = TabMWP(config)
24 |         env.seed(0)
25 |         env.close()
26 | 


--------------------------------------------------------------------------------
/dizoo/tabmwp/tabmwp.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/tabmwp/tabmwp.jpeg


--------------------------------------------------------------------------------
/dizoo/taxi/Taxi-v3_episode_0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/DI-engine/f6ee768d135b189ca1a4736d4179db36dae019ad/dizoo/taxi/Taxi-v3_episode_0.gif


--------------------------------------------------------------------------------
/dizoo/taxi/__init__.py:
--------------------------------------------------------------------------------
1 | from .envs import *
2 | 


--------------------------------------------------------------------------------
/dizoo/taxi/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .taxi_dqn_config import main_config, create_config
2 | 


--------------------------------------------------------------------------------
/dizoo/taxi/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .taxi_env import TaxiEnv
2 | 


--------------------------------------------------------------------------------
/docker/Dockerfile.rpc:
--------------------------------------------------------------------------------
 1 | FROM snsao/pytorch:tensorpipe-fix as base
 2 | 
 3 | WORKDIR /ding
 4 | 
 5 | RUN apt update \
 6 |     && apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git vim gcc \g++ make wget locales dnsutils -y \
 7 |     && apt clean \
 8 |     && rm -rf /var/cache/apt/* \
 9 |     && sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen \
10 |     && locale-gen
11 | 
12 | ENV LANG        en_US.UTF-8
13 | ENV LANGUAGE    en_US:UTF-8
14 | ENV LC_ALL      en_US.UTF-8
15 | 
16 | ADD setup.py setup.py
17 | ADD dizoo dizoo
18 | ADD ding ding
19 | ADD README.md README.md
20 | 
21 | RUN python3 -m pip install --upgrade pip \
22 |     && python3 -m pip install --ignore-installed 'PyYAML<6.0' \
23 |     && python3 -m pip install --no-cache-dir .[fast,test]
24 | 


--------------------------------------------------------------------------------
/format.sh:
--------------------------------------------------------------------------------
 1 | #!/usr/bin/env bash
 2 | # Usage: at the root dir >> bash scripts/format.sh .
 3 | 
 4 | # Check yapf version. (20200318 latest is 0.29.0. Format might be changed in future version.)
 5 | ver=$(yapf --version)
 6 | if ! echo $ver | grep -q 0.29.0; then
 7 |   echo "Wrong YAPF version installed: 0.29.0 is required, not $ver. $YAPF_DOWNLOAD_COMMAND_MSG"
 8 |   exit 1
 9 | fi
10 | 
11 | yapf --in-place --recursive -p --verbose --style .style.yapf $1
12 | 
13 | if [[ "$2" == '--test' ]]; then # Only for CI usage, user should not use --test flag.
14 |   if ! git diff --quiet &>/dev/null; then
15 |     echo '*** You have not reformatted your codes! Please run [bash format.sh] at root directory before commit! Thanks! ***'
16 |     exit 1
17 |   else
18 |     echo "Code style test passed!"
19 |   fi
20 | fi
21 | 


--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
 1 | [pytest]
 2 | execution_timeout = 600
 3 | markers =
 4 |     unittest
 5 |     platformtest
 6 |     envtest
 7 |     cudatest
 8 |     algotest
 9 |     benchmark
10 |     envpooltest
11 |     other
12 |     tmp
13 | 
14 | norecursedirs = ding/hpc_rl/tests
15 | 


--------------------------------------------------------------------------------