├── .circleci └── config.yml ├── .codecov.yml ├── .gitignore ├── .gitmodules ├── .isort.cfg ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── _static │ └── empty ├── api │ ├── modules.rst │ ├── reagent.core.rst │ ├── reagent.data.rst │ ├── reagent.evaluation.feature_importance.rst │ ├── reagent.evaluation.rst │ ├── reagent.gym.agents.rst │ ├── reagent.gym.datasets.rst │ ├── reagent.gym.envs.dynamics.rst │ ├── reagent.gym.envs.functionality.rst │ ├── reagent.gym.envs.pomdp.rst │ ├── reagent.gym.envs.rst │ ├── reagent.gym.envs.wrappers.rst │ ├── reagent.gym.policies.rst │ ├── reagent.gym.policies.samplers.rst │ ├── reagent.gym.policies.scorers.rst │ ├── reagent.gym.preprocessors.rst │ ├── reagent.gym.rst │ ├── reagent.gym.runners.rst │ ├── reagent.gym.tests.preprocessors.rst │ ├── reagent.gym.tests.rst │ ├── reagent.lite.rst │ ├── reagent.mab.rst │ ├── reagent.model_managers.actor_critic.rst │ ├── reagent.model_managers.discrete.rst │ ├── reagent.model_managers.model_based.rst │ ├── reagent.model_managers.parametric.rst │ ├── reagent.model_managers.policy_gradient.rst │ ├── reagent.model_managers.ranking.rst │ ├── reagent.model_managers.rst │ ├── reagent.model_utils.rst │ ├── reagent.models.rst │ ├── reagent.net_builder.categorical_dqn.rst │ ├── reagent.net_builder.continuous_actor.rst │ ├── reagent.net_builder.discrete_actor.rst │ ├── reagent.net_builder.discrete_dqn.rst │ ├── reagent.net_builder.parametric_dqn.rst │ ├── reagent.net_builder.quantile_dqn.rst │ ├── reagent.net_builder.rst │ ├── reagent.net_builder.slate_ranking.rst │ ├── reagent.net_builder.slate_reward.rst │ ├── reagent.net_builder.synthetic_reward.rst │ ├── reagent.net_builder.value.rst │ ├── reagent.ope.datasets.rst │ ├── reagent.ope.estimators.rst │ ├── reagent.ope.rst │ ├── reagent.ope.test.rst │ ├── reagent.ope.test.unit_tests.rst │ ├── reagent.ope.trainers.rst │ ├── reagent.optimizer.rst │ ├── reagent.prediction.ranking.rst │ ├── reagent.prediction.rst │ ├── reagent.prediction.synthetic_reward.rst │ ├── reagent.preprocessing.rst │ ├── reagent.publishers.rst │ ├── reagent.replay_memory.rst │ ├── reagent.reporting.rst │ ├── reagent.rst │ ├── reagent.samplers.rst │ ├── reagent.scripts.rst │ ├── reagent.training.cb.rst │ ├── reagent.training.cfeval.rst │ ├── reagent.training.gradient_free.rst │ ├── reagent.training.ranking.rst │ ├── reagent.training.rst │ ├── reagent.training.world_model.rst │ ├── reagent.validators.rst │ └── reagent.workflow.rst ├── build.sh ├── conf.py ├── continuous_integration.rst ├── distributed.rst ├── index.rst ├── installation.rst ├── license.rst ├── rasp_tutorial.rst └── usage.rst ├── logo ├── horizon_banner.png ├── horizon_logo.png ├── horizon_logo_256.png ├── horizon_logo_inverted.png ├── reagent_banner.png └── reagent_logo.png ├── preprocessing ├── pom.xml └── src │ ├── main │ └── scala │ │ └── com │ │ └── facebook │ │ └── spark │ │ └── rl │ │ ├── Constants.scala │ │ ├── Helper.scala │ │ ├── MultiStepTimeline.scala │ │ ├── Timeline.scala │ │ └── Udfs.scala │ └── test │ └── scala │ └── com │ └── facebook │ └── spark │ ├── common │ └── testutil │ │ ├── PipelineTester.scala │ │ ├── TestLogger.scala │ │ └── TestLogging.scala │ └── rl │ └── TimelineTest.scala ├── pyproject.toml ├── rasp_requirements.txt ├── reagent ├── __init__.py ├── core │ ├── __init__.py │ ├── aggregators.py │ ├── base_dataclass.py │ ├── configuration.py │ ├── dataclasses.py │ ├── debug_on_error.py │ ├── fb_checker.py │ ├── multiprocess_utils.py │ ├── observers.py │ ├── oss_tensorboard_logger.py │ ├── parameters.py │ ├── parameters_seq2slate.py │ ├── registry_meta.py │ ├── report_utils.py │ ├── result_registries.py │ ├── result_types.py │ ├── running_stats.py │ ├── tagged_union.py │ ├── tensorboardX.py │ ├── torch_utils.py │ ├── tracker.py │ ├── types.py │ └── utils.py ├── data │ ├── __init__.py │ ├── data_fetcher.py │ ├── manual_data_module.py │ ├── oss_data_fetcher.py │ ├── reagent_data_module.py │ └── spark_utils.py ├── evaluation │ ├── __init__.py │ ├── cb │ │ ├── __init__.py │ │ ├── base_evaluator.py │ │ ├── policy_evaluator.py │ │ ├── run_synthetic_bandit.py │ │ ├── synthetic_contextual_bandit_data.py │ │ └── utils.py │ ├── cpe.py │ ├── doubly_robust_estimator.py │ ├── evaluation_data_page.py │ ├── evaluator.py │ ├── feature_importance │ │ ├── __init__.py │ │ ├── feature_importance_base.py │ │ └── feature_importance_perturbation.py │ ├── ope_adapter.py │ ├── sequential_doubly_robust_estimator.py │ ├── weighted_sequential_doubly_robust_estimator.py │ └── world_model_evaluator.py ├── gym │ ├── __init__.py │ ├── agents │ │ ├── __init__.py │ │ ├── agent.py │ │ └── post_step.py │ ├── datasets │ │ ├── __init__.py │ │ ├── episodic_dataset.py │ │ └── replay_buffer_dataset.py │ ├── envs │ │ ├── __init__.py │ │ ├── changing_arms.py │ │ ├── dynamics │ │ │ ├── __init__.py │ │ │ └── linear_dynamics.py │ │ ├── env_wrapper.py │ │ ├── functionality │ │ │ ├── __init__.py │ │ │ └── possible_actions_mask_tester.py │ │ ├── gym.py │ │ ├── oracle_pvm.py │ │ ├── pomdp │ │ │ ├── __init__.py │ │ │ ├── pocman.py │ │ │ ├── state_embed_env.py │ │ │ ├── string_game.py │ │ │ └── string_game_v1.py │ │ ├── recsim.py │ │ ├── toy_vm.py │ │ ├── utils.py │ │ └── wrappers │ │ │ ├── __init__.py │ │ │ ├── recsim.py │ │ │ └── simple_minigrid.py │ ├── normalizers.py │ ├── policies │ │ ├── __init__.py │ │ ├── policy.py │ │ ├── predictor_policies.py │ │ ├── random_policies.py │ │ ├── samplers │ │ │ ├── __init__.py │ │ │ ├── continuous_sampler.py │ │ │ ├── discrete_sampler.py │ │ │ └── top_k_sampler.py │ │ └── scorers │ │ │ ├── __init__.py │ │ │ ├── continuous_scorer.py │ │ │ ├── discrete_scorer.py │ │ │ └── slate_q_scorer.py │ ├── preprocessors │ │ ├── __init__.py │ │ ├── default_preprocessors.py │ │ ├── replay_buffer_inserters.py │ │ └── trainer_preprocessor.py │ ├── runners │ │ ├── __init__.py │ │ └── gymrunner.py │ ├── tests │ │ ├── __init__.py │ │ ├── configs │ │ │ ├── cartpole │ │ │ │ ├── discrete_c51_cartpole_online.yaml │ │ │ │ ├── discrete_crr_cartpole_online.yaml │ │ │ │ ├── discrete_dqn_cartpole_online.yaml │ │ │ │ ├── discrete_ppo_cartpole_online.yaml │ │ │ │ ├── discrete_qr_cartpole_online.yaml │ │ │ │ ├── discrete_reinforce_cartpole_online.yaml │ │ │ │ ├── parametric_dqn_cartpole_online.yaml │ │ │ │ └── parametric_sarsa_cartpole_online.yaml │ │ │ ├── functionality │ │ │ │ └── dqn_possible_actions_mask.yaml │ │ │ ├── open_gridworld │ │ │ │ └── discrete_dqn_open_gridworld.yaml │ │ │ ├── pendulum │ │ │ │ ├── continuous_crr_pendulum_online.yaml │ │ │ │ ├── sac_pendulum_online.yaml │ │ │ │ └── td3_pendulum_online.yaml │ │ │ ├── recsim │ │ │ │ ├── slate_q_recsim_online.yaml │ │ │ │ ├── slate_q_recsim_online_maxq_topk.yaml │ │ │ │ ├── slate_q_recsim_online_multi_selection.yaml │ │ │ │ ├── slate_q_recsim_online_multi_selection_avg_curr.yaml │ │ │ │ └── slate_q_recsim_online_with_time_scale.yaml │ │ │ ├── sparse │ │ │ │ └── discrete_dqn_changing_arms_online.yaml │ │ │ └── world_model │ │ │ │ ├── cartpole_features.yaml │ │ │ │ ├── cem_cartpole_offline.yaml │ │ │ │ ├── cem_many_world_models_linear_dynamics_offline.yaml │ │ │ │ ├── cem_single_world_model_linear_dynamics_offline.yaml │ │ │ │ ├── discrete_dqn_string.yaml │ │ │ │ └── seq2reward_test.yaml │ │ ├── preprocessors │ │ │ ├── __init__.py │ │ │ ├── test_default_preprocessors.py │ │ │ └── test_replay_buffer_inserters.py │ │ ├── test_epsilon_greedy_action_sampler.py │ │ ├── test_gym.py │ │ ├── test_gym_datasets.py │ │ ├── test_gym_offline.py │ │ ├── test_gym_replay_buffer.py │ │ ├── test_linear_dynamics.py │ │ ├── test_pomdp.py │ │ └── test_world_model.py │ ├── types.py │ └── utils.py ├── lite │ ├── __init__.py │ └── optimizer.py ├── mab │ ├── __init__.py │ ├── mab_algorithm.py │ ├── simulation.py │ ├── thompson_sampling.py │ └── ucb.py ├── model_managers │ ├── __init__.py │ ├── actor_critic │ │ ├── __init__.py │ │ ├── sac.py │ │ └── td3.py │ ├── actor_critic_base.py │ ├── discrete │ │ ├── __init__.py │ │ ├── discrete_c51dqn.py │ │ ├── discrete_crr.py │ │ ├── discrete_dqn.py │ │ └── discrete_qrdqn.py │ ├── discrete_dqn_base.py │ ├── model_based │ │ ├── __init__.py │ │ ├── cross_entropy_method.py │ │ ├── seq2reward_model.py │ │ ├── synthetic_reward.py │ │ └── world_model.py │ ├── model_manager.py │ ├── parametric │ │ ├── __init__.py │ │ └── parametric_dqn.py │ ├── parametric_dqn_base.py │ ├── policy_gradient │ │ ├── __init__.py │ │ ├── ppo.py │ │ └── reinforce.py │ ├── ranking │ │ ├── __init__.py │ │ └── slate_q.py │ ├── slate_q_base.py │ ├── union.py │ └── world_model_base.py ├── model_utils │ ├── __init__.py │ └── seq2slate_utils.py ├── models │ ├── __init__.py │ ├── actor.py │ ├── base.py │ ├── bcq.py │ ├── categorical_dqn.py │ ├── cb_base_model.py │ ├── cb_fully_connected_network.py │ ├── cem_planner.py │ ├── containers.py │ ├── convolutional_network.py │ ├── critic.py │ ├── deep_represent_linucb.py │ ├── disjoint_linucb_predictor.py │ ├── dqn.py │ ├── dueling_q_network.py │ ├── embedding_bag_concat.py │ ├── fully_connected_network.py │ ├── linear_regression.py │ ├── mab.py │ ├── mdn_rnn.py │ ├── mlp_scorer.py │ ├── model_feature_config_provider.py │ ├── no_soft_update_embedding.py │ ├── probabilistic_fully_connected_network.py │ ├── residual_wrapper.py │ ├── seq2reward_model.py │ ├── seq2slate.py │ ├── seq2slate_reward.py │ ├── sparse_dqn.py │ ├── synthetic_reward.py │ ├── synthetic_reward_sparse_arch.py │ └── world_model.py ├── net_builder │ ├── __init__.py │ ├── categorical_dqn │ │ ├── __init__.py │ │ └── categorical.py │ ├── categorical_dqn_net_builder.py │ ├── continuous_actor │ │ ├── __init__.py │ │ ├── dirichlet_fully_connected.py │ │ ├── fully_connected.py │ │ └── gaussian_fully_connected.py │ ├── continuous_actor_net_builder.py │ ├── discrete_actor │ │ ├── __init__.py │ │ └── fully_connected.py │ ├── discrete_actor_net_builder.py │ ├── discrete_dqn │ │ ├── __init__.py │ │ ├── dueling.py │ │ ├── fully_connected.py │ │ └── fully_connected_with_embedding.py │ ├── discrete_dqn_net_builder.py │ ├── parametric_dqn │ │ ├── __init__.py │ │ └── fully_connected.py │ ├── parametric_dqn_net_builder.py │ ├── quantile_dqn │ │ ├── __init__.py │ │ ├── dueling_quantile.py │ │ └── quantile.py │ ├── quantile_dqn_net_builder.py │ ├── slate_ranking │ │ ├── __init__.py │ │ ├── slate_ranking_scorer.py │ │ └── slate_ranking_transformer.py │ ├── slate_ranking_net_builder.py │ ├── slate_reward │ │ ├── __init__.py │ │ ├── slate_reward_gru.py │ │ └── slate_reward_transformer.py │ ├── slate_reward_net_builder.py │ ├── synthetic_reward │ │ ├── __init__.py │ │ ├── ngram_synthetic_reward.py │ │ ├── sequence_synthetic_reward.py │ │ ├── single_step_synthetic_reward.py │ │ ├── single_step_synthetic_reward_sparse_arch.py │ │ └── transformer_synthetic_reward.py │ ├── synthetic_reward_net_builder.py │ ├── unions.py │ ├── value │ │ ├── __init__.py │ │ ├── fully_connected.py │ │ └── seq2reward_rnn.py │ └── value_net_builder.py ├── notebooks │ ├── PPO_for_CartPole_Control.ipynb │ └── REINFORCE_for_CartPole_Control.ipynb ├── ope │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ └── logged_dataset.py │ ├── estimators │ │ ├── __init__.py │ │ ├── contextual_bandits_estimators.py │ │ ├── estimator.py │ │ ├── sequential_estimators.py │ │ ├── slate_estimators.py │ │ └── types.py │ ├── test │ │ ├── __init__.py │ │ ├── cartpole.py │ │ ├── configs │ │ │ ├── ecoli_config.json │ │ │ ├── letter_recog_config.json │ │ │ ├── mslr_web30k_config.json │ │ │ ├── optdigits_config.json │ │ │ ├── pendigits_config.json │ │ │ ├── satimage_config.json │ │ │ └── yandex_web_search_config.json │ │ ├── data │ │ │ ├── ecoli.data │ │ │ ├── ecoli.names │ │ │ ├── letter-recognition.data │ │ │ ├── letter-recognition.names │ │ │ ├── optdigits.data │ │ │ ├── optdigits.names │ │ │ ├── pendigits.data │ │ │ ├── pendigits.names │ │ │ ├── satimage.data │ │ │ └── satimage.names │ │ ├── envs.py │ │ ├── gridworld.py │ │ ├── mslr_slate.py │ │ ├── multiclass_bandits.py │ │ ├── notebooks │ │ │ ├── CartpoleExperiments.ipynb │ │ │ ├── GridWorldExperiments.ipynb │ │ │ ├── contextual_bandit_experiments.ipynb │ │ │ ├── contextual_bandit_randomized_experiments.ipynb │ │ │ └── img │ │ │ │ ├── bias.png │ │ │ │ ├── rmse.png │ │ │ │ └── variance.png │ │ ├── unit_tests │ │ │ ├── __init__.py │ │ │ ├── test_contextual_bandit_estimators.py │ │ │ ├── test_slate_estimators.py │ │ │ ├── test_types.py │ │ │ └── test_utils.py │ │ └── yandex_web_search.py │ ├── trainers │ │ ├── __init__.py │ │ ├── linear_trainers.py │ │ └── rl_tabular_trainers.py │ └── utils.py ├── optimizer │ ├── __init__.py │ ├── optimizer.py │ ├── scheduler.py │ ├── scheduler_union.py │ ├── soft_update.py │ ├── uninferrable_optimizers.py │ ├── uninferrable_schedulers.py │ ├── union.py │ └── utils.py ├── prediction │ ├── __init__.py │ ├── cfeval │ │ ├── __init__.py │ │ └── predictor_wrapper.py │ ├── predictor_wrapper.py │ ├── ranking │ │ ├── __init__.py │ │ └── predictor_wrapper.py │ └── synthetic_reward │ │ ├── __init__.py │ │ └── synthetic_reward_predictor_wrapper.py ├── preprocessing │ ├── __init__.py │ ├── batch_preprocessor.py │ ├── identify_types.py │ ├── normalization.py │ ├── postprocessor.py │ ├── preprocessor.py │ ├── sparse_preprocessor.py │ ├── sparse_to_dense.py │ ├── transforms.py │ └── types.py ├── publishers │ ├── __init__.py │ ├── file_system_publisher.py │ ├── model_publisher.py │ ├── no_publishing.py │ └── union.py ├── replay_memory │ ├── __init__.py │ ├── circular_replay_buffer.py │ ├── prioritized_replay_buffer.py │ ├── sum_tree.py │ └── utils.py ├── reporting │ ├── __init__.py │ ├── actor_critic_reporter.py │ ├── compound_reporter.py │ ├── discrete_crr_reporter.py │ ├── discrete_dqn_reporter.py │ ├── parametric_dqn_reporter.py │ ├── reporter_base.py │ ├── reward_network_reporter.py │ ├── seq2reward_reporter.py │ ├── slate_q_reporter.py │ ├── td3_reporter.py │ └── world_model_reporter.py ├── samplers │ ├── __init__.py │ └── frechet.py ├── scripts │ ├── __init__.py │ └── hparam_tuning.py ├── test │ ├── __init__.py │ ├── base │ │ ├── __init__.py │ │ ├── horizon_test_base.py │ │ ├── test_tensorboardX.py │ │ ├── test_types.py │ │ ├── test_utils.py │ │ └── utils.py │ ├── core │ │ ├── __init__.py │ │ ├── aggregators_test.py │ │ ├── test_config_parsing.py │ │ ├── test_utils.py │ │ └── tracker_test.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── cb │ │ │ ├── __init__.py │ │ │ ├── test_integration.py │ │ │ ├── test_policy_evaluator.py │ │ │ ├── test_synthetic_contextual_bandit.py │ │ │ └── test_utils.py │ │ ├── test_evaluation_data_page.py │ │ └── test_ope_integration.py │ ├── lite │ │ ├── __init__.py │ │ └── test_combo_optimizer.py │ ├── mab │ │ ├── __init__.py │ │ └── test_mab.py │ ├── models │ │ ├── __init__.py │ │ ├── test_actor.py │ │ ├── test_bcq.py │ │ ├── test_cb_fully_connected.py │ │ ├── test_critic.py │ │ ├── test_deep_represent_linucb_model.py │ │ ├── test_disjoint_linucb_predictor.py │ │ ├── test_dqn.py │ │ ├── test_dueling_q_network.py │ │ ├── test_linear_regression_ucb.py │ │ ├── test_mab.py │ │ ├── test_no_soft_update_embedding.py │ │ ├── test_residual_wrapper.py │ │ ├── test_sparse_dqn_net.py │ │ ├── test_synthetic_reward_net.py │ │ └── test_utils.py │ ├── net_builder │ │ ├── __init__.py │ │ ├── test_continuous_actor_net_builder.py │ │ ├── test_discrete_dqn_net_builder.py │ │ ├── test_parametric_dqn_net_builder.py │ │ ├── test_synthetic_reward_net_builder.py │ │ └── test_value_net_builder.py │ ├── optimizer │ │ ├── __init__.py │ │ └── test_make_optimizer.py │ ├── prediction │ │ ├── __init__.py │ │ ├── test_model_with_preprocessor.py │ │ ├── test_prediction_utils.py │ │ └── test_predictor_wrapper.py │ ├── preprocessing │ │ ├── __init__.py │ │ ├── preprocessing_util.py │ │ ├── test_postprocessing.py │ │ ├── test_preprocessing.py │ │ ├── test_sparse_to_dense.py │ │ ├── test_transforms.py │ │ └── test_type_identification.py │ ├── ranking │ │ ├── __init__.py │ │ ├── seq2slate_utils.py │ │ ├── test_seq2slate_inference.py │ │ ├── test_seq2slate_off_policy.py │ │ ├── test_seq2slate_on_policy.py │ │ ├── test_seq2slate_simulation.py │ │ └── test_seq2slate_trainer.py │ ├── replay_memory │ │ ├── __init__.py │ │ ├── circular_replay_buffer_test.py │ │ ├── create_from_env_test.py │ │ ├── extra_replay_buffer_test.py │ │ ├── prioritized_replay_buffer_test.py │ │ └── sum_tree_test.py │ ├── samplers │ │ ├── __init__.py │ │ └── test_frechet_sort.py │ ├── simulators │ │ └── __init__.py │ ├── test_data │ │ └── ex_mdps.py │ ├── training │ │ ├── __init__.py │ │ ├── cb │ │ │ ├── __init__.py │ │ │ ├── test_deep_represent_linucb.py │ │ │ ├── test_disjoint_linucb.py │ │ │ ├── test_linucb.py │ │ │ ├── test_supervised_trainer.py │ │ │ └── test_utils.py │ │ ├── test_ars_optimizer.py │ │ ├── test_behavioral_cloning.py │ │ ├── test_crr.py │ │ ├── test_dqn.py │ │ ├── test_dqn_base.py │ │ ├── test_multi_stage_trainer.py │ │ ├── test_ppo.py │ │ ├── test_probabilistic.py │ │ ├── test_qrdqn.py │ │ └── test_synthetic_reward_training.py │ ├── workflow │ │ ├── __init__.py │ │ ├── reagent_sql_test_base.py │ │ ├── test_data │ │ │ ├── __init__.py │ │ │ ├── continuous_action │ │ │ │ ├── action_norm.json │ │ │ │ ├── pendulum_eval.json.bz2 │ │ │ │ ├── pendulum_training.json.bz2 │ │ │ │ └── state_features_norm.json │ │ │ ├── discrete_action │ │ │ │ ├── cartpole_norm.json │ │ │ │ └── dqn_workflow.zip │ │ │ └── parametric_action │ │ │ │ ├── action_norm.json │ │ │ │ ├── cartpole_eval.json.bz2 │ │ │ │ ├── cartpole_training.json.bz2 │ │ │ │ ├── cartpole_training_data.json │ │ │ │ └── state_features_norm.json │ │ ├── test_oss_workflows.py │ │ ├── test_preprocessing.py │ │ ├── test_query_data.py │ │ └── test_query_data_parametric.py │ └── world_model │ │ ├── __init__.py │ │ ├── simulated_world_model.py │ │ ├── test_mdnrnn.py │ │ └── test_seq2reward.py ├── training │ ├── __init__.py │ ├── behavioral_cloning_trainer.py │ ├── c51_trainer.py │ ├── cb │ │ ├── README.md │ │ ├── __init__.py │ │ ├── base_trainer.py │ │ ├── deep_represent_linucb_trainer.py │ │ ├── disjoint_linucb_trainer.py │ │ ├── linucb_trainer.py │ │ ├── mab_trainer.py │ │ ├── supervised_trainer.py │ │ └── utils.py │ ├── cem_trainer.py │ ├── cfeval │ │ ├── __init__.py │ │ ├── bandit_reward_network_trainer.py │ │ └── bayes_by_backprop_trainer.py │ ├── discrete_crr_trainer.py │ ├── dqn_trainer.py │ ├── dqn_trainer_base.py │ ├── gradient_free │ │ ├── __init__.py │ │ ├── ars_util.py │ │ ├── es_worker.py │ │ └── evolution_pool.py │ ├── imitator_training.py │ ├── multi_stage_trainer.py │ ├── parameters.py │ ├── parametric_dqn_trainer.py │ ├── ppo_trainer.py │ ├── qrdqn_trainer.py │ ├── ranking │ │ ├── __init__.py │ │ ├── helper.py │ │ ├── seq2slate_attn_trainer.py │ │ ├── seq2slate_sim_trainer.py │ │ ├── seq2slate_tf_trainer.py │ │ └── seq2slate_trainer.py │ ├── reagent_lightning_module.py │ ├── reinforce_trainer.py │ ├── reward_network_trainer.py │ ├── rl_trainer_pytorch.py │ ├── sac_trainer.py │ ├── slate_q_trainer.py │ ├── td3_trainer.py │ ├── utils.py │ └── world_model │ │ ├── __init__.py │ │ ├── compress_model_trainer.py │ │ ├── mdnrnn_trainer.py │ │ └── seq2reward_trainer.py ├── validators │ ├── __init__.py │ ├── model_validator.py │ ├── no_validation.py │ └── union.py └── workflow │ ├── __init__.py │ ├── cli.py │ ├── env.py │ ├── gym_batch_rl.py │ ├── identify_types_flow.py │ ├── sample_configs │ ├── continuous_action │ │ └── timeline.json │ ├── discrete_action │ │ ├── dqn_example.json │ │ └── timeline.json │ ├── discrete_dqn_cartpole_offline.yaml │ ├── parametric_action │ │ ├── parametric_dqn_example.json │ │ └── timeline.json │ └── sac_pendulum_offline.yaml │ ├── training.py │ ├── training_reports.py │ ├── types.py │ └── utils.py ├── scripts └── recurring_training_sac_offline.sh ├── serving ├── CMakeLists.txt ├── README.md ├── examples │ ├── __init__.py │ └── ecommerce │ │ ├── __init__.py │ │ ├── customer_simulator.py │ │ ├── plans │ │ ├── contextual_bandit.json │ │ ├── heuristic.json │ │ └── multi_armed_bandit.json │ │ └── training │ │ └── contextual_bandit.yaml ├── reagent │ └── serving │ │ ├── cli │ │ ├── Main.cpp │ │ ├── Server.cpp │ │ └── Server.h │ │ ├── config │ │ ├── applications │ │ │ ├── __init__.py │ │ │ └── example │ │ │ │ ├── __init__.py │ │ │ │ └── example.py │ │ ├── builder.py │ │ ├── config.py │ │ ├── main.py │ │ ├── namespace.py │ │ ├── operators.py │ │ └── serialize.py │ │ ├── core │ │ ├── ActionValueScorer.cpp │ │ ├── ActionValueScorer.h │ │ ├── ConfigProvider.cpp │ │ ├── ConfigProvider.h │ │ ├── Containers.cpp │ │ ├── Containers.h │ │ ├── DecisionPlan.cpp │ │ ├── DecisionPlan.h │ │ ├── DecisionService.cpp │ │ ├── DecisionService.h │ │ ├── DecisionServiceException.cpp │ │ ├── DecisionServiceException.h │ │ ├── DiskConfigProvider.cpp │ │ ├── DiskConfigProvider.h │ │ ├── Headers.cpp │ │ ├── Headers.h │ │ ├── InMemoryLogJoiner.cpp │ │ ├── InMemoryLogJoiner.h │ │ ├── LocalRealTimeCounter.cpp │ │ ├── LocalRealTimeCounter.h │ │ ├── LogJoiner.cpp │ │ ├── LogJoiner.h │ │ ├── Operator.cpp │ │ ├── Operator.h │ │ ├── OperatorFactory.cpp │ │ ├── OperatorFactory.h │ │ ├── OperatorRunner.cpp │ │ ├── OperatorRunner.h │ │ ├── PytorchActionValueScorer.cpp │ │ ├── PytorchActionValueScorer.h │ │ ├── RealTimeCounter.cpp │ │ ├── RealTimeCounter.h │ │ ├── SharedParameterHandler.cpp │ │ └── SharedParameterHandler.h │ │ ├── operators │ │ ├── ActionValueScoring.cpp │ │ ├── ActionValueScoring.h │ │ ├── EpsilonGreedyRanker.cpp │ │ ├── EpsilonGreedyRanker.h │ │ ├── Expression.cpp │ │ ├── Expression.h │ │ ├── Frechet.cpp │ │ ├── Frechet.h │ │ ├── InputFromRequest.cpp │ │ ├── InputFromRequest.h │ │ ├── PropensityFit.cpp │ │ ├── PropensityFit.h │ │ ├── Softmax.cpp │ │ ├── Softmax.h │ │ ├── SoftmaxRanker.cpp │ │ ├── SoftmaxRanker.h │ │ ├── Ucb.cpp │ │ └── Ucb.h │ │ └── test │ │ ├── DecisionService_test.cpp │ │ ├── EpsilonGreedyRanker_test.cpp │ │ ├── Expression_test.cpp │ │ ├── Frechet_test.cpp │ │ ├── InMemoryLogJoiner_test.cpp │ │ ├── InputFromRequest_test.cpp │ │ ├── PlanProvider_test.cpp │ │ ├── PropensityFit_test.cpp │ │ ├── PytorchScoring_test.cpp │ │ ├── SoftmaxRanker_test.cpp │ │ ├── Softmax_test.cpp │ │ ├── TestHeaders.cpp │ │ ├── TestHeaders.h │ │ └── Ucb_test.cpp ├── requirements.txt ├── scripts │ ├── __init__.py │ └── rasp_to_model.py └── setup.py ├── setup.cfg ├── setup.py └── tox.ini /.codecov.yml: -------------------------------------------------------------------------------- 1 | ignore: 2 | # These are more experimental stuffs 3 | - "reagent/ope/**/*" 4 | - "reagent/training/gradient_free/**/*" 5 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "serving/external/googletest"] 2 | path = serving/external/googletest 3 | url = https://github.com/google/googletest.git 4 | [submodule "serving/external/json"] 5 | path = serving/external/nlohmann_json 6 | url = https://github.com/nlohmann/json.git 7 | [submodule "serving/external/exprtk"] 8 | path = serving/external/exprtk 9 | url = https://github.com/ArashPartow/exprtk.git 10 | [submodule "serving/external/SimpleWebServer"] 11 | path = serving/external/SimpleWebServer 12 | url = https://gitlab.com/eidheim/Simple-Web-Server.git 13 | [submodule "serving/external/cpp-taskflow"] 14 | path = serving/external/cpp-taskflow 15 | url = https://github.com/cpp-taskflow/cpp-taskflow.git 16 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | multi_line_output=3 3 | include_trailing_comma=True 4 | force_grid_wrap=0 5 | use_parentheses=True 6 | line_length=88 7 | lines_after_imports=2 8 | reverse_relative=True 9 | default_section=THIRDPARTY 10 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to ReAgent 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | ## License 26 | By contributing ReAgent, you agree that your contributions will be licensed 27 | under the LICENSE file in the root directory of this source tree. 28 | -------------------------------------------------------------------------------- /docs/_static/empty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/docs/_static/empty -------------------------------------------------------------------------------- /docs/api/modules.rst: -------------------------------------------------------------------------------- 1 | reagent 2 | ======= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | reagent 8 | -------------------------------------------------------------------------------- /docs/api/reagent.data.rst: -------------------------------------------------------------------------------- 1 | reagent.data package 2 | ==================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.data.data\_fetcher module 8 | --------------------------------- 9 | 10 | .. automodule:: reagent.data.data_fetcher 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.data.manual\_data\_module module 16 | ---------------------------------------- 17 | 18 | .. automodule:: reagent.data.manual_data_module 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | reagent.data.oss\_data\_fetcher module 24 | -------------------------------------- 25 | 26 | .. automodule:: reagent.data.oss_data_fetcher 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | reagent.data.reagent\_data\_module module 32 | ----------------------------------------- 33 | 34 | .. automodule:: reagent.data.reagent_data_module 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | reagent.data.spark\_utils module 40 | -------------------------------- 41 | 42 | .. automodule:: reagent.data.spark_utils 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | Module contents 48 | --------------- 49 | 50 | .. automodule:: reagent.data 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | -------------------------------------------------------------------------------- /docs/api/reagent.evaluation.feature_importance.rst: -------------------------------------------------------------------------------- 1 | reagent.evaluation.feature\_importance package 2 | ============================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.evaluation.feature\_importance.feature\_importance\_base module 8 | ----------------------------------------------------------------------- 9 | 10 | .. automodule:: reagent.evaluation.feature_importance.feature_importance_base 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.evaluation.feature\_importance.feature\_importance\_perturbation module 16 | ------------------------------------------------------------------------------- 17 | 18 | .. automodule:: reagent.evaluation.feature_importance.feature_importance_perturbation 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: reagent.evaluation.feature_importance 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/api/reagent.gym.agents.rst: -------------------------------------------------------------------------------- 1 | reagent.gym.agents package 2 | ========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.gym.agents.agent module 8 | ------------------------------- 9 | 10 | .. automodule:: reagent.gym.agents.agent 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.gym.agents.post\_step module 16 | ------------------------------------ 17 | 18 | .. automodule:: reagent.gym.agents.post_step 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: reagent.gym.agents 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/api/reagent.gym.datasets.rst: -------------------------------------------------------------------------------- 1 | reagent.gym.datasets package 2 | ============================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.gym.datasets.episodic\_dataset module 8 | --------------------------------------------- 9 | 10 | .. automodule:: reagent.gym.datasets.episodic_dataset 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.gym.datasets.replay\_buffer\_dataset module 16 | --------------------------------------------------- 17 | 18 | .. automodule:: reagent.gym.datasets.replay_buffer_dataset 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: reagent.gym.datasets 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/api/reagent.gym.envs.dynamics.rst: -------------------------------------------------------------------------------- 1 | reagent.gym.envs.dynamics package 2 | ================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.gym.envs.dynamics.linear\_dynamics module 8 | ------------------------------------------------- 9 | 10 | .. automodule:: reagent.gym.envs.dynamics.linear_dynamics 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: reagent.gym.envs.dynamics 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/api/reagent.gym.envs.functionality.rst: -------------------------------------------------------------------------------- 1 | reagent.gym.envs.functionality package 2 | ====================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.gym.envs.functionality.possible\_actions\_mask\_tester module 8 | --------------------------------------------------------------------- 9 | 10 | .. automodule:: reagent.gym.envs.functionality.possible_actions_mask_tester 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: reagent.gym.envs.functionality 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/api/reagent.gym.envs.pomdp.rst: -------------------------------------------------------------------------------- 1 | reagent.gym.envs.pomdp package 2 | ============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.gym.envs.pomdp.pocman module 8 | ------------------------------------ 9 | 10 | .. automodule:: reagent.gym.envs.pomdp.pocman 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.gym.envs.pomdp.state\_embed\_env module 16 | ----------------------------------------------- 17 | 18 | .. automodule:: reagent.gym.envs.pomdp.state_embed_env 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | reagent.gym.envs.pomdp.string\_game module 24 | ------------------------------------------ 25 | 26 | .. automodule:: reagent.gym.envs.pomdp.string_game 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | reagent.gym.envs.pomdp.string\_game\_v1 module 32 | ---------------------------------------------- 33 | 34 | .. automodule:: reagent.gym.envs.pomdp.string_game_v1 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | Module contents 40 | --------------- 41 | 42 | .. automodule:: reagent.gym.envs.pomdp 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | -------------------------------------------------------------------------------- /docs/api/reagent.gym.envs.wrappers.rst: -------------------------------------------------------------------------------- 1 | reagent.gym.envs.wrappers package 2 | ================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.gym.envs.wrappers.recsim module 8 | --------------------------------------- 9 | 10 | .. automodule:: reagent.gym.envs.wrappers.recsim 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.gym.envs.wrappers.simple\_minigrid module 16 | ------------------------------------------------- 17 | 18 | .. automodule:: reagent.gym.envs.wrappers.simple_minigrid 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: reagent.gym.envs.wrappers 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/api/reagent.gym.policies.rst: -------------------------------------------------------------------------------- 1 | reagent.gym.policies package 2 | ============================ 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | reagent.gym.policies.samplers 11 | reagent.gym.policies.scorers 12 | 13 | Submodules 14 | ---------- 15 | 16 | reagent.gym.policies.policy module 17 | ---------------------------------- 18 | 19 | .. automodule:: reagent.gym.policies.policy 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | 24 | reagent.gym.policies.predictor\_policies module 25 | ----------------------------------------------- 26 | 27 | .. automodule:: reagent.gym.policies.predictor_policies 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | 32 | reagent.gym.policies.random\_policies module 33 | -------------------------------------------- 34 | 35 | .. automodule:: reagent.gym.policies.random_policies 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | 40 | Module contents 41 | --------------- 42 | 43 | .. automodule:: reagent.gym.policies 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | -------------------------------------------------------------------------------- /docs/api/reagent.gym.policies.samplers.rst: -------------------------------------------------------------------------------- 1 | reagent.gym.policies.samplers package 2 | ===================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.gym.policies.samplers.continuous\_sampler module 8 | -------------------------------------------------------- 9 | 10 | .. automodule:: reagent.gym.policies.samplers.continuous_sampler 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.gym.policies.samplers.discrete\_sampler module 16 | ------------------------------------------------------ 17 | 18 | .. automodule:: reagent.gym.policies.samplers.discrete_sampler 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | reagent.gym.policies.samplers.top\_k\_sampler module 24 | ---------------------------------------------------- 25 | 26 | .. automodule:: reagent.gym.policies.samplers.top_k_sampler 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: reagent.gym.policies.samplers 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/api/reagent.gym.policies.scorers.rst: -------------------------------------------------------------------------------- 1 | reagent.gym.policies.scorers package 2 | ==================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.gym.policies.scorers.continuous\_scorer module 8 | ------------------------------------------------------ 9 | 10 | .. automodule:: reagent.gym.policies.scorers.continuous_scorer 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.gym.policies.scorers.discrete\_scorer module 16 | ---------------------------------------------------- 17 | 18 | .. automodule:: reagent.gym.policies.scorers.discrete_scorer 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | reagent.gym.policies.scorers.slate\_q\_scorer module 24 | ---------------------------------------------------- 25 | 26 | .. automodule:: reagent.gym.policies.scorers.slate_q_scorer 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: reagent.gym.policies.scorers 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/api/reagent.gym.preprocessors.rst: -------------------------------------------------------------------------------- 1 | reagent.gym.preprocessors package 2 | ================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.gym.preprocessors.default\_preprocessors module 8 | ------------------------------------------------------- 9 | 10 | .. automodule:: reagent.gym.preprocessors.default_preprocessors 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.gym.preprocessors.replay\_buffer\_inserters module 16 | ---------------------------------------------------------- 17 | 18 | .. automodule:: reagent.gym.preprocessors.replay_buffer_inserters 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | reagent.gym.preprocessors.trainer\_preprocessor module 24 | ------------------------------------------------------ 25 | 26 | .. automodule:: reagent.gym.preprocessors.trainer_preprocessor 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: reagent.gym.preprocessors 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/api/reagent.gym.rst: -------------------------------------------------------------------------------- 1 | reagent.gym package 2 | =================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | reagent.gym.agents 11 | reagent.gym.datasets 12 | reagent.gym.envs 13 | reagent.gym.policies 14 | reagent.gym.preprocessors 15 | reagent.gym.runners 16 | reagent.gym.tests 17 | 18 | Submodules 19 | ---------- 20 | 21 | reagent.gym.normalizers module 22 | ------------------------------ 23 | 24 | .. automodule:: reagent.gym.normalizers 25 | :members: 26 | :undoc-members: 27 | :show-inheritance: 28 | 29 | reagent.gym.types module 30 | ------------------------ 31 | 32 | .. automodule:: reagent.gym.types 33 | :members: 34 | :undoc-members: 35 | :show-inheritance: 36 | 37 | reagent.gym.utils module 38 | ------------------------ 39 | 40 | .. automodule:: reagent.gym.utils 41 | :members: 42 | :undoc-members: 43 | :show-inheritance: 44 | 45 | Module contents 46 | --------------- 47 | 48 | .. automodule:: reagent.gym 49 | :members: 50 | :undoc-members: 51 | :show-inheritance: 52 | -------------------------------------------------------------------------------- /docs/api/reagent.gym.runners.rst: -------------------------------------------------------------------------------- 1 | reagent.gym.runners package 2 | =========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.gym.runners.gymrunner module 8 | ------------------------------------ 9 | 10 | .. automodule:: reagent.gym.runners.gymrunner 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: reagent.gym.runners 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/api/reagent.gym.tests.preprocessors.rst: -------------------------------------------------------------------------------- 1 | reagent.gym.tests.preprocessors package 2 | ======================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.gym.tests.preprocessors.test\_default\_preprocessors module 8 | ------------------------------------------------------------------- 9 | 10 | .. automodule:: reagent.gym.tests.preprocessors.test_default_preprocessors 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.gym.tests.preprocessors.test\_replay\_buffer\_inserters module 16 | ---------------------------------------------------------------------- 17 | 18 | .. automodule:: reagent.gym.tests.preprocessors.test_replay_buffer_inserters 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: reagent.gym.tests.preprocessors 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/api/reagent.lite.rst: -------------------------------------------------------------------------------- 1 | reagent.lite package 2 | ==================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.lite.optimizer module 8 | ----------------------------- 9 | 10 | .. automodule:: reagent.lite.optimizer 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: reagent.lite 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/api/reagent.mab.rst: -------------------------------------------------------------------------------- 1 | reagent.mab package 2 | =================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.mab.mab\_algorithm module 8 | --------------------------------- 9 | 10 | .. automodule:: reagent.mab.mab_algorithm 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.mab.simulation module 16 | ----------------------------- 17 | 18 | .. automodule:: reagent.mab.simulation 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | reagent.mab.thompson\_sampling module 24 | ------------------------------------- 25 | 26 | .. automodule:: reagent.mab.thompson_sampling 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | reagent.mab.ucb module 32 | ---------------------- 33 | 34 | .. automodule:: reagent.mab.ucb 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | Module contents 40 | --------------- 41 | 42 | .. automodule:: reagent.mab 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | -------------------------------------------------------------------------------- /docs/api/reagent.model_managers.actor_critic.rst: -------------------------------------------------------------------------------- 1 | reagent.model\_managers.actor\_critic package 2 | ============================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.model\_managers.actor\_critic.sac module 8 | ------------------------------------------------ 9 | 10 | .. automodule:: reagent.model_managers.actor_critic.sac 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.model\_managers.actor\_critic.td3 module 16 | ------------------------------------------------ 17 | 18 | .. automodule:: reagent.model_managers.actor_critic.td3 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: reagent.model_managers.actor_critic 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/api/reagent.model_managers.discrete.rst: -------------------------------------------------------------------------------- 1 | reagent.model\_managers.discrete package 2 | ======================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.model\_managers.discrete.discrete\_c51dqn module 8 | -------------------------------------------------------- 9 | 10 | .. automodule:: reagent.model_managers.discrete.discrete_c51dqn 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.model\_managers.discrete.discrete\_crr module 16 | ----------------------------------------------------- 17 | 18 | .. automodule:: reagent.model_managers.discrete.discrete_crr 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | reagent.model\_managers.discrete.discrete\_dqn module 24 | ----------------------------------------------------- 25 | 26 | .. automodule:: reagent.model_managers.discrete.discrete_dqn 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | reagent.model\_managers.discrete.discrete\_qrdqn module 32 | ------------------------------------------------------- 33 | 34 | .. automodule:: reagent.model_managers.discrete.discrete_qrdqn 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | Module contents 40 | --------------- 41 | 42 | .. automodule:: reagent.model_managers.discrete 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | -------------------------------------------------------------------------------- /docs/api/reagent.model_managers.model_based.rst: -------------------------------------------------------------------------------- 1 | reagent.model\_managers.model\_based package 2 | ============================================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.model\_managers.model\_based.cross\_entropy\_method module 8 | ------------------------------------------------------------------ 9 | 10 | .. automodule:: reagent.model_managers.model_based.cross_entropy_method 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.model\_managers.model\_based.seq2reward\_model module 16 | ------------------------------------------------------------- 17 | 18 | .. automodule:: reagent.model_managers.model_based.seq2reward_model 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | reagent.model\_managers.model\_based.synthetic\_reward module 24 | ------------------------------------------------------------- 25 | 26 | .. automodule:: reagent.model_managers.model_based.synthetic_reward 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | reagent.model\_managers.model\_based.world\_model module 32 | -------------------------------------------------------- 33 | 34 | .. automodule:: reagent.model_managers.model_based.world_model 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | Module contents 40 | --------------- 41 | 42 | .. automodule:: reagent.model_managers.model_based 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | -------------------------------------------------------------------------------- /docs/api/reagent.model_managers.parametric.rst: -------------------------------------------------------------------------------- 1 | reagent.model\_managers.parametric package 2 | ========================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.model\_managers.parametric.parametric\_dqn module 8 | --------------------------------------------------------- 9 | 10 | .. automodule:: reagent.model_managers.parametric.parametric_dqn 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: reagent.model_managers.parametric 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/api/reagent.model_managers.policy_gradient.rst: -------------------------------------------------------------------------------- 1 | reagent.model\_managers.policy\_gradient package 2 | ================================================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.model\_managers.policy\_gradient.ppo module 8 | --------------------------------------------------- 9 | 10 | .. automodule:: reagent.model_managers.policy_gradient.ppo 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.model\_managers.policy\_gradient.reinforce module 16 | --------------------------------------------------------- 17 | 18 | .. automodule:: reagent.model_managers.policy_gradient.reinforce 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: reagent.model_managers.policy_gradient 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/api/reagent.model_managers.ranking.rst: -------------------------------------------------------------------------------- 1 | reagent.model\_managers.ranking package 2 | ======================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.model\_managers.ranking.slate\_q module 8 | ----------------------------------------------- 9 | 10 | .. automodule:: reagent.model_managers.ranking.slate_q 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: reagent.model_managers.ranking 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/api/reagent.model_utils.rst: -------------------------------------------------------------------------------- 1 | reagent.model\_utils package 2 | ============================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.model\_utils.seq2slate\_utils module 8 | -------------------------------------------- 9 | 10 | .. automodule:: reagent.model_utils.seq2slate_utils 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: reagent.model_utils 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/api/reagent.net_builder.categorical_dqn.rst: -------------------------------------------------------------------------------- 1 | reagent.net\_builder.categorical\_dqn package 2 | ============================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.net\_builder.categorical\_dqn.categorical module 8 | -------------------------------------------------------- 9 | 10 | .. automodule:: reagent.net_builder.categorical_dqn.categorical 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: reagent.net_builder.categorical_dqn 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/api/reagent.net_builder.continuous_actor.rst: -------------------------------------------------------------------------------- 1 | reagent.net\_builder.continuous\_actor package 2 | ============================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.net\_builder.continuous\_actor.dirichlet\_fully\_connected module 8 | ------------------------------------------------------------------------- 9 | 10 | .. automodule:: reagent.net_builder.continuous_actor.dirichlet_fully_connected 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.net\_builder.continuous\_actor.fully\_connected module 16 | -------------------------------------------------------------- 17 | 18 | .. automodule:: reagent.net_builder.continuous_actor.fully_connected 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | reagent.net\_builder.continuous\_actor.gaussian\_fully\_connected module 24 | ------------------------------------------------------------------------ 25 | 26 | .. automodule:: reagent.net_builder.continuous_actor.gaussian_fully_connected 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: reagent.net_builder.continuous_actor 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/api/reagent.net_builder.discrete_actor.rst: -------------------------------------------------------------------------------- 1 | reagent.net\_builder.discrete\_actor package 2 | ============================================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.net\_builder.discrete\_actor.fully\_connected module 8 | ------------------------------------------------------------ 9 | 10 | .. automodule:: reagent.net_builder.discrete_actor.fully_connected 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: reagent.net_builder.discrete_actor 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/api/reagent.net_builder.discrete_dqn.rst: -------------------------------------------------------------------------------- 1 | reagent.net\_builder.discrete\_dqn package 2 | ========================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.net\_builder.discrete\_dqn.dueling module 8 | ------------------------------------------------- 9 | 10 | .. automodule:: reagent.net_builder.discrete_dqn.dueling 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.net\_builder.discrete\_dqn.fully\_connected module 16 | ---------------------------------------------------------- 17 | 18 | .. automodule:: reagent.net_builder.discrete_dqn.fully_connected 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | reagent.net\_builder.discrete\_dqn.fully\_connected\_with\_embedding module 24 | --------------------------------------------------------------------------- 25 | 26 | .. automodule:: reagent.net_builder.discrete_dqn.fully_connected_with_embedding 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: reagent.net_builder.discrete_dqn 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/api/reagent.net_builder.parametric_dqn.rst: -------------------------------------------------------------------------------- 1 | reagent.net\_builder.parametric\_dqn package 2 | ============================================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.net\_builder.parametric\_dqn.fully\_connected module 8 | ------------------------------------------------------------ 9 | 10 | .. automodule:: reagent.net_builder.parametric_dqn.fully_connected 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: reagent.net_builder.parametric_dqn 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/api/reagent.net_builder.quantile_dqn.rst: -------------------------------------------------------------------------------- 1 | reagent.net\_builder.quantile\_dqn package 2 | ========================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.net\_builder.quantile\_dqn.dueling\_quantile module 8 | ----------------------------------------------------------- 9 | 10 | .. automodule:: reagent.net_builder.quantile_dqn.dueling_quantile 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.net\_builder.quantile\_dqn.quantile module 16 | -------------------------------------------------- 17 | 18 | .. automodule:: reagent.net_builder.quantile_dqn.quantile 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: reagent.net_builder.quantile_dqn 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/api/reagent.net_builder.slate_ranking.rst: -------------------------------------------------------------------------------- 1 | reagent.net\_builder.slate\_ranking package 2 | =========================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.net\_builder.slate\_ranking.slate\_ranking\_scorer module 8 | ----------------------------------------------------------------- 9 | 10 | .. automodule:: reagent.net_builder.slate_ranking.slate_ranking_scorer 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.net\_builder.slate\_ranking.slate\_ranking\_transformer module 16 | ---------------------------------------------------------------------- 17 | 18 | .. automodule:: reagent.net_builder.slate_ranking.slate_ranking_transformer 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: reagent.net_builder.slate_ranking 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/api/reagent.net_builder.slate_reward.rst: -------------------------------------------------------------------------------- 1 | reagent.net\_builder.slate\_reward package 2 | ========================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.net\_builder.slate\_reward.slate\_reward\_gru module 8 | ------------------------------------------------------------ 9 | 10 | .. automodule:: reagent.net_builder.slate_reward.slate_reward_gru 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.net\_builder.slate\_reward.slate\_reward\_transformer module 16 | -------------------------------------------------------------------- 17 | 18 | .. automodule:: reagent.net_builder.slate_reward.slate_reward_transformer 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: reagent.net_builder.slate_reward 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/api/reagent.net_builder.value.rst: -------------------------------------------------------------------------------- 1 | reagent.net\_builder.value package 2 | ================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.net\_builder.value.fully\_connected module 8 | -------------------------------------------------- 9 | 10 | .. automodule:: reagent.net_builder.value.fully_connected 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.net\_builder.value.seq2reward\_rnn module 16 | ------------------------------------------------- 17 | 18 | .. automodule:: reagent.net_builder.value.seq2reward_rnn 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: reagent.net_builder.value 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/api/reagent.ope.datasets.rst: -------------------------------------------------------------------------------- 1 | reagent.ope.datasets package 2 | ============================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.ope.datasets.logged\_dataset module 8 | ------------------------------------------- 9 | 10 | .. automodule:: reagent.ope.datasets.logged_dataset 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: reagent.ope.datasets 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/api/reagent.ope.estimators.rst: -------------------------------------------------------------------------------- 1 | reagent.ope.estimators package 2 | ============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.ope.estimators.contextual\_bandits\_estimators module 8 | ------------------------------------------------------------- 9 | 10 | .. automodule:: reagent.ope.estimators.contextual_bandits_estimators 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.ope.estimators.estimator module 16 | --------------------------------------- 17 | 18 | .. automodule:: reagent.ope.estimators.estimator 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | reagent.ope.estimators.sequential\_estimators module 24 | ---------------------------------------------------- 25 | 26 | .. automodule:: reagent.ope.estimators.sequential_estimators 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | reagent.ope.estimators.slate\_estimators module 32 | ----------------------------------------------- 33 | 34 | .. automodule:: reagent.ope.estimators.slate_estimators 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | reagent.ope.estimators.types module 40 | ----------------------------------- 41 | 42 | .. automodule:: reagent.ope.estimators.types 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | Module contents 48 | --------------- 49 | 50 | .. automodule:: reagent.ope.estimators 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | -------------------------------------------------------------------------------- /docs/api/reagent.ope.rst: -------------------------------------------------------------------------------- 1 | reagent.ope package 2 | =================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | reagent.ope.datasets 11 | reagent.ope.estimators 12 | reagent.ope.test 13 | reagent.ope.trainers 14 | 15 | Submodules 16 | ---------- 17 | 18 | reagent.ope.utils module 19 | ------------------------ 20 | 21 | .. automodule:: reagent.ope.utils 22 | :members: 23 | :undoc-members: 24 | :show-inheritance: 25 | 26 | Module contents 27 | --------------- 28 | 29 | .. automodule:: reagent.ope 30 | :members: 31 | :undoc-members: 32 | :show-inheritance: 33 | -------------------------------------------------------------------------------- /docs/api/reagent.ope.test.unit_tests.rst: -------------------------------------------------------------------------------- 1 | reagent.ope.test.unit\_tests package 2 | ==================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.ope.test.unit\_tests.test\_contextual\_bandit\_estimators module 8 | ------------------------------------------------------------------------ 9 | 10 | .. automodule:: reagent.ope.test.unit_tests.test_contextual_bandit_estimators 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.ope.test.unit\_tests.test\_slate\_estimators module 16 | ----------------------------------------------------------- 17 | 18 | .. automodule:: reagent.ope.test.unit_tests.test_slate_estimators 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | reagent.ope.test.unit\_tests.test\_types module 24 | ----------------------------------------------- 25 | 26 | .. automodule:: reagent.ope.test.unit_tests.test_types 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | reagent.ope.test.unit\_tests.test\_utils module 32 | ----------------------------------------------- 33 | 34 | .. automodule:: reagent.ope.test.unit_tests.test_utils 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | Module contents 40 | --------------- 41 | 42 | .. automodule:: reagent.ope.test.unit_tests 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | -------------------------------------------------------------------------------- /docs/api/reagent.ope.trainers.rst: -------------------------------------------------------------------------------- 1 | reagent.ope.trainers package 2 | ============================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.ope.trainers.linear\_trainers module 8 | -------------------------------------------- 9 | 10 | .. automodule:: reagent.ope.trainers.linear_trainers 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.ope.trainers.rl\_tabular\_trainers module 16 | ------------------------------------------------- 17 | 18 | .. automodule:: reagent.ope.trainers.rl_tabular_trainers 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: reagent.ope.trainers 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/api/reagent.prediction.ranking.rst: -------------------------------------------------------------------------------- 1 | reagent.prediction.ranking package 2 | ================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.prediction.ranking.predictor\_wrapper module 8 | ---------------------------------------------------- 9 | 10 | .. automodule:: reagent.prediction.ranking.predictor_wrapper 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: reagent.prediction.ranking 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/api/reagent.prediction.rst: -------------------------------------------------------------------------------- 1 | reagent.prediction package 2 | ========================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | reagent.prediction.ranking 11 | reagent.prediction.synthetic_reward 12 | 13 | Submodules 14 | ---------- 15 | 16 | reagent.prediction.predictor\_wrapper module 17 | -------------------------------------------- 18 | 19 | .. automodule:: reagent.prediction.predictor_wrapper 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | 24 | Module contents 25 | --------------- 26 | 27 | .. automodule:: reagent.prediction 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | -------------------------------------------------------------------------------- /docs/api/reagent.prediction.synthetic_reward.rst: -------------------------------------------------------------------------------- 1 | reagent.prediction.synthetic\_reward package 2 | ============================================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.prediction.synthetic\_reward.synthetic\_reward\_predictor\_wrapper module 8 | --------------------------------------------------------------------------------- 9 | 10 | .. automodule:: reagent.prediction.synthetic_reward.synthetic_reward_predictor_wrapper 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: reagent.prediction.synthetic_reward 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/api/reagent.publishers.rst: -------------------------------------------------------------------------------- 1 | reagent.publishers package 2 | ========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.publishers.file\_system\_publisher module 8 | ------------------------------------------------- 9 | 10 | .. automodule:: reagent.publishers.file_system_publisher 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.publishers.model\_publisher module 16 | ------------------------------------------ 17 | 18 | .. automodule:: reagent.publishers.model_publisher 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | reagent.publishers.no\_publishing module 24 | ---------------------------------------- 25 | 26 | .. automodule:: reagent.publishers.no_publishing 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | reagent.publishers.union module 32 | ------------------------------- 33 | 34 | .. automodule:: reagent.publishers.union 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | Module contents 40 | --------------- 41 | 42 | .. automodule:: reagent.publishers 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | -------------------------------------------------------------------------------- /docs/api/reagent.replay_memory.rst: -------------------------------------------------------------------------------- 1 | reagent.replay\_memory package 2 | ============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.replay\_memory.circular\_replay\_buffer module 8 | ------------------------------------------------------ 9 | 10 | .. automodule:: reagent.replay_memory.circular_replay_buffer 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.replay\_memory.prioritized\_replay\_buffer module 16 | --------------------------------------------------------- 17 | 18 | .. automodule:: reagent.replay_memory.prioritized_replay_buffer 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | reagent.replay\_memory.sum\_tree module 24 | --------------------------------------- 25 | 26 | .. automodule:: reagent.replay_memory.sum_tree 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | reagent.replay\_memory.utils module 32 | ----------------------------------- 33 | 34 | .. automodule:: reagent.replay_memory.utils 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | Module contents 40 | --------------- 41 | 42 | .. automodule:: reagent.replay_memory 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | -------------------------------------------------------------------------------- /docs/api/reagent.rst: -------------------------------------------------------------------------------- 1 | reagent package 2 | =============== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | reagent.core 11 | reagent.data 12 | reagent.evaluation 13 | reagent.gym 14 | reagent.lite 15 | reagent.mab 16 | reagent.model_managers 17 | reagent.model_utils 18 | reagent.models 19 | reagent.net_builder 20 | reagent.ope 21 | reagent.optimizer 22 | reagent.prediction 23 | reagent.preprocessing 24 | reagent.publishers 25 | reagent.replay_memory 26 | reagent.reporting 27 | reagent.samplers 28 | reagent.scripts 29 | reagent.training 30 | reagent.validators 31 | reagent.workflow 32 | 33 | Module contents 34 | --------------- 35 | 36 | .. automodule:: reagent 37 | :members: 38 | :undoc-members: 39 | :show-inheritance: 40 | -------------------------------------------------------------------------------- /docs/api/reagent.samplers.rst: -------------------------------------------------------------------------------- 1 | reagent.samplers package 2 | ======================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.samplers.frechet module 8 | ------------------------------- 9 | 10 | .. automodule:: reagent.samplers.frechet 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: reagent.samplers 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/api/reagent.scripts.rst: -------------------------------------------------------------------------------- 1 | reagent.scripts package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.scripts.hparam\_tuning module 8 | ------------------------------------- 9 | 10 | .. automodule:: reagent.scripts.hparam_tuning 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: reagent.scripts 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/api/reagent.training.cb.rst: -------------------------------------------------------------------------------- 1 | reagent.training.cb package 2 | =========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.training.cb.linucb\_trainer module 8 | ------------------------------------------ 9 | 10 | .. automodule:: reagent.training.cb.linucb_trainer 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: reagent.training.cb 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/api/reagent.training.cfeval.rst: -------------------------------------------------------------------------------- 1 | reagent.training.cfeval package 2 | =============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.training.cfeval.bandit\_reward\_network\_trainer module 8 | --------------------------------------------------------------- 9 | 10 | .. automodule:: reagent.training.cfeval.bandit_reward_network_trainer 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: reagent.training.cfeval 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/api/reagent.training.gradient_free.rst: -------------------------------------------------------------------------------- 1 | reagent.training.gradient\_free package 2 | ======================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.training.gradient\_free.ars\_util module 8 | ------------------------------------------------ 9 | 10 | .. automodule:: reagent.training.gradient_free.ars_util 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.training.gradient\_free.es\_worker module 16 | ------------------------------------------------- 17 | 18 | .. automodule:: reagent.training.gradient_free.es_worker 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | reagent.training.gradient\_free.evolution\_pool module 24 | ------------------------------------------------------ 25 | 26 | .. automodule:: reagent.training.gradient_free.evolution_pool 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: reagent.training.gradient_free 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/api/reagent.training.world_model.rst: -------------------------------------------------------------------------------- 1 | reagent.training.world\_model package 2 | ===================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.training.world\_model.compress\_model\_trainer module 8 | ------------------------------------------------------------- 9 | 10 | .. automodule:: reagent.training.world_model.compress_model_trainer 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.training.world\_model.mdnrnn\_trainer module 16 | ---------------------------------------------------- 17 | 18 | .. automodule:: reagent.training.world_model.mdnrnn_trainer 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | reagent.training.world\_model.seq2reward\_trainer module 24 | -------------------------------------------------------- 25 | 26 | .. automodule:: reagent.training.world_model.seq2reward_trainer 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: reagent.training.world_model 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/api/reagent.validators.rst: -------------------------------------------------------------------------------- 1 | reagent.validators package 2 | ========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | reagent.validators.model\_validator module 8 | ------------------------------------------ 9 | 10 | .. automodule:: reagent.validators.model_validator 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | reagent.validators.no\_validation module 16 | ---------------------------------------- 17 | 18 | .. automodule:: reagent.validators.no_validation 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | reagent.validators.union module 24 | ------------------------------- 25 | 26 | .. automodule:: reagent.validators.union 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: reagent.validators 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | rm -rf api/* && rm -rf ~/github/HorizonDocs && sphinx-build -b html -E -v . ~/github/HorizonDocs 3 | -------------------------------------------------------------------------------- /docs/continuous_integration.rst: -------------------------------------------------------------------------------- 1 | .. _continuous_integration: 2 | 3 | Continuous Integration 4 | ====================== 5 | 6 | We have CI setup on `CircleCI `_. 7 | It's a pretty basic setup. You should follow the local testing instructions in ``.circleci/config.yml``. 8 | -------------------------------------------------------------------------------- /logo/horizon_banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/logo/horizon_banner.png -------------------------------------------------------------------------------- /logo/horizon_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/logo/horizon_logo.png -------------------------------------------------------------------------------- /logo/horizon_logo_256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/logo/horizon_logo_256.png -------------------------------------------------------------------------------- /logo/horizon_logo_inverted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/logo/horizon_logo_inverted.png -------------------------------------------------------------------------------- /logo/reagent_banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/logo/reagent_banner.png -------------------------------------------------------------------------------- /logo/reagent_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/logo/reagent_logo.png -------------------------------------------------------------------------------- /preprocessing/src/main/scala/com/facebook/spark/rl/Udfs.scala: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | package com.facebook.spark.rl 3 | 4 | import org.apache.spark.sql.functions.coalesce 5 | import org.apache.spark.sql.functions.udf 6 | 7 | object Udfs { 8 | def unionListOfMaps[A, B](input: Seq[Map[A, B]]): Map[A, B] = 9 | input.flatten.toMap 10 | 11 | def prepend[A](x: A, arr: Seq[A]): Seq[A] = 12 | x +: arr 13 | 14 | def sort_list_of_map[B](x: Seq[Map[Long, B]]): Seq[B] = 15 | x.sortBy(_.head._1).map(_.head._2) 16 | 17 | def drop_last[A](x: Seq[A]): Seq[A] = 18 | x.dropRight(1) 19 | 20 | val emptyMap = udf(() => Map.empty[Long, Double]) 21 | val emptyMapOfIds = udf(() => Map.empty[Long, Seq[Long]]) 22 | val emptyMapOfMap = udf(() => Map.empty[Long, Map[Long, Double]]) 23 | val emptyMapOfArrOfMap = udf(() => Map.empty[Long, Seq[Map[Long, Double]]]) 24 | val emptyStr = udf(() => "") 25 | val emptyArrOfLong = udf(() => Array.empty[Long]) 26 | val emptyArrOfStr = udf(() => Array.empty[String]) 27 | val emptyArrOfDbl = udf(() => Array.empty[Double]) 28 | val emptyArrOfMap = udf(() => Array.empty[Map[Long, Double]]) 29 | val emptyArrOfMapStr = udf(() => Array.empty[Map[String, Double]]) 30 | 31 | } 32 | -------------------------------------------------------------------------------- /preprocessing/src/test/scala/com/facebook/spark/common/testutil/TestLogger.scala: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | package com.facebook.spark.common.testutil 3 | 4 | import org.slf4j.{Logger, LoggerFactory} 5 | 6 | trait TestLogger { 7 | lazy val log: Logger = LoggerFactory.getLogger(this.getClass.getName) 8 | } 9 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools >= 42", 4 | "setuptools_scm[toml] >= 3.4", 5 | "wheel" 6 | ] 7 | build-backend = "setuptools.build_meta" 8 | [tool.setuptools_scm] 9 | -------------------------------------------------------------------------------- /rasp_requirements.txt: -------------------------------------------------------------------------------- 1 | boost 2 | cmake 3 | gflags==2.2.2 4 | glog==0.4.0 5 | eigen==3.3.7 6 | -------------------------------------------------------------------------------- /reagent/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | # pyre-unsafe 4 | -------------------------------------------------------------------------------- /reagent/core/base_dataclass.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | """ 7 | We should revisit this at some point. Config classes shouldn't subclass from this. 8 | """ 9 | 10 | import dataclasses 11 | from typing import cast 12 | 13 | 14 | class BaseDataClass: 15 | def _replace(self, **kwargs): 16 | return cast(type(self), dataclasses.replace(self, **kwargs)) 17 | -------------------------------------------------------------------------------- /reagent/core/debug_on_error.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import sys 7 | 8 | 9 | def start() -> None: 10 | def info(type, value, tb): 11 | if hasattr(sys, "ps1") or not sys.stderr.isatty(): 12 | # we are in interactive mode or we don't have a tty-like 13 | # device, so we call the default hook 14 | sys.__excepthook__(type, value, tb) 15 | else: 16 | import pdb 17 | import traceback 18 | 19 | # we are NOT in interactive mode, print the exception... 20 | traceback.print_exception(type, value, tb) 21 | print 22 | # ...then start the debugger in post-mortem mode. 23 | # pdb.pm() # deprecated 24 | pdb.post_mortem(tb) # more "modern" 25 | 26 | sys.excepthook = info 27 | -------------------------------------------------------------------------------- /reagent/core/fb_checker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | import importlib.util 6 | import os 7 | 8 | 9 | def is_fb_environment() -> bool: 10 | if importlib.util.find_spec("fblearner") is not None: 11 | if not bool(int(os.environ.get("FORCE_OSS_ENVIRONMENT", False))): 12 | return True 13 | return False 14 | 15 | 16 | IS_FB_ENVIRONMENT: bool = is_fb_environment() 17 | -------------------------------------------------------------------------------- /reagent/core/report_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | # pyre-unsafe 5 | 6 | import logging 7 | from math import ceil 8 | from typing import Dict, List 9 | 10 | import numpy as np 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def get_mean_of_recent_values( 17 | values: Dict[str, List[float]], min_window_size=10 18 | ) -> Dict[str, float]: 19 | ret = {} 20 | for key, vals in values.items(): 21 | window_size = max(min_window_size, int(ceil(0.1 * len(vals)))) 22 | ret[key] = np.mean(vals[-window_size:]) 23 | return ret 24 | 25 | 26 | def calculate_recent_window_average(arr, window_size, num_entries): 27 | if len(arr) > 0: 28 | begin = max(0, len(arr) - window_size) 29 | return np.mean(np.array(arr[begin:]), axis=0) 30 | else: 31 | logger.error("Not enough samples for evaluation.") 32 | if num_entries == 1: 33 | return float("nan") 34 | else: 35 | return [float("nan")] * num_entries 36 | -------------------------------------------------------------------------------- /reagent/core/result_registries.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from reagent.core.dataclasses import dataclass 7 | from reagent.core.registry_meta import RegistryMeta 8 | 9 | 10 | class TrainingReport(metaclass=RegistryMeta): 11 | pass 12 | 13 | 14 | @dataclass 15 | class PublishingResult(metaclass=RegistryMeta): 16 | success: bool 17 | 18 | 19 | @dataclass 20 | class ValidationResult(metaclass=RegistryMeta): 21 | should_publish: bool 22 | -------------------------------------------------------------------------------- /reagent/core/result_types.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from reagent.core.dataclasses import dataclass 7 | from reagent.core.result_registries import PublishingResult, ValidationResult 8 | 9 | 10 | @dataclass 11 | class NoPublishingResults(PublishingResult): 12 | __registry_name__ = "no_publishing_results" 13 | 14 | 15 | @dataclass 16 | class NoValidationResults(ValidationResult): 17 | __registry_name__ = "no_validation_results" 18 | -------------------------------------------------------------------------------- /reagent/data/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from .data_fetcher import DataFetcher 7 | from .manual_data_module import ManualDataModule 8 | from .reagent_data_module import ReAgentDataModule 9 | 10 | __all__ = [ 11 | "DataFetcher", 12 | "ManualDataModule", 13 | "ReAgentDataModule", 14 | ] 15 | -------------------------------------------------------------------------------- /reagent/data/data_fetcher.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import logging 7 | from typing import List, Optional, Tuple 8 | 9 | # pyre-fixme[21]: Could not find module `reagent.workflow.types`. 10 | from reagent.workflow.types import Dataset, TableSpec 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class DataFetcher: 17 | def query_data( 18 | self, 19 | # pyre-fixme[11]: Annotation `TableSpec` is not defined as a type. 20 | input_table_spec: TableSpec, 21 | discrete_action: bool, 22 | actions: Optional[List[str]] = None, 23 | include_possible_actions=True, 24 | custom_reward_expression: Optional[str] = None, 25 | sample_range: Optional[Tuple[float, float]] = None, 26 | multi_steps: Optional[int] = None, 27 | gamma: Optional[float] = None, 28 | # pyre-fixme[11]: Annotation `Dataset` is not defined as a type. 29 | ) -> Dataset: 30 | raise NotImplementedError() 31 | 32 | def query_data_synthetic_reward( 33 | self, 34 | input_table_spec: TableSpec, 35 | discrete_action_names: Optional[List[str]] = None, 36 | sample_range: Optional[Tuple[float, float]] = None, 37 | max_seq_len: Optional[int] = None, 38 | ) -> Dataset: 39 | raise NotImplementedError() 40 | -------------------------------------------------------------------------------- /reagent/data/reagent_data_module.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import abc 7 | from typing import Dict, List, Optional 8 | 9 | import pytorch_lightning as pl 10 | from reagent.core.parameters import NormalizationData 11 | 12 | 13 | class ReAgentDataModule(pl.LightningDataModule): 14 | def __init__(self) -> None: 15 | super().__init__() 16 | 17 | @abc.abstractmethod 18 | def get_normalization_data_map( 19 | self, 20 | keys: Optional[List[str]] = None, 21 | ) -> Dict[str, NormalizationData]: 22 | pass 23 | 24 | @abc.abstractproperty 25 | def train_dataset(self): 26 | pass 27 | 28 | @abc.abstractproperty 29 | def eval_dataset(self): 30 | pass 31 | 32 | @abc.abstractproperty 33 | def test_dataset(self): 34 | pass 35 | -------------------------------------------------------------------------------- /reagent/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/evaluation/cb/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/reagent/evaluation/cb/__init__.py -------------------------------------------------------------------------------- /reagent/evaluation/feature_importance/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/evaluation/feature_importance/feature_importance_base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | from typing import List 6 | 7 | import pandas as pd 8 | import torch.nn as nn 9 | from reagent.core.dataclasses import dataclass 10 | 11 | 12 | @dataclass 13 | class FeatureImportanceBase: 14 | model: nn.Module 15 | sorted_feature_ids: List[int] 16 | 17 | def compute_feature_importance(self) -> pd.DataFrame: 18 | raise NotImplementedError() 19 | -------------------------------------------------------------------------------- /reagent/gym/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from .agents.agent import Agent 7 | from .envs.gym import Gym 8 | 9 | 10 | __all__ = ["Agent", "Gym"] 11 | -------------------------------------------------------------------------------- /reagent/gym/agents/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/gym/agents/post_step.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | 7 | import logging 8 | 9 | import gym 10 | from reagent.gym.preprocessors import make_replay_buffer_inserter 11 | from reagent.gym.types import Transition 12 | from reagent.replay_memory.circular_replay_buffer import ReplayBuffer 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def add_replay_buffer_post_step( 19 | replay_buffer: ReplayBuffer, 20 | env: gym.Env, 21 | replay_buffer_inserter=None, 22 | ): 23 | """ 24 | Simply add transitions to replay_buffer. 25 | """ 26 | 27 | if replay_buffer_inserter is None: 28 | replay_buffer_inserter = make_replay_buffer_inserter(env) 29 | 30 | def post_step(transition: Transition) -> None: 31 | replay_buffer_inserter(replay_buffer, transition) 32 | 33 | return post_step 34 | -------------------------------------------------------------------------------- /reagent/gym/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/gym/datasets/episodic_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import logging 7 | from typing import Optional 8 | 9 | import torch 10 | from reagent.gym.agents.agent import Agent 11 | from reagent.gym.envs.gym import Gym 12 | from reagent.gym.runners.gymrunner import run_episode 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class EpisodicDataset(torch.utils.data.IterableDataset): 19 | def __init__( 20 | self, 21 | env: Gym, 22 | agent: Agent, 23 | num_episodes: int, 24 | seed: int = 0, 25 | max_steps: Optional[int] = None, 26 | ): 27 | self.env = env 28 | self.agent = agent 29 | self.num_episodes = num_episodes 30 | self.seed = seed 31 | self.max_steps = max_steps 32 | 33 | def __iter__(self): 34 | self.env.reset() 35 | for i in range(self.num_episodes): 36 | trajectory = run_episode( 37 | self.env, self.agent, max_steps=self.max_steps, mdp_id=i 38 | ) 39 | yield trajectory.to_dict() 40 | 41 | def __len__(self): 42 | return self.num_episodes 43 | -------------------------------------------------------------------------------- /reagent/gym/envs/dynamics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | # pyre-unsafe 4 | -------------------------------------------------------------------------------- /reagent/gym/envs/functionality/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | # pyre-unsafe 4 | -------------------------------------------------------------------------------- /reagent/gym/envs/pomdp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | # pyre-unsafe 4 | -------------------------------------------------------------------------------- /reagent/gym/envs/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import logging 7 | 8 | from gym.envs.registration import register, registry 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def register_if_not_exists(id, entry_point): 15 | """ 16 | Preventing tests from failing trying to re-register environments 17 | """ 18 | if id not in registry.env_specs: 19 | logger.debug(f"Registering id={id}, entry_point={entry_point}.") 20 | register(id=id, entry_point=entry_point) 21 | -------------------------------------------------------------------------------- /reagent/gym/envs/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | # pyre-unsafe 4 | -------------------------------------------------------------------------------- /reagent/gym/envs/wrappers/simple_minigrid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import gym 7 | import gym_minigrid # noqa 8 | import numpy as np 9 | from gym import spaces 10 | from gym_minigrid.minigrid import DIR_TO_VEC 11 | 12 | 13 | NUM_DIRECTIONS: int = len(DIR_TO_VEC) 14 | 15 | 16 | class SimpleObsWrapper(gym.core.ObservationWrapper): 17 | """ 18 | Encode the agent's position & direction in a one-hot vector 19 | """ 20 | 21 | def __init__(self, env) -> None: 22 | super().__init__(env) 23 | 24 | self.observation_space = spaces.Box( 25 | low=0, 26 | high=1, 27 | shape=(self.env.width * self.env.height * NUM_DIRECTIONS,), 28 | dtype="float32", 29 | ) 30 | 31 | def observation(self, obs): 32 | retval = np.zeros( 33 | (self.env.width * self.env.height * NUM_DIRECTIONS,), dtype=np.float32 34 | ) 35 | retval[ 36 | self.env.agent_pos[0] * self.env.height * NUM_DIRECTIONS 37 | + self.env.agent_pos[1] * NUM_DIRECTIONS 38 | + self.env.agent_dir 39 | ] = 1.0 40 | return retval 41 | -------------------------------------------------------------------------------- /reagent/gym/policies/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from .policy import Policy 7 | 8 | 9 | __all__ = ["Policy"] 10 | -------------------------------------------------------------------------------- /reagent/gym/policies/policy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from typing import Any, Optional 7 | 8 | import reagent.core.types as rlt 9 | import torch 10 | from reagent.gym.types import Sampler, Scorer 11 | 12 | 13 | class Policy: 14 | def __init__(self, scorer: Scorer, sampler: Sampler) -> None: 15 | """ 16 | The Policy composes the scorer and sampler to create actions. 17 | 18 | Args: 19 | scorer: given preprocessed input, outputs intermediate scores 20 | used for sampling actions 21 | sampler: given scores (from the scorer), samples an action. 22 | """ 23 | self.scorer = scorer 24 | self.sampler = sampler 25 | 26 | def act( 27 | self, obs: Any, possible_actions_mask: Optional[torch.Tensor] = None 28 | ) -> rlt.ActorOutput: 29 | """ 30 | Performs the composition described above. 31 | These are the actions being put into the replay buffer, not necessary 32 | the actions taken by the environment! 33 | """ 34 | scorer_inputs = (obs,) 35 | if possible_actions_mask is not None: 36 | scorer_inputs += (possible_actions_mask,) 37 | scores = self.scorer(*scorer_inputs) 38 | actor_output = self.sampler.sample_action(scores) 39 | return actor_output.cpu().detach() 40 | -------------------------------------------------------------------------------- /reagent/gym/policies/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/gym/policies/samplers/top_k_sampler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | 7 | import reagent.core.types as rlt 8 | import torch 9 | from reagent.gym.types import Sampler 10 | 11 | 12 | class TopKSampler(Sampler): 13 | def __init__(self, k: int) -> None: 14 | self.k = k 15 | 16 | def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput: 17 | top_values, item_idxs = torch.topk(scores, self.k, dim=1) 18 | return rlt.ActorOutput( 19 | action=item_idxs, log_prob=torch.zeros(item_idxs.shape[0], 1) 20 | ) 21 | 22 | def log_prob(self, scores: torch.Tensor, action: torch.Tensor) -> torch.Tensor: 23 | raise NotImplementedError 24 | -------------------------------------------------------------------------------- /reagent/gym/policies/scorers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/gym/policies/scorers/continuous_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import reagent.core.types as rlt 7 | import torch 8 | from reagent.gym.types import GaussianSamplerScore, Scorer 9 | from reagent.models.base import ModelBase 10 | 11 | 12 | def sac_scorer(actor_network: ModelBase) -> Scorer: 13 | @torch.no_grad() 14 | def score(preprocessed_obs: rlt.FeatureData) -> GaussianSamplerScore: 15 | actor_network.eval() 16 | # pyre-fixme[29]: `Union[Tensor, Module]` is not a function. 17 | loc, scale_log = actor_network._get_loc_and_scale_log(preprocessed_obs) 18 | actor_network.train() 19 | return GaussianSamplerScore(loc=loc, scale_log=scale_log) 20 | 21 | return score 22 | -------------------------------------------------------------------------------- /reagent/gym/preprocessors/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from .replay_buffer_inserters import make_replay_buffer_inserter 7 | from .trainer_preprocessor import ( 8 | make_replay_buffer_trainer_preprocessor, 9 | make_trainer_preprocessor_online, 10 | ) 11 | 12 | 13 | __all__ = [ 14 | "make_replay_buffer_trainer_preprocessor", 15 | "make_replay_buffer_inserter", 16 | "make_trainer_preprocessor_online", 17 | ] 18 | -------------------------------------------------------------------------------- /reagent/gym/runners/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/gym/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | # pyre-unsafe 4 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/cartpole/discrete_c51_cartpole_online.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | Gym: 3 | env_name: CartPole-v1 4 | model: 5 | DiscreteC51DQN: 6 | trainer_param: 7 | actions: 8 | - 0 9 | - 1 10 | rl: 11 | gamma: 0.9 12 | target_update_rate: 0.05 13 | maxq_learning: true 14 | temperature: 1.0 15 | double_q_learning: true 16 | minibatches_per_step: 1 17 | num_atoms: 21 18 | qmin: 0 19 | qmax: 40 20 | optimizer: 21 | AdamW: 22 | lr: 0.001 23 | amsgrad: true 24 | net_builder: 25 | Categorical: 26 | sizes: 27 | - 64 28 | - 64 29 | activations: 30 | - leaky_relu 31 | - leaky_relu 32 | eval_parameters: 33 | calc_cpe_in_training: false 34 | replay_memory_size: 100000 35 | train_every_ts: 1 36 | train_after_ts: 20000 37 | num_train_episodes: 40 38 | num_eval_episodes: 20 39 | passing_score_bar: 100.0 40 | use_gpu: false 41 | minibatch_size: 512 42 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/cartpole/discrete_crr_cartpole_online.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | Gym: 3 | env_name: CartPole-v0 4 | model: 5 | DiscreteCRR: 6 | trainer_param: 7 | actions: 8 | - 0 9 | - 1 10 | rl: 11 | gamma: 0.99 12 | target_update_rate: 0.2 13 | temperature: 0.1 14 | q_network_optimizer: 15 | Adam: 16 | lr: 0.001 17 | actor_network_optimizer: 18 | Adam: 19 | lr: 0.001 20 | use_target_actor: false 21 | double_q_learning: true 22 | delayed_policy_update: 1 23 | actor_net_builder: 24 | FullyConnected: 25 | exploration_variance: 0.0000001 26 | sizes: 27 | - 1024 28 | - 1024 29 | activations: 30 | - relu 31 | - relu 32 | critic_net_builder: 33 | FullyConnected: 34 | sizes: 35 | - 1024 36 | - 1024 37 | activations: 38 | - relu 39 | - relu 40 | eval_parameters: 41 | calc_cpe_in_training: false 42 | replay_memory_size: 20000 43 | train_every_ts: 1 44 | train_after_ts: 5000 45 | num_train_episodes: 25 46 | num_eval_episodes: 20 47 | passing_score_bar: 100 48 | use_gpu: false 49 | minibatch_size: 256 50 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/cartpole/discrete_dqn_cartpole_online.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | Gym: 3 | env_name: CartPole-v0 4 | model: 5 | DiscreteDQN: 6 | trainer_param: 7 | actions: 8 | - 0 9 | - 1 10 | rl: 11 | gamma: 0.99 12 | target_update_rate: 0.2 13 | maxq_learning: true 14 | temperature: 1.0 15 | double_q_learning: true 16 | minibatches_per_step: 1 17 | optimizer: 18 | Adam: 19 | lr: 0.01 20 | net_builder: 21 | FullyConnected: 22 | sizes: 23 | - 128 24 | - 64 25 | activations: 26 | - leaky_relu 27 | - leaky_relu 28 | eval_parameters: 29 | calc_cpe_in_training: false 30 | replay_memory_size: 100000 31 | train_every_ts: 1 32 | train_after_ts: 30000 33 | num_train_episodes: 120 34 | num_eval_episodes: 20 35 | passing_score_bar: 100.0 36 | use_gpu: false 37 | minibatch_size: 512 38 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/cartpole/discrete_ppo_cartpole_online.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | Gym: 3 | env_name: CartPole-v0 4 | model: 5 | PPO: 6 | trainer_param: 7 | actions: 8 | - 0 9 | - 1 10 | gamma: 0.99 11 | ppo_epsilon: 0.2 12 | optimizer: 13 | Adam: 14 | lr: 0.001 15 | weight_decay: 0.001 16 | update_freq: 2 17 | update_epochs: 1 18 | ppo_batch_size: 2 19 | policy_net_builder: 20 | FullyConnected: 21 | sizes: 22 | - 32 23 | - 32 24 | activations: 25 | - leaky_relu 26 | - leaky_relu 27 | sampler_temperature: 1.0 28 | num_train_episodes: 1000 29 | num_eval_episodes: 100 30 | passing_score_bar: 180.0 31 | use_gpu: false 32 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/cartpole/discrete_qr_cartpole_online.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | Gym: 3 | env_name: CartPole-v1 4 | model: 5 | DiscreteQRDQN: 6 | trainer_param: 7 | actions: 8 | - 0 9 | - 1 10 | rl: 11 | gamma: 0.9 12 | target_update_rate: 0.05 13 | maxq_learning: true 14 | temperature: 1.0 15 | double_q_learning: true 16 | minibatches_per_step: 1 17 | num_atoms: 11 18 | optimizer: 19 | AdamW: 20 | lr: 0.001 21 | amsgrad: true 22 | net_builder: 23 | DuelingQuantile: 24 | sizes: 25 | - 64 26 | - 64 27 | activations: 28 | - leaky_relu 29 | - leaky_relu 30 | eval_parameters: 31 | calc_cpe_in_training: false 32 | replay_memory_size: 100000 33 | train_every_ts: 1 34 | train_after_ts: 20000 35 | num_train_episodes: 40 36 | num_eval_episodes: 20 37 | passing_score_bar: 100.0 38 | use_gpu: false 39 | minibatch_size: 512 40 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/cartpole/discrete_reinforce_cartpole_online.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | Gym: 3 | env_name: CartPole-v0 4 | model: 5 | Reinforce: 6 | trainer_param: 7 | actions: 8 | - 0 9 | - 1 10 | gamma: 0.99 11 | off_policy: False 12 | optimizer: 13 | Adam: 14 | lr: 0.001 15 | normalize: False 16 | subtract_mean: True 17 | policy_net_builder: 18 | FullyConnected: 19 | sizes: 20 | - 64 21 | activations: 22 | - leaky_relu 23 | sampler_temperature: 1.0 24 | num_train_episodes: 1000 25 | num_eval_episodes: 100 26 | passing_score_bar: 180.0 27 | use_gpu: false 28 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/cartpole/parametric_dqn_cartpole_online.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | Gym: 3 | env_name: CartPole-v1 4 | model: 5 | ParametricDQN: 6 | trainer_param: 7 | rl: 8 | gamma: 0.99 9 | target_update_rate: 0.1 10 | maxq_learning: true 11 | temperature: 1.0 12 | double_q_learning: true 13 | minibatches_per_step: 1 14 | optimizer: 15 | AdamW: 16 | lr: 0.001 17 | amsgrad: true 18 | net_builder: 19 | FullyConnected: 20 | sizes: 21 | - 128 22 | - 64 23 | activations: 24 | - leaky_relu 25 | - leaky_relu 26 | eval_parameters: 27 | calc_cpe_in_training: false 28 | replay_memory_size: 100000 29 | train_every_ts: 1 30 | train_after_ts: 20000 31 | num_train_episodes: 90 32 | num_eval_episodes: 20 33 | passing_score_bar: 100.0 34 | use_gpu: false 35 | minibatch_size: 1024 36 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/cartpole/parametric_sarsa_cartpole_online.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | Gym: 3 | env_name: CartPole-v0 4 | model: 5 | ParametricDQN: 6 | trainer_param: 7 | rl: 8 | gamma: 0.99 9 | target_update_rate: 0.2 10 | # disabling maxq for sarsa 11 | # although strictly speaking, this is still 12 | # off-policy (due to replay buffer) so not 13 | # vanilla, on-policy sarsa 14 | maxq_learning: false 15 | temperature: 0.35 16 | double_q_learning: true 17 | minibatches_per_step: 1 18 | optimizer: 19 | Adam: 20 | lr: 0.05 21 | net_builder: 22 | FullyConnected: 23 | sizes: 24 | - 64 25 | - 64 26 | activations: 27 | - leaky_relu 28 | - leaky_relu 29 | eval_parameters: 30 | calc_cpe_in_training: false 31 | replay_memory_size: 50000 32 | train_every_ts: 1 33 | train_after_ts: 25000 34 | num_train_episodes: 30 35 | num_eval_episodes: 20 36 | passing_score_bar: 100.0 37 | use_gpu: false 38 | minibatch_size: 1024 39 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/functionality/dqn_possible_actions_mask.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | Gym: 3 | env_name: PossibleActionsMaskTester-v0 4 | model: 5 | DiscreteDQN: 6 | trainer_param: 7 | actions: 8 | - 0 9 | - 1 10 | - 2 11 | - 3 12 | rl: 13 | gamma: 1.0 14 | target_update_rate: 0.2 15 | maxq_learning: true 16 | temperature: 1.0 17 | double_q_learning: true 18 | minibatches_per_step: 1 19 | optimizer: 20 | Adam: 21 | lr: 0.05 22 | net_builder: 23 | FullyConnected: 24 | sizes: 25 | - 128 26 | - 64 27 | activations: 28 | - leaky_relu 29 | - leaky_relu 30 | eval_parameters: 31 | calc_cpe_in_training: false 32 | replay_memory_size: 5000 33 | train_every_ts: 1 34 | train_after_ts: 500 35 | num_train_episodes: 5 36 | num_eval_episodes: 3 37 | passing_score_bar: 200.0 38 | use_gpu: false 39 | minibatch_size: 512 40 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/open_gridworld/discrete_dqn_open_gridworld.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | Gym: 3 | env_name: MiniGrid-Empty-5x5-v0 4 | model: 5 | DiscreteDQN: 6 | trainer_param: 7 | actions: 8 | - 101 9 | - 102 10 | - 103 11 | - 104 12 | - 105 13 | - 106 14 | - 107 15 | rl: 16 | gamma: 0.99 17 | epsilon: 0.05 18 | target_update_rate: 0.1 19 | maxq_learning: true 20 | temperature: 0.01 21 | q_network_loss: mse 22 | double_q_learning: true 23 | minibatches_per_step: 1 24 | optimizer: 25 | Adam: 26 | lr: 0.01 27 | weight_decay: 0.01 28 | net_builder: 29 | FullyConnected: 30 | sizes: [] 31 | activations: [] 32 | eval_parameters: 33 | calc_cpe_in_training: false 34 | replay_memory_size: 2000 35 | train_every_ts: 3 36 | train_after_ts: 1 37 | num_train_episodes: 125 38 | num_eval_episodes: 20 39 | passing_score_bar: 0.9 40 | use_gpu: false 41 | minibatch_size: 512 42 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/pendulum/sac_pendulum_online.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | Gym: 3 | env_name: Pendulum-v0 4 | model: 5 | SAC: 6 | trainer_param: 7 | rl: 8 | gamma: 0.99 9 | target_update_rate: 0.005 10 | softmax_policy: true 11 | entropy_temperature: 0.3 12 | q_network_optimizer: 13 | Adam: 14 | lr: 0.001 15 | value_network_optimizer: 16 | Adam: 17 | lr: 0.001 18 | actor_network_optimizer: 19 | Adam: 20 | lr: 0.001 21 | alpha_optimizer: 22 | Adam: 23 | lr: 0.001 24 | actor_net_builder: 25 | GaussianFullyConnected: 26 | sizes: 27 | - 64 28 | - 64 29 | activations: 30 | - leaky_relu 31 | - leaky_relu 32 | critic_net_builder: 33 | FullyConnected: 34 | sizes: 35 | - 64 36 | - 64 37 | activations: 38 | - leaky_relu 39 | - leaky_relu 40 | value_net_builder: 41 | FullyConnected: 42 | sizes: 43 | - 64 44 | - 64 45 | activations: 46 | - leaky_relu 47 | - leaky_relu 48 | eval_parameters: 49 | calc_cpe_in_training: false 50 | replay_memory_size: 100000 51 | train_every_ts: 1 52 | train_after_ts: 20000 53 | num_train_episodes: 40 54 | num_eval_episodes: 20 55 | # Though maximal score is 0, we set lower bar to let tests finish in time 56 | passing_score_bar: -500 57 | use_gpu: false 58 | minibatch_size: 256 59 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/pendulum/td3_pendulum_online.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | Gym: 3 | env_name: Pendulum-v0 4 | model: 5 | TD3: 6 | trainer_param: 7 | rl: 8 | gamma: 0.99 9 | target_update_rate: 0.005 10 | q_network_optimizer: 11 | Adam: 12 | lr: 0.01 13 | actor_network_optimizer: 14 | Adam: 15 | lr: 0.005 16 | noise_variance: 0.2 17 | noise_clip: 0.5 18 | delayed_policy_update: 2 19 | actor_net_builder: 20 | FullyConnected: 21 | exploration_variance: 0.01 22 | sizes: 23 | - 64 24 | - 64 25 | activations: 26 | - leaky_relu 27 | - leaky_relu 28 | critic_net_builder: 29 | FullyConnected: 30 | sizes: 31 | - 64 32 | - 64 33 | activations: 34 | - leaky_relu 35 | - leaky_relu 36 | eval_parameters: 37 | calc_cpe_in_training: false 38 | replay_memory_size: 100000 39 | train_every_ts: 1 40 | train_after_ts: 5000 41 | num_train_episodes: 40 42 | num_eval_episodes: 1 43 | # Though maximal score is 0, we set lower bar to let tests finish in time 44 | passing_score_bar: -750 45 | use_gpu: false 46 | minibatch_size: 256 47 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/recsim/slate_q_recsim_online.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | RecSim: 3 | slate_size: 3 4 | num_candidates: 10 5 | model: 6 | SlateQ: 7 | slate_size: 3 8 | num_candidates: 10 9 | slate_feature_id: 1 # filler 10 | slate_score_id: [42, 42] # filler 11 | trainer_param: 12 | optimizer: 13 | Adam: 14 | lr: 0.001 15 | net_builder: 16 | FullyConnected: 17 | sizes: 18 | - 64 19 | - 64 20 | activations: 21 | - leaky_relu 22 | - leaky_relu 23 | replay_memory_size: 100000 24 | train_every_ts: 1 25 | train_after_ts: 5000 26 | num_train_episodes: 300 27 | num_eval_episodes: 20 28 | passing_score_bar: 154.0 29 | use_gpu: false 30 | minibatch_size: 1024 31 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/recsim/slate_q_recsim_online_maxq_topk.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | RecSim: 3 | slate_size: 3 4 | num_candidates: 10 5 | model: 6 | SlateQ: 7 | slate_size: 3 8 | num_candidates: 10 9 | slate_feature_id: 1 # filler 10 | slate_score_id: [42, 42] # filler 11 | trainer_param: 12 | rl: 13 | maxq_learning: True 14 | optimizer: 15 | Adam: 16 | lr: 0.001 17 | net_builder: 18 | FullyConnected: 19 | sizes: 20 | - 64 21 | - 64 22 | activations: 23 | - leaky_relu 24 | - leaky_relu 25 | replay_memory_size: 100000 26 | train_every_ts: 1 27 | train_after_ts: 5000 28 | num_train_episodes: 300 29 | num_eval_episodes: 20 30 | passing_score_bar: 154.0 31 | use_gpu: false 32 | minibatch_size: 1024 33 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/recsim/slate_q_recsim_online_multi_selection.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | RecSim: 3 | slate_size: 3 4 | num_candidates: 10 5 | model: 6 | SlateQ: 7 | slate_size: 3 8 | num_candidates: 10 9 | slate_feature_id: 1 # filler 10 | slate_score_id: [42, 42] # filler 11 | trainer_param: 12 | single_selection: False 13 | next_slate_value_norm_method: "norm_by_next_slate_size" 14 | optimizer: 15 | Adam: 16 | lr: 0.001 17 | net_builder: 18 | FullyConnected: 19 | sizes: 20 | - 64 21 | - 64 22 | activations: 23 | - leaky_relu 24 | - leaky_relu 25 | replay_memory_size: 100000 26 | train_every_ts: 1 27 | train_after_ts: 5000 28 | num_train_episodes: 300 29 | num_eval_episodes: 20 30 | passing_score_bar: 154.0 31 | use_gpu: false 32 | minibatch_size: 1024 33 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/recsim/slate_q_recsim_online_multi_selection_avg_curr.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | RecSim: 3 | slate_size: 3 4 | num_candidates: 10 5 | model: 6 | SlateQ: 7 | slate_size: 3 8 | num_candidates: 10 9 | slate_feature_id: 1 # filler 10 | slate_score_id: [42, 42] # filler 11 | trainer_param: 12 | single_selection: False 13 | next_slate_value_norm_method: "norm_by_current_slate_size" 14 | optimizer: 15 | Adam: 16 | lr: 0.001 17 | net_builder: 18 | FullyConnected: 19 | sizes: 20 | - 64 21 | - 64 22 | activations: 23 | - leaky_relu 24 | - leaky_relu 25 | replay_memory_size: 100000 26 | train_every_ts: 1 27 | train_after_ts: 5000 28 | num_train_episodes: 300 29 | num_eval_episodes: 20 30 | passing_score_bar: 154.0 31 | use_gpu: false 32 | minibatch_size: 1024 33 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/recsim/slate_q_recsim_online_with_time_scale.yaml: -------------------------------------------------------------------------------- 1 | env: 2 | RecSim: 3 | slate_size: 3 4 | num_candidates: 10 5 | model: 6 | SlateQ: 7 | slate_size: 3 8 | num_candidates: 10 9 | slate_feature_id: 1 # filler 10 | slate_score_id: [42, 42] # filler 11 | trainer_param: 12 | discount_time_scale: 2 13 | optimizer: 14 | Adam: 15 | lr: 0.001 16 | net_builder: 17 | FullyConnected: 18 | sizes: 19 | - 64 20 | - 64 21 | activations: 22 | - leaky_relu 23 | - leaky_relu 24 | replay_memory_size: 100000 25 | train_every_ts: 1 26 | train_after_ts: 5000 27 | num_train_episodes: 300 28 | num_eval_episodes: 20 29 | passing_score_bar: 154.0 30 | use_gpu: false 31 | minibatch_size: 1024 32 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/world_model/cartpole_features.yaml: -------------------------------------------------------------------------------- 1 | env_name: CartPole-v0 2 | model: 3 | WorldModel: 4 | trainer_param: 5 | hidden_size: 50 6 | num_hidden_layers: 2 7 | learning_rate: 0.001 8 | not_terminal_loss_weight: 1 9 | next_state_loss_weight: 1 10 | reward_loss_weight: 1 11 | num_gaussians: 1 12 | num_train_transitions: 100000 # approx. 500 episodes 13 | num_test_transitions: 6000 # approx. 30 episodes 14 | seq_len: 1 15 | batch_size: 1024 16 | num_train_epochs: 30 17 | use_gpu: false 18 | saved_mdnrnn_path: null 19 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/world_model/cem_cartpole_offline.yaml: -------------------------------------------------------------------------------- 1 | env_name: CartPole-v0 2 | model: 3 | CrossEntropyMethod: 4 | trainer_param: 5 | plan_horizon_length: 10 6 | num_world_models: 1 7 | cem_population_size: 100 8 | cem_num_iterations: 10 9 | ensemble_population_size: 1 10 | num_elites: 15 11 | mdnrnn: 12 | hidden_size: 100 13 | num_hidden_layers: 2 14 | learning_rate: 0.001 15 | not_terminal_loss_weight: 200.0 16 | next_state_loss_weight: 1.0 17 | reward_loss_weight: 1.0 18 | num_gaussians: 1 19 | rl: 20 | gamma: 1.0 21 | softmax_policy: 0 22 | replay_memory_size: 200000 23 | num_batches_per_epoch: 1000 24 | num_train_epochs: 1 25 | num_eval_episodes: 1 26 | passing_score_bar: 100.0 27 | minibatch_size: 1024 28 | use_gpu: false 29 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/world_model/cem_many_world_models_linear_dynamics_offline.yaml: -------------------------------------------------------------------------------- 1 | env_name: LinearDynamics-v0 2 | model: 3 | CrossEntropyMethod: 4 | trainer_param: 5 | plan_horizon_length: 4 6 | num_world_models: 2 7 | cem_population_size: 100 8 | cem_num_iterations: 10 9 | ensemble_population_size: 1 10 | num_elites: 15 11 | mdnrnn: 12 | hidden_size: 100 13 | num_hidden_layers: 2 14 | learning_rate: 0.001 15 | not_terminal_loss_weight: 0.0 16 | next_state_loss_weight: 1.0 17 | reward_loss_weight: 1.0 18 | num_gaussians: 1 19 | rl: 20 | gamma: 1.0 21 | softmax_policy: 0 22 | replay_memory_size: 50000 23 | num_batches_per_epoch: 5000 24 | num_train_epochs: 1 25 | num_eval_episodes: 1 26 | passing_score_bar: -2.5 27 | minibatch_size: 1024 28 | use_gpu: false 29 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/world_model/cem_single_world_model_linear_dynamics_offline.yaml: -------------------------------------------------------------------------------- 1 | env_name: LinearDynamics-v0 2 | model: 3 | CrossEntropyMethod: 4 | trainer_param: 5 | plan_horizon_length: 4 6 | num_world_models: 1 7 | cem_population_size: 100 8 | cem_num_iterations: 10 9 | ensemble_population_size: 1 10 | num_elites: 15 11 | mdnrnn: 12 | hidden_size: 100 13 | num_hidden_layers: 2 14 | learning_rate: 0.001 15 | not_terminal_loss_weight: 0.0 16 | next_state_loss_weight: 1.0 17 | reward_loss_weight: 1.0 18 | num_gaussians: 1 19 | rl: 20 | gamma: 1.0 21 | softmax_policy: 0 22 | minibatch_size: 1024 23 | replay_memory_size: 50000 24 | num_batches_per_epoch: 5000 25 | num_train_epochs: 1 26 | num_eval_episodes: 1 27 | passing_score_bar: -2.5 28 | use_gpu: false 29 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/world_model/discrete_dqn_string.yaml: -------------------------------------------------------------------------------- 1 | env_name: StringGame-v0 2 | 3 | # for training embedding model 4 | embedding_model: 5 | WorldModel: 6 | trainer_param: 7 | hidden_size: 20 8 | num_hidden_layers: 2 9 | learning_rate: 0.001 10 | not_terminal_loss_weight: 0 11 | next_state_loss_weight: 0 12 | reward_loss_weight: 1 13 | num_gaussians: 1 14 | seq_len: 3 15 | batch_size: 1024 16 | num_embedding_train_transitions: 24000 # approx. 4000 episodes 17 | num_embedding_train_epochs: 15 18 | saved_mdnrnn_path: null 19 | 20 | # for training agent 21 | num_state_embed_transitions: 6000 # approx 1000 episodes 22 | train_model: 23 | DiscreteDQN: 24 | trainer_param: 25 | actions: 26 | - 0 27 | - 1 28 | rl: 29 | gamma: 0.99 30 | target_update_rate: 0.1 31 | maxq_learning: true 32 | q_network_loss: mse 33 | double_q_learning: true 34 | minibatch_size: 1024 35 | minibatches_per_step: 1 36 | optimizer: 37 | Adam: 38 | lr: 0.001 39 | net_builder: 40 | FullyConnected: 41 | sizes: 42 | - 128 43 | - 64 44 | activations: 45 | - leaky_relu 46 | - leaky_relu 47 | eval_parameters: 48 | calc_cpe_in_training: false 49 | num_agent_train_epochs: 100 50 | num_agent_eval_epochs: 10 51 | use_gpu: false 52 | # highest score, which requires history insight, is 10.0 53 | passing_score_bar: 10.0 54 | -------------------------------------------------------------------------------- /reagent/gym/tests/configs/world_model/seq2reward_test.yaml: -------------------------------------------------------------------------------- 1 | env_name: StringGame-v0 2 | model: 3 | Seq2RewardModel: 4 | trainer_param: 5 | learning_rate: 0.005 6 | multi_steps: 6 7 | action_names: ["0","1"] 8 | num_train_transitions: 100000 # approx. 500 episodes 9 | num_test_transitions: 6000 # approx. 30 episodes 10 | seq_len: 6 11 | batch_size: 1024 12 | num_train_epochs: 20 13 | use_gpu: false 14 | saved_seq2reward_path: null 15 | -------------------------------------------------------------------------------- /reagent/gym/tests/preprocessors/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/lite/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/mab/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | # pyre-unsafe 4 | -------------------------------------------------------------------------------- /reagent/model_managers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/model_managers/actor_critic/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from .sac import SAC 7 | from .td3 import TD3 8 | 9 | 10 | __all__ = ["SAC", "TD3"] 11 | -------------------------------------------------------------------------------- /reagent/model_managers/discrete/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from .discrete_c51dqn import DiscreteC51DQN 7 | from .discrete_crr import DiscreteCRR 8 | from .discrete_dqn import DiscreteDQN 9 | from .discrete_qrdqn import DiscreteQRDQN 10 | 11 | __all__ = ["DiscreteC51DQN", "DiscreteDQN", "DiscreteQRDQN", "DiscreteCRR"] 12 | -------------------------------------------------------------------------------- /reagent/model_managers/model_based/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from .cross_entropy_method import CrossEntropyMethod 7 | from .seq2reward_model import Seq2RewardModel 8 | from .synthetic_reward import SyntheticReward 9 | from .world_model import WorldModel 10 | 11 | 12 | __all__ = ["WorldModel", "CrossEntropyMethod", "Seq2RewardModel", "SyntheticReward"] 13 | -------------------------------------------------------------------------------- /reagent/model_managers/parametric/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from .parametric_dqn import ParametricDQN 7 | 8 | 9 | __all__ = ["ParametricDQN"] 10 | -------------------------------------------------------------------------------- /reagent/model_managers/policy_gradient/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from .ppo import PPO 7 | from .reinforce import Reinforce 8 | 9 | __all__ = ["Reinforce", "PPO"] 10 | -------------------------------------------------------------------------------- /reagent/model_managers/ranking/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from .slate_q import SlateQ 7 | 8 | 9 | __all__ = ["SlateQ"] 10 | -------------------------------------------------------------------------------- /reagent/model_utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from .actor import ( 7 | DirichletFullyConnectedActor, 8 | FullyConnectedActor, 9 | GaussianFullyConnectedActor, 10 | ) 11 | from .base import ModelBase 12 | from .bcq import BatchConstrainedDQN 13 | from .categorical_dqn import CategoricalDQN 14 | from .containers import Sequential 15 | from .critic import FullyConnectedCritic 16 | from .dqn import FullyConnectedDQN 17 | from .dueling_q_network import DuelingQNetwork, ParametricDuelingQNetwork 18 | from .embedding_bag_concat import EmbeddingBagConcat 19 | from .fully_connected_network import FullyConnectedNetwork 20 | from .mlp_scorer import MLPScorer 21 | from .seq2reward_model import Seq2RewardNetwork 22 | 23 | 24 | __all__ = [ 25 | "ModelBase", 26 | "Sequential", 27 | "FullyConnectedDQN", 28 | "DuelingQNetwork", 29 | "ParametricDuelingQNetwork", 30 | "BatchConstrainedDQN", 31 | "CategoricalDQN", 32 | "EmbeddingBagConcat", 33 | "FullyConnectedNetwork", 34 | "FullyConnectedCritic", 35 | "GaussianFullyConnectedActor", 36 | "DirichletFullyConnectedActor", 37 | "FullyConnectedActor", 38 | "MLPScorer", 39 | "Seq2RewardNetwork", 40 | ] 41 | -------------------------------------------------------------------------------- /reagent/models/bcq.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import torch 7 | from reagent.core import types as rlt 8 | from reagent.models.base import ModelBase 9 | 10 | 11 | class BatchConstrainedDQN(ModelBase): 12 | def __init__( 13 | self, state_dim, q_network, imitator_network, bcq_drop_threshold 14 | ) -> None: 15 | super().__init__() 16 | assert state_dim > 0, "state_dim must be > 0, got {}".format(state_dim) 17 | self.state_dim = state_dim 18 | self.q_network = q_network 19 | self.imitator_network = imitator_network 20 | self.invalid_action_penalty = -1e10 21 | self.bcq_drop_threshold = bcq_drop_threshold 22 | 23 | def input_prototype(self): 24 | return self.q_network.input_prototype() 25 | 26 | def forward(self, state: rlt.FeatureData): 27 | q_values = self.q_network(state) 28 | imitator_outputs = self.imitator_network(state.float_features) 29 | imitator_probs = torch.nn.functional.softmax(imitator_outputs, dim=1) 30 | filter_values = imitator_probs / imitator_probs.max(keepdim=True, dim=1)[0] 31 | invalid_actions = (filter_values < self.bcq_drop_threshold).float() 32 | invalid_action_penalty = self.invalid_action_penalty * invalid_actions 33 | constrained_q_values = q_values + invalid_action_penalty 34 | return constrained_q_values 35 | -------------------------------------------------------------------------------- /reagent/models/categorical_dqn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from reagent.core import types as rlt 9 | from reagent.models.base import ModelBase 10 | 11 | 12 | class CategoricalDQN(ModelBase): 13 | def __init__( 14 | self, 15 | distributional_network: ModelBase, 16 | *, 17 | qmin: float, 18 | qmax: float, 19 | num_atoms: int, 20 | ) -> None: 21 | super().__init__() 22 | self.distributional_network = distributional_network 23 | self.support = torch.linspace(qmin, qmax, num_atoms) 24 | 25 | def input_prototype(self): 26 | return self.distributional_network.input_prototype() 27 | 28 | def forward(self, state: rlt.FeatureData): 29 | dist = self.log_dist(state).exp() 30 | q_values = (dist * self.support.to(dist.device)).sum(2) 31 | return q_values 32 | 33 | def log_dist(self, state: rlt.FeatureData) -> torch.Tensor: 34 | log_dist = self.distributional_network(state) 35 | return F.log_softmax(log_dist, -1) 36 | -------------------------------------------------------------------------------- /reagent/models/cb_base_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import logging 7 | from abc import ABC, abstractmethod 8 | from typing import Dict, Optional 9 | 10 | import torch 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class UCBBaseModel(torch.nn.Module, ABC): 16 | """ 17 | Abstract base class for UCB-style CB models. 18 | """ 19 | 20 | def __init__(self, input_dim: int): 21 | super().__init__() 22 | self.input_dim = input_dim 23 | 24 | def input_prototype(self) -> torch.Tensor: 25 | return torch.randn(1, self.input_dim) 26 | 27 | @abstractmethod 28 | def forward( 29 | self, inp: torch.Tensor, ucb_alpha: Optional[float] = None 30 | ) -> Dict[str, torch.Tensor]: 31 | """ 32 | Model forward pass. 33 | Returns pred_label, pred_sigma, ucb (where ucb = pred_label + ucb_alpha*pred_sigma) 34 | """ 35 | pass 36 | 37 | def forward_inference( 38 | self, inp: torch.Tensor, ucb_alpha: Optional[float] = None 39 | ) -> Dict[str, torch.Tensor]: 40 | """ 41 | This forward method will be called by the inference wrapper. 42 | By default it's same as regular forward(), but users can override it 43 | if they need special behavior in the inference wrapper. 44 | """ 45 | return self.forward(inp, ucb_alpha=ucb_alpha) 46 | -------------------------------------------------------------------------------- /reagent/models/containers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import torch.nn as nn 7 | from reagent.models.base import ModelBase 8 | 9 | 10 | class Sequential( 11 | nn.Sequential, # type: ignore 12 | ModelBase, 13 | ): 14 | """ 15 | Used this instead of torch.nn.Sequential to automate model tracing 16 | """ 17 | 18 | def input_prototype(self): 19 | first = self[0] 20 | assert isinstance( 21 | first, ModelBase 22 | ), "The first module of Sequential has to be ModelBase" 23 | return first.input_prototype() 24 | -------------------------------------------------------------------------------- /reagent/models/mlp_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | 7 | import reagent.core.types as rlt 8 | import torch 9 | from reagent.models.base import ModelBase 10 | 11 | 12 | class MLPScorer(ModelBase): 13 | """ 14 | Log-space in and out 15 | """ 16 | 17 | def __init__( 18 | self, 19 | mlp: torch.nn.Module, 20 | has_user_feat: bool = False, 21 | ) -> None: 22 | super().__init__() 23 | self.mlp = mlp 24 | self.has_user_feat = has_user_feat 25 | 26 | def forward(self, obs: rlt.FeatureData): 27 | mlp_input = obs.get_ranking_state(self.has_user_feat) 28 | scores = self.mlp(mlp_input) 29 | return scores.squeeze(-1) 30 | 31 | def input_prototype(self): 32 | # Sample config for input 33 | batch_size = 2 34 | state_dim = 5 35 | num_docs = 3 36 | candidate_dim = 4 37 | return rlt.FeatureData( 38 | float_features=torch.randn((batch_size, state_dim)), 39 | candidate_docs=rlt.DocList( 40 | float_features=torch.randn(batch_size, num_docs, candidate_dim) 41 | ), 42 | ) 43 | -------------------------------------------------------------------------------- /reagent/models/model_feature_config_provider.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | import abc 6 | 7 | import reagent.core.types as rlt 8 | from reagent.core.dataclasses import dataclass 9 | from reagent.core.registry_meta import RegistryMeta 10 | 11 | 12 | class ModelFeatureConfigProvider(metaclass=RegistryMeta): 13 | @abc.abstractmethod 14 | def get_model_feature_config(self) -> rlt.ModelFeatureConfig: 15 | pass 16 | 17 | 18 | @dataclass 19 | class RawModelFeatureConfigProvider(ModelFeatureConfigProvider, rlt.ModelFeatureConfig): 20 | __registry_name__ = "raw" 21 | 22 | def get_model_feature_config(self) -> rlt.ModelFeatureConfig: 23 | return self 24 | -------------------------------------------------------------------------------- /reagent/models/no_soft_update_embedding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import copy 7 | 8 | import torch.nn as nn 9 | 10 | 11 | class NoSoftUpdateEmbedding(nn.Embedding): 12 | """ 13 | Use this instead of vanilla Embedding module to avoid soft-updating the embedding 14 | table in the target network. 15 | """ 16 | 17 | def __deepcopy__(self, memo) -> "NoSoftUpdateEmbedding": 18 | return copy.copy(self) 19 | -------------------------------------------------------------------------------- /reagent/models/residual_wrapper.py: -------------------------------------------------------------------------------- 1 | # pyre-unsafe 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class ResidualWrapper(nn.Module): 7 | """ 8 | A wrapper block for residual networks. It is used to wrap a single layer of the network. 9 | 10 | Example: 11 | layers = [] 12 | for layer in layer_generator: 13 | layers.append(ResidualWrapper(layer)) 14 | model = torch.nn.Sequential(*layers) 15 | """ 16 | 17 | def __init__(self, module: nn.Module): 18 | super().__init__() 19 | self.module = module 20 | 21 | def forward(self, x: torch.Tensor) -> torch.Tensor: 22 | return x + self.module(x) 23 | -------------------------------------------------------------------------------- /reagent/net_builder/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/net_builder/categorical_dqn/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | from . import categorical # noqa 6 | -------------------------------------------------------------------------------- /reagent/net_builder/continuous_actor/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from . import ( # noqa # noqa # noqa 7 | dirichlet_fully_connected, 8 | fully_connected, 9 | gaussian_fully_connected, 10 | ) 11 | -------------------------------------------------------------------------------- /reagent/net_builder/discrete_actor/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from . import fully_connected # noqa 7 | -------------------------------------------------------------------------------- /reagent/net_builder/discrete_dqn/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from . import ( # noqa # noqa # noqa 7 | dueling, 8 | fully_connected, 9 | fully_connected_with_embedding, 10 | ) 11 | -------------------------------------------------------------------------------- /reagent/net_builder/discrete_dqn/dueling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from typing import List 7 | 8 | from reagent.core import types as rlt 9 | from reagent.core.dataclasses import dataclass, field 10 | from reagent.core.parameters import NormalizationData, param_hash 11 | from reagent.models.base import ModelBase 12 | from reagent.models.dueling_q_network import DuelingQNetwork 13 | from reagent.net_builder.discrete_dqn_net_builder import DiscreteDQNNetBuilder 14 | 15 | 16 | @dataclass 17 | class Dueling(DiscreteDQNNetBuilder): 18 | __hash__ = param_hash 19 | 20 | sizes: List[int] = field(default_factory=lambda: [256, 128]) 21 | activations: List[str] = field(default_factory=lambda: ["relu", "relu"]) 22 | 23 | def __post_init_post_parse__(self) -> None: 24 | assert len(self.sizes) == len(self.activations), ( 25 | f"Must have the same numbers of sizes and activations; got: " 26 | f"{self.sizes}, {self.activations}" 27 | ) 28 | 29 | def build_q_network( 30 | self, 31 | state_feature_config: rlt.ModelFeatureConfig, 32 | state_normalization_data: NormalizationData, 33 | output_dim: int, 34 | ) -> ModelBase: 35 | state_dim = self._get_input_dim(state_normalization_data) 36 | return DuelingQNetwork.make_fully_connected( 37 | state_dim, output_dim, self.sizes, self.activations 38 | ) 39 | -------------------------------------------------------------------------------- /reagent/net_builder/parametric_dqn/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from . import fully_connected # noqa 7 | -------------------------------------------------------------------------------- /reagent/net_builder/quantile_dqn/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | from . import dueling_quantile, quantile # noqa # noqa 6 | -------------------------------------------------------------------------------- /reagent/net_builder/slate_ranking/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from typing import Optional 7 | 8 | from reagent.core.registry_meta import wrap_oss_with_dataclass 9 | from reagent.core.tagged_union import TaggedUnion 10 | 11 | from .slate_ranking_scorer import SlateRankingScorer as SlateRankingScorerT 12 | from .slate_ranking_transformer import ( 13 | SlateRankingTransformer as SlateRankingTransformerType, 14 | ) 15 | 16 | 17 | @wrap_oss_with_dataclass 18 | class SlateRankingNetBuilder__Union(TaggedUnion): 19 | SlateRankingTransformer: Optional[SlateRankingTransformerType] = None 20 | SlateRankingScorer: Optional[SlateRankingScorerT] = None 21 | -------------------------------------------------------------------------------- /reagent/net_builder/slate_ranking_net_builder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import abc 7 | 8 | import torch 9 | 10 | 11 | class SlateRankingNetBuilder: 12 | """ 13 | Base class for slate ranking network builder. 14 | """ 15 | 16 | @abc.abstractmethod 17 | def build_slate_ranking_network( 18 | self, state_dim, candidate_dim, candidate_size, slate_size 19 | ) -> torch.nn.Module: 20 | pass 21 | -------------------------------------------------------------------------------- /reagent/net_builder/slate_reward/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from typing import Optional 7 | 8 | from reagent.core.registry_meta import wrap_oss_with_dataclass 9 | from reagent.core.tagged_union import TaggedUnion 10 | 11 | from .slate_reward_gru import SlateRewardGRU as SlateRewardGRUType 12 | from .slate_reward_transformer import ( 13 | SlateRewardTransformer as SlateRewardTransformerType, 14 | ) 15 | 16 | 17 | @wrap_oss_with_dataclass 18 | class SlateRewardNetBuilder__Union(TaggedUnion): 19 | SlateRewardGRU: Optional[SlateRewardGRUType] = None 20 | SlateRewardTransformer: Optional[SlateRewardTransformerType] = None 21 | -------------------------------------------------------------------------------- /reagent/net_builder/slate_reward/slate_reward_gru.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from reagent.core.dataclasses import dataclass, field 7 | from reagent.core.parameters import GRUParameters, param_hash 8 | from reagent.models.base import ModelBase 9 | from reagent.models.seq2slate_reward import Seq2SlateGRURewardNet 10 | from reagent.net_builder.slate_reward_net_builder import SlateRewardNetBuilder 11 | 12 | 13 | @dataclass 14 | class SlateRewardGRU(SlateRewardNetBuilder): 15 | __hash__ = param_hash 16 | 17 | gru: GRUParameters = field( 18 | default_factory=lambda: GRUParameters(dim_model=16, num_stacked_layers=2) 19 | ) 20 | fit_slate_wise_reward: bool = True 21 | 22 | def build_slate_reward_network( 23 | self, state_dim, candidate_dim, candidate_size, slate_size 24 | ) -> ModelBase: 25 | seq2slate_reward_net = Seq2SlateGRURewardNet( 26 | state_dim=state_dim, 27 | candidate_dim=candidate_dim, 28 | num_stacked_layers=self.gru.num_stacked_layers, 29 | dim_model=self.gru.dim_model, 30 | max_src_seq_len=candidate_size, 31 | max_tgt_seq_len=slate_size, 32 | ) 33 | return seq2slate_reward_net 34 | 35 | @property 36 | def expect_slate_wise_reward(self) -> bool: 37 | return self.fit_slate_wise_reward 38 | -------------------------------------------------------------------------------- /reagent/net_builder/slate_reward_net_builder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import abc 7 | 8 | import torch 9 | 10 | 11 | class SlateRewardNetBuilder: 12 | """ 13 | Base class for slate reward network builder. 14 | """ 15 | 16 | @abc.abstractmethod 17 | def build_slate_reward_network( 18 | self, state_dim, candidate_dim, candidate_size, slate_size 19 | ) -> torch.nn.Module: 20 | pass 21 | 22 | @abc.abstractproperty 23 | def expect_slate_wise_reward(self) -> bool: 24 | pass 25 | -------------------------------------------------------------------------------- /reagent/net_builder/synthetic_reward/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | from . import ( # noqa # noqa # noqa # noqa 6 | ngram_synthetic_reward, 7 | sequence_synthetic_reward, 8 | single_step_synthetic_reward, 9 | single_step_synthetic_reward_sparse_arch, 10 | ) 11 | -------------------------------------------------------------------------------- /reagent/net_builder/value/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from . import fully_connected, seq2reward_rnn # noqa # noqa 7 | -------------------------------------------------------------------------------- /reagent/net_builder/value/seq2reward_rnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import torch 7 | from reagent.core.dataclasses import dataclass 8 | from reagent.core.parameters import NormalizationData, param_hash 9 | from reagent.models.seq2reward_model import Seq2RewardNetwork 10 | from reagent.net_builder.value_net_builder import ValueNetBuilder 11 | from reagent.preprocessing.normalization import get_num_output_features 12 | 13 | 14 | @dataclass 15 | class Seq2RewardNetBuilder(ValueNetBuilder): 16 | __hash__ = param_hash 17 | action_dim: int = 2 18 | num_hiddens: int = 64 19 | num_hidden_layers: int = 2 20 | 21 | def build_value_network( 22 | self, state_normalization_data: NormalizationData 23 | ) -> torch.nn.Module: 24 | state_dim = get_num_output_features( 25 | state_normalization_data.dense_normalization_parameters 26 | ) 27 | 28 | return Seq2RewardNetwork( 29 | state_dim=state_dim, 30 | action_dim=self.action_dim, 31 | num_hiddens=self.num_hiddens, 32 | num_hidden_layers=self.num_hidden_layers, 33 | ) 34 | -------------------------------------------------------------------------------- /reagent/net_builder/value_net_builder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import abc 7 | 8 | import torch 9 | from reagent.core.parameters import NormalizationData 10 | 11 | 12 | class ValueNetBuilder: 13 | """ 14 | Base class for value-network builder. 15 | """ 16 | 17 | @abc.abstractmethod 18 | def build_value_network( 19 | self, state_normalization_data: NormalizationData 20 | ) -> torch.nn.Module: 21 | pass 22 | -------------------------------------------------------------------------------- /reagent/ope/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/ope/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | -------------------------------------------------------------------------------- /reagent/ope/estimators/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/ope/test/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/ope/test/configs/ecoli_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "file": "data/ecoli.data", 4 | "sep": "\\s+", 5 | "index_col": 0, 6 | "label_col": 8 7 | } 8 | } -------------------------------------------------------------------------------- /reagent/ope/test/configs/letter_recog_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "file": "data/letter-recognition.data", 4 | "sep": ",", 5 | "label_col": 0 6 | } 7 | } -------------------------------------------------------------------------------- /reagent/ope/test/configs/optdigits_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "file": "data/optdigits.data", 4 | "sep": ",", 5 | "label_col": 64 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /reagent/ope/test/configs/pendigits_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "file": "data/pendigits.data", 4 | "sep": ",", 5 | "label_col": 16 6 | } 7 | } -------------------------------------------------------------------------------- /reagent/ope/test/configs/satimage_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "file": "data/satimage.data", 4 | "sep": " ", 5 | "label_col": 36 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /reagent/ope/test/configs/yandex_web_search_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "raw_data": { 3 | "folder": "data/Yandex_Web_Search", 4 | "cache_folder": "", 5 | "source_file": "train", 6 | "total_days": 27 7 | }, 8 | "log_data": { 9 | "folder": "data/Yandex_Web_Search", 10 | "base_file_name": "train", 11 | "days": [ 12 | 1, 13 | 3, 14 | 5, 15 | 7, 16 | 11, 17 | 13, 18 | 15, 19 | 17, 20 | 19, 21 | 21, 22 | 23, 23 | 25 24 | ], 25 | "cache_file": "log_dataset.pickle", 26 | "min_query_count": 10 27 | }, 28 | "target_data": { 29 | "folder": "data/Yandex_Web_Search", 30 | "base_file_name": "train", 31 | "days": [ 32 | 2, 33 | 4, 34 | 6, 35 | 8, 36 | 10, 37 | 12, 38 | 14, 39 | 16, 40 | 20, 41 | 22, 42 | 24, 43 | 26 44 | ], 45 | "cache_file": "target_dataset.pickle", 46 | "min_query_count": 10 47 | }, 48 | "test_data": { 49 | "folder": "data/Yandex_Web_Search", 50 | "base_file_name": "train", 51 | "days": [ 52 | 9 53 | ], 54 | "cache_file_name": "test_log" 55 | } 56 | } -------------------------------------------------------------------------------- /reagent/ope/test/notebooks/img/bias.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/reagent/ope/test/notebooks/img/bias.png -------------------------------------------------------------------------------- /reagent/ope/test/notebooks/img/rmse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/reagent/ope/test/notebooks/img/rmse.png -------------------------------------------------------------------------------- /reagent/ope/test/notebooks/img/variance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/reagent/ope/test/notebooks/img/variance.png -------------------------------------------------------------------------------- /reagent/ope/test/unit_tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | -------------------------------------------------------------------------------- /reagent/ope/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from .soft_update import SoftUpdate 7 | from .union import Optimizer__Union 8 | 9 | 10 | __all__ = ["Optimizer__Union", "SoftUpdate"] 11 | -------------------------------------------------------------------------------- /reagent/optimizer/scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import inspect 7 | from typing import Any, Dict 8 | 9 | import torch 10 | from reagent.core.dataclasses import dataclass 11 | from reagent.core.registry_meta import RegistryMeta 12 | 13 | from .utils import is_torch_lr_scheduler 14 | 15 | 16 | @dataclass(frozen=True) 17 | class LearningRateSchedulerConfig(metaclass=RegistryMeta): 18 | def make_from_optimizer( 19 | self, optimizer: torch.optim.Optimizer 20 | ) -> torch.optim.lr_scheduler._LRScheduler: 21 | torch_lr_scheduler_class = getattr( 22 | torch.optim.lr_scheduler, type(self).__name__ 23 | ) 24 | assert is_torch_lr_scheduler( 25 | torch_lr_scheduler_class 26 | ), f"{torch_lr_scheduler_class} is not a scheduler." 27 | 28 | filtered_args = { 29 | k: getattr(self, k) 30 | for k in inspect.signature(torch_lr_scheduler_class).parameters 31 | if k != "optimizer" 32 | } 33 | 34 | self.decode_lambdas(filtered_args) 35 | 36 | return torch_lr_scheduler_class(optimizer=optimizer, **filtered_args) 37 | 38 | def decode_lambdas(self, args: Dict[str, Any]) -> None: 39 | pass 40 | -------------------------------------------------------------------------------- /reagent/optimizer/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import inspect 7 | 8 | import torch 9 | 10 | 11 | def is_strict_subclass(a: object, b: object): 12 | if not inspect.isclass(a) or not inspect.isclass(b): 13 | return False 14 | return issubclass(a, b) and a != b 15 | 16 | 17 | def is_torch_optimizer(cls): 18 | return is_strict_subclass(cls, torch.optim.Optimizer) 19 | 20 | 21 | def is_torch_lr_scheduler(cls): 22 | return is_strict_subclass( 23 | cls, torch.optim.lr_scheduler._LRScheduler 24 | ) or is_strict_subclass(cls, torch.optim.lr_scheduler.LRScheduler) 25 | -------------------------------------------------------------------------------- /reagent/prediction/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/prediction/cfeval/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/prediction/ranking/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/prediction/synthetic_reward/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/publishers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/publishers/union.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from reagent.core.fb_checker import IS_FB_ENVIRONMENT 7 | from reagent.core.tagged_union import TaggedUnion 8 | 9 | from .file_system_publisher import FileSystemPublisher # noqa 10 | from .model_publisher import ModelPublisher 11 | from .no_publishing import NoPublishing # noqa 12 | 13 | 14 | if IS_FB_ENVIRONMENT: 15 | # pyre-fixme[21]: Could not find module 16 | # `fblearner.flow.projects.rl.publishing.clients`. 17 | import fblearner.flow.projects.rl.publishing.clients # noqa 18 | 19 | # pyre-fixme[21]: Could not find module 20 | # `fblearner.flow.projects.rl.publishing.common`. 21 | import fblearner.flow.projects.rl.publishing.common # noqa 22 | 23 | 24 | @ModelPublisher.fill_union() 25 | class ModelPublisher__Union(TaggedUnion): 26 | pass 27 | -------------------------------------------------------------------------------- /reagent/replay_memory/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from .circular_replay_buffer import ReplayBuffer 7 | from .prioritized_replay_buffer import PrioritizedReplayBuffer 8 | 9 | 10 | __all__ = ["ReplayBuffer", "PrioritizedReplayBuffer"] 11 | -------------------------------------------------------------------------------- /reagent/reporting/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from .compound_reporter import CompoundReporter 7 | from .reporter_base import ReporterBase 8 | 9 | __all__ = [ 10 | "CompoundReporter", 11 | "ReporterBase", 12 | ] 13 | -------------------------------------------------------------------------------- /reagent/reporting/compound_reporter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from typing import Callable, List 7 | 8 | from reagent.core.result_registries import TrainingReport 9 | 10 | from .reporter_base import ReporterBase 11 | 12 | 13 | class CompoundReporter(ReporterBase): 14 | def __init__( 15 | self, 16 | reporters: List[ReporterBase], 17 | merge_function: Callable[[List[ReporterBase]], TrainingReport], 18 | ) -> None: 19 | super().__init__({}, {}) 20 | self._reporters = reporters 21 | self._merge_function = merge_function 22 | self._flush_function = None 23 | 24 | def set_flush_function(self, flush_function) -> None: 25 | self._flush_function = flush_function 26 | 27 | def log(self, **kwargs) -> None: 28 | raise RuntimeError("You should call log() on this reporter") 29 | 30 | def flush(self, epoch: int) -> None: 31 | if self._flush_function: 32 | self._flush_function(self, epoch) 33 | else: 34 | for reporter in self._reporters: 35 | reporter.flush(epoch) 36 | 37 | def generate_training_report(self) -> TrainingReport: 38 | return self._merge_function(self._reporters) 39 | -------------------------------------------------------------------------------- /reagent/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from .frechet import FrechetSort 7 | 8 | 9 | __all__ = ["FrechetSort"] 10 | -------------------------------------------------------------------------------- /reagent/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/test/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | -------------------------------------------------------------------------------- /reagent/test/base/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | -------------------------------------------------------------------------------- /reagent/test/core/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | -------------------------------------------------------------------------------- /reagent/test/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | -------------------------------------------------------------------------------- /reagent/test/evaluation/cb/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/reagent/test/evaluation/cb/__init__.py -------------------------------------------------------------------------------- /reagent/test/lite/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # light APIs for solving optimization problems. 5 | -------------------------------------------------------------------------------- /reagent/test/mab/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | -------------------------------------------------------------------------------- /reagent/test/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | -------------------------------------------------------------------------------- /reagent/test/models/test_cb_fully_connected.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import unittest 7 | 8 | import torch 9 | from reagent.models.cb_fully_connected_network import CBFullyConnectedNetwork 10 | 11 | 12 | class TestCBFullyConnectedNetwork(unittest.TestCase): 13 | def test_call_no_ucb(self) -> None: 14 | model = CBFullyConnectedNetwork(2, [5, 7], activation="relu") 15 | 16 | inp = torch.tensor([[1.0, 5.0], [1.0, 6.0]]) 17 | model_output = model(inp) 18 | pred_label = model_output["pred_label"] 19 | ucb = model_output["ucb"] 20 | 21 | self.assertIsInstance(pred_label, torch.Tensor) 22 | self.assertEqual(tuple(pred_label.shape), (2,)) 23 | assert torch.allclose(pred_label, ucb, atol=1e-4, rtol=1e-4) 24 | -------------------------------------------------------------------------------- /reagent/test/models/test_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | import logging 6 | from typing import Callable, Optional 7 | 8 | import torch 9 | from reagent.models.base import ModelBase 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def run_model_jit_trace( 16 | model: ModelBase, 17 | script_model, 18 | compare_func: Optional[Callable] = None, 19 | ): 20 | input_prototype = model.input_prototype() 21 | if not isinstance(input_prototype, (list, tuple)): 22 | input_prototype = (input_prototype,) 23 | tensor_input_prototype = tuple(x.float_features for x in input_prototype) 24 | traced_model = torch.jit.trace(script_model, tensor_input_prototype) 25 | 26 | x = model(*input_prototype) 27 | y = traced_model(*tensor_input_prototype) 28 | 29 | if compare_func: 30 | compare_func(x, y) 31 | elif isinstance(x, (list, tuple)): 32 | assert isinstance(y, (list, tuple)) 33 | for xx, yy in x, y: 34 | assert isinstance(xx, torch.Tensor) 35 | assert isinstance(yy, torch.Tensor) 36 | assert torch.all(xx == yy) 37 | else: 38 | assert isinstance(x, torch.Tensor) 39 | assert isinstance(y, torch.Tensor) 40 | assert torch.all(x == y) 41 | -------------------------------------------------------------------------------- /reagent/test/net_builder/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | -------------------------------------------------------------------------------- /reagent/test/net_builder/test_value_net_builder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import unittest 7 | 8 | import torch 9 | from reagent.core.parameters import NormalizationData, NormalizationParameters 10 | from reagent.core.types import FeatureData 11 | from reagent.net_builder import value 12 | from reagent.net_builder.unions import ValueNetBuilder__Union 13 | from reagent.preprocessing.identify_types import CONTINUOUS 14 | 15 | 16 | class TestValueNetBuilder(unittest.TestCase): 17 | def test_fully_connected(self) -> None: 18 | # pyre-fixme[28]: Unexpected keyword argument `FullyConnected`. 19 | chooser = ValueNetBuilder__Union( 20 | FullyConnected=value.fully_connected.FullyConnected() 21 | ) 22 | builder = chooser.value 23 | state_dim = 3 24 | normalization_data = NormalizationData( 25 | dense_normalization_parameters={ 26 | i: NormalizationParameters(feature_type=CONTINUOUS) 27 | for i in range(state_dim) 28 | } 29 | ) 30 | value_network = builder.build_value_network(normalization_data) 31 | batch_size = 5 32 | x = FeatureData(float_features=torch.randn(batch_size, state_dim)) 33 | y = value_network(x) 34 | self.assertEqual(y.shape, (batch_size, 1)) 35 | -------------------------------------------------------------------------------- /reagent/test/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | -------------------------------------------------------------------------------- /reagent/test/prediction/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | -------------------------------------------------------------------------------- /reagent/test/prediction/test_prediction_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import torch 7 | from reagent.preprocessing.identify_types import CONTINUOUS, CONTINUOUS_ACTION 8 | from reagent.preprocessing.normalization import NormalizationParameters 9 | 10 | 11 | def _cont_norm(): 12 | return NormalizationParameters(feature_type=CONTINUOUS, mean=0.0, stddev=1.0) 13 | 14 | 15 | def _cont_action_norm(): 16 | return NormalizationParameters( 17 | feature_type=CONTINUOUS_ACTION, min_value=-3.0, max_value=3.0 18 | ) 19 | 20 | 21 | def change_cand_size_slate_ranking(input_prototype, candidate_size_override): 22 | state_prototype, candidate_prototype = input_prototype 23 | candidate_prototype = ( 24 | candidate_prototype[0][:, :1, :].repeat(1, candidate_size_override, 1), 25 | candidate_prototype[1][:, :1, :].repeat(1, candidate_size_override, 1), 26 | ) 27 | return ( 28 | (torch.randn_like(state_prototype[0]), torch.ones_like(state_prototype[1])), 29 | ( 30 | torch.randn_like(candidate_prototype[0]), 31 | torch.ones_like(candidate_prototype[1]), 32 | ), 33 | ) 34 | -------------------------------------------------------------------------------- /reagent/test/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | -------------------------------------------------------------------------------- /reagent/test/preprocessing/test_type_identification.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import unittest 7 | 8 | from reagent.preprocessing import identify_types 9 | from reagent.test.preprocessing.preprocessing_util import ( 10 | BINARY_FEATURE_ID, 11 | BOXCOX_FEATURE_ID, 12 | CONTINUOUS_FEATURE_ID, 13 | ENUM_FEATURE_ID, 14 | PROBABILITY_FEATURE_ID, 15 | QUANTILE_FEATURE_ID, 16 | read_data, 17 | ) 18 | 19 | 20 | class TestTypeIdentification(unittest.TestCase): 21 | def test_identification(self) -> None: 22 | feature_value_map = read_data() 23 | 24 | types = {} 25 | for name, values in feature_value_map.items(): 26 | types[name] = identify_types.identify_type(values) 27 | 28 | # Examples through manual inspection 29 | self.assertEqual(types[BINARY_FEATURE_ID], identify_types.BINARY) 30 | self.assertEqual(types[CONTINUOUS_FEATURE_ID], identify_types.CONTINUOUS) 31 | 32 | # We don't yet know the boxcox type 33 | self.assertEqual(types[BOXCOX_FEATURE_ID], identify_types.CONTINUOUS) 34 | 35 | # We don't yet know the quantile type 36 | self.assertEqual(types[QUANTILE_FEATURE_ID], identify_types.CONTINUOUS) 37 | self.assertEqual(types[ENUM_FEATURE_ID], identify_types.ENUM) 38 | self.assertEqual(types[PROBABILITY_FEATURE_ID], identify_types.PROBABILITY) 39 | -------------------------------------------------------------------------------- /reagent/test/ranking/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | -------------------------------------------------------------------------------- /reagent/test/replay_memory/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | -------------------------------------------------------------------------------- /reagent/test/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | -------------------------------------------------------------------------------- /reagent/test/simulators/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | -------------------------------------------------------------------------------- /reagent/test/training/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | -------------------------------------------------------------------------------- /reagent/test/training/cb/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | -------------------------------------------------------------------------------- /reagent/test/workflow/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/test/workflow/test_data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/reagent/test/workflow/test_data/__init__.py -------------------------------------------------------------------------------- /reagent/test/workflow/test_data/continuous_action/action_norm.json: -------------------------------------------------------------------------------- 1 | {"3": "{\"feature_type\": \"CONTINUOUS\", \"boxcox_lambda\": null, \"boxcox_shift\": null, \"mean\": 0.3202674388885498, \"stddev\": 1.6936050653457642, \"possible_values\": null, \"quantiles\": null, \"min_value\": -3.3081326484680176, \"max_value\": 3.120173692703247}"} -------------------------------------------------------------------------------- /reagent/test/workflow/test_data/continuous_action/pendulum_eval.json.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/reagent/test/workflow/test_data/continuous_action/pendulum_eval.json.bz2 -------------------------------------------------------------------------------- /reagent/test/workflow/test_data/continuous_action/pendulum_training.json.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/reagent/test/workflow/test_data/continuous_action/pendulum_training.json.bz2 -------------------------------------------------------------------------------- /reagent/test/workflow/test_data/continuous_action/state_features_norm.json: -------------------------------------------------------------------------------- 1 | {"0": "{\"feature_type\": \"BOXCOX\", \"boxcox_lambda\": 0.11620849064705728, \"boxcox_shift\": 1.0, \"mean\": -2.3017139434814453, \"stddev\": 1.366979718208313, \"possible_values\": null, \"quantiles\": null, \"min_value\": -1.0, \"max_value\": 0.9999994039535522}", "1": "{\"feature_type\": \"CONTINUOUS\", \"boxcox_lambda\": null, \"boxcox_shift\": null, \"mean\": -0.061158571392297745, \"stddev\": 0.43823540210723877, \"possible_values\": null, \"quantiles\": null, \"min_value\": -1.0, \"max_value\": 0.9987553358078003}", "2": "{\"feature_type\": \"CONTINUOUS\", \"boxcox_lambda\": null, \"boxcox_shift\": null, \"mean\": -0.02112729288637638, \"stddev\": 1.3180352449417114, \"possible_values\": null, \"quantiles\": null, \"min_value\": -5.819076061248779, \"max_value\": 6.1764092445373535}"} -------------------------------------------------------------------------------- /reagent/test/workflow/test_data/discrete_action/dqn_workflow.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/reagent/test/workflow/test_data/discrete_action/dqn_workflow.zip -------------------------------------------------------------------------------- /reagent/test/workflow/test_data/parametric_action/action_norm.json: -------------------------------------------------------------------------------- 1 | {"4": "{\"feature_type\": \"BINARY\", \"boxcox_lambda\": null, \"boxcox_shift\": 0, \"mean\": 0, \"stddev\": 1, \"possible_values\": null, \"quantiles\": null, \"min_value\": 0.0, \"max_value\": 1.0}", "5": "{\"feature_type\": \"BINARY\", \"boxcox_lambda\": null, \"boxcox_shift\": 0, \"mean\": 0, \"stddev\": 1, \"possible_values\": null, \"quantiles\": null, \"min_value\": 0.0, \"max_value\": 1.0}"} -------------------------------------------------------------------------------- /reagent/test/workflow/test_data/parametric_action/cartpole_eval.json.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/reagent/test/workflow/test_data/parametric_action/cartpole_eval.json.bz2 -------------------------------------------------------------------------------- /reagent/test/workflow/test_data/parametric_action/cartpole_training.json.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/reagent/test/workflow/test_data/parametric_action/cartpole_training.json.bz2 -------------------------------------------------------------------------------- /reagent/test/workflow/test_data/parametric_action/state_features_norm.json: -------------------------------------------------------------------------------- 1 | {"0": "{\"feature_type\": \"CONTINUOUS\", \"boxcox_lambda\": null, \"boxcox_shift\": null, \"mean\": 0.1436615288257599, \"stddev\": 0.32602012157440186, \"possible_values\": null, \"quantiles\": null, \"min_value\": -0.7986576557159424, \"max_value\": 1.4071340560913086}", "1": "{\"feature_type\": \"CONTINUOUS\", \"boxcox_lambda\": null, \"boxcox_shift\": null, \"mean\": 0.09220181405544281, \"stddev\": 0.3116305470466614, \"possible_values\": null, \"quantiles\": null, \"min_value\": -1.7964972257614136, \"max_value\": 1.5968587398529053}", "2": "{\"feature_type\": \"CONTINUOUS\", \"boxcox_lambda\": null, \"boxcox_shift\": null, \"mean\": 0.0007701526628807187, \"stddev\": 0.04096432402729988, \"possible_values\": null, \"quantiles\": null, \"min_value\": -0.20418913662433624, \"max_value\": 0.20443527400493622}", "3": "{\"feature_type\": \"CONTINUOUS\", \"boxcox_lambda\": null, \"boxcox_shift\": null, \"mean\": -0.00856819935142994, \"stddev\": 0.31406718492507935, \"possible_values\": null, \"quantiles\": null, \"min_value\": -2.3652102947235107, \"max_value\": 2.5358734130859375}"} -------------------------------------------------------------------------------- /reagent/test/world_model/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | -------------------------------------------------------------------------------- /reagent/training/cb/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/training/cb/mab_trainer.py: -------------------------------------------------------------------------------- 1 | # pyre-unsafe 2 | import logging 3 | 4 | from reagent.core.types import CBInput 5 | from reagent.gym.policies.policy import Policy 6 | from reagent.models.mab import MABBaseModel 7 | from reagent.training.cb.base_trainer import BaseCBTrainerWithEval 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class MABTrainer(BaseCBTrainerWithEval): 13 | def __init__( 14 | self, 15 | policy: Policy, 16 | *args, 17 | **kwargs, 18 | ): 19 | super().__init__(automatic_optimization=False, *args, **kwargs) 20 | assert isinstance(policy.scorer, MABBaseModel) 21 | self.scorer = policy.scorer 22 | 23 | def cb_training_step(self, batch: CBInput, batch_idx: int, optimizer_idx: int = 0): 24 | self.scorer.learn(batch) 25 | 26 | def configure_optimizers(self): 27 | # no optimizers bcs we update state manually 28 | return None 29 | -------------------------------------------------------------------------------- /reagent/training/cfeval/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from .bandit_reward_network_trainer import BanditRewardNetTrainer 7 | 8 | __all__ = [ 9 | "BanditRewardNetTrainer", 10 | ] 11 | -------------------------------------------------------------------------------- /reagent/training/gradient_free/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/training/imitator_training.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import logging 7 | 8 | import torch 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def get_valid_actions_from_imitator(imitator, input, drop_threshold): 14 | """Create mask for non-viable actions under the imitator.""" 15 | if isinstance(imitator, torch.nn.Module): 16 | # pytorch model 17 | imitator_outputs = imitator(input.float_features) 18 | on_policy_action_probs = torch.nn.functional.softmax(imitator_outputs, dim=1) 19 | else: 20 | # sci-kit learn model 21 | on_policy_action_probs = torch.tensor(imitator(input.float_features.cpu())) 22 | 23 | filter_values = ( 24 | on_policy_action_probs / on_policy_action_probs.max(keepdim=True, dim=1)[0] 25 | ) 26 | return (filter_values >= drop_threshold).float() 27 | -------------------------------------------------------------------------------- /reagent/training/ranking/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/training/ranking/helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | from typing import Optional 6 | 7 | import torch 8 | from reagent.core.parameters_seq2slate import IPSClamp, IPSClampMethod 9 | 10 | 11 | def ips_clamp(impt_smpl, ips_clamp: Optional[IPSClamp]): 12 | if not ips_clamp: 13 | return impt_smpl.clone() 14 | if ips_clamp.clamp_method == IPSClampMethod.UNIVERSAL: 15 | return torch.clamp(impt_smpl, 0, ips_clamp.clamp_max) 16 | elif ips_clamp.clamp_method == IPSClampMethod.AGGRESSIVE: 17 | return torch.where( 18 | impt_smpl > ips_clamp.clamp_max, torch.zeros_like(impt_smpl), impt_smpl 19 | ) 20 | -------------------------------------------------------------------------------- /reagent/training/world_model/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/validators/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/validators/no_validation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | from typing import List, Optional 6 | 7 | from reagent.core.dataclasses import dataclass 8 | from reagent.core.result_types import NoValidationResults 9 | from reagent.validators.model_validator import ModelValidator 10 | 11 | # pyre-fixme[21]: Could not find module `reagent.workflow.types`. 12 | from reagent.workflow.types import RLTrainingOutput, TableSpec 13 | 14 | 15 | @dataclass 16 | class NoValidation(ModelValidator): 17 | """ 18 | This is an example of how to create a validator. This validator performs no 19 | validation. In your own validator, you would want to have `validate()` performs 20 | some validation. 21 | """ 22 | 23 | def do_validate( 24 | self, 25 | # pyre-fixme[11]: Annotation `RLTrainingOutput` is not defined as a type. 26 | training_output: RLTrainingOutput, 27 | result_history: Optional[List[RLTrainingOutput]] = None, 28 | # pyre-fixme[11]: Annotation `TableSpec` is not defined as a type. 29 | input_table_spec: Optional[TableSpec] = None, 30 | ) -> NoValidationResults: 31 | return NoValidationResults(should_publish=True) 32 | -------------------------------------------------------------------------------- /reagent/validators/union.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from reagent.core.fb_checker import IS_FB_ENVIRONMENT 7 | from reagent.core.tagged_union import TaggedUnion 8 | 9 | from .model_validator import ModelValidator 10 | from .no_validation import NoValidation # noqa 11 | 12 | 13 | if IS_FB_ENVIRONMENT: 14 | # pyre-fixme[21]: Could not find module 15 | # `fblearner.flow.projects.rl.validation.clients`. 16 | import fblearner.flow.projects.rl.validation.clients # noqa 17 | 18 | # pyre-fixme[21]: Could not find module 19 | # `fblearner.flow.projects.rl.validation.common`. 20 | import fblearner.flow.projects.rl.validation.common # noqa 21 | 22 | 23 | @ModelValidator.fill_union() 24 | class ModelValidator__Union(TaggedUnion): 25 | pass 26 | -------------------------------------------------------------------------------- /reagent/workflow/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | -------------------------------------------------------------------------------- /reagent/workflow/env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | # pyre-unsafe 5 | 6 | from typing import List 7 | 8 | # pyre-fixme[21]: Could not find module `reagent.workflow.types`. 9 | from reagent.workflow.types import ModuleNameToEntityId 10 | 11 | 12 | def get_workflow_id() -> int: 13 | # This is just stub. You will want to replace this file. 14 | return 987654321 15 | 16 | 17 | # pyre-fixme[11]: Annotation `ModuleNameToEntityId` is not defined as a type. 18 | def get_new_named_entity_ids(module_names: List[str]) -> ModuleNameToEntityId: 19 | result = {} 20 | i = 1 21 | done_one = False 22 | for name in module_names: 23 | if not done_one: 24 | result[name] = get_workflow_id() 25 | done_one = True 26 | else: 27 | # this is just random, you'll want to replace 28 | result[name] = 987654321 - i 29 | i += 1 30 | return result 31 | -------------------------------------------------------------------------------- /reagent/workflow/sample_configs/continuous_action/timeline.json: -------------------------------------------------------------------------------- 1 | { 2 | "timeline": { 3 | "startDs": "2019-01-01", 4 | "endDs": "2019-01-01", 5 | "addTerminalStateRow": false, 6 | "actionDiscrete": false, 7 | "inputTableName": "pendulum", 8 | "outputTableName": "pendulum_training", 9 | "evalTableName": "pendulum_eval", 10 | "numOutputShards": 1, 11 | "includePossibleActions": false, 12 | "percentileFunction": "percentile_approx", 13 | "rewardColumns": ["reward", "metrics"], 14 | "extraFeatureColumns": [] 15 | }, 16 | "query": { 17 | "tableSample": 5 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /reagent/workflow/sample_configs/discrete_action/dqn_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "training_data_path": "training_data/cartpole_discrete_timeline.json", 3 | "eval_data_path": "training_data/cartpole_discrete_timeline_eval.json", 4 | "state_norm_data_path": "training_data/state_features_norm.json", 5 | "model_output_path": "outputs/", 6 | "use_gpu": true, 7 | "use_all_avail_gpus": true, 8 | "num_nodes": 1, 9 | "init_method": "file:///tmp/dqn_example.lock", 10 | "norm_params": { 11 | "output_dir": "training_data/", 12 | "cols_to_norm": [ 13 | "state_features" 14 | ], 15 | "num_samples": 1000 16 | }, 17 | "actions": [ 18 | "4", 19 | "5" 20 | ], 21 | "epochs": 150, 22 | "rl": { 23 | "gamma": 0.99, 24 | "target_update_rate": 0.1, 25 | "maxq_learning": true, 26 | "epsilon": 0.2, 27 | "temperature": 0.35, 28 | "softmax_policy": 0 29 | }, 30 | "rainbow": { 31 | "double_q_learning": true, 32 | "dueling_architecture": false 33 | }, 34 | "training": { 35 | "layers": [ 36 | -1, 37 | 128, 38 | 64, 39 | -1 40 | ], 41 | "activations": [ 42 | "relu", 43 | "relu", 44 | "linear" 45 | ], 46 | "minibatch_size": 512, 47 | "learning_rate": 0.01, 48 | "optimizer": "ADAM", 49 | "lr_decay": 0.999, 50 | "warm_start_model_path": null 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /reagent/workflow/sample_configs/discrete_action/timeline.json: -------------------------------------------------------------------------------- 1 | { 2 | "timeline": { 3 | "startDs": "2019-01-01", 4 | "endDs": "2019-01-01", 5 | "addTerminalStateRow": true, 6 | "actionDiscrete": true, 7 | "inputTableName": "cartpole_discrete", 8 | "outputTableName": "cartpole_discrete_training", 9 | "evalTableName": "cartpole_discrete_eval", 10 | "numOutputShards": 1, 11 | "includePossibleActions": true, 12 | "percentileFunction": "percentile_approx", 13 | "rewardColumns": ["reward", "metrics"], 14 | "extraFeatureColumns": [] 15 | }, 16 | "query": { 17 | "tableSample": 100, 18 | "actions": ["0", "1"] 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /reagent/workflow/sample_configs/discrete_dqn_cartpole_offline.yaml: -------------------------------------------------------------------------------- 1 | env_name: CartPole-v0 2 | model_path: "cartpole_batch_rl_model.torchscript" 3 | pkl_path: "/tmp/tmp_pickle.pkl" 4 | input_table_spec: 5 | table_name: test_table 6 | table_sample: 90 7 | eval_table_sample: 10 8 | model: 9 | DiscreteDQN: 10 | trainer_param: 11 | actions: 12 | - 0 13 | - 1 14 | rl: 15 | gamma: 0.99 16 | target_update_rate: 0.1 17 | maxq_learning: true 18 | temperature: 0.35 19 | softmax_policy: false 20 | q_network_loss: mse 21 | double_q_learning: true 22 | minibatch_size: 512 23 | minibatches_per_step: 1 24 | optimizer: 25 | Adam: 26 | lr: 0.01 27 | weight_decay: 0.01 28 | net_builder: 29 | FullyConnected: 30 | sizes: 31 | - 128 32 | - 64 33 | activations: 34 | - relu 35 | - relu 36 | cpe_net_builder: 37 | FullyConnected: 38 | sizes: 39 | - 128 40 | - 64 41 | activations: 42 | - relu 43 | - relu 44 | preprocessing_options: 45 | num_samples: 1000 46 | eval_parameters: 47 | calc_cpe_in_training: true 48 | num_train_transitions: 30000 # approx. 150 episodes 49 | max_steps: 200 50 | seed: 42 51 | num_epochs: 20 52 | publisher: 53 | FileSystemPublisher: {} 54 | num_eval_episodes: 50 55 | passing_score_bar: 120 56 | -------------------------------------------------------------------------------- /reagent/workflow/sample_configs/parametric_action/parametric_dqn_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "training_data_path": "training_data/cartpole_parametric_training_data.json", 3 | "eval_data_path": "training_data/cartpole_parametric_eval_data.json", 4 | "state_norm_data_path": "training_data/state_features_norm.json", 5 | "action_norm_data_path": "training_data/action_norm.json", 6 | "model_output_path": "outputs/", 7 | "use_gpu": true, 8 | "use_all_avail_gpus": true, 9 | "num_nodes": 1, 10 | "init_method": "file:///tmp/parametric_dqn_example.lock", 11 | "norm_params": { 12 | "output_dir": "training_data/", 13 | "cols_to_norm": [ 14 | "state_features", 15 | "action" 16 | ], 17 | "num_samples": 1000 18 | }, 19 | "epochs": 1, 20 | "rl": { 21 | "gamma": 0.99, 22 | "target_update_rate": 0.2, 23 | 24 | "maxq_learning": true, 25 | "epsilon": 0.2, 26 | "temperature": 0.35, 27 | "softmax_policy": 0 28 | }, 29 | "rainbow": { 30 | "double_q_learning": true, 31 | "dueling_architecture": false 32 | }, 33 | "training": { 34 | "layers": [ 35 | -1, 36 | 128, 37 | 64, 38 | -1 39 | ], 40 | "activations": [ 41 | "relu", 42 | "relu", 43 | "linear" 44 | ], 45 | "minibatch_size": 128, 46 | "learning_rate": 0.001, 47 | "optimizer": "ADAM", 48 | "lr_decay": 0.999, 49 | "warm_start_model_path": null 50 | } 51 | } -------------------------------------------------------------------------------- /reagent/workflow/sample_configs/parametric_action/timeline.json: -------------------------------------------------------------------------------- 1 | { 2 | "timeline": { 3 | "startDs": "2019-01-01", 4 | "endDs": "2019-01-01", 5 | "addTerminalStateRow": false, 6 | "actionDiscrete": false, 7 | "inputTableName": "cartpole_parametric", 8 | "outputTableName": "cartpole_parametric_training", 9 | "evalTableName": "cartpole_parametric_eval", 10 | "numOutputShards": 1, 11 | "includePossibleActions": true, 12 | "percentileFunction": "percentile_approx", 13 | "rewardColumns": ["reward", "metrics"], 14 | "extraFeatureColumns": [] 15 | }, 16 | "query": { 17 | "tableSample": 15 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /scripts/recurring_training_sac_offline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x -e 4 | 5 | rm -f /tmp/file_system_publisher 6 | rm -Rf test_warmstart model_* pl_log* runs 7 | 8 | CONFIG=reagent/workflow/sample_configs/sac_pendulum_offline.yaml 9 | 10 | python ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.offline_gym_random "$CONFIG" 11 | rm -Rf spark-warehouse derby.log metastore_db 12 | python ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.timeline_operator "$CONFIG" 13 | python ./reagent/workflow/cli.py run reagent.workflow.training.identify_and_train_network "$CONFIG" 14 | python ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.evaluate_gym "$CONFIG" 15 | 16 | for _ in {0..30} 17 | do 18 | python ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.offline_gym_predictor "$CONFIG" 19 | rm -Rf spark-warehouse derby.log metastore_db 20 | python ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.timeline_operator "$CONFIG" 21 | python ./reagent/workflow/cli.py run reagent.workflow.training.identify_and_train_network "$CONFIG" 22 | python ./reagent/workflow/cli.py run reagent.workflow.gym_batch_rl.evaluate_gym "$CONFIG" 23 | done 24 | -------------------------------------------------------------------------------- /serving/README.md: -------------------------------------------------------------------------------- 1 | TODO -------------------------------------------------------------------------------- /serving/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/serving/examples/__init__.py -------------------------------------------------------------------------------- /serving/examples/ecommerce/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/serving/examples/ecommerce/__init__.py -------------------------------------------------------------------------------- /serving/examples/ecommerce/plans/contextual_bandit.json: -------------------------------------------------------------------------------- 1 | { 2 | "operators": [ 3 | { 4 | "name": "ActionValueScoringOp", 5 | "op_name": "ActionValueScoring", 6 | "input_dep_map": { 7 | "model_id": "model_id", 8 | "snapshot_id": "snapshot_id" 9 | } 10 | }, 11 | { 12 | "name": "SoftmaxRankerOp", 13 | "op_name": "SoftmaxRanker", 14 | "input_dep_map": { 15 | "temperature": "constant_2", 16 | "values": "ActionValueScoringOp" 17 | } 18 | } 19 | ], 20 | "constants": [ 21 | { 22 | "name": "model_id", 23 | "value": { 24 | "int_value": 0 25 | } 26 | }, 27 | { 28 | "name": "snapshot_id", 29 | "value": { 30 | "int_value": 0 31 | } 32 | }, 33 | { 34 | "name": "constant_2", 35 | "value": { 36 | "double_value": 0.001 37 | } 38 | } 39 | ], 40 | "num_actions_to_choose": 1, 41 | "reward_function": "reward", 42 | "reward_aggregator": "sum" 43 | } -------------------------------------------------------------------------------- /serving/examples/ecommerce/plans/heuristic.json: -------------------------------------------------------------------------------- 1 | { 2 | "operators": [ 3 | { 4 | "name": "SoftmaxRanker_1", 5 | "op_name": "SoftmaxRanker", 6 | "input_dep_map": { 7 | "temperature": "constant_2", 8 | "values": "constant_3" 9 | } 10 | } 11 | ], 12 | "constants": [ 13 | { 14 | "name": "constant_2", 15 | "value": { 16 | "double_value": 1.0 17 | } 18 | }, 19 | { 20 | "name": "constant_3", 21 | "value": { 22 | "map_double_value": { 23 | "Bacon": 1.1, 24 | "Ribs": 1.0 25 | } 26 | } 27 | } 28 | ], 29 | "num_actions_to_choose": 1, 30 | "reward_function": "reward", 31 | "reward_aggregator": "sum" 32 | } -------------------------------------------------------------------------------- /serving/examples/ecommerce/plans/multi_armed_bandit.json: -------------------------------------------------------------------------------- 1 | { 2 | "operators": [ 3 | { 4 | "name": "UCB_1", 5 | "op_name": "Ucb", 6 | "input_dep_map": { 7 | "method": "constant_2", 8 | "batch_size": "constant_3" 9 | } 10 | } 11 | ], 12 | "constants": [ 13 | { 14 | "name": "constant_2", 15 | "value": { 16 | "string_value": "UCB1" 17 | } 18 | }, 19 | { 20 | "name": "constant_3", 21 | "value": { 22 | "int_value": 8 23 | } 24 | } 25 | ], 26 | "num_actions_to_choose": 1, 27 | "reward_function": "reward", 28 | "reward_aggregator": "sum" 29 | } -------------------------------------------------------------------------------- /serving/examples/ecommerce/training/contextual_bandit.yaml: -------------------------------------------------------------------------------- 1 | pkl_path: "/tmp/input_df.pkl" 2 | input_table_spec: 3 | table_name: ecom_cb_input_data 4 | table_sample: 90 5 | eval_table_sample: 10 6 | model: 7 | DiscreteDQN: 8 | trainer_param: 9 | actions: 10 | - Bacon 11 | - Ribs 12 | rl: 13 | gamma: 0.0 # zero gamma for bandit setting 14 | target_update_rate: 1.0 15 | maxq_learning: true 16 | temperature: 0.35 17 | softmax_policy: false 18 | q_network_loss: mse 19 | double_q_learning: true 20 | minibatch_size: 128 21 | minibatches_per_step: 1 22 | optimizer: 23 | Adam: 24 | lr: 0.01 25 | eval_parameters: 26 | calc_cpe_in_training: true 27 | net_builder: 28 | FullyConnected: 29 | sizes: [] 30 | activations: [] 31 | cpe_net_builder: 32 | FullyConnected: 33 | sizes: [] 34 | activations: [] 35 | preprocessing_options: 36 | num_samples: 1000 37 | num_epochs: 10 38 | publisher: 39 | FileSystemPublisher: {} 40 | -------------------------------------------------------------------------------- /serving/reagent/serving/cli/Main.cpp: -------------------------------------------------------------------------------- 1 | #include "reagent/serving/core/Headers.h" 2 | 3 | #include 4 | #include 5 | 6 | #include "reagent/serving/cli/Server.h" 7 | #include "reagent/serving/core/DecisionService.h" 8 | #include "reagent/serving/core/DiskConfigProvider.h" 9 | #include "reagent/serving/core/InMemoryLogJoiner.h" 10 | #include "reagent/serving/core/LocalRealTimeCounter.h" 11 | #include "reagent/serving/core/PytorchActionValueScorer.h" 12 | #include "reagent/serving/core/SharedParameterHandler.h" 13 | 14 | namespace reagent { 15 | int Main(int argc, char** argv) { 16 | gflags::ParseCommandLineFlags(&argc, &argv, true); 17 | google::InitGoogleLogging(argv[0]); 18 | 19 | auto service = std::make_shared( 20 | std::make_shared("serving/examples/ecommerce/plans"), 21 | std::make_shared(), 22 | std::make_shared("/tmp/rasp_logging/log.txt"), 23 | std::make_shared(), 24 | std::make_shared()); 25 | 26 | Server server(service, 3000); 27 | server.start(); 28 | 29 | while (true) { 30 | sleep(1); 31 | } 32 | } 33 | } // namespace reagent 34 | 35 | int main(int argc, char** argv) { 36 | return reagent::Main(argc, argv); 37 | } 38 | -------------------------------------------------------------------------------- /serving/reagent/serving/cli/Server.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "reagent/serving/core/DecisionService.h" 4 | #include "reagent/serving/core/Headers.h" 5 | 6 | #include "SimpleWebServer/client_http.hpp" 7 | #include "SimpleWebServer/server_http.hpp" 8 | using HttpServer = SimpleWeb::Server; 9 | using HttpClient = SimpleWeb::Client; 10 | 11 | namespace reagent { 12 | class Server { 13 | public: 14 | Server(std::shared_ptr decisionService, int port); 15 | 16 | void start(); 17 | void shutdown(); 18 | 19 | protected: 20 | HttpServer server_; 21 | std::shared_ptr serverThread_; 22 | std::shared_ptr decisionService_; 23 | int port_; 24 | }; 25 | } // namespace reagent 26 | -------------------------------------------------------------------------------- /serving/reagent/serving/config/applications/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import importlib 5 | import pkgutil 6 | 7 | 8 | __all__ = [] 9 | print(list(pkgutil.walk_packages(__path__))) 10 | 11 | for _, module_name, _ in pkgutil.walk_packages(__path__): 12 | __all__.append(module_name) 13 | _module = importlib.import_module(f"{__name__}.{module_name}") 14 | globals()[module_name] = _module 15 | -------------------------------------------------------------------------------- /serving/reagent/serving/config/applications/example/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/serving/reagent/serving/config/applications/example/__init__.py -------------------------------------------------------------------------------- /serving/reagent/serving/config/namespace.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2004-present Facebook. All Rights Reserved. 3 | 4 | # pyre-unsafe 5 | 6 | import inspect 7 | 8 | 9 | class DecisionOperator: 10 | def __init__(self): 11 | self.id = None 12 | self.op_name = type(self).__name__ 13 | 14 | def arguments(self): 15 | raise NotImplementedError 16 | 17 | 18 | def DecisionOperation(op): 19 | def __init__(self, *args, **kwargs): 20 | self.args = inspect.getcallargs(op, *args, **kwargs) 21 | DecisionOperator.__init__(self) 22 | 23 | def arguments(self): 24 | return self.args 25 | 26 | return type( 27 | op.__name__, (DecisionOperator,), {"__init__": __init__, "arguments": arguments} 28 | ) 29 | -------------------------------------------------------------------------------- /serving/reagent/serving/core/ActionValueScorer.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/serving/reagent/serving/core/ActionValueScorer.cpp -------------------------------------------------------------------------------- /serving/reagent/serving/core/ActionValueScorer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "reagent/serving/core/Headers.h" 4 | 5 | namespace reagent { 6 | 7 | class ActionValueScorer { 8 | public: 9 | virtual ~ActionValueScorer() = default; 10 | 11 | virtual StringDoubleMap 12 | predict(const DecisionRequest& request, int model, int snapshot) = 0; 13 | }; 14 | 15 | } // namespace reagent 16 | -------------------------------------------------------------------------------- /serving/reagent/serving/core/ConfigProvider.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/serving/reagent/serving/core/ConfigProvider.cpp -------------------------------------------------------------------------------- /serving/reagent/serving/core/ConfigProvider.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "reagent/serving/core/Headers.h" 4 | 5 | #include "reagent/serving/core/DecisionPlan.h" 6 | 7 | namespace reagent { 8 | class DecisionService; 9 | 10 | class ConfigProvider { 11 | public: 12 | ConfigProvider() {} 13 | 14 | virtual void initialize(DecisionService* decisionService) { 15 | decisionService_ = decisionService; 16 | } 17 | 18 | virtual ~ConfigProvider() = default; 19 | 20 | protected: 21 | DecisionService* decisionService_; 22 | }; 23 | } // namespace reagent 24 | -------------------------------------------------------------------------------- /serving/reagent/serving/core/DecisionPlan.cpp: -------------------------------------------------------------------------------- 1 | #include "reagent/serving/core/DecisionPlan.h" 2 | 3 | #include 4 | 5 | namespace reagent { 6 | DecisionPlan::DecisionPlan( 7 | const DecisionConfig& config, 8 | const std::vector>& operators, 9 | const StringOperatorDataMap& constants) 10 | : config_(config), operators_(operators), constants_(constants) {} 11 | 12 | double DecisionPlan::evaluateRewardFunction(const StringDoubleMap& metrics) { 13 | exprtk::symbol_table symbolTable; 14 | for (const auto& it : metrics) { 15 | symbolTable.add_constant(it.first, it.second); 16 | } 17 | exprtk::expression expression; 18 | expression.register_symbol_table(symbolTable); 19 | 20 | exprtk::parser parser; 21 | parser.compile(config_.reward_function, expression); 22 | 23 | double value = expression.value(); 24 | return value; 25 | } 26 | 27 | } // namespace reagent 28 | -------------------------------------------------------------------------------- /serving/reagent/serving/core/DecisionPlan.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "reagent/serving/core/Headers.h" 4 | 5 | #include "reagent/serving/core/Operator.h" 6 | 7 | namespace reagent { 8 | class DecisionPlan { 9 | public: 10 | DecisionPlan( 11 | const DecisionConfig& config, 12 | const std::vector>& operators, 13 | const StringOperatorDataMap& constants); 14 | 15 | const DecisionConfig& getConfig() { 16 | return config_; 17 | } 18 | 19 | const std::vector>& getOperators() { 20 | return operators_; 21 | } 22 | 23 | const StringOperatorDataMap& getConstants() { 24 | return constants_; 25 | } 26 | 27 | const std::string& getOutputOperatorName() { 28 | if (operators_.empty()) { 29 | LOG_AND_THROW("Tried to get output operator name but no operators exist"); 30 | } 31 | return operators_.back()->getName(); 32 | } 33 | 34 | double evaluateRewardFunction(const StringDoubleMap& metrics); 35 | 36 | protected: 37 | DecisionConfig config_; 38 | std::vector> operators_; 39 | StringOperatorDataMap constants_; 40 | }; 41 | } // namespace reagent 42 | -------------------------------------------------------------------------------- /serving/reagent/serving/core/DecisionServiceException.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/serving/reagent/serving/core/DecisionServiceException.cpp -------------------------------------------------------------------------------- /serving/reagent/serving/core/DecisionServiceException.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace reagent { 6 | class DecisionServiceException : public std::runtime_error { 7 | public: 8 | explicit DecisionServiceException(const std::string& what) 9 | : std::runtime_error(what) {} 10 | 11 | explicit DecisionServiceException(const char* what) 12 | : std::runtime_error(what) {} 13 | }; 14 | } // namespace reagent 15 | -------------------------------------------------------------------------------- /serving/reagent/serving/core/DiskConfigProvider.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "reagent/serving/core/Headers.h" 4 | 5 | #include "reagent/serving/core/ConfigProvider.h" 6 | 7 | namespace reagent { 8 | class DiskConfigProvider : public ConfigProvider { 9 | public: 10 | explicit DiskConfigProvider(std::string config_dir) { 11 | configDir_ = config_dir; 12 | } 13 | 14 | protected: 15 | std::string configDir_; 16 | 17 | void initialize(DecisionService* decisionService) override; 18 | void readConfig(const std::string& path); 19 | }; 20 | } // namespace reagent 21 | -------------------------------------------------------------------------------- /serving/reagent/serving/core/Headers.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #include "reagent/serving/core/Containers.h" 26 | #include "reagent/serving/core/DecisionServiceException.h" 27 | 28 | #define LOG_AND_THROW(MSG_STREAM) \ 29 | { \ 30 | std::ostringstream errorStream; \ 31 | errorStream << MSG_STREAM; \ 32 | LOG(ERROR) << errorStream.str(); \ 33 | throw reagent::DecisionServiceException(errorStream.str()); \ 34 | } 35 | 36 | namespace reagent { 37 | std::string generateUuid4(); 38 | 39 | } // namespace reagent 40 | -------------------------------------------------------------------------------- /serving/reagent/serving/core/LocalRealTimeCounter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "reagent/serving/core/RealTimeCounter.h" 4 | 5 | namespace reagent { 6 | 7 | class LocalRealTimeCounter : public RealTimeCounter { 8 | public: 9 | LocalRealTimeCounter() : windowSize_(1024 * 1024) {} 10 | 11 | virtual ~LocalRealTimeCounter() override = default; 12 | 13 | virtual int64_t getNumSamples(const std::string& key) override; 14 | 15 | virtual double getMean(const std::string& key) override; 16 | 17 | virtual double getVariance(const std::string& key) override; 18 | 19 | virtual void addValue(const std::string& key, double value) override; 20 | 21 | void setWindowSize(int windowSize) { 22 | windowSize_ = windowSize; 23 | } 24 | 25 | virtual void clear(const std::string& key) override; 26 | 27 | protected: 28 | std::unordered_map> counts_; 29 | int windowSize_; 30 | }; 31 | 32 | } // namespace reagent 33 | -------------------------------------------------------------------------------- /serving/reagent/serving/core/LogJoiner.cpp: -------------------------------------------------------------------------------- 1 | #include "reagent/serving/core/LogJoiner.h" 2 | 3 | #include "reagent/serving/core/DecisionService.h" 4 | 5 | namespace reagent { 6 | void LogJoiner::addRewardToFeedback(Feedback* feedback) {} 7 | 8 | DecisionWithFeedback LogJoiner::deserializeAndJoinDecisionAndFeedback( 9 | StringList decisionAndFeedback) { 10 | if (decisionAndFeedback.size() != 2) { 11 | LOG_AND_THROW( 12 | "Somehow ended up with more than 2 values for the same key: " 13 | << decisionAndFeedback.size()); 14 | } 15 | 16 | DecisionWithFeedback first = json::parse(decisionAndFeedback.at(0)); 17 | const DecisionWithFeedback& second = json::parse(decisionAndFeedback.at(1)); 18 | if (bool(first.feedback)) { 19 | if (bool(second.feedback)) { 20 | LOG_AND_THROW("Got two feedbacks for the same key"); 21 | } 22 | first.decision_request = second.decision_request; 23 | first.decision_response = second.decision_response; 24 | first.operator_outputs = second.operator_outputs; 25 | } else { 26 | if (bool(second.decision_request)) { 27 | LOG_AND_THROW("Got two requests for the same key"); 28 | } 29 | first.feedback = second.feedback; 30 | } 31 | if (decisionService_) { 32 | decisionService_->_giveFeedback(first); 33 | } 34 | return first; 35 | } 36 | 37 | } // namespace reagent 38 | -------------------------------------------------------------------------------- /serving/reagent/serving/core/LogJoiner.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "reagent/serving/core/Headers.h" 4 | 5 | namespace reagent { 6 | class DecisionService; 7 | 8 | class LogJoiner { 9 | public: 10 | LogJoiner() : decisionService_(nullptr) {} 11 | 12 | virtual ~LogJoiner() {} 13 | 14 | virtual void registerDecisionService(DecisionService* decisionService) { 15 | decisionService_ = decisionService; 16 | } 17 | 18 | virtual void logDecision( 19 | const DecisionRequest& request, 20 | const DecisionResponse& decisionResponse, 21 | const StringOperatorDataMap& operator_outputs) = 0; 22 | 23 | virtual void logFeedback(Feedback feedback) = 0; 24 | 25 | virtual DecisionWithFeedback deserializeAndJoinDecisionAndFeedback( 26 | StringList decisionAndFeedback); 27 | 28 | protected: 29 | DecisionService* decisionService_; 30 | 31 | virtual void addRewardToFeedback(Feedback* feedback); 32 | }; 33 | 34 | } // namespace reagent 35 | -------------------------------------------------------------------------------- /serving/reagent/serving/core/Operator.cpp: -------------------------------------------------------------------------------- 1 | #include "reagent/serving/core/Operator.h" 2 | 3 | #include "reagent/serving/core/DecisionService.h" 4 | 5 | namespace reagent { 6 | Operator::Operator( 7 | const std::string& name, 8 | const std::string& planName, 9 | const StringStringMap& inputDepMap, 10 | const DecisionService* const decisionService) 11 | : name_(name), 12 | planName_(planName), 13 | inputDepMap_(inputDepMap), 14 | actionValueScorer_(decisionService->getActionValueScorer()), 15 | logJoiner_(decisionService->getLogJoiner()), 16 | realTimeCounter_(decisionService->getRealTimeCounter()), 17 | sharedParameterHandler_(decisionService->getSharedParameterHandler()) { 18 | for (const auto& it : inputDepMap) { 19 | deps_.insert(it.second); 20 | } 21 | } 22 | } // namespace reagent 23 | -------------------------------------------------------------------------------- /serving/reagent/serving/core/OperatorRunner.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "reagent/serving/core/Headers.h" 4 | 5 | #include "reagent/serving/core/Operator.h" 6 | 7 | namespace reagent { 8 | class OperatorRunner { 9 | public: 10 | OperatorRunner() {} 11 | 12 | StringOperatorDataMap run( 13 | const std::vector>& ops, 14 | const StringOperatorDataMap& constants, 15 | const DecisionRequest& request, 16 | const OperatorData& extraInput); 17 | 18 | protected: 19 | tf::Executor taskExecutor_; 20 | }; 21 | } // namespace reagent 22 | -------------------------------------------------------------------------------- /serving/reagent/serving/core/PytorchActionValueScorer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "reagent/serving/core/ActionValueScorer.h" 4 | 5 | #include 6 | 7 | namespace reagent { 8 | 9 | class PytorchActionValueScorer : public ActionValueScorer { 10 | public: 11 | PytorchActionValueScorer(); 12 | 13 | virtual ~PytorchActionValueScorer() override = default; 14 | 15 | StringDoubleMap 16 | predict(const DecisionRequest& request, int modelId, int snapshotId) override; 17 | 18 | protected: 19 | std::unordered_map models_; 20 | }; 21 | 22 | } // namespace reagent 23 | -------------------------------------------------------------------------------- /serving/reagent/serving/core/RealTimeCounter.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/serving/reagent/serving/core/RealTimeCounter.cpp -------------------------------------------------------------------------------- /serving/reagent/serving/core/RealTimeCounter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "reagent/serving/core/Headers.h" 4 | 5 | namespace reagent { 6 | 7 | class RealTimeCounter { 8 | public: 9 | virtual ~RealTimeCounter() = default; 10 | 11 | virtual int64_t getNumSamples(const std::string& key) = 0; 12 | 13 | virtual double getMean(const std::string& key) = 0; 14 | 15 | virtual double getVariance(const std::string& key) = 0; 16 | 17 | virtual void addValue(const std::string& key, double value) = 0; 18 | 19 | virtual void clear(const std::string& key) = 0; 20 | }; 21 | 22 | } // namespace reagent 23 | -------------------------------------------------------------------------------- /serving/reagent/serving/core/SharedParameterHandler.cpp: -------------------------------------------------------------------------------- 1 | #include "reagent/serving/core/SharedParameterHandler.h" 2 | 3 | namespace reagent { 4 | SharedParameterHandler::SharedParameterHandler() {} 5 | 6 | StringDoubleMap SharedParameterHandler::getValues(const std::string& name) { 7 | if (parameters_.find(name) == parameters_.end()) { 8 | // Add the parameter 9 | parameters_[name] = std::make_shared(name); 10 | } 11 | 12 | auto parameter = parameters_.find(name)->second; 13 | return parameter->getValues(); 14 | } 15 | 16 | bool SharedParameterHandler::acquireLockToModifyParameter(const std::string&) { 17 | return true; 18 | } 19 | 20 | void SharedParameterHandler::updateParameter( 21 | const std::string& name, 22 | const StringDoubleMap& values) { 23 | auto it = parameters_.find(name); 24 | if (it == parameters_.end()) { 25 | LOG_AND_THROW("Tried to update a parameter that doesn't exist: " << name); 26 | } 27 | it->second->updateValues(values); 28 | } 29 | 30 | } // namespace reagent 31 | -------------------------------------------------------------------------------- /serving/reagent/serving/core/SharedParameterHandler.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "reagent/serving/core/Headers.h" 4 | 5 | namespace reagent { 6 | class SharedParameterInfo { 7 | public: 8 | explicit SharedParameterInfo(std::string name) : name_(name) {} 9 | 10 | time_t getLastFetchTime() { 11 | return lastFetchTime_; 12 | } 13 | 14 | StringDoubleMap getValues() { 15 | return values_; 16 | } 17 | 18 | void updateValues(StringDoubleMap values) { 19 | values_ = values; 20 | lastFetchTime_ = time(nullptr); 21 | } 22 | 23 | protected: 24 | std::string name_; 25 | time_t lastFetchTime_; 26 | StringDoubleMap values_; 27 | }; 28 | 29 | class SharedParameterHandler { 30 | public: 31 | SharedParameterHandler(); 32 | 33 | virtual ~SharedParameterHandler() = default; 34 | 35 | virtual StringDoubleMap getValues(const std::string& name); 36 | 37 | virtual bool acquireLockToModifyParameter(const std::string& name); 38 | 39 | // This doesn't guarantee that we acquired the lock, maybe there's a better 40 | // architecture? 41 | virtual void updateParameter( 42 | const std::string& name, 43 | const StringDoubleMap& values); 44 | 45 | protected: 46 | std::unordered_map> 47 | parameters_; 48 | 49 | inline std::string get_parameter_store_name( 50 | const std::string& parameter_name) { 51 | return std::string("Parameter_Store_") + parameter_name; 52 | } 53 | }; 54 | } // namespace reagent 55 | -------------------------------------------------------------------------------- /serving/reagent/serving/operators/ActionValueScoring.cpp: -------------------------------------------------------------------------------- 1 | #include "reagent/serving/operators/ActionValueScoring.h" 2 | 3 | #include "reagent/serving/core/ActionValueScorer.h" 4 | #include "reagent/serving/core/OperatorFactory.h" 5 | 6 | namespace reagent { 7 | OperatorData ActionValueScoring::run( 8 | const DecisionRequest& request, 9 | const StringOperatorDataMap& namedInputs) { 10 | int modelId = int(std::get(namedInputs.at("model_id"))); 11 | int snapshotId = int(std::get(namedInputs.at("snapshot_id"))); 12 | OperatorData ret; 13 | ret = runInternal(modelId, snapshotId, request); 14 | return ret; 15 | } 16 | 17 | StringDoubleMap ActionValueScoring::runInternal( 18 | int modelId, 19 | int snapshotId, 20 | const DecisionRequest& request) { 21 | return actionValueScorer_->predict(request, modelId, snapshotId); 22 | } 23 | 24 | REGISTER_OPERATOR(ActionValueScoring, "ActionValueScoring"); 25 | 26 | } // namespace reagent 27 | -------------------------------------------------------------------------------- /serving/reagent/serving/operators/ActionValueScoring.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "reagent/serving/core/Headers.h" 4 | 5 | #include "reagent/serving/core/Operator.h" 6 | 7 | namespace reagent { 8 | class ActionValueScoring : public Operator { 9 | public: 10 | ActionValueScoring( 11 | const std::string& name, 12 | const std::string& planName, 13 | const StringStringMap& inputDepMap, 14 | const DecisionService* const decisionService) 15 | : Operator(name, planName, inputDepMap, decisionService) {} 16 | 17 | virtual ~ActionValueScoring() override = default; 18 | 19 | virtual OperatorData run( 20 | const DecisionRequest& request, 21 | const StringOperatorDataMap& namedInputs) override; 22 | 23 | StringDoubleMap 24 | runInternal(int modelId, int snapshotId, const DecisionRequest& request); 25 | }; 26 | } // namespace reagent 27 | -------------------------------------------------------------------------------- /serving/reagent/serving/operators/EpsilonGreedyRanker.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "reagent/serving/core/Headers.h" 4 | #include "reagent/serving/core/Operator.h" 5 | 6 | namespace reagent { 7 | class EpsilonGreedyRanker : public Operator { 8 | public: 9 | EpsilonGreedyRanker( 10 | const std::string& name, 11 | const std::string& planName, 12 | const StringStringMap& inputDepMap, 13 | const DecisionService* const decisionService) 14 | : Operator(name, planName, inputDepMap, decisionService) { 15 | int seed = std::chrono::system_clock::now().time_since_epoch().count(); 16 | generator_.seed(seed); 17 | } 18 | 19 | virtual ~EpsilonGreedyRanker() override = default; 20 | 21 | virtual OperatorData run( 22 | const DecisionRequest& request, 23 | const StringOperatorDataMap& namedInputs) override; 24 | 25 | virtual RankedActionList runInternal( 26 | const StringDoubleMap& input, 27 | double epsilon); 28 | 29 | protected: 30 | std::mt19937 generator_; 31 | }; 32 | } // namespace reagent 33 | -------------------------------------------------------------------------------- /serving/reagent/serving/operators/Expression.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "reagent/serving/core/Headers.h" 4 | #include "reagent/serving/core/Operator.h" 5 | 6 | namespace reagent { 7 | class Expression : public Operator { 8 | public: 9 | Expression( 10 | const std::string& name, 11 | const std::string& planName, 12 | const StringStringMap& inputDepMap, 13 | const DecisionService* const decisionService) 14 | : Operator(name, planName, inputDepMap, decisionService) {} 15 | 16 | virtual ~Expression() override = default; 17 | 18 | virtual OperatorData run( 19 | const DecisionRequest& request, 20 | const StringOperatorDataMap& namedInputs) override; 21 | 22 | virtual double runInternal( 23 | const std::string& equation, 24 | const StringDoubleMap& symbolTable); 25 | }; 26 | } // namespace reagent 27 | -------------------------------------------------------------------------------- /serving/reagent/serving/operators/Frechet.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "reagent/serving/core/Headers.h" 4 | #include "reagent/serving/core/Operator.h" 5 | 6 | namespace reagent { 7 | class Frechet : public Operator { 8 | public: 9 | Frechet( 10 | const std::string& name, 11 | const std::string& planName, 12 | const StringStringMap& inputDepMap, 13 | const DecisionService* const decisionService) 14 | : Operator(name, planName, inputDepMap, decisionService) {} 15 | 16 | virtual ~Frechet() override = default; 17 | 18 | virtual OperatorData run( 19 | const DecisionRequest& request, 20 | const StringOperatorDataMap& namedInputs) override; 21 | 22 | virtual StringDoubleMap 23 | run(const StringDoubleMap& input, double rho, double gamma, int seed); 24 | }; 25 | } // namespace reagent 26 | -------------------------------------------------------------------------------- /serving/reagent/serving/operators/InputFromRequest.cpp: -------------------------------------------------------------------------------- 1 | #include "reagent/serving/operators/InputFromRequest.h" 2 | 3 | #include "reagent/serving/core/OperatorFactory.h" 4 | 5 | namespace reagent { 6 | OperatorData InputFromRequest::run( 7 | const DecisionRequest& request, 8 | const StringOperatorDataMap&) { 9 | return request.input; 10 | } 11 | 12 | REGISTER_OPERATOR(InputFromRequest, "InputFromRequest"); 13 | 14 | } // namespace reagent 15 | -------------------------------------------------------------------------------- /serving/reagent/serving/operators/InputFromRequest.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "reagent/serving/core/Headers.h" 4 | #include "reagent/serving/core/Operator.h" 5 | 6 | namespace reagent { 7 | class InputFromRequest : public Operator { 8 | public: 9 | InputFromRequest( 10 | const std::string& name, 11 | const std::string& planName, 12 | const StringStringMap& inputDepMap, 13 | const DecisionService* const decisionService) 14 | : Operator(name, planName, inputDepMap, decisionService) {} 15 | 16 | virtual ~InputFromRequest() override = default; 17 | 18 | virtual OperatorData run( 19 | const DecisionRequest& request, 20 | const StringOperatorDataMap& namedInputs) override; 21 | }; 22 | } // namespace reagent 23 | -------------------------------------------------------------------------------- /serving/reagent/serving/operators/PropensityFit.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "reagent/serving/core/Headers.h" 4 | 5 | #include "reagent/serving/core/Operator.h" 6 | 7 | namespace reagent { 8 | class PropensityFit : public Operator { 9 | public: 10 | PropensityFit( 11 | const std::string& name, 12 | const std::string& planName, 13 | const StringStringMap& inputDepMap, 14 | const DecisionService* const decisionService) 15 | : Operator(name, planName, inputDepMap, decisionService) {} 16 | 17 | virtual ~PropensityFit() override = default; 18 | 19 | virtual OperatorData run( 20 | const DecisionRequest& request, 21 | const StringOperatorDataMap& namedInputs) override; 22 | 23 | virtual StringDoubleMap run(const StringDoubleMap& input); 24 | 25 | void giveFeedback( 26 | const Feedback& feedback, 27 | const StringOperatorDataMap& pastInputs, 28 | const OperatorData& pastOuptut) override; 29 | 30 | void giveFeedbackInternal( 31 | const Feedback& feedback, 32 | const StringOperatorDataMap& pastInputs, 33 | const StringDoubleMap& pastOuptut, 34 | const StringDoubleMap& targets); 35 | 36 | double getShift(const std::string& actionName); 37 | 38 | protected: 39 | inline std::string getParameterName(const std::string& configeratorPath) { 40 | return configeratorPath + std::string("/") + name_; 41 | } 42 | }; 43 | } // namespace reagent 44 | -------------------------------------------------------------------------------- /serving/reagent/serving/operators/Softmax.cpp: -------------------------------------------------------------------------------- 1 | #include "reagent/serving/operators/Softmax.h" 2 | 3 | #include "reagent/serving/core/OperatorFactory.h" 4 | 5 | namespace reagent { 6 | OperatorData Softmax::run( 7 | const DecisionRequest&, 8 | const StringOperatorDataMap& namedInputs) { 9 | const StringDoubleMap& input = 10 | std::get(namedInputs.at("values")); 11 | double temperature = std::get(namedInputs.at("temperature")); 12 | OperatorData ret = run(input, temperature); 13 | return ret; 14 | } 15 | 16 | StringDoubleMap Softmax::run(const StringDoubleMap& input, double temperature) { 17 | Eigen::ArrayXd v(input.size()); 18 | StringList names; 19 | for (auto& it : input) { 20 | v[names.size()] = it.second; 21 | names.emplace_back(it.first); 22 | } 23 | v -= v.maxCoeff(); 24 | v /= temperature; 25 | v = v.exp(); 26 | v /= v.sum(); 27 | StringDoubleMap retval; 28 | for (int a = 0; a < int(names.size()); a++) { 29 | retval[names[a]] = v[a]; 30 | } 31 | return retval; 32 | } 33 | 34 | REGISTER_OPERATOR(Softmax, "Softmax"); 35 | 36 | } // namespace reagent 37 | -------------------------------------------------------------------------------- /serving/reagent/serving/operators/Softmax.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "reagent/serving/core/Headers.h" 4 | #include "reagent/serving/core/Operator.h" 5 | 6 | namespace reagent { 7 | class Softmax : public Operator { 8 | public: 9 | Softmax( 10 | const std::string& name, 11 | const std::string& planName, 12 | const StringStringMap& inputDepMap, 13 | const DecisionService* const decisionService) 14 | : Operator(name, planName, inputDepMap, decisionService) {} 15 | 16 | virtual ~Softmax() override = default; 17 | 18 | virtual OperatorData run( 19 | const DecisionRequest& request, 20 | const StringOperatorDataMap& namedInputs) override; 21 | 22 | virtual StringDoubleMap run(const StringDoubleMap& input, double temperature); 23 | }; 24 | } // namespace reagent 25 | -------------------------------------------------------------------------------- /serving/reagent/serving/operators/SoftmaxRanker.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "reagent/serving/core/Headers.h" 4 | #include "reagent/serving/core/Operator.h" 5 | 6 | namespace reagent { 7 | class SoftmaxRanker : public Operator { 8 | public: 9 | SoftmaxRanker( 10 | const std::string& name, 11 | const std::string& planName, 12 | const StringStringMap& inputDepMap, 13 | const DecisionService* const decisionService); 14 | 15 | virtual ~SoftmaxRanker() override = default; 16 | 17 | virtual OperatorData run( 18 | const DecisionRequest& request, 19 | const StringOperatorDataMap& namedInputs) override; 20 | 21 | virtual RankedActionList run( 22 | const StringDoubleMap& input, 23 | double temperature); 24 | 25 | protected: 26 | std::mt19937 generator_; 27 | }; 28 | } // namespace reagent 29 | -------------------------------------------------------------------------------- /serving/reagent/serving/operators/Ucb.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "reagent/serving/core/Headers.h" 4 | 5 | #include "reagent/serving/core/Operator.h" 6 | 7 | namespace reagent { 8 | class Ucb : public Operator { 9 | public: 10 | Ucb(const std::string& name, 11 | const std::string& planName, 12 | const StringStringMap& inputDepMap, 13 | const DecisionService* const decisionService) 14 | : Operator(name, planName, inputDepMap, decisionService) { 15 | int seed = std::chrono::system_clock::now().time_since_epoch().count(); 16 | generator_.seed(seed); 17 | } 18 | 19 | virtual ~Ucb() override = default; 20 | 21 | virtual OperatorData run( 22 | const DecisionRequest& request, 23 | const StringOperatorDataMap& namedInputs) override; 24 | 25 | RankedActionList runInternal( 26 | const DecisionRequest& request, 27 | const std::string& method); 28 | 29 | virtual void giveFeedback( 30 | const Feedback& feedback, 31 | const StringOperatorDataMap& pastInputs, 32 | const OperatorData& pastOuptut) override; 33 | 34 | double getArmExpectation(const std::string& armName); 35 | 36 | protected: 37 | std::mt19937 generator_; 38 | }; 39 | } // namespace reagent 40 | -------------------------------------------------------------------------------- /serving/reagent/serving/test/Expression_test.cpp: -------------------------------------------------------------------------------- 1 | #include "reagent/serving/test/TestHeaders.h" 2 | 3 | #include "reagent/serving/operators/Expression.h" 4 | 5 | namespace reagent { 6 | 7 | TEST(ExpressionTests, Simple) { 8 | auto service = makeTestDecisionService(); 9 | const auto PLAN_NAME = std::string("/"); 10 | StringOperatorDataMap namedInputs; 11 | 12 | namedInputs["equation"] = "(x*y)^z"; 13 | 14 | StringDoubleMap x = {{"e1", 2.0}, {"e2", 2.0}, {"e3", 2.0}}; 15 | namedInputs["x"] = x; 16 | StringDoubleMap y = {{"e1", 3.0}, {"e2", 3.0}, {"e3", 3.0}}; 17 | namedInputs["y"] = y; 18 | StringDoubleMap z = {{"e1", 1.0}, {"e2", 2.0}, {"e3", 3.0}}; 19 | namedInputs["z"] = z; 20 | 21 | Expression expression("expression", PLAN_NAME, {}, service.get()); 22 | 23 | StringDoubleMap expectedOutput = {{"e1", 6.0}, {"e2", 36.0}, {"e3", 216.0}}; 24 | EXPECT_SYMBOLTABLE_NEAR( 25 | std::get(expression.run(DecisionRequest(), namedInputs)), 26 | expectedOutput); 27 | } 28 | 29 | } // namespace reagent 30 | -------------------------------------------------------------------------------- /serving/reagent/serving/test/InputFromRequest_test.cpp: -------------------------------------------------------------------------------- 1 | #include "reagent/serving/test/TestHeaders.h" 2 | 3 | #include "reagent/serving/operators/InputFromRequest.h" 4 | 5 | namespace reagent { 6 | 7 | TEST(InputFromRequestTests, Simple) { 8 | auto service = makeTestDecisionService(); 9 | const auto PLAN_NAME = std::string("/"); 10 | const auto INPUT_DATA = 100; 11 | 12 | DecisionRequest request; 13 | OperatorData input = int64_t(INPUT_DATA); 14 | request.input = input; 15 | 16 | EXPECT_EQ( 17 | std::get( 18 | InputFromRequest("input_from_request", PLAN_NAME, {}, service.get()) 19 | .run(request, {})), 20 | INPUT_DATA); 21 | } 22 | 23 | } // namespace reagent 24 | -------------------------------------------------------------------------------- /serving/reagent/serving/test/SoftmaxRanker_test.cpp: -------------------------------------------------------------------------------- 1 | #include "reagent/serving/test/TestHeaders.h" 2 | 3 | #include "reagent/serving/operators/SoftmaxRanker.h" 4 | 5 | namespace reagent { 6 | 7 | TEST(SoftmaxRankerTests, SimpleSort) { 8 | auto service = makeTestDecisionService(); 9 | const auto PLAN_NAME = std::string("/"); 10 | StringOperatorDataMap namedInputs; 11 | 12 | StringDoubleMap values = {{"1", 1.0}, {"2", 1000.0}}; 13 | namedInputs["values"] = (values); 14 | namedInputs["temperature"] = (0.01); 15 | namedInputs["seed"] = (int64_t(1)); 16 | 17 | auto result = std::get( 18 | SoftmaxRanker("softmaxranker", PLAN_NAME, {}, service.get()) 19 | .run(DecisionRequest(), namedInputs)); 20 | 21 | StringList expectedResult = {"2", "1"}; 22 | for (int i = 0; i < int(expectedResult.size()); i++) { 23 | EXPECT_EQ(result[i].name, expectedResult[i]); 24 | } 25 | } 26 | 27 | } // namespace reagent 28 | -------------------------------------------------------------------------------- /serving/reagent/serving/test/Softmax_test.cpp: -------------------------------------------------------------------------------- 1 | #include "reagent/serving/test/TestHeaders.h" 2 | 3 | #include "reagent/serving/operators/Softmax.h" 4 | 5 | namespace reagent { 6 | 7 | TEST(SoftmaxTests, OneUniform) { 8 | auto service = makeTestDecisionService(); 9 | const auto PLAN_NAME = std::string("/"); 10 | StringOperatorDataMap namedInputs; 11 | 12 | StringDoubleMap values = {{"1", 3.0}, {"2", 3.0}}; 13 | namedInputs["values"] = (values); 14 | namedInputs["temperature"] = (1.0); 15 | 16 | StringDoubleMap expectedOutput = {{"1", 0.5}, {"2", 0.5}}; 17 | EXPECT_SYMBOLTABLE_NEAR( 18 | std::get(Softmax("softmax", PLAN_NAME, {}, service.get()) 19 | .run(DecisionRequest(), namedInputs)), 20 | expectedOutput); 21 | } 22 | 23 | } // namespace reagent 24 | -------------------------------------------------------------------------------- /serving/reagent/serving/test/TestHeaders.cpp: -------------------------------------------------------------------------------- 1 | #include "reagent/serving/test/TestHeaders.h" 2 | 3 | #include "reagent/serving/core/ConfigProvider.h" 4 | #include "reagent/serving/core/InMemoryLogJoiner.h" 5 | #include "reagent/serving/core/LocalRealTimeCounter.h" 6 | #include "reagent/serving/core/SharedParameterHandler.h" 7 | 8 | namespace reagent { 9 | std::shared_ptr makeTestDecisionService() { 10 | return std::make_shared( 11 | std::make_shared(), 12 | std::shared_ptr(), 13 | std::make_shared(), 14 | std::make_shared(), 15 | std::make_shared()); 16 | } 17 | 18 | } // namespace reagent 19 | -------------------------------------------------------------------------------- /serving/reagent/serving/test/TestHeaders.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "reagent/serving/core/DecisionService.h" 6 | #include "reagent/serving/core/Headers.h" 7 | #include "reagent/serving/core/Operator.h" 8 | 9 | namespace reagent { 10 | 11 | inline void EXPECT_SYMBOLTABLE_NEAR( 12 | const StringDoubleMap& st1, 13 | const StringDoubleMap& st2) { 14 | std::set keys1; 15 | std::transform( 16 | st1.begin(), st1.end(), std::inserter(keys1, keys1.end()), [](auto pair) { 17 | return pair.first; 18 | }); 19 | 20 | std::set keys2; 21 | std::transform( 22 | st2.begin(), st2.end(), std::inserter(keys2, keys2.end()), [](auto pair) { 23 | return pair.first; 24 | }); 25 | 26 | EXPECT_EQ(keys1, keys2); 27 | 28 | for (auto& it : st1) { 29 | EXPECT_NEAR(it.second, st2.find(it.first)->second, 1e-3); 30 | } 31 | } 32 | 33 | inline void EXPECT_RANKEDACTIONLIST_NEAR( 34 | const RankedActionList& st1, 35 | const RankedActionList& st2) { 36 | EXPECT_EQ(st1.size(), st2.size()); 37 | for (int a = 0; a < int(st1.size()); a++) { 38 | EXPECT_EQ(st1[a].name, st2[a].name); 39 | EXPECT_NEAR(st1[a].propensity, st2[a].propensity, 1e-3); 40 | } 41 | } 42 | 43 | std::shared_ptr makeTestDecisionService(); 44 | 45 | } // namespace reagent 46 | -------------------------------------------------------------------------------- /serving/requirements.txt: -------------------------------------------------------------------------------- 1 | python>=3.8 2 | -------------------------------------------------------------------------------- /serving/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ReAgent/9e707c093a00c10947fda883a5848daace53d76d/serving/scripts/__init__.py -------------------------------------------------------------------------------- /serving/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | from setuptools import find_packages, setup 5 | 6 | 7 | def readme(): 8 | with open("README.md") as f: 9 | return f.read() 10 | 11 | 12 | setup( 13 | name="ReAgentServing", 14 | version="0.1", 15 | author="Facebook", 16 | description=("ReAgent Serving Platform"), 17 | long_description=readme(), 18 | url="https://github.com/facebookresearch/ReAgent", 19 | license="BSD", 20 | packages=find_packages(), 21 | install_requires=[], 22 | dependency_links=[], 23 | ) 24 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | 4 | from setuptools import setup 5 | 6 | 7 | # see config.cfg 8 | setup() 9 | --------------------------------------------------------------------------------