├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── PRINCIPLES.md ├── README.md ├── STYLE_GUIDE.md ├── broken_tests.txt ├── broken_tests_gpu.txt ├── docs ├── _book.yaml ├── _index.yaml ├── images │ ├── actor_learner_distributed_architecture.png │ ├── cql_sac_readme │ │ ├── antmaze-large-diverse-v0_graph.png │ │ ├── antmaze-large-play-v0_graph.png │ │ ├── antmaze-medium-diverse-v0_graph.png │ │ ├── antmaze-medium-play-v0_graph.png │ │ ├── halfcheetah-medium-expert-v0_graph.png │ │ ├── halfcheetah-medium-v0_graph.png │ │ ├── hopper-medium-expert-v0_graph.png │ │ ├── hopper-medium-v0_graph.png │ │ ├── walker2d-medium-expert-v0_graph.png │ │ └── walker2d-medium-v0_graph.png │ ├── dqn_readme │ │ └── pong-v0_graph.png │ ├── learner_detail.png │ ├── ppo_readme │ │ ├── halfcheetah-v2_graph.png │ │ ├── hopper-v2_graph.png │ │ ├── inverteddoublependulum-v2_graph.png │ │ ├── invertedpendulum-v2_graph.png │ │ ├── reacher-v2_graph.png │ │ ├── swimmer-v2_graph.png │ │ └── walker2d-v2_graph.png │ ├── rlds │ │ ├── flatten_rlds.png │ │ ├── pairs_to_trajectories.png │ │ ├── rlds_step_to_trajectory.png │ │ └── rlds_to_pairs.png │ └── sac_readme │ │ ├── ant-v2_graph.png │ │ ├── halfcheetah-v2_graph.png │ │ ├── hopper-v2_graph.png │ │ ├── humanoid-v2_graph.png │ │ └── walker2d-v2_graph.png ├── overview.md ├── rlds_to_reverb.md └── tutorials │ ├── 0_intro_rl.ipynb │ ├── 10_checkpointer_policysaver_tutorial.ipynb │ ├── 1_dqn_tutorial.ipynb │ ├── 2_environments_tutorial.ipynb │ ├── 3_policies_tutorial.ipynb │ ├── 4_drivers_tutorial.ipynb │ ├── 5_replay_buffers_tutorial.ipynb │ ├── 6_reinforce_tutorial.ipynb │ ├── 7_SAC_minitaur_tutorial.ipynb │ ├── 8_networks_tutorial.ipynb │ ├── 9_c51_tutorial.ipynb │ ├── bandits_tutorial.ipynb │ ├── colab_kernel_init.py │ ├── images │ ├── c51_distribution.png │ ├── cartpole.png │ └── rl_overview.png │ ├── intro_bandit.ipynb │ ├── per_arm_bandits_tutorial.ipynb │ └── ranking_tutorial.ipynb ├── pip_pkg.sh ├── setup.py ├── test_individually.txt ├── tests_release.sh ├── tf_agents ├── AUTHORS ├── __init__.py ├── agents │ ├── __init__.py │ ├── behavioral_cloning │ │ ├── __init__.py │ │ ├── behavioral_cloning_agent.py │ │ └── behavioral_cloning_agent_test.py │ ├── categorical_dqn │ │ ├── __init__.py │ │ ├── categorical_dqn_agent.py │ │ ├── categorical_dqn_agent_test.py │ │ └── examples │ │ │ └── train_eval_atari.py │ ├── cql │ │ ├── __init__.py │ │ ├── cql_sac_agent.py │ │ └── cql_sac_agent_test.py │ ├── data_converter.py │ ├── data_converter_test.py │ ├── ddpg │ │ ├── __init__.py │ │ ├── actor_network.py │ │ ├── actor_network_test.py │ │ ├── actor_rnn_network.py │ │ ├── actor_rnn_network_test.py │ │ ├── critic_network.py │ │ ├── critic_network_test.py │ │ ├── critic_rnn_network.py │ │ ├── critic_rnn_network_test.py │ │ ├── ddpg_agent.py │ │ ├── ddpg_agent_test.py │ │ └── examples │ │ │ ├── __init__.py │ │ │ └── v2 │ │ │ ├── __init__.py │ │ │ ├── train_eval.py │ │ │ └── train_eval_rnn.py │ ├── dqn │ │ ├── __init__.py │ │ ├── dqn_agent.py │ │ ├── dqn_agent_test.py │ │ └── examples │ │ │ ├── __init__.py │ │ │ └── v2 │ │ │ ├── __init__.py │ │ │ ├── train_eval.py │ │ │ └── train_eval_test.py │ ├── ppo │ │ ├── __init__.py │ │ ├── examples │ │ │ ├── __init__.py │ │ │ └── v2 │ │ │ │ ├── __init__.py │ │ │ │ └── train_eval_clip_agent.py │ │ ├── ppo_actor_network.py │ │ ├── ppo_actor_network_test.py │ │ ├── ppo_agent.py │ │ ├── ppo_agent_test.py │ │ ├── ppo_clip_agent.py │ │ ├── ppo_kl_penalty_agent.py │ │ ├── ppo_policy.py │ │ ├── ppo_policy_test.py │ │ ├── ppo_utils.py │ │ └── ppo_utils_test.py │ ├── qtopt │ │ ├── qtopt_agent.py │ │ └── qtopt_agent_test.py │ ├── random │ │ ├── __init__.py │ │ ├── fixed_policy_agent.py │ │ ├── random_agent.py │ │ └── random_agent_test.py │ ├── reinforce │ │ ├── __init__.py │ │ ├── examples │ │ │ ├── __init__.py │ │ │ └── v2 │ │ │ │ ├── __init__.py │ │ │ │ └── train_eval.py │ │ ├── reinforce_agent.py │ │ └── reinforce_agent_test.py │ ├── sac │ │ ├── __init__.py │ │ ├── examples │ │ │ ├── __init__.py │ │ │ └── v2 │ │ │ │ ├── __init__.py │ │ │ │ ├── train_eval.py │ │ │ │ └── train_eval_rnn.py │ │ ├── sac_agent.py │ │ ├── sac_agent_test.py │ │ ├── tanh_normal_projection_network.py │ │ └── tanh_normal_projection_network_test.py │ ├── td3 │ │ ├── __init__.py │ │ ├── examples │ │ │ ├── __init__.py │ │ │ └── v2 │ │ │ │ ├── train_eval.py │ │ │ │ └── train_eval_rnn.py │ │ ├── td3_agent.py │ │ └── td3_agent_test.py │ ├── test_util.py │ ├── tf_agent.py │ └── tf_agent_test.py ├── bandits │ ├── README.md │ ├── __init__.py │ ├── agents │ │ ├── __init__.py │ │ ├── bernoulli_thompson_sampling_agent.py │ │ ├── bernoulli_thompson_sampling_agent_test.py │ │ ├── dropout_thompson_sampling_agent.py │ │ ├── dropout_thompson_sampling_agent_test.py │ │ ├── examples │ │ │ ├── __init__.py │ │ │ └── v2 │ │ │ │ ├── __init__.py │ │ │ │ ├── train_eval_bernoulli.py │ │ │ │ ├── train_eval_covertype.py │ │ │ │ ├── train_eval_dqn.py │ │ │ │ ├── train_eval_drifting_linear.py │ │ │ │ ├── train_eval_movielens.py │ │ │ │ ├── train_eval_mushroom.py │ │ │ │ ├── train_eval_per_arm_stationary_linear.py │ │ │ │ ├── train_eval_piecewise_linear.py │ │ │ │ ├── train_eval_ranking.py │ │ │ │ ├── train_eval_sparse_features.py │ │ │ │ ├── train_eval_stationary_linear.py │ │ │ │ ├── train_eval_structured_linear.py │ │ │ │ ├── train_eval_wheel.py │ │ │ │ ├── trainer.py │ │ │ │ ├── trainer_test.py │ │ │ │ └── trainer_test_utils.py │ │ ├── exp3_agent.py │ │ ├── exp3_agent_test.py │ │ ├── exp3_mixture_agent.py │ │ ├── exp3_mixture_agent_test.py │ │ ├── greedy_multi_objective_neural_agent.py │ │ ├── greedy_multi_objective_neural_agent_test.py │ │ ├── greedy_reward_prediction_agent.py │ │ ├── greedy_reward_prediction_agent_test.py │ │ ├── lin_ucb_agent.py │ │ ├── linear_bandit_agent.py │ │ ├── linear_bandit_agent_test.py │ │ ├── linear_thompson_sampling_agent.py │ │ ├── mixture_agent.py │ │ ├── mixture_agent_test.py │ │ ├── neural_boltzmann_agent.py │ │ ├── neural_boltzmann_agent_test.py │ │ ├── neural_epsilon_greedy_agent.py │ │ ├── neural_epsilon_greedy_agent_test.py │ │ ├── neural_falcon_agent.py │ │ ├── neural_falcon_agent_test.py │ │ ├── neural_linucb_agent.py │ │ ├── neural_linucb_agent_test.py │ │ ├── ranking_agent.py │ │ ├── ranking_agent_test.py │ │ ├── static_mixture_agent.py │ │ ├── utils.py │ │ └── utils_test.py │ ├── drivers │ │ ├── __init__.py │ │ └── driver_utils.py │ ├── environments │ │ ├── __init__.py │ │ ├── bandit_py_environment.py │ │ ├── bandit_tf_environment.py │ │ ├── bandit_tf_environment_test.py │ │ ├── bernoulli_action_mask_tf_environment.py │ │ ├── bernoulli_action_mask_tf_environment_test.py │ │ ├── bernoulli_py_environment.py │ │ ├── bernoulli_py_environment_test.py │ │ ├── classification_environment.py │ │ ├── classification_environment_test.py │ │ ├── dataset_utilities.py │ │ ├── dataset_utilities_test.py │ │ ├── drifting_linear_environment.py │ │ ├── drifting_linear_environment_test.py │ │ ├── environment_utilities.py │ │ ├── movielens_per_arm_py_environment.py │ │ ├── movielens_py_environment.py │ │ ├── non_stationary_stochastic_environment.py │ │ ├── non_stationary_stochastic_environment_test.py │ │ ├── piecewise_bernoulli_py_environment.py │ │ ├── piecewise_bernoulli_py_environment_test.py │ │ ├── piecewise_stochastic_environment.py │ │ ├── piecewise_stochastic_environment_test.py │ │ ├── random_bandit_environment.py │ │ ├── random_bandit_environment_test.py │ │ ├── ranking_environment.py │ │ ├── ranking_environment_test.py │ │ ├── stationary_stochastic_per_arm_py_environment.py │ │ ├── stationary_stochastic_per_arm_py_environment_test.py │ │ ├── stationary_stochastic_py_environment.py │ │ ├── stationary_stochastic_py_environment_test.py │ │ ├── stationary_stochastic_structured_py_environment.py │ │ ├── stationary_stochastic_structured_py_environment_test.py │ │ ├── wheel_py_environment.py │ │ └── wheel_py_environment_test.py │ ├── metrics │ │ ├── __init__.py │ │ ├── tf_metrics.py │ │ └── tf_metrics_test.py │ ├── multi_objective │ │ ├── __init__.py │ │ ├── multi_objective_scalarizer.py │ │ └── multi_objective_scalarizer_test.py │ ├── networks │ │ ├── __init__.py │ │ ├── global_and_arm_feature_network.py │ │ ├── global_and_arm_feature_network_test.py │ │ ├── heteroscedastic_q_network.py │ │ └── heteroscedastic_q_network_test.py │ ├── policies │ │ ├── __init__.py │ │ ├── bernoulli_thompson_sampling_policy.py │ │ ├── bernoulli_thompson_sampling_policy_test.py │ │ ├── boltzmann_reward_prediction_policy.py │ │ ├── boltzmann_reward_prediction_policy_test.py │ │ ├── categorical_policy.py │ │ ├── categorical_policy_test.py │ │ ├── constraints.py │ │ ├── constraints_test.py │ │ ├── falcon_reward_prediction_policy.py │ │ ├── falcon_reward_prediction_policy_test.py │ │ ├── greedy_multi_objective_neural_policy.py │ │ ├── greedy_multi_objective_neural_policy_test.py │ │ ├── greedy_reward_prediction_policy.py │ │ ├── greedy_reward_prediction_policy_test.py │ │ ├── lin_ucb_policy.py │ │ ├── linalg.py │ │ ├── linalg_test.py │ │ ├── linear_bandit_policy.py │ │ ├── linear_bandit_policy_test.py │ │ ├── linear_thompson_sampling_policy.py │ │ ├── loss_utils.py │ │ ├── loss_utils_test.py │ │ ├── mixture_policy.py │ │ ├── mixture_policy_test.py │ │ ├── neural_linucb_policy.py │ │ ├── neural_linucb_policy_test.py │ │ ├── policy_utilities_test.py │ │ ├── ranking_policy.py │ │ ├── ranking_policy_test.py │ │ ├── reward_prediction_base_policy.py │ │ └── reward_prediction_policies_test.py │ ├── replay_buffers │ │ ├── __init__.py │ │ └── bandit_replay_buffer.py │ └── specs │ │ ├── __init__.py │ │ └── utils.py ├── benchmark │ ├── __init__.py │ ├── cql_sac_benchmark.py │ ├── distribution_strategy_utils.py │ ├── dqn_benchmark.py │ ├── dqn_benchmark_test.py │ ├── perfzero_benchmark.py │ ├── perfzero_benchmark_test.py │ ├── ppo_benchmark.py │ ├── sac_benchmark.py │ ├── test_data │ │ ├── event_log_3m │ │ │ └── events.out.tfevents.1599310762 │ │ └── event_log_too_many │ │ │ ├── events.out.tfevents.1599310762 │ │ │ └── events.out.tfevents.1599379945 │ ├── utils.py │ └── utils_test.py ├── distributions │ ├── __init__.py │ ├── gumbel_softmax.py │ ├── gumbel_softmax_test.py │ ├── masked.py │ ├── masked_test.py │ ├── reparameterized_sampling.py │ ├── shifted_categorical.py │ ├── shifted_categorical_test.py │ ├── tanh_bijector_stable.py │ ├── utils.py │ └── utils_test.py ├── drivers │ ├── __init__.py │ ├── driver.py │ ├── dynamic_episode_driver.py │ ├── dynamic_episode_driver_test.py │ ├── dynamic_step_driver.py │ ├── dynamic_step_driver_test.py │ ├── py_driver.py │ ├── py_driver_test.py │ ├── test_utils.py │ ├── test_utils_test.py │ ├── tf_driver.py │ └── tf_driver_test.py ├── environments │ ├── __init__.py │ ├── atari_preprocessing.py │ ├── atari_preprocessing_test.py │ ├── atari_wrappers.py │ ├── atari_wrappers_test.py │ ├── batched_py_environment.py │ ├── batched_py_environment_test.py │ ├── configs │ │ ├── suite_bsuite.gin │ │ ├── suite_gym.gin │ │ ├── suite_gymnasium.gin │ │ ├── suite_mujoco.gin │ │ └── suite_pybullet.gin │ ├── dm_control_wrapper.py │ ├── dm_control_wrapper_test.py │ ├── examples │ │ ├── __init__.py │ │ ├── masked_cartpole.py │ │ ├── tic_tac_toe_environment.py │ │ └── tic_tac_toe_environment_test.py │ ├── gym_wrapper.py │ ├── gym_wrapper_test.py │ ├── gymnasium_wrapper.py │ ├── gymnasium_wrapper_test.py │ ├── parallel_py_environment.py │ ├── parallel_py_environment_test.py │ ├── py_environment.py │ ├── py_environment_test.py │ ├── py_to_dm_wrapper.py │ ├── random_py_environment.py │ ├── random_py_environment_test.py │ ├── random_tf_environment.py │ ├── random_tf_environment_test.py │ ├── suite_atari.py │ ├── suite_atari_test.py │ ├── suite_bsuite.py │ ├── suite_bsuite_test.py │ ├── suite_dm_control.py │ ├── suite_dm_control_test.py │ ├── suite_gym.py │ ├── suite_gym_test.py │ ├── suite_gymnasium.py │ ├── suite_gymnasium_test.py │ ├── suite_mujoco.py │ ├── suite_mujoco_test.py │ ├── suite_pybullet.py │ ├── suite_pybullet_test.py │ ├── test_envs.py │ ├── test_envs_test.py │ ├── tf_environment.py │ ├── tf_environment_test.py │ ├── tf_py_environment.py │ ├── tf_py_environment_test.py │ ├── tf_wrappers.py │ ├── tf_wrappers_test.py │ ├── trajectory_replay.py │ ├── trajectory_replay_test.py │ ├── utils.py │ ├── utils_test.py │ ├── wrappers.py │ └── wrappers_test.py ├── eval │ ├── __init__.py │ ├── metric_utils.py │ └── metric_utils_test.py ├── examples │ ├── __init__.py │ ├── cql_sac │ │ ├── README.md │ │ ├── __init__.py │ │ └── kumar20 │ │ │ ├── __init__.py │ │ │ ├── configs │ │ │ ├── __init__.py │ │ │ ├── antmaze.gin │ │ │ ├── mujoco.gin │ │ │ ├── mujoco_medium.gin │ │ │ └── mujoco_medium_expert.gin │ │ │ ├── cql_sac_train_eval.py │ │ │ ├── cql_sac_train_eval_test.py │ │ │ ├── d4rl_utils.py │ │ │ ├── data_utils.py │ │ │ └── dataset │ │ │ ├── __init__.py │ │ │ ├── dataset_generator.py │ │ │ ├── dataset_utils.py │ │ │ ├── dataset_utils_test.py │ │ │ ├── file_utils.py │ │ │ ├── file_utils_test.py │ │ │ └── test_data │ │ │ ├── antmaze-medium-play-v0_0.tfrecord │ │ │ └── antmaze-medium-play-v0_0.tfrecord.spec │ ├── dqn │ │ ├── README.md │ │ ├── __init__.py │ │ ├── dqn_train_eval.py │ │ ├── dqn_train_eval_rnn.py │ │ ├── gymnasium │ │ │ └── d3qn_train_eval.py │ │ └── mnih15 │ │ │ ├── __init__.py │ │ │ ├── configs │ │ │ ├── breakout.gin │ │ │ └── pong.gin │ │ │ └── dqn_train_eval_atari.py │ ├── ppo │ │ ├── README.md │ │ ├── __init__.py │ │ └── schulman17 │ │ │ ├── __init__.py │ │ │ ├── configs │ │ │ ├── half_cheetah.gin │ │ │ ├── hopper.gin │ │ │ ├── inverted_double_pendulum.gin │ │ │ ├── inverted_pendulum.gin │ │ │ ├── reacher.gin │ │ │ ├── swimmer.gin │ │ │ └── walker_2d.gin │ │ │ ├── ppo_clip_train_eval.py │ │ │ └── train_eval_lib.py │ └── sac │ │ ├── README.md │ │ ├── __init__.py │ │ └── haarnoja18 │ │ ├── __init__.py │ │ ├── configs │ │ ├── ant.gin │ │ ├── half_cheetah.gin │ │ ├── hopper.gin │ │ ├── humanoid.gin │ │ └── walker_2d.gin │ │ └── sac_train_eval.py ├── experimental │ ├── __init__.py │ ├── distributed │ │ ├── README.md │ │ ├── __init__.py │ │ ├── examples │ │ │ ├── eval_job.py │ │ │ ├── eval_job_test.py │ │ │ └── sac │ │ │ │ ├── README.md │ │ │ │ ├── sac_collect.py │ │ │ │ ├── sac_reverb_server.py │ │ │ │ └── sac_train.py │ │ ├── reverb_variable_container.py │ │ └── reverb_variable_container_test.py │ └── examples │ │ ├── __init__.py │ │ └── ppo │ │ ├── train_eval_lib.py │ │ └── train_eval_lib_test.py ├── keras_layers │ ├── __init__.py │ ├── bias_layer.py │ ├── bias_layer_test.py │ ├── dynamic_unroll_layer.py │ ├── dynamic_unroll_layer_test.py │ ├── inner_reshape.py │ ├── inner_reshape_test.py │ ├── permanent_variable_rate_dropout.py │ ├── permanent_variable_rate_dropout_test.py │ ├── rnn_wrapper.py │ ├── rnn_wrapper_test.py │ ├── squashed_outer_wrapper.py │ └── squashed_outer_wrapper_test.py ├── metrics │ ├── __init__.py │ ├── batched_py_metric.py │ ├── batched_py_metric_test.py │ ├── export_utils.py │ ├── metric_equality_test.py │ ├── py_metric.py │ ├── py_metric_test.py │ ├── py_metrics.py │ ├── py_metrics_test.py │ ├── tf_metric.py │ ├── tf_metrics.py │ ├── tf_metrics_test.py │ ├── tf_py_metric.py │ └── tf_py_metric_test.py ├── networks │ ├── __init__.py │ ├── actor_distribution_network.py │ ├── actor_distribution_network_test.py │ ├── actor_distribution_rnn_network.py │ ├── actor_distribution_rnn_network_test.py │ ├── categorical_projection_network.py │ ├── categorical_projection_network_test.py │ ├── categorical_q_network.py │ ├── categorical_q_network_test.py │ ├── dueling_q_network.py │ ├── encoding_network.py │ ├── encoding_network_test.py │ ├── expand_dims_layer.py │ ├── layer_utils.py │ ├── lstm_encoding_network.py │ ├── mask_splitter_network.py │ ├── mask_splitter_network_test.py │ ├── nest_map.py │ ├── nest_map_test.py │ ├── network.py │ ├── network_test.py │ ├── normal_projection_network.py │ ├── normal_projection_network_test.py │ ├── q_network.py │ ├── q_network_test.py │ ├── q_rnn_network.py │ ├── q_rnn_network_test.py │ ├── sequential.py │ ├── sequential_test.py │ ├── test_utils.py │ ├── utils.py │ ├── utils_test.py │ ├── value_network.py │ ├── value_network_test.py │ ├── value_rnn_network.py │ └── value_rnn_network_test.py ├── policies │ ├── __init__.py │ ├── actor_policy.py │ ├── actor_policy_test.py │ ├── async_policy_saver.py │ ├── async_policy_saver_test.py │ ├── batched_py_policy.py │ ├── batched_py_policy_test.py │ ├── boltzmann_policy.py │ ├── boltzmann_policy_test.py │ ├── categorical_q_policy.py │ ├── categorical_q_policy_test.py │ ├── epsilon_greedy_policy.py │ ├── epsilon_greedy_policy_test.py │ ├── fixed_policy.py │ ├── fixed_policy_test.py │ ├── gaussian_policy.py │ ├── gaussian_policy_test.py │ ├── greedy_policy.py │ ├── greedy_policy_test.py │ ├── ou_noise_policy.py │ ├── ou_noise_policy_test.py │ ├── policy_info_updater_wrapper.py │ ├── policy_info_updater_wrapper_test.py │ ├── policy_loader.py │ ├── policy_loader_test.py │ ├── policy_saver.py │ ├── policy_saver_test.py │ ├── py_epsilon_greedy_policy.py │ ├── py_epsilon_greedy_policy_test.py │ ├── py_policy.py │ ├── py_tf_eager_policy.py │ ├── py_tf_eager_policy_test.py │ ├── py_tf_policy.py │ ├── py_tf_policy_test.py │ ├── q_policy.py │ ├── q_policy_test.py │ ├── qtopt_cem_policy.py │ ├── qtopt_cem_policy_test.py │ ├── random_py_policy.py │ ├── random_py_policy_test.py │ ├── random_tf_policy.py │ ├── random_tf_policy_test.py │ ├── samplers │ │ ├── qtopt_cem_actions_sampler.py │ │ ├── qtopt_cem_actions_sampler_continuous.py │ │ ├── qtopt_cem_actions_sampler_continuous_and_one_hot.py │ │ ├── qtopt_cem_actions_sampler_continuous_and_one_hot_test.py │ │ ├── qtopt_cem_actions_sampler_continuous_test.py │ │ ├── qtopt_cem_actions_sampler_hybrid.py │ │ └── qtopt_cem_actions_sampler_hybrid_test.py │ ├── scripted_py_policy.py │ ├── scripted_py_policy_test.py │ ├── temporal_action_smoothing.py │ ├── temporal_action_smoothing_test.py │ ├── tf_policy.py │ ├── tf_policy_test.py │ ├── tf_py_policy.py │ ├── tf_py_policy_test.py │ └── utils.py ├── replay_buffers │ ├── __init__.py │ ├── episodic_replay_buffer.py │ ├── episodic_replay_buffer_driver_test.py │ ├── episodic_replay_buffer_test.py │ ├── episodic_table.py │ ├── episodic_table_test.py │ ├── py_hashed_replay_buffer.py │ ├── py_replay_buffers_test.py │ ├── py_uniform_replay_buffer.py │ ├── replay_buffer.py │ ├── replay_buffer_test.py │ ├── reverb_replay_buffer.py │ ├── reverb_replay_buffer_test.py │ ├── reverb_utils.py │ ├── reverb_utils_test.py │ ├── rlds_to_reverb.py │ ├── rlds_to_reverb_test.py │ ├── table.py │ ├── table_test.py │ ├── tf_uniform_replay_buffer.py │ └── tf_uniform_replay_buffer_test.py ├── specs │ ├── __init__.py │ ├── array_spec.py │ ├── array_spec_test.py │ ├── bandit_spec_utils.py │ ├── distribution_spec.py │ ├── distribution_spec_test.py │ ├── specs_test.py │ ├── tensor_spec.py │ └── tensor_spec_test.py ├── system │ ├── __init__.py │ ├── default │ │ ├── __init__.py │ │ └── multiprocessing_core.py │ ├── multiprocessing_test.py │ └── system_multiprocessing.py ├── tf_agents_api_test.py ├── train │ ├── README.md │ ├── __init__.py │ ├── actor.py │ ├── actor_test.py │ ├── interval_trigger.py │ ├── learner.py │ ├── learner_test.py │ ├── ppo_learner.py │ ├── ppo_learner_test.py │ ├── ppo_learner_test_utils.py │ ├── step_per_second_tracker.py │ ├── tpu_ppo_learner_test.py │ ├── triggers.py │ └── utils │ │ ├── __init__.py │ │ ├── replay_buffer_utils.py │ │ ├── replay_buffer_utils_test.py │ │ ├── spec_utils.py │ │ ├── spec_utils_test.py │ │ ├── strategy_utils.py │ │ ├── strategy_utils_test.py │ │ ├── test_utils.py │ │ ├── train_utils.py │ │ └── train_utils_test.py ├── trajectories │ ├── __init__.py │ ├── policy_step.py │ ├── policy_step_test.py │ ├── test_utils.py │ ├── time_step.py │ ├── time_step_test.py │ ├── trajectory.py │ └── trajectory_test.py ├── typing │ ├── __init__.py │ └── types.py ├── utils │ ├── __init__.py │ ├── batched_observer_unbatching.py │ ├── batched_observer_unbatching_test.py │ ├── common.py │ ├── common_members_not_overridden_test.py │ ├── common_test.py │ ├── composite.py │ ├── composite_test.py │ ├── eager_utils.py │ ├── eager_utils_test.py │ ├── example_encoding.py │ ├── example_encoding_dataset.py │ ├── example_encoding_dataset_test.py │ ├── example_encoding_test.py │ ├── lazy_loader.py │ ├── nest_utils.py │ ├── nest_utils_test.py │ ├── numpy_storage.py │ ├── numpy_storage_test.py │ ├── object_identity.py │ ├── object_identity_test.py │ ├── session_utils.py │ ├── session_utils_test.py │ ├── tensor_normalizer.py │ ├── tensor_normalizer_test.py │ ├── test_utils.py │ ├── test_utils_test.py │ ├── timer.py │ ├── value_ops.py │ ├── value_ops_test.py │ ├── xla.py │ └── xla_test.py └── version.py └── tools ├── build_docs.py ├── docker ├── README.md ├── ubuntu_atari ├── ubuntu_ci ├── ubuntu_d4rl ├── ubuntu_mujoco ├── ubuntu_mujoco_oss └── ubuntu_tf_agents ├── graph_builder.py ├── graph_builder_test.py ├── release_builder.py ├── test_colabs.py └── test_data ├── event_log_ant_eval00 └── events.out.tfevents.1599310762 ├── event_log_ant_eval01 └── events.out.tfevents.1599379945 └── event_log_ant_eval02 └── events.out.tfevents.1599448596 /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .ipynb_checkpoints 3 | node_modules 4 | /pip_test 5 | /_python_build 6 | *.pyc 7 | __pycache__ 8 | *.swp 9 | .vscode/ 10 | .idea/ 11 | .eggs/ 12 | *.egg-info/ 13 | /.vs 14 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Interested in contributing to TF Agents? We appreciate all kinds 4 | of help! 5 | 6 | ## Pull Requests 7 | 8 | We gladly welcome [pull requests]( 9 | https://help.github.com/articles/about-pull-requests/). 10 | 11 | Before making any changes, we recommend opening an issue (if it 12 | doesn't already exist) and discussing your proposed changes. This will 13 | let us give you advice on the proposed changes. If the changes are 14 | minor, then feel free to make them without discussion. 15 | 16 | Want to contribute but not sure of what? Here are a few suggestions: 17 | 18 | 1. Add a new example, colab, or tutorial. These are a great way to familiarize 19 | yourself and others with TF Agents. 20 | 21 | 22 | 2. Solve an [existing issue](https://github.com/tensorflow/agents/issues). 23 | These range from low-level software bugs to higher-level design problems. 24 | Check out the label [good first issue]( 25 | https://github.com/tensorflow/agents/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22). 26 | 27 | All submissions, including submissions by project members, require review. After 28 | a pull request is approved, we merge it. Note our merging process differs 29 | from GitHub in that we pull and submit the change into an internal version 30 | control system. This system automatically pushes a git commit to the GitHub 31 | repository (with credit to the original author) and closes the pull request. 32 | 33 | ## Style 34 | 35 | See the [TF Agents style guide](STYLE_GUIDE.md). 36 | 37 | ## Unit tests 38 | 39 | All TF Agents code-paths must be unit-tested. See existing unit tests for 40 | recommended test setup. 41 | 42 | Unit tests ensure new features (a) work correctly and (b) guard against future 43 | breaking changes (thus lower maintenance costs). 44 | 45 | To run existing unit-tests, use the command: 46 | 47 | 48 | ```shell 49 | python setup.py test 50 | ``` 51 | 52 | from the root of the `tf_agents` repository, ideally inside a virtualenv. 53 | The tests will run with CPU or GPU, depending on which version of TensorFlow 54 | you have installed. 55 | 56 | 57 | ## Contributor License Agreement 58 | 59 | Contributions to this project must be accompanied by a Contributor License 60 | Agreement. You (or your employer) retain the copyright to your contribution; 61 | this simply gives us permission to use and redistribute your contributions as 62 | part of the project. Head over to to see 63 | your current agreements on file or to sign a new one. 64 | 65 | You generally only need to submit a CLA once, so if you've already submitted one 66 | (even if it was for a different project), you probably don't need to do it 67 | again. 68 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include tf_agents *.gin 2 | -------------------------------------------------------------------------------- /broken_tests.txt: -------------------------------------------------------------------------------- 1 | agents.cql.cql_sac_agent_test # b/239448722 2 | environments.suite_atari_test # b/201537188 3 | examples.cql_sac.kumar20.dataset.file_utils_test # b/199583100 4 | examples.cql_sac.kumar20.dataset.dataset_utils_test # b/199583100 5 | examples.cql_sac.kumar20.cql_sac_train_eval_test # b/199583100 6 | policies.policy_saver_test # b/315708156 7 | -------------------------------------------------------------------------------- /broken_tests_gpu.txt: -------------------------------------------------------------------------------- 1 | agents.cql.cql_sac_agent_test # b/239448722 2 | environments.suite_atari_test # b/201537188 3 | examples.cql_sac.kumar20.dataset.file_utils_test # b/199583100 4 | examples.cql_sac.kumar20.dataset.dataset_utils_test # b/199583100 5 | examples.cql_sac.kumar20.cql_sac_train_eval_test # b/199583100 6 | replay_buffers.episodic_replay_buffer_driver_test # b/199587793 7 | replay_buffers.episodic_replay_buffer_test # b/199587793 8 | networks.q_rnn_network_test # b/237573967 9 | networks.actor_distribution_rnn_network_test # b/237573967 10 | agents.ddpg.critic_rnn_network_test # b/237573967 11 | policies.policy_saver_test # b/315708156 12 | -------------------------------------------------------------------------------- /docs/_book.yaml: -------------------------------------------------------------------------------- 1 | upper_tabs: 2 | # Tabs left of dropdown menu 3 | - include: /_upper_tabs_left.yaml 4 | - include: /api_docs/_upper_tabs_api.yaml 5 | # Dropdown menu 6 | - name: Resources 7 | path: /resources 8 | is_default: true 9 | menu: 10 | - include: /resources/_menu_toc.yaml 11 | lower_tabs: 12 | # Subsite tabs 13 | other: 14 | - name: Guide & Tutorials 15 | contents: 16 | - title: Overview and install 17 | path: /agents/overview 18 | - title: Intro to RL 19 | path: /agents/tutorials/0_intro_rl 20 | - title: Intro to Multi-Armed Bandits 21 | path: /agents/tutorials/intro_bandit 22 | - title: Train a deep Q network 23 | path: /agents/tutorials/1_dqn_tutorial 24 | - heading: "Guide" 25 | - title: Environments 26 | path: /agents/tutorials/2_environments_tutorial 27 | - title: Policies 28 | path: /agents/tutorials/3_policies_tutorial 29 | - title: Drivers 30 | path: /agents/tutorials/4_drivers_tutorial 31 | - title: Replay buffers 32 | path: /agents/tutorials/5_replay_buffers_tutorial 33 | - title: Networks 34 | path: /agents/tutorials/8_networks_tutorial 35 | - title: Checkpointer and PolicySaver 36 | path: /agents/tutorials/10_checkpointer_policysaver_tutorial 37 | - heading: "Tutorials" 38 | - title: REINFORCE agent 39 | path: /agents/tutorials/6_reinforce_tutorial 40 | - title: Soft Actor Critic (SAC) in Minitaur 41 | path: /agents/tutorials/7_SAC_minitaur_tutorial 42 | - title: Categorical DQN 43 | path: /agents/tutorials/9_c51_tutorial 44 | - title: Multi-Armed Bandits Tutorial 45 | path: /agents/tutorials/bandits_tutorial 46 | - title: Ranking Tutorial 47 | path: /agents/tutorials/ranking_tutorial 48 | - title: Multi-Armed Bandits with Per-Arm 49 | path: agents/tutorials/per_arm_bandits_tutorial 50 | - name: API 51 | skip_translation: true 52 | contents: 53 | - include: /agents/api_docs/python/tf_agents/_toc.yaml 54 | 55 | - include: /_upper_tabs_right.yaml 56 | -------------------------------------------------------------------------------- /docs/images/actor_learner_distributed_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/actor_learner_distributed_architecture.png -------------------------------------------------------------------------------- /docs/images/cql_sac_readme/antmaze-large-diverse-v0_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/cql_sac_readme/antmaze-large-diverse-v0_graph.png -------------------------------------------------------------------------------- /docs/images/cql_sac_readme/antmaze-large-play-v0_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/cql_sac_readme/antmaze-large-play-v0_graph.png -------------------------------------------------------------------------------- /docs/images/cql_sac_readme/antmaze-medium-diverse-v0_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/cql_sac_readme/antmaze-medium-diverse-v0_graph.png -------------------------------------------------------------------------------- /docs/images/cql_sac_readme/antmaze-medium-play-v0_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/cql_sac_readme/antmaze-medium-play-v0_graph.png -------------------------------------------------------------------------------- /docs/images/cql_sac_readme/halfcheetah-medium-expert-v0_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/cql_sac_readme/halfcheetah-medium-expert-v0_graph.png -------------------------------------------------------------------------------- /docs/images/cql_sac_readme/halfcheetah-medium-v0_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/cql_sac_readme/halfcheetah-medium-v0_graph.png -------------------------------------------------------------------------------- /docs/images/cql_sac_readme/hopper-medium-expert-v0_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/cql_sac_readme/hopper-medium-expert-v0_graph.png -------------------------------------------------------------------------------- /docs/images/cql_sac_readme/hopper-medium-v0_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/cql_sac_readme/hopper-medium-v0_graph.png -------------------------------------------------------------------------------- /docs/images/cql_sac_readme/walker2d-medium-expert-v0_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/cql_sac_readme/walker2d-medium-expert-v0_graph.png -------------------------------------------------------------------------------- /docs/images/cql_sac_readme/walker2d-medium-v0_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/cql_sac_readme/walker2d-medium-v0_graph.png -------------------------------------------------------------------------------- /docs/images/dqn_readme/pong-v0_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/dqn_readme/pong-v0_graph.png -------------------------------------------------------------------------------- /docs/images/learner_detail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/learner_detail.png -------------------------------------------------------------------------------- /docs/images/ppo_readme/halfcheetah-v2_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/ppo_readme/halfcheetah-v2_graph.png -------------------------------------------------------------------------------- /docs/images/ppo_readme/hopper-v2_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/ppo_readme/hopper-v2_graph.png -------------------------------------------------------------------------------- /docs/images/ppo_readme/inverteddoublependulum-v2_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/ppo_readme/inverteddoublependulum-v2_graph.png -------------------------------------------------------------------------------- /docs/images/ppo_readme/invertedpendulum-v2_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/ppo_readme/invertedpendulum-v2_graph.png -------------------------------------------------------------------------------- /docs/images/ppo_readme/reacher-v2_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/ppo_readme/reacher-v2_graph.png -------------------------------------------------------------------------------- /docs/images/ppo_readme/swimmer-v2_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/ppo_readme/swimmer-v2_graph.png -------------------------------------------------------------------------------- /docs/images/ppo_readme/walker2d-v2_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/ppo_readme/walker2d-v2_graph.png -------------------------------------------------------------------------------- /docs/images/rlds/flatten_rlds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/rlds/flatten_rlds.png -------------------------------------------------------------------------------- /docs/images/rlds/pairs_to_trajectories.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/rlds/pairs_to_trajectories.png -------------------------------------------------------------------------------- /docs/images/rlds/rlds_step_to_trajectory.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/rlds/rlds_step_to_trajectory.png -------------------------------------------------------------------------------- /docs/images/rlds/rlds_to_pairs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/rlds/rlds_to_pairs.png -------------------------------------------------------------------------------- /docs/images/sac_readme/ant-v2_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/sac_readme/ant-v2_graph.png -------------------------------------------------------------------------------- /docs/images/sac_readme/halfcheetah-v2_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/sac_readme/halfcheetah-v2_graph.png -------------------------------------------------------------------------------- /docs/images/sac_readme/hopper-v2_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/sac_readme/hopper-v2_graph.png -------------------------------------------------------------------------------- /docs/images/sac_readme/humanoid-v2_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/sac_readme/humanoid-v2_graph.png -------------------------------------------------------------------------------- /docs/images/sac_readme/walker2d-v2_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/images/sac_readme/walker2d-v2_graph.png -------------------------------------------------------------------------------- /docs/tutorials/colab_kernel_init.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Initialization code for colab test.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import glob 23 | import logging 24 | import os 25 | import time 26 | 27 | 28 | def WaitForFilePath(path_pattern, timeout_sec): 29 | start = time.time() 30 | result = [] 31 | while not result: 32 | if time.time() - start > timeout_sec: 33 | return result 34 | result = glob.glob(path_pattern) 35 | time.sleep(0.1) 36 | return result 37 | 38 | 39 | def SetDisplayFromWebTest(): 40 | """Set up display from web test. 41 | 42 | Colab test sets up display using xvfb for front end web test suite. We just 43 | ensure that DISPLAY environment variable is properly set for colab kernel 44 | (backend) which can be used for open gym environment rendering. 45 | """ 46 | 47 | res = WaitForFilePath("/tmp/.X11-unix", 60) 48 | assert res 49 | 50 | pattern = "/tmp/.X11-unix/X*" 51 | res = WaitForFilePath(pattern, 60) 52 | assert res 53 | 54 | # If we find "/tmp/.X11-unix/X1", then we will set DISPLAY to be ":1". 55 | display = ":" + res[0][len(pattern) - 1 :] 56 | os.environ["DISPLAY"] = display 57 | logging.info("Set DISPLAY=%s", display) 58 | 59 | 60 | SetDisplayFromWebTest() 61 | -------------------------------------------------------------------------------- /docs/tutorials/images/c51_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/tutorials/images/c51_distribution.png -------------------------------------------------------------------------------- /docs/tutorials/images/cartpole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/tutorials/images/cartpole.png -------------------------------------------------------------------------------- /docs/tutorials/images/rl_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/docs/tutorials/images/rl_overview.png -------------------------------------------------------------------------------- /pip_pkg.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2018 The TF Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | set -e 17 | set -x 18 | 19 | # Requiring PYTHON_VERSION (path to the python binary to use) mitigates the risk 20 | # of testing/building the modules with one version of Python and packaging the 21 | # wheel with another. 22 | if [ -z "$PYTHON_VERSION" ]; then 23 | echo "ENV var PYTHON_VERSION must point to an installed python binary." 24 | exit 1 25 | fi 26 | 27 | PLATFORM="$(uname -s | tr 'A-Z' 'a-z')" 28 | 29 | if [[ $# -lt 1 ]] ; then 30 | echo "Usage:" 31 | echo "pip_pkg /path/to/destination/directory [--release]" 32 | exit 1 33 | fi 34 | 35 | # Create the destination directory, then do dirname on a non-existent file 36 | # inside it to give us a path with tilde characters resolved (readlink -f is 37 | # another way of doing this but is not available on a fresh macOS install). 38 | # Finally, use cd and pwd to get an absolute path, in case a relative one was 39 | # given. 40 | mkdir -p "$1" 41 | DEST=$(dirname "${1}/does_not_exist") 42 | DEST=$(cd "$DEST" && pwd) 43 | 44 | # Pass through remaining arguments (following the first argument, which 45 | # specifies the output dir) to setup.py, e.g., 46 | # ./pip_pkg /tmp/tf_agents_pkg --release 47 | # passes `--release` to setup.py. 48 | $PYTHON_VERSION setup.py bdist_wheel ${@:2} --dist-dir="$DEST" >/dev/null 49 | 50 | set +x 51 | echo -e "\nBuild complete. Wheel files are in $DEST" 52 | -------------------------------------------------------------------------------- /test_individually.txt: -------------------------------------------------------------------------------- 1 | environments.parallel_py_environment_test 2 | # Needs custom multiprocessing state savers. 3 | system.multiprocessing_test 4 | # TODO(b/135926163): random_??_environment_tests time out when inside 5 | # TestSuite. 6 | environments.random_py_environment_test 7 | environments.random_tf_environment_test 8 | train.learner_test 9 | -------------------------------------------------------------------------------- /tf_agents/AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the list of TF-Agent's significant contributors. 2 | # 3 | # This does not necessarily list everyone who has contributed code, 4 | # especially since many employees of one corporation may be contributing. 5 | # To see the full list of contributors, see the revision history in 6 | # source control. 7 | Google LLC 8 | Sergio Guadarrama 9 | Anoop Korattikara 10 | Oscar Ramirez 11 | Pablo Castro 12 | Ethan Holly 13 | Sam Fishman 14 | Ke Wang 15 | Ekaterina Gonina 16 | Neal Wu 17 | Efi Kokiopoulou 18 | Luciano Sbaiz 19 | Jamie Smith 20 | Gábor Bartók 21 | Jesse Berent 22 | Chris Harris 23 | Vincent Vanhoucke 24 | Eugene Brevdo 25 | James Davidson 26 | Toby Boyd 27 | Summer Yue 28 | Robert Ormandi 29 | Kuang-Huei Lee 30 | Alexa Greenberg 31 | Amir Yazdanbakhsh 32 | Yao Lu 33 | Gaurav Jain 34 | Christof Angermueller 35 | Mark Daoust 36 | Adam Wood -------------------------------------------------------------------------------- /tf_agents/agents/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Module importing all agents.""" 17 | from tf_agents.agents import behavioral_cloning 18 | from tf_agents.agents import categorical_dqn 19 | from tf_agents.agents import cql 20 | from tf_agents.agents import data_converter 21 | from tf_agents.agents import ddpg 22 | from tf_agents.agents import dqn 23 | from tf_agents.agents import ppo 24 | from tf_agents.agents import reinforce 25 | from tf_agents.agents import sac 26 | from tf_agents.agents import td3 27 | from tf_agents.agents import tf_agent 28 | from tf_agents.agents.behavioral_cloning.behavioral_cloning_agent import BehavioralCloningAgent 29 | from tf_agents.agents.categorical_dqn.categorical_dqn_agent import CategoricalDqnAgent 30 | from tf_agents.agents.cql.cql_sac_agent import CqlSacAgent 31 | from tf_agents.agents.ddpg.ddpg_agent import DdpgAgent 32 | from tf_agents.agents.dqn.dqn_agent import DqnAgent 33 | from tf_agents.agents.ppo.ppo_agent import PPOAgent 34 | from tf_agents.agents.ppo.ppo_clip_agent import PPOClipAgent 35 | from tf_agents.agents.ppo.ppo_kl_penalty_agent import PPOKLPenaltyAgent 36 | from tf_agents.agents.reinforce.reinforce_agent import ReinforceAgent 37 | from tf_agents.agents.sac.sac_agent import SacAgent 38 | from tf_agents.agents.td3.td3_agent import Td3Agent 39 | from tf_agents.agents.tf_agent import TFAgent 40 | -------------------------------------------------------------------------------- /tf_agents/agents/behavioral_cloning/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A Behavioral Cloning agent.""" 17 | from tf_agents.agents.behavioral_cloning import behavioral_cloning_agent 18 | -------------------------------------------------------------------------------- /tf_agents/agents/categorical_dqn/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A Categorical DQN (C51) agent.""" 17 | from tf_agents.agents.categorical_dqn import categorical_dqn_agent 18 | -------------------------------------------------------------------------------- /tf_agents/agents/cql/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A CQL-SAC agent.""" 17 | from tf_agents.agents.cql import cql_sac_agent 18 | -------------------------------------------------------------------------------- /tf_agents/agents/ddpg/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A Deep Deterministic Policy Gradient (DDPG) agent and its networks.""" 17 | 18 | from tf_agents.agents.ddpg import actor_network 19 | from tf_agents.agents.ddpg import actor_rnn_network 20 | from tf_agents.agents.ddpg import critic_network 21 | from tf_agents.agents.ddpg import critic_rnn_network 22 | from tf_agents.agents.ddpg import ddpg_agent 23 | -------------------------------------------------------------------------------- /tf_agents/agents/ddpg/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/agents/ddpg/examples/v2/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/agents/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A DQN (Deep Q Network) agent.""" 17 | from tf_agents.agents.dqn import dqn_agent 18 | -------------------------------------------------------------------------------- /tf_agents/agents/dqn/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/agents/dqn/examples/v2/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/agents/dqn/examples/v2/train_eval_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tf_agents.agents.dqn.examples.v2.train_eval.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl import flags 23 | import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import 24 | from tf_agents.agents.dqn.examples.v2 import train_eval 25 | 26 | FLAGS = flags.FLAGS 27 | 28 | 29 | class TrainEval(tf.test.TestCase): 30 | 31 | def testDQNCartPole(self): 32 | if not tf.executing_eagerly(): 33 | self.skipTest('Binary is eager-only.') 34 | 35 | root_dir = self.get_temp_dir() 36 | train_loss = train_eval.train_eval( 37 | root_dir, 38 | num_iterations=1, 39 | num_eval_episodes=1, 40 | initial_collect_steps=10, 41 | ) 42 | self.assertGreater(train_loss.loss, 0.0) 43 | 44 | def testRNNDQNMaskedCartPole(self): 45 | if not tf.executing_eagerly(): 46 | self.skipTest('Binary is eager-only.') 47 | 48 | root_dir = self.get_temp_dir() 49 | train_loss = train_eval.train_eval( 50 | root_dir, 51 | env_name='MaskedCartPole-v0', 52 | train_sequence_length=2, 53 | initial_collect_steps=10, 54 | num_eval_episodes=1, 55 | num_iterations=1, 56 | ) 57 | self.assertGreater(train_loss.loss, 0.0) 58 | 59 | 60 | if __name__ == '__main__': 61 | tf.compat.v1.enable_v2_behavior() 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /tf_agents/agents/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """PPO Agents.""" 17 | from tf_agents.agents.ppo import ppo_actor_network 18 | from tf_agents.agents.ppo import ppo_agent 19 | from tf_agents.agents.ppo import ppo_clip_agent 20 | from tf_agents.agents.ppo import ppo_kl_penalty_agent 21 | from tf_agents.agents.ppo import ppo_policy 22 | from tf_agents.agents.ppo import ppo_utils 23 | -------------------------------------------------------------------------------- /tf_agents/agents/ppo/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/agents/ppo/examples/v2/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/agents/random/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A random Agent.""" 17 | from tf_agents.agents.random import fixed_policy_agent 18 | from tf_agents.agents.random import random_agent 19 | -------------------------------------------------------------------------------- /tf_agents/agents/reinforce/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A REINFORCE agent.""" 17 | from tf_agents.agents.reinforce import reinforce_agent 18 | -------------------------------------------------------------------------------- /tf_agents/agents/reinforce/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/agents/reinforce/examples/v2/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/agents/sac/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A Soft Actor Critic agent.""" 17 | from tf_agents.agents.sac import sac_agent 18 | -------------------------------------------------------------------------------- /tf_agents/agents/sac/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/agents/sac/examples/v2/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/agents/td3/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Twin Delayed Deep Deterministic policy gradient (TD3) agent.""" 17 | from tf_agents.agents.td3 import td3_agent 18 | -------------------------------------------------------------------------------- /tf_agents/agents/td3/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/bandits/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """TF-Agents Bandits.""" 17 | 18 | from tf_agents.bandits import agents 19 | from tf_agents.bandits import drivers 20 | from tf_agents.bandits import environments 21 | from tf_agents.bandits import metrics 22 | from tf_agents.bandits import multi_objective 23 | from tf_agents.bandits import networks 24 | from tf_agents.bandits import policies 25 | from tf_agents.bandits import specs 26 | -------------------------------------------------------------------------------- /tf_agents/bandits/agents/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Module importing all agents.""" 17 | 18 | from tf_agents.bandits.agents import bernoulli_thompson_sampling_agent 19 | from tf_agents.bandits.agents import dropout_thompson_sampling_agent 20 | from tf_agents.bandits.agents import examples 21 | from tf_agents.bandits.agents import exp3_agent 22 | from tf_agents.bandits.agents import exp3_mixture_agent 23 | from tf_agents.bandits.agents import greedy_multi_objective_neural_agent 24 | from tf_agents.bandits.agents import greedy_reward_prediction_agent 25 | from tf_agents.bandits.agents import lin_ucb_agent 26 | from tf_agents.bandits.agents import linear_bandit_agent 27 | from tf_agents.bandits.agents import linear_thompson_sampling_agent 28 | from tf_agents.bandits.agents import mixture_agent 29 | from tf_agents.bandits.agents import neural_boltzmann_agent 30 | from tf_agents.bandits.agents import neural_epsilon_greedy_agent 31 | from tf_agents.bandits.agents import neural_falcon_agent 32 | from tf_agents.bandits.agents import neural_linucb_agent 33 | from tf_agents.bandits.agents import ranking_agent 34 | from tf_agents.bandits.agents import static_mixture_agent 35 | from tf_agents.bandits.agents import utils 36 | -------------------------------------------------------------------------------- /tf_agents/bandits/agents/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/bandits/agents/examples/v2/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/bandits/agents/static_mixture_agent.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """An agent that mixes a list of agents with a constant mixture distribution.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import gin 22 | from tf_agents.bandits.agents import mixture_agent 23 | 24 | 25 | @gin.configurable 26 | class StaticMixtureAgent(mixture_agent.MixtureAgent): 27 | """An agent that mixes a set of agents with a given static mixture. 28 | 29 | For every data sample, the agent updates the sub-agent that was used to make 30 | the action choice in that sample. For this update to happen, the mixture agent 31 | needs to have the information on which sub-agent is "responsible" for the 32 | action. This information is in a policy info field `mixture_agent_id`. 33 | 34 | Note that this agent makes use of `tf.dynamic_partition`, and thus it is not 35 | compatible with XLA. 36 | """ 37 | 38 | def _update_mixture_distribution(self, experience): 39 | pass 40 | -------------------------------------------------------------------------------- /tf_agents/bandits/drivers/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Module importing all driver libraries.""" 17 | 18 | from tf_agents.bandits.drivers import driver_utils 19 | -------------------------------------------------------------------------------- /tf_agents/bandits/drivers/driver_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Driver utilities for use with bandit policies and environments.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import 23 | from tf_agents.trajectories import trajectory 24 | from tf_agents.typing import types 25 | 26 | nest = tf.compat.v2.nest 27 | 28 | 29 | def trajectory_for_bandit( 30 | initial_step: types.TimeStep, 31 | action_step: types.PolicyStep, 32 | final_step: types.TimeStep, 33 | ) -> types.NestedTensor: 34 | """Builds a trajectory from a single-step bandit episode. 35 | 36 | Since all episodes consist of a single step, the returned `Trajectory` has no 37 | time dimension. All input and output `Tensor`s/arrays are expected to have 38 | shape `[batch_size, ...]`. 39 | 40 | Args: 41 | initial_step: A `TimeStep` returned from `environment.step(...)`. 42 | action_step: A `PolicyStep` returned by `policy.action(...)`. 43 | final_step: A `TimeStep` returned from `environment.step(...)`. 44 | 45 | Returns: 46 | A `Trajectory` containing zeros for discount value and `StepType.LAST` for 47 | both `step_type` and `next_step_type`. 48 | """ 49 | return trajectory.Trajectory( 50 | observation=initial_step.observation, 51 | action=action_step.action, 52 | policy_info=action_step.info, 53 | reward=final_step.reward, 54 | discount=final_step.discount, 55 | step_type=initial_step.step_type, 56 | next_step_type=final_step.step_type, 57 | ) 58 | -------------------------------------------------------------------------------- /tf_agents/bandits/environments/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Module importing all environments.""" 17 | 18 | from tf_agents.bandits.environments import bandit_py_environment 19 | from tf_agents.bandits.environments import bandit_tf_environment 20 | from tf_agents.bandits.environments import bernoulli_action_mask_tf_environment 21 | from tf_agents.bandits.environments import bernoulli_py_environment 22 | from tf_agents.bandits.environments import classification_environment 23 | from tf_agents.bandits.environments import dataset_utilities 24 | from tf_agents.bandits.environments import drifting_linear_environment 25 | from tf_agents.bandits.environments import movielens_per_arm_py_environment 26 | from tf_agents.bandits.environments import movielens_py_environment 27 | from tf_agents.bandits.environments import non_stationary_stochastic_environment 28 | from tf_agents.bandits.environments import piecewise_bernoulli_py_environment 29 | from tf_agents.bandits.environments import piecewise_stochastic_environment 30 | from tf_agents.bandits.environments import random_bandit_environment 31 | from tf_agents.bandits.environments import ranking_environment 32 | from tf_agents.bandits.environments import stationary_stochastic_per_arm_py_environment 33 | from tf_agents.bandits.environments import stationary_stochastic_py_environment 34 | from tf_agents.bandits.environments import stationary_stochastic_structured_py_environment 35 | from tf_agents.bandits.environments import wheel_py_environment 36 | -------------------------------------------------------------------------------- /tf_agents/bandits/environments/bernoulli_py_environment_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for the Bernoulli Bandit environment.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import 23 | from tf_agents.bandits.environments import bernoulli_py_environment 24 | 25 | 26 | class BernoulliBanditPyEnvironmentTest(tf.test.TestCase): 27 | 28 | def test_bernoulli_bandit_py_environment(self): 29 | env = bernoulli_py_environment.BernoulliPyEnvironment( 30 | [0.1, 0.2, 0.3], batch_size=2 31 | ) 32 | observation_step = env.reset() 33 | self.assertAllEqual(observation_step.observation.shape, [2]) 34 | reward_step = env.step([0, 1]) 35 | self.assertAllEqual(len(reward_step.reward), 2) 36 | 37 | def test_out_of_bound_parameter(self): 38 | with self.assertRaisesRegex( 39 | ValueError, r'All parameters should be floats in \[0, 1\]\.' 40 | ): 41 | bernoulli_py_environment.BernoulliPyEnvironment( 42 | [0.1, 1.2, 0.3], batch_size=1 43 | ) 44 | 45 | 46 | if __name__ == '__main__': 47 | tf.test.main() 48 | -------------------------------------------------------------------------------- /tf_agents/bandits/environments/dataset_utilities_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tf_agents.bandits.environments.dataset_utilities.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import 24 | from tf_agents.bandits.environments import dataset_utilities 25 | 26 | 27 | class DatasetUtilitiesTest(tf.test.TestCase): 28 | 29 | def testOneHot(self): 30 | data = np.array([[1, 2], [1, 3], [2, 2], [1, 1]], dtype=np.int32) 31 | encoded = dataset_utilities._one_hot(data) 32 | expected = [ 33 | [1, 0, 0, 1, 0], 34 | [1, 0, 0, 0, 1], 35 | [0, 1, 0, 1, 0], 36 | [1, 0, 1, 0, 0], 37 | ] 38 | np.testing.assert_array_equal(encoded, expected) 39 | 40 | def testRewardDistribution(self): 41 | reward_distr = dataset_utilities.mushroom_reward_distribution( 42 | r_noeat=0.0, 43 | r_eat_safe=5.0, 44 | r_eat_poison_bad=-35.0, 45 | r_eat_poison_good=5.0, 46 | prob_poison_bad=0.5, 47 | ) 48 | self.assertAllEqual(reward_distr.mean(), [[0, -15.0], [0, 5.0]]) 49 | 50 | 51 | if __name__ == '__main__': 52 | tf.test.main() 53 | -------------------------------------------------------------------------------- /tf_agents/bandits/environments/piecewise_bernoulli_py_environment_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for the Bernoulli Bandit environment.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl.testing import parameterized 23 | import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import 24 | from tf_agents.bandits.environments import piecewise_bernoulli_py_environment as pbe 25 | 26 | 27 | class PiecewiseBernoulliBanditPyEnvironmentTest( 28 | tf.test.TestCase, parameterized.TestCase 29 | ): 30 | 31 | def deterministic_duration_generator(self): 32 | while True: 33 | yield 10 34 | 35 | def test_out_of_bound_parameter(self): 36 | with self.assertRaisesRegex( 37 | ValueError, r'All parameters should be floats in \[0, 1\]\.' 38 | ): 39 | pbe.PiecewiseBernoulliPyEnvironment( 40 | [[0.1, 1.2, 0.3]], self.deterministic_duration_generator() 41 | ) 42 | 43 | @parameterized.named_parameters( 44 | dict(testcase_name='_batch_1', batch_size=1), 45 | dict(testcase_name='_batch_4', batch_size=4), 46 | ) 47 | def test_correct_piece(self, batch_size): 48 | env = pbe.PiecewiseBernoulliPyEnvironment( 49 | [[0.1, 0.2, 0.3], [0.3, 0.2, 0.1], [0.1, 0.12, 0.14]], 50 | self.deterministic_duration_generator(), 51 | batch_size, 52 | ) 53 | for t in range(100): 54 | env.reset() 55 | self.assertEqual(int(t / 10) % 3, env._current_piece) 56 | _ = env.step([0]) 57 | 58 | 59 | if __name__ == '__main__': 60 | tf.test.main() 61 | -------------------------------------------------------------------------------- /tf_agents/bandits/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Module importing all metrics.""" 17 | 18 | from tf_agents.bandits.metrics import tf_metrics 19 | -------------------------------------------------------------------------------- /tf_agents/bandits/multi_objective/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Module importing all multi_objective modules.""" 17 | 18 | from tf_agents.bandits.multi_objective import multi_objective_scalarizer 19 | -------------------------------------------------------------------------------- /tf_agents/bandits/networks/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Module importing all networks.""" 17 | 18 | from tf_agents.bandits.networks import global_and_arm_feature_network 19 | from tf_agents.bandits.networks import heteroscedastic_q_network 20 | -------------------------------------------------------------------------------- /tf_agents/bandits/policies/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Module importing all policies.""" 17 | 18 | from tf_agents.bandits.policies import categorical_policy 19 | from tf_agents.bandits.policies import falcon_reward_prediction_policy 20 | from tf_agents.bandits.policies import greedy_multi_objective_neural_policy 21 | from tf_agents.bandits.policies import greedy_reward_prediction_policy 22 | from tf_agents.bandits.policies import lin_ucb_policy 23 | from tf_agents.bandits.policies import linalg 24 | from tf_agents.bandits.policies import linear_thompson_sampling_policy 25 | from tf_agents.bandits.policies import mixture_policy 26 | from tf_agents.bandits.policies import neural_linucb_policy 27 | from tf_agents.bandits.policies import ranking_policy 28 | from tf_agents.policies import utils as policy_utilities 29 | -------------------------------------------------------------------------------- /tf_agents/bandits/policies/greedy_reward_prediction_policy.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Policy for greedy reward prediction.""" 17 | 18 | import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import 19 | import tensorflow_probability as tfp 20 | from tf_agents.bandits.policies import reward_prediction_base_policy 21 | from tf_agents.policies import utils as policy_utilities 22 | 23 | 24 | class GreedyRewardPredictionPolicy( 25 | reward_prediction_base_policy.RewardPredictionBasePolicy 26 | ): 27 | """Class to build GreedyNNPredictionPolicies.""" 28 | 29 | def _action_distribution(self, mask, predicted_rewards): 30 | """Returns the action with largest predicted reward.""" 31 | # Argmax. 32 | batch_size = tf.shape(predicted_rewards)[0] 33 | if mask is not None: 34 | actions = policy_utilities.masked_argmax( 35 | predicted_rewards, mask, output_type=self.action_spec.dtype 36 | ) 37 | else: 38 | actions = tf.argmax( 39 | predicted_rewards, axis=-1, output_type=self.action_spec.dtype 40 | ) 41 | 42 | actions += self._action_offset 43 | 44 | bandit_policy_values = tf.fill( 45 | [batch_size, 1], policy_utilities.BanditPolicyType.GREEDY 46 | ) 47 | return tfp.distributions.Deterministic(loc=actions), bandit_policy_values 48 | 49 | def _distribution(self, time_step, policy_state): 50 | step = super(GreedyRewardPredictionPolicy, self)._distribution( 51 | time_step, policy_state 52 | ) 53 | # Greedy is deterministic, so we know the chosen arm features here. We 54 | # save it here so the chosen arm features get correctly returned by 55 | # `tf_agents.policies.epsilon_greey_policy.EpsilonGreedyPolicy` wrapping a 56 | # `GreedyRewardPredictionPolicy` because `EpsilonGreedyPolicy` only accesses 57 | # the `distribution` method of the wrapped policy via 58 | # `tf_agents.policies.greedy_policy.GreedyPolicy`. 59 | action = step.action.sample() 60 | return self._maybe_save_chosen_arm_features(time_step, action, step) 61 | -------------------------------------------------------------------------------- /tf_agents/bandits/policies/loss_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tf_agents.bandits.agents.loss_utils.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import 24 | from tf_agents.bandits.policies import loss_utils 25 | 26 | 27 | tf.compat.v1.enable_v2_behavior() 28 | 29 | 30 | class LossUtilsTest(tf.test.TestCase): 31 | 32 | def testBaseCase(self): 33 | # Example taken from: 34 | # https://en.wikipedia.org/wiki/Quantile_regression 35 | # Random variable takes values 1...9 with equal probability. 36 | y_true = tf.constant(np.arange(1, 10), dtype=tf.float32) 37 | 38 | # Compute the loss for the median. 39 | # We see that the value `y_pred = 5` minimizes the loss. 40 | 41 | p_loss = loss_utils.pinball_loss( 42 | y_true, y_pred=3 * tf.ones_like(y_true), quantile=0.5 43 | ) 44 | self.assertNear(24.0, 9.0 / 0.5 * self.evaluate(p_loss), err=1e-3) 45 | 46 | p_loss = loss_utils.pinball_loss( 47 | y_true, y_pred=4 * tf.ones_like(y_true), quantile=0.5 48 | ) 49 | self.assertNear(21.0, 9.0 / 0.5 * self.evaluate(p_loss), err=1e-3) 50 | 51 | p_loss = loss_utils.pinball_loss( 52 | y_true, y_pred=5 * tf.ones_like(y_true), quantile=0.5 53 | ) 54 | self.assertNear(20.0, 9.0 / 0.5 * self.evaluate(p_loss), err=1e-3) 55 | 56 | p_loss = loss_utils.pinball_loss( 57 | y_true, y_pred=6 * tf.ones_like(y_true), quantile=0.5 58 | ) 59 | self.assertNear(21.0, 9.0 / 0.5 * self.evaluate(p_loss), err=1e-3) 60 | 61 | p_loss = loss_utils.pinball_loss( 62 | y_true, y_pred=7 * tf.ones_like(y_true), quantile=0.5 63 | ) 64 | self.assertNear(24.0, 9.0 / 0.5 * self.evaluate(p_loss), err=1e-3) 65 | 66 | 67 | if __name__ == '__main__': 68 | tf.test.main() 69 | -------------------------------------------------------------------------------- /tf_agents/bandits/replay_buffers/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Module importing bandit replay buffers.""" 17 | 18 | from tf_agents.bandits.replay_buffers import bandit_replay_buffer 19 | -------------------------------------------------------------------------------- /tf_agents/bandits/replay_buffers/bandit_replay_buffer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A replay buffer for bandit algorithms.""" 17 | 18 | from tf_agents.replay_buffers import tf_uniform_replay_buffer 19 | 20 | 21 | BanditReplayBuffer = tf_uniform_replay_buffer.TFUniformReplayBuffer 22 | -------------------------------------------------------------------------------- /tf_agents/bandits/specs/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Module importing all specs modules.""" 17 | 18 | from tf_agents.bandits.specs import utils 19 | -------------------------------------------------------------------------------- /tf_agents/bandits/specs/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Forwarding utils for backwards compatibility.""" 17 | 18 | from tf_agents.specs import bandit_spec_utils as _utils 19 | 20 | GLOBAL_FEATURE_KEY = _utils.GLOBAL_FEATURE_KEY 21 | PER_ARM_FEATURE_KEY = _utils.PER_ARM_FEATURE_KEY 22 | NUM_ACTIONS_FEATURE_KEY = _utils.NUM_ACTIONS_FEATURE_KEY 23 | 24 | REWARD_SPEC_KEY = _utils.REWARD_SPEC_KEY 25 | CONSTRAINTS_SPEC_KEY = _utils.CONSTRAINTS_SPEC_KEY 26 | 27 | create_per_arm_observation_spec = _utils.create_per_arm_observation_spec 28 | get_context_dims_from_spec = _utils.get_context_dims_from_spec 29 | drop_arm_observation = _utils.drop_arm_observation 30 | -------------------------------------------------------------------------------- /tf_agents/benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | -------------------------------------------------------------------------------- /tf_agents/benchmark/test_data/event_log_3m/events.out.tfevents.1599310762: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/tf_agents/benchmark/test_data/event_log_3m/events.out.tfevents.1599310762 -------------------------------------------------------------------------------- /tf_agents/benchmark/test_data/event_log_too_many/events.out.tfevents.1599310762: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/tf_agents/benchmark/test_data/event_log_too_many/events.out.tfevents.1599310762 -------------------------------------------------------------------------------- /tf_agents/benchmark/test_data/event_log_too_many/events.out.tfevents.1599379945: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/tf_agents/benchmark/test_data/event_log_too_many/events.out.tfevents.1599379945 -------------------------------------------------------------------------------- /tf_agents/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Distributions module.""" 17 | from tf_agents.distributions import masked 18 | from tf_agents.distributions import shifted_categorical 19 | from tf_agents.distributions import tanh_bijector_stable 20 | from tf_agents.distributions import utils 21 | -------------------------------------------------------------------------------- /tf_agents/distributions/gumbel_softmax_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tf_agents.distributions.gumbel_softmax.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import 23 | from tf_agents.distributions import gumbel_softmax 24 | 25 | 26 | class GumbelSoftmaxTest(tf.test.TestCase): 27 | 28 | def testLogProb(self): 29 | temperature = 0.8 30 | logits = [0.3, 0.1, 0.4] 31 | dist = gumbel_softmax.GumbelSoftmax(temperature, logits, validate_args=True) 32 | x = tf.constant([0, 0, 1]) 33 | log_prob = self.evaluate(dist.log_prob(x)) 34 | expected_log_prob = -0.972918868065 35 | self.assertAllClose(expected_log_prob, log_prob) 36 | 37 | def testSample(self): 38 | temperature = 0.8 39 | logits = [0.3, 0.1, 0.4] 40 | dist = gumbel_softmax.GumbelSoftmax( 41 | temperature, logits, dtype=tf.int64, validate_args=True 42 | ) 43 | actions = dist.convert_to_one_hot(dist.sample()) 44 | self.assertEqual(actions.dtype, tf.int64) 45 | self.assertEqual(self.evaluate(tf.reduce_sum(actions, axis=-1)), 1) 46 | 47 | def testMode(self): 48 | temperature = 1.0 49 | logits = [0.3, 0.1, 0.4] 50 | dist = gumbel_softmax.GumbelSoftmax(temperature, logits, validate_args=True) 51 | self.assertAllEqual( 52 | self.evaluate(dist.mode()), self.evaluate(tf.constant([0, 0, 1])) 53 | ) 54 | 55 | 56 | if __name__ == '__main__': 57 | tf.test.main() 58 | -------------------------------------------------------------------------------- /tf_agents/distributions/masked_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests masked distributions.""" 17 | import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import 18 | from tf_agents.distributions import masked 19 | 20 | 21 | class MaskedCategoricalTest(tf.test.TestCase): 22 | 23 | def testCopy(self): 24 | """Confirm we can copy the distribution.""" 25 | distribution = masked.MaskedCategorical( 26 | [100.0, 100.0, 100.0], mask=[True, False, True] 27 | ) 28 | copy = distribution.copy() 29 | with self.cached_session() as s: 30 | probs_np = s.run(copy.probs_parameter()) 31 | logits_np = s.run(copy.logits_parameter()) 32 | ref_probs_np = s.run(distribution.probs_parameter()) 33 | ref_logits_np = s.run(distribution.logits_parameter()) 34 | self.assertAllEqual(ref_logits_np, logits_np) 35 | self.assertAllEqual(ref_probs_np, probs_np) 36 | 37 | def testMasking(self): 38 | distribution = masked.MaskedCategorical( 39 | [100.0, 100.0, 100.0], mask=[True, False, True], neg_inf=None 40 | ) 41 | sample = distribution.sample() 42 | results = [] 43 | 44 | probs_tensor = distribution.probs_parameter() 45 | logits_tensor = distribution.logits_parameter() 46 | 47 | with self.cached_session() as s: 48 | probs_np = s.run(probs_tensor) 49 | logits_np = s.run(logits_tensor) 50 | 51 | # Draw samples & confirm we never draw a masked sample 52 | for _ in range(100): 53 | results.append(s.run(sample)) 54 | 55 | self.assertAllEqual([0.5, 0, 0.5], probs_np) 56 | self.assertAllEqual([100, logits_tensor.dtype.min, 100], logits_np) 57 | self.assertNotIn(1, results) 58 | 59 | 60 | if __name__ == '__main__': 61 | tf.test.main() 62 | -------------------------------------------------------------------------------- /tf_agents/distributions/reparameterized_sampling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Helper function to do reparameterized sampling if the distributions supports it.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow_probability as tfp 23 | from tf_agents.distributions import gumbel_softmax 24 | 25 | 26 | def sample(distribution, reparam=False, **kwargs): 27 | """Sample from distribution either with reparameterized sampling or regular sampling. 28 | 29 | Args: 30 | distribution: A `tfp.distributions.Distribution` instance. 31 | reparam: Whether to use reparameterized sampling. 32 | **kwargs: Parameters to be passed to distribution's sample() fucntion. 33 | 34 | Returns: 35 | """ 36 | if reparam: 37 | if ( 38 | distribution.reparameterization_type 39 | != tfp.distributions.FULLY_REPARAMETERIZED 40 | ): 41 | raise ValueError( 42 | 'This distribution cannot be reparameterized: {}'.format(distribution) 43 | ) 44 | else: 45 | return distribution.sample(**kwargs) 46 | else: 47 | if isinstance(distribution, gumbel_softmax.GumbelSoftmax): 48 | samples = distribution.sample(**kwargs) 49 | return distribution.convert_to_one_hot(samples) 50 | else: 51 | return distribution.sample(**kwargs) 52 | -------------------------------------------------------------------------------- /tf_agents/distributions/tanh_bijector_stable.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tanh bijector.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import 23 | from tensorflow_probability.python.bijectors import bijector 24 | 25 | 26 | __all__ = [ 27 | "Tanh", 28 | ] 29 | 30 | 31 | class Tanh(bijector.Bijector): 32 | """Bijector that computes `Y = tanh(X)`, therefore `Y in (-1, 1)`. 33 | 34 | This can be achieved by an affine transform of the Sigmoid bijector, i.e., 35 | it is equivalent to 36 | ``` 37 | tfb.Chain([tfb.Affine(shift=-1, scale=2.), 38 | tfb.Sigmoid(), 39 | tfb.Affine(scale=2.)]) 40 | ``` 41 | 42 | However, using the `Tanh` bijector directly is slightly faster and more 43 | numerically stable. 44 | """ 45 | 46 | def __init__(self, validate_args=False, name="tanh"): 47 | parameters = dict(locals()) 48 | super(Tanh, self).__init__( 49 | forward_min_event_ndims=0, 50 | validate_args=validate_args, 51 | parameters=parameters, 52 | name=name, 53 | ) 54 | 55 | def _forward(self, x): 56 | return tf.nn.tanh(x) 57 | 58 | def _inverse(self, y): 59 | # 0.99999997 is the maximum value such that atanh(x) is valid for both 60 | # tf.float32 and tf.float64 61 | y = tf.where( 62 | tf.less_equal(tf.abs(y), 1.0), 63 | tf.clip_by_value(y, -0.99999997, 0.99999997), 64 | y, 65 | ) 66 | return tf.atanh(y) 67 | 68 | def _forward_log_det_jacobian(self, x): 69 | # This formula is mathematically equivalent to 70 | # `tf.log1p(-tf.square(tf.tanh(x)))`, however this code is more numerically 71 | # stable. 72 | 73 | # Derivation: 74 | # log(1 - tanh(x)^2) 75 | # = log(sech(x)^2) 76 | # = 2 * log(sech(x)) 77 | # = 2 * log(2e^-x / (e^-2x + 1)) 78 | # = 2 * (log(2) - x - log(e^-2x + 1)) 79 | # = 2 * (log(2) - x - softplus(-2x)) 80 | return 2.0 * ( 81 | tf.math.log(tf.constant(2.0, dtype=x.dtype)) 82 | - x 83 | - tf.nn.softplus(-2.0 * x) 84 | ) 85 | -------------------------------------------------------------------------------- /tf_agents/drivers/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Drivers for running a policy in an environment.""" 17 | 18 | from tf_agents.drivers import driver 19 | from tf_agents.drivers import dynamic_episode_driver 20 | from tf_agents.drivers import dynamic_step_driver 21 | from tf_agents.drivers import py_driver 22 | -------------------------------------------------------------------------------- /tf_agents/drivers/test_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tf_agents.drivers.test_utils.""" 17 | from absl.testing import parameterized 18 | import numpy as np 19 | from tf_agents.drivers import test_utils as driver_test_utils 20 | from tf_agents.trajectories import time_step as ts 21 | from tf_agents.trajectories import trajectory 22 | from tf_agents.utils import test_utils 23 | 24 | 25 | class TestUtilsTest(parameterized.TestCase, test_utils.TestCase): 26 | 27 | @parameterized.named_parameters([ 28 | ('BatchOfOneTrajectoryOfLengthThree', 1, 3), 29 | ('BatchOfOneTrajectoryOfLengthSeven', 1, 7), 30 | ('BatchOfOneTrajectoryOfLengthNine', 1, 9), 31 | ('BatchOfTwoTrajectorieOfLengthThree', 2, 3), 32 | ('BatchOfTwoTrajectorieOfLengthSeven', 2, 7), 33 | ('BatchOfTwoTrajectorieOfLengthNine', 2, 9), 34 | ('BatchOfFiveTrajectorieOfLengthThree', 5, 3), 35 | ('BatchOfFiveTrajectorieOfLengthSeven', 5, 7), 36 | ('BatchOfFiveTrajectorieOfLengthNine', 5, 9), 37 | ]) 38 | def testNumEpisodesObserverEpisodeTotal(self, batch_size, traj_len): 39 | single_trajectory = np.concatenate([ 40 | [ts.StepType.FIRST], 41 | np.repeat(ts.StepType.MID, traj_len - 2), 42 | [ts.StepType.LAST], 43 | ]) 44 | step_type = np.tile(single_trajectory, (batch_size, 1)) 45 | 46 | traj = trajectory.Trajectory( 47 | observation=np.random.rand(batch_size, traj_len), 48 | action=np.random.rand(batch_size, traj_len), 49 | policy_info=(), 50 | reward=np.random.rand(batch_size, traj_len), 51 | discount=np.ones((batch_size, traj_len)), 52 | step_type=step_type, 53 | next_step_type=np.zeros((batch_size, traj_len)), 54 | ) 55 | 56 | observer = driver_test_utils.NumEpisodesObserver() 57 | observer(traj) 58 | self.assertEqual(observer.num_episodes, batch_size) 59 | -------------------------------------------------------------------------------- /tf_agents/environments/atari_wrappers_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for environments.atari_wrappers.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | from absl.testing.absltest import mock 22 | from tf_agents.environments import atari_wrappers 23 | from tf_agents.trajectories import time_step as ts 24 | from tf_agents.utils import test_utils 25 | 26 | 27 | class AtariTimeLimitTest(test_utils.TestCase): 28 | 29 | def test_game_over_after_limit(self): 30 | max_steps = 5 31 | base_env = mock.MagicMock() 32 | wrapped_env = atari_wrappers.AtariTimeLimit(base_env, max_steps) 33 | 34 | base_env.gym.game_over = False 35 | base_env.reset.return_value = ts.restart(1) # pytype: disable=wrong-arg-types 36 | base_env.step.return_value = ts.transition(2, 0) # pytype: disable=wrong-arg-types 37 | action = 1 38 | 39 | self.assertFalse(wrapped_env.game_over) 40 | 41 | for _ in range(max_steps): 42 | time_step = wrapped_env.step(action) # pytype: disable=wrong-arg-types 43 | self.assertFalse(time_step.is_last()) 44 | self.assertFalse(wrapped_env.game_over) 45 | 46 | time_step = wrapped_env.step(action) # pytype: disable=wrong-arg-types 47 | self.assertTrue(time_step.is_last()) 48 | self.assertTrue(wrapped_env.game_over) 49 | 50 | def test_resets_after_limit(self): 51 | max_steps = 5 52 | base_env = mock.MagicMock() 53 | wrapped_env = atari_wrappers.AtariTimeLimit(base_env, max_steps) 54 | 55 | base_env.gym.game_over = False 56 | base_env.reset.return_value = ts.restart(1) # pytype: disable=wrong-arg-types 57 | base_env.step.return_value = ts.transition(2, 0) # pytype: disable=wrong-arg-types 58 | action = 1 59 | 60 | for _ in range(max_steps + 1): 61 | wrapped_env.step(action) # pytype: disable=wrong-arg-types 62 | 63 | self.assertTrue(wrapped_env.game_over) 64 | self.assertEqual(1, base_env.reset.call_count) 65 | 66 | wrapped_env.step(action) # pytype: disable=wrong-arg-types 67 | self.assertFalse(wrapped_env.game_over) 68 | self.assertEqual(2, base_env.reset.call_count) 69 | 70 | 71 | if __name__ == '__main__': 72 | test_utils.main() 73 | -------------------------------------------------------------------------------- /tf_agents/environments/configs/suite_bsuite.gin: -------------------------------------------------------------------------------- 1 | #-*-Python-*- 2 | import tf_agents.environments.suite_bsuite 3 | 4 | ## Configure Environment 5 | ENVIRONMENT = @suite_bsuite.load() 6 | suite_bsuite.load.bsuite_id = %BSUITE_ID 7 | suite_bsuite.load.record = %RECORD 8 | suite_bsuite.load.save_path = %SAVE_PATH 9 | suite_bsuite.load.logging_mode = %LOGGING_MODE 10 | # Note: The BSUITE_ID can be overridden by passing the command line flag: 11 | # --gin_param="BSUITE_ID='deep_sea/0'". Same for SAVE_PATH and LOGGING_MODE. 12 | BSUITE_ID = 'deep_sea/0' 13 | RECORD = True 14 | SAVE_PATH = None 15 | LOGGING_MODE = 'terminal' 16 | -------------------------------------------------------------------------------- /tf_agents/environments/configs/suite_gym.gin: -------------------------------------------------------------------------------- 1 | #-*-Python-*- 2 | import tf_agents.environments.suite_gym 3 | 4 | ## Configure Environment 5 | ENVIRONMENT = @suite_gym.load() 6 | suite_gym.load.environment_name = %ENVIRONMENT_NAME 7 | # Note: The ENVIRONMENT_NAME can be overridden by passing the command line flag: 8 | # --params="ENVIRONMENT_NAME='CartPole-v1'" 9 | ENVIRONMENT_NAME = 'CartPole-v1' 10 | -------------------------------------------------------------------------------- /tf_agents/environments/configs/suite_gymnasium.gin: -------------------------------------------------------------------------------- 1 | #-*-Python-*- 2 | import tf_agents.environments.suite_gymnasium 3 | 4 | ## Configure Environment 5 | ENVIRONMENT = @suite_gymnasium.load() 6 | suite_gymnasium.load.environment_name = %ENVIRONMENT_NAME 7 | # Note: The ENVIRONMENT_NAME can be overridden by passing the command line flag: 8 | # --params="ENVIRONMENT_NAME='CartPole-v1'" 9 | ENVIRONMENT_NAME = 'CartPole-v1' 10 | -------------------------------------------------------------------------------- /tf_agents/environments/configs/suite_mujoco.gin: -------------------------------------------------------------------------------- 1 | #-*-Python-*- 2 | import tf_agents.environments.suite_mujoco 3 | 4 | ## Configure Environment 5 | ENVIRONMENT = @suite_mujoco.load() 6 | suite_mujoco.load.environment_name = %ENVIRONMENT_NAME 7 | # Note: The ENVIRONMENT_NAME can be overridden by passing the command line flag: 8 | # --params="ENVIRONMENT_NAME='HalfCheetah-v2'" 9 | ENVIRONMENT_NAME = 'HalfCheetah-v2' 10 | -------------------------------------------------------------------------------- /tf_agents/environments/configs/suite_pybullet.gin: -------------------------------------------------------------------------------- 1 | #-*-Python-*- 2 | import tf_agents.environments.suite_pybullet 3 | 4 | ## Configure Environment 5 | ENVIRONMENT = @suite_pybullet.load() 6 | suite_pybullet.load.environment_name = %ENVIRONMENT_NAME 7 | # Note: The ENVIRONMENT_NAME can be overridden by passing the command line flag: 8 | # --binding="ENVIRONMENT_NAME='MinitaurBulletEnv-v0'" 9 | ENVIRONMENT_NAME = 'InvertedPendulumBulletEnv-v0' 10 | 11 | -------------------------------------------------------------------------------- /tf_agents/environments/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/environments/examples/masked_cartpole.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Example registering of a new Gym environment. 17 | 18 | See agents/dqn/examples/train_eval_gym_rnn.py for usage. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import gym 26 | from gym.envs.classic_control import cartpole 27 | from gym.envs.registration import register 28 | import numpy as np 29 | 30 | 31 | class MaskedCartPoleEnv(cartpole.CartPoleEnv): 32 | """Cartpole environment with masked velocity components. 33 | 34 | This environment is useful as a unit tests for agents that utilize recurrent 35 | networks. 36 | """ 37 | 38 | def __init__(self): 39 | super(MaskedCartPoleEnv, self).__init__() 40 | high = np.array([ 41 | self.x_threshold * 2, 42 | self.theta_threshold_radians * 2, 43 | ]) 44 | 45 | self.observation_space = gym.spaces.Box(-high, high) 46 | 47 | def _mask_observation(self, observation): 48 | return observation[[0, 2]] 49 | 50 | def reset(self): 51 | observation = super(MaskedCartPoleEnv, self).reset() 52 | # Get rid of velocity components at index 1, and 3. 53 | return self._mask_observation(observation) 54 | 55 | def step(self, action): 56 | observation, reward, done, info = super(MaskedCartPoleEnv, self).step( 57 | action 58 | ) 59 | # Get rid of velocity components at index 1, and 3. 60 | return self._mask_observation(observation), reward, done, info 61 | 62 | 63 | register( 64 | id='MaskedCartPole-v0', 65 | entry_point=MaskedCartPoleEnv, 66 | max_episode_steps=200, 67 | reward_threshold=195.0, 68 | ) 69 | 70 | register( 71 | id='MaskedCartPole-v1', 72 | entry_point=MaskedCartPoleEnv, 73 | max_episode_steps=500, 74 | reward_threshold=475.0, 75 | ) 76 | -------------------------------------------------------------------------------- /tf_agents/environments/suite_atari_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for third_party.py.tf_agents.environments.suite_atari.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl import flags 23 | import numpy as np 24 | from tf_agents.environments import atari_wrappers 25 | from tf_agents.environments import py_environment 26 | from tf_agents.environments import suite_atari 27 | from tf_agents.utils import test_utils 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | # Atari ROMs are placed in atari_py.get_game_path('.') 32 | 33 | 34 | class SuiteAtariTest(test_utils.TestCase): 35 | 36 | def testGameName(self): 37 | name = suite_atari.game('Pong') 38 | self.assertEqual(name, 'PongNoFrameskip-v0') 39 | 40 | def testGameObsType(self): 41 | name = suite_atari.game('Pong', obs_type='ram') 42 | self.assertEqual(name, 'Pong-ramNoFrameskip-v0') 43 | 44 | def testGameMode(self): 45 | name = suite_atari.game('Pong', mode='Deterministic') 46 | self.assertEqual(name, 'PongDeterministic-v0') 47 | 48 | def testGameVersion(self): 49 | name = suite_atari.game('Pong', version='v4') 50 | self.assertEqual(name, 'PongNoFrameskip-v4') 51 | 52 | def testGameSetAll(self): 53 | name = suite_atari.game('Pong', 'ram', 'Deterministic', 'v4') 54 | self.assertEqual(name, 'Pong-ramDeterministic-v4') 55 | 56 | def testAtariEnvRegistered(self): 57 | env = suite_atari.load('Pong-v0') 58 | self.assertIsInstance(env, py_environment.PyEnvironment) 59 | self.assertIsInstance(env, atari_wrappers.AtariTimeLimit) 60 | 61 | def testAtariObsSpec(self): 62 | env = suite_atari.load('Pong-v0') 63 | self.assertIsInstance(env, py_environment.PyEnvironment) 64 | self.assertEqual(np.uint8, env.observation_spec().dtype) 65 | self.assertEqual((84, 84, 1), env.observation_spec().shape) 66 | 67 | def testAtariActionSpec(self): 68 | env = suite_atari.load('Pong-v0') 69 | self.assertIsInstance(env, py_environment.PyEnvironment) 70 | self.assertEqual(np.int64, env.action_spec().dtype) 71 | self.assertEqual((), env.action_spec().shape) 72 | 73 | 74 | if __name__ == '__main__': 75 | test_utils.main() 76 | -------------------------------------------------------------------------------- /tf_agents/environments/suite_bsuite_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tf_agents.environments.suite_bsuite.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import gin 23 | from tf_agents.environments import py_environment 24 | from tf_agents.environments import suite_bsuite 25 | from tf_agents.utils import test_utils 26 | 27 | 28 | class SuiteBsuiteTest(test_utils.TestCase): 29 | 30 | def setUp(self): 31 | super(SuiteBsuiteTest, self).setUp() 32 | if not suite_bsuite.is_available(): 33 | self.skipTest('bsuite is not available.') 34 | 35 | def tearDown(self): 36 | gin.clear_config() 37 | super(SuiteBsuiteTest, self).tearDown() 38 | 39 | def testBsuiteEnvRegisteredWithRecord(self): 40 | env = suite_bsuite.load( 41 | 'deep_sea/0', record=True, save_path=None, logging_mode='terminal' 42 | ) 43 | self.assertIsInstance(env, py_environment.PyEnvironment) 44 | 45 | def testBsuiteEnvRegistered(self): 46 | env = suite_bsuite.load('deep_sea/0', record=False) 47 | self.assertIsInstance(env, py_environment.PyEnvironment) 48 | 49 | def testGinConfig(self): 50 | gin.parse_config_file( 51 | test_utils.test_src_dir_path('environments/configs/suite_bsuite.gin') 52 | ) 53 | env = suite_bsuite.load() 54 | self.assertIsInstance(env, py_environment.PyEnvironment) 55 | 56 | 57 | if __name__ == '__main__': 58 | test_utils.main() 59 | -------------------------------------------------------------------------------- /tf_agents/environments/suite_dm_control_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for dm_control_wrapper.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | import numpy as np 22 | from tf_agents.environments import py_environment 23 | from tf_agents.environments import suite_dm_control 24 | from tf_agents.environments import utils 25 | from tf_agents.utils import test_utils 26 | 27 | 28 | class SuiteDMControlTest(test_utils.TestCase): 29 | 30 | def setUp(self): 31 | super(SuiteDMControlTest, self).setUp() 32 | if not suite_dm_control.is_available(): 33 | self.skipTest('dm_control is not available.') 34 | 35 | def testEnvRegistered(self): 36 | env = suite_dm_control.load('ball_in_cup', 'catch') 37 | self.assertIsInstance(env, py_environment.PyEnvironment) 38 | 39 | utils.validate_py_environment(env) 40 | 41 | def testObservationSpec(self): 42 | env = suite_dm_control.load('ball_in_cup', 'catch') 43 | obs_spec = env.observation_spec() 44 | self.assertEqual(np.float32, obs_spec['position'].dtype) 45 | self.assertEqual((4,), obs_spec['position'].shape) 46 | 47 | def testActionSpec(self): 48 | env = suite_dm_control.load('ball_in_cup', 'catch') 49 | action_spec = env.action_spec() 50 | self.assertEqual(np.float32, action_spec.dtype) 51 | self.assertEqual((2,), action_spec.shape) 52 | 53 | def testPixelObservationSpec(self): 54 | render_kwargs = dict(width=100, height=50) 55 | env = suite_dm_control.load_pixels( 56 | 'ball_in_cup', 'catch', render_kwargs=render_kwargs 57 | ) 58 | obs_spec = env.observation_spec() 59 | 60 | self.assertEqual(np.uint8, obs_spec['pixels'].dtype) 61 | self.assertEqual((50, 100, 3), obs_spec['pixels'].shape) 62 | 63 | 64 | if __name__ == '__main__': 65 | test_utils.main() 66 | -------------------------------------------------------------------------------- /tf_agents/environments/suite_mujoco_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tf_agents.environments.suite_mujoco.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import gin 23 | import numpy as np 24 | from tf_agents.environments import py_environment 25 | from tf_agents.environments import suite_mujoco 26 | from tf_agents.environments import wrappers 27 | from tf_agents.utils import test_utils 28 | 29 | 30 | class SuiteMujocoTest(test_utils.TestCase): 31 | 32 | def setUp(self): 33 | super(SuiteMujocoTest, self).setUp() 34 | if not suite_mujoco.is_available(): 35 | self.skipTest('suite_mujoco is not available.') 36 | 37 | def tearDown(self): 38 | gin.clear_config() 39 | super(SuiteMujocoTest, self).tearDown() 40 | 41 | def testMujocoEnvRegistered(self): 42 | env = suite_mujoco.load('HalfCheetah-v2') 43 | self.assertIsInstance(env, py_environment.PyEnvironment) 44 | self.assertIsInstance(env, wrappers.TimeLimit) 45 | 46 | def testObservationSpec(self): 47 | env = suite_mujoco.load('HalfCheetah-v2') 48 | self.assertEqual(np.float32, env.observation_spec().dtype) 49 | self.assertEqual((17,), env.observation_spec().shape) 50 | 51 | def testActionSpec(self): 52 | env = suite_mujoco.load('HalfCheetah-v2') 53 | self.assertEqual(np.float32, env.action_spec().dtype) 54 | self.assertEqual((6,), env.action_spec().shape) 55 | 56 | def testGinConfig(self): 57 | gin.parse_config_file( 58 | test_utils.test_src_dir_path('environments/configs/suite_mujoco.gin') 59 | ) 60 | env = suite_mujoco.load() 61 | self.assertIsInstance(env, py_environment.PyEnvironment) 62 | self.assertIsInstance(env, wrappers.TimeLimit) 63 | 64 | 65 | if __name__ == '__main__': 66 | test_utils.main() 67 | -------------------------------------------------------------------------------- /tf_agents/environments/suite_pybullet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Suite for loading pybullet Gym environments. 17 | 18 | Importing pybullet_envs registers the environments. Once this is done the 19 | regular gym loading mechanism used in suite_gym will generate pybullet envs. 20 | 21 | For a list of registered pybullet environments take a look at: 22 | pybullet_envs/__init__.py 23 | 24 | To visualize a pybullet environment as it is being run you can launch the 25 | example browser BEFORE you start the training. 26 | 27 | ```bash 28 | ExampleBrowser -- --start_demo_name="PhysicsServer" 29 | ``` 30 | """ 31 | import gin 32 | from tf_agents.environments import suite_gym 33 | 34 | # pylint: disable=unused-import 35 | import pybullet_envs 36 | # pylint: enable=unused-import 37 | 38 | load = gin.external_configurable(suite_gym.load, 'suite_pybullet.load') 39 | -------------------------------------------------------------------------------- /tf_agents/environments/suite_pybullet_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests tf_agents.environments.suite_pybullet.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import gin 23 | from tf_agents.environments import py_environment 24 | from tf_agents.environments import suite_pybullet 25 | from tf_agents.environments import wrappers 26 | from tf_agents.utils import test_utils 27 | 28 | 29 | class SuitePybulletTest(test_utils.TestCase): 30 | 31 | def tearDown(self): 32 | gin.clear_config() 33 | super(SuitePybulletTest, self).tearDown() 34 | 35 | def testPybulletEnvRegistered(self): 36 | env = suite_pybullet.load('InvertedPendulumBulletEnv-v0') 37 | self.assertIsInstance(env, py_environment.PyEnvironment) 38 | self.assertIsInstance(env, wrappers.TimeLimit) 39 | 40 | def testGinConfig(self): 41 | gin.parse_config_file( 42 | test_utils.test_src_dir_path('environments/configs/suite_pybullet.gin') 43 | ) 44 | env = suite_pybullet.load() 45 | self.assertIsInstance(env, py_environment.PyEnvironment) 46 | self.assertIsInstance(env, wrappers.TimeLimit) 47 | 48 | 49 | if __name__ == '__main__': 50 | test_utils.main() 51 | -------------------------------------------------------------------------------- /tf_agents/eval/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Eval module.""" 17 | 18 | from tf_agents.eval import metric_utils 19 | -------------------------------------------------------------------------------- /tf_agents/eval/metric_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Test for tf_agents.eval.metric_utils.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import 24 | from tf_agents.environments import random_py_environment 25 | from tf_agents.eval import metric_utils 26 | from tf_agents.metrics import py_metrics 27 | from tf_agents.policies import random_py_policy 28 | from tf_agents.specs import array_spec 29 | 30 | 31 | class MetricUtilsTest(tf.test.TestCase): 32 | 33 | def testMetricIsComputedCorrectly(self): 34 | def reward_fn(*unused_args): 35 | reward = np.random.uniform() 36 | reward_fn.total_reward += reward 37 | return reward 38 | 39 | reward_fn.total_reward = 0 40 | 41 | action_spec = array_spec.BoundedArraySpec((1,), np.int32, -10, 10) 42 | observation_spec = array_spec.BoundedArraySpec((1,), np.int32, -10, 10) 43 | env = random_py_environment.RandomPyEnvironment( 44 | observation_spec, action_spec, reward_fn=reward_fn 45 | ) 46 | policy = random_py_policy.RandomPyPolicy( 47 | time_step_spec=None, action_spec=action_spec 48 | ) 49 | 50 | average_return = py_metrics.AverageReturnMetric() 51 | 52 | num_episodes = 10 53 | results = metric_utils.compute([average_return], env, policy, num_episodes) 54 | self.assertAlmostEqual( 55 | reward_fn.total_reward / num_episodes, 56 | results[average_return.name], 57 | places=5, 58 | ) 59 | 60 | 61 | if __name__ == '__main__': 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /tf_agents/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/examples/cql_sac/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/examples/cql_sac/kumar20/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/examples/cql_sac/kumar20/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/examples/cql_sac/kumar20/configs/antmaze.gin: -------------------------------------------------------------------------------- 1 | train_eval.bc_steps = 10000 2 | train_eval.reward_scale_factor = 4.0 3 | train_eval.reward_shift = -0.5 4 | train_eval.critic_learning_rate = 3e-4 5 | train_eval.actor_learning_rate = 1e-4 6 | train_eval.cql_tau = 5.0 7 | train_eval.reward_noise_variance = 0.1 8 | train_eval.log_cql_alpha_clipping = (-1, 100.0) 9 | -------------------------------------------------------------------------------- /tf_agents/examples/cql_sac/kumar20/configs/mujoco.gin: -------------------------------------------------------------------------------- 1 | train_eval.reward_scale_factor = 0.1 2 | train_eval.softmax_temperature = 50.0 3 | train_eval.use_lagrange_cql_alpha = False 4 | train_eval.actor_learning_rate = 3e-4 5 | train_eval.critic_learning_rate = 3e-4 6 | train_eval.action_clipping = (-0.995, 0.995) 7 | train_eval.bc_steps = 10000 8 | -------------------------------------------------------------------------------- /tf_agents/examples/cql_sac/kumar20/configs/mujoco_medium.gin: -------------------------------------------------------------------------------- 1 | include 'tf_agents/examples/cql_sac/kumar20/configs/mujoco.gin' 2 | 3 | train_eval.cql_alpha = 0.1 4 | -------------------------------------------------------------------------------- /tf_agents/examples/cql_sac/kumar20/configs/mujoco_medium_expert.gin: -------------------------------------------------------------------------------- 1 | include 'tf_agents/examples/cql_sac/kumar20/configs/mujoco.gin' 2 | 3 | train_eval.cql_alpha = 1.0 4 | -------------------------------------------------------------------------------- /tf_agents/examples/cql_sac/kumar20/d4rl_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utils for using D4RL in TF-Agents.""" 17 | import d4rl # pylint: disable=unused-import 18 | import gin 19 | import gym 20 | 21 | from gym.wrappers.time_limit import TimeLimit 22 | from tf_agents.environments import gym_wrapper 23 | 24 | 25 | @gin.configurable 26 | def load_d4rl(env_name, default_time_limit=1000): 27 | """Loads the python environment from D4RL.""" 28 | gym_env = gym.make(env_name) 29 | gym_spec = gym.spec(env_name) 30 | 31 | # Default to env time limit unless it is not specified. 32 | if gym_spec.max_episode_steps in [0, None]: 33 | gym_env = TimeLimit(gym_env, max_episode_steps=default_time_limit) 34 | 35 | # Wrap TF-Agents environment. 36 | env = gym_wrapper.GymWrapper(gym_env) 37 | return env 38 | -------------------------------------------------------------------------------- /tf_agents/examples/cql_sac/kumar20/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/examples/cql_sac/kumar20/dataset/test_data/antmaze-medium-play-v0_0.tfrecord: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/tf_agents/examples/cql_sac/kumar20/dataset/test_data/antmaze-medium-play-v0_0.tfrecord -------------------------------------------------------------------------------- /tf_agents/examples/cql_sac/kumar20/dataset/test_data/antmaze-medium-play-v0_0.tfrecord.spec: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/tf_agents/examples/cql_sac/kumar20/dataset/test_data/antmaze-medium-play-v0_0.tfrecord.spec -------------------------------------------------------------------------------- /tf_agents/examples/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/examples/dqn/mnih15/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/examples/dqn/mnih15/configs/breakout.gin: -------------------------------------------------------------------------------- 1 | # v4 aligns the action space with the Deepmind paper. 2 | # Deterministic ensures that a fixed frameskip of 4 is applied. 3 | train_eval.env_name='BreakoutDeterministic-v4' 4 | AtariPreprocessing.terminal_on_life_loss = True 5 | -------------------------------------------------------------------------------- /tf_agents/examples/dqn/mnih15/configs/pong.gin: -------------------------------------------------------------------------------- 1 | train_eval.env_name='Pong-v0' 2 | AtariPreprocessing.terminal_on_life_loss = True 3 | -------------------------------------------------------------------------------- /tf_agents/examples/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/examples/ppo/schulman17/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/examples/ppo/schulman17/configs/half_cheetah.gin: -------------------------------------------------------------------------------- 1 | train_eval.env_name = 'HalfCheetah-v2' 2 | train_eval.eval_episodes = 100 3 | -------------------------------------------------------------------------------- /tf_agents/examples/ppo/schulman17/configs/hopper.gin: -------------------------------------------------------------------------------- 1 | train_eval.env_name = 'Hopper-v2' 2 | train_eval.eval_episodes = 100 3 | -------------------------------------------------------------------------------- /tf_agents/examples/ppo/schulman17/configs/inverted_double_pendulum.gin: -------------------------------------------------------------------------------- 1 | train_eval.env_name = 'InvertedDoublePendulum-v2' 2 | train_eval.eval_episodes = 100 3 | -------------------------------------------------------------------------------- /tf_agents/examples/ppo/schulman17/configs/inverted_pendulum.gin: -------------------------------------------------------------------------------- 1 | train_eval.env_name = 'InvertedPendulum-v2' 2 | train_eval.eval_episodes = 100 3 | -------------------------------------------------------------------------------- /tf_agents/examples/ppo/schulman17/configs/reacher.gin: -------------------------------------------------------------------------------- 1 | train_eval.env_name = 'Reacher-v2' 2 | train_eval.eval_episodes = 100 3 | -------------------------------------------------------------------------------- /tf_agents/examples/ppo/schulman17/configs/swimmer.gin: -------------------------------------------------------------------------------- 1 | train_eval.env_name = 'Swimmer-v2' 2 | train_eval.eval_episodes = 100 3 | -------------------------------------------------------------------------------- /tf_agents/examples/ppo/schulman17/configs/walker_2d.gin: -------------------------------------------------------------------------------- 1 | train_eval.env_name = 'Walker2d-v2' 2 | train_eval.eval_episodes = 100 3 | -------------------------------------------------------------------------------- /tf_agents/examples/sac/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/examples/sac/haarnoja18/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/examples/sac/haarnoja18/configs/ant.gin: -------------------------------------------------------------------------------- 1 | train_eval.env_name = 'Ant-v2' 2 | train_eval.initial_collect_steps = 10000 3 | -------------------------------------------------------------------------------- /tf_agents/examples/sac/haarnoja18/configs/half_cheetah.gin: -------------------------------------------------------------------------------- 1 | train_eval.env_name = 'HalfCheetah-v2' 2 | train_eval.initial_collect_steps = 10000 3 | -------------------------------------------------------------------------------- /tf_agents/examples/sac/haarnoja18/configs/hopper.gin: -------------------------------------------------------------------------------- 1 | train_eval.env_name = 'Hopper-v2' 2 | train_eval.initial_collect_steps = 1000 3 | -------------------------------------------------------------------------------- /tf_agents/examples/sac/haarnoja18/configs/humanoid.gin: -------------------------------------------------------------------------------- 1 | train_eval.env_name = 'Humanoid-v2' 2 | train_eval.initial_collect_steps = 1000 3 | -------------------------------------------------------------------------------- /tf_agents/examples/sac/haarnoja18/configs/walker_2d.gin: -------------------------------------------------------------------------------- 1 | train_eval.env_name = 'Walker2d-v2' 2 | train_eval.initial_collect_steps = 1000 3 | -------------------------------------------------------------------------------- /tf_agents/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """TF-Agents Experimental Modules. 17 | 18 | These utilities, libraries, and tools have not been rigorously tested for 19 | production use. For example, experimental examples may not have associated 20 | nightly regression tests. 21 | """ 22 | # Aliasing the already moved `tf_agent.train` module from its new location here 23 | # for backward compatibility. 24 | # TODO(b/175303833): Remove this when everyone uses the new dependencies. 25 | from tf_agents import train 26 | from tf_agents.experimental import distributed 27 | -------------------------------------------------------------------------------- /tf_agents/experimental/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """TF-Agents Experimental Distributed Library.""" 17 | 18 | from tf_agents.experimental.distributed.reverb_variable_container import ReverbVariableContainer 19 | -------------------------------------------------------------------------------- /tf_agents/experimental/distributed/examples/sac/README.md: -------------------------------------------------------------------------------- 1 | # Distributed SAC Launch Instructions 2 | 3 | ## Launching locally: 4 | In separate terminals, launch the following three (3) jobs. You can use a different 5 | port number: 6 | 7 | 1) Reverb server: 8 | 9 | ```shell 10 | python tf_agents/experimental/distributed/examples/sac/sac_reverb_server.py -- \ 11 | --root_dir=/tmp/sac_train/ \ 12 | --port=8008 \ 13 | --alsologtostderr 14 | ``` 15 | 16 | 2) Collect job: 17 | 18 | ```shell 19 | $ python tf_agents/experimental/distributed/examples/sac/sac_collect.py -- \ 20 | --root_dir=/tmp/sac_train/ \ 21 | --gin_bindings='collect.environment_name="HalfCheetah-v2"' \ 22 | --replay_buffer_server_address=localhost:8008 \ 23 | --variable_container_server_address=localhost:8008 \ 24 | --alsologtostderr 25 | ``` 26 | 27 | 3) Train job: 28 | 29 | ```shell 30 | $ python tf_agents/experimental/distributed/examples/sac/sac_train.py -- \ 31 | --root_dir=/tmp/sac_train/ \ 32 | --gin_bindings='train.environment_name="HalfCheetah-v2"' \ 33 | --gin_bindings='train.learning_rate=0.0003' \ 34 | --replay_buffer_server_address=localhost:8008 \ 35 | --variable_container_server_address=localhost:8008 \ 36 | --alsologtostderr 37 | ``` 38 | 39 | 4) Eval job (optional): 40 | 41 | Not SAC specific. The evaluator job simply reads the greedy policy from an 42 | arbitrary actor-learner `root_dir`, instantiates an environment (defined by the 43 | GIN bindings to `evaluate.environment_name` and `evaluate.suite_load_fn`; 44 | assumed the environment dependencies are already provided), then evaluates the 45 | policy iteratively on policy parameters provided by the variable container. 46 | 47 | ```shell 48 | $ python tf_agents/experimental/distributed/examples/ckpt_evaluator.py -- \ 49 | --root_dir=/tmp/sac_train/ \ 50 | --env_name='HalfCheetah-v2' \ 51 | --alsologtostderr 52 | ``` 53 | -------------------------------------------------------------------------------- /tf_agents/experimental/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/experimental/examples/ppo/train_eval_lib_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for train_eval_lib.""" 17 | 18 | from tf_agents.experimental.examples.ppo import train_eval_lib 19 | from tf_agents.system import system_multiprocessing as multiprocessing 20 | from tf_agents.utils import test_utils 21 | 22 | 23 | class TrainEvalLibTest(test_utils.TestCase): 24 | 25 | def test_train_eval(self): 26 | train_eval_lib.train_eval( 27 | root_dir=self.create_tempdir(), 28 | env_name='HalfCheetah-v2', 29 | # Training params 30 | num_iterations=2, 31 | actor_fc_layers=(20, 10), 32 | value_fc_layers=(20, 10), 33 | learning_rate=1e-3, 34 | collect_sequence_length=10, 35 | minibatch_size=None, 36 | num_epochs=2, 37 | # Agent params 38 | importance_ratio_clipping=0.2, 39 | lambda_value=0.95, 40 | discount_factor=0.99, 41 | entropy_regularization=0.0, 42 | value_pred_loss_coef=0.5, 43 | use_gae=True, 44 | use_td_lambda_return=True, 45 | gradient_clipping=None, 46 | value_clipping=None, 47 | # Replay params 48 | reverb_port=None, 49 | replay_capacity=10000, 50 | # Others 51 | policy_save_interval=0, 52 | summary_interval=0, 53 | eval_interval=0, 54 | ) 55 | 56 | 57 | if __name__ == '__main__': 58 | multiprocessing.handle_test_main(test_utils.main) 59 | -------------------------------------------------------------------------------- /tf_agents/keras_layers/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Keras Layers Module.""" 17 | 18 | from tf_agents.keras_layers.bias_layer import BiasLayer 19 | from tf_agents.keras_layers.dynamic_unroll_layer import DynamicUnroll 20 | from tf_agents.keras_layers.inner_reshape import InnerReshape 21 | from tf_agents.keras_layers.rnn_wrapper import RNNWrapper 22 | from tf_agents.keras_layers.squashed_outer_wrapper import SquashedOuterWrapper 23 | -------------------------------------------------------------------------------- /tf_agents/keras_layers/bias_layer_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tf_agents.keras_layers.bias_layer.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import 24 | from tf_agents.keras_layers import bias_layer 25 | 26 | 27 | class BiasLayerTest(tf.test.TestCase): 28 | 29 | def testBuild(self): 30 | bias = bias_layer.BiasLayer() 31 | states = tf.ones((2, 3)) 32 | out = bias(states) 33 | self.evaluate(tf.compat.v1.global_variables_initializer()) 34 | np.testing.assert_almost_equal([[1.0] * 3] * 2, self.evaluate(out)) 35 | 36 | def testBuildScalar(self): 37 | bias = bias_layer.BiasLayer() 38 | states = tf.ones((2,)) 39 | out = bias(states) 40 | self.evaluate(tf.compat.v1.global_variables_initializer()) 41 | np.testing.assert_almost_equal([1.0] * 2, self.evaluate(out)) 42 | 43 | def testTrainableVariables(self): 44 | bias = bias_layer.BiasLayer( 45 | bias_initializer=tf.constant_initializer(value=1.0) 46 | ) 47 | states = tf.zeros((2, 3)) 48 | _ = bias(states) 49 | self.evaluate(tf.compat.v1.global_variables_initializer()) 50 | variables = bias.trainable_variables 51 | np.testing.assert_almost_equal([[1.0] * 3], self.evaluate(variables)) 52 | 53 | 54 | if __name__ == '__main__': 55 | tf.test.main() 56 | -------------------------------------------------------------------------------- /tf_agents/keras_layers/inner_reshape_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tf_agents.keras_layers.inner_reshape.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | import tensorflow as tf 24 | from tf_agents.keras_layers import inner_reshape 25 | from tf_agents.utils import test_utils 26 | 27 | 28 | class InnerReshapeTest(test_utils.TestCase): 29 | 30 | def testInnerReshapeSimple(self): 31 | layer = inner_reshape.InnerReshape([3, 4], [12]) 32 | out = layer(np.arange(2 * 12).reshape(2, 3, 4)) 33 | self.assertAllEqual(self.evaluate(out), np.arange(2 * 12).reshape(2, 12)) 34 | out = layer(np.arange(4 * 12).reshape(2, 2, 3, 4)) 35 | self.assertAllEqual(self.evaluate(out), np.arange(4 * 12).reshape(2, 2, 12)) 36 | 37 | def testInnerReshapeUnknowns(self): 38 | layer = inner_reshape.InnerReshape([None, None], [-1]) 39 | out = layer(np.arange(3 * 20).reshape(3, 4, 5)) 40 | self.assertAllEqual(self.evaluate(out), np.arange(3 * 20).reshape(3, 20)) 41 | out = layer(np.arange(6 * 20).reshape(2, 3, 4, 5)) 42 | self.assertAllEqual(self.evaluate(out), np.arange(6 * 20).reshape(2, 3, 20)) 43 | 44 | def testIncompatibleShapes(self): 45 | with self.assertRaisesRegex(ValueError, 'must have known rank'): 46 | inner_reshape.InnerReshape(tf.TensorShape(None), [1]) 47 | 48 | with self.assertRaisesRegex(ValueError, 'Mismatched number of elements'): 49 | inner_reshape.InnerReshape([1, 2], []) 50 | 51 | with self.assertRaisesRegex(ValueError, r'Shapes.*are incompatible'): 52 | inner_reshape.InnerReshape([1], [1, 1])(np.ones((2, 3))) 53 | 54 | 55 | if __name__ == '__main__': 56 | test_utils.main() 57 | -------------------------------------------------------------------------------- /tf_agents/keras_layers/permanent_variable_rate_dropout.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A keras layer that applies dropout both in training and serving. 17 | 18 | Add the possibility to apply a variable dropout rate, that is, the rate 19 | parameter can be a callable. 20 | """ 21 | 22 | import tensorflow as tf 23 | 24 | 25 | class PermanentVariableRateDropout(tf.keras.layers.Dropout): 26 | """Applies dropout both in training and serving, with variable dropout rate. 27 | 28 | Initialize this layer the same was as `keras.layers.Dropout`, with two notable 29 | differences: 30 | --The parameter `rate` can also be a callable. 31 | --The extra boolean parameter `permanent`. If set to true, dropout will be 32 | applied both in training and inference. 33 | """ 34 | 35 | def __init__(self, rate, permanent=False, **kwargs): 36 | self._permanent = permanent 37 | super(PermanentVariableRateDropout, self).__init__(rate, **kwargs) 38 | 39 | def call(self, inputs, training=None): 40 | # If permanent, ignore training, we are keeping dropout. 41 | if self._permanent: 42 | training = True 43 | if training is None: 44 | training = tf.keras.backend.learning_phase() 45 | 46 | if training: 47 | rate = self._get_dropout_value() 48 | outputs = tf.nn.dropout( 49 | inputs, 50 | noise_shape=self._get_noise_shape(inputs), 51 | seed=self.seed, 52 | rate=rate, 53 | ) 54 | return outputs 55 | else: 56 | return inputs 57 | 58 | def _get_dropout_value(self): 59 | if callable(self.rate): 60 | return self.rate() 61 | return self.rate 62 | -------------------------------------------------------------------------------- /tf_agents/keras_layers/permanent_variable_rate_dropout_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tf_agents.keras_layers.permanent_variable_rate_dropout.""" 17 | 18 | import tensorflow as tf 19 | from tf_agents.keras_layers import permanent_variable_rate_dropout 20 | from tf_agents.utils import test_utils 21 | 22 | 23 | class PermanentVariableRateDropoutTest(test_utils.TestCase): 24 | 25 | def testPermanent(self): 26 | var = tf.Variable(0.5, dtype=tf.float32) 27 | 28 | def dropout_fn(): 29 | return tf.identity(var) 30 | 31 | layer = permanent_variable_rate_dropout.PermanentVariableRateDropout( 32 | rate=dropout_fn, permanent=True 33 | ) 34 | inputs = tf.reshape(tf.range(4 * 12, dtype=tf.float32), shape=(2, 2, 3, 4)) 35 | out = layer(inputs) 36 | scaled = inputs * 2 37 | # All elements should be either zero or the scaled input. 38 | self.assertAllClose(out * (scaled - out), tf.zeros_like(inputs)) 39 | 40 | out = layer(inputs, training=False) 41 | self.assertAllClose(out * (scaled - out), tf.zeros_like(inputs)) 42 | 43 | var.assign(0.3) 44 | out = layer(inputs) 45 | scaled = inputs / 0.7 46 | self.assertAllClose(out * (scaled - out), tf.zeros_like(inputs)) 47 | 48 | def testNonPermanent(self): 49 | var = tf.Variable(0.5, dtype=tf.float32) 50 | 51 | def dropout_fn(): 52 | return tf.identity(var) 53 | 54 | layer = permanent_variable_rate_dropout.PermanentVariableRateDropout( 55 | rate=dropout_fn 56 | ) 57 | inputs = tf.reshape(tf.range(4 * 12, dtype=tf.float32), shape=(2, 2, 3, 4)) 58 | out = layer(inputs, training=True) 59 | scaled = inputs * 2 60 | # All elements should be either zero or the scaled input. 61 | self.assertAllClose(out * (scaled - out), tf.zeros_like(inputs)) 62 | 63 | out = layer(inputs, training=False) 64 | self.assertAllClose(out, inputs) 65 | 66 | 67 | if __name__ == '__main__': 68 | test_utils.main() 69 | -------------------------------------------------------------------------------- /tf_agents/keras_layers/squashed_outer_wrapper_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tf_agents.keras_layers.squashed_outer_wrapper.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | from tf_agents.keras_layers import squashed_outer_wrapper 24 | from tf_agents.utils import common 25 | from tf_agents.utils import test_utils 26 | 27 | 28 | class SquashedOuterWrapperTest(test_utils.TestCase): 29 | 30 | def testFromConfigBatchNorm(self): 31 | l1 = squashed_outer_wrapper.SquashedOuterWrapper( 32 | tf.keras.layers.BatchNormalization(axis=-1), inner_rank=3 33 | ) 34 | l2 = squashed_outer_wrapper.SquashedOuterWrapper.from_config( 35 | l1.get_config() 36 | ) 37 | self.assertEqual(l1.get_config(), l2.get_config()) 38 | 39 | def testSquashedOuterWrapperSimple(self): 40 | bn = tf.keras.layers.BatchNormalization(axis=-1) 41 | layer = squashed_outer_wrapper.SquashedOuterWrapper(bn, inner_rank=3) 42 | 43 | inputs_flat = tf.range(3 * 4 * 5 * 6 * 7, dtype=tf.float32) 44 | inputs_2_batch = tf.reshape(inputs_flat, [3, 4, 5, 6, 7]) 45 | outputs_2_batch = layer(inputs_2_batch) 46 | 47 | inputs_1_batch = tf.reshape(inputs_flat, [3 * 4, 5, 6, 7]) 48 | outputs_1_batch = layer(inputs_1_batch) 49 | outputs_1_batch_reshaped = tf.reshape(outputs_1_batch, [3, 4, 5, 6, 7]) 50 | 51 | self.evaluate(tf.compat.v1.global_variables_initializer()) 52 | self.assertAllClose( 53 | self.evaluate(outputs_2_batch), self.evaluate(outputs_1_batch_reshaped) 54 | ) 55 | 56 | def testIncompatibleShapes(self): 57 | bn = tf.keras.layers.BatchNormalization(axis=-1) 58 | layer = squashed_outer_wrapper.SquashedOuterWrapper(bn, inner_rank=3) 59 | 60 | with self.assertRaisesRegex(ValueError, 'must have known rank'): 61 | fn = common.function(layer) 62 | fn.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.float32)) 63 | 64 | 65 | if __name__ == '__main__': 66 | test_utils.main() 67 | -------------------------------------------------------------------------------- /tf_agents/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Metrics module.""" 17 | 18 | from tf_agents.metrics import batched_py_metric 19 | from tf_agents.metrics import export_utils 20 | from tf_agents.metrics import py_metric 21 | from tf_agents.metrics import py_metrics 22 | from tf_agents.metrics import tf_metric 23 | from tf_agents.metrics import tf_metrics 24 | from tf_agents.metrics import tf_py_metric 25 | -------------------------------------------------------------------------------- /tf_agents/metrics/export_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utils to export metrics.""" 17 | 18 | from absl import logging 19 | 20 | 21 | def export_metrics(step, metrics, loss_info=None): 22 | """Exports the metrics and loss information to logging.info. 23 | 24 | Args: 25 | step: Integer denoting the round at which we log the metrics. 26 | metrics: List of `TF metrics` to log. 27 | loss_info: An optional instance of `LossInfo` whose value is logged. 28 | """ 29 | 30 | def logging_at_step_fn(name, value): 31 | logging_msg = f'[step={step}] {name} = {value}.' 32 | logging.info(logging_msg) 33 | 34 | for metric in metrics: 35 | logging_at_step_fn(metric.name, metric.result()) 36 | if loss_info is not None: 37 | logging_at_step_fn('loss', loss_info.loss) 38 | -------------------------------------------------------------------------------- /tf_agents/networks/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Networks Module.""" 17 | 18 | from tf_agents.networks import actor_distribution_network 19 | from tf_agents.networks import actor_distribution_rnn_network 20 | from tf_agents.networks import categorical_projection_network 21 | from tf_agents.networks import categorical_q_network 22 | from tf_agents.networks import dueling_q_network 23 | from tf_agents.networks import encoding_network 24 | from tf_agents.networks import expand_dims_layer 25 | from tf_agents.networks import lstm_encoding_network 26 | from tf_agents.networks import mask_splitter_network 27 | from tf_agents.networks import nest_map 28 | from tf_agents.networks import network 29 | from tf_agents.networks import normal_projection_network 30 | from tf_agents.networks import q_network 31 | from tf_agents.networks import q_rnn_network 32 | from tf_agents.networks import sequential 33 | from tf_agents.networks import utils 34 | from tf_agents.networks import value_network 35 | from tf_agents.networks import value_rnn_network 36 | from tf_agents.networks.nest_map import NestFlatten 37 | from tf_agents.networks.nest_map import NestMap 38 | from tf_agents.networks.network import Network 39 | from tf_agents.networks.sequential import Sequential 40 | -------------------------------------------------------------------------------- /tf_agents/networks/categorical_projection_network_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tf_agents.networks.categorical_projection_network.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import 22 | import tensorflow_probability as tfp 23 | from tf_agents.networks import categorical_projection_network 24 | from tf_agents.specs import tensor_spec 25 | 26 | 27 | def _get_inputs(batch_size, num_input_dims): 28 | return tf.random.uniform([batch_size, num_input_dims]) 29 | 30 | 31 | class CategoricalProjectionNetworkTest(tf.test.TestCase): 32 | 33 | def testBuild(self): 34 | output_spec = tensor_spec.BoundedTensorSpec([2, 3], tf.int32, 0, 1) 35 | network = categorical_projection_network.CategoricalProjectionNetwork( 36 | output_spec 37 | ) 38 | 39 | inputs = _get_inputs(batch_size=3, num_input_dims=5) 40 | 41 | distribution, _ = network(inputs, outer_rank=1) 42 | self.evaluate(tf.compat.v1.global_variables_initializer()) 43 | sample = self.evaluate(distribution.sample()) 44 | 45 | self.assertEqual(tfp.distributions.Categorical, type(distribution)) 46 | # Batch = 3; 2x3 action choices, 2x actions per choise. 47 | self.assertEqual((3, 2, 3, 2), distribution.logits.shape) 48 | self.assertAllEqual((3, 2, 3), sample.shape) 49 | 50 | def testTrainableVariables(self): 51 | output_spec = tensor_spec.BoundedTensorSpec([2], tf.int32, 0, 1) 52 | network = categorical_projection_network.CategoricalProjectionNetwork( 53 | output_spec 54 | ) 55 | 56 | inputs = _get_inputs(batch_size=3, num_input_dims=5) 57 | 58 | network(inputs, outer_rank=1) 59 | self.evaluate(tf.compat.v1.global_variables_initializer()) 60 | 61 | # Dense kernel, dense bias. 62 | self.assertEqual(2, len(network.trainable_variables)) 63 | self.assertEqual((5, 4), network.trainable_variables[0].shape) 64 | self.assertEqual((4,), network.trainable_variables[1].shape) 65 | 66 | 67 | if __name__ == '__main__': 68 | tf.test.main() 69 | -------------------------------------------------------------------------------- /tf_agents/networks/test_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Common utility functions for testing.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tf_agents.networks import network 23 | 24 | 25 | class KerasLayersNet(network.Network): 26 | 27 | def __init__(self, observation_spec, action_spec, layer, name=None): 28 | super(KerasLayersNet, self).__init__( 29 | observation_spec, state_spec=(), name=name 30 | ) 31 | self._layer = layer 32 | 33 | def call(self, inputs, step_type=None, network_state=()): 34 | del step_type 35 | return self._layer(inputs), network_state 36 | -------------------------------------------------------------------------------- /tf_agents/policies/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Policies Module.""" 17 | 18 | from tf_agents.policies import actor_policy 19 | from tf_agents.policies import boltzmann_policy 20 | from tf_agents.policies import epsilon_greedy_policy 21 | from tf_agents.policies import fixed_policy 22 | from tf_agents.policies import gaussian_policy 23 | from tf_agents.policies import greedy_policy 24 | from tf_agents.policies import ou_noise_policy 25 | from tf_agents.policies import policy_saver 26 | from tf_agents.policies import py_policy 27 | from tf_agents.policies import py_tf_eager_policy 28 | from tf_agents.policies import py_tf_policy 29 | from tf_agents.policies import q_policy 30 | from tf_agents.policies import random_py_policy 31 | from tf_agents.policies import random_tf_policy 32 | from tf_agents.policies import scripted_py_policy 33 | from tf_agents.policies import tf_policy 34 | from tf_agents.policies import tf_py_policy 35 | from tf_agents.policies import utils 36 | from tf_agents.policies.actor_policy import ActorPolicy 37 | from tf_agents.policies.epsilon_greedy_policy import EpsilonGreedyPolicy 38 | from tf_agents.policies.greedy_policy import GreedyPolicy 39 | from tf_agents.policies.policy_saver import PolicySaver 40 | from tf_agents.policies.py_tf_eager_policy import PyTFEagerPolicy 41 | from tf_agents.policies.py_tf_eager_policy import SavedModelPyTFEagerPolicy 42 | from tf_agents.policies.tf_policy import TFPolicy 43 | -------------------------------------------------------------------------------- /tf_agents/replay_buffers/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Replay Buffers Module.""" 17 | 18 | from tf_agents.replay_buffers import episodic_replay_buffer 19 | from tf_agents.replay_buffers import py_hashed_replay_buffer 20 | from tf_agents.replay_buffers import py_uniform_replay_buffer 21 | from tf_agents.replay_buffers import replay_buffer 22 | from tf_agents.replay_buffers import reverb_replay_buffer 23 | from tf_agents.replay_buffers import reverb_utils 24 | from tf_agents.replay_buffers import table 25 | from tf_agents.replay_buffers import tf_uniform_replay_buffer 26 | from tf_agents.replay_buffers.reverb_replay_buffer import ReverbReplayBuffer 27 | from tf_agents.replay_buffers.reverb_utils import ReverbAddEpisodeObserver 28 | from tf_agents.replay_buffers.reverb_utils import ReverbAddTrajectoryObserver 29 | from tf_agents.replay_buffers.tf_uniform_replay_buffer import TFUniformReplayBuffer 30 | -------------------------------------------------------------------------------- /tf_agents/replay_buffers/replay_buffer_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tf_agents.replay_buffers.replay_buffer.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import 23 | from tf_agents import specs 24 | from tf_agents.replay_buffers import replay_buffer 25 | 26 | 27 | class ReplayBufferTestClass(replay_buffer.ReplayBuffer): 28 | """Basic test for ReplayBuffer subclass.""" 29 | 30 | pass 31 | 32 | 33 | class ReplayBufferInitTest(tf.test.TestCase): 34 | 35 | def _data_spec(self): 36 | return ( 37 | specs.TensorSpec([3], tf.float32, 'action'), 38 | ( 39 | specs.TensorSpec([5], tf.float32, 'lidar'), 40 | specs.TensorSpec([3, 2], tf.float32, 'camera'), 41 | ), 42 | ) 43 | 44 | def testReplayBufferInit(self): 45 | spec = self._data_spec() 46 | capacity = 10 47 | rb = ReplayBufferTestClass(spec, capacity) 48 | self.assertEqual(rb.data_spec, spec) 49 | self.assertEqual(rb.capacity, capacity) 50 | 51 | def testReplayBufferInitWithStatefulDataset(self): 52 | spec = self._data_spec() 53 | capacity = 10 54 | rb = ReplayBufferTestClass(spec, capacity, stateful_dataset=True) 55 | self.assertEqual(rb.data_spec, spec) 56 | self.assertEqual(rb.capacity, capacity) 57 | self.assertEqual(rb.stateful_dataset, True) 58 | 59 | def testMethods(self): 60 | spec = self._data_spec() 61 | capacity = 10 62 | rb = ReplayBufferTestClass(spec, capacity) 63 | with self.assertRaises(NotImplementedError): 64 | rb.as_dataset() 65 | with self.assertRaises(NotImplementedError): 66 | rb.as_dataset(single_deterministic_pass=True) 67 | with self.assertRaises(NotImplementedError): 68 | rb.get_next() 69 | with self.assertRaises(NotImplementedError): 70 | rb.add_batch(items=None) 71 | 72 | 73 | if __name__ == '__main__': 74 | tf.test.main() 75 | -------------------------------------------------------------------------------- /tf_agents/specs/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Module for numpy array and `tf.Tensor` shape and dtype specifications.""" 17 | 18 | # TODO(b/130564501): Do not import classes directly, only expose 19 | # modules. 20 | from tf_agents.specs import array_spec 21 | from tf_agents.specs import bandit_spec_utils 22 | from tf_agents.specs import distribution_spec 23 | from tf_agents.specs import tensor_spec 24 | from tf_agents.specs.array_spec import ArraySpec 25 | from tf_agents.specs.array_spec import BoundedArraySpec 26 | from tf_agents.specs.tensor_spec import BoundedTensorSpec 27 | from tf_agents.specs.tensor_spec import from_spec 28 | from tf_agents.specs.tensor_spec import is_bounded 29 | from tf_agents.specs.tensor_spec import is_continuous 30 | from tf_agents.specs.tensor_spec import is_discrete 31 | from tf_agents.specs.tensor_spec import sample_spec_nest 32 | from tf_agents.specs.tensor_spec import TensorSpec 33 | from tf_agents.specs.tensor_spec import zero_spec_nest 34 | -------------------------------------------------------------------------------- /tf_agents/specs/distribution_spec_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tf_agents.specs.distribution_spec.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import 23 | import tensorflow_probability as tfp 24 | from tf_agents.specs import distribution_spec 25 | from tf_agents.specs import tensor_spec 26 | 27 | tfd = tfp.distributions 28 | 29 | 30 | class DistributionSpecTest(tf.test.TestCase): 31 | 32 | def testBuildsDistribution(self): 33 | expected_distribution = tfd.Categorical([0.2, 0.3, 0.5], validate_args=True) 34 | input_param_spec = tensor_spec.TensorSpec((3,), dtype=tf.float32) 35 | sample_spec = tensor_spec.TensorSpec((1,), dtype=tf.int32) 36 | 37 | spec = distribution_spec.DistributionSpec( 38 | tfd.Categorical, 39 | input_param_spec, 40 | sample_spec=sample_spec, 41 | **expected_distribution.parameters 42 | ) 43 | 44 | self.assertEqual( 45 | expected_distribution.parameters['logits'], 46 | spec.distribution_parameters['logits'], 47 | ) 48 | 49 | distribution = spec.build_distribution(logits=[0.1, 0.4, 0.5]) 50 | 51 | self.assertTrue(isinstance(distribution, tfd.Categorical)) 52 | self.assertTrue(distribution.parameters['validate_args']) 53 | self.assertEqual([0.1, 0.4, 0.5], distribution.parameters['logits']) 54 | 55 | 56 | if __name__ == '__main__': 57 | tf.test.main() 58 | -------------------------------------------------------------------------------- /tf_agents/system/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Module importing all system utilities.""" 17 | from tf_agents.system import system_multiprocessing as multiprocessing 18 | -------------------------------------------------------------------------------- /tf_agents/system/default/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | -------------------------------------------------------------------------------- /tf_agents/tf_agents_api_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests TF-Agents API root.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | import tf_agents 23 | 24 | 25 | class RootAPITest(tf.test.TestCase): 26 | 27 | def test_entries(self): 28 | # Ensure that some of the basics exist 29 | # pylint: disable=pointless-statement 30 | tf_agents.agents 31 | tf_agents.experimental 32 | tf_agents.policies 33 | tf_agents.networks 34 | tf_agents.bandits.agents 35 | tf_agents.bandits.policies 36 | tf_agents.bandits.networks 37 | # pylint: disable=pointless-statement 38 | 39 | 40 | if __name__ == '__main__': 41 | tf.test.main() 42 | -------------------------------------------------------------------------------- /tf_agents/train/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """TF-Agents Training Library.""" 17 | 18 | from tf_agents.train import actor 19 | from tf_agents.train import learner 20 | from tf_agents.train import triggers 21 | from tf_agents.train.actor import Actor 22 | from tf_agents.train.learner import Learner 23 | -------------------------------------------------------------------------------- /tf_agents/train/interval_trigger.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utility that Triggers every n calls.""" 17 | 18 | from typing import Callable 19 | 20 | from absl import logging 21 | 22 | 23 | class IntervalTrigger(object): 24 | """Triggers on every fixed interval. 25 | 26 | Note that as long as the >= `interval` number of steps have passed since the 27 | last trigger, the event gets triggered. The current value is not necessarily 28 | `interval` steps away from the last triggered value. 29 | """ 30 | 31 | def __init__(self, interval: int, fn: Callable[[], None], start: int = 0): 32 | """Constructs the IntervalTrigger. 33 | 34 | Args: 35 | interval: The triggering interval. 36 | fn: callable with no arguments that gets triggered. 37 | start: An initial value for the trigger. 38 | """ 39 | self._interval = interval 40 | self._original_start_value = start 41 | self._last_trigger_value = start 42 | self._fn = fn 43 | 44 | if self._interval <= 0: 45 | logging.info( 46 | 'IntervalTrigger will not be triggered because interval is set to %d', 47 | self._interval, 48 | ) 49 | 50 | def __call__(self, value: int, force_trigger: bool = False) -> None: 51 | """Maybe trigger the event based on the interval. 52 | 53 | Args: 54 | value: the value for triggering. 55 | force_trigger: If True, the trigger will be forced triggered unless the 56 | last trigger value is equal to `value`. 57 | """ 58 | if self._interval <= 0: 59 | return 60 | 61 | if (force_trigger and value != self._last_trigger_value) or ( 62 | value >= self._last_trigger_value + self._interval 63 | ): 64 | self._last_trigger_value = value 65 | self._fn() 66 | 67 | def reset(self) -> None: 68 | """Resets the trigger interval.""" 69 | self._last_trigger_value = self._original_start_value 70 | 71 | def set_start(self, start: int) -> None: 72 | self._last_trigger_value = start 73 | -------------------------------------------------------------------------------- /tf_agents/train/step_per_second_tracker.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utility class to keep track of global training steps per second.""" 17 | 18 | import time 19 | 20 | 21 | class StepPerSecondTracker(object): 22 | """Utility class for measuring steps/second.""" 23 | 24 | def __init__(self, step): 25 | """Creates an instance of the StepPerSecondTracker. 26 | 27 | Args: 28 | step: `tf.Variable` holding the current value for the number of train 29 | steps. 30 | """ 31 | self.step = step 32 | self.last_iteration = 0 33 | self.last_time = 0 34 | self.restart() 35 | 36 | def restart(self): 37 | self.last_iteration = self.step.numpy() 38 | self.last_time = time.time() 39 | 40 | def steps_per_second(self): 41 | value = (self.step.numpy() - self.last_iteration) / ( 42 | time.time() - self.last_time 43 | ) 44 | return value 45 | -------------------------------------------------------------------------------- /tf_agents/train/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /tf_agents/train/utils/spec_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utils for processing specs.""" 17 | 18 | from tf_agents.specs import tensor_spec 19 | from tf_agents.trajectories import time_step as ts 20 | from tf_agents.trajectories import trajectory 21 | 22 | 23 | def get_tensor_specs(env): 24 | """Returns observation, action and time step TensorSpecs from passed env. 25 | 26 | Args: 27 | env: environment instance used for collection. 28 | """ 29 | observation_tensor_spec = tensor_spec.from_spec(env.observation_spec()) 30 | action_tensor_spec = tensor_spec.from_spec(env.action_spec()) 31 | time_step_tensor_spec = ts.time_step_spec(observation_tensor_spec) 32 | 33 | return observation_tensor_spec, action_tensor_spec, time_step_tensor_spec 34 | 35 | 36 | def get_collect_data_spec_from_policy_and_env(env, policy): 37 | """Returns collect data spec from policy and environment. 38 | 39 | Args: 40 | env: instance of the environment used for collection 41 | policy: policy for collection to get policy spec 42 | 43 | Meant to be used for collection jobs (i.e. Actors) without having to 44 | construct an agent instance but directly from a policy (which can be loaded 45 | from a saved model). 46 | """ 47 | observation_tensor_spec = tensor_spec.from_spec(env.observation_spec()) 48 | time_step_tensor_spec = ts.time_step_spec(observation_tensor_spec) 49 | policy_step_tensor_spec = tensor_spec.from_spec(policy.policy_step_spec) 50 | collect_data_spec = trajectory.from_transition( 51 | time_step_tensor_spec, policy_step_tensor_spec, time_step_tensor_spec 52 | ) 53 | return collect_data_spec 54 | -------------------------------------------------------------------------------- /tf_agents/train/utils/spec_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tf_agents.train.utils.spec_utils.""" 17 | 18 | from tf_agents.drivers import test_utils as driver_test_utils 19 | from tf_agents.environments import suite_gym 20 | from tf_agents.train.utils import spec_utils 21 | from tf_agents.utils import test_utils 22 | 23 | 24 | class SpecUtilsTest(test_utils.TestCase): 25 | 26 | def test_get_tensor_specs(self): 27 | collect_env = suite_gym.load('CartPole-v0') 28 | observation_spec, action_spec, time_step_spec = spec_utils.get_tensor_specs( 29 | collect_env 30 | ) 31 | self.assertEqual(observation_spec.name, 'observation') 32 | self.assertEqual(action_spec.name, 'action') 33 | self.assertEqual(time_step_spec.observation.name, 'observation') 34 | self.assertEqual(time_step_spec.reward.name, 'reward') 35 | 36 | def test_get_collect_data_spec(self): 37 | env = suite_gym.load('CartPole-v0') 38 | policy = driver_test_utils.PyPolicyMock( 39 | env.time_step_spec(), env.action_spec() 40 | ) 41 | collect_spec = spec_utils.get_collect_data_spec_from_policy_and_env( 42 | env, policy 43 | ) 44 | self.assertEqual(collect_spec.observation.name, 'observation') 45 | self.assertEqual(collect_spec.reward.name, 'reward') 46 | 47 | 48 | if __name__ == '__main__': 49 | test_utils.main() 50 | -------------------------------------------------------------------------------- /tf_agents/train/utils/strategy_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utilities for managing distrubtion strategies.""" 17 | from absl import flags 18 | from absl import logging 19 | import tensorflow.compat.v2 as tf 20 | 21 | TPU = flags.DEFINE_string('tpu', None, 'BNS address for the TPU') 22 | USE_GPU = flags.DEFINE_bool( 23 | 'use_gpu', False, 'If True a MirroredStrategy will be used.' 24 | ) 25 | 26 | 27 | def get_strategy(tpu, use_gpu): 28 | """Utility to create a `tf.DistributionStrategy` for TPU or GPU. 29 | 30 | If neither is being used a DefaultStrategy is returned which allows executing 31 | on CPU only. 32 | 33 | Args: 34 | tpu: BNS address of TPU to use. Note the flag and param are called TPU as 35 | that is what the xmanager utilities call. 36 | use_gpu: Whether a GPU should be used. This will create a MirroredStrategy. 37 | 38 | Raises: 39 | ValueError if both tpu and use_gpu are set. 40 | Returns: 41 | An instance of a `tf.DistributionStrategy`. 42 | """ 43 | if tpu and use_gpu: 44 | raise ValueError('Only one of tpu or use_gpu should be provided.') 45 | if tpu or use_gpu: 46 | logging.info('Devices: \n%s', tf.config.list_logical_devices()) 47 | if tpu: 48 | resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu) 49 | tf.config.experimental_connect_to_cluster(resolver) 50 | tf.tpu.experimental.initialize_tpu_system(resolver) 51 | 52 | strategy = tf.distribute.TPUStrategy(resolver) 53 | else: 54 | strategy = tf.distribute.MirroredStrategy() 55 | logging.info( 56 | 'Devices after getting strategy:\n%s', tf.config.list_logical_devices() 57 | ) 58 | else: 59 | strategy = tf.distribute.get_strategy() 60 | 61 | return strategy 62 | -------------------------------------------------------------------------------- /tf_agents/train/utils/strategy_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tf_agents.train.strategy_utils.""" 17 | 18 | from absl.testing.absltest import mock 19 | import tensorflow.compat.v2 as tf 20 | from tf_agents.train.utils import strategy_utils 21 | from tf_agents.utils import test_utils 22 | 23 | 24 | class StrategyUtilsTest(test_utils.TestCase): 25 | 26 | def test_get_distribution_strategy_default(self): 27 | # Get a default strategy to compare against. 28 | default_strategy = tf.distribute.get_strategy() 29 | 30 | strategy = strategy_utils.get_strategy(tpu=False, use_gpu=False) 31 | self.assertIsInstance(strategy, type(default_strategy)) 32 | 33 | @mock.patch.object(tf.distribute, 'TPUStrategy') 34 | @mock.patch.object(tf.tpu.experimental, 'initialize_tpu_system') 35 | @mock.patch.object(tf.config, 'experimental_connect_to_cluster') 36 | @mock.patch.object(tf.distribute.cluster_resolver, 'TPUClusterResolver') 37 | def test_tpu_strategy( 38 | self, 39 | mock_tpu_cluster_resolver, 40 | mock_experimental_connect_to_cluster, 41 | mock_initialize_tpu_system, 42 | mock_tpu_strategy, 43 | ): 44 | resolver = mock.MagicMock() 45 | mock_tpu_cluster_resolver.return_value = resolver 46 | mock_strategy = mock.MagicMock() 47 | mock_tpu_strategy.return_value = mock_strategy 48 | 49 | strategy = strategy_utils.get_strategy(tpu='bns_address', use_gpu=False) 50 | 51 | mock_tpu_cluster_resolver.assert_called_with(tpu='bns_address') 52 | mock_experimental_connect_to_cluster.assert_called_with(resolver) 53 | mock_initialize_tpu_system.assert_called_with(resolver) 54 | self.assertIs(strategy, mock_strategy) 55 | 56 | @mock.patch.object(tf.distribute, 'MirroredStrategy') 57 | def test_mirrored_strategy(self, mock_mirrored_strategy): 58 | mirrored_strategy = mock.MagicMock() 59 | mock_mirrored_strategy.return_value = mirrored_strategy 60 | 61 | strategy = strategy_utils.get_strategy(False, use_gpu=True) 62 | self.assertIs(strategy, mirrored_strategy) 63 | 64 | 65 | if __name__ == '__main__': 66 | test_utils.main() 67 | -------------------------------------------------------------------------------- /tf_agents/trajectories/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Trajectories module.""" 17 | 18 | from tf_agents.trajectories import policy_step 19 | from tf_agents.trajectories import time_step 20 | from tf_agents.trajectories import trajectory 21 | from tf_agents.trajectories.policy_step import PolicyInfo 22 | from tf_agents.trajectories.policy_step import PolicyStep 23 | from tf_agents.trajectories.time_step import restart 24 | from tf_agents.trajectories.time_step import StepType 25 | from tf_agents.trajectories.time_step import termination 26 | from tf_agents.trajectories.time_step import time_step_spec 27 | from tf_agents.trajectories.time_step import TimeStep 28 | from tf_agents.trajectories.time_step import transition 29 | from tf_agents.trajectories.time_step import truncation 30 | from tf_agents.trajectories.trajectory import boundary 31 | from tf_agents.trajectories.trajectory import first 32 | from tf_agents.trajectories.trajectory import from_transition 33 | from tf_agents.trajectories.trajectory import last 34 | from tf_agents.trajectories.trajectory import mid 35 | from tf_agents.trajectories.trajectory import single_step 36 | from tf_agents.trajectories.trajectory import to_n_step_transition 37 | from tf_agents.trajectories.trajectory import to_transition 38 | from tf_agents.trajectories.trajectory import to_transition_spec 39 | from tf_agents.trajectories.trajectory import Trajectory 40 | from tf_agents.trajectories.trajectory import Transition 41 | -------------------------------------------------------------------------------- /tf_agents/trajectories/policy_step_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tf_agents.trajectories.policy_step.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import 23 | from tf_agents.trajectories import policy_step 24 | 25 | 26 | class PolicyStepTest(tf.test.TestCase): 27 | 28 | def testCreate(self): 29 | action = 1 30 | state = 2 31 | info = 3 32 | step = policy_step.PolicyStep(action=action, state=state, info=info) 33 | self.assertEqual(step.action, action) 34 | self.assertEqual(step.state, state) 35 | self.assertEqual(step.info, info) 36 | 37 | def testCreateWithAllDefaults(self): 38 | action = 1 39 | state = () 40 | info = () 41 | step = policy_step.PolicyStep(action) 42 | self.assertEqual(step.action, action) 43 | self.assertEqual(step.state, state) 44 | self.assertEqual(step.info, info) 45 | 46 | def testCreateWithDefaultInfo(self): 47 | action = 1 48 | state = 2 49 | info = () 50 | step = policy_step.PolicyStep(action, state) 51 | self.assertEqual(step.action, action) 52 | self.assertEqual(step.state, state) 53 | self.assertEqual(step.info, info) 54 | 55 | 56 | if __name__ == '__main__': 57 | tf.test.main() 58 | -------------------------------------------------------------------------------- /tf_agents/trajectories/test_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utility functions for testing with trajectories.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import 23 | from tf_agents.trajectories import trajectory 24 | 25 | 26 | def stacked_trajectory_from_transition(time_step, action_step, next_time_step): 27 | """Given transitions, returns a time stacked `Trajectory`. 28 | 29 | The tensors of the produced `Trajectory` will have a time dimension added 30 | (i.e., a shape of `[B, T, ...]` where T = 2 in this case). The `Trajectory` 31 | can be used when calling `agent.train()` or passed directly to `to_transition` 32 | without the need for a `next_trajectory` argument. 33 | 34 | Args: 35 | time_step: A `time_step.TimeStep` representing the first step in a 36 | transition. 37 | action_step: A `policy_step.PolicyStep` representing actions corresponding 38 | to observations from time_step. 39 | next_time_step: A `time_step.TimeStep` representing the second step in a 40 | transition. 41 | 42 | Returns: 43 | A time stacked `Trajectory`. 44 | """ 45 | # Note that we reuse action_step and next_time_step in experience2 in order to 46 | # ensure the action, policy_info, next_step_type, reward, and discount match 47 | # for both values of the time dimension. 48 | experience1 = trajectory.from_transition( 49 | time_step, action_step, next_time_step 50 | ) 51 | experience2 = trajectory.from_transition( 52 | next_time_step, action_step, next_time_step 53 | ) 54 | return tf.nest.map_structure( 55 | lambda x, y: tf.stack([x, y], axis=1), experience1, experience2 56 | ) 57 | -------------------------------------------------------------------------------- /tf_agents/typing/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Typing module.""" 17 | -------------------------------------------------------------------------------- /tf_agents/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utils module.""" 17 | 18 | from tf_agents.utils import common 19 | from tf_agents.utils import composite 20 | from tf_agents.utils import eager_utils 21 | from tf_agents.utils import example_encoding 22 | from tf_agents.utils import lazy_loader 23 | from tf_agents.utils import nest_utils 24 | from tf_agents.utils import numpy_storage 25 | from tf_agents.utils import session_utils 26 | from tf_agents.utils import tensor_normalizer 27 | from tf_agents.utils import test_utils 28 | from tf_agents.utils import timer 29 | from tf_agents.utils import value_ops 30 | -------------------------------------------------------------------------------- /tf_agents/utils/lazy_loader.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Lazy loader class.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import importlib 23 | import types 24 | 25 | from absl import logging 26 | 27 | 28 | class LazyLoader(types.ModuleType): 29 | """Lazily import a module, mainly to avoid pulling in large dependencies. 30 | 31 | `contrib`, and `ffmpeg` are examples of modules that are large and not always 32 | needed, and this allows them to only be loaded when they are used. 33 | """ 34 | 35 | # The lint error here is incorrect. 36 | def __init__(self, local_name, parent_module_globals, name, warning=None): 37 | self._local_name = local_name 38 | self._parent_module_globals = parent_module_globals 39 | self._warning = warning 40 | 41 | super(LazyLoader, self).__init__(name) 42 | 43 | def _load(self): 44 | """Load the module and insert it into the parent's globals.""" 45 | # Import the target module and insert it into the parent's namespace 46 | module = importlib.import_module(self.__name__) 47 | self._parent_module_globals[self._local_name] = module 48 | 49 | # Emit a warning if one was specified 50 | if self._warning: 51 | logging.warning(self._warning) 52 | # Make sure to only warn once. 53 | self._warning = None 54 | 55 | # Update this object's dict so that if someone keeps a reference to the 56 | # LazyLoader, lookups are efficient (__getattr__ is only called on lookups 57 | # that fail). 58 | self.__dict__.update(module.__dict__) 59 | 60 | return module 61 | 62 | def __getattr__(self, item): 63 | module = self._load() 64 | return getattr(module, item) 65 | 66 | def __dir__(self): 67 | module = self._load() 68 | return dir(module) 69 | -------------------------------------------------------------------------------- /tf_agents/utils/numpy_storage_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for tf_agents.utils.numpy_storage.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | import numpy as np 25 | import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import 26 | from tf_agents.utils import numpy_storage 27 | 28 | from tensorflow.python.framework import test_util # pylint:disable=g-direct-tensorflow-import # TF internal 29 | 30 | 31 | class NumpyStorageTest(tf.test.TestCase): 32 | 33 | @test_util.run_in_graph_and_eager_modes() 34 | def testSaveRestore(self): 35 | arrays = numpy_storage.NumpyState() 36 | checkpoint = tf.train.Checkpoint(numpy_arrays=arrays) 37 | arrays.x = np.ones([3, 4]) 38 | directory = self.get_temp_dir() 39 | prefix = os.path.join(directory, 'ckpt') 40 | save_path = checkpoint.save(prefix) 41 | arrays.x[:] = 0.0 42 | self.assertAllEqual(arrays.x, np.zeros([3, 4])) 43 | checkpoint.restore(save_path).assert_consumed() 44 | self.assertAllEqual(arrays.x, np.ones([3, 4])) 45 | 46 | second_checkpoint = tf.train.Checkpoint( 47 | numpy_arrays=numpy_storage.NumpyState() 48 | ) 49 | # Attributes of NumpyState objects are created automatically by restore() 50 | second_checkpoint.restore(save_path).assert_consumed() 51 | self.assertAllEqual(np.ones([3, 4]), second_checkpoint.numpy_arrays.x) 52 | 53 | 54 | if __name__ == '__main__': 55 | tf.test.main() 56 | -------------------------------------------------------------------------------- /tf_agents/utils/test_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for utils/test_utils.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | import numpy as np 22 | from tf_agents.utils import test_utils 23 | 24 | 25 | class TestUtilsTest(test_utils.TestCase): 26 | 27 | def testBatchContainsSample(self): 28 | batch = np.array([[1, 2], [3, 4]]) 29 | sample = np.array([3, 4]) 30 | self.assertTrue(test_utils.contains(batch, [sample])) 31 | 32 | def testBatchDoesNotContainSample(self): 33 | batch = np.array([[1, 2], [3, 4]]) 34 | sample = np.array([2, 4]) 35 | self.assertFalse(test_utils.contains(batch, [sample])) 36 | 37 | def testBatchContainsBatch(self): 38 | batch1 = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) 39 | batch2 = np.array([[3, 4], [3, 4], [1, 2]]) 40 | self.assertTrue(test_utils.contains(batch1, batch2)) 41 | 42 | def testBatchDoesNotContainBatch(self): 43 | batch1 = np.array([[1, 2], [3, 4]]) 44 | batch2 = np.array([[1, 2], [5, 6]]) 45 | self.assertFalse(test_utils.contains(batch1, batch2)) 46 | 47 | 48 | if __name__ == '__main__': 49 | test_utils.main() 50 | -------------------------------------------------------------------------------- /tf_agents/utils/timer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Timing utility for TF-Agents.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import time 23 | 24 | 25 | class Timer(object): 26 | """Context manager to time blocks of code.""" 27 | 28 | def __init__(self): 29 | self._accumulator = 0 30 | self._last = None 31 | 32 | def __enter__(self): 33 | self.start() 34 | 35 | def __exit__(self, *args): 36 | self.stop() 37 | 38 | def start(self): 39 | self._last = time.time() 40 | 41 | def stop(self): 42 | self._accumulator += time.time() - self._last 43 | 44 | def value(self): 45 | return self._accumulator 46 | 47 | def reset(self): 48 | self._accumulator = 0 49 | -------------------------------------------------------------------------------- /tf_agents/version.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Define TF Agents version information.""" 17 | 18 | # We follow Semantic Versioning (https://semver.org/) 19 | _MAJOR_VERSION = '0' 20 | _MINOR_VERSION = '20' 21 | _PATCH_VERSION = '0' 22 | 23 | # When building releases, we can update this value on the release branch to 24 | # reflect the current release candidate ('rc0', 'rc1') or, finally, the official 25 | # stable release (indicated by `_REL_SUFFIX = ''`). Outside the context of a 26 | # release branch, the current version is by default assumed to be a 27 | # 'development' version, labeled 'dev'. 28 | _DEV_SUFFIX = 'dev' 29 | _REL_SUFFIX = '' 30 | 31 | # Example, '0.10.0rc0' 32 | __version__ = '.'.join([ 33 | _MAJOR_VERSION, 34 | _MINOR_VERSION, 35 | _PATCH_VERSION, 36 | ]) 37 | __dev_version__ = '{}.{}'.format(__version__, _DEV_SUFFIX) 38 | __rel_version__ = '{}{}'.format(__version__, _REL_SUFFIX) 39 | -------------------------------------------------------------------------------- /tools/docker/README.md: -------------------------------------------------------------------------------- 1 | # Docker build files for TF-Agents 2 | 3 | This directory contains Docker build files used to run and build TF-Agents. Each 4 | file has basic usage at the top. All of the commands below assume they are run 5 | from the root of the github repository. 6 | 7 | ## Standard TF-Agents Docker 8 | 9 | **Build with tf-agents nightly** 10 | 11 | ```shell 12 | $ docker build -t tf_agents/core \ 13 | --build-arg tf_agents_pip_spec=tf-agents-nightly[reverb] \ 14 | -f tools/docker/ubuntu_tf_agents ./tools/docker 15 | ``` 16 | 17 | **Build with tf-agents latest stable** 18 | 19 | ```shell 20 | $ docker build --pull -t tf_agents/core \ 21 | --build-arg tf_agents_pip_spec=tf-agents[reverb] \ 22 | --build-arg tensorflow_pip_spec=tensorflow \ 23 | -f tools/docker/ubuntu_tf_agents ./tools/docker 24 | ``` 25 | 26 | **Run** 27 | 28 | Starts the Docker image above and gives you a bash prompt and the TF-Agents' 29 | repo mounted as `/workspace`. 30 | 31 | ```shell 32 | $ sudo docker run -it -u $(id -u):$(id -g) -v $(pwd):/workspace \ 33 | tf_agents/core bash 34 | ``` 35 | 36 | ## TF-Agents plus Mujoco 37 | 38 | The steps below build and run a docker with Mujoco. 39 | 40 | Note: `tf_agents/core` is the base for this docker and built above. 41 | 42 | ```shell 43 | $ docker build -t tf_agents/mujoco \ 44 | -f tools/docker/ubuntu_mujoco ./tools/docker 45 | ``` 46 | 47 | Start the container: 48 | 49 | ```shell 50 | $ docker run --rm -it \ 51 | -v $(pwd):/workspace \ 52 | tf_agents/mujoco bash 53 | ``` 54 | 55 | Start the container with GPU support: 56 | 57 | ```shell 58 | $ docker run --rm -it \ 59 | --gpus all \ 60 | -v $(pwd):/workspace \ 61 | tf_agents/mujoco bash 62 | ``` 63 | -------------------------------------------------------------------------------- /tools/docker/ubuntu_atari: -------------------------------------------------------------------------------- 1 | # Docker for running tf-agents with atari and user provided ROMs. 2 | # 3 | # For current details on installing ROMs for ale-py read their instructions: 4 | # https://github.com/mgbellemare/Arcade-Learning-Environment 5 | # 6 | # Example usage: 7 | # This docker builds on the core tf-agents docker. 8 | # 9 | # Builds a core tf-agents image with the nightly version of TF-Agents and 10 | # Tensorflow. 11 | # docker build --pull -t tf_agents/core \ 12 | # --build-arg tf_agents_pip_spec=tf-agents-nightly[reverb] \ 13 | # -f ./tools/docker/ubuntu_tf_agents ./tools/docker 14 | # 15 | # Builds a docker that installs user provided ROMs located in 16 | # ./tools/docker/roms into the ale-py package. 17 | # docker build -t tf_agents/atari -f ./tools/docker/ubuntu_atari \ 18 | # ./tools/docker 19 | # 20 | # 21 | # Builds a core tf-agents docker with the latest stable TF-Agents and 22 | # Tensorflow. 23 | # docker build --pull -t tf_agents/core \ 24 | # --build-arg tf_agents_pip_spec=tf-agents[reverb] \ 25 | # --build-arg tensorflow_pip_spec=tensorflow \ 26 | # -f ./tools/docker/ubuntu_tf_agents ./tools/docker 27 | # 28 | # Builds a docker that installs user provided ROMs located in 29 | # ./tools/docker/roms into the ale-py package. 30 | # docker build -t tf_agents/atari -f ./tools/docker/ubuntu_atari \ 31 | # ./tools/docker 32 | # 33 | FROM tf_agents/core as tf_agents_core 34 | ARG rom_dir=roms 35 | ARG python_version="python3" 36 | 37 | # Copies ROMs into image. 38 | COPY ${rom_dir} /roms 39 | 40 | # Installs the ROMs into the ale-py python package. 41 | RUN ale-import-roms /roms 42 | -------------------------------------------------------------------------------- /tools/docker/ubuntu_ci: -------------------------------------------------------------------------------- 1 | # Run the following commands in order: 2 | # 3 | # docker build --tag tf_agents:ci -f tools/docker/ubuntu_ci tools/docker/ 4 | # 5 | # Test that everything worked: 6 | # docker run -it --rm -v $(pwd):/workspace --workdir /workspace circuit_training:ci bash 7 | # python3.9 -m circuit_training.environment.environment_test 8 | ARG base_image="ubuntu:20.04" 9 | 10 | FROM $base_image 11 | 12 | LABEL maintainer="tobyboyd@google.com" 13 | 14 | # Supports setting up a single version of python. 15 | ARG python_version="python3.9" 16 | ARG APT_COMMAND="apt-get -o Acquire::Retries=3 -y" 17 | 18 | # Stops tzdata from asking about timezones and blocking install on user input. 19 | ARG DEBIAN_FRONTEND=noninteractive 20 | ENV TZ=America/Los_Angeles 21 | 22 | RUN ${APT_COMMAND} update && ${APT_COMMAND} install -y --no-install-recommends \ 23 | software-properties-common \ 24 | curl \ 25 | less 26 | 27 | # Adds repository to pull versions of python from. 28 | RUN add-apt-repository ppa:deadsnakes/ppa 29 | 30 | # Installs various versions of python and then cleans up apt. 31 | # Pick up some TF dependencies 32 | RUN apt-get update && apt-get install -y --no-install-recommends \ 33 | build-essential \ 34 | cmake \ 35 | zlib1g-dev \ 36 | libpng-dev \ 37 | lsb-core \ 38 | vim \ 39 | ca-certificates \ 40 | wget \ 41 | zip \ 42 | xvfb \ 43 | freeglut3-dev \ 44 | ffmpeg \ 45 | python3.9-dev \ 46 | python3.10-dev \ 47 | # python >= 3.8 needs distutils for packaging. 48 | python3.9-distutils \ 49 | python3.10-distutils \ 50 | gfortran \ 51 | libopenblas-dev \ 52 | liblapack-dev 53 | 54 | RUN curl -O https://bootstrap.pypa.io/get-pip.py 55 | 56 | ARG pip_dependencies=' \ 57 | virtualenv \ 58 | scipy' 59 | 60 | # No need to install tf-agents as that should be part of the test setup. 61 | # Doing it for now to test. 62 | RUN for python in ${python_version}; do \ 63 | $python get-pip.py && \ 64 | $python -mpip --no-cache-dir install $pip_dependencies; \ 65 | done 66 | 67 | 68 | CMD ["/bin/bash"] 69 | -------------------------------------------------------------------------------- /tools/docker/ubuntu_d4rl: -------------------------------------------------------------------------------- 1 | # Docker for running tf-agents with D4RL. 2 | # 3 | # Example usage: 4 | # This docker builds on the MuJoCo tf-agents docker. 5 | # 6 | # Build a docker with nightly version of TF-Agents and Tensorflow. 7 | # docker build --pull -t tf_agents/core \ 8 | # --build-arg tf_agents_pip_spec=tf-agents-nightly[reverb] \ 9 | # -f tools/docker/ubuntu_tf_agents ./tools/docker 10 | # 11 | # docker build -t tf_agents/mujoco -f ./tools/docker/ubuntu_mujoco_oss \ 12 | # ./tools/docker 13 | # 14 | # docker build -t tf_agents/mujoco/d4rl -f ./tools/docker/ubuntu_d4rl \ 15 | # ./tools/docker 16 | # 17 | # 18 | # Build a docker with latest stable TF-Agents and Tensorflow. 19 | # docker build --pull -t tf_agents/core \ 20 | # --build-arg tf_agents_pip_spec=tf-agents[reverb] \ 21 | # --build-arg tensorflow_pip_spec=tensorflow \ 22 | # -f tools/docker/ubuntu_tf_agents ./tools/docker 23 | # 24 | # docker build -t tf_agents/mujoco -f ./tools/docker/ubuntu_mujoco \ 25 | # ./tools/docker 26 | # 27 | # docker build -t tf_agents/mujoco/d4rl -f ./tools/docker/ubuntu_d4rl \ 28 | # ./tools/docker 29 | # 30 | FROM tf_agents/mujoco as d4rl 31 | ARG python_version="python3" 32 | 33 | # Symlink to MuJoCo that is needed for D4RL installation. 34 | RUN ln -s /root/.mujoco/mujoco210 /root/.mujoco/mujoco210_linux 35 | 36 | # Install D4RL. 37 | RUN $python_version -m pip install git+https://github.com/rail-berkeley/d4rl@master#egg=d4rl 38 | -------------------------------------------------------------------------------- /tools/docker/ubuntu_mujoco_oss: -------------------------------------------------------------------------------- 1 | # Docker for running tf-agents with MuJoCo. 2 | # 3 | # Example usage: 4 | # This docker builds on the core tf-agents docker. 5 | # 6 | # Build a docker with nightly version of TF-Agents and Tensorflow. 7 | # docker build --pull -t tf_agents/core \ 8 | # --build-arg tf_agents_pip_spec=tf-agents-nightly[reverb] \ 9 | # -f tools/docker/ubuntu_tf_agents ./tools/docker 10 | # 11 | # docker build -t tf_agents/mujoco -f ./tools/docker/ubuntu_mujoco_oss \ 12 | # ./tools/docker 13 | # 14 | # 15 | # Build a docker with latest stable TF-Agents and Tensorflow. 16 | # docker build --pull -t tf_agents/core \ 17 | # --build-arg tf_agents_pip_spec=tf-agents[reverb] \ 18 | # --build-arg tensorflow_pip_spec=tensorflow \ 19 | # -f tools/docker/ubuntu_tf_agents ./tools/docker 20 | # 21 | # docker build -t tf_agents/mujoco -f ./tools/docker/ubuntu_mujoco_oss \ 22 | # ./tools/docker 23 | # 24 | FROM tf_agents/core as mujoco 25 | ARG python_version="python3" 26 | 27 | # Prerequisites for MuJoCo. 28 | RUN apt-get update && apt-get install -y \ 29 | --no-install-recommends \ 30 | libglew-dev \ 31 | libosmesa6-dev \ 32 | patchelf 33 | 34 | RUN mkdir -p /root/.mujoco \ 35 | && wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz \ 36 | && tar -xf mujoco210-linux-x86_64.tar.gz -C /root/.mujoco/ \ 37 | && rm mujoco210-linux-x86_64.tar.gz 38 | 39 | # On some hosts LD_LIBRARY_PATH is unset when running the container. After hours 40 | # of internet research along with trial and error, a root cause was not found. 41 | # It works fine on google cloud instances but fails on my workstation, which 42 | # worked fine on my workstation a few months ago. `docker inspect` shows 43 | # LD_LIBRARY_PATH is set in the container config. If you get the following 44 | # message when trying to use MuJoCo, use the work around below: 45 | # 46 | # Please add following line to .bashrc: 47 | # export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin 48 | # 49 | # Starting docker with the arguments below resolved the issue when 50 | # LD_LIBRARY_PATH would not stay set: 51 | # docker run --rm -it tf_agents/mujoco /bin/bash -c \ 52 | # "LD_LIBRARY_PATH=/root/.mujoco/mujoco210/bin python3 foo.py --bar 10" 53 | ENV LD_LIBRARY_PATH "/root/.mujoco/mujoco210/bin:${LD_LIBRARY_PATH}" 54 | # When LD_LIBRARY_PATH is not sticky, this sets it on login. 55 | RUN echo 'export LD_LIBRARY_PATH=/root/.mujoco/mujoco210/bin:${LD_LIBRARY_PATH}' >> /etc/bash.bashrc 56 | 57 | # LD_LIBRARY_PATH is redunanty but needed on hosts where LD_LIBRARY_PATH is 58 | # not sticky. See comments above. 59 | RUN LD_LIBRARY_PATH=/root/.mujoco/mujoco200/bin $python_version -m pip install -U 'mujoco-py<2.2,>=2.1' 60 | -------------------------------------------------------------------------------- /tools/graph_builder_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TF-Agents Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for graph_builder.py.""" 17 | import os 18 | 19 | from absl.testing import absltest 20 | import graph_builder 21 | 22 | TEST_DATA = 'test_data' 23 | 24 | 25 | class GraphBuilderTest(absltest.TestCase): 26 | 27 | def test_align_and_aggregate(self): 28 | """Tests combining data from 3 differnet logs into a single result.""" 29 | event_log_dirs = [ 30 | 'event_log_ant_eval00', 31 | 'event_log_ant_eval01', 32 | 'event_log_ant_eval02', 33 | ] 34 | event_log_paths = [ 35 | os.path.join(TEST_DATA, log_dir) for log_dir in event_log_dirs 36 | ] 37 | stat_builder = graph_builder.StatsBuilder(event_log_paths, 'AverageReturn') 38 | data_collector, _ = stat_builder._gather_data() 39 | agg_results = stat_builder._align_and_aggregate(data_collector) 40 | # Mean at step 3M. 41 | self.assertAlmostEqual(agg_results[-1][-1], 5674.96354, places=4) 42 | # Median at step 3M. 43 | self.assertAlmostEqual(agg_results[-1][-2], 5573.90380, places=4) 44 | 45 | def test_output_graph(self): 46 | """Tests outputing a graph to a file does not error out. 47 | 48 | There is no validation that the output graph is correct. 49 | """ 50 | output_path = self.create_tempdir() 51 | event_log_dirs = [ 52 | 'event_log_ant_eval00', 53 | 'event_log_ant_eval01', 54 | 'event_log_ant_eval02', 55 | ] 56 | event_log_paths = [ 57 | os.path.join(TEST_DATA, log_dir) for log_dir in event_log_dirs 58 | ] 59 | stat_builder = graph_builder.StatsBuilder( 60 | event_log_paths, 61 | 'AverageReturn', 62 | graph_agg=graph_builder.GraphAggTypes.MEDIAN, 63 | output_path=output_path.full_path, 64 | ) 65 | data_collector, _ = stat_builder._gather_data() 66 | agg_results = stat_builder._align_and_aggregate(data_collector) 67 | stat_builder._output_graph(agg_results, len(event_log_dirs)) 68 | 69 | 70 | if __name__ == '__main__': 71 | absltest.main() 72 | -------------------------------------------------------------------------------- /tools/test_data/event_log_ant_eval00/events.out.tfevents.1599310762: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/tools/test_data/event_log_ant_eval00/events.out.tfevents.1599310762 -------------------------------------------------------------------------------- /tools/test_data/event_log_ant_eval01/events.out.tfevents.1599379945: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/tools/test_data/event_log_ant_eval01/events.out.tfevents.1599379945 -------------------------------------------------------------------------------- /tools/test_data/event_log_ant_eval02/events.out.tfevents.1599448596: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/agents/0e27d54b6bc0133d51e8e10f956bf515ac3bd3b6/tools/test_data/event_log_ant_eval02/events.out.tfevents.1599448596 --------------------------------------------------------------------------------