├── .clang-format ├── .clang-tidy ├── .github ├── codeql │ └── codeql-config.yml └── workflows │ ├── asan.yml │ ├── build_docs.yml │ ├── build_nuget.yml │ ├── build_python_wheels.yml │ ├── build_rlclientlib.yml │ ├── build_vw_bp.yml │ ├── codeql-analysis.yml │ ├── daily_integration.yml │ ├── dotnet_nugets.yml │ ├── e2e_testing.yml │ ├── lint.yml │ ├── run_benchmarks.yml │ └── vcpkg_build.yml ├── .gitignore ├── .gitmodules ├── .scripts ├── linux │ ├── clang-format.sh │ └── run-clang-tidy.sh └── version_number.py ├── .travis.yml ├── CMakeLists.txt ├── CMakePresets.json ├── LICENSE.txt ├── README.md ├── benchmarks ├── CMakeLists.txt ├── README.md ├── benchmark_cb_v2.cc ├── benchmark_ccb.cc ├── benchmark_common.cc ├── benchmark_common.h ├── benchmark_init.cc ├── benchmark_main.cc └── models │ ├── cb_explore_adf_half.m │ ├── cb_explore_adf_large.m │ └── cb_explore_adf_small.m ├── bindings ├── cs │ ├── .config │ │ └── dotnet-tools.json │ ├── CMakeLists.txt │ ├── common │ │ └── codegen │ │ │ ├── InternalsVisibleToTest.tt │ │ │ └── TextTemplate.targets │ ├── rl.net.cli.test │ │ ├── CMakeLists.txt │ │ ├── CleanupContainer.cs │ │ ├── LoopUsageTest.cs │ │ ├── MockSender.cs │ │ ├── ReplayStepProviderTest.cs │ │ ├── SenderExtensibilityCBLoopTest.cs │ │ ├── SenderExtensibilityTest.cs │ │ ├── TempFileDisposable.cs │ │ ├── TestBase.cs │ │ ├── UnicodeTest.cs │ │ └── rl.net.cli.test.csproj │ ├── rl.net.cli │ │ ├── App.config │ │ ├── BasicUsageCommand.cs │ │ ├── CMakeLists.txt │ │ ├── CommandBase.cs │ │ ├── EntryPoints.cs │ │ ├── Helpers.cs │ │ ├── InternalsVisibleToTest.tt │ │ ├── PerfTestCommand.cs │ │ ├── PerfTestStepProvider.cs │ │ ├── Person.cs │ │ ├── Properties │ │ │ └── launchSettings.json │ │ ├── README.md │ │ ├── RLDriver.cs │ │ ├── RLSimulator.cs │ │ ├── ReplayCommand.cs │ │ ├── ReplayStepProvider.cs │ │ ├── RobotJoint.cs │ │ ├── RunSimulatorCommand.cs │ │ ├── Statistics.cs │ │ ├── StatisticsCalculator.cs │ │ ├── Stats.cs │ │ ├── StatsReplayCommand.cs │ │ └── rl.net.cli.csproj │ ├── rl.net.native │ │ ├── CMakeLists.txt │ │ ├── binding_sender.cc │ │ ├── binding_sender.h │ │ ├── binding_static_model.cc │ │ ├── binding_static_model.h │ │ ├── binding_tracer.cc │ │ ├── binding_tracer.h │ │ ├── packages.config │ │ ├── rl.net.api_status.cc │ │ ├── rl.net.api_status.h │ │ ├── rl.net.azure_factories.cc │ │ ├── rl.net.azure_factories.h │ │ ├── rl.net.buffer.cc │ │ ├── rl.net.buffer.h │ │ ├── rl.net.ca_loop.cc │ │ ├── rl.net.ca_loop.h │ │ ├── rl.net.cb_loop.cc │ │ ├── rl.net.cb_loop.h │ │ ├── rl.net.ccb_loop.cc │ │ ├── rl.net.ccb_loop.h │ │ ├── rl.net.config.cc │ │ ├── rl.net.config.h │ │ ├── rl.net.continuous_action_response.cc │ │ ├── rl.net.continuous_action_response.h │ │ ├── rl.net.decision_response.cc │ │ ├── rl.net.decision_response.h │ │ ├── rl.net.episode_state.cc │ │ ├── rl.net.episode_state.h │ │ ├── rl.net.factory_context.cc │ │ ├── rl.net.factory_context.h │ │ ├── rl.net.live_model.cc │ │ ├── rl.net.live_model.h │ │ ├── rl.net.loop_context.h │ │ ├── rl.net.multi_slot_response.cc │ │ ├── rl.net.multi_slot_response.h │ │ ├── rl.net.multi_slot_response_detailed.cc │ │ ├── rl.net.multi_slot_response_detailed.h │ │ ├── rl.net.native.cc │ │ ├── rl.net.native.h │ │ ├── rl.net.native.vcxproj │ │ ├── rl.net.ranking_response.cc │ │ ├── rl.net.ranking_response.h │ │ ├── rl.net.slates_loop.cc │ │ ├── rl.net.slates_loop.h │ │ ├── rl.net.slot_ranking.cc │ │ └── rl.net.slot_ranking.h │ └── rl.net │ │ ├── ActionFlags.cs │ │ ├── ApiStatus.cs │ │ ├── AsyncSender.cs │ │ ├── CALoop.cs │ │ ├── CBLoop.cs │ │ ├── CCBLoop.cs │ │ ├── CMakeLists.txt │ │ ├── Configuration.cs │ │ ├── ContinuousActionResponse.cs │ │ ├── DecisionResponse.cs │ │ ├── EpisodeState.cs │ │ ├── FactoryContext.cs │ │ ├── ILoop.cs │ │ ├── ISender.cs │ │ ├── InternalsVisibleToTest.tt │ │ ├── LiveModel.cs │ │ ├── LiveModelThreadSafe.cs │ │ ├── MultiSlotResponse.cs │ │ ├── MultiSlotResponseDetailed.cs │ │ ├── Native │ │ ├── ErrorCallback.cs │ │ ├── GCHandleLifetime.cs │ │ ├── Global.cs │ │ ├── NativeImports.cs │ │ ├── NativeObject.cs │ │ ├── SenderAdapter.cs │ │ └── StringExtensions.cs │ │ ├── NativeCallbacks.cs │ │ ├── OAuthCredentialProvider.cs │ │ ├── RLException.cs │ │ ├── RLLibLogUtils.cs │ │ ├── RankingResponse.cs │ │ ├── SharedBuffer.cs │ │ ├── SlatesLoop.cs │ │ ├── SlotRanking.cs │ │ ├── TraceLogEventArgs.cs │ │ └── rl.net.csproj └── python │ ├── CMakeLists.txt │ ├── README.md │ ├── docs │ ├── Makefile │ ├── conf.py │ ├── constants.rst │ ├── index.rst │ ├── migration_guide.rst │ └── rl_client.rst │ ├── py_api.cc │ └── test │ ├── log_test_driver.py │ ├── unit_test.py │ └── verify_logs.py ├── cmake ├── DetectCXXStandard.cmake ├── Modules │ ├── FindDotnet.cmake │ ├── FindFlatbuffers.cmake │ ├── FindOnnxRuntime.cmake │ └── FlatbufferUtils.cmake └── platforms │ └── win32.cmake ├── custom-triplets ├── x64-windows-static-md-v141.cmake └── x64-windows-v141.cmake ├── doc ├── cpp │ ├── .gitignore │ ├── Doxyfile │ ├── api_config.dox │ ├── api_context_format.dox │ ├── api_error_codes.dox │ ├── build.dox │ ├── compression.dox │ ├── mainpage.dox │ └── rl-loop.GIF └── readme.md ├── examples ├── CMakeLists.txt ├── basic_usage_cpp │ ├── CMakeLists.txt │ ├── basic_usage_cpp.cc │ ├── basic_usage_cpp.h │ └── client.json ├── onnx │ ├── CMakeLists.txt │ ├── onnx_example.cc │ └── readme.md ├── override_interface │ ├── CMakeLists.txt │ ├── client.json │ └── override_interface.cc ├── python │ ├── basic_usage.py │ └── rl_sim.py ├── rl_sim_cpp │ ├── CMakeLists.txt │ ├── README.md │ ├── client.json │ ├── current │ ├── main.cc │ ├── person.cc │ ├── person.h │ ├── rand48.h │ ├── rl_sim.cc │ ├── rl_sim.h │ ├── rl_sim_cpp.h │ ├── robot_joint.cc │ ├── robot_joint.h │ ├── simulation_stats.h │ └── targetver.h └── test_cpp │ ├── CMakeLists.txt │ ├── experiment_controller.cc │ ├── experiment_controller.h │ ├── main.cc │ ├── model.vw │ ├── options.cc │ ├── options.h │ ├── scripts │ ├── perf_test.cmd │ └── perf_test.sh │ ├── test_data_provider.cc │ ├── test_data_provider.h │ ├── test_loop.cc │ └── test_loop.h ├── ext_libs ├── date │ ├── CMakeLists.txt │ └── date.h ├── ext_libs.cmake ├── fakeit │ ├── CMakeLists.txt │ └── fakeit │ │ └── fakeit.hpp └── string-view-lite │ ├── CMakeLists.txt │ └── nonstd │ └── string_view.hpp ├── external_parser ├── CMakeLists.txt ├── README.md ├── event_processors │ ├── joined_event.h │ ├── loop.h │ ├── metadata.h │ ├── reward.h │ ├── timestamp_helper.cc │ ├── timestamp_helper.h │ └── typed_events.h ├── joiners │ ├── example_joiner.cc │ ├── example_joiner.h │ ├── i_joiner.h │ ├── multistep_example_joiner.cc │ └── multistep_example_joiner.h ├── log_converter.cc ├── log_converter.h ├── lru_dedup_cache.cc ├── lru_dedup_cache.h ├── main.cc ├── metrics │ └── metrics.h ├── parse_example_binary.cc ├── parse_example_binary.h ├── parse_example_converter.cc ├── parse_example_converter.h ├── parse_example_external.cc ├── parse_example_external.h ├── unit_tests │ ├── CMakeLists.txt │ ├── main.cc │ ├── test_client_and_enqueued_time.cc │ ├── test_common.cc │ ├── test_common.h │ ├── test_example_joiner.cc │ ├── test_files │ │ ├── README.md │ │ ├── client_time │ │ │ ├── cb_v2_client_time.fb │ │ │ └── f-reward_3obs_v2_client_time.fb │ │ ├── fb_events │ │ │ ├── ca_v2.fb │ │ │ ├── ca_v2_size_2.fb │ │ │ ├── cb_v2.fb │ │ │ ├── cb_v2_dedup.fb │ │ │ ├── cb_v2_size_2.fb │ │ │ ├── cb_v2_size_5_apprentice.fb │ │ │ ├── ccb-baseline-loopinteractions_v2.fb │ │ │ ├── ccb-baseline-loopobservations_v2.fb │ │ │ ├── ccb_v2.fb │ │ │ ├── f-reward_v2.fb │ │ │ ├── f-reward_v2_size_2.fb │ │ │ ├── f-reward_v2_size_5_apprentice.fb │ │ │ ├── fi-reward_v2.fb │ │ │ └── invalid-cb_v2.fb │ │ ├── invalid_joined_logs │ │ │ ├── bad_event_in_joined_event.log │ │ │ ├── bad_magic.log │ │ │ ├── bad_version.log │ │ │ ├── corrupt_joined_payload.log │ │ │ ├── dedup_payload_missing.log │ │ │ ├── empty_msg_hdr.log │ │ │ ├── incomplete_checkpoint_info.log │ │ │ ├── interaction_with_no_observation.log │ │ │ ├── invalid_cb_context.log │ │ │ ├── no_interaction_but_with_observation.log │ │ │ ├── no_msg_hdr.log │ │ │ └── one_invalid_msg_type.log │ │ ├── reward_functions │ │ │ ├── ca │ │ │ │ ├── ca_v2.fb │ │ │ │ └── f-reward_3obs_v2.fb │ │ │ ├── cb │ │ │ │ ├── cb_apprentice_match_baseline_v2.fb │ │ │ │ ├── cb_v2.fb │ │ │ │ └── f-reward_3obs_v2.fb │ │ │ ├── ccb │ │ │ │ ├── ccb-apprentice-baseline-match_v2.fb │ │ │ │ ├── ccb-apprentice-baseline-not-match_v2.fb │ │ │ │ ├── ccb-with-slot-id_v2.fb │ │ │ │ ├── ccb_v2.fb │ │ │ │ ├── fi-out-of-bound-reward_v2.fb │ │ │ │ ├── fi-reward_v2.fb │ │ │ │ ├── fmix-reward_v2.fb │ │ │ │ └── fs-reward_v2.fb │ │ │ └── slates │ │ │ │ ├── fi-reward_v2.fb │ │ │ │ └── slates_v2.fb │ │ ├── skip_learn │ │ │ ├── ca │ │ │ │ ├── deferred_action_with_activation.fb │ │ │ │ ├── deferred_action_with_activation_deduped.fb │ │ │ │ ├── deferred_action_without_activation.fb │ │ │ │ ├── deferred_action_without_activation_deduped.fb │ │ │ │ └── mixed_deferred_action_events.fb │ │ │ ├── cb │ │ │ │ ├── deferred_action_with_activation.fb │ │ │ │ ├── deferred_action_with_activation_deduped.fb │ │ │ │ ├── deferred_action_without_activation.fb │ │ │ │ ├── deferred_action_without_activation_deduped.fb │ │ │ │ └── mixed_deferred_action_events.fb │ │ │ ├── ccb │ │ │ │ ├── deferred_action_with_activation.fb │ │ │ │ ├── deferred_action_with_activation_deduped.fb │ │ │ │ ├── deferred_action_without_activation.fb │ │ │ │ ├── deferred_action_without_activation_deduped.fb │ │ │ │ └── mixed_deferred_action_events.fb │ │ │ └── slates │ │ │ │ ├── deferred_action_with_activation.fb │ │ │ │ ├── deferred_action_without_activation.fb │ │ │ │ └── mixed_deferred_action_events.fb │ │ ├── test_outputs │ │ │ └── .gitignore │ │ └── valid_joined_logs │ │ │ ├── average_reward_100_interactions.fb │ │ │ ├── average_reward_100_interactions.json │ │ │ ├── ca_loop_mixed_skip_learn.fb │ │ │ ├── ca_loop_mixed_skip_learn.json │ │ │ ├── ca_loop_simple.fb │ │ │ ├── ca_loop_simple.json │ │ │ ├── ca_loop_simple_e2e.log │ │ │ ├── ca_loop_skip_learn_e2e.log │ │ │ ├── ca_mixed_deferred_action_events_20.log │ │ │ ├── cb_apprentice_5.log │ │ │ ├── cb_dedup_compressed.log │ │ │ ├── cb_deferred_actions_w_activations_and_apprentice_10.fb │ │ │ ├── cb_deferred_actions_w_activations_and_apprentice_10.json │ │ │ ├── cb_joined_with_pdrop_05.fb │ │ │ ├── cb_joined_with_pdrop_1.fb │ │ │ ├── cb_simple.log │ │ │ ├── ccb_apprentice_5.log │ │ │ ├── ccb_deferred_actions_w_activations_and_apprentice_20.fb │ │ │ ├── ccb_deferred_actions_w_activations_and_apprentice_20.json │ │ │ ├── ccb_simple.log │ │ │ ├── ccb_sum_reward_100_interactions.fb │ │ │ ├── ccb_sum_reward_100_interactions.json │ │ │ ├── ccb_w_slot_id.log │ │ │ ├── ccb_w_various_outcomes.log │ │ │ ├── multistep_2_episodes.fb │ │ │ ├── multistep_3_deferred_episodes.fb │ │ │ ├── multistep_unordered_episodes.fb │ │ │ ├── rcrrmr.fb │ │ │ ├── rrcr.fb │ │ │ ├── slates_average_reward_100_interactions.fb │ │ │ ├── slates_average_reward_100_interactions.json │ │ │ ├── slates_deferred_actions_w_activations_10.fb │ │ │ ├── slates_deferred_actions_w_activations_10.json │ │ │ └── slates_simple.log │ ├── test_log_converter.cc │ ├── test_lru_dedup_cache.cc │ ├── test_metrics.cc │ ├── test_reward_functions.cc │ ├── test_skip_learn.cc │ ├── test_timestamp_helper.cc │ ├── test_vw_binary_parser.cc │ └── test_vw_external_parser.cc ├── utils.cc └── utils.h ├── include ├── action_flags.h ├── api_status.h ├── azure_credentials_provider.h ├── config_utility.h ├── configuration.h ├── constants.h ├── container_iterator.h ├── continuous_action_response.h ├── data_buffer.h ├── decision_response.h ├── err_constants.h ├── error_callback_fn.h ├── errors_data.h ├── factory_resolver.h ├── future_compat.h ├── internal_constants.h ├── learning_mode.h ├── live_model.h ├── loop_apis │ ├── README.md │ ├── base_loop.h │ ├── ca_loop.h │ ├── cb_loop.h │ ├── ccb_loop.h │ ├── multistep_loop.h │ └── slates_loop.h ├── model_mgmt.h ├── multi_slot_response.h ├── multi_slot_response_detailed.h ├── multistep.h ├── oauth_callback_fn.h ├── object_factory.h ├── personalization.h ├── ranking_response.h ├── rl_string_view.h ├── sender.h ├── slot_ranking.h ├── str_util.h └── trace_logger.h ├── nuget ├── CMakeLists.txt ├── CreateNugetPackage.cmake ├── dotnet │ ├── rl.net.nuspec │ ├── rl.net.props │ ├── rl.net.targets │ └── test │ │ ├── client.json │ │ ├── dotnetcore_nuget_test.csproj │ │ └── nuget_test.cs ├── nuget.exe ├── rlclientlib.nuspec.in ├── rlclientlib.targets.in └── test │ ├── main.cc │ └── test_rl_nuget.vcxproj ├── rlclientlib ├── CMakeLists.txt ├── api_status.cc ├── azure_factories.cc ├── azure_factories.h ├── base_loop.cc ├── ca_loop.cc ├── cb_loop.cc ├── ccb_loop.cc ├── console_tracer.cc ├── console_tracer.h ├── constants.cc ├── continuous_action_response.cc ├── decision_response.cc ├── dedup.cc ├── dedup.h ├── dedup_internals.h ├── error_callback_fn.cc ├── extensions │ ├── CMakeLists.txt │ └── onnx │ │ ├── CMakeLists.txt │ │ ├── include │ │ └── onnx_extension.h │ │ └── src │ │ ├── onnx_extension.cc │ │ ├── onnx_input.cc │ │ ├── onnx_input.h │ │ ├── onnx_model.cc │ │ ├── onnx_model.h │ │ ├── tensor_parser.cc │ │ └── tensor_parser.h ├── factory_resolver.cc ├── federation │ ├── federated_client.h │ └── joined_log_provider.h ├── generic_event.cc ├── generic_event.h ├── learning_mode.cc ├── live_model.cc ├── live_model_impl.cc ├── live_model_impl.h ├── logger │ ├── async_batcher.h │ ├── endian.cc │ ├── endian.h │ ├── event_logger.cc │ ├── event_logger.h │ ├── event_queue.h │ ├── file │ │ ├── file_logger.cc │ │ └── file_logger.h │ ├── flatbuffer_allocator.cc │ ├── flatbuffer_allocator.h │ ├── http_transport_client.h │ ├── logger_extensions.cc │ ├── logger_extensions.h │ ├── logger_facade.cc │ ├── logger_facade.h │ ├── message_sender.h │ ├── message_type.h │ ├── preamble.cc │ ├── preamble.h │ ├── preamble_sender.cc │ └── preamble_sender.h ├── model_mgmt │ ├── data_callback_fn.cc │ ├── data_callback_fn.h │ ├── empty_data_transport.cc │ ├── empty_data_transport.h │ ├── file_model_loader.cc │ ├── file_model_loader.h │ ├── model_downloader.cc │ ├── model_downloader.h │ ├── model_mgmt.cc │ ├── restapi_data_transport.cc │ ├── restapi_data_transport.h │ ├── restapi_data_transport_oauth.cc │ └── restapi_data_transport_oauth.h ├── moving_queue.h ├── multi_slot_response.cc ├── multi_slot_response_detailed.cc ├── multistep.cc ├── multistep_loop.cc ├── ranking_event.cc ├── ranking_event.h ├── ranking_response.cc ├── sampling.cc ├── sampling.h ├── schema │ ├── v1 │ │ ├── DecisionRankingEvent.fbs │ │ ├── Metadata.fbs │ │ ├── OutcomeEvent.fbs │ │ ├── RankingEvent.fbs │ │ └── SlatesEvent.fbs │ └── v2 │ │ ├── CaEvent.fbs │ │ ├── CbEvent.fbs │ │ ├── DedupInfo.fbs │ │ ├── Event.fbs │ │ ├── FileFormat.fbs │ │ ├── LearningModeType.fbs │ │ ├── Metadata.fbs │ │ ├── MultiSlotEvent.fbs │ │ ├── MultiStepEvent.fbs │ │ ├── OutcomeEvent.fbs │ │ └── ProblemType.fbs ├── serialization │ ├── fb_serializer.h │ ├── json_serializer.h │ ├── payload_serializer.cc │ └── payload_serializer.h ├── slates_loop.cc ├── slot_ranking.cc ├── time_helper.cc ├── time_helper.h ├── trace_logger.cc ├── utility │ ├── api_header_token.h │ ├── config_helper.cc │ ├── config_helper.h │ ├── config_utility.cc │ ├── configuration.cc │ ├── context_helper.cc │ ├── context_helper.h │ ├── data_buffer.cc │ ├── data_buffer_streambuf.cc │ ├── data_buffer_streambuf.h │ ├── eventhub_http_authorization.cc │ ├── eventhub_http_authorization.h │ ├── header_authorization.cc │ ├── header_authorization.h │ ├── http_client.cc │ ├── http_client.h │ ├── http_helper.cc │ ├── http_helper.h │ ├── interruptable_sleeper.h │ ├── object_pool.h │ ├── periodic_background_proc.h │ ├── stl_container_adapter.cc │ ├── stl_container_adapter.h │ ├── str_util.cc │ ├── versioned_object_pool.h │ ├── watchdog.cc │ └── watchdog.h └── vw_model │ ├── pdf_model.cc │ ├── pdf_model.h │ ├── safe_vw.cc │ ├── safe_vw.h │ ├── vw_model.cc │ └── vw_model.h ├── setup.py ├── templates ├── README.md └── create-loop.json ├── test_tools ├── e2e_testing │ ├── base_files │ │ ├── input │ │ │ └── multistep │ │ │ │ ├── episode.fbs │ │ │ │ ├── interaction.fbs │ │ │ │ └── observation.fbs │ │ └── output │ │ │ └── multistep │ │ │ └── vw_out.log │ ├── compare_serialized_examples.py │ ├── evaluate_result.py │ ├── multistep_client.json │ └── requirements.txt ├── example_gen │ ├── CMakeLists.txt │ └── example_gen.cc ├── joiner │ ├── CMakeLists.txt │ ├── main.cc │ ├── sample_data │ │ ├── interaction.fb.data │ │ └── observation.fb.data │ ├── text_converter.cc │ └── text_converter.h ├── log_parser │ ├── data.py │ ├── joiner.py │ ├── log_gen.py │ ├── parser.py │ ├── parsing_example.ipynb │ └── test_gen.py ├── onnx_pytorch │ ├── adapters │ │ └── pytorch.py │ ├── common │ │ ├── __init__.py │ │ ├── parser.py │ │ └── types.py │ └── pytorch_train.py ├── reproduce_model.py ├── sender_test │ ├── CMakeLists.txt │ ├── main.cc │ ├── test_loop.cc │ └── test_loop.h └── stdin2rllib │ └── main.cc ├── unit_test ├── CMakeLists.txt ├── async_batcher_test.cc ├── common_test_utils.h ├── configuration_test.cc ├── data.h ├── data_buffer_test.cc ├── data_callback_test.cc ├── dedup_test.cc ├── err_callback_test.cc ├── event_queue_test.cc ├── explore_test.cc ├── extensions │ ├── CMakeLists.txt │ └── onnx │ │ ├── CMakeLists.txt │ │ ├── global_fixture.h │ │ ├── main.cc │ │ ├── mnist_data │ │ ├── data_generator.py │ │ └── mnist_model.onnx │ │ ├── mnist_inference_test.cc │ │ ├── mock_helpers.cc │ │ ├── mock_helpers.h │ │ ├── tensor_notation_test.cc │ │ ├── test_data.h.in │ │ └── test_helpers.h ├── factory_test.cc ├── fb_serializer_test.cc ├── file_logger_test.cc ├── header_auth_test.cc ├── http_client_test.cc ├── http_transport_client_test.cc ├── interaction.txt ├── json_context_parse_test.cc ├── json_serializer_test.cc ├── learning_mode_test.cc ├── live_model_test.cc ├── live_model_test_legacy.cc ├── main.cc ├── mock_http_client.cc ├── mock_http_client.h ├── mock_util.cc ├── mock_util.h ├── model_mgmt_test.cc ├── moving_queue_test.cc ├── multi_slot_response_detailed_test.cc ├── multistep_test.cc ├── object_pool_test.cc ├── observation.txt ├── outcome.json ├── payload_serializer_test.cc ├── preamble_test.cc ├── ranking_context.json ├── ranking_response_test.cc ├── safe_vw_test.cc ├── serializer.cc ├── sleeper_test.cc ├── slot_ranking_test.cc ├── status_builder_test.cc ├── str_util_test.cc ├── time_tests.cc ├── trace_logger_test.cc └── watchdog_test.cc └── vcpkg.json /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: Google 2 | AccessModifierOffset: -2 3 | AlignAfterOpenBracket: DontAlign 4 | AlignOperands: false 5 | AllowShortBlocksOnASingleLine: true 6 | AllowShortCaseLabelsOnASingleLine: false 7 | AllowShortFunctionsOnASingleLine: All 8 | AllowShortIfStatementsOnASingleLine: true 9 | AllowShortLoopsOnASingleLine: true 10 | BreakBeforeBraces: Allman 11 | BreakConstructorInitializersBeforeComma: true 12 | ColumnLimit: 120 13 | SortIncludes: true 14 | IndentPPDirectives: AfterHash 15 | PointerAlignment: Left 16 | DerivePointerAlignment: false 17 | IncludeCategories: 18 | # Boost unit test must be included first if it is used 19 | # or else there will be a compile error 20 | - Regex: '' 21 | Priority: 0 22 | # First block is local directory includes 23 | - Regex: '"[[:alnum:]._\/]+"' 24 | Priority: 1 25 | # Second block is system includes with .h suffix. Usually dependencies. 26 | - Regex: '<[[:alnum:]._\/]+\.h>' 27 | Priority: 2 28 | # STL system deps 29 | - Regex: '<[[:alnum:]._\/]+>' 30 | Priority: 3 31 | # Catch all 32 | - Regex: '.*' 33 | Priority: 4 34 | -------------------------------------------------------------------------------- /.clang-tidy: -------------------------------------------------------------------------------- 1 | { 2 | "Checks": "-*,readability-*,modernize-*,performance-*,cppcoreguidelines-pro-type-member-init,cppcoreguidelines-init-variables,-modernize-use-trailing-return-type,-readability-uppercase-literal-suffix,-readability-container-data-pointer", 3 | "FormatStyle": "file", 4 | "WarningsAsErrors": "-*,performance-*,modernize-use-using,readability-braces-around-statements", 5 | "CheckOptions": [ 6 | { 7 | "key": "performance-move-const-arg.CheckTriviallyCopyableMove", 8 | "value": "0" 9 | }, 10 | { 11 | "key":"readability-identifier-length.IgnoredVariableNames", 12 | "value": "t" 13 | } 14 | ] 15 | } 16 | -------------------------------------------------------------------------------- /.github/codeql/codeql-config.yml: -------------------------------------------------------------------------------- 1 | name: "CodeQL config" 2 | 3 | # Specify the paths to ignore 4 | paths-ignore: 5 | - 'bindings/python/docs/' 6 | - 'ext_libs/' 7 | 8 | # Specify the minimum disk space required 9 | min-disk-free: 2048 # Adjust as necessary. This value is in MB. -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.o 2 | *.a 3 | *.out 4 | 5 | .DS_Store 6 | 7 | # build folders 8 | **/x64 9 | **/obj 10 | **/Debug 11 | **/Release 12 | **/bin 13 | 14 | # VS files 15 | *.vcxproj.user 16 | *.csproj.user 17 | *.VC.db 18 | .vs 19 | # VS CMake configuration 20 | CMakeSettings.json 21 | 22 | # Ignore NuGet Packages 23 | *.nupkg 24 | # Ignore the packages folder 25 | **/packages 26 | # Ignore nuget_staging directory used for building the nuget package 27 | nuget_staging/ 28 | 29 | build 30 | ext_deps 31 | /rlclientlib/generated 32 | /external_parser/generated 33 | /.vscode 34 | /bindings/cs/rl.net/InternalsVisibleToTest.cs 35 | /bindings/cs/rl.net.cli/InternalsVisibleToTest.cs 36 | /test_tools/log_parser/reinforcement_learning/ 37 | _build 38 | .idea 39 | dist/ 40 | rl_client.egg-info/ 41 | 42 | .cache -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "ext_libs/vowpal_wabbit"] 2 | path = ext_libs/vowpal_wabbit 3 | url = https://github.com/VowpalWabbit/vowpal_wabbit.git 4 | branch = master 5 | [submodule "ext_libs/zstd"] 6 | path = ext_libs/zstd 7 | url = https://github.com/facebook/zstd.git 8 | branch = master 9 | [submodule "ext_libs/pybind11"] 10 | path = ext_libs/pybind11 11 | url = ../../pybind/pybind11 12 | branch = stable 13 | [submodule "ext_libs/cpprestsdk"] 14 | path = ext_libs/cpprestsdk 15 | url = https://github.com/microsoft/cpprestsdk.git 16 | branch = master 17 | [submodule "ext_libs/openssl"] 18 | path = ext_libs/openssl 19 | url = https://github.com/openssl/openssl.git 20 | [submodule "ext_libs/vcpkg"] 21 | path = ext_libs/vcpkg 22 | url = https://github.com/microsoft/vcpkg.git 23 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: required 2 | os: linux 3 | services: 4 | - docker 5 | 6 | git: 7 | submodules: true 8 | 9 | before_install: 10 | - docker pull vowpalwabbit/travis-base 11 | script: 12 | - docker run -a STDOUT -v `pwd`:/reinforcement_learning -t vowpalwabbit/travis-base /bin/bash -c "cd /reinforcement_learning && chmod +x ./build-linux.sh && ./build-linux.sh" 13 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2012 - present Microsoft Corporation 4 | 5 | All rights reserved. 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. -------------------------------------------------------------------------------- /benchmarks/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(all_sources 2 | benchmark_cb_v2.cc 3 | benchmark_ccb.cc 4 | benchmark_common.cc 5 | benchmark_init.cc 6 | benchmark_main.cc 7 | ) 8 | 9 | add_executable(rl_benchmarks 10 | ${all_sources} 11 | ) 12 | 13 | find_package(benchmark REQUIRED) 14 | 15 | # Add the include directories from rlclientlib target for testing 16 | target_include_directories(rl_benchmarks PRIVATE $) 17 | target_link_libraries(rl_benchmarks PRIVATE rlclientlib benchmark::benchmark) 18 | 19 | # Communicate that Boost Unit Test is being statically linked 20 | if(RL_STATIC_DEPS) 21 | target_compile_definitions(rl_benchmarks PRIVATE RL_STATIC_DEPS) 22 | endif() 23 | 24 | add_test(rl_benchmarks rl_benchmarks) -------------------------------------------------------------------------------- /benchmarks/README.md: -------------------------------------------------------------------------------- 1 | install google benchmark: https://github.com/google/benchmark 2 | 3 | by default google benchmark is built in Debug mode so you might want to specify Release mode when building/installing 4 | 5 | build rl benchmarks: 6 | 7 | ``` 8 | cmake -S . -B build -DRL_BUILD_BENCHMARKS=ON 9 | cmake --build build --target rl_benchmarks -j $(nproc) 10 | ``` 11 | 12 | run rl benchmarks: 13 | 14 | ``` 15 | ./benchmarks/rl_benchmarks 16 | ``` -------------------------------------------------------------------------------- /benchmarks/benchmark_common.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | class prng 6 | { 7 | uint64_t val; 8 | 9 | public: 10 | prng(uint64_t initial_seed); 11 | uint64_t next_uint(); 12 | }; 13 | 14 | class cb_decision_gen 15 | { 16 | int shared_features, action_features, actions_per_decision; 17 | std::vector actions_set; 18 | prng rand; 19 | bool passthrough; 20 | 21 | public: 22 | cb_decision_gen(int shared_features, int action_features, int actions_per_decision, int total_actions, 23 | int initial_seed, bool passthrough); 24 | 25 | std::string gen_example(); 26 | }; 27 | 28 | class ccb_decision_gen 29 | { 30 | int shared_features_size; // size of set of possible features to choose from 31 | int shared_features_count; // actual number of features per example 32 | int action_features_size; 33 | int action_features_count; 34 | int actions_per_example; 35 | int slots_per_example; 36 | 37 | std::vector actions_set; 38 | prng rand; 39 | 40 | public: 41 | ccb_decision_gen(int shared_features_size, int shared_features_count, int action_features_size, 42 | int action_features_count, int actions_per_example, int slots_per_example, int total_actions, int initial_seed); 43 | 44 | std::string gen_example(); 45 | }; 46 | -------------------------------------------------------------------------------- /benchmarks/benchmark_main.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | BENCHMARK_MAIN(); -------------------------------------------------------------------------------- /benchmarks/models/cb_explore_adf_half.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/benchmarks/models/cb_explore_adf_half.m -------------------------------------------------------------------------------- /benchmarks/models/cb_explore_adf_large.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/benchmarks/models/cb_explore_adf_large.m -------------------------------------------------------------------------------- /benchmarks/models/cb_explore_adf_small.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/benchmarks/models/cb_explore_adf_small.m -------------------------------------------------------------------------------- /bindings/cs/.config/dotnet-tools.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": 1, 3 | "isRoot": true, 4 | "tools": { 5 | "dotnet-t4": { 6 | "version": "2.0.5", 7 | "commands": [ 8 | "t4" 9 | ] 10 | } 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /bindings/cs/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/Modules/") 2 | include(FindDotnet) 3 | 4 | # note: this change was made since building with Ninja does not add suffixes 5 | # but, using the VS generator does. rl.net uses dllimport to load rlnetnative. 6 | # This is a workaround to make sure the correct dll is used. 7 | if (WIN32 AND CMAKE_GENERATOR MATCHES "Visual Studio") 8 | set(CMAKE_DEBUG_POSTFIX "") 9 | endif() 10 | 11 | add_subdirectory(rl.net.native) 12 | add_subdirectory(rl.net) 13 | add_subdirectory(rl.net.cli) 14 | add_subdirectory(rl.net.cli.test) 15 | -------------------------------------------------------------------------------- /bindings/cs/common/codegen/InternalsVisibleToTest.tt: -------------------------------------------------------------------------------- 1 | <#@ template language="C#" hostspecific="true" #> 2 | <#@ assembly name="System.Core" #> 3 | <#@ import namespace="System.Linq" #> 4 | <#@ import namespace="System.Text" #> 5 | <#@ import namespace="System" #> 6 | <#@ import namespace="System.IO" #> 7 | <#@ import namespace="System.Collections.Generic" #> 8 | <#@ output extension=".cs" #> 9 | using System.Reflection; 10 | using System.Runtime.CompilerServices; 11 | using System.Runtime.InteropServices; 12 | 13 | <# 14 | try{ 15 | string snRequired = this.Host.ResolveParameterValue("","","SNRequired"); 16 | string publicKey = this.Host.ResolveParameterValue("","","PublicKey"); 17 | 18 | if(ParseBool(snRequired,false) && !string.IsNullOrWhiteSpace(publicKey)){ 19 | #>[assembly: InternalsVisibleTo("rl.net.cli.test, PublicKey=<#=publicKey#>")]<# 20 | } 21 | else{ 22 | #>[assembly: InternalsVisibleTo("rl.net.cli.test")]<# 23 | } 24 | } 25 | catch(Exception) { 26 | #>[assembly: InternalsVisibleTo( "rl.net.cli.test" )]<# 27 | } 28 | #> 29 | <#+ 30 | public bool ParseBool(string boolval, bool defval) 31 | { 32 | if(string.IsNullOrWhiteSpace(boolval)) 33 | return defval; 34 | bool retval = defval; 35 | bool.TryParse(boolval,out retval); 36 | return retval; 37 | } 38 | #> -------------------------------------------------------------------------------- /bindings/cs/rl.net.cli.test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set (RL_NET_CLI_TEST_SOURCES 2 | CleanupContainer.cs 3 | LoopUsageTest.cs 4 | MockSender.cs 5 | ReplayStepProviderTest.cs 6 | SenderExtensibilityTest.cs 7 | SenderExtensibilityCBLoopTest.cs 8 | TempFileDisposable.cs 9 | TestBase.cs 10 | UnicodeTest.cs 11 | ) 12 | 13 | if (rlclientlib_DOTNET_USE_MSPROJECT) 14 | include_external_msproject(rl.net.cli.test ${CMAKE_CURRENT_SOURCE_DIR}/rl.net.cli.test.csproj rl.net.cli) 15 | else() 16 | # No need to add the other two targets to ALL, because the rl.net.cli.test target will chain-build them 17 | add_custom_target(rl.net.cli.test ALL 18 | COMMAND ${DOTNET_COMMAND} build ${CMAKE_CURRENT_SOURCE_DIR} -o $ -v m --nologo --no-dependencies /clp:NoSummary 19 | COMMENT Building rl.net.cli.test 20 | DEPENDS rl.net.cli 21 | SOURCES ${RL_NET_CLI_TEST_SOURCES}) 22 | endif() 23 | 24 | # TODO: Enable TRX test logging 25 | add_test( 26 | NAME rl.net.cli.test 27 | COMMAND 28 | ${DOTNET_COMMAND} test $/rl.net.cli.test.dll --Platform:x64 --InIsolation "--logger:console;verbosity=default" 29 | ) -------------------------------------------------------------------------------- /bindings/cs/rl.net.cli.test/CleanupContainer.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Threading; 4 | 5 | namespace Rl.Net.Cli.Test 6 | { 7 | public sealed class CleanupContainer : IDisposable 8 | { 9 | private Stack cleanupStack = new Stack(); 10 | 11 | public void Dispose() 12 | { 13 | Stack localCleanupStack = Interlocked.Exchange(ref cleanupStack, new Stack()); 14 | while (localCleanupStack.TryPop(out Action action)) 15 | { 16 | try 17 | { 18 | action?.Invoke(); 19 | } 20 | catch 21 | { 22 | // Suppress errors on TestCleanup. If we want to detect state errors on 23 | // cleanup, it should be an explicit part of a test. 24 | } 25 | } 26 | } 27 | 28 | public void Add(IDisposable disposable) 29 | { 30 | this.cleanupStack.Push(disposable.Dispose); 31 | } 32 | 33 | public void Add(Action action) 34 | { 35 | this.cleanupStack.Push(action); 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /bindings/cs/rl.net.cli.test/TempFileDisposable.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.IO; 3 | 4 | namespace Rl.Net.Cli.Test 5 | { 6 | internal sealed class TempFileDisposable : IDisposable 7 | { 8 | public TempFileDisposable() 9 | { 10 | this.Path = System.IO.Path.GetTempFileName(); 11 | } 12 | 13 | public string Path 14 | { 15 | get; 16 | private set; 17 | } 18 | 19 | public void Dispose() 20 | { 21 | try 22 | { 23 | if (File.Exists(this.Path)) 24 | { 25 | File.Delete(this.Path); 26 | } 27 | 28 | if (Directory.Exists(this.Path)) 29 | { 30 | Directory.Delete(this.Path, recursive: true); 31 | } 32 | } 33 | catch 34 | { 35 | // TestCleanup is best-efforts 36 | } 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /bindings/cs/rl.net.cli.test/TestBase.cs: -------------------------------------------------------------------------------- 1 | using Microsoft.VisualStudio.TestTools.UnitTesting; 2 | using Newtonsoft.Json; 3 | using Newtonsoft.Json.Linq; 4 | using Rl.Net; 5 | 6 | namespace Rl.Net.Cli.Test 7 | { 8 | public abstract class TestBase 9 | { 10 | protected CleanupContainer TestCleanup 11 | { 12 | get; 13 | private set; 14 | } = new CleanupContainer(); 15 | 16 | [TestCleanup] 17 | public void CleanupTest() 18 | { 19 | this.TestCleanup.Dispose(); 20 | } 21 | } 22 | } -------------------------------------------------------------------------------- /bindings/cs/rl.net.cli/App.config: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /bindings/cs/rl.net.cli/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set (RL_NET_CLI_SOURCES 2 | BasicUsageCommand.cs 3 | CommandBase.cs 4 | EntryPoints.cs 5 | Helpers.cs 6 | InternalsVisibleToTest.tt 7 | PerfTestCommand.cs 8 | PerfTestStepProvider.cs 9 | Person.cs 10 | ReplayCommand.cs 11 | ReplayStepProvider.cs 12 | RLDriver.cs 13 | RLSimulator.cs 14 | RobotJoint.cs 15 | RunSimulatorCommand.cs 16 | Statistics.cs 17 | StatisticsCalculator.cs 18 | ) 19 | 20 | if (rlclientlib_DOTNET_USE_MSPROJECT) 21 | include_external_msproject(rl.net.cli ${CMAKE_CURRENT_SOURCE_DIR}/rl.net.cli.csproj rl.net) 22 | else() 23 | add_custom_target(rl.net.cli 24 | COMMAND ${DOTNET_COMMAND} build ${CMAKE_CURRENT_SOURCE_DIR} -o $ -v m --nologo --no-dependencies /clp:NoSummary 25 | COMMENT Building rl.net.cli 26 | DEPENDS rl.net 27 | SOURCES ${RL_NET_CLI_SOURCES}) 28 | endif() -------------------------------------------------------------------------------- /bindings/cs/rl.net.cli/EntryPoints.cs: -------------------------------------------------------------------------------- 1 | using CommandLine; 2 | using System; 3 | using System.Collections.Generic; 4 | using System.IO; 5 | 6 | namespace Rl.Net.Cli 7 | { 8 | static class EntryPoints 9 | { 10 | public static void Main(string[] args) 11 | { 12 | Parser.Default.ParseArguments 13 | (args) 14 | .WithParsed(command => command.Run()); 15 | } 16 | 17 | public static IEnumerable LazyReadLines(this TextReader textReader) 18 | { 19 | string line; 20 | while ((line = textReader.ReadLine()) != null) 21 | { 22 | if (string.Empty == line.Trim()) 23 | { 24 | continue; 25 | } 26 | 27 | yield return line; 28 | } 29 | } 30 | } 31 | } -------------------------------------------------------------------------------- /bindings/cs/rl.net.cli/InternalsVisibleToTest.tt: -------------------------------------------------------------------------------- 1 | <#@include file="..\common\codegen\InternalsVisibleToTest.tt" #> -------------------------------------------------------------------------------- /bindings/cs/rl.net.cli/Properties/launchSettings.json: -------------------------------------------------------------------------------- 1 | { 2 | "profiles": { 3 | "rl.net.cli": { 4 | "commandName": "Project", 5 | "nativeDebugging": true 6 | } 7 | } 8 | } -------------------------------------------------------------------------------- /bindings/cs/rl.net.cli/ReplayCommand.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.IO; 4 | using CommandLine; 5 | 6 | namespace Rl.Net.Cli 7 | { 8 | [Verb("replay", HelpText = "Replay existing log")] 9 | class ReplayCommand : CommandBase 10 | { 11 | [Option(longName: "log", HelpText = "path to the log file to replay", Required = true)] 12 | public string LogPath { get; set; } 13 | 14 | [Option(longName: "sleep", HelpText = "sleep interval in milliseconds", Required = false, Default = 100)] 15 | public int SleepIntervalMs { get; set; } 16 | 17 | public override void Run() 18 | { 19 | LiveModel liveModel = Helpers.CreateLiveModelOrExit(this.ConfigPath); 20 | RLDriver rlDriver = new RLDriver(liveModel, loopKind: this.GetLoopKind()); 21 | rlDriver.StepInterval = TimeSpan.FromMilliseconds(this.SleepIntervalMs); 22 | 23 | using (TextReader textReader = File.OpenText(this.LogPath)) 24 | { 25 | IEnumerable dsJsonLines = textReader.LazyReadLines(); 26 | ReplayStepProvider stepProvider = new ReplayStepProvider(dsJsonLines); 27 | 28 | rlDriver.Run(stepProvider); 29 | } 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /bindings/cs/rl.net.cli/RunSimulatorCommand.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using CommandLine; 3 | 4 | namespace Rl.Net.Cli 5 | { 6 | [Verb("simulator", HelpText = "Run simulator")] 7 | class RunSimulatorCommand : CommandBase 8 | { 9 | [Option(longName: "sleep", HelpText = "sleep interval in milliseconds", Required = false, Default = 1000)] 10 | public int SleepIntervalMs { get; set; } 11 | 12 | [Option(longName: "steps", HelpText = "Amount of steps", Required = false, Default = SimulatorStepProvider.InfinitySteps)] 13 | public int Steps { get; set; } 14 | 15 | public override void Run() 16 | { 17 | LiveModel liveModel = Helpers.CreateLiveModelOrExit(this.ConfigPath); 18 | 19 | RLSimulator rlSim = new RLSimulator(liveModel, loopKind: this.GetLoopKind()); 20 | rlSim.StepInterval = TimeSpan.FromMilliseconds(this.SleepIntervalMs); 21 | rlSim.OnError += (sender, apiStatus) => Helpers.WriteStatusAndExit(apiStatus); 22 | rlSim.Run(this.Steps); 23 | } 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /bindings/cs/rl.net.native/binding_sender.cc: -------------------------------------------------------------------------------- 1 | #include "binding_sender.h" 2 | 3 | #include 4 | 5 | using namespace reinforcement_learning; 6 | 7 | namespace rl_net_native 8 | { 9 | int binding_sender::init(const reinforcement_learning::utility::configuration& config, api_status* status) 10 | { 11 | return this->vtable.init(managed_handle, status); 12 | } 13 | 14 | int binding_sender::v_send(const buffer& data, api_status* status) 15 | { 16 | size_t length = data->buffer_filled_size(); 17 | if (length > INT32_MAX) 18 | { 19 | RETURN_ERROR_LS(trace_logger, status, background_queue_overflow) 20 | << "ISender only supports chunks of up to " << INT32_MAX << " in size."; 21 | } 22 | 23 | return this->vtable.send(managed_handle, &data, status); 24 | } 25 | 26 | binding_sender::~binding_sender() { this->vtable.release(managed_handle); } 27 | } // namespace rl_net_native 28 | -------------------------------------------------------------------------------- /bindings/cs/rl.net.native/binding_static_model.cc: -------------------------------------------------------------------------------- 1 | #include "binding_static_model.h" 2 | 3 | using namespace rl_net_native; 4 | using namespace reinforcement_learning; 5 | 6 | binding_static_model::binding_static_model(const char* vw_model, const size_t len) : vw_model(vw_model), len(len) {} 7 | 8 | int binding_static_model::get_data(model_transport::model_data& data, reinforcement_learning::api_status* status) 9 | { 10 | return data.set_data(vw_model, len); 11 | } 12 | -------------------------------------------------------------------------------- /bindings/cs/rl.net.native/binding_static_model.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "err_constants.h" 4 | #include "model_mgmt.h" 5 | 6 | #include 7 | 8 | namespace model_transport = reinforcement_learning::model_management; 9 | 10 | namespace rl_net_native 11 | { 12 | 13 | namespace constants 14 | { 15 | const char* const BINDING_DATA_TRANSPORT = "BINDING_DATA_TRANSPORT"; 16 | } 17 | class binding_static_model : public model_transport::i_data_transport 18 | { 19 | public: 20 | binding_static_model(const char* vw_model, const size_t len); 21 | int get_data(model_transport::model_data& data, reinforcement_learning::api_status* status = nullptr) override; 22 | 23 | private: 24 | const char* vw_model; 25 | const size_t len; 26 | }; 27 | } // namespace rl_net_native -------------------------------------------------------------------------------- /bindings/cs/rl.net.native/binding_tracer.cc: -------------------------------------------------------------------------------- 1 | #include "binding_tracer.h" 2 | 3 | namespace rl_net_native 4 | { 5 | binding_tracer::binding_tracer(loop_context& _context) : context(_context) {} 6 | 7 | void binding_tracer::log(int log_level, const std::string& msg) 8 | { 9 | if (log_level < this->log_level) { return; } 10 | if (context.trace_logger_callback != nullptr) { context.trace_logger_callback(log_level, msg.c_str()); } 11 | } 12 | 13 | void binding_tracer::set_level(int log_level) { this->log_level = log_level; } 14 | 15 | } // namespace rl_net_native 16 | -------------------------------------------------------------------------------- /bindings/cs/rl.net.native/binding_tracer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "rl.net.loop_context.h" 3 | #include "trace_logger.h" 4 | 5 | namespace rl_net_native 6 | { 7 | class binding_tracer : public reinforcement_learning::i_trace 8 | { 9 | public: 10 | // Inherited via i_trace 11 | binding_tracer(loop_context& _context); 12 | void log(int log_level, const std::string& msg) override; 13 | void set_level(int log_level) override; 14 | 15 | private: 16 | loop_context& context; 17 | int log_level = 0; 18 | }; 19 | } // namespace rl_net_native 20 | -------------------------------------------------------------------------------- /bindings/cs/rl.net.native/packages.config: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /bindings/cs/rl.net.native/rl.net.api_status.cc: -------------------------------------------------------------------------------- 1 | #include "rl.net.api_status.h" 2 | 3 | API reinforcement_learning::api_status* CreateApiStatus() { return new reinforcement_learning::api_status(); } 4 | 5 | API void DeleteApiStatus(reinforcement_learning::api_status* status) { delete status; } 6 | 7 | API const char* GetApiStatusErrorMessage(reinforcement_learning::api_status* status) { return status->get_error_msg(); } 8 | 9 | API int GetApiStatusErrorCode(reinforcement_learning::api_status* status) { return status->get_error_code(); } 10 | 11 | API void UpdateApiStatusSafe(reinforcement_learning::api_status* status, int error_code, const char* message) 12 | { 13 | // api_status takes a copy of the message string coming in, since it has no way to enforce that its callers 14 | // do not deallocate the buffer after calling try_update. 15 | reinforcement_learning::api_status::try_update(status, error_code, message); 16 | } 17 | 18 | API void ClearApiStatusSafe(reinforcement_learning::api_status* status) 19 | { 20 | reinforcement_learning::api_status::try_clear(status); 21 | } 22 | -------------------------------------------------------------------------------- /bindings/cs/rl.net.native/rl.net.api_status.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "rl.net.native.h" 4 | 5 | // Global exports 6 | extern "C" 7 | { 8 | // NOTE: THIS IS NOT POLYMORPHISM SAFE! 9 | API reinforcement_learning::api_status* CreateApiStatus(); 10 | API void DeleteApiStatus(reinforcement_learning::api_status* status); 11 | 12 | // TODO: We should think about how to avoid extra string copies; ideally, err constants 13 | // should be able to be shared between native/managed, but not clear if this is possible 14 | // right now. 15 | API const char* GetApiStatusErrorMessage(reinforcement_learning::api_status* status); 16 | API int GetApiStatusErrorCode(reinforcement_learning::api_status* status); 17 | 18 | API void UpdateApiStatusSafe(reinforcement_learning::api_status* status, int error_code, const char* message); 19 | API void ClearApiStatusSafe(reinforcement_learning::api_status* status); 20 | } 21 | -------------------------------------------------------------------------------- /bindings/cs/rl.net.native/rl.net.buffer.cc: -------------------------------------------------------------------------------- 1 | #include "rl.net.buffer.h" 2 | 3 | using namespace rl_net_native; 4 | 5 | API const buffer* CloneBufferSharedPointer(const buffer* original) { return new buffer(*original); } 6 | 7 | API void ReleaseBufferSharedPointer(const buffer* buffer) { delete buffer; } 8 | 9 | API const unsigned char* GetSharedBufferBegin(const buffer* buffer) { return (*buffer)->preamble_begin(); } 10 | 11 | API const size_t GetSharedBufferLength(const buffer* buffer) { return (*buffer)->buffer_filled_size(); } 12 | -------------------------------------------------------------------------------- /bindings/cs/rl.net.native/rl.net.buffer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "api_status.h" 4 | #include "rl.net.native.h" 5 | #include "sender.h" 6 | 7 | namespace rl_net_native 8 | { 9 | using buffer = std::shared_ptr; 10 | using error_context = reinforcement_learning::error_callback_fn; 11 | } // namespace rl_net_native 12 | 13 | extern "C" 14 | { 15 | API void ReleaseBufferSharedPointer(const rl_net_native::buffer* buffer); 16 | 17 | API const rl_net_native::buffer* CloneBufferSharedPointer(const rl_net_native::buffer* original); 18 | 19 | API const unsigned char* GetSharedBufferBegin(const rl_net_native::buffer* buffer); 20 | API const size_t GetSharedBufferLength(const rl_net_native::buffer* buffer); 21 | } 22 | -------------------------------------------------------------------------------- /bindings/cs/rl.net.native/rl.net.config.h: -------------------------------------------------------------------------------- 1 | #include "rl.net.native.h" 2 | 3 | #pragma once 4 | 5 | // Exports 6 | extern "C" 7 | { 8 | // NOTE: THIS IS NOT POLYMORPHISM SAFE! 9 | API reinforcement_learning::utility::configuration* CreateConfig(); 10 | API void DeleteConfig(reinforcement_learning::utility::configuration* config); 11 | 12 | API int LoadConfigurationFromJson(const int json_length, const char* json_value, 13 | reinforcement_learning::utility::configuration* config, reinforcement_learning::api_status* status = nullptr); 14 | 15 | API void ConfigurationSet( 16 | reinforcement_learning::utility::configuration* config, const char* name, const char* value); 17 | API const char* ConfigurationGet( 18 | reinforcement_learning::utility::configuration* config, const char* name, const char* defVal); 19 | 20 | // Enumerate? 21 | } 22 | -------------------------------------------------------------------------------- /bindings/cs/rl.net.native/rl.net.continuous_action_response.cc: -------------------------------------------------------------------------------- 1 | #include "rl.net.continuous_action_response.h" 2 | 3 | API reinforcement_learning::continuous_action_response* CreateContinuousActionResponse() 4 | { 5 | return new reinforcement_learning::continuous_action_response(); 6 | } 7 | 8 | API void DeleteContinuousActionResponse(reinforcement_learning::continuous_action_response* response) 9 | { 10 | delete response; 11 | } 12 | 13 | API const char* GetContinuousActionEventId(reinforcement_learning::continuous_action_response* response) 14 | { 15 | return response->get_event_id(); 16 | } 17 | 18 | API const char* GetContinuousActionModelId(reinforcement_learning::continuous_action_response* response) 19 | { 20 | return response->get_model_id(); 21 | } 22 | 23 | API float GetContinuousActionChosenAction(reinforcement_learning::continuous_action_response* response) 24 | { 25 | return response->get_chosen_action(); 26 | } 27 | 28 | API float GetContinuousActionChosenActionPdfValue(reinforcement_learning::continuous_action_response* response) 29 | { 30 | return response->get_chosen_action_pdf_value(); 31 | } -------------------------------------------------------------------------------- /bindings/cs/rl.net.native/rl.net.continuous_action_response.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "rl.net.native.h" 4 | 5 | // Global exports 6 | extern "C" 7 | { 8 | // NOTE: THIS IS NOT POLYMORPHISM SAFE! 9 | API reinforcement_learning::continuous_action_response* CreateContinuousActionResponse(); 10 | API void DeleteContinuousActionResponse(reinforcement_learning::continuous_action_response* response); 11 | 12 | // TODO: We should think about how to avoid extra string copies; ideally, err constants 13 | // should be able to be shared between native/managed, but not clear if this is possible 14 | // right now. 15 | API const char* GetContinuousActionEventId(reinforcement_learning::continuous_action_response* response); 16 | API const char* GetContinuousActionModelId(reinforcement_learning::continuous_action_response* response); 17 | 18 | API float GetContinuousActionChosenAction(reinforcement_learning::continuous_action_response* response); 19 | API float GetContinuousActionChosenActionPdfValue(reinforcement_learning::continuous_action_response* response); 20 | } -------------------------------------------------------------------------------- /bindings/cs/rl.net.native/rl.net.episode_state.cc: -------------------------------------------------------------------------------- 1 | #include "rl.net.episode_state.h" 2 | 3 | API reinforcement_learning::episode_state* CreateEpisodeState(const char* episodeId) 4 | { 5 | return new reinforcement_learning::episode_state(episodeId); 6 | } 7 | 8 | API void DeleteEpisodeState(reinforcement_learning::episode_state* episode_state) { delete episode_state; } 9 | 10 | API const char* GetEpisodeId(reinforcement_learning::episode_state* episode_state) 11 | { 12 | return episode_state->get_episode_id(); 13 | } 14 | 15 | API int UpdateEpisodeHistory(reinforcement_learning::episode_state* episode_state, const char* event_id, 16 | const char* previous_event_id, const char* context, const reinforcement_learning::ranking_response& resp, 17 | reinforcement_learning::api_status* error) 18 | { 19 | return episode_state->update(event_id, previous_event_id, context, resp, error); 20 | } -------------------------------------------------------------------------------- /bindings/cs/rl.net.native/rl.net.episode_state.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "rl.net.native.h" 4 | 5 | #include 6 | #include 7 | 8 | // Global exports 9 | extern "C" 10 | { 11 | // NOTE: THIS IS NOT POLYMORPHISM SAFE! 12 | API reinforcement_learning::episode_state* CreateEpisodeState(const char* episodeId); 13 | API void DeleteEpisodeState(reinforcement_learning::episode_state* episode_state); 14 | 15 | API const char* GetEpisodeId(reinforcement_learning::episode_state* episode_state); 16 | API int UpdateEpisodeHistory(reinforcement_learning::episode_state* episode_state, const char* event_id, 17 | const char* previous_event_id, const char* context, const reinforcement_learning::ranking_response& resp, 18 | reinforcement_learning::api_status* error = nullptr); 19 | } -------------------------------------------------------------------------------- /bindings/cs/rl.net.native/rl.net.factory_context.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "binding_sender.h" 4 | #include "binding_static_model.h" 5 | #include "factory_resolver.h" 6 | #include "model_mgmt.h" 7 | #include "rl.net.native.h" 8 | 9 | #include 10 | 11 | typedef struct factory_context 12 | { 13 | reinforcement_learning::trace_logger_factory_t* trace_logger_factory; 14 | reinforcement_learning::time_provider_factory_t* time_provider_factory; 15 | reinforcement_learning::sender_factory_t* sender_factory; 16 | reinforcement_learning::data_transport_factory_t* data_transport_factory; 17 | reinforcement_learning::model_factory_t* model_factory; 18 | } factory_context_t; 19 | 20 | extern "C" 21 | { 22 | API factory_context_t* CreateFactoryContext(); 23 | API factory_context_t* CreateFactoryContextWithStaticModel(const char* vw_model, const size_t len); 24 | API void DeleteFactoryContext(factory_context_t* context); 25 | API void SetFactoryContextBindingSenderFactory( 26 | factory_context_t* context, rl_net_native::sender_create_fn create_fn, rl_net_native::sender_vtable_t vtable); 27 | } 28 | -------------------------------------------------------------------------------- /bindings/cs/rl.net.native/rl.net.loop_context.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "rl.net.factory_context.h" 4 | 5 | namespace rl_net_native 6 | { 7 | namespace constants 8 | { 9 | const char* const BINDING_TRACE_LOGGER = "BINDING_TRACE_LOGGER"; 10 | } 11 | 12 | typedef void (*trace_logger_callback_t)(int log_level, const char* msg); 13 | } // namespace rl_net_native 14 | 15 | typedef struct loop_context 16 | { 17 | // callback funtion to user when there is background error. 18 | rl_net_native::background_error_callback_t background_error_callback; 19 | // callback funtion to user for trace log. 20 | rl_net_native::trace_logger_callback_t trace_logger_callback; 21 | // A trace log factory instance holder of one loop instance for binding calls. 22 | reinforcement_learning::trace_logger_factory_t* trace_logger_factory; 23 | } loop_context_t; -------------------------------------------------------------------------------- /bindings/cs/rl.net.native/rl.net.multi_slot_response.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "rl.net.native.h" 4 | 5 | // TODO: Make the underlying iterator more ammenable to P/Invoke projection 6 | class multi_slot_enumerator_adapter; 7 | 8 | // Global exports 9 | extern "C" 10 | { 11 | // NOTE: THIS IS NOT POLYMORPHISM SAFE! 12 | API int GetSlotEntryActionId(reinforcement_learning::slot_entry* slot); 13 | API float GetSlotEntryProbability(reinforcement_learning::slot_entry* slot); 14 | 15 | API reinforcement_learning::multi_slot_response* CreateMultiSlotResponse(); 16 | API void DeleteMultiSlotResponse(reinforcement_learning::multi_slot_response* multi_slot); 17 | 18 | API size_t GetMultiSlotSize(reinforcement_learning::multi_slot_response* multi_slot); 19 | 20 | API const char* GetMultiSlotModelId(reinforcement_learning::multi_slot_response* multi_slot); 21 | API const char* GetMultiSlotEventId(reinforcement_learning::multi_slot_response* multi_slot); 22 | 23 | API multi_slot_enumerator_adapter* CreateMultiSlotEnumeratorAdapter( 24 | reinforcement_learning::multi_slot_response* multi_slot); 25 | API void DeleteMultiSlotEnumeratorAdapter(multi_slot_enumerator_adapter* adapter); 26 | 27 | API int MultiSlotEnumeratorInit(multi_slot_enumerator_adapter* adapter); 28 | API int MultiSlotEnumeratorMoveNext(multi_slot_enumerator_adapter* adapter); 29 | API reinforcement_learning::slot_entry const* GetMultiSlotEnumeratorCurrent(multi_slot_enumerator_adapter* adapter); 30 | } -------------------------------------------------------------------------------- /bindings/cs/rl.net.native/rl.net.native.cc: -------------------------------------------------------------------------------- 1 | #include "rl.net.native.h" 2 | 3 | using namespace reinforcement_learning::error_code; 4 | 5 | API const char* LookupMessageForErrorCode(int error_code) 6 | { 7 | #define ERROR_CODE_DEFINITION(code, name, msg) \ 8 | case code: \ 9 | return name##_s; 10 | 11 | switch (error_code) 12 | { 13 | case 0: 14 | return "Success."; 15 | default: 16 | return unknown_s; 17 | 18 | #include "errors_data.h" 19 | } 20 | 21 | #undef ERROR_CODE_DEFINITION 22 | } 23 | -------------------------------------------------------------------------------- /bindings/cs/rl.net.native/rl.net.native.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ca_loop.h" 4 | #include "cb_loop.h" 5 | #include "ccb_loop.h" 6 | #include "config_utility.h" 7 | #include "live_model.h" 8 | #include "slates_loop.h" 9 | 10 | #include 11 | #include 12 | 13 | #if defined(_MSC_VER) 14 | // Microsoft 15 | # define API __declspec(dllexport) 16 | #elif defined(__GNUC__) 17 | // GCC 18 | # define API __attribute__((visibility("default"))) 19 | #else 20 | // do nothing and hope for the best? 21 | # define API 22 | # pragma warning Unknown dynamic link import / export semantics. 23 | #endif 24 | 25 | namespace rl_net_native 26 | { 27 | typedef void (*background_error_callback_t)(const reinforcement_learning::api_status&); 28 | } 29 | 30 | extern "C" 31 | { 32 | API const char* LookupMessageForErrorCode(int code); 33 | } 34 | -------------------------------------------------------------------------------- /bindings/cs/rl.net.native/rl.net.slot_ranking.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "rl.net.native.h" 4 | 5 | class slot_enumerator_adapter; 6 | 7 | // Global Exports 8 | extern "C" 9 | { 10 | // NOTE: THIS IS NOT POLYMORPHISM SAFE 11 | API reinforcement_learning::slot_ranking* CreateSlotRanking(); 12 | API void DeleteSlotRanking(reinforcement_learning::slot_ranking* slot); 13 | 14 | // TODO: We should think about how to avoid extra string copies; ideally, err constants 15 | // should be able to be shared between native/managed, but not clear if this is possible 16 | // right now. 17 | API const char* GetSlotId(reinforcement_learning::slot_ranking* slot); 18 | 19 | API size_t GetSlotActionCount(reinforcement_learning::slot_ranking* slot); 20 | 21 | API int GetSlotChosenAction(reinforcement_learning::slot_ranking* slot, size_t* action_id, 22 | reinforcement_learning::api_status* status = nullptr); 23 | 24 | API slot_enumerator_adapter* CreateSlotEnumeratorAdapter(reinforcement_learning::slot_ranking* slot); 25 | API void DeleteSlotEnumeratorAdapter(slot_enumerator_adapter* adapter); 26 | 27 | API int SlotEnumeratorInit(slot_enumerator_adapter* adapter); 28 | API int SlotEnumeratorMoveNext(slot_enumerator_adapter* adapter); 29 | API reinforcement_learning::action_prob_d GetSlotEnumeratorCurrent(slot_enumerator_adapter* adapter); 30 | } 31 | -------------------------------------------------------------------------------- /bindings/cs/rl.net/ActionFlags.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | 3 | namespace Rl.Net 4 | { 5 | [Flags] 6 | public enum ActionFlags : uint 7 | { 8 | Default = 0, 9 | Deferred = 1, 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /bindings/cs/rl.net/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(RL_NET_SOURCES 2 | Native/ErrorCallback.cs 3 | Native/GCHandleLifetime.cs 4 | Native/Global.cs 5 | Native/NativeImports.cs 6 | Native/NativeObject.cs 7 | Native/SenderAdapter.cs 8 | Native/StringExtensions.cs 9 | 10 | ActionFlags.cs 11 | ApiStatus.cs 12 | AsyncSender.cs 13 | OAuthCredentialProvider.cs 14 | CALoop.cs 15 | CBLoop.cs 16 | CCBLoop.cs 17 | Configuration.cs 18 | ContinuousActionResponse.cs 19 | DecisionResponse.cs 20 | FactoryContext.cs 21 | InternalsVisibleToTest.tt 22 | ILoop.cs 23 | ISender.cs 24 | LiveModel.cs 25 | LiveModelThreadSafe.cs 26 | MultiSlotResponse.cs 27 | MultiSlotResponseDetailed.cs 28 | NativeCallbacks.cs 29 | RankingResponse.cs 30 | RLException.cs 31 | RLLibLogUtils.cs 32 | SharedBuffer.cs 33 | SlatesLoop.cs 34 | SlotRanking.cs 35 | TraceLogEventArgs.cs 36 | ) 37 | 38 | if (rlclientlib_DOTNET_USE_MSPROJECT) 39 | include_external_msproject(rl.net ${CMAKE_CURRENT_SOURCE_DIR}/rl.net.csproj rlnetnative) 40 | else() 41 | add_custom_target(rl.net 42 | COMMAND ${DOTNET_COMMAND} build ${CMAKE_CURRENT_SOURCE_DIR} -o $ -v n --nologo --no-dependencies /clp:NoSummary --configuration "$<$:Debug>$<$:Release>$<$:Release>" 43 | COMMENT Building rl.net 44 | DEPENDS rlnetnative 45 | SOURCES ${RL_NET_SOURCES}) 46 | endif() -------------------------------------------------------------------------------- /bindings/cs/rl.net/ILoop.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | 3 | namespace Rl.Net 4 | { 5 | public interface ILoop 6 | { 7 | bool TryInit(ApiStatus apiStatus = null); 8 | void Init(); 9 | bool TryRefreshModel(ApiStatus apiStatus = null); 10 | void RefreshModel(); 11 | event EventHandler BackgroundError; 12 | event EventHandler TraceLoggerEvent; 13 | } 14 | } -------------------------------------------------------------------------------- /bindings/cs/rl.net/ISender.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Runtime.InteropServices; 3 | 4 | using Rl.Net.Native; 5 | 6 | namespace Rl.Net 7 | { 8 | public interface ISender 9 | { 10 | void Init(ApiStatus status); 11 | 12 | void Send(SharedBuffer buffer, ApiStatus status); 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /bindings/cs/rl.net/InternalsVisibleToTest.tt: -------------------------------------------------------------------------------- 1 | <#@include file="..\common\codegen\InternalsVisibleToTest.tt" #> 2 | -------------------------------------------------------------------------------- /bindings/cs/rl.net/Native/ErrorCallback.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Threading; 3 | using System.Runtime.InteropServices; 4 | using System.Text; 5 | 6 | namespace Rl.Net.Native 7 | { 8 | internal delegate void error_fn(IntPtr error_context, IntPtr status); 9 | 10 | public class ErrorCallback 11 | { 12 | private IntPtr error_context; 13 | private error_fn callback; 14 | 15 | internal ErrorCallback(error_fn callback, IntPtr error_context) 16 | { 17 | this.callback = callback; 18 | this.error_context = error_context; 19 | } 20 | 21 | public void Invoke(ApiStatus status) 22 | { 23 | if (status != null) 24 | { 25 | this.callback(this.error_context, status.DangerousGetHandle()); 26 | GC.KeepAlive(status); 27 | } 28 | } 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /bindings/cs/rl.net/Native/Global.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Threading; 3 | using System.Runtime.InteropServices; 4 | using System.Text; 5 | 6 | namespace Rl.Net.Native { 7 | internal static partial class NativeMethods 8 | { 9 | public static readonly Encoding StringEncoding = Encoding.UTF8; 10 | public static readonly Func StringMarshallingFunc = StringExtensions.PtrToStringUtf8; 11 | 12 | public const int SuccessStatus = 0; // See err_constants.h 13 | public const int OpaqueBindingError = 39; // See err_contants.h 14 | 15 | public static IntPtr ToNativeHandleOrNullptrDangerous(this NativeObject nativeObject) where TObject : NativeObject 16 | { 17 | if (nativeObject == null) 18 | { 19 | return IntPtr.Zero; 20 | } 21 | 22 | return nativeObject.DangerousGetHandle(); 23 | } 24 | 25 | [DllImport(NativeImports.RLNETNATIVE)] 26 | public static extern IntPtr LookupMessageForErrorCode(int error_code); 27 | 28 | public static string MarshalMessageForErrorCode(int error_code) 29 | { 30 | IntPtr nativeMessage = LookupMessageForErrorCode(error_code); 31 | return StringMarshallingFunc(nativeMessage); 32 | } 33 | 34 | public static readonly string OpaqueBindingErrorMessage = MarshalMessageForErrorCode(OpaqueBindingError); 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /bindings/cs/rl.net/Native/NativeImports.cs: -------------------------------------------------------------------------------- 1 | namespace Rl.Net.Native { 2 | internal static class NativeImports { 3 | // NOTE: RLNETNATIVE for debug and release are the same, 4 | // but this is a placeholder for future changes. 5 | #if DEBUG 6 | internal const string RLNETNATIVE = "rlnetnative"; 7 | #else 8 | internal const string RLNETNATIVE = "rlnetnative"; 9 | #endif 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /bindings/cs/rl.net/NativeCallbacks.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | 4 | namespace Rl.Net 5 | { 6 | namespace Native 7 | { 8 | internal static partial class NativeMethods 9 | { 10 | public delegate void managed_background_error_callback_t(IntPtr apiStatus); 11 | public delegate void managed_trace_callback_t(int logLevel, IntPtr msgUtf8Ptr); 12 | public delegate int managed_oauth_callback_t(IntPtr scopes, IntPtr tokenOutPtr, IntPtr unixTimestamp); 13 | public delegate void managed_oauth_callback_t_complete_t(IntPtr tokenStringToFree, int errorCode); 14 | } 15 | } 16 | } -------------------------------------------------------------------------------- /bindings/cs/rl.net/RLLibLogUtils.cs: -------------------------------------------------------------------------------- 1 | namespace Rl.Net 2 | { 3 | // These values are defined in https://github.com/VowpalWabbit/reinforcement_learning/blob/master/include/trace_logger.h 4 | public enum RLLogLevel 5 | { 6 | LEVEL_DEBUG = -10, 7 | LEVEL_INFO = 0, 8 | LEVEL_WARNING = 10, 9 | LEVEL_ERROR = 20, 10 | } 11 | 12 | } 13 | -------------------------------------------------------------------------------- /bindings/cs/rl.net/TraceLogEventArgs.cs: -------------------------------------------------------------------------------- 1 | namespace Rl.Net 2 | { 3 | public class TraceLogEventArgs 4 | { 5 | public TraceLogEventArgs(RLLogLevel logLevel, string msg) 6 | { 7 | LogLevel = logLevel; 8 | Message = msg; 9 | } 10 | 11 | public RLLogLevel LogLevel { get; } 12 | public string Message { get; } 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /bindings/python/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | pybind11_add_module(rl_client py_api.cc) 2 | target_link_libraries(rl_client PRIVATE rlclientlib) 3 | -------------------------------------------------------------------------------- /bindings/python/README.md: -------------------------------------------------------------------------------- 1 | # Python Bindings 2 | 3 | ## Build + Install 4 | 5 | Commands are relative to repo root. 6 | ```bash 7 | python setup.py install 8 | 9 | # Or, if vcpkg used for deps 10 | python setup.py --cmake-options="-DCMAKE_TOOLCHAIN_FILE=/path_to_vcpkg_root/scripts/buildsystems/vcpkg.cmake" install 11 | ``` 12 | 13 | - For Ubuntu 20.04, Python 3.8 a recommended vcpkg version is: `Release 2020.06, commit 6185aa7` 14 | 15 | ## Usage 16 | 17 | After successful installation, an example is in [`examples/python/basic_usage.py`](../../examples/python/basic_usage.py). 18 | -------------------------------------------------------------------------------- /bindings/python/docs/constants.rst: -------------------------------------------------------------------------------- 1 | rl_client.constants Reference 2 | ============================= 3 | 4 | .. autoclass:: rl_client.constants 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /bindings/python/docs/index.rst: -------------------------------------------------------------------------------- 1 | rl_client Documentation 2 | ======================= 3 | 4 | Contents: 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | rl_client 10 | constants 11 | migration_guide 12 | 13 | -------------------------------------------------------------------------------- /bindings/python/docs/rl_client.rst: -------------------------------------------------------------------------------- 1 | rl_client Reference 2 | =================== 3 | 4 | .. automodule:: rl_client 5 | :members: 6 | :undoc-members: 7 | :exclude-members: constants 8 | -------------------------------------------------------------------------------- /cmake/DetectCXXStandard.cmake: -------------------------------------------------------------------------------- 1 | include(CheckCXXCompilerFlag) 2 | 3 | function(DetectCXXStandard OUTPUT_VAR) 4 | if(NOT MSVC) 5 | set(CXX_STANDARD20_FLAG "-std=c++20") 6 | set(CXX_STANDARD17_FLAG "-std=c++17") 7 | set(CXX_STANDARD14_FLAG "-std=c++14") 8 | set(CXX_STANDARD11_FLAG "-std=c++11") 9 | elseif(MSVC) 10 | set(CXX_STANDARD20_FLAG "/std:c++20") 11 | set(CXX_STANDARD17_FLAG "/std:c++17") 12 | set(CXX_STANDARD14_FLAG "/std:c++14") 13 | endif() 14 | 15 | # TODO use VERSION_GREATER_EQUAL in cmake 3.7+ 16 | check_cxx_compiler_flag(${CXX_STANDARD20_FLAG} HAS_CXX20_FLAG) 17 | if (HAS_CXX20_FLAG AND ((${CMAKE_VERSION} VERSION_EQUAL "3.12.0") OR (${CMAKE_VERSION} VERSION_GREATER "3.12.0"))) 18 | set(${OUTPUT_VAR} 20 PARENT_SCOPE) 19 | return() 20 | endif() 21 | 22 | check_cxx_compiler_flag(${CXX_STANDARD17_FLAG} HAS_CXX17_FLAG) 23 | if (HAS_CXX17_FLAG AND ((${CMAKE_VERSION} VERSION_EQUAL "3.8.0") OR (${CMAKE_VERSION} VERSION_GREATER "3.8.0"))) 24 | set(${OUTPUT_VAR} 17 PARENT_SCOPE) 25 | return() 26 | endif() 27 | 28 | check_cxx_compiler_flag(${CXX_STANDARD14_FLAG} HAS_CXX14_FLAG) 29 | if (HAS_CXX14_FLAG) 30 | set(${OUTPUT_VAR} 14 PARENT_SCOPE) 31 | return() 32 | endif() 33 | 34 | set(${OUTPUT_VAR} 11 PARENT_SCOPE) 35 | endfunction() 36 | -------------------------------------------------------------------------------- /cmake/Modules/FindDotnet.cmake: -------------------------------------------------------------------------------- 1 | find_program(DOTNET_COMMAND "dotnet" REQUIRED) 2 | 3 | if(WIN32) 4 | find_program(DOTNET_T4_COMMAND "t4" REQUIRED) 5 | endif() 6 | -------------------------------------------------------------------------------- /custom-triplets/x64-windows-static-md-v141.cmake: -------------------------------------------------------------------------------- 1 | set(VCPKG_TARGET_ARCHITECTURE x64) 2 | set(VCPKG_CRT_LINKAGE dynamic) 3 | set(VCPKG_LIBRARY_LINKAGE static) 4 | set(VCPKG_PLATFORM_TOOLSET v141) -------------------------------------------------------------------------------- /custom-triplets/x64-windows-v141.cmake: -------------------------------------------------------------------------------- 1 | set(VCPKG_TARGET_ARCHITECTURE x64) 2 | set(VCPKG_CRT_LINKAGE dynamic) 3 | set(VCPKG_LIBRARY_LINKAGE dynamic) 4 | set(VCPKG_PLATFORM_TOOLSET v141) -------------------------------------------------------------------------------- /doc/cpp/.gitignore: -------------------------------------------------------------------------------- 1 | html -------------------------------------------------------------------------------- /doc/cpp/Doxyfile: -------------------------------------------------------------------------------- 1 | # Doxyfile 1.8.13 2 | 3 | PROJECT_NAME = "Reinforcement Learning" 4 | PROJECT_NUMBER = 1.1 5 | 6 | # Input files 7 | INPUT = mainpage.dox build.dox api_config.dox api_error_codes.dox api_context_format.dox compression.dox 8 | INPUT += ../../include/live_model.h ../../include/ranking_response.h ../../include/api_status.h ../../include/configuration.h ../../include/err_constants.h 9 | EXAMPLE_PATH = ../../examples/basic_usage_cpp ../../examples/rl_sim_cpp ../../include/err_constants.h ../../include/errors_data.h 10 | EXAMPLE_PATH += ../../examples/override_interface 11 | 12 | DOXYFILE_ENCODING = UTF-8 13 | 14 | # Don't generate index so we can use our custom one for navigating between projects 15 | DISABLE_INDEX = YES 16 | GENERATE_TREEVIEW = YES 17 | 18 | GENERATE_LATEX = NO 19 | 20 | TAB_SIZE = 2 21 | 22 | # Suppress build warnings when Doxygen tries to use Graphviz 23 | HAVE_DOT = NO 24 | -------------------------------------------------------------------------------- /doc/cpp/api_error_codes.dox: -------------------------------------------------------------------------------- 1 | /*! \page api_error_codes API Error Codes 2 | 3 | All RL Inference API functions return error codes. 4 | 5 | Following is the list of all error defintions: 6 | \snippet include/errors_data.h Error Definitions 7 | 8 | Where each error defintion is surfaced in C++ by the follwing: 9 | \snippet include/err_constants.h Error Generator 10 | 11 | */ 12 | -------------------------------------------------------------------------------- /doc/cpp/build.dox: -------------------------------------------------------------------------------- 1 | /*! \page Build 2 | Building RL Inference Library 3 | ------------------------------ 4 | 5 | Linux: 6 | ------ 7 | - Pre-requisites: 8 | + g++ 4.9 or higher 9 | + cpprestsdk (https://github.com/Microsoft/cpprestsdk) 10 | + boost 11 | - Make targets: 12 | + rlclientlib (API) 13 | + rltest (Unit tests) 14 | + _rl_python (Python Bindings) 15 | 16 | Windows: 17 | -------- 18 | - Pre-requisites: 19 | + cpprestsdk (https://github.com/Microsoft/cpprestsdk) 20 | + boost 21 | - Visual Studio Solution: 22 | + vowpalwabbit/vw.sln 23 | - Visual Studio Projects 24 | + reinforcement_learning/rlclientlib 25 | + reinforcement_learning/unit_test 26 | + reinforcement_learning/examples/basic_usage_cpp 27 | + reinforcement_learning/examples/rl_sim_cpp 28 | */ 29 | -------------------------------------------------------------------------------- /doc/cpp/compression.dox: -------------------------------------------------------------------------------- 1 | /*! \page compression Payload compression 2 | 3 | RLlib has a way to efficiently compress specific payloads. 4 | 5 | In some cases it's desireable to trade client-side CPU for a reduced payload size if the backend messaging system is a bottleneck. 6 | We support the following compresion mechanisms: 7 | 8 | - Per-decision compression using zstandard (https://zstd.net). 9 | - Per-batch dictionary-based action deduplication for CB, CCB and Slates problems using the DSJSON input format. 10 | 11 | When both enabled, it's possible to get upwards 20x compression ratio for specific workloads. 12 | 13 | 14 | Zstandard compression 15 | --------------------- 16 | 17 | We use zstandard as it was empirically found to provide the best tradeoff between compression ratio and CPU time. 18 | RLlib uses the default compression level 1. 19 | 20 | Dictionary-based action deduplication 21 | ------------------------------------ 22 | 23 | In some workloads, there's a small set of actions used across all decisions. In those cases, this schema will explit this source 24 | of redundancy and produce significantly smaller payloads. Case should be taken as it will cause a minor size expansion in the case 25 | that no action is ever seen more than once - quite possible in ADF settings where features are computed at runtime. 26 | 27 | 28 | To enable it set the following configuration key: 29 | 30 | ~~~~~ 31 | cfg::configuration cc; 32 | cc.set("XXXXX", "YYYYY"); 33 | ~~~~~ 34 | */ -------------------------------------------------------------------------------- /doc/cpp/rl-loop.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/doc/cpp/rl-loop.GIF -------------------------------------------------------------------------------- /doc/readme.md: -------------------------------------------------------------------------------- 1 | # Generate C++ Documentation 2 | 3 | ```sh 4 | cd cpp 5 | doxygen 6 | # Docs will be generated to html/ 7 | ``` -------------------------------------------------------------------------------- /examples/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(basic_usage_cpp) 2 | add_subdirectory(override_interface) 3 | add_subdirectory(rl_sim_cpp) 4 | add_subdirectory(test_cpp) 5 | 6 | if (rlclientlib_BUILD_ONNXRUNTIME_EXTENSION) 7 | add_subdirectory(onnx) 8 | endif() 9 | -------------------------------------------------------------------------------- /examples/basic_usage_cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(basic_usage_cpp.out 2 | basic_usage_cpp.cc 3 | ) 4 | 5 | target_link_libraries(basic_usage_cpp.out PRIVATE rlclientlib) 6 | 7 | if(RL_LINK_AZURE_LIBS) 8 | target_compile_definitions(basic_usage_cpp.out PRIVATE LINK_AZURE_LIBS) 9 | find_package(azure-identity-cpp CONFIG REQUIRED) 10 | target_link_libraries(basic_usage_cpp.out PRIVATE Azure::azure-identity) 11 | endif() 12 | -------------------------------------------------------------------------------- /examples/basic_usage_cpp/basic_usage_cpp.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @brief Simple RL Inference API sample * 3 | * 4 | * @file basic_usage_cpp.h 5 | * @author Rajan Chari et al 6 | * @date 2018-07-18 7 | */ 8 | #pragma once 9 | 10 | #include "ca_loop.h" 11 | #include "cb_loop.h" 12 | #include "ccb_loop.h" 13 | #include "config_utility.h" 14 | #include "multistep_loop.h" 15 | #include "slates_loop.h" 16 | 17 | #include 18 | #include 19 | 20 | // Namespace manipulation for brevity 21 | namespace r = reinforcement_learning; 22 | namespace u = r::utility; 23 | namespace cfg = u::config; 24 | namespace err = r::error_code; 25 | 26 | int basic_usage_cb(); 27 | int basic_usage_ca(); 28 | int basic_usage_ccb(); 29 | int basic_usage_slates(); 30 | int basic_usage_multistep(); 31 | 32 | int load_file(const std::string& file_name, std::string& config_str); 33 | int load_config_from_json(const std::string& file_name, u::configuration& config); 34 | -------------------------------------------------------------------------------- /examples/basic_usage_cpp/client.json: -------------------------------------------------------------------------------- 1 | { 2 | "ApplicationID": "", 3 | "EventHubInteractionConnectionString": "Endpoint=sb://.servicebus.windows.net/;SharedAccessKeyName=RootManageSharedAccessKey;SharedAccessKey=;EntityPath=interaction", 4 | "EventHubObservationConnectionString": "Endpoint=sb://.servicebus.windows.net/;SharedAccessKeyName=RootManageSharedAccessKey;SharedAccessKey=;EntityPath=observation", 5 | "IsExplorationEnabled": true, 6 | "ModelBlobUri": "https://.blob.core.windows.net/mwt-models/current?sv=2017-07-29&sr=b&sig=&st=2018-06-26T09%3A00%3A55Z&se=2028-06-26T09%3A01%3A55Z&sp=r", 7 | "InitialExplorationEpsilon": 1.0 8 | } 9 | -------------------------------------------------------------------------------- /examples/onnx/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/Modules/") 2 | find_package(OnnxRuntime REQUIRED) 3 | 4 | add_executable(onnx_example 5 | onnx_example.cc 6 | ) 7 | 8 | file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mnist_data/) 9 | 10 | set(ONNX_EXTENSION_TEST_RESOURCE_FILES 11 | ${CMAKE_CURRENT_LIST_DIR}/../../unit_test/extensions/onnx/mnist_data/mnist_model.onnx 12 | ) 13 | 14 | add_custom_command( 15 | TARGET onnx_example POST_BUILD 16 | COMMAND ${CMAKE_COMMAND} -E copy_if_different 17 | ${ONNX_EXTENSION_TEST_RESOURCE_FILES} 18 | ${CMAKE_CURRENT_BINARY_DIR}/mnist_data/ 19 | ) 20 | 21 | target_link_libraries(onnx_example PRIVATE rlclientlib-onnx) 22 | 23 | add_custom_command(TARGET onnx_example POST_BUILD 24 | COMMAND ${CMAKE_COMMAND} -E copy 25 | $ 26 | $ 27 | ) 28 | -------------------------------------------------------------------------------- /examples/onnx/readme.md: -------------------------------------------------------------------------------- 1 | # ONNX example 2 | 3 | This is example runs RLClientLib as BYOM using an ONNX model for MNIST and logs the interactions and observations to files. 4 | 5 | ## Generate `mnist_test_data.txt` 6 | 7 | ```sh 8 | # Install mnist python lib 9 | python3 -m pip install python-mnist 10 | 11 | # Run installed script to pull mnist data to ./data 12 | mnist_get_data.sh 13 | 14 | python3 ../../unit_test/extensions/onnx/mnist_data/data_generator.py 15 | ``` 16 | 17 | ## Build example 18 | 19 | ```sh 20 | mkdir build 21 | cd build 22 | cmake .. -Drlclientlib_BUILD_ONNXRUNTIME_EXTENSION=On 23 | make onnx_example -j $(nproc) 24 | ``` 25 | 26 | ## Run example 27 | 28 | ```sh 29 | ./examples/onnx/onnx_example ../examples/onnx/mnist_test_data.txt ./examples/onnx/mnist_data/mnist_model.onnx 30 | ``` 31 | 32 | This produces two files in the current directory `observation.fb.data` and `interaction.fb.data`. 33 | -------------------------------------------------------------------------------- /examples/override_interface/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(override_interface.out 2 | override_interface.cc 3 | ) 4 | 5 | target_link_libraries(override_interface.out PRIVATE rlclientlib) 6 | 7 | if(RL_LINK_AZURE_LIBS) 8 | target_compile_definitions(override_interface.out PRIVATE LINK_AZURE_LIBS) 9 | find_package(azure-identity-cpp CONFIG REQUIRED) 10 | target_link_libraries(override_interface.out PRIVATE Azure::azure-identity) 11 | endif() 12 | -------------------------------------------------------------------------------- /examples/override_interface/client.json: -------------------------------------------------------------------------------- 1 | { 2 | "ApplicationID": "", 3 | "EventHubInteractionConnectionString": "Endpoint=sb://.servicebus.windows.net/;SharedAccessKeyName=RootManageSharedAccessKey;SharedAccessKey=;EntityPath=interaction", 4 | "EventHubObservationConnectionString": "Endpoint=sb://.servicebus.windows.net/;SharedAccessKeyName=RootManageSharedAccessKey;SharedAccessKey=;EntityPath=observation", 5 | "IsExplorationEnabled": true, 6 | "ModelBlobUri": "https://.blob.core.windows.net/mwt-models/current?sv=2017-07-29&sr=b&sig=&st=2018-06-26T09%3A00%3A55Z&se=2028-06-26T09%3A01%3A55Z&sp=r", 7 | "InitialExplorationEpsilon": 1.0 8 | } 9 | -------------------------------------------------------------------------------- /examples/rl_sim_cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_LIST_DIR}/../../cmake/Modules/") 2 | 3 | set(RL_SIM_SOURCES 4 | main.cc 5 | person.cc 6 | robot_joint.cc 7 | rl_sim.cc 8 | ) 9 | 10 | add_executable(rl_sim_cpp.out 11 | ${RL_SIM_SOURCES} 12 | ) 13 | 14 | target_link_libraries(rl_sim_cpp.out PRIVATE Boost::program_options rlclientlib) 15 | 16 | if(RL_LINK_AZURE_LIBS) 17 | target_compile_definitions(rl_sim_cpp.out PRIVATE LINK_AZURE_LIBS) 18 | find_package(azure-identity-cpp CONFIG REQUIRED) 19 | target_link_libraries(rl_sim_cpp.out PRIVATE Azure::azure-identity) 20 | endif() 21 | -------------------------------------------------------------------------------- /examples/rl_sim_cpp/client.json: -------------------------------------------------------------------------------- 1 | { 2 | "ApplicationID": "", 3 | "EventHubInteractionConnectionString": "Endpoint=sb://.servicebus.windows.net/;SharedAccessKeyName=RootManageSharedAccessKey;SharedAccessKey=;EntityPath=interaction", 4 | "EventHubObservationConnectionString": "Endpoint=sb://.servicebus.windows.net/;SharedAccessKeyName=RootManageSharedAccessKey;SharedAccessKey=;EntityPath=observation", 5 | "IsExplorationEnabled": true, 6 | "ModelBlobUri": "https://.blob.core.windows.net/mwt-models/current?sv=2017-07-29&sr=b&sig=&st=2018-06-26T09%3A00%3A55Z&se=2028-06-26T09%3A01%3A55Z&sp=r", 7 | "InitialExplorationEpsilon": 1.0, 8 | "model.source": "FILE_MODEL_DATA" 9 | } 10 | -------------------------------------------------------------------------------- /examples/rl_sim_cpp/current: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/examples/rl_sim_cpp/current -------------------------------------------------------------------------------- /examples/rl_sim_cpp/person.cc: -------------------------------------------------------------------------------- 1 | #include "person.h" 2 | 3 | #include "rand48.h" 4 | 5 | #include 6 | #include 7 | 8 | person::person(std::string id, std::string major, std::string hobby, std::string fav_char, topic_prob& p) 9 | : _id(std::move(id)) 10 | , _major{std::move(major)} 11 | , _hobby{std::move(hobby)} 12 | , _favorite_character{std::move(fav_char)} 13 | , _topic_click_probability{p} 14 | { 15 | } 16 | 17 | person::~person() = default; 18 | 19 | std::string person::get_features() const 20 | { 21 | std::ostringstream oss; 22 | oss << R"("GUser":{)"; 23 | oss << R"("id":")" << _id << R"(",)"; 24 | oss << R"("major":")" << _major << R"(",)"; 25 | oss << R"("hobby":")" << _hobby << R"(",)"; 26 | oss << R"("favorite_character":")" << _favorite_character; 27 | oss << R"("})"; 28 | return oss.str(); 29 | } 30 | 31 | float person::get_outcome(const std::string& chosen_action, uint64_t& random_seed) 32 | { 33 | float const norm_draw_val = rand48(random_seed); 34 | float const click_prob = _topic_click_probability[chosen_action]; 35 | if (norm_draw_val <= click_prob) { return 1.0f; } 36 | return 0.0f; 37 | } 38 | 39 | std::string person::id() const { return _id; } 40 | -------------------------------------------------------------------------------- /examples/rl_sim_cpp/person.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | /** 6 | * @brief Represents a person initiating an interaction 7 | */ 8 | class person 9 | { 10 | public: 11 | //! Collection type of outcome probability for given action 12 | using topic_prob = std::unordered_map; 13 | /** 14 | * @brief Construct a new person 15 | * 16 | * @param id Unique id for a person 17 | * @param major Person feature (major) 18 | * @param hobby Person feature (hobby) 19 | * @param fav_char Person feature (fav_char) 20 | * @param topicprob Probability of outcome for a given topic 21 | **/ 22 | person(std::string id, std::string major, std::string hobby, std::string fav_char, topic_prob& p); 23 | ~person(); 24 | 25 | //! Get person features as a json string 26 | std::string get_features() const; 27 | //! Get the outcome for a topic. Use probability to randomly assign a outcome 28 | float get_outcome(const std::string& chosen_action, uint64_t& random_seed); 29 | //! Get the person's id 30 | std::string id() const; 31 | 32 | private: 33 | const std::string _id; 34 | const std::string _major; 35 | const std::string _hobby; 36 | const std::string _favorite_character; 37 | topic_prob _topic_click_probability; 38 | }; 39 | -------------------------------------------------------------------------------- /examples/rl_sim_cpp/rand48.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #include 5 | 6 | // merand48 copied from VW. Stable random number generator for reproducibility 7 | constexpr uint64_t CONSTANT_A = 0xeece66d5deece66dULL; 8 | constexpr uint64_t CONSTANT_C = 2147483647; 9 | 10 | constexpr int BIAS = 127 << 23; 11 | 12 | // int-ified version of merand48... 13 | inline float rand48(uint64_t& initial) 14 | { 15 | static_assert( 16 | sizeof(int32_t) == sizeof(float), "Floats and int32_ts are converted between, they must be the same size."); 17 | initial = CONSTANT_A * initial + CONSTANT_C; 18 | int32_t temp = ((initial >> 25) & 0x7FFFFF) | BIAS; 19 | return reinterpret_cast(temp) - 1; 20 | } -------------------------------------------------------------------------------- /examples/rl_sim_cpp/rl_sim_cpp.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "config_utility.h" 4 | #include "live_model.h" 5 | #include "rl_sim.h" 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | // Namespace manipulation for brevity 13 | namespace r = reinforcement_learning; 14 | namespace u = r::utility; 15 | namespace cfg = u::config; 16 | namespace err = r::error_code; 17 | namespace po = boost::program_options; 18 | 19 | // Forward declare functions 20 | po::variables_map process_cmd_line(const int argc, char** argv); 21 | bool is_help(const po::variables_map& vm); 22 | -------------------------------------------------------------------------------- /examples/rl_sim_cpp/robot_joint.cc: -------------------------------------------------------------------------------- 1 | #include "robot_joint.h" 2 | 3 | #include "rand48.h" 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | joint::joint(std::string id, float temp, float ang_velocity, float load, friction_prob& outcome_probs) 10 | : _id(std::move(id)), _temp(temp), _angular_velocity(ang_velocity), _load(load), _outcome_probability(outcome_probs) 11 | { 12 | } 13 | 14 | std::string joint::get_features() 15 | { 16 | std::ostringstream oss; 17 | oss << R"("id":")" << _id << R"(",)"; 18 | oss << R"(")" << _temp << R"(":)" << 1 << R"(,)"; 19 | oss << R"(")" << _angular_velocity << R"(":)" << 1 << R"(,)"; 20 | oss << R"(")" << _load << R"(":)" << 1; 21 | return oss.str(); 22 | } 23 | 24 | float joint::get_outcome(float observed_friction, uint64_t& random_seed) 25 | { 26 | float const norm_draw_val = rand48(random_seed); 27 | float click_prob = 0.; 28 | 29 | // figure out which bucket from our pre-set frictions the observed_friction 30 | // falls into to and get it's probability 31 | for (auto fp : _outcome_probability) 32 | { 33 | if (observed_friction >= fp.first) { click_prob = fp.second; } 34 | } 35 | if (norm_draw_val <= click_prob) { return 1.0f; } 36 | return 0.0f; 37 | } 38 | 39 | std::string joint::id() const { return _id; } -------------------------------------------------------------------------------- /examples/rl_sim_cpp/robot_joint.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | 6 | class joint 7 | { 8 | public: 9 | // Collection type of outcome probability for a given friction range 10 | using friction_prob = std::map; 11 | joint(std::string id, float temp, float ang_velocity, float load, friction_prob& outcome_probs); 12 | 13 | std::string get_features(); 14 | float get_outcome(float observed_friction, uint64_t& random_seed); 15 | std::string id() const; 16 | 17 | private: 18 | const std::string _id; 19 | const float _temp; 20 | const float _angular_velocity; 21 | const float _load; 22 | friction_prob _outcome_probability; 23 | }; -------------------------------------------------------------------------------- /examples/rl_sim_cpp/simulation_stats.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "str_util.h" 3 | 4 | #include 5 | 6 | template 7 | class simulation_stats 8 | { 9 | public: 10 | simulation_stats() = default; 11 | ~simulation_stats() = default; 12 | void record(const std::string& id, T chosen_action, const float outcome) 13 | { 14 | auto& action_stats = _action_stats[std::make_pair(id, chosen_action)]; 15 | if (outcome > 0.00001f) ++action_stats.first; 16 | ++action_stats.second; 17 | auto& item_count = _item_stats[id]; 18 | ++item_count; 19 | ++_total_events; 20 | } 21 | 22 | std::string get_stats(const std::string& id, T chosen_action) 23 | { 24 | auto& action_stats = _action_stats[std::make_pair(id, chosen_action)]; 25 | auto& item_count = _item_stats[id]; 26 | 27 | return u::concat("wins: ", action_stats.first, ", out_of: ", action_stats.second, ", total: ", item_count); 28 | } 29 | 30 | int count() const { return _total_events; } 31 | 32 | private: 33 | std::map, std::pair> _action_stats; 34 | std::map _item_stats; 35 | int _total_events = 0; 36 | }; -------------------------------------------------------------------------------- /examples/rl_sim_cpp/targetver.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // Including SDKDDKVer.h defines the highest available Windows platform. 4 | 5 | // If you wish to build your application for a previous Windows platform, include WinSDKVer.h and 6 | // set the _WIN32_WINNT macro to the platform you wish to support before including SDKDDKVer.h. 7 | 8 | #include 9 | -------------------------------------------------------------------------------- /examples/test_cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(rl_test.out 2 | main.cc 3 | options.cc 4 | experiment_controller.cc 5 | test_data_provider.cc 6 | test_loop.cc 7 | ) 8 | 9 | # TODO remove internal header dependency of this example 10 | target_include_directories(rl_test.out PRIVATE $) 11 | 12 | target_link_libraries(rl_test.out PRIVATE Boost::program_options rlclientlib) 13 | 14 | if(RL_LINK_AZURE_LIBS) 15 | target_compile_definitions(rl_test.out PRIVATE LINK_AZURE_LIBS) 16 | find_package(azure-identity-cpp CONFIG REQUIRED) 17 | target_link_libraries(rl_test.out PRIVATE Azure::azure-identity) 18 | endif() 19 | -------------------------------------------------------------------------------- /examples/test_cpp/main.cc: -------------------------------------------------------------------------------- 1 | #include "options.h" 2 | #include "test_loop.h" 3 | 4 | #include 5 | #include 6 | 7 | int run_test_instance(size_t index, const boost::program_options::variables_map& vm) 8 | { 9 | test_loop loop(index, vm); 10 | 11 | if (!loop.init()) 12 | { 13 | std::cerr << "Test loop haven't initialized properly." << std::endl; 14 | return -1; 15 | } 16 | 17 | loop.run(); 18 | 19 | return 0; 20 | } 21 | 22 | int main(int argc, char** argv) 23 | { 24 | try 25 | { 26 | const auto vm = process_cmd_line(argc, argv); 27 | if (is_help(vm)) { return 0; } 28 | 29 | const size_t num_instances = vm["instances"].as(); 30 | std::vector instances; 31 | for (size_t i = 0; i < num_instances; ++i) { instances.emplace_back(&run_test_instance, i, vm); } 32 | for (size_t i = 0; i < num_instances; ++i) { instances[i].join(); } 33 | } 34 | catch (const std::exception& e) 35 | { 36 | std::cout << "Error: " << e.what() << std::endl; 37 | return -1; 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /examples/test_cpp/model.vw: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/examples/test_cpp/model.vw -------------------------------------------------------------------------------- /examples/test_cpp/options.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | boost::program_options::variables_map process_cmd_line(const int argc, char** argv); 7 | bool is_help(const boost::program_options::variables_map& vm); 8 | void throw_if_conflicting( 9 | const boost::program_options::variables_map& vm, const std::string& first, const std::string& second); 10 | -------------------------------------------------------------------------------- /examples/test_cpp/scripts/perf_test.cmd: -------------------------------------------------------------------------------- 1 | test_cpp.exe -j client.json -t 1 -n 10000 -x 1 -a 1 -e win_perf -f -p 2 | test_cpp.exe -j client.json -t 1 -n 10000 -x 1 -a 10 -e win_perf -f -p 3 | test_cpp.exe -j client.json -t 1 -n 10000 -x 1 -a 100 -e win_perf -f -p 4 | test_cpp.exe -j client.json -t 1 -n 1000 -x 1 -a 1000 -e win_perf -f -p 5 | 6 | test_cpp.exe -j client.json -t 1 -n 10000 -x 10 -a 1 -e win_perf -f -p 7 | test_cpp.exe -j client.json -t 1 -n 10000 -x 10 -a 10 -e win_perf -f -p 8 | test_cpp.exe -j client.json -t 1 -n 1000 -x 10 -a 100 -e win_perf -f -p 9 | test_cpp.exe -j client.json -t 1 -n 1000 -x 10 -a 1000 -e win_perf -f -p 10 | 11 | test_cpp.exe -j client.json -t 1 -n 10000 -x 100 -a 1 -e win_perf -f -p 12 | test_cpp.exe -j client.json -t 1 -n 1000 -x 100 -a 10 -e win_perf -f -p 13 | test_cpp.exe -j client.json -t 1 -n 1000 -x 100 -a 100 -e win_perf -f -p 14 | test_cpp.exe -j client.json -t 1 -n 1000 -x 100 -a 1000 -e win_perf -f -p 15 | 16 | test_cpp.exe -j client.json -t 1 -n 1000 -x 1000 -a 1 -e win_perf -f -p 17 | test_cpp.exe -j client.json -t 1 -n 1000 -x 1000 -a 10 -e win_perf -f -p 18 | test_cpp.exe -j client.json -t 1 -n 1000 -x 1000 -a 100 -e win_perf -f -p 19 | test_cpp.exe -j client.json -t 1 -n 100 -x 1000 -a 1000 -e win_perf -f -p 20 | -------------------------------------------------------------------------------- /examples/test_cpp/scripts/perf_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | ./rl_test.out -j client.json -t 1 -n 10000 -x 1 -a 1 -e linux_perf -f -p 6 | ./rl_test.out -j client.json -t 1 -n 10000 -x 1 -a 10 -e linux_perf -f -p 7 | ./rl_test.out -j client.json -t 1 -n 10000 -x 1 -a 100 -e linux_perf -f -p 8 | ./rl_test.out -j client.json -t 1 -n 1000 -x 1 -a 1000 -e linux_perf -f -p 9 | 10 | ./rl_test.out -j client.json -t 1 -n 10000 -x 10 -a 1 -e linux_perf -f -p 11 | ./rl_test.out -j client.json -t 1 -n 10000 -x 10 -a 10 -e linux_perf -f -p 12 | ./rl_test.out -j client.json -t 1 -n 1000 -x 10 -a 100 -e linux_perf -f -p 13 | ./rl_test.out -j client.json -t 1 -n 1000 -x 10 -a 1000 -e linux_perf -f -p 14 | 15 | ./rl_test.out -j client.json -t 1 -n 10000 -x 100 -a 1 -e linux_perf -f -p 16 | ./rl_test.out -j client.json -t 1 -n 1000 -x 100 -a 10 -e linux_perf -f -p 17 | ./rl_test.out -j client.json -t 1 -n 1000 -x 100 -a 100 -e linux_perf -f -p 18 | ./rl_test.out -j client.json -t 1 -n 1000 -x 100 -a 1000 -e linux_perf -f -p 19 | 20 | ./rl_test.out -j client.json -t 1 -n 1000 -x 1000 -a 1 -e linux_perf -f -p 21 | ./rl_test.out -j client.json -t 1 -n 1000 -x 1000 -a 10 -e linux_perf -f -p 22 | ./rl_test.out -j client.json -t 1 -n 1000 -x 1000 -a 100 -e linux_perf -f -p 23 | ./rl_test.out -j client.json -t 1 -n 100 -x 1000 -a 1000 -e linux_perf -f -p 24 | -------------------------------------------------------------------------------- /ext_libs/date/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(date INTERFACE) 2 | target_include_directories(date INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) 3 | -------------------------------------------------------------------------------- /ext_libs/fakeit/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(fakeit INTERFACE) 2 | target_include_directories(fakeit INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) 3 | -------------------------------------------------------------------------------- /ext_libs/string-view-lite/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(string_view_lite INTERFACE) 2 | target_include_directories( 3 | string_view_lite SYSTEM INTERFACE 4 | $ 5 | $ 6 | ) 7 | 8 | install( 9 | DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/nonstd/ 10 | DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/nonstd/ 11 | ) 12 | -------------------------------------------------------------------------------- /external_parser/event_processors/metadata.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // FileFormat_generated.h used for the payload type and encoding enum's 4 | #include "generated/v2/FileFormat_generated.h" 5 | #include "timestamp_helper.h" 6 | 7 | namespace v2 = reinforcement_learning::messages::flatbuff::v2; 8 | 9 | namespace metadata 10 | { 11 | // used both for interactions and observations 12 | struct event_metadata_info 13 | { 14 | std::string app_id; 15 | v2::PayloadType payload_type; 16 | float pass_probability; 17 | v2::EventEncoding event_encoding; 18 | std::string event_id; 19 | v2::LearningModeType learning_mode; 20 | }; 21 | } // namespace metadata 22 | -------------------------------------------------------------------------------- /external_parser/event_processors/timestamp_helper.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "date.h" 4 | #include "generated/v2/Metadata_generated.h" 5 | #include "vw/io/logger.h" 6 | 7 | #include 8 | 9 | namespace v2 = reinforcement_learning::messages::flatbuff::v2; 10 | using TimePoint = std::chrono::time_point; 11 | TimePoint timestamp_to_chrono(const v2::TimeStamp& ts); 12 | bool is_empty_timestamp(const v2::TimeStamp& ts); 13 | TimePoint get_enqueued_time(const v2::TimeStamp* enqueued_time_utc, const v2::TimeStamp* client_time_utc, 14 | bool use_client_time, VW::io::logger& logger); -------------------------------------------------------------------------------- /external_parser/log_converter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "joiners/example_joiner.h" 4 | #include "vw/io/logger.h" 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | namespace v2 = reinforcement_learning::messages::flatbuff::v2; 14 | 15 | namespace log_converter 16 | { 17 | void build_json(std::ofstream& outfile, joined_event::joined_event& je, VW::io::logger& logger); 18 | void build_cb_json(std::ofstream& outfile, joined_event::joined_event& je, VW::io::logger& logger); 19 | void build_ccb_json(std::ofstream& outfile, joined_event::joined_event& je, VW::io::logger& logger); 20 | void build_ca_json(std::ofstream& outfile, joined_event::joined_event& je, VW::io::logger& logger); 21 | void build_slates_json(std::ofstream& outfile, joined_event::joined_event& je, VW::io::logger& logger); 22 | ; 23 | } // namespace log_converter 24 | -------------------------------------------------------------------------------- /external_parser/lru_dedup_cache.cc: -------------------------------------------------------------------------------- 1 | #include "lru_dedup_cache.h" 2 | 3 | void lru_dedup_cache::add(uint64_t dedup_id, VW::example* ex) 4 | { 5 | dedup_examples.emplace(dedup_id, ex); 6 | lru.push_front(dedup_id); 7 | lru_pos.emplace(dedup_id, lru.begin()); 8 | } 9 | 10 | void lru_dedup_cache::update(uint64_t dedup_id) 11 | { 12 | // existing move to front 13 | auto position = lru_pos[dedup_id]; 14 | lru.erase(position); 15 | lru.push_front(dedup_id); 16 | lru_pos[dedup_id] = lru.begin(); 17 | } 18 | 19 | void lru_dedup_cache::clear_after(uint64_t first_id, release_example_f release_example, void* context) 20 | { 21 | // erase the rest 22 | auto iter = lru_pos[first_id]; 23 | // point to the element right after 24 | iter++; 25 | auto first_pos = iter; 26 | while (iter != lru.end()) 27 | { 28 | auto dedup_id = *iter; 29 | lru_pos.erase(dedup_id); 30 | release_example(context, dedup_examples[dedup_id]); 31 | dedup_examples.erase(dedup_id); 32 | iter++; 33 | } 34 | lru.erase(first_pos, lru.end()); 35 | } 36 | 37 | void lru_dedup_cache::clear(release_example_f release_example, void* context) 38 | { 39 | for (auto& dedup_item : dedup_examples) { release_example(context, dedup_item.second); } 40 | dedup_examples.clear(); 41 | lru_pos.clear(); 42 | lru.clear(); 43 | } 44 | 45 | bool lru_dedup_cache::exists(uint64_t dedup_id) { return dedup_examples.find(dedup_id) != dedup_examples.end(); } -------------------------------------------------------------------------------- /external_parser/metrics/metrics.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "event_processors/timestamp_helper.h" 4 | 5 | namespace metrics 6 | { 7 | struct joiner_metrics 8 | { 9 | size_t number_of_skipped_events = 0; 10 | float sum_cost_original = 0.f; 11 | TimePoint last_event_timestamp = TimePoint(); 12 | TimePoint first_event_timestamp = TimePoint(); 13 | std::string first_event_id = ""; 14 | std::string last_event_id = ""; 15 | }; 16 | } // namespace metrics 17 | -------------------------------------------------------------------------------- /external_parser/parse_example_converter.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) by respective owners including Yahoo!, Microsoft, and 2 | // individual contributors. All rights reserved. Released under a BSD (revised) 3 | // license as described in the file LICENSE. 4 | 5 | #include "parse_example_converter.h" 6 | 7 | #include "vw/core/example.h" 8 | #include "vw/core/global_data.h" 9 | 10 | namespace VW 11 | { 12 | namespace external 13 | { 14 | binary_json_converter::binary_json_converter(std::unique_ptr&& joiner, VW::io::logger logger) 15 | : parser(logger), _parser(std::move(joiner), logger) 16 | { 17 | } 18 | 19 | binary_json_converter::~binary_json_converter() = default; 20 | 21 | bool binary_json_converter::parse_examples(VW::workspace* all, io_buf& io_buf, VW::multi_ex& examples) 22 | { 23 | while (_parser.parse_examples(all, io_buf, examples)) 24 | { 25 | // do nothing 26 | } 27 | // vw will not learn, just exit 28 | return false; 29 | } 30 | 31 | void binary_json_converter::persist_metrics(metric_sink&) 32 | { 33 | // do we want metrics here? 34 | } 35 | 36 | } // namespace external 37 | } // namespace VW -------------------------------------------------------------------------------- /external_parser/parse_example_converter.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) by respective owners including Yahoo!, Microsoft, and 2 | // individual contributors. All rights reserved. Released under a BSD (revised) 3 | // license as described in the file LICENSE. 4 | 5 | #pragma once 6 | 7 | #include "joiners/i_joiner.h" 8 | #include "parse_example_binary.h" 9 | #include "parse_example_external.h" 10 | 11 | namespace VW 12 | { 13 | namespace external 14 | { 15 | class binary_json_converter : public parser 16 | { 17 | public: 18 | binary_json_converter(std::unique_ptr&& joiner, VW::io::logger logger); // taking ownership of joiner 19 | ~binary_json_converter(); 20 | bool parse_examples(VW::workspace* all, io_buf& io_buf, VW::multi_ex& examples) override; 21 | void persist_metrics(metric_sink& metrics_sink) override; 22 | 23 | private: 24 | binary_parser _parser; 25 | }; 26 | } // namespace external 27 | } // namespace VW -------------------------------------------------------------------------------- /external_parser/unit_tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(Boost_USE_STATIC_LIBS ON) 2 | find_package(Boost COMPONENTS unit_test_framework filesystem system program_options thread REQUIRED) 3 | 4 | set(TEST_SOURCES 5 | test_vw_external_parser.cc 6 | test_vw_binary_parser.cc 7 | test_example_joiner.cc 8 | test_reward_functions.cc 9 | main.cc 10 | test_common.cc 11 | test_lru_dedup_cache.cc 12 | test_timestamp_helper.cc 13 | test_log_converter.cc 14 | test_skip_learn.cc 15 | test_metrics.cc 16 | test_client_and_enqueued_time.cc 17 | ) 18 | 19 | add_executable(binary_parser_unit_tests ${TEST_SOURCES}) 20 | 21 | # Add the include directories from vw target for testing 22 | target_include_directories(binary_parser_unit_tests 23 | PRIVATE 24 | $ 25 | ) 26 | 27 | target_link_libraries(binary_parser_unit_tests 28 | PRIVATE 29 | rl_binary_parser 30 | Boost::unit_test_framework 31 | Boost::system 32 | Boost::filesystem 33 | ) 34 | 35 | add_test(NAME binary_parser_unit_tests COMMAND binary_parser_unit_tests -- ${CMAKE_CURRENT_LIST_DIR}/test_files/) 36 | -------------------------------------------------------------------------------- /external_parser/unit_tests/main.cc: -------------------------------------------------------------------------------- 1 | #define BOOST_TEST_MODULE Main 2 | #include 3 | -------------------------------------------------------------------------------- /external_parser/unit_tests/test_common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "generated/v2/FileFormat_generated.h" 6 | #include "generated/v2/Metadata_generated.h" 7 | #include "vw/core/v_array.h" 8 | #include "vw/core/vw.h" 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | namespace v2 = reinforcement_learning::messages::flatbuff::v2; 16 | constexpr float FLOAT_TOL = 0.0001f; 17 | 18 | // learn/predict isn't called in the unit test but cleanup examples 19 | // expects shared pred to be set for slates 20 | void set_slates_label(VW::multi_ex& examples); 21 | 22 | void clear_examples(VW::multi_ex& examples, VW::workspace* vw); 23 | 24 | void set_buffer_as_vw_input(const std::vector& buffer, VW::workspace* vw); 25 | 26 | std::vector read_file(const std::string& file_name); 27 | 28 | std::string get_test_files_location(); 29 | 30 | std::vector wrap_into_joined_events( 31 | std::vector& buffer, std::vector& detached_buffers); 32 | -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/client_time/cb_v2_client_time.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/client_time/cb_v2_client_time.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/client_time/f-reward_3obs_v2_client_time.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/client_time/f-reward_3obs_v2_client_time.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/fb_events/ca_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/fb_events/ca_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/fb_events/ca_v2_size_2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/fb_events/ca_v2_size_2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/fb_events/cb_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/fb_events/cb_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/fb_events/cb_v2_dedup.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/fb_events/cb_v2_dedup.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/fb_events/cb_v2_size_2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/fb_events/cb_v2_size_2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/fb_events/cb_v2_size_5_apprentice.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/fb_events/cb_v2_size_5_apprentice.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/fb_events/ccb-baseline-loopinteractions_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/fb_events/ccb-baseline-loopinteractions_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/fb_events/ccb-baseline-loopobservations_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/fb_events/ccb-baseline-loopobservations_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/fb_events/ccb_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/fb_events/ccb_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/fb_events/f-reward_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/fb_events/f-reward_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/fb_events/f-reward_v2_size_2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/fb_events/f-reward_v2_size_2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/fb_events/f-reward_v2_size_5_apprentice.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/fb_events/f-reward_v2_size_5_apprentice.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/fb_events/fi-reward_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/fb_events/fi-reward_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/fb_events/invalid-cb_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/fb_events/invalid-cb_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/invalid_joined_logs/bad_event_in_joined_event.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/invalid_joined_logs/bad_event_in_joined_event.log -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/invalid_joined_logs/bad_magic.log: -------------------------------------------------------------------------------- 1 | VWFC -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/invalid_joined_logs/bad_version.log: -------------------------------------------------------------------------------- 1 | VWFB -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/invalid_joined_logs/corrupt_joined_payload.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/invalid_joined_logs/corrupt_joined_payload.log -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/invalid_joined_logs/dedup_payload_missing.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/invalid_joined_logs/dedup_payload_missing.log -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/invalid_joined_logs/empty_msg_hdr.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/invalid_joined_logs/empty_msg_hdr.log -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/invalid_joined_logs/incomplete_checkpoint_info.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/invalid_joined_logs/incomplete_checkpoint_info.log -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/invalid_joined_logs/interaction_with_no_observation.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/invalid_joined_logs/interaction_with_no_observation.log -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/invalid_joined_logs/invalid_cb_context.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/invalid_joined_logs/invalid_cb_context.log -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/invalid_joined_logs/no_interaction_but_with_observation.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/invalid_joined_logs/no_interaction_but_with_observation.log -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/invalid_joined_logs/no_msg_hdr.log: -------------------------------------------------------------------------------- 1 | VWFB -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/invalid_joined_logs/one_invalid_msg_type.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/invalid_joined_logs/one_invalid_msg_type.log -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/reward_functions/ca/ca_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/reward_functions/ca/ca_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/reward_functions/ca/f-reward_3obs_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/reward_functions/ca/f-reward_3obs_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/reward_functions/cb/cb_apprentice_match_baseline_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/reward_functions/cb/cb_apprentice_match_baseline_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/reward_functions/cb/cb_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/reward_functions/cb/cb_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/reward_functions/cb/f-reward_3obs_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/reward_functions/cb/f-reward_3obs_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/reward_functions/ccb/ccb-apprentice-baseline-match_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/reward_functions/ccb/ccb-apprentice-baseline-match_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/reward_functions/ccb/ccb-apprentice-baseline-not-match_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/reward_functions/ccb/ccb-apprentice-baseline-not-match_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/reward_functions/ccb/ccb-with-slot-id_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/reward_functions/ccb/ccb-with-slot-id_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/reward_functions/ccb/ccb_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/reward_functions/ccb/ccb_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/reward_functions/ccb/fi-out-of-bound-reward_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/reward_functions/ccb/fi-out-of-bound-reward_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/reward_functions/ccb/fi-reward_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/reward_functions/ccb/fi-reward_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/reward_functions/ccb/fmix-reward_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/reward_functions/ccb/fmix-reward_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/reward_functions/ccb/fs-reward_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/reward_functions/ccb/fs-reward_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/reward_functions/slates/fi-reward_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/reward_functions/slates/fi-reward_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/reward_functions/slates/slates_v2.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/reward_functions/slates/slates_v2.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/skip_learn/ca/deferred_action_with_activation.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/skip_learn/ca/deferred_action_with_activation.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/skip_learn/ca/deferred_action_with_activation_deduped.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/skip_learn/ca/deferred_action_with_activation_deduped.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/skip_learn/ca/deferred_action_without_activation.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/skip_learn/ca/deferred_action_without_activation.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/skip_learn/ca/deferred_action_without_activation_deduped.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/skip_learn/ca/deferred_action_without_activation_deduped.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/skip_learn/ca/mixed_deferred_action_events.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/skip_learn/ca/mixed_deferred_action_events.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/skip_learn/cb/deferred_action_with_activation.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/skip_learn/cb/deferred_action_with_activation.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/skip_learn/cb/deferred_action_with_activation_deduped.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/skip_learn/cb/deferred_action_with_activation_deduped.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/skip_learn/cb/deferred_action_without_activation.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/skip_learn/cb/deferred_action_without_activation.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/skip_learn/cb/deferred_action_without_activation_deduped.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/skip_learn/cb/deferred_action_without_activation_deduped.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/skip_learn/cb/mixed_deferred_action_events.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/skip_learn/cb/mixed_deferred_action_events.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/skip_learn/ccb/deferred_action_with_activation.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/skip_learn/ccb/deferred_action_with_activation.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/skip_learn/ccb/deferred_action_with_activation_deduped.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/skip_learn/ccb/deferred_action_with_activation_deduped.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/skip_learn/ccb/deferred_action_without_activation.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/skip_learn/ccb/deferred_action_without_activation.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/skip_learn/ccb/deferred_action_without_activation_deduped.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/skip_learn/ccb/deferred_action_without_activation_deduped.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/skip_learn/ccb/mixed_deferred_action_events.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/skip_learn/ccb/mixed_deferred_action_events.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/skip_learn/slates/deferred_action_with_activation.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/skip_learn/slates/deferred_action_with_activation.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/skip_learn/slates/deferred_action_without_activation.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/skip_learn/slates/deferred_action_without_activation.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/skip_learn/slates/mixed_deferred_action_events.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/skip_learn/slates/mixed_deferred_action_events.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/test_outputs/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/average_reward_100_interactions.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/average_reward_100_interactions.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/ca_loop_mixed_skip_learn.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/ca_loop_mixed_skip_learn.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/ca_loop_mixed_skip_learn.json: -------------------------------------------------------------------------------- 1 | {"_label_ca":{"cost":-1.5,"pdf_value":0.0005050505278632045,"action":1.014871597290039},"Timestamp":"2021-08-25T15:36:54.000000Z","Version":"1","EventId":"91f71c8","c":{"RobotJoint1":{"friction":78}},"VWState":{"m":"N/A"},"_skipLearn":true} 2 | {"_label_ca":{"cost":-1.5,"pdf_value":0.4755050539970398,"action":12.464624404907227},"Timestamp":"2021-08-25T15:36:54.000000Z","Version":"1","EventId":"75d50657","c":{"RobotJoint1":{"friction":78}},"VWState":{"m":"N/A"}} 3 | {"_label_ca":{"cost":-1.5,"pdf_value":0.4755050539970398,"action":12.43958568572998},"Timestamp":"2021-08-25T15:36:54.000000Z","Version":"1","EventId":"e28a9ae6","c":{"RobotJoint1":{"friction":78}},"VWState":{"m":"N/A"}} 4 | -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/ca_loop_simple.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/ca_loop_simple.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/ca_loop_simple.json: -------------------------------------------------------------------------------- 1 | {"_label_ca":{"cost":-1.5,"pdf_value":0.0005050505278632045,"action":1.014871597290039},"Timestamp":"2021-08-24T14:38:15.000000Z","Version":"1","EventId":"91f71c8","c":{"RobotJoint1":{"friction":78}},"VWState":{"m":"N/A"}} 2 | {"_label_ca":{"cost":-1.5,"pdf_value":0.4755050539970398,"action":12.464624404907227},"Timestamp":"2021-08-24T14:38:15.000000Z","Version":"1","EventId":"75d50657","c":{"RobotJoint1":{"friction":78}},"VWState":{"m":"N/A"}} 3 | {"_label_ca":{"cost":-1.5,"pdf_value":0.4755050539970398,"action":12.43958568572998},"Timestamp":"2021-08-24T14:38:15.000000Z","Version":"1","EventId":"e28a9ae6","c":{"RobotJoint1":{"friction":78}},"VWState":{"m":"N/A"}} 4 | -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/ca_loop_simple_e2e.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/ca_loop_simple_e2e.log -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/ca_loop_skip_learn_e2e.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/ca_loop_skip_learn_e2e.log -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/ca_mixed_deferred_action_events_20.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/ca_mixed_deferred_action_events_20.log -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/cb_apprentice_5.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/cb_apprentice_5.log -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/cb_dedup_compressed.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/cb_dedup_compressed.log -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/cb_deferred_actions_w_activations_and_apprentice_10.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/cb_deferred_actions_w_activations_and_apprentice_10.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/cb_joined_with_pdrop_05.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/cb_joined_with_pdrop_05.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/cb_joined_with_pdrop_1.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/cb_joined_with_pdrop_1.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/cb_simple.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/cb_simple.log -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/ccb_apprentice_5.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/ccb_apprentice_5.log -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/ccb_deferred_actions_w_activations_and_apprentice_20.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/ccb_deferred_actions_w_activations_and_apprentice_20.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/ccb_simple.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/ccb_simple.log -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/ccb_sum_reward_100_interactions.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/ccb_sum_reward_100_interactions.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/ccb_w_slot_id.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/ccb_w_slot_id.log -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/ccb_w_various_outcomes.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/ccb_w_various_outcomes.log -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/multistep_2_episodes.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/multistep_2_episodes.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/multistep_3_deferred_episodes.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/multistep_3_deferred_episodes.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/multistep_unordered_episodes.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/multistep_unordered_episodes.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/rcrrmr.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/rcrrmr.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/rrcr.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/rrcr.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/slates_average_reward_100_interactions.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/slates_average_reward_100_interactions.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/slates_deferred_actions_w_activations_10.fb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/slates_deferred_actions_w_activations_10.fb -------------------------------------------------------------------------------- /external_parser/unit_tests/test_files/valid_joined_logs/slates_simple.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/external_parser/unit_tests/test_files/valid_joined_logs/slates_simple.log -------------------------------------------------------------------------------- /external_parser/utils.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) by respective owners including Yahoo!, Microsoft, and 2 | // individual contributors. All rights reserved. Released under a BSD (revised) 3 | // license as described in the file LICENSE. 4 | 5 | #include 6 | #ifndef _WIN32 7 | # define _stricmp strcasecmp 8 | #endif 9 | 10 | namespace VW 11 | { 12 | namespace external 13 | { 14 | bool stricmp(const char* first, const char* second) { return _stricmp(first, second); } 15 | } // namespace external 16 | } // namespace VW -------------------------------------------------------------------------------- /external_parser/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) by respective owners including Yahoo!, Microsoft, and 2 | // individual contributors. All rights reserved. Released under a BSD (revised) 3 | // license as described in the file LICENSE. 4 | 5 | #pragma once 6 | #include "vw/core/memory.h" 7 | #include "vw/core/vw.h" 8 | 9 | namespace VW 10 | { 11 | namespace external 12 | { 13 | bool stricmp(const char* first, const char* second); 14 | 15 | template 16 | bool str_to_enum( 17 | const std::string& str, const std::map enum_map, const enum_t default_value, enum_t& result) 18 | { 19 | for (auto p : enum_map) 20 | { 21 | if (!stricmp(p.first, str.c_str())) 22 | { 23 | result = p.second; 24 | return true; 25 | } 26 | } 27 | result = default_value; 28 | return false; 29 | } 30 | 31 | } // namespace external 32 | } // namespace VW 33 | -------------------------------------------------------------------------------- /include/action_flags.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @brief action_flags definition. 3 | * 4 | * @file action_flags.h 5 | * @author Rajan Chari et al 6 | * @date 2018-07-18 7 | */ 8 | #pragma once 9 | 10 | namespace reinforcement_learning 11 | { 12 | enum action_flags 13 | { 14 | DEFAULT = 0, 15 | DEFERRED = 1 16 | }; 17 | } 18 | -------------------------------------------------------------------------------- /include/config_utility.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace reinforcement_learning 5 | { 6 | namespace utility 7 | { 8 | class configuration; 9 | } 10 | class api_status; 11 | class i_trace; 12 | } // namespace reinforcement_learning 13 | 14 | namespace reinforcement_learning 15 | { 16 | namespace utility 17 | { 18 | namespace config 19 | { 20 | std::string load_config_json(); 21 | int create_from_json( 22 | const std::string& config_json, configuration& cc, i_trace* trace = nullptr, api_status* = nullptr); 23 | } // namespace config 24 | } // namespace utility 25 | } // namespace reinforcement_learning 26 | -------------------------------------------------------------------------------- /include/err_constants.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @brief Definition of all API error return codes and descriptions 3 | * 4 | * @file err_constants.h 5 | * @author Rajan Chari et al 6 | * @date 2018-07-18 7 | */ 8 | #pragma once 9 | 10 | //! [Error Generator] 11 | #define ERROR_CODE_DEFINITION(code, name, message) \ 12 | namespace reinforcement_learning \ 13 | { \ 14 | namespace error_code \ 15 | { \ 16 | const int name = code; \ 17 | char const* const name##_s = message; \ 18 | } \ 19 | } 20 | //! [Error Generator] 21 | 22 | #include "errors_data.h" 23 | 24 | namespace reinforcement_learning 25 | { 26 | namespace error_code 27 | { 28 | // Success code 29 | const int success = 0; 30 | } // namespace error_code 31 | } // namespace reinforcement_learning 32 | 33 | namespace reinforcement_learning 34 | { 35 | namespace error_code 36 | { 37 | char const* const unknown_s = "Unexpected error."; 38 | } 39 | } // namespace reinforcement_learning 40 | 41 | #undef ERROR_CODE_DEFINITION 42 | -------------------------------------------------------------------------------- /include/internal_constants.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace reinforcement_learning 4 | { 5 | namespace config_constants 6 | { 7 | const char* const EPISODE = "episode"; 8 | const char* const INTERACTION = "interaction"; 9 | const char* const OBSERVATION = "observation"; 10 | const char* const CONFIG_SECTION = "config.section"; 11 | const char* const EH_HOST = ".eventhub.host"; 12 | const char* const EH_NAME = ".eventhub.name"; 13 | const char* const EH_KEY_NAME = ".eventhub.keyname"; 14 | const char* const EH_KEY = ".eventhub.key"; 15 | } // namespace config_constants 16 | } // namespace reinforcement_learning 17 | -------------------------------------------------------------------------------- /include/learning_mode.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @brief learning_mode definition. 3 | * 4 | * @file learning_mode.h 5 | * @author Chenxi Zhao 6 | * @date 2020-02-13 7 | */ 8 | #pragma once 9 | 10 | namespace reinforcement_learning 11 | { 12 | enum learning_mode 13 | { 14 | ONLINE = 0, 15 | APPRENTICE = 1, 16 | LOGGINGONLY = 2, 17 | }; 18 | 19 | namespace learning 20 | { 21 | learning_mode to_learning_mode(const char* learning_mode); 22 | } 23 | } // namespace reinforcement_learning 24 | -------------------------------------------------------------------------------- /include/loop_apis/README.md: -------------------------------------------------------------------------------- 1 | # RL Client Library Loop APIs 2 | These public apis split up the old live_model interface into various loop types, ensuring that only relevant calls can be made on a given loop. 3 | -------------------------------------------------------------------------------- /include/personalization.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "api_status.h" 3 | #include "config_utility.h" 4 | #include "configuration.h" 5 | #include "live_model.h" 6 | #include "ranking_response.h" 7 | 8 | namespace personalization = reinforcement_learning; 9 | -------------------------------------------------------------------------------- /include/rl_string_view.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "nonstd/string_view.hpp" 4 | 5 | namespace reinforcement_learning 6 | { 7 | using string_view = nonstd::string_view; 8 | } 9 | -------------------------------------------------------------------------------- /include/sender.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "configuration.h" 3 | #include "data_buffer.h" 4 | 5 | #include 6 | namespace reinforcement_learning 7 | { 8 | class api_status; 9 | class i_sender 10 | { 11 | public: 12 | using buffer = std::shared_ptr; 13 | virtual int init(const utility::configuration& config, api_status* status) = 0; 14 | 15 | // For mocking in unit tests, buffer& data may be initialized with nullptr 16 | // Disable UBSan here to prevent generating an error 17 | #ifdef RL_USE_UBSAN 18 | __attribute__((no_sanitize("undefined"))) 19 | #endif 20 | int send(const buffer& data, api_status* status = nullptr) 21 | { 22 | return v_send(data, status); 23 | } 24 | 25 | virtual ~i_sender() = default; 26 | 27 | protected: 28 | virtual int v_send(const buffer& data, api_status* status = nullptr) = 0; 29 | }; 30 | } // namespace reinforcement_learning 31 | -------------------------------------------------------------------------------- /include/str_util.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace reinforcement_learning 6 | { 7 | namespace utility 8 | { 9 | template 10 | std::string concat(std::ostringstream& os, const Last& last) 11 | { 12 | os << last; 13 | return os.str(); 14 | } 15 | 16 | template 17 | std::string concat(std::ostringstream& os, const First& first, const Rest&... rest) 18 | { 19 | os << first; 20 | return concat(os, rest...); 21 | } 22 | 23 | template 24 | std::string concat(const First& first, const Rest&... rest) 25 | { 26 | std::ostringstream os; 27 | return concat(os, first, rest...); 28 | } 29 | 30 | class str_util 31 | { 32 | public: 33 | static std::string& to_lower(std::string& sval); 34 | static std::string& ltrim(std::string& sval); 35 | static std::string& rtrim(std::string& sval); 36 | static std::string& trim(std::string& sval); 37 | }; 38 | } // namespace utility 39 | } // namespace reinforcement_learning 40 | -------------------------------------------------------------------------------- /nuget/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Check that variables are set 2 | if(NOT DEFINED RL_NUGET_PACKAGE_NAME OR NOT DEFINED RL_NUGET_PACKAGE_VERSION OR NOT DEFINED NATIVE_NUGET_PLATFORM_TAG) 3 | message(FATAL_ERROR "When building Nuget package, must define variables: RL_NUGET_PACKAGE_NAME, RL_NUGET_PACKAGE_VERSION, NATIVE_NUGET_PLATFORM_TAG") 4 | endif() 5 | 6 | # Generate NuGet package specification file from template 7 | configure_file(rlclientlib.nuspec.in rlclientlib.nuspec @ONLY) 8 | install(FILES ${CMAKE_CURRENT_BINARY_DIR}/rlclientlib.nuspec DESTINATION ./) 9 | 10 | # Generate the .targets file from template 11 | configure_file(rlclientlib.targets.in rlclientlib.targets @ONLY) 12 | install(FILES ${CMAKE_CURRENT_BINARY_DIR}/rlclientlib.targets DESTINATION ./ RENAME ${RL_NUGET_PACKAGE_NAME}-v${MSVC_TOOLSET_VERSION}-${CMAKE_BUILD_TYPE}-${NATIVE_NUGET_PLATFORM_TAG}.targets) 13 | 14 | # Build package 15 | install(SCRIPT CreateNugetPackage.cmake) -------------------------------------------------------------------------------- /nuget/CreateNugetPackage.cmake: -------------------------------------------------------------------------------- 1 | find_program( 2 | nuget_exe 3 | NAMES "nuget.exe" "NuGet.exe" 4 | HINTS "${CMAKE_BINARY_DIR}/nuget/" "${CMAKE_BINARY_DIR}/../nuget/" "${CMAKE_BINARY_DIR}/../../nuget/" "${CMAKE_BINARY_DIR}/../../../nuget/" 5 | NO_CACHE 6 | REQUIRED 7 | ) 8 | 9 | message("Creating Nuget package...") 10 | message("Path to nuget.exe: ${nuget_exe}") 11 | message("Working directory: ${CMAKE_INSTALL_PREFIX}") 12 | 13 | execute_process( 14 | COMMAND "${nuget_exe}" pack rlclientlib.nuspec 15 | WORKING_DIRECTORY "${CMAKE_INSTALL_PREFIX}" 16 | RESULT_VARIABLE return_code 17 | ) 18 | 19 | if(return_code) 20 | message(FATAL_ERROR "Failed to build Nuget package: nuget.exe returned ${return_code}") 21 | endif() -------------------------------------------------------------------------------- /nuget/dotnet/rl.net.props: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /nuget/dotnet/rl.net.targets: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /nuget/dotnet/test/client.json: -------------------------------------------------------------------------------- 1 | { 2 | "ApplicationID": "testclient", 3 | "interaction.sender.implementation": "INTERACTION_FILE_SENDER", 4 | "observation.sender.implementation": "OBSERVATION_FILE_SENDER", 5 | "interaction.file.name": "interaction.fbs", 6 | "observation.file.name": "observation.fbs", 7 | "IsExplorationEnabled": true, 8 | "InitialExplorationEpsilon": 1.0, 9 | "LearningMode": "Online", 10 | "model.source": "FILE_MODEL_DATA", 11 | "protocol.version":"2" 12 | } 13 | -------------------------------------------------------------------------------- /nuget/dotnet/test/dotnetcore_nuget_test.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | Exe 5 | net6.0 6 | enable 7 | enable 8 | x64 9 | x64 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /nuget/nuget.exe: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/nuget/nuget.exe -------------------------------------------------------------------------------- /nuget/rlclientlib.nuspec.in: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | @RL_NUGET_PACKAGE_NAME@-v@MSVC_TOOLSET_VERSION@-@CMAKE_BUILD_TYPE@-@NATIVE_NUGET_PLATFORM_TAG@ 5 | @RL_NUGET_PACKAGE_VERSION@ 6 | Reinforcement Learning Client Library - Static Build 7 | John Langford et al 8 | MIT 9 | RLClientLib is an Interaction-side integration library for Reinforcement Learning loops: Predict, Log, Learn, Update 10 | Copyright (c) by respective owners including Yahoo!, Microsoft, and individual contributors. 11 | native 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /nuget/test/main.cc: -------------------------------------------------------------------------------- 1 | #include "config_utility.h" 2 | #include "live_model.h" 3 | 4 | #include 5 | 6 | namespace r = reinforcement_learning; 7 | namespace u = reinforcement_learning::utility; 8 | 9 | int main() 10 | { 11 | u::configuration config; 12 | auto rl = std::make_unique(config); 13 | 14 | std::cout << "Successfully ran rlclientlib test program" << std::endl; 15 | return 0; 16 | } -------------------------------------------------------------------------------- /rlclientlib/api_status.cc: -------------------------------------------------------------------------------- 1 | #include "api_status.h" 2 | 3 | #include "trace_logger.h" 4 | 5 | namespace reinforcement_learning 6 | { 7 | int api_status::get_error_code() const { return _error_code; } 8 | 9 | const char* api_status::get_error_msg() const { return _error_msg.c_str(); } 10 | 11 | api_status::api_status() : _error_code(0) {} 12 | 13 | // static helper: update the status if needed (i.e. if it is not null) 14 | void api_status::try_update(api_status* status, const int new_code, const char* new_msg) 15 | { 16 | if (status != nullptr) 17 | { 18 | status->_error_code = new_code; 19 | status->_error_msg = new_msg; 20 | } 21 | } 22 | 23 | void api_status::try_clear(api_status* status) 24 | { 25 | if (status != nullptr) 26 | { 27 | status->_error_code = 0; 28 | status->_error_msg.clear(); 29 | } 30 | } 31 | 32 | status_builder::status_builder(i_trace* trace, api_status* status, const int code) 33 | : _code{code}, _status{status}, _trace{trace} 34 | { 35 | if (enable_logging()) { _os << "(ERR:" << _code << ")"; } 36 | } 37 | 38 | status_builder::~status_builder() 39 | { 40 | if (_status != nullptr) { api_status::try_update(_status, _code, _os.str().c_str()); } 41 | if (_trace != nullptr) { _trace->log(0, _os.str()); } 42 | } 43 | 44 | status_builder::operator int() const { return _code; } 45 | 46 | bool status_builder::enable_logging() const { return _status != nullptr || _trace != nullptr; } 47 | } // namespace reinforcement_learning 48 | -------------------------------------------------------------------------------- /rlclientlib/azure_factories.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "oauth_callback_fn.h" 3 | 4 | #include 5 | 6 | namespace reinforcement_learning 7 | { 8 | void register_azure_factories(); 9 | 10 | void register_azure_oauth_factories(oauth_callback_t& callback); 11 | 12 | using azure_cred_provider_cb_wrapper_t = std::unique_ptr; 13 | } // namespace reinforcement_learning 14 | -------------------------------------------------------------------------------- /rlclientlib/console_tracer.cc: -------------------------------------------------------------------------------- 1 | #include "console_tracer.h" 2 | 3 | #include "str_util.h" 4 | 5 | #include 6 | 7 | namespace reinforcement_learning 8 | { 9 | void console_tracer::log(int log_level, const std::string& msg) 10 | { 11 | if (log_level < _log_level) { return; } 12 | std::cout << details::get_log_level_string(log_level) << ": " << msg << std::endl; 13 | } 14 | 15 | void console_tracer::set_level(int log_level) { _log_level = log_level; } 16 | } // namespace reinforcement_learning 17 | -------------------------------------------------------------------------------- /rlclientlib/console_tracer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "trace_logger.h" 3 | 4 | namespace reinforcement_learning 5 | { 6 | class console_tracer : public i_trace 7 | { 8 | public: 9 | // Inherited via i_trace 10 | void log(int log_level, const std::string& msg) override; 11 | void set_level(int log_level) override; 12 | 13 | private: 14 | int _log_level; 15 | }; 16 | } // namespace reinforcement_learning 17 | -------------------------------------------------------------------------------- /rlclientlib/constants.cc: -------------------------------------------------------------------------------- 1 | #include "constants.h" 2 | 3 | namespace reinforcement_learning 4 | { 5 | namespace value 6 | { 7 | #ifdef USE_AZURE_FACTORIES 8 | const char* const DEFAULT_EPISODE_SENDER = EPISODE_EH_SENDER; 9 | const char* const DEFAULT_OBSERVATION_SENDER = OBSERVATION_EH_SENDER; 10 | const char* const DEFAULT_INTERACTION_SENDER = INTERACTION_EH_SENDER; 11 | const char* const DEFAULT_DATA_TRANSPORT = AZURE_STORAGE_BLOB; 12 | const char* const DEFAULT_TIME_PROVIDER = NULL_TIME_PROVIDER; 13 | #else 14 | const char* const DEFAULT_EPISODE_SENDER = EPISODE_FILE_SENDER; 15 | const char* const DEFAULT_OBSERVATION_SENDER = OBSERVATION_FILE_SENDER; 16 | const char* const DEFAULT_INTERACTION_SENDER = INTERACTION_FILE_SENDER; 17 | const char* const DEFAULT_DATA_TRANSPORT = NO_MODEL_DATA; 18 | const char* const DEFAULT_TIME_PROVIDER = CLOCK_TIME_PROVIDER; 19 | #endif 20 | 21 | const char* get_default_episode_sender() { return DEFAULT_EPISODE_SENDER; } 22 | 23 | const char* get_default_observation_sender() { return DEFAULT_OBSERVATION_SENDER; } 24 | 25 | const char* get_default_interaction_sender() { return DEFAULT_INTERACTION_SENDER; } 26 | 27 | const char* get_default_data_transport() { return DEFAULT_DATA_TRANSPORT; } 28 | 29 | const char* get_default_time_provider() { return DEFAULT_TIME_PROVIDER; } 30 | } // namespace value 31 | } // namespace reinforcement_learning 32 | -------------------------------------------------------------------------------- /rlclientlib/dedup.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "configuration.h" 3 | #include "logger/logger_facade.h" 4 | 5 | namespace reinforcement_learning 6 | { 7 | bool should_use_dedup_logger_extension(const utility::configuration& config, const char* section); 8 | 9 | std::unique_ptr create_dedup_logger_extension( 10 | const utility::configuration& config, const char* section, std::unique_ptr time_provider); 11 | } // namespace reinforcement_learning 12 | -------------------------------------------------------------------------------- /rlclientlib/error_callback_fn.cc: -------------------------------------------------------------------------------- 1 | #include "error_callback_fn.h" 2 | 3 | using namespace std; 4 | 5 | namespace reinforcement_learning 6 | { 7 | void error_callback_fn::report_error(api_status& s) 8 | { 9 | if (!_fn) { return; } 10 | 11 | lock_guard lock(_mutex); 12 | if (_fn) 13 | { 14 | try 15 | { 16 | _fn(s); 17 | } 18 | catch (...) 19 | { 20 | // Error handler is throwing so can't call it again 21 | } 22 | } 23 | } 24 | 25 | } // namespace reinforcement_learning 26 | -------------------------------------------------------------------------------- /rlclientlib/extensions/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if (rlclientlib_BUILD_ONNXRUNTIME_EXTENSION) 2 | message("Building RLClientLib Extension: OnnxRuntime") 3 | add_subdirectory(onnx) 4 | endif() -------------------------------------------------------------------------------- /rlclientlib/extensions/onnx/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/Modules/") 2 | find_package(OnnxRuntime REQUIRED) 3 | find_package(cpprestsdk REQUIRED) 4 | 5 | SET(ONNX_EXTENSION_SOURCES 6 | src/onnx_model.cc 7 | src/onnx_extension.cc 8 | src/onnx_input.cc 9 | src/tensor_parser.cc 10 | ) 11 | 12 | SET(ONNX_EXTENSION_PUBLIC_HEADERS 13 | include/onnx_extension.h 14 | ) 15 | 16 | SET(ONNX_EXTENSION_HEADERS 17 | src/onnx_model.h 18 | src/onnx_input.h 19 | src/tensor_parser.h 20 | src/tensor_notation.h 21 | ) 22 | 23 | #source_group("Sources" FILES ${ONNX_EXTENSION_SOURCES}) 24 | #source_group("Public headers" FILES ${ONNX_EXTENSION_PUBLIC_HEADERS}) 25 | #source_group("Private headers" FILES ${ONNX_EXTENSION_HEADERS}) 26 | 27 | add_library(rlclientlib-onnx ${ONNX_EXTENSION_SOURCES} ${ONNX_EXTENSION_PUBLIC_HEADERS} ${ONNX_EXTENSION_PRIVATE_HEADERS}) 28 | 29 | set_target_properties(rlclientlib-onnx PROPERTIES POSITION_INDEPENDENT_CODE ON) 30 | 31 | target_include_directories(rlclientlib-onnx 32 | PUBLIC 33 | ${CMAKE_CURRENT_SOURCE_DIR}/include 34 | PRIVATE 35 | ${CMAKE_CURRENT_SOURCE_DIR}/src 36 | ) 37 | 38 | target_link_libraries(rlclientlib-onnx PUBLIC rlclientlib onnxruntime) -------------------------------------------------------------------------------- /rlclientlib/extensions/onnx/include/onnx_extension.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace reinforcement_learning 4 | { 5 | namespace onnx 6 | { 7 | void register_onnx_factory(); 8 | } 9 | } // namespace reinforcement_learning 10 | 11 | // Constants 12 | 13 | namespace reinforcement_learning 14 | { 15 | namespace name 16 | { 17 | // TODO: Explore and expose useful configuration settings here 18 | const char* const ONNX_USE_UNSTRUCTURED_INPUT = "onnx.use_unstructured_input"; 19 | const char* const ONNX_OUTPUT_NAME = "onnx.output_name"; 20 | } // namespace name 21 | } // namespace reinforcement_learning 22 | 23 | namespace reinforcement_learning 24 | { 25 | namespace value 26 | { 27 | const char* const ONNXRUNTIME_MODEL = "ONNXRUNTIME"; 28 | } 29 | } // namespace reinforcement_learning -------------------------------------------------------------------------------- /rlclientlib/extensions/onnx/src/onnx_extension.cc: -------------------------------------------------------------------------------- 1 | #include "onnx_extension.h" 2 | 3 | #include "api_status.h" 4 | #include "configuration.h" 5 | #include "constants.h" 6 | #include "err_constants.h" 7 | #include "factory_resolver.h" 8 | #include "model_mgmt.h" 9 | #include "onnx_model.h" 10 | 11 | namespace m = reinforcement_learning::model_management; 12 | namespace u = reinforcement_learning::utility; 13 | 14 | namespace reinforcement_learning 15 | { 16 | namespace onnx 17 | { 18 | int create_onnx_model( 19 | std::unique_ptr& retval, const u::configuration& config, i_trace* trace_logger, api_status* status) 20 | { 21 | const char* app_id = config.get(name::APP_ID, ""); 22 | const char* output_name = config.get(name::ONNX_OUTPUT_NAME, nullptr); 23 | if (output_name == nullptr) 24 | { 25 | RETURN_ERROR_LS(trace_logger, status, inference_configuration_error) 26 | << "Output name is not provided in the configuration."; 27 | } 28 | 29 | bool use_unstructured_input = config.get_bool(name::ONNX_USE_UNSTRUCTURED_INPUT, false); 30 | 31 | retval.reset(new onnx_model(trace_logger, app_id, output_name, use_unstructured_input)); 32 | 33 | return error_code::success; 34 | }; 35 | 36 | void register_onnx_factory() { model_factory.register_type(value::ONNXRUNTIME_MODEL, create_onnx_model); } 37 | } // namespace onnx 38 | } // namespace reinforcement_learning -------------------------------------------------------------------------------- /rlclientlib/federation/joined_log_provider.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "api_status.h" 4 | #include "future_compat.h" 5 | #include "vw/io/io_adapter.h" 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | namespace reinforcement_learning 12 | { 13 | struct i_joined_log_batch 14 | { 15 | virtual ~i_joined_log_batch() = default; 16 | 17 | /// Returns next chunk of batch. chunk_reader will be nullptr when then batch is complete. 18 | RL_ATTR(nodiscard) 19 | virtual int next(std::unique_ptr& chunk_reader, api_status* status = nullptr) = 0; 20 | }; 21 | 22 | /** 23 | * @brief This interface allows polling access to logged event data. 24 | */ 25 | struct i_joined_log_provider 26 | { 27 | virtual ~i_joined_log_provider() = default; 28 | 29 | /// Runs the join operation and returns the resulting batch which can be consumed. 30 | /// The format of the data returned in the batch it implementation dependent. 31 | RL_ATTR(nodiscard) 32 | virtual int invoke_join(std::unique_ptr& batch, api_status* status = nullptr) = 0; 33 | }; 34 | } // namespace reinforcement_learning 35 | -------------------------------------------------------------------------------- /rlclientlib/learning_mode.cc: -------------------------------------------------------------------------------- 1 | #include "learning_mode.h" 2 | 3 | #include "constants.h" 4 | 5 | #include 6 | 7 | // portability fun 8 | #ifndef _WIN32 9 | # define _stricmp strcasecmp 10 | #endif 11 | 12 | namespace reinforcement_learning 13 | { 14 | namespace learning 15 | { 16 | learning_mode to_learning_mode(const char* learning_mode) 17 | { 18 | if (_stricmp(learning_mode, value::LEARNING_MODE_APPRENTICE) == 0) { return APPRENTICE; } 19 | if (_stricmp(learning_mode, value::LEARNING_MODE_ONLINE) == 0) { return ONLINE; } 20 | if (_stricmp(learning_mode, value::LEARNING_MODE_LOGGINGONLY) == 0) { return LOGGINGONLY; } 21 | else { return ONLINE; } 22 | } 23 | } // namespace learning 24 | } // namespace reinforcement_learning 25 | -------------------------------------------------------------------------------- /rlclientlib/logger/endian.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace reinforcement_learning 5 | { 6 | namespace logger 7 | { 8 | struct endian 9 | { 10 | static bool is_big_endian(void); 11 | static std::uint32_t htonl(uint32_t host_l); 12 | static std::uint16_t htons(uint16_t host_l); 13 | static std::uint32_t ntohl(uint32_t net_l); 14 | static std::uint16_t ntohs(uint16_t net_s); 15 | }; 16 | } // namespace logger 17 | } // namespace reinforcement_learning -------------------------------------------------------------------------------- /rlclientlib/logger/file/file_logger.cc: -------------------------------------------------------------------------------- 1 | #include "file_logger.h" 2 | 3 | #include "api_status.h" 4 | #include "err_constants.h" 5 | 6 | #include 7 | 8 | #include 9 | #include 10 | namespace reinforcement_learning 11 | { 12 | namespace logger 13 | { 14 | namespace file 15 | { 16 | file_logger::file_logger(std::string file_name, i_trace* trace) : _file_name(std::move(file_name)), _trace(trace) {} 17 | 18 | int file_logger::init(const utility::configuration& config, api_status* status) 19 | { 20 | _file.exceptions(std::ifstream::failbit | std::ifstream::badbit); 21 | try 22 | { 23 | _file.open(_file_name, std::ios::binary); 24 | } 25 | catch (const std::ios_base::failure& e) 26 | { 27 | RETURN_ERROR_LS(_trace, status, file_open_error) << " File:" << _file_name << " Error:" << e.what(); 28 | } 29 | return error_code::success; 30 | } 31 | 32 | int file_logger::v_send(const buffer& data, api_status* status) 33 | { 34 | try 35 | { 36 | _file.write(reinterpret_cast(data->preamble_begin()), data->buffer_filled_size()); 37 | _file.flush(); 38 | } 39 | catch (const std::ios_base::failure& e) 40 | { 41 | RETURN_ERROR_LS(_trace, status, file_open_error) << " File:" << _file_name << " Error:" << e.what(); 42 | } 43 | return error_code::success; 44 | } 45 | } // namespace file 46 | } // namespace logger 47 | } // namespace reinforcement_learning 48 | -------------------------------------------------------------------------------- /rlclientlib/logger/file/file_logger.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "sender.h" 3 | 4 | #include 5 | 6 | namespace reinforcement_learning 7 | { 8 | class i_trace; 9 | } 10 | 11 | namespace reinforcement_learning 12 | { 13 | namespace logger 14 | { 15 | namespace file 16 | { 17 | class file_logger : public i_sender 18 | { 19 | public: 20 | explicit file_logger(std::string file_name, i_trace*); 21 | int init(const utility::configuration& config, api_status* status) override; 22 | 23 | file_logger(const file_logger&) = delete; 24 | file_logger(file_logger&&) = delete; 25 | file_logger& operator=(const file_logger&) = delete; 26 | file_logger& operator=(file_logger&&) = delete; 27 | 28 | protected: 29 | int v_send(const buffer& data, reinforcement_learning::api_status* status) override; 30 | std::string _file_name; 31 | i_trace* _trace; 32 | std::ofstream _file; 33 | }; 34 | } // namespace file 35 | } // namespace logger 36 | } // namespace reinforcement_learning 37 | -------------------------------------------------------------------------------- /rlclientlib/logger/flatbuffer_allocator.cc: -------------------------------------------------------------------------------- 1 | #include "flatbuffer_allocator.h" 2 | namespace reinforcement_learning 3 | { 4 | flatbuffer_allocator::flatbuffer_allocator(utility::data_buffer& data_buffer) : _buffer(data_buffer) {} 5 | 6 | uint8_t* flatbuffer_allocator::allocate(size_t size) 7 | { 8 | _buffer.resize_body_region(size); 9 | return _buffer.body_begin(); 10 | } 11 | 12 | void flatbuffer_allocator::deallocate(uint8_t* p, size_t size) 13 | { 14 | // Nothing to do. Buffer cleanup will happen in data_buffer 15 | } 16 | 17 | uint8_t* flatbuffer_allocator::reallocate_downward( 18 | uint8_t* old_p, size_t old_size, size_t new_size, size_t in_use_back, size_t in_use_front) 19 | { 20 | assert(new_size > old_size); // vector_downward only grows 21 | _buffer.resize_body_region(new_size); 22 | // implementation is almost identical with memcpy_downward from 23 | // https://github.com/google/flatbuffers/blob/master/include/flatbuffers/flatbuffers.h, but since we are staying at 24 | // the same chunk, we 1) cannot use memcpy which has undefined behavior on copying data between overlapping regions 25 | // (https://en.cppreference.com/w/cpp/string/byte/memcpy) 2) do not need to copy head 26 | memmove(_buffer.body_begin() + new_size - in_use_back, _buffer.body_begin() + old_size - in_use_back, in_use_back); 27 | return _buffer.body_begin(); 28 | } 29 | } // namespace reinforcement_learning 30 | -------------------------------------------------------------------------------- /rlclientlib/logger/flatbuffer_allocator.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "data_buffer.h" 3 | #include "flatbuffers/flatbuffers.h" 4 | 5 | namespace reinforcement_learning 6 | { 7 | class flatbuffer_allocator : public flatbuffers::Allocator 8 | { 9 | public: 10 | flatbuffer_allocator(utility::data_buffer& data_buffer); 11 | 12 | uint8_t* allocate(size_t size) override; 13 | void deallocate(uint8_t* p, size_t size) override; 14 | uint8_t* reallocate_downward( 15 | uint8_t* old_p, size_t old_size, size_t new_size, size_t in_use_back, size_t in_use_front) override; 16 | 17 | private: 18 | utility::data_buffer& _buffer; 19 | }; 20 | } // namespace reinforcement_learning 21 | -------------------------------------------------------------------------------- /rlclientlib/logger/message_sender.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | namespace reinforcement_learning 6 | { 7 | class api_status; 8 | 9 | namespace utility 10 | { 11 | class data_buffer; 12 | } 13 | 14 | namespace logger 15 | { 16 | class i_message_sender 17 | { 18 | public: 19 | using buffer = std::shared_ptr; 20 | virtual ~i_message_sender() = default; 21 | virtual int send(const uint16_t msg_type, const buffer& db, api_status* status = nullptr) = 0; 22 | virtual int init(api_status* status = nullptr) = 0; 23 | }; 24 | } // namespace logger 25 | } // namespace reinforcement_learning 26 | -------------------------------------------------------------------------------- /rlclientlib/logger/message_type.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace reinforcement_learning 6 | { 7 | namespace logger 8 | { 9 | struct message_type 10 | { 11 | using const_int = const uint16_t; 12 | 13 | // WARNING ! 14 | // Do not reuse message ids. This can cause 15 | // incompatibility with services that exist in the system. 16 | // 17 | // Message ids should only be added to this list! 18 | 19 | // Flat buffer ranking event collection message 20 | static const_int UNKNOWN = 0; 21 | static const_int fb_ranking_event_collection = 22 | 1; // Deprecated. It is replaced by fb_ranking_learning_mode_event_collection = 9. 23 | static const_int fb_outcome_event_collection = 2; 24 | static const_int json_ranking_event_collection = 3; 25 | static const_int json_outcome_event_collection = 4; 26 | static const_int fb_outcome_event = 5; 27 | static const_int fb_interaction_event = 6; // Deprecated. It is replaced by fb_interaction_learning_mode_event = 10. 28 | static const_int fb_decision_event = 7; 29 | static const_int fb_decision_event_collection = 8; 30 | static const_int fb_ranking_learning_mode_event_collection = 9; 31 | static const_int fb_interaction_learning_mode_event = 10; 32 | static const_int fb_slates_event = 11; 33 | static const_int fb_slates_event_collection = 12; 34 | static const_int fb_generic_event_collection = 13; 35 | }; 36 | } // namespace logger 37 | } // namespace reinforcement_learning 38 | -------------------------------------------------------------------------------- /rlclientlib/logger/preamble.cc: -------------------------------------------------------------------------------- 1 | #include "preamble.h" 2 | 3 | #include "endian.h" 4 | 5 | namespace reinforcement_learning 6 | { 7 | namespace logger 8 | { 9 | bool preamble::write_to_bytes(uint8_t* buffer, size_t buffersz) const 10 | { 11 | if (buffersz < size()) { return false; } 12 | 13 | buffer[0] = reserved; 14 | buffer[1] = version; 15 | auto* p_type = reinterpret_cast(buffer + 2); 16 | *p_type = endian::htons(msg_type); 17 | auto* p_size = reinterpret_cast(buffer + 4); 18 | *p_size = endian::htonl(msg_size); 19 | return true; 20 | } 21 | 22 | bool preamble::read_from_bytes(uint8_t* buffer, size_t buffersz) 23 | { 24 | if (buffersz < size()) { return false; } 25 | 26 | reserved = buffer[0]; 27 | version = buffer[1]; 28 | auto* p_type = reinterpret_cast(buffer + 2); 29 | msg_type = endian::ntohs(*p_type); 30 | auto* p_size = reinterpret_cast(buffer + 4); 31 | msg_size = endian::ntohl(*p_size); 32 | return true; 33 | } 34 | 35 | } // namespace logger 36 | } // namespace reinforcement_learning -------------------------------------------------------------------------------- /rlclientlib/logger/preamble.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace reinforcement_learning 6 | { 7 | namespace logger 8 | { 9 | struct preamble 10 | { 11 | uint8_t reserved = 0; 12 | uint8_t version = 0; 13 | uint16_t msg_type = 0; 14 | uint32_t msg_size = 0; 15 | 16 | bool write_to_bytes(uint8_t* buffer, size_t buffersz) const; 17 | bool read_from_bytes(uint8_t* buffer, size_t buffersz); 18 | constexpr static uint32_t size() { return 8; }; 19 | }; 20 | } // namespace logger 21 | } // namespace reinforcement_learning 22 | -------------------------------------------------------------------------------- /rlclientlib/logger/preamble_sender.cc: -------------------------------------------------------------------------------- 1 | #include "preamble_sender.h" 2 | 3 | #include "api_status.h" 4 | #include "preamble.h" 5 | 6 | namespace reinforcement_learning 7 | { 8 | namespace logger 9 | { 10 | struct preamble; 11 | 12 | preamble_message_sender::preamble_message_sender(std::unique_ptr sender) : _sender(std::move(sender)) {} 13 | 14 | int preamble_message_sender::send(const uint16_t msg_type, const buffer& db, api_status* status) 15 | { 16 | // Set the preamble for this message 17 | preamble pre; 18 | pre.msg_type = msg_type; 19 | pre.msg_size = static_cast(db->body_filled_size()); 20 | if (!pre.write_to_bytes(db->preamble_begin(), db->preamble_size())) 21 | { 22 | RETURN_ERROR_LS(nullptr, status, preamble_error) << " Write error."; 23 | } 24 | // Send message with preamble 25 | return _sender->send(db, status); 26 | } 27 | 28 | int preamble_message_sender::init(api_status* status) { return error_code::success; } 29 | } // namespace logger 30 | } // namespace reinforcement_learning 31 | -------------------------------------------------------------------------------- /rlclientlib/logger/preamble_sender.h: -------------------------------------------------------------------------------- 1 | #include "data_buffer.h" 2 | #include "message_sender.h" 3 | #include "sender.h" 4 | 5 | namespace reinforcement_learning 6 | { 7 | namespace logger 8 | { 9 | class preamble_message_sender : public i_message_sender 10 | { 11 | public: 12 | explicit preamble_message_sender(std::unique_ptr); 13 | int send(const uint16_t msg_type, const buffer& db, api_status* status) override; 14 | int init(api_status* status) override; 15 | 16 | private: 17 | std::unique_ptr _sender; 18 | }; 19 | } // namespace logger 20 | } // namespace reinforcement_learning 21 | -------------------------------------------------------------------------------- /rlclientlib/model_mgmt/data_callback_fn.cc: -------------------------------------------------------------------------------- 1 | #include "data_callback_fn.h" 2 | 3 | #include "err_constants.h" 4 | 5 | #include 6 | 7 | #include 8 | 9 | namespace reinforcement_learning 10 | { 11 | namespace model_management 12 | { 13 | int model_management::data_callback_fn::report_data(const model_data& data, i_trace* trace, api_status* status) 14 | { 15 | if (!_fn) { RETURN_ERROR_LS(trace, status, data_callback_not_set); } 16 | 17 | // Need not be thread safe since this is only called from one thread. 18 | try 19 | { 20 | _fn(data); 21 | return error_code::success; 22 | } 23 | catch (const std::exception& ex) 24 | { 25 | RETURN_ERROR_LS(trace, status, data_callback_exception) << ex.what(); 26 | } 27 | catch (...) 28 | { 29 | RETURN_ERROR_LS(trace, status, data_callback_exception) << "Unknown exception"; 30 | } 31 | } 32 | 33 | } // namespace model_management 34 | } // namespace reinforcement_learning 35 | -------------------------------------------------------------------------------- /rlclientlib/model_mgmt/data_callback_fn.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "model_mgmt.h" 3 | 4 | #include 5 | 6 | namespace reinforcement_learning 7 | { 8 | class i_trace; 9 | namespace model_management 10 | { 11 | class data_callback_fn 12 | { 13 | public: 14 | using data_fn = std::function; 15 | 16 | int report_data(const model_data& data, i_trace* trace, api_status* status = nullptr); 17 | 18 | // Typed constructor 19 | template 20 | using data_fn_ptr = void (*)(const model_data&, DataCntxt*); 21 | 22 | template 23 | explicit data_callback_fn(data_fn_ptr fn, DataCntxt* context) 24 | { 25 | if (fn != nullptr) { _fn = std::bind(fn, std::placeholders::_1, context); } 26 | } 27 | 28 | ~data_callback_fn() = default; 29 | 30 | data_callback_fn(const data_callback_fn&) = delete; 31 | data_callback_fn(data_callback_fn&&) = delete; 32 | data_callback_fn& operator=(const data_callback_fn&) = delete; 33 | data_callback_fn& operator=(data_callback_fn&&) = delete; 34 | 35 | private: 36 | data_fn _fn; 37 | }; 38 | 39 | } // namespace model_management 40 | } // namespace reinforcement_learning 41 | -------------------------------------------------------------------------------- /rlclientlib/model_mgmt/empty_data_transport.cc: -------------------------------------------------------------------------------- 1 | #include "empty_data_transport.h" 2 | 3 | #include "api_status.h" 4 | #include "factory_resolver.h" 5 | 6 | namespace u = reinforcement_learning::utility; 7 | 8 | namespace reinforcement_learning 9 | { 10 | namespace model_management 11 | { 12 | int empty_data_transport::get_data(model_data& ret, api_status* status) 13 | { 14 | ret.increment_refresh_count(); 15 | 16 | return error_code::success; 17 | } 18 | } // namespace model_management 19 | } // namespace reinforcement_learning 20 | -------------------------------------------------------------------------------- /rlclientlib/model_mgmt/empty_data_transport.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "model_mgmt.h" 3 | 4 | namespace reinforcement_learning 5 | { 6 | namespace model_management 7 | { 8 | class empty_data_transport : public i_data_transport 9 | { 10 | public: 11 | int get_data(model_data& ret, api_status* status) override; 12 | 13 | private: 14 | }; 15 | } // namespace model_management 16 | } // namespace reinforcement_learning 17 | -------------------------------------------------------------------------------- /rlclientlib/model_mgmt/file_model_loader.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "model_mgmt.h" 3 | namespace reinforcement_learning 4 | { 5 | class i_trace; 6 | } 7 | 8 | namespace reinforcement_learning 9 | { 10 | namespace model_management 11 | { 12 | class file_model_loader : public i_data_transport 13 | { 14 | public: 15 | file_model_loader(std::string file_name, bool file_must_exist, i_trace* trace_logger); 16 | int init(api_status* status = nullptr); 17 | int get_data(model_data& data, api_status* status = nullptr) override; 18 | 19 | private: 20 | int get_file_modified_time(time_t& file_time, api_status* status) const; 21 | 22 | private: 23 | std::string _file_name; 24 | bool _file_must_exist; 25 | i_trace* _trace; 26 | time_t _last_modified = 0; 27 | size_t _datasz{}; 28 | }; 29 | 30 | } // namespace model_management 31 | } // namespace reinforcement_learning 32 | -------------------------------------------------------------------------------- /rlclientlib/model_mgmt/model_downloader.cc: -------------------------------------------------------------------------------- 1 | #include "model_downloader.h" 2 | 3 | #include "api_status.h" 4 | 5 | namespace reinforcement_learning 6 | { 7 | namespace model_management 8 | { 9 | model_downloader::model_downloader(i_data_transport* ptrans, data_callback_fn* pdata_cb, i_trace* trace) 10 | : _ptrans(ptrans), _pdata_cb(pdata_cb), _trace(trace) 11 | { 12 | } 13 | 14 | model_downloader::model_downloader(model_downloader&& temp) noexcept 15 | { 16 | _ptrans = temp._ptrans; 17 | temp._ptrans = nullptr; 18 | _pdata_cb = temp._pdata_cb; 19 | temp._pdata_cb = nullptr; 20 | _trace = temp._trace; 21 | temp._trace = nullptr; 22 | } 23 | 24 | model_downloader& model_downloader::operator=(model_downloader&& temp) noexcept 25 | { 26 | if (&temp != this) 27 | { 28 | _ptrans = temp._ptrans; 29 | temp._ptrans = nullptr; 30 | _pdata_cb = temp._pdata_cb; 31 | temp._pdata_cb = nullptr; 32 | _trace = temp._trace; 33 | temp._trace = nullptr; 34 | } 35 | return *this; 36 | } 37 | 38 | int model_downloader::run_iteration(api_status* status) const 39 | { 40 | model_data md; 41 | RETURN_IF_FAIL(_ptrans->get_data(md, status)); 42 | 43 | const auto scode = _pdata_cb->report_data(md, _trace, status); 44 | return scode; 45 | } 46 | } // namespace model_management 47 | } // namespace reinforcement_learning 48 | -------------------------------------------------------------------------------- /rlclientlib/model_mgmt/model_downloader.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "data_callback_fn.h" 3 | namespace reinforcement_learning 4 | { 5 | class error_callback_fn; 6 | } 7 | 8 | namespace reinforcement_learning 9 | { 10 | namespace model_management 11 | { 12 | class model_downloader 13 | { 14 | public: 15 | model_downloader(i_data_transport* ptrans, data_callback_fn* pdata_cb, i_trace* trace); 16 | model_downloader(model_downloader&& temp) noexcept; 17 | model_downloader& operator=(model_downloader&& temp) noexcept; 18 | 19 | int run_iteration(api_status* status) const; 20 | 21 | private: 22 | // Lifetime of pointers managed by user of this class 23 | i_data_transport* _ptrans = nullptr; 24 | data_callback_fn* _pdata_cb = nullptr; 25 | i_trace* _trace; 26 | }; 27 | } // namespace model_management 28 | } // namespace reinforcement_learning 29 | -------------------------------------------------------------------------------- /rlclientlib/model_mgmt/model_mgmt.cc: -------------------------------------------------------------------------------- 1 | #include "model_mgmt.h" 2 | 3 | namespace reinforcement_learning 4 | { 5 | namespace model_management 6 | { 7 | 8 | char* model_data::data() { return _data.data(); } 9 | const char* model_data::data() const { return _data.data(); } 10 | 11 | void model_data::increment_refresh_count() { ++_refresh_count; } 12 | 13 | size_t model_data::data_sz() const { return _data.size(); } 14 | 15 | uint32_t model_data::refresh_count() const { return _refresh_count; } 16 | 17 | void model_data::data_sz(const size_t fillsz) { _data.resize(fillsz); } 18 | 19 | char* model_data::alloc(const size_t desired) 20 | { 21 | _data.clear(); 22 | _data.resize(desired); 23 | return _data.data(); 24 | } 25 | 26 | void model_data::free() { _data.clear(); } 27 | 28 | int model_data::set_data(const char* vw_model, size_t len) 29 | { 30 | if (vw_model == nullptr || len == 0) { return reinforcement_learning::error_code::static_model_load_error; } 31 | 32 | char* buffer = this->alloc(len); 33 | if (buffer == nullptr) { return reinforcement_learning::error_code::static_model_load_error; } 34 | 35 | memcpy(buffer, vw_model, len); 36 | this->data_sz(len); 37 | 38 | return reinforcement_learning::error_code::success; 39 | } 40 | 41 | } // namespace model_management 42 | } // namespace reinforcement_learning 43 | -------------------------------------------------------------------------------- /rlclientlib/moving_queue.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace reinforcement_learning 7 | { 8 | // a moving non-concurrent queue 9 | template 10 | class moving_queue 11 | { 12 | using queue_t = std::queue; 13 | 14 | queue_t _queue; 15 | 16 | public: 17 | void pop(T* item) 18 | { 19 | if (!_queue.empty()) 20 | { 21 | *item = std::move(_queue.front()); 22 | _queue.pop(); 23 | } 24 | } 25 | 26 | void push(T& item) { push(std::move(item)); } 27 | 28 | void push(T&& item) { _queue.push(std::forward(item)); } 29 | 30 | // approximate size 31 | size_t size() { return _queue.size(); } 32 | }; 33 | } // namespace reinforcement_learning 34 | -------------------------------------------------------------------------------- /rlclientlib/schema/v1/DecisionRankingEvent.fbs: -------------------------------------------------------------------------------- 1 | include "Metadata.fbs"; 2 | 3 | namespace reinforcement_learning.messages.flatbuff; 4 | 5 | table SlotEvent { 6 | decision_slot_id:string; 7 | action_ids:[uint32]; // ranked action ids 8 | probabilities:[float]; // probabilities 9 | } 10 | 11 | table DecisionEvent { 12 | context:[ubyte]; // context json 13 | slots:[SlotEvent]; // the collection of individual slots 14 | model_id:string; // model id 15 | pass_probability:float; // Probability of event surviving throttling operation 16 | deferred_action:bool = false; // delayed activation flag 17 | meta:Metadata; // contains metadata like timestamp 18 | } 19 | 20 | // Collection of ranking events 21 | table DecisionEventBatch { 22 | events:[DecisionEvent]; 23 | } 24 | 25 | root_type DecisionEventBatch; 26 | -------------------------------------------------------------------------------- /rlclientlib/schema/v1/Metadata.fbs: -------------------------------------------------------------------------------- 1 | namespace reinforcement_learning.messages.flatbuff; 2 | 3 | struct TimeStamp { 4 | year:uint16; 5 | month:uint8; 6 | day:uint8; 7 | hour:uint8; 8 | minute:uint8; 9 | second:uint8; 10 | subsecond:uint32; 11 | } 12 | 13 | table Metadata { 14 | client_time_utc:TimeStamp; 15 | app_id:string; 16 | } -------------------------------------------------------------------------------- /rlclientlib/schema/v1/OutcomeEvent.fbs: -------------------------------------------------------------------------------- 1 | // EventHubInteraction Schema used by FlatBuffer 2 | include "Metadata.fbs"; 3 | 4 | namespace reinforcement_learning.messages.flatbuff; 5 | 6 | table StringEvent { 7 | value:string; 8 | } 9 | 10 | table NumericEvent { 11 | value:float; 12 | } 13 | 14 | table ActionTakenEvent { 15 | value:bool = false; 16 | } 17 | 18 | union OutcomeEvent { StringEvent, NumericEvent, ActionTakenEvent } 19 | 20 | table OutcomeEventHolder { 21 | event_id:string; 22 | pass_probability:float; // Probability of event surviving throttling operation 23 | the_event:OutcomeEvent; 24 | meta:Metadata; 25 | } 26 | 27 | table OutcomeEventBatch { 28 | events:[OutcomeEventHolder]; 29 | } 30 | 31 | root_type OutcomeEventBatch; 32 | -------------------------------------------------------------------------------- /rlclientlib/schema/v1/RankingEvent.fbs: -------------------------------------------------------------------------------- 1 | // EventHubInteraction Schema used by FlatBuffer 2 | include "Metadata.fbs"; 3 | 4 | namespace reinforcement_learning.messages.flatbuff; 5 | 6 | enum LearningModeType : ubyte { Online, Apprentice, LoggingOnly } 7 | 8 | table RankingEvent { 9 | event_id:string; // event IDs 10 | deferred_action:bool = false; 11 | action_ids:[uint64]; // action IDs 12 | context:[ubyte]; // context 13 | probabilities:[float]; // probabilities 14 | model_id:string; // model ID 15 | pass_probability:float; // Probability of event surviving throttling operation 16 | meta:Metadata; 17 | learning_mode:LearningModeType; // decision mode used to determine rank behavior 18 | } 19 | 20 | // Collection of Ranking events 21 | table RankingEventBatch { 22 | events:[RankingEvent]; 23 | } 24 | 25 | root_type RankingEventBatch; 26 | -------------------------------------------------------------------------------- /rlclientlib/schema/v1/SlatesEvent.fbs: -------------------------------------------------------------------------------- 1 | // EventHubInteraction Schema used by FlatBuffer 2 | include "Metadata.fbs"; 3 | 4 | namespace reinforcement_learning.messages.flatbuff; 5 | 6 | table SlatesSlotEvent { 7 | action_ids:[uint32]; // ranked action ids 8 | probabilities:[float]; // probabilities 9 | } 10 | 11 | table SlatesEvent { 12 | event_id:string; // event ID 13 | context:[ubyte]; // context 14 | slots:[SlatesSlotEvent]; // actions and probabilities 15 | model_id:string; // model ID 16 | pass_probability:float; // Probability of event surviving throttling operation 17 | deferred_action:bool = false; // delayed activation flag 18 | meta:Metadata; 19 | } 20 | 21 | // Collection of slate events 22 | table SlatesEventBatch { 23 | events:[SlatesEvent]; 24 | } 25 | 26 | root_type SlatesEventBatch; -------------------------------------------------------------------------------- /rlclientlib/schema/v2/CaEvent.fbs: -------------------------------------------------------------------------------- 1 | // EventHubInteraction Schema used by FlatBuffer 2 | include "LearningModeType.fbs"; 3 | 4 | namespace reinforcement_learning.messages.flatbuff.v2; 5 | 6 | table CaEvent { 7 | deferred_action:bool = false; 8 | action:float; // continuous action 9 | context:[ubyte]; // context 10 | pdf_value:float; // pdf_value at chosen location 11 | model_id:string; // model ID 12 | learning_mode:LearningModeType; // decision mode used to determine rank behavior 13 | } 14 | 15 | root_type CaEvent; 16 | -------------------------------------------------------------------------------- /rlclientlib/schema/v2/CbEvent.fbs: -------------------------------------------------------------------------------- 1 | // EventHubInteraction Schema used by FlatBuffer 2 | include "LearningModeType.fbs"; 3 | 4 | namespace reinforcement_learning.messages.flatbuff.v2; 5 | 6 | table CbEvent { 7 | deferred_action:bool = false; 8 | action_ids:[uint64]; // action IDs 9 | context:[ubyte]; // context 10 | probabilities:[float]; // probabilities 11 | model_id:string; // model ID 12 | learning_mode:LearningModeType; // decision mode used to determine rank behavior 13 | } 14 | 15 | root_type CbEvent; 16 | -------------------------------------------------------------------------------- /rlclientlib/schema/v2/DedupInfo.fbs: -------------------------------------------------------------------------------- 1 | namespace reinforcement_learning.messages.flatbuff.v2; 2 | 3 | table DedupInfo { 4 | ids: [ulong]; 5 | values: [string]; 6 | } 7 | 8 | root_type DedupInfo; 9 | -------------------------------------------------------------------------------- /rlclientlib/schema/v2/Event.fbs: -------------------------------------------------------------------------------- 1 | // EventHubInteraction Schema used by FlatBuffer 2 | include "Metadata.fbs"; 3 | 4 | namespace reinforcement_learning.messages.flatbuff.v2; 5 | 6 | table Event { 7 | meta:Metadata; 8 | payload:[ubyte]; // payload 9 | } 10 | 11 | table BatchMetadata { 12 | content_encoding: string; //valid values: IDENTITY and DEDUP 13 | original_event_count: uint64; 14 | } 15 | 16 | table SerializedEvent { 17 | payload:[ubyte]; //serialized Event objects 18 | } 19 | 20 | table EventBatch { 21 | events:[SerializedEvent]; 22 | metadata: BatchMetadata; 23 | } 24 | 25 | root_type EventBatch; 26 | -------------------------------------------------------------------------------- /rlclientlib/schema/v2/FileFormat.fbs: -------------------------------------------------------------------------------- 1 | include "Event.fbs"; 2 | include "LearningModeType.fbs"; 3 | include "ProblemType.fbs"; 4 | 5 | namespace reinforcement_learning.messages.flatbuff.v2; 6 | enum RewardFunctionType : ubyte { Earliest, Average, Median, Sum, Min, Max } 7 | 8 | table JoinedEvent { 9 | event: [ubyte]; 10 | timestamp: TimeStamp; 11 | } 12 | 13 | table JoinedPayload { 14 | events: [JoinedEvent]; 15 | } 16 | 17 | table KeyValue { 18 | key: string; 19 | value: string; 20 | } 21 | 22 | table FileHeader { 23 | join_time: TimeStamp; 24 | properties: [KeyValue]; 25 | } 26 | 27 | table CheckpointInfo { 28 | reward_function_type: RewardFunctionType; 29 | default_reward: float; 30 | learning_mode_config: LearningModeType; 31 | problem_type_config: ProblemType; 32 | use_client_time: bool; 33 | } 34 | 35 | root_type FileHeader; 36 | root_type CheckpointInfo; 37 | root_type JoinedPayload; 38 | -------------------------------------------------------------------------------- /rlclientlib/schema/v2/LearningModeType.fbs: -------------------------------------------------------------------------------- 1 | // LearningMode Schema used by FlatBuffer 2 | namespace reinforcement_learning.messages.flatbuff.v2; 3 | 4 | enum LearningModeType : ubyte { Online, Apprentice, LoggingOnly } 5 | -------------------------------------------------------------------------------- /rlclientlib/schema/v2/Metadata.fbs: -------------------------------------------------------------------------------- 1 | namespace reinforcement_learning.messages.flatbuff.v2; 2 | 3 | enum PayloadType : ubyte { CB, CCB, Slates, Outcome, CA, DedupInfo, MultiStep, Episode } 4 | enum EventEncoding: ubyte { Identity, Zstd } 5 | 6 | struct TimeStamp { 7 | year:uint16; 8 | month:uint8; 9 | day:uint8; 10 | hour:uint8; 11 | minute:uint8; 12 | second:uint8; 13 | subsecond:uint32; 14 | } 15 | 16 | table Metadata { 17 | id:string; 18 | client_time_utc:TimeStamp; 19 | app_id:string; 20 | payload_type:PayloadType; 21 | pass_probability:float; // Probability of event surviving throttling operation 22 | encoding: EventEncoding; 23 | } 24 | -------------------------------------------------------------------------------- /rlclientlib/schema/v2/MultiSlotEvent.fbs: -------------------------------------------------------------------------------- 1 | // EventHubInteraction Schema used by FlatBuffer 2 | include "LearningModeType.fbs"; 3 | 4 | namespace reinforcement_learning.messages.flatbuff.v2; 5 | 6 | table SlotEvent { 7 | action_ids:[uint32]; // ranked action ids 8 | probabilities:[float]; // probabilities 9 | id:string; // id for slot 10 | } 11 | 12 | // this event covers both ccb and slates events from v1 schema 13 | table MultiSlotEvent { 14 | context:[ubyte]; // context 15 | slots:[SlotEvent]; // actions and probabilities 16 | model_id:string; // model ID 17 | deferred_action:bool = false; // delayed activation flag 18 | baseline_actions:[int]; // baseline actions for apprentice mode 19 | learning_mode:LearningModeType; // decision mode used to determine rank behavior 20 | } 21 | 22 | root_type MultiSlotEvent; 23 | -------------------------------------------------------------------------------- /rlclientlib/schema/v2/MultiStepEvent.fbs: -------------------------------------------------------------------------------- 1 | // EventHubInteraction Schema used by FlatBuffer 2 | namespace reinforcement_learning.messages.flatbuff.v2; 3 | 4 | table MultiStepEvent { 5 | event_id:string; 6 | previous_id:string; 7 | 8 | action_ids:[uint64]; // action IDs 9 | context:[ubyte]; // context 10 | probabilities:[float]; // probabilities 11 | model_id:string; // model ID 12 | 13 | deferred_action:bool = false; 14 | } 15 | 16 | table EpisodeEvent { 17 | episode_id:string; 18 | } 19 | 20 | root_type MultiStepEvent; 21 | root_type EpisodeEvent; 22 | -------------------------------------------------------------------------------- /rlclientlib/schema/v2/OutcomeEvent.fbs: -------------------------------------------------------------------------------- 1 | // EventHubInteraction Schema used by FlatBuffer 2 | namespace reinforcement_learning.messages.flatbuff.v2; 3 | 4 | //must be a table because flatbuffs don't allow a float in unions :facepalm: 5 | table NumericOutcome { 6 | value: float; 7 | } 8 | 9 | union OutcomeValue { 10 | numeric: NumericOutcome, 11 | literal: string 12 | } 13 | 14 | table NumericIndex { 15 | index: int; 16 | } 17 | 18 | union IndexValue { 19 | numeric: NumericIndex, 20 | literal: string 21 | } 22 | 23 | // both value and index are optional 24 | // OutcomeEvent can be used to indicate activation 25 | table OutcomeEvent { 26 | value: OutcomeValue; 27 | index: IndexValue; 28 | action_taken: bool = false; 29 | } 30 | 31 | root_type OutcomeEvent; 32 | -------------------------------------------------------------------------------- /rlclientlib/schema/v2/ProblemType.fbs: -------------------------------------------------------------------------------- 1 | // Problem Type Schema used by FlatBuffer 2 | namespace reinforcement_learning.messages.flatbuff.v2; 3 | 4 | enum ProblemType : ubyte { UNKNOWN, CB, CCB, SLATES, CA, MULTISTEP } 5 | -------------------------------------------------------------------------------- /rlclientlib/serialization/payload_serializer.cc: -------------------------------------------------------------------------------- 1 | #include "payload_serializer.h" 2 | 3 | namespace reinforcement_learning 4 | { 5 | using namespace messages::flatbuff; 6 | namespace logger 7 | { 8 | int get_learning_mode(learning_mode mode_in, v2::LearningModeType& mode_out, api_status* status) 9 | { 10 | switch (mode_in) 11 | { 12 | case APPRENTICE: 13 | mode_out = v2::LearningModeType_Apprentice; 14 | return error_code::success; 15 | case ONLINE: 16 | mode_out = v2::LearningModeType_Online; 17 | return error_code::success; 18 | case LOGGINGONLY: 19 | mode_out = v2::LearningModeType_LoggingOnly; 20 | return error_code::success; 21 | default: 22 | return report_error(status, error_code::unsupported_learning_mode, error_code::unsupported_learning_mode_s); 23 | } 24 | } 25 | } // namespace logger 26 | } // namespace reinforcement_learning 27 | -------------------------------------------------------------------------------- /rlclientlib/utility/context_helper.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "api_status.h" 3 | #include "rl_string_view.h" 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | namespace reinforcement_learning 10 | { 11 | class i_trace; 12 | namespace utility 13 | { 14 | //! This struct collects all sort of relevant data we need about a context json. 15 | struct ContextInfo 16 | { 17 | //! Each pair is the start offset and length of a JSON object. IE, the range covers '{' to '}' 18 | typedef std::vector> index_vector_t; 19 | 20 | //! The index to each element in the _multi array 21 | index_vector_t actions; 22 | //! The index to each element in the _slots array 23 | index_vector_t slots; 24 | }; 25 | 26 | int get_event_ids(string_view context, std::map& event_ids, i_trace* trace, api_status* status); 27 | int get_context_info(string_view context, ContextInfo& info, i_trace* trace = nullptr, api_status* status = nullptr); 28 | int get_slot_ids(string_view context, const ContextInfo::index_vector_t& slots, std::map& slot_ids, 29 | i_trace* trace = nullptr, api_status* status = nullptr); 30 | } // namespace utility 31 | } // namespace reinforcement_learning 32 | -------------------------------------------------------------------------------- /rlclientlib/utility/data_buffer_streambuf.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | namespace reinforcement_learning 4 | { 5 | namespace utility 6 | { 7 | class data_buffer; 8 | /** 9 | * \brief A streambuf class that is backed by data_buffer. streambuf can be passed to 10 | * std::ostream and used to serialize using << and >> operators. 11 | * This is used while serializing json into data_buffer. 12 | * 13 | * See https://en.cppreference.com/w/cpp/io/basic_streambuf for additional details on streambuf 14 | */ 15 | class data_buffer_streambuf : public std::streambuf 16 | { 17 | public: 18 | explicit data_buffer_streambuf(data_buffer*); 19 | int_type overflow(int_type) override; 20 | int_type sync() override; 21 | void finalize(); 22 | ~data_buffer_streambuf(); 23 | 24 | private: 25 | data_buffer* _db; 26 | const size_t GROW_BY = 2048; 27 | bool _finalized = false; 28 | }; 29 | } // namespace utility 30 | } // namespace reinforcement_learning 31 | -------------------------------------------------------------------------------- /rlclientlib/utility/header_authorization.cc: -------------------------------------------------------------------------------- 1 | #include "header_authorization.h" 2 | 3 | namespace reinforcement_learning 4 | { 5 | int header_authorization::init(const utility::configuration& config, api_status* status, i_trace* trace) 6 | { 7 | const auto* api_key = config.get(name::HTTP_API_KEY, nullptr); 8 | if (api_key == nullptr) { RETURN_ERROR(trace, status, http_api_key_not_provided); } 9 | _api_key = api_key; 10 | #ifdef _WIN32 11 | _http_api_header_key_name = ::utility::conversions::utf8_to_utf16( 12 | config.get(name::HTTP_API_HEADER_KEY_NAME, value::HTTP_API_DEFAULT_HEADER_KEY_NAME)); 13 | #else 14 | _http_api_header_key_name = config.get(name::HTTP_API_HEADER_KEY_NAME, value::HTTP_API_DEFAULT_HEADER_KEY_NAME); 15 | #endif 16 | return error_code::success; 17 | } 18 | 19 | int header_authorization::insert_authorization_header(http_headers& headers, api_status* status, i_trace* trace) 20 | { 21 | if (_api_key.empty()) { RETURN_ERROR(trace, status, http_api_key_not_provided); } 22 | headers.add(_http_api_header_key_name, _api_key.c_str()); 23 | return error_code::success; 24 | } 25 | } // namespace reinforcement_learning 26 | -------------------------------------------------------------------------------- /rlclientlib/utility/header_authorization.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "api_status.h" 4 | #include "configuration.h" 5 | #include "constants.h" 6 | 7 | #include 8 | 9 | using namespace web::http; 10 | 11 | namespace reinforcement_learning 12 | { 13 | class header_authorization 14 | { 15 | public: 16 | header_authorization() = default; 17 | ~header_authorization() = default; 18 | 19 | int init(const utility::configuration& config, api_status* status, i_trace* trace); 20 | int insert_authorization_header(http_headers& headers, api_status* status, i_trace* trace); 21 | 22 | header_authorization(const header_authorization&) = delete; 23 | header_authorization(header_authorization&&) = delete; 24 | header_authorization& operator=(const header_authorization&) = delete; 25 | header_authorization& operator=(header_authorization&&) = delete; 26 | 27 | private: 28 | std::string _api_key; 29 | http_headers::key_type _http_api_header_key_name; 30 | }; 31 | } // namespace reinforcement_learning 32 | -------------------------------------------------------------------------------- /rlclientlib/utility/http_client.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "api_status.h" 3 | #include "configuration.h" 4 | 5 | #include 6 | 7 | namespace reinforcement_learning 8 | { 9 | class i_http_client 10 | { 11 | public: 12 | typedef web::http::http_request request_t; 13 | typedef pplx::task response_t; 14 | typedef web::http::method method_t; 15 | 16 | public: 17 | virtual ~i_http_client() = default; 18 | 19 | virtual response_t request(method_t) = 0; 20 | virtual response_t request(request_t) = 0; 21 | 22 | virtual const std::string& get_url() const = 0; 23 | }; 24 | 25 | int create_http_client( 26 | const char* url, const utility::configuration& cfg, i_http_client** client, api_status* status = nullptr); 27 | } // namespace reinforcement_learning 28 | -------------------------------------------------------------------------------- /rlclientlib/utility/http_helper.cc: -------------------------------------------------------------------------------- 1 | #include "http_helper.h" 2 | 3 | #include "constants.h" 4 | 5 | #include 6 | 7 | namespace reinforcement_learning 8 | { 9 | namespace utility 10 | { 11 | web::http::client::http_client_config get_http_config(const utility::configuration& cfg) 12 | { 13 | web::http::client::http_client_config config; 14 | 15 | // The default is to validate certificates. 16 | config.set_validate_certificates(!cfg.get_bool(name::HTTP_CLIENT_DISABLE_CERT_VALIDATION, false)); 17 | auto timeout = cfg.get_int(name::HTTP_CLIENT_TIMEOUT, 30); 18 | // Valid values are 1-30. 19 | if (timeout < 1 || timeout > 30) { timeout = 30; } 20 | config.set_timeout(std::chrono::seconds(timeout)); 21 | return config; 22 | } 23 | } // namespace utility 24 | } // namespace reinforcement_learning 25 | -------------------------------------------------------------------------------- /rlclientlib/utility/http_helper.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "configuration.h" 3 | 4 | #include 5 | 6 | namespace reinforcement_learning 7 | { 8 | namespace utility 9 | { 10 | web::http::client::http_client_config get_http_config(const utility::configuration& cfg); 11 | 12 | } 13 | } // namespace reinforcement_learning -------------------------------------------------------------------------------- /rlclientlib/utility/interruptable_sleeper.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | 6 | namespace reinforcement_learning 7 | { 8 | namespace utility 9 | { 10 | // one use wakable sleeping class 11 | class interruptable_sleeper 12 | { 13 | std::condition_variable _cv; 14 | std::mutex _mutex; 15 | bool _interrupt = false; 16 | 17 | public: 18 | // waits until wake is called or the specified time passes 19 | template 20 | // returns true if timeout expired. false if sleep was interrupted 21 | bool sleep(const std::chrono::duration& timeout_duration); 22 | // unblock sleeping thread 23 | void interrupt(); 24 | }; 25 | 26 | inline void interruptable_sleeper::interrupt() 27 | { 28 | { 29 | std::unique_lock lock(_mutex); 30 | _interrupt = true; 31 | } 32 | _cv.notify_one(); 33 | } 34 | 35 | /* 36 | * Sleep returns true if timeout expires and returns false if sleep was interrupted. 37 | */ 38 | template 39 | bool interruptable_sleeper::sleep(const std::chrono::duration& timeout_duration) 40 | { 41 | std::unique_lock lock(_mutex); 42 | return !(_cv.wait_for(lock, timeout_duration, [this]() { return _interrupt; })); 43 | } 44 | } // namespace utility 45 | } // namespace reinforcement_learning -------------------------------------------------------------------------------- /rlclientlib/utility/stl_container_adapter.cc: -------------------------------------------------------------------------------- 1 | #include "stl_container_adapter.h" 2 | 3 | #include 4 | namespace reinforcement_learning 5 | { 6 | namespace utility 7 | { 8 | stl_container_adapter::stl_container_adapter(data_buffer* db) : _db(db) {} 9 | 10 | size_t stl_container_adapter::size() const { return _db->buffer_filled_size(); } 11 | 12 | const stl_container_adapter::value_type& stl_container_adapter::operator[](size_t idx) const 13 | { 14 | assert(idx < size()); 15 | return *(_db->preamble_begin() + idx); 16 | } 17 | 18 | void stl_container_adapter::resize(size_t /*unused*/) 19 | { 20 | assert(false); // Resize is not supported. 21 | } 22 | 23 | #ifdef _WIN32 24 | stdext::checked_array_iterator stl_container_adapter::begin() const 25 | { 26 | return {_db->preamble_begin(), _db->buffer_filled_size()}; 27 | } 28 | #else 29 | stl_container_adapter::value_type* stl_container_adapter::begin() const { return _db->preamble_begin(); } 30 | #endif 31 | } // namespace utility 32 | } // namespace reinforcement_learning 33 | -------------------------------------------------------------------------------- /rlclientlib/utility/stl_container_adapter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "data_buffer.h" 3 | 4 | #include 5 | 6 | namespace reinforcement_learning 7 | { 8 | namespace utility 9 | { 10 | class stl_container_adapter 11 | { 12 | public: 13 | using value_type = data_buffer::value_type; 14 | 15 | explicit stl_container_adapter(data_buffer* db); 16 | stl_container_adapter(const stl_container_adapter& rhs) = default; 17 | stl_container_adapter(stl_container_adapter&& rhs) noexcept = default; 18 | size_t size() const; 19 | const value_type& operator[](size_t idx) const; 20 | static void resize(size_t); 21 | 22 | #ifdef _WIN32 23 | stdext::checked_array_iterator begin() const; 24 | #else 25 | value_type* begin() const; 26 | #endif 27 | 28 | protected: 29 | data_buffer* _db; 30 | }; 31 | } // namespace utility 32 | } // namespace reinforcement_learning 33 | -------------------------------------------------------------------------------- /rlclientlib/utility/str_util.cc: -------------------------------------------------------------------------------- 1 | #include "str_util.h" 2 | 3 | #include 4 | #include 5 | 6 | using namespace std; 7 | 8 | // Adapted from source: https://stackoverflow.com/a/217605/7964431 9 | 10 | namespace reinforcement_learning 11 | { 12 | namespace utility 13 | { 14 | string& str_util::to_lower(string& sval) 15 | { 16 | transform(sval.begin(), sval.end(), sval.begin(), ::tolower); 17 | return sval; 18 | } 19 | 20 | string& str_util::ltrim(std::string& sval) 21 | { 22 | sval.erase( 23 | sval.begin(), std::find_if(sval.begin(), sval.end(), [](unsigned char ch) { return std::isspace(ch) == 0; })); 24 | return sval; 25 | } 26 | 27 | string& str_util::rtrim(std::string& sval) 28 | { 29 | sval.erase(std::find_if(sval.rbegin(), sval.rend(), [](unsigned char ch) { return std::isspace(ch) == 0; }).base(), 30 | sval.end()); 31 | return sval; 32 | } 33 | 34 | std::string& str_util::trim(std::string& sval) { return ltrim(rtrim(sval)); } 35 | 36 | } // namespace utility 37 | } // namespace reinforcement_learning 38 | -------------------------------------------------------------------------------- /templates/README.md: -------------------------------------------------------------------------------- 1 | # Create Azure Personalizer Loop 2 | 3 | ## Introduction 4 | - Overview: https://azure.microsoft.com/en-us/products/ai-services/ai-personalizer 5 | - What is Personalizer? https://learn.microsoft.com/en-us/azure/ai-services/personalizer/what-is-personalizer 6 | 7 | ## Prerequisites and steps to create a Personalizer Loop: 8 | 1. Create an Azure account: https://azure.microsoft.com/en-us/free/ 9 | 2. Contact us to add your subscription ID to allowlist 10 | 3. Click to deploy the loop: [![Create Personalizer Loop](https://aka.ms/deploytoazurebutton)](https://portal.azure.com/#create/Microsoft.Template/uri/https%3A%2F%2Fraw.githubusercontent.com%2FVowpalWabbit%2Freinforcement_learning%2Fmaster%2Ftemplates%2Fcreate-loop.json) 11 | 12 | ## Learn More 13 | - Personalizer documentation: https://learn.microsoft.com/en-us/azure/ai-services/personalizer/ 14 | - Need help? https://learn.microsoft.com/en-us/answers/tags/219/azure-personalizer -------------------------------------------------------------------------------- /templates/create-loop.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#", 3 | "contentVersion": "1.0.0.0", 4 | "parameters": { 5 | "personalizer_name": { 6 | "type": "String" 7 | }, 8 | "personalizer_region": { 9 | "type": "String" 10 | } 11 | }, 12 | "variables": {}, 13 | "resources": [ 14 | { 15 | "type": "Microsoft.CognitiveServices/accounts", 16 | "apiVersion": "2023-05-01", 17 | "name": "[parameters('personalizer_name')]", 18 | "location": "[parameters('personalizer_region')]", 19 | "sku": { 20 | "name": "S0" 21 | }, 22 | "kind": "Personalizer", 23 | "identity": { 24 | "type": "None" 25 | }, 26 | "properties": { 27 | "customSubDomainName": "[parameters('personalizer_name')]", 28 | "networkAcls": { 29 | "defaultAction": "Allow", 30 | "virtualNetworkRules": [], 31 | "ipRules": [] 32 | }, 33 | "publicNetworkAccess": "Enabled" 34 | } 35 | } 36 | ] 37 | } -------------------------------------------------------------------------------- /test_tools/e2e_testing/base_files/input/multistep/episode.fbs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/test_tools/e2e_testing/base_files/input/multistep/episode.fbs -------------------------------------------------------------------------------- /test_tools/e2e_testing/base_files/input/multistep/interaction.fbs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/test_tools/e2e_testing/base_files/input/multistep/interaction.fbs -------------------------------------------------------------------------------- /test_tools/e2e_testing/base_files/input/multistep/observation.fbs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/test_tools/e2e_testing/base_files/input/multistep/observation.fbs -------------------------------------------------------------------------------- /test_tools/e2e_testing/multistep_client.json: -------------------------------------------------------------------------------- 1 | { 2 | "ApplicationID": "multistep_e2e", 3 | "interaction.sender.implementation": "INTERACTION_FILE_SENDER", 4 | "observation.sender.implementation": "OBSERVATION_FILE_SENDER", 5 | "episode.sender.implementation": "EPISODE_FILE_SENDER", 6 | "interaction.file.name": "interaction.fbs", 7 | "observation.file.name": "observation.fbs", 8 | "episode.file.name": "episode.fbs", 9 | "IsExplorationEnabled": true, 10 | "InitialExplorationEpsilon": 1.0, 11 | "LearningMode": "Online", 12 | "model.source": "FILE_MODEL_DATA", 13 | "protocol.version":"2", 14 | "model.vw.initial_command_line": "--cb_explore_adf --epsilon 0.2 --power_t 0 -l 0.001 --cb_type mtr -q ::" 15 | } 16 | -------------------------------------------------------------------------------- /test_tools/e2e_testing/requirements.txt: -------------------------------------------------------------------------------- 1 | flatbuffers 2 | numpy 3 | zstd -------------------------------------------------------------------------------- /test_tools/example_gen/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(example_gen 2 | example_gen.cc 3 | ) 4 | target_link_libraries(example_gen PRIVATE Boost::program_options rlclientlib) 5 | 6 | if(RL_LINK_AZURE_LIBS) 7 | target_compile_definitions(example_gen PRIVATE LINK_AZURE_LIBS) 8 | find_package(azure-identity-cpp CONFIG REQUIRED) 9 | target_link_libraries(example_gen PRIVATE Azure::azure-identity) 10 | endif() 11 | -------------------------------------------------------------------------------- /test_tools/joiner/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(joiner.out 2 | main.cc 3 | text_converter.cc 4 | ) 5 | 6 | target_include_directories(joiner.out PRIVATE ${FLATBUFFERS_INCLUDE_DIR}) 7 | target_link_libraries(joiner.out PRIVATE Boost::program_options rlclientlib) 8 | -------------------------------------------------------------------------------- /test_tools/joiner/sample_data/interaction.fb.data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/test_tools/joiner/sample_data/interaction.fb.data -------------------------------------------------------------------------------- /test_tools/joiner/sample_data/observation.fb.data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/test_tools/joiner/sample_data/observation.fb.data -------------------------------------------------------------------------------- /test_tools/joiner/text_converter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace reinforcement_learning 6 | { 7 | namespace joiner 8 | { 9 | void convert_to_text(const std::vector& files); 10 | } 11 | } // namespace reinforcement_learning -------------------------------------------------------------------------------- /test_tools/onnx_pytorch/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/test_tools/onnx_pytorch/common/__init__.py -------------------------------------------------------------------------------- /test_tools/onnx_pytorch/common/types.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class Problem(Enum): 5 | CB = (1,) 6 | MultiClass = 2 7 | -------------------------------------------------------------------------------- /test_tools/sender_test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(sender_test 2 | main.cc 3 | test_loop.cc 4 | ) 5 | 6 | # Sender test uses internal headers from the rlclientlib target 7 | target_include_directories(sender_test PRIVATE $) 8 | 9 | target_link_libraries(sender_test PRIVATE Boost::program_options rlclientlib) 10 | 11 | if(RL_LINK_AZURE_LIBS) 12 | target_compile_definitions(sender_test PRIVATE LINK_AZURE_LIBS) 13 | find_package(azure-identity-cpp CONFIG REQUIRED) 14 | target_link_libraries(sender_test PRIVATE Azure::azure-identity) 15 | endif() 16 | -------------------------------------------------------------------------------- /test_tools/sender_test/main.cc: -------------------------------------------------------------------------------- 1 | #include "test_loop.h" 2 | 3 | #include 4 | 5 | namespace po = boost::program_options; 6 | 7 | bool is_help(const po::variables_map& vm) { return vm.count("help") > 0; } 8 | 9 | po::variables_map process_cmd_line(const int argc, char** argv) 10 | { 11 | po::options_description desc("Options"); 12 | desc.add_options()("help", "produce help message")("json_config,j", 13 | po::value()->default_value("client.json"), 14 | "JSON file with config information for hosted RL loop")("message_size,s", po::value()->default_value(100), 15 | "Message size in Kb")("message_count,n", po::value()->default_value(1000000), "Amount of messages")( 16 | "threads,t", po::value()->default_value(1)); 17 | 18 | po::variables_map vm; 19 | store(parse_command_line(argc, argv, desc), vm); 20 | 21 | if (is_help(vm)) { std::cout << desc << std::endl; } 22 | 23 | return vm; 24 | } 25 | 26 | int main(int argc, char** argv) 27 | { 28 | try 29 | { 30 | const auto vm = process_cmd_line(argc, argv); 31 | if (is_help(vm)) { return 0; } 32 | 33 | test_loop loop(vm); 34 | if (!loop.init()) 35 | { 36 | std::cerr << "Test loop haven't initialized properly." << std::endl; 37 | return -1; 38 | } 39 | for (int i = 0; i < 5; ++i) { loop.run(); } 40 | } 41 | catch (const std::exception& e) 42 | { 43 | std::cout << "Error: " << e.what() << std::endl; 44 | return -1; 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /test_tools/sender_test/test_loop.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "configuration.h" 3 | #include "sender.h" 4 | 5 | #include 6 | 7 | class test_loop 8 | { 9 | public: 10 | test_loop(const boost::program_options::variables_map& vm); 11 | bool init(); 12 | void run(); 13 | 14 | private: 15 | static int load_file(const std::string& file_name, std::string& config_str); 16 | static int load_config_from_json(const std::string& file_name, reinforcement_learning::utility::configuration& config, 17 | reinforcement_learning::api_status* status); 18 | std::string get_message(size_t i) const; 19 | void init_messages(); 20 | 21 | private: 22 | const size_t _message_size; 23 | const size_t _message_count; 24 | const size_t _threads; 25 | const std::string _json_config; 26 | std::unique_ptr _sender; 27 | std::string _message; 28 | }; 29 | -------------------------------------------------------------------------------- /unit_test/common_test_utils.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "rl_string_view.h" 4 | 5 | #include 6 | 7 | bool is_invoked_with(const std::string& arg) 8 | { 9 | for (size_t i = 0; i < boost::unit_test::framework::master_test_suite().argc; i++) 10 | { 11 | if (reinforcement_learning::string_view(boost::unit_test::framework::master_test_suite().argv[i]).find(arg) != 12 | std::string::npos) 13 | { 14 | return true; 15 | } 16 | } 17 | return false; 18 | } -------------------------------------------------------------------------------- /unit_test/extensions/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if (rlclientlib_BUILD_ONNXRUNTIME_EXTENSION) 2 | add_subdirectory(onnx) 3 | endif() -------------------------------------------------------------------------------- /unit_test/extensions/onnx/global_fixture.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace Ort 6 | { 7 | struct MemoryInfo; 8 | } 9 | 10 | struct GlobalConfig 11 | { 12 | GlobalConfig(); 13 | ~GlobalConfig() = default; 14 | 15 | static GlobalConfig*& instance(); 16 | const Ort::MemoryInfo& get_memory_info(); 17 | 18 | private: 19 | std::unique_ptr TestMemoryInfo; 20 | }; 21 | -------------------------------------------------------------------------------- /unit_test/extensions/onnx/main.cc: -------------------------------------------------------------------------------- 1 | #define BOOST_TEST_MODULE Main 2 | #include 3 | 4 | #include "global_fixture.h" 5 | #include "onnx_extension.h" 6 | 7 | #include 8 | 9 | #include 10 | 11 | GlobalConfig::GlobalConfig() 12 | { 13 | instance() = this; 14 | reinforcement_learning::onnx::register_onnx_factory(); 15 | } 16 | 17 | GlobalConfig*& GlobalConfig::instance() 18 | { 19 | static GlobalConfig* s_inst = nullptr; 20 | return s_inst; 21 | } 22 | 23 | const Ort::MemoryInfo& GlobalConfig::get_memory_info() 24 | { 25 | if (TestMemoryInfo == nullptr) 26 | { 27 | TestMemoryInfo = std::unique_ptr( 28 | new Ort::MemoryInfo(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault).release())); 29 | } 30 | return *TestMemoryInfo; 31 | } 32 | 33 | BOOST_GLOBAL_FIXTURE(GlobalConfig); 34 | -------------------------------------------------------------------------------- /unit_test/extensions/onnx/mnist_data/mnist_model.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/unit_test/extensions/onnx/mnist_data/mnist_model.onnx -------------------------------------------------------------------------------- /unit_test/extensions/onnx/mock_helpers.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "factory_resolver.h" 4 | #include "sender.h" 5 | 6 | #ifdef __GNUG__ 7 | 8 | // Fakeit does not work with GCC's devirtualization 9 | // which is enabled with -O2 (the default) or higher. 10 | # pragma GCC optimize("no-devirtualize") 11 | 12 | #endif 13 | 14 | #include 15 | #include 16 | 17 | std::unique_ptr> get_mock_sender(int send_return_code); 18 | std::unique_ptr get_mock_sender_factory( 19 | fakeit::Mock* mock_observation_sender, 20 | fakeit::Mock* mock_interaction_sender); -------------------------------------------------------------------------------- /unit_test/extensions/onnx/test_data.h.in: -------------------------------------------------------------------------------- 1 | #define TEST_DATA(item_path) \ 2 | "@rltestonnx_DATA_ROOT@" "@rltestonnx_SYSTEM_PATH_SEPARATOR@" #item_path 3 | -------------------------------------------------------------------------------- /unit_test/file_logger_test.cc: -------------------------------------------------------------------------------- 1 | #ifdef STAND_ALONE 2 | # define BOOST_TEST_MODULE Main 3 | #endif 4 | 5 | #include "logger/file/file_logger.h" 6 | #include 7 | 8 | #include "err_constants.h" 9 | 10 | #include 11 | 12 | namespace rl = reinforcement_learning; 13 | namespace rlog = reinforcement_learning::logger; 14 | namespace rerr = reinforcement_learning::error_code; 15 | namespace rutil = reinforcement_learning::utility; 16 | 17 | bool file_exists(const std::string& file) 18 | { 19 | std::ifstream f(file); 20 | return f.good(); 21 | } 22 | 23 | BOOST_AUTO_TEST_CASE(file_logger_test) 24 | { 25 | const std::string file("file_logger_test"); 26 | { 27 | if (file_exists(file)) remove(file.c_str()); 28 | 29 | BOOST_CHECK(!file_exists(file)); 30 | 31 | rlog::file::file_logger logger(file, nullptr); 32 | rutil::configuration config; 33 | BOOST_CHECK_EQUAL(logger.init(config, nullptr), rerr::success); 34 | const auto buff = rl::i_sender::buffer(new rutil::data_buffer()); 35 | logger.send(buff); 36 | } 37 | 38 | BOOST_CHECK(file_exists(file)); 39 | remove(file.c_str()); 40 | } -------------------------------------------------------------------------------- /unit_test/interaction.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/unit_test/interaction.txt -------------------------------------------------------------------------------- /unit_test/json_serializer_test.cc: -------------------------------------------------------------------------------- 1 | #ifdef STAND_ALONE 2 | # define BOOST_TEST_MODULE Main 3 | #endif 4 | 5 | #include 6 | 7 | BOOST_AUTO_TEST_CASE(json_serializer_ranking_event_single) {} 8 | 9 | BOOST_AUTO_TEST_CASE(json_serializer_outcome_event_single_string) {} 10 | BOOST_AUTO_TEST_CASE(json_serializer_outcome_event_single_numeric) {} 11 | BOOST_AUTO_TEST_CASE(json_serializer_outcome_event_single_action_taken) {} 12 | 13 | BOOST_AUTO_TEST_CASE(json_serializer_ranking_event_collection) {} 14 | BOOST_AUTO_TEST_CASE(json_serializer_outcome_event_collection) {} 15 | 16 | BOOST_AUTO_TEST_CASE(json_serializer_outcome_event_collection_mixed_types) {} -------------------------------------------------------------------------------- /unit_test/main.cc: -------------------------------------------------------------------------------- 1 | #define BOOST_TEST_MODULE Main 2 | #include 3 | 4 | // wait for CMake build 5 | 6 | // #include 7 | 8 | // struct GlobalConfig { 9 | // ~GlobalConfig() 10 | // { 11 | // // fix memory leak: https://rt.openssl.org/Ticket/Display.html?id=2561&user=guest&pass=guest 12 | // SSL_COMP_free_compression_methods(); 13 | // } 14 | // }; 15 | 16 | // BOOST_GLOBAL_FIXTURE(GlobalConfig); 17 | -------------------------------------------------------------------------------- /unit_test/mock_http_client.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "utility/http_client.h" 3 | 4 | #include 5 | 6 | using namespace web; 7 | using namespace http; 8 | 9 | class mock_http_client : public reinforcement_learning::i_http_client 10 | { 11 | public: 12 | using response_fn = void(const http_request&, http_response&); 13 | 14 | mock_http_client(const std::string& url); 15 | 16 | virtual const std::string& get_url() const override; 17 | 18 | virtual response_t request(method_t method) override; 19 | virtual response_t request(request_t request) override; 20 | 21 | void set_responder(const http::method&, const std::function& custom_responder); 22 | 23 | private: 24 | static void handle_get(const http_request& message, http_response& resp); 25 | static void handle_put(const http_request& message, http_response& resp); 26 | static void handle_post(const http_request& message, http_response& resp); 27 | static void handle_delete(const http_request& message, http_response& resp); 28 | static void handle_head(const http_request& message, http_response& resp); 29 | 30 | private: 31 | const std::string _url; 32 | std::map> _responders; 33 | }; 34 | -------------------------------------------------------------------------------- /unit_test/observation.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VowpalWabbit/reinforcement_learning/84d54903df53dde642606118c83009961cfdeb6e/unit_test/observation.txt -------------------------------------------------------------------------------- /unit_test/outcome.json: -------------------------------------------------------------------------------- 1 | { 2 | "EventId": "b94a280e32024acb9a4fa12b058157d3", 3 | "v": "1.0" 4 | } 5 | -------------------------------------------------------------------------------- /unit_test/ranking_context.json: -------------------------------------------------------------------------------- 1 | { 2 | "Version": "1", 3 | "EventId": "b94a280e32024acb9a4fa12b058157d3", 4 | "a": [ 5 | 2, 6 | 1 7 | ], 8 | "c": { 9 | "User": { 10 | "_age": 22 11 | }, 12 | "Geo": { 13 | "country": "United States", 14 | "state": "California", 15 | "city": "Anaheim" 16 | }, 17 | "_multi": [ 18 | { 19 | "_tag": "cmplx$http://www.complex.com/style/2017/06/kid-puts-together-hypebeast-pop-up-book-for-art-class" 20 | }, 21 | { 22 | "_tag": "cmplx$http://www.complex.com/sports/2017/06/floyd-mayweather-will-beat-conor-mcgregor" 23 | } 24 | ] 25 | }, 26 | "p": [ 27 | 0.814285755, 28 | 0.0142857144 29 | ], 30 | "VWState": { 31 | "m": "680ec362b798463eaf64489efaa0d7b1" 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /unit_test/sleeper_test.cc: -------------------------------------------------------------------------------- 1 | #ifdef STAND_ALONE 2 | # define BOOST_TEST_MODULE Main 3 | #endif 4 | #include 5 | 6 | #include "utility/interruptable_sleeper.h" 7 | 8 | #include 9 | 10 | namespace u = reinforcement_learning::utility; 11 | 12 | BOOST_AUTO_TEST_CASE(sleeper_interrupt) 13 | { 14 | u::interruptable_sleeper sleeper; 15 | std::thread t( 16 | [&]() 17 | { 18 | // test interruption 19 | const auto start = std::chrono::system_clock::now(); 20 | sleeper.sleep(std::chrono::milliseconds(5000)); 21 | const auto stop = std::chrono::system_clock::now(); 22 | const auto diff = std::chrono::duration_cast(stop - start); 23 | BOOST_CHECK(diff <= std::chrono::milliseconds(100)); 24 | }); 25 | 26 | std::this_thread::sleep_for(std::chrono::milliseconds(10)); 27 | 28 | sleeper.interrupt(); 29 | t.join(); 30 | } 31 | 32 | BOOST_AUTO_TEST_CASE(sleeper_sleep) 33 | { 34 | u::interruptable_sleeper sleeper; 35 | std::thread t( 36 | [&]() 37 | { 38 | // test interruption 39 | const auto start = std::chrono::system_clock::now(); 40 | sleeper.sleep(std::chrono::milliseconds(100)); 41 | const auto stop = std::chrono::system_clock::now(); 42 | const auto diff = std::chrono::duration_cast(stop - start); 43 | BOOST_CHECK(diff >= std::chrono::milliseconds(80)); 44 | }); 45 | t.join(); 46 | } -------------------------------------------------------------------------------- /unit_test/status_builder_test.cc: -------------------------------------------------------------------------------- 1 | #ifdef STAND_ALONE 2 | # define BOOST_TEST_MODULE Main 3 | #endif 4 | 5 | #include 6 | 7 | #include "api_status.h" 8 | 9 | namespace err = reinforcement_learning::error_code; 10 | 11 | int testfn() 12 | { 13 | reinforcement_learning::api_status s; 14 | RETURN_ERROR_LS(nullptr, &s, create_fn_exception) << "Error msg: " << 5; 15 | } 16 | 17 | BOOST_AUTO_TEST_CASE(status_builder_usage) 18 | { 19 | const auto scode = testfn(); 20 | BOOST_CHECK_EQUAL(scode, err::create_fn_exception); 21 | } 22 | -------------------------------------------------------------------------------- /unit_test/str_util_test.cc: -------------------------------------------------------------------------------- 1 | #ifdef STAND_ALONE 2 | # define BOOST_TEST_MODULE Main 3 | #endif 4 | 5 | #include "str_util.h" 6 | #include 7 | 8 | using namespace reinforcement_learning::utility; 9 | using namespace std; 10 | 11 | BOOST_AUTO_TEST_CASE(str_functions) 12 | { 13 | string tval = " TRUE "; 14 | str_util::to_lower(tval); 15 | BOOST_CHECK_EQUAL(tval, " true "); 16 | str_util::ltrim(tval); 17 | BOOST_CHECK_EQUAL(tval, "true "); 18 | str_util::rtrim(tval); 19 | BOOST_CHECK_EQUAL(tval, "true"); 20 | 21 | tval = " FALSE "; 22 | str_util::trim(tval); 23 | BOOST_CHECK_EQUAL(tval, "FALSE"); 24 | } 25 | -------------------------------------------------------------------------------- /vcpkg.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://raw.githubusercontent.com/microsoft/vcpkg/master/scripts/vcpkg.schema.json", 3 | "name": "reinforcement-learning", 4 | "version": "1.0.2", 5 | "builtin-baseline": "f30434939d5516ce764c549ab04e3d23d312180a", 6 | "dependencies": [ 7 | "boost-align", 8 | "boost-asio", 9 | "boost-date-time", 10 | "boost-filesystem", 11 | "boost-interprocess", 12 | "boost-math", 13 | "boost-program-options", 14 | "boost-regex", 15 | "boost-system", 16 | "boost-test", 17 | "boost-thread", 18 | "boost-uuid", 19 | "cpprestsdk", 20 | "flatbuffers", 21 | "fmt", 22 | "openssl", 23 | "rapidjson", 24 | "spdlog", 25 | "zlib" 26 | ], 27 | "overrides": [ 28 | {"name": "cpprestsdk", "version": "2.10.18"}, 29 | {"name": "flatbuffers", "version": "23.1.21"}, 30 | {"name": "fmt", "version": "9.1.0"}, 31 | {"name": "spdlog", "version": "1.11.0"}, 32 | {"name": "zlib", "version": "1.2.13"} 33 | ], 34 | "features": { 35 | "benchmarks": { 36 | "description": "Build Benchmarks", 37 | "dependencies": [{"name":"benchmark", "version>=":"1.7.1"}] 38 | }, 39 | "azurelibs": { 40 | "description": "Build Azure-specific code", 41 | "dependencies": [{"name":"azure-identity-cpp"}] 42 | } 43 | } 44 | } 45 | --------------------------------------------------------------------------------