├── .gitignore ├── Doxyfile ├── ExploreSample ├── App.config ├── ExploreOnlySample.cs ├── ExploreSample.csproj ├── Program.cs └── Properties │ └── AssemblyInfo.cs ├── LICENSE ├── Makefile ├── README.md ├── ReadMe.txt ├── SECURITY.md ├── clr ├── explore_clr.vcxproj ├── explore_clr.vcxproj.filters ├── explore_clr_wrapper.cpp ├── explore_clr_wrapper.h ├── explore_interface.h └── explore_interop.h ├── explore.cpp ├── explore.sln ├── explore.vcxproj ├── explore.vcxproj.filters ├── explore_sample.cpp ├── mwt.chm ├── static ├── MWTExplorer.h ├── Makefile ├── explore.cpp ├── explore_static.vcxproj ├── utility.h └── vw_explore.vcxproj.filters └── tests ├── ExploreTests.csproj ├── ExploreTests.csproj.user ├── MWTExploreTests.cpp ├── MWTExploreTests.cs ├── MWTExploreTests.h ├── Properties └── AssemblyInfo.cs ├── explore_tests.vcxproj ├── explore_tests.vcxproj.filters ├── stdafx.cpp ├── stdafx.h └── targetver.h /.gitignore: -------------------------------------------------------------------------------- 1 | # build folders 2 | **/x64 3 | **/obj 4 | **/Debug 5 | **/Release 6 | **/dll 7 | **/ipch 8 | 9 | # VS files 10 | *.opensdf 11 | *.suo 12 | *.sdf 13 | *.vcxproj.user 14 | *.so 15 | 16 | # doxygen output folders 17 | html 18 | latex -------------------------------------------------------------------------------- /ExploreSample/App.config: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /ExploreSample/ExploreOnlySample.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | using System.Text; 5 | using MultiWorldTesting; 6 | 7 | namespace cs_test 8 | { 9 | class ExploreOnlySample 10 | { 11 | /// 12 | /// Example of a custom context. 13 | /// 14 | class MyContext { } 15 | 16 | /// 17 | /// Example of a custom recorder which implements the IRecorder, 18 | /// declaring that this recorder only interacts with MyContext objects. 19 | /// 20 | class MyRecorder : IRecorder 21 | { 22 | public void Record(MyContext context, UInt32 action, float probability, string uniqueKey) 23 | { 24 | // Stores the tuple internally in a vector that could be used later for other purposes. 25 | interactions.Add(new Interaction() 26 | { 27 | Context = context, 28 | Action = action, 29 | Probability = probability, 30 | UniqueKey = uniqueKey 31 | }); 32 | } 33 | 34 | public List> GetAllInteractions() 35 | { 36 | return interactions; 37 | } 38 | 39 | private List> interactions = new List>(); 40 | } 41 | 42 | /// 43 | /// Example of a custom policy which implements the IPolicy, 44 | /// declaring that this policy only interacts with MyContext objects. 45 | /// 46 | class MyPolicy : IPolicy 47 | { 48 | public MyPolicy() : this(-1) { } 49 | 50 | public MyPolicy(int index) 51 | { 52 | this.index = index; 53 | } 54 | 55 | public uint ChooseAction(MyContext context) 56 | { 57 | // Always returns the same action regardless of context 58 | return 5; 59 | } 60 | 61 | private int index; 62 | } 63 | 64 | /// 65 | /// Example of a custom policy which implements the IPolicy, 66 | /// declaring that this policy only interacts with SimpleContext objects. 67 | /// 68 | class StringPolicy : IPolicy 69 | { 70 | public uint ChooseAction(SimpleContext context) 71 | { 72 | // Always returns the same action regardless of context 73 | return 1; 74 | } 75 | } 76 | 77 | /// 78 | /// Example of a custom scorer which implements the IScorer, 79 | /// declaring that this scorer only interacts with MyContext objects. 80 | /// 81 | class MyScorer : IScorer 82 | { 83 | public MyScorer(uint numActions) 84 | { 85 | this.numActions = numActions; 86 | } 87 | public List ScoreActions(MyContext context) 88 | { 89 | return Enumerable.Repeat(1.0f / numActions, (int)numActions).ToList(); 90 | } 91 | private uint numActions; 92 | } 93 | 94 | /// 95 | /// Represents a tuple . 96 | /// 97 | /// The Context type. 98 | struct Interaction 99 | { 100 | public Ctx Context; 101 | public uint Action; 102 | public float Probability; 103 | public string UniqueKey; 104 | } 105 | 106 | public static void Run() 107 | { 108 | string exploration_type = "greedy"; 109 | 110 | if (exploration_type == "greedy") 111 | { 112 | // Initialize Epsilon-Greedy explore algorithm using built-in StringRecorder and SimpleContext types 113 | 114 | // Creates a recorder of built-in StringRecorder type for string serialization 115 | StringRecorder recorder = new StringRecorder(); 116 | 117 | // Creates an MwtExplorer instance using the recorder above 118 | MwtExplorer mwtt = new MwtExplorer("mwt", recorder); 119 | 120 | // Creates a policy that interacts with SimpleContext type 121 | StringPolicy policy = new StringPolicy(); 122 | 123 | uint numActions = 10; 124 | float epsilon = 0.2f; 125 | // Creates an Epsilon-Greedy explorer using the specified settings 126 | EpsilonGreedyExplorer explorer = new EpsilonGreedyExplorer(policy, epsilon, numActions); 127 | 128 | // Creates a context of built-in SimpleContext type 129 | SimpleContext context = new SimpleContext(new Feature[] { 130 | new Feature() { Id = 1, Value = 0.5f }, 131 | new Feature() { Id = 4, Value = 1.3f }, 132 | new Feature() { Id = 9, Value = -0.5f }, 133 | }); 134 | 135 | // Performs exploration by passing an instance of the Epsilon-Greedy exploration algorithm into MwtExplorer 136 | // using a sample string to uniquely identify this event 137 | string uniqueKey = "eventid"; 138 | uint action = mwtt.ChooseAction(explorer, uniqueKey, context); 139 | 140 | Console.WriteLine(recorder.GetRecording()); 141 | 142 | return; 143 | } 144 | else if (exploration_type == "tau-first") 145 | { 146 | // Initialize Tau-First explore algorithm using custom Recorder, Policy & Context types 147 | MyRecorder recorder = new MyRecorder(); 148 | MwtExplorer mwtt = new MwtExplorer("mwt", recorder); 149 | 150 | uint numActions = 10; 151 | uint tau = 0; 152 | MyPolicy policy = new MyPolicy(); 153 | uint action = mwtt.ChooseAction(new TauFirstExplorer(policy, tau, numActions), "key", new MyContext()); 154 | Console.WriteLine(String.Join(",", recorder.GetAllInteractions().Select(it => it.Action))); 155 | return; 156 | } 157 | else if (exploration_type == "bootstrap") 158 | { 159 | // Initialize Bootstrap explore algorithm using custom Recorder, Policy & Context types 160 | MyRecorder recorder = new MyRecorder(); 161 | MwtExplorer mwtt = new MwtExplorer("mwt", recorder); 162 | 163 | uint numActions = 10; 164 | uint numbags = 2; 165 | MyPolicy[] policies = new MyPolicy[numbags]; 166 | for (int i = 0; i < numbags; i++) 167 | { 168 | policies[i] = new MyPolicy(i * 2); 169 | } 170 | uint action = mwtt.ChooseAction(new BootstrapExplorer(policies, numActions), "key", new MyContext()); 171 | Console.WriteLine(String.Join(",", recorder.GetAllInteractions().Select(it => it.Action))); 172 | return; 173 | } 174 | else if (exploration_type == "softmax") 175 | { 176 | // Initialize Softmax explore algorithm using custom Recorder, Scorer & Context types 177 | MyRecorder recorder = new MyRecorder(); 178 | MwtExplorer mwtt = new MwtExplorer("mwt", recorder); 179 | 180 | uint numActions = 10; 181 | float lambda = 0.5f; 182 | MyScorer scorer = new MyScorer(numActions); 183 | uint action = mwtt.ChooseAction(new SoftmaxExplorer(scorer, lambda, numActions), "key", new MyContext()); 184 | 185 | Console.WriteLine(String.Join(",", recorder.GetAllInteractions().Select(it => it.Action))); 186 | return; 187 | } 188 | else if (exploration_type == "generic") 189 | { 190 | // Initialize Generic explore algorithm using custom Recorder, Scorer & Context types 191 | MyRecorder recorder = new MyRecorder(); 192 | MwtExplorer mwtt = new MwtExplorer("mwt", recorder); 193 | 194 | uint numActions = 10; 195 | MyScorer scorer = new MyScorer(numActions); 196 | uint action = mwtt.ChooseAction(new GenericExplorer(scorer, numActions), "key", new MyContext()); 197 | 198 | Console.WriteLine(String.Join(",", recorder.GetAllInteractions().Select(it => it.Action))); 199 | return; 200 | } 201 | else 202 | { //add error here 203 | 204 | 205 | } 206 | } 207 | } 208 | } 209 | -------------------------------------------------------------------------------- /ExploreSample/ExploreSample.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | Debug 6 | AnyCPU 7 | {7081D542-AE64-485D-9087-79194B958499} 8 | Exe 9 | Properties 10 | ExploreSample 11 | ExploreSample 12 | v4.5 13 | 512 14 | 15 | 16 | true 17 | bin\x86\Debug\ 18 | DEBUG;TRACE 19 | full 20 | x86 21 | prompt 22 | MinimumRecommendedRules.ruleset 23 | true 24 | 25 | 26 | bin\x86\Release\ 27 | TRACE 28 | true 29 | pdbonly 30 | x86 31 | prompt 32 | MinimumRecommendedRules.ruleset 33 | true 34 | 35 | 36 | true 37 | bin\x64\Debug\ 38 | DEBUG;TRACE 39 | full 40 | x64 41 | prompt 42 | MinimumRecommendedRules.ruleset 43 | true 44 | 45 | 46 | bin\x64\Release\ 47 | TRACE 48 | true 49 | pdbonly 50 | x64 51 | prompt 52 | MinimumRecommendedRules.ruleset 53 | true 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | {8400da16-1f46-4a31-a126-bbe16f62bfd7} 75 | explore_clr 76 | 77 | 78 | 79 | 86 | -------------------------------------------------------------------------------- /ExploreSample/Program.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | using System.Text; 5 | using MultiWorldTesting; 6 | 7 | namespace ExploreSample 8 | { 9 | class Program 10 | { 11 | public static void Main(string[] args) 12 | { 13 | cs_test.ExploreOnlySample.Run(); 14 | } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /ExploreSample/Properties/AssemblyInfo.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | using System.Runtime.CompilerServices; 3 | using System.Runtime.InteropServices; 4 | 5 | // General Information about an assembly is controlled through the following 6 | // set of attributes. Change these attribute values to modify the information 7 | // associated with an assembly. 8 | [assembly: AssemblyTitle("ExploreSample")] 9 | [assembly: AssemblyDescription("")] 10 | [assembly: AssemblyConfiguration("")] 11 | [assembly: AssemblyCompany("")] 12 | [assembly: AssemblyProduct("ExploreSample")] 13 | [assembly: AssemblyCopyright("Copyright © 2014")] 14 | [assembly: AssemblyTrademark("")] 15 | [assembly: AssemblyCulture("")] 16 | 17 | // Setting ComVisible to false makes the types in this assembly not visible 18 | // to COM components. If you need to access a type in this assembly from 19 | // COM, set the ComVisible attribute to true on that type. 20 | [assembly: ComVisible(false)] 21 | 22 | // The following GUID is for the ID of the typelib if this project is exposed to COM 23 | [assembly: Guid("767d4e7c-6acc-4b46-9eac-e86ab079625a")] 24 | 25 | // Version information for an assembly consists of the following four values: 26 | // 27 | // Major Version 28 | // Minor Version 29 | // Build Number 30 | // Revision 31 | // 32 | // You can specify all the values or you can default the Build and Revision Numbers 33 | // by using the '*' as shown below: 34 | // [assembly: AssemblyVersion("1.0.*")] 35 | [assembly: AssemblyVersion("1.0.0.0")] 36 | [assembly: AssemblyFileVersion("1.0.0.0")] 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright © Microsoft Corp 2012-2014, Yahoo! Inc. 2007-2012, and many 2 | individual contributors. 3 | 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright 11 | 12 | notice, this list of conditions and the following disclaimer. 13 | 14 | * Redistributions in binary form must reproduce the above copyright 15 | 16 | notice, this list of conditions and the following disclaimer in the 17 | 18 | documentation and/or other materials provided with the distribution. 19 | 20 | * Neither the name of the Microsoft Corp nor the 21 | 22 | names of its contributors may be used to endorse or promote products 23 | 24 | derived from this software without specific prior written permission. 25 | 26 | 27 | 28 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 29 | 30 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 31 | 32 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 33 | 34 | DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY 35 | 36 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 37 | 38 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 39 | 40 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 41 | 42 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 43 | 44 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 45 | 46 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 47 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: sample 2 | 3 | time: explore.cpp 4 | g++ -Wall -O3 explore.cpp -I static -I ../vowpalwabbit -std=c++0x 5 | 6 | sample: explore_sample.cpp 7 | g++ -Wall -O3 explore_sample.cpp -I static -I ../vowpalwabbit -std=c++0x 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Exploration Library 2 | ======= 3 | 4 | [WARNING: This repository is deprecated. See https://github.com/multiworldtesting/explore-csharp for C# and https://github.com/multiworldtesting/explore-cpp for C++ instead] 5 | 6 | The exploration library addresses the ‘gathering the data’ aspect of machine learning rather than the ‘using the data’ aspect we are most familiar with. The primary goal here is to enable individuals (i.e. you) to gather the right data for using machine learning for interventions in a live system based on user feedback (click, dwell, correction, etc…). Empirically, gathering the right data has often made a substantial difference. Theoretically, we know it is required to disentangle causation from correlation effectively in general. 7 | 8 | First version of client-side exploration library that includes the following exploration algorithms: 9 | - Epsilon Greedy 10 | - Tau First 11 | - Softmax 12 | - Bootstrap 13 | - Generic (users can specify custom weight for every action) 14 | 15 | This release supports C++ and C#. 16 | 17 | For sample usage, see: 18 | 19 | C++: https://github.com/multiworldtesting/explore/blob/master/explore_sample.cpp 20 | 21 | C#: https://github.com/multiworldtesting/explore/blob/master/ExploreSample/ExploreOnlySample.cs 22 | -------------------------------------------------------------------------------- /ReadMe.txt: -------------------------------------------------------------------------------- 1 | This is the root of the exploration library and client-side decision service. 2 | 3 | explore_sample.cpp shows how to use the exploration library which is a 4 | header-only include in C++. 5 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /clr/explore_clr.vcxproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | Debug 6 | Win32 7 | 8 | 9 | Debug 10 | x64 11 | 12 | 13 | Release 14 | Win32 15 | 16 | 17 | Release 18 | x64 19 | 20 | 21 | 22 | {8400DA16-1F46-4A31-A126-BBE16F62BFD7} 23 | Win32Proj 24 | vw_explore_clr_wrapper 25 | explore_clr 26 | 27 | 28 | 29 | DynamicLibrary 30 | true 31 | v120 32 | Unicode 33 | true 34 | 35 | 36 | DynamicLibrary 37 | true 38 | v120 39 | Unicode 40 | true 41 | 42 | 43 | DynamicLibrary 44 | false 45 | v120 46 | true 47 | Unicode 48 | true 49 | 50 | 51 | DynamicLibrary 52 | false 53 | v120 54 | true 55 | Unicode 56 | true 57 | 58 | 59 | c:\boost\x64\include\boost-1_56 60 | c:\boost\x64\lib 61 | ..\..\..\zlib-1.2.8 62 | $(ZlibIncludeDir)\contrib\vstudio\vc11\x64\ZlibStat$(Configuration) 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | true 82 | 83 | 84 | true 85 | 86 | 87 | false 88 | 89 | 90 | false 91 | 92 | 93 | 94 | NotUsing 95 | Level3 96 | Disabled 97 | WIN32;_DEBUG;_WINDOWS;_USRDLL;VW_EXPLORE_CLR_WRAPPER_EXPORTS;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) 98 | true 99 | ..\static;%(AdditionalIncludeDirectories) 100 | true 101 | 102 | 103 | Windows 104 | true 105 | 106 | 107 | 108 | 109 | NotUsing 110 | Level3 111 | Disabled 112 | WIN32;_DEBUG;_WINDOWS;_USRDLL;VW_EXPLORE_CLR_WRAPPER_EXPORTS;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) 113 | true 114 | ..\static;%(AdditionalIncludeDirectories) 115 | true 116 | 117 | 118 | Windows 119 | true 120 | 121 | 122 | 123 | 124 | Level3 125 | NotUsing 126 | MaxSpeed 127 | true 128 | true 129 | WIN32;NDEBUG;_WINDOWS;_USRDLL;VW_EXPLORE_CLR_WRAPPER_EXPORTS;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) 130 | true 131 | ..\static;%(AdditionalIncludeDirectories) 132 | true 133 | 134 | 135 | Windows 136 | true 137 | true 138 | true 139 | 140 | 141 | 142 | 143 | Level3 144 | NotUsing 145 | MaxSpeed 146 | true 147 | true 148 | WIN32;NDEBUG;_WINDOWS;_USRDLL;VW_EXPLORE_CLR_WRAPPER_EXPORTS;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) 149 | true 150 | ..\static;%(AdditionalIncludeDirectories) 151 | true 152 | 153 | 154 | Windows 155 | true 156 | true 157 | true 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | {ace47e98-488c-4cdf-b9f1-36337b2855ad} 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | -------------------------------------------------------------------------------- /clr/explore_clr.vcxproj.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF} 6 | cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx 7 | 8 | 9 | {93995380-89BD-4b04-88EB-625FBE52EBFB} 10 | h;hpp;hxx;hm;inl;inc;xsd 11 | 12 | 13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} 14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | Header Files 23 | 24 | 25 | Header Files 26 | 27 | 28 | Header Files 29 | 30 | 31 | 32 | 33 | Source Files 34 | 35 | 36 | -------------------------------------------------------------------------------- /clr/explore_clr_wrapper.cpp: -------------------------------------------------------------------------------- 1 | // vw_explore_clr_wrapper.cpp : Defines the exported functions for the DLL application. 2 | // 3 | 4 | #define WIN32_LEAN_AND_MEAN 5 | #include 6 | 7 | #include "explore_clr_wrapper.h" 8 | 9 | using namespace System; 10 | using namespace System::Collections; 11 | using namespace System::Collections::Generic; 12 | using namespace System::Runtime::InteropServices; 13 | using namespace msclr::interop; 14 | using namespace NativeMultiWorldTesting; 15 | 16 | namespace MultiWorldTesting { 17 | 18 | } 19 | -------------------------------------------------------------------------------- /clr/explore_clr_wrapper.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "explore_interop.h" 3 | 4 | /*! 5 | * \addtogroup MultiWorldTestingCsharp 6 | * @{ 7 | */ 8 | namespace MultiWorldTesting { 9 | 10 | /// 11 | /// The epsilon greedy exploration class. 12 | /// 13 | /// 14 | /// This is a good choice if you have no idea which actions should be preferred. 15 | /// Epsilon greedy is also computationally cheap. 16 | /// 17 | /// The Context type. 18 | generic 19 | public ref class EpsilonGreedyExplorer : public IExplorer, public IConsumePolicy, public PolicyCallback 20 | { 21 | public: 22 | /// 23 | /// The constructor is the only public member, because this should be used with the MwtExplorer. 24 | /// 25 | /// A default function which outputs an action given a context. 26 | /// The probability of a random exploration. 27 | /// The number of actions to randomize over. 28 | EpsilonGreedyExplorer(IPolicy^ defaultPolicy, float epsilon, UInt32 numActions) 29 | { 30 | this->defaultPolicy = defaultPolicy; 31 | m_explorer = new NativeMultiWorldTesting::EpsilonGreedyExplorer(*GetNativePolicy(), epsilon, (u32)numActions); 32 | } 33 | 34 | /// 35 | /// Initializes an epsilon greedy explorer with variable number of actions. 36 | /// 37 | /// A default function which outputs an action given a context. 38 | /// The probability of a random exploration. 39 | EpsilonGreedyExplorer(IPolicy^ defaultPolicy, float epsilon) 40 | { 41 | if (!(IVariableActionContext::typeid->IsAssignableFrom(Ctx::typeid))) 42 | { 43 | throw gcnew ArgumentException("The specified context type does not implement variable-action interface.", "Ctx"); 44 | } 45 | 46 | this->defaultPolicy = defaultPolicy; 47 | m_explorer = new NativeMultiWorldTesting::EpsilonGreedyExplorer(*GetNativePolicy(), epsilon); 48 | } 49 | 50 | ~EpsilonGreedyExplorer() 51 | { 52 | delete m_explorer; 53 | } 54 | 55 | virtual void UpdatePolicy(IPolicy^ newPolicy) 56 | { 57 | this->defaultPolicy = newPolicy; 58 | } 59 | 60 | virtual void EnableExplore(bool explore) 61 | { 62 | m_explorer->Enable_Explore(explore); 63 | } 64 | 65 | internal: 66 | virtual UInt32 InvokePolicyCallback(Ctx context, int index) override 67 | { 68 | return defaultPolicy->ChooseAction(context); 69 | } 70 | 71 | NativeMultiWorldTesting::EpsilonGreedyExplorer* Get() 72 | { 73 | return m_explorer; 74 | } 75 | 76 | private: 77 | IPolicy^ defaultPolicy; 78 | NativeMultiWorldTesting::EpsilonGreedyExplorer* m_explorer; 79 | }; 80 | 81 | /// 82 | /// The tau-first exploration class. 83 | /// 84 | /// 85 | /// The tau-first explorer collects precisely tau uniform random 86 | /// exploration events, and then uses the default policy. 87 | /// 88 | /// The Context type. 89 | generic 90 | public ref class TauFirstExplorer : public IExplorer, public IConsumePolicy, public PolicyCallback 91 | { 92 | public: 93 | /// 94 | /// The constructor is the only public member, because this should be used with the MwtExplorer. 95 | /// 96 | /// A default policy after randomization finishes. 97 | /// The number of events to be uniform over. 98 | /// The number of actions to randomize over. 99 | TauFirstExplorer(IPolicy^ defaultPolicy, UInt32 tau, UInt32 numActions) 100 | { 101 | this->defaultPolicy = defaultPolicy; 102 | m_explorer = new NativeMultiWorldTesting::TauFirstExplorer(*GetNativePolicy(), tau, (u32)numActions); 103 | } 104 | 105 | /// 106 | /// Initializes a tau-first explorer with variable number of actions. 107 | /// 108 | /// A default policy after randomization finishes. 109 | /// The number of events to be uniform over. 110 | TauFirstExplorer(IPolicy^ defaultPolicy, UInt32 tau) 111 | { 112 | if (!(IVariableActionContext::typeid->IsAssignableFrom(Ctx::typeid))) 113 | { 114 | throw gcnew ArgumentException("The specified context type does not implement variable-action interface.", "Ctx"); 115 | } 116 | 117 | this->defaultPolicy = defaultPolicy; 118 | m_explorer = new NativeMultiWorldTesting::TauFirstExplorer(*GetNativePolicy(), tau); 119 | } 120 | 121 | virtual void UpdatePolicy(IPolicy^ newPolicy) 122 | { 123 | this->defaultPolicy = newPolicy; 124 | } 125 | 126 | virtual void EnableExplore(bool explore) 127 | { 128 | m_explorer->Enable_Explore(explore); 129 | } 130 | 131 | ~TauFirstExplorer() 132 | { 133 | delete m_explorer; 134 | } 135 | 136 | internal: 137 | virtual UInt32 InvokePolicyCallback(Ctx context, int index) override 138 | { 139 | return defaultPolicy->ChooseAction(context); 140 | } 141 | 142 | NativeMultiWorldTesting::TauFirstExplorer* Get() 143 | { 144 | return m_explorer; 145 | } 146 | 147 | private: 148 | IPolicy^ defaultPolicy; 149 | NativeMultiWorldTesting::TauFirstExplorer* m_explorer; 150 | }; 151 | 152 | /// 153 | /// The epsilon greedy exploration class. 154 | /// 155 | /// 156 | /// In some cases, different actions have a different scores, and you 157 | /// would prefer to choose actions with large scores. Softmax allows 158 | /// you to do that. 159 | /// 160 | /// The Context type. 161 | generic 162 | public ref class SoftmaxExplorer : public IExplorer, public IConsumeScorer, public ScorerCallback 163 | { 164 | public: 165 | /// 166 | /// The constructor is the only public member, because this should be used with the MwtExplorer. 167 | /// 168 | /// A function which outputs a score for each action. 169 | /// lambda = 0 implies uniform distribution. Large lambda is equivalent to a max. 170 | /// The number of actions to randomize over. 171 | SoftmaxExplorer(IScorer^ defaultScorer, float lambda, UInt32 numActions) 172 | { 173 | this->defaultScorer = defaultScorer; 174 | m_explorer = new NativeMultiWorldTesting::SoftmaxExplorer(*GetNativeScorer(), lambda, (u32)numActions); 175 | } 176 | 177 | /// 178 | /// Initializes a softmax explorer with variable number of actions. 179 | /// 180 | /// A function which outputs a score for each action. 181 | /// lambda = 0 implies uniform distribution. Large lambda is equivalent to a max. 182 | SoftmaxExplorer(IScorer^ defaultScorer, float lambda) 183 | { 184 | if (!(IVariableActionContext::typeid->IsAssignableFrom(Ctx::typeid))) 185 | { 186 | throw gcnew ArgumentException("The specified context type does not implement variable-action interface.", "Ctx"); 187 | } 188 | 189 | this->defaultScorer = defaultScorer; 190 | m_explorer = new NativeMultiWorldTesting::SoftmaxExplorer(*GetNativeScorer(), lambda); 191 | } 192 | 193 | virtual void UpdateScorer(IScorer^ newScorer) 194 | { 195 | this->defaultScorer = newScorer; 196 | } 197 | 198 | virtual void EnableExplore(bool explore) 199 | { 200 | m_explorer->Enable_Explore(explore); 201 | } 202 | 203 | ~SoftmaxExplorer() 204 | { 205 | delete m_explorer; 206 | } 207 | 208 | internal: 209 | virtual List^ InvokeScorerCallback(Ctx context) override 210 | { 211 | return defaultScorer->ScoreActions(context); 212 | } 213 | 214 | NativeMultiWorldTesting::SoftmaxExplorer* Get() 215 | { 216 | return m_explorer; 217 | } 218 | 219 | private: 220 | IScorer^ defaultScorer; 221 | NativeMultiWorldTesting::SoftmaxExplorer* m_explorer; 222 | }; 223 | 224 | /// 225 | /// The generic exploration class. 226 | /// 227 | /// 228 | /// GenericExplorer provides complete flexibility. You can create any 229 | /// distribution over actions desired, and it will draw from that. 230 | /// 231 | /// The Context type. 232 | generic 233 | public ref class GenericExplorer : public IExplorer, public IConsumeScorer, public ScorerCallback 234 | { 235 | public: 236 | /// 237 | /// The constructor is the only public member, because this should be used with the MwtExplorer. 238 | /// 239 | /// A function which outputs the probability of each action. 240 | /// The number of actions to randomize over. 241 | GenericExplorer(IScorer^ defaultScorer, UInt32 numActions) 242 | { 243 | this->defaultScorer = defaultScorer; 244 | m_explorer = new NativeMultiWorldTesting::GenericExplorer(*GetNativeScorer(), (u32)numActions); 245 | } 246 | 247 | /// 248 | /// Initializes a generic explorer with variable number of actions. 249 | /// 250 | /// A function which outputs the probability of each action. 251 | GenericExplorer(IScorer^ defaultScorer) 252 | { 253 | if (!(IVariableActionContext::typeid->IsAssignableFrom(Ctx::typeid))) 254 | { 255 | throw gcnew ArgumentException("The specified context type does not implement variable-action interface.", "Ctx"); 256 | } 257 | 258 | this->defaultScorer = defaultScorer; 259 | m_explorer = new NativeMultiWorldTesting::GenericExplorer(*GetNativeScorer()); 260 | } 261 | 262 | virtual void UpdateScorer(IScorer^ newScorer) 263 | { 264 | this->defaultScorer = newScorer; 265 | } 266 | 267 | virtual void EnableExplore(bool explore) 268 | { 269 | m_explorer->Enable_Explore(explore); 270 | } 271 | 272 | ~GenericExplorer() 273 | { 274 | delete m_explorer; 275 | } 276 | 277 | internal: 278 | virtual List^ InvokeScorerCallback(Ctx context) override 279 | { 280 | return defaultScorer->ScoreActions(context); 281 | } 282 | 283 | NativeMultiWorldTesting::GenericExplorer* Get() 284 | { 285 | return m_explorer; 286 | } 287 | 288 | private: 289 | IScorer^ defaultScorer; 290 | NativeMultiWorldTesting::GenericExplorer* m_explorer; 291 | }; 292 | 293 | /// 294 | /// The bootstrap exploration class. 295 | /// 296 | /// 297 | /// The Bootstrap explorer randomizes over the actions chosen by a set of 298 | /// default policies. This performs well statistically but can be 299 | /// computationally expensive. 300 | /// 301 | /// The Context type. 302 | generic 303 | public ref class BootstrapExplorer : public IExplorer, public IConsumePolicies, public PolicyCallback 304 | { 305 | public: 306 | /// 307 | /// The constructor is the only public member, because this should be used with the MwtExplorer. 308 | /// 309 | /// A set of default policies to be uniform random over. 310 | /// The number of actions to randomize over. 311 | BootstrapExplorer(cli::array^>^ defaultPolicies, UInt32 numActions) 312 | { 313 | this->defaultPolicies = defaultPolicies; 314 | if (this->defaultPolicies == nullptr) 315 | { 316 | throw gcnew ArgumentNullException("The specified array of default policy functions cannot be null."); 317 | } 318 | 319 | m_explorer = new NativeMultiWorldTesting::BootstrapExplorer(*GetNativePolicies((u32)defaultPolicies->Length), (u32)numActions); 320 | } 321 | 322 | /// 323 | /// Initializes a bootstrap explorer with variable number of actions. 324 | /// 325 | /// A set of default policies to be uniform random over. 326 | BootstrapExplorer(cli::array^>^ defaultPolicies) 327 | { 328 | if (!(IVariableActionContext::typeid->IsAssignableFrom(Ctx::typeid))) 329 | { 330 | throw gcnew ArgumentException("The specified context type does not implement variable-action interface.", "Ctx"); 331 | } 332 | 333 | this->defaultPolicies = defaultPolicies; 334 | if (this->defaultPolicies == nullptr) 335 | { 336 | throw gcnew ArgumentNullException("The specified array of default policy functions cannot be null."); 337 | } 338 | 339 | m_explorer = new NativeMultiWorldTesting::BootstrapExplorer(*GetNativePolicies((u32)defaultPolicies->Length)); 340 | } 341 | 342 | virtual void UpdatePolicy(cli::array^>^ newPolicies) 343 | { 344 | this->defaultPolicies = newPolicies; 345 | } 346 | 347 | virtual void EnableExplore(bool explore) 348 | { 349 | m_explorer->Enable_Explore(explore); 350 | } 351 | 352 | ~BootstrapExplorer() 353 | { 354 | delete m_explorer; 355 | } 356 | 357 | internal: 358 | virtual UInt32 InvokePolicyCallback(Ctx context, int index) override 359 | { 360 | if (index < 0 || index >= defaultPolicies->Length) 361 | { 362 | throw gcnew InvalidDataException("Internal error: Index of interop bag is out of range."); 363 | } 364 | return defaultPolicies[index]->ChooseAction(context); 365 | } 366 | 367 | NativeMultiWorldTesting::BootstrapExplorer* Get() 368 | { 369 | return m_explorer; 370 | } 371 | 372 | private: 373 | cli::array^>^ defaultPolicies; 374 | NativeMultiWorldTesting::BootstrapExplorer* m_explorer; 375 | }; 376 | 377 | /// 378 | /// The top level MwtExplorer class. Using this makes sure that the 379 | /// right bits are recorded and good random actions are chosen. 380 | /// 381 | /// The Context type. 382 | generic 383 | public ref class MwtExplorer : public RecorderCallback 384 | { 385 | public: 386 | /// 387 | /// Constructor. 388 | /// 389 | /// This should be unique to each experiment to avoid correlation bugs. 390 | /// A user-specified class for recording the appropriate bits for use in evaluation and learning. 391 | MwtExplorer(String^ appId, IRecorder^ recorder) 392 | { 393 | this->appId = appId; 394 | this->recorder = recorder; 395 | } 396 | 397 | /// 398 | /// Choose_Action should be drop-in replacement for any existing policy function. 399 | /// 400 | /// An existing exploration algorithm (one of the above) which uses the default policy as a callback. 401 | /// A unique identifier for the experimental unit. This could be a user id, a session id, etc... 402 | /// The context upon which a decision is made. See SimpleContext above for an example. 403 | /// An unsigned 32-bit integer representing the 1-based chosen action. 404 | UInt32 ChooseAction(IExplorer^ explorer, String^ unique_key, Ctx context) 405 | { 406 | String^ salt = this->appId; 407 | NativeMultiWorldTesting::MwtExplorer mwt(marshal_as(salt), *GetNativeRecorder()); 408 | 409 | // Normal handles are sufficient here since native code will only hold references and not access the object's data 410 | // https://www.microsoftpressstore.com/articles/article.aspx?p=2224054&seqNum=4 411 | GCHandle selfHandle = GCHandle::Alloc(this); 412 | IntPtr selfPtr = (IntPtr)selfHandle; 413 | 414 | GCHandle contextHandle = GCHandle::Alloc(context); 415 | IntPtr contextPtr = (IntPtr)contextHandle; 416 | 417 | GCHandle explorerHandle = GCHandle::Alloc(explorer); 418 | IntPtr explorerPtr = (IntPtr)explorerHandle; 419 | 420 | try 421 | { 422 | NativeContext native_context(selfPtr.ToPointer(), explorerPtr.ToPointer(), contextPtr.ToPointer(), 423 | this->GetNumActionsCallback()); 424 | u32 action = 0; 425 | if (explorer->GetType() == EpsilonGreedyExplorer::typeid) 426 | { 427 | EpsilonGreedyExplorer^ epsilonGreedyExplorer = (EpsilonGreedyExplorer^)explorer; 428 | action = mwt.Choose_Action(*epsilonGreedyExplorer->Get(), marshal_as(unique_key), native_context); 429 | } 430 | else if (explorer->GetType() == TauFirstExplorer::typeid) 431 | { 432 | TauFirstExplorer^ tauFirstExplorer = (TauFirstExplorer^)explorer; 433 | action = mwt.Choose_Action(*tauFirstExplorer->Get(), marshal_as(unique_key), native_context); 434 | } 435 | else if (explorer->GetType() == SoftmaxExplorer::typeid) 436 | { 437 | SoftmaxExplorer^ softmaxExplorer = (SoftmaxExplorer^)explorer; 438 | action = mwt.Choose_Action(*softmaxExplorer->Get(), marshal_as(unique_key), native_context); 439 | } 440 | else if (explorer->GetType() == GenericExplorer::typeid) 441 | { 442 | GenericExplorer^ genericExplorer = (GenericExplorer^)explorer; 443 | action = mwt.Choose_Action(*genericExplorer->Get(), marshal_as(unique_key), native_context); 444 | } 445 | else if (explorer->GetType() == BootstrapExplorer::typeid) 446 | { 447 | BootstrapExplorer^ bootstrapExplorer = (BootstrapExplorer^)explorer; 448 | action = mwt.Choose_Action(*bootstrapExplorer->Get(), marshal_as(unique_key), native_context); 449 | } 450 | return action; 451 | } 452 | finally 453 | { 454 | if (explorerHandle.IsAllocated) 455 | { 456 | explorerHandle.Free(); 457 | } 458 | if (contextHandle.IsAllocated) 459 | { 460 | contextHandle.Free(); 461 | } 462 | if (selfHandle.IsAllocated) 463 | { 464 | selfHandle.Free(); 465 | } 466 | } 467 | } 468 | 469 | internal: 470 | virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) override 471 | { 472 | recorder->Record(context, action, probability, unique_key); 473 | } 474 | 475 | private: 476 | IRecorder^ recorder; 477 | String^ appId; 478 | }; 479 | 480 | /// 481 | /// Represents a feature in a sparse array. 482 | /// 483 | [StructLayout(LayoutKind::Sequential)] 484 | public value struct Feature 485 | { 486 | float Value; 487 | UInt32 Id; 488 | }; 489 | 490 | /// 491 | /// A sample recorder class that converts the exploration tuple into string format. 492 | /// 493 | /// The Context type. 494 | generic where Ctx : IStringContext 495 | public ref class StringRecorder : public IRecorder, public ToStringCallback 496 | { 497 | public: 498 | StringRecorder() 499 | { 500 | m_string_recorder = new NativeMultiWorldTesting::StringRecorder(); 501 | } 502 | 503 | ~StringRecorder() 504 | { 505 | delete m_string_recorder; 506 | } 507 | 508 | virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey) 509 | { 510 | // Normal handles are sufficient here since native code will only hold references and not access the object's data 511 | // https://www.microsoftpressstore.com/articles/article.aspx?p=2224054&seqNum=4 512 | GCHandle contextHandle = GCHandle::Alloc(context); 513 | IntPtr contextPtr = (IntPtr)contextHandle; 514 | 515 | NativeStringContext native_context(contextPtr.ToPointer(), GetCallback()); 516 | m_string_recorder->Record(native_context, (u32)action, probability, marshal_as(uniqueKey)); 517 | 518 | contextHandle.Free(); 519 | } 520 | 521 | /// 522 | /// Gets the content of the recording so far as a string and clears internal content. 523 | /// 524 | /// 525 | /// A string with recording content. 526 | /// 527 | String^ GetRecording() 528 | { 529 | // Workaround for C++-CLI bug which does not allow default value for parameter 530 | return GetRecording(true); 531 | } 532 | 533 | /// 534 | /// Gets the content of the recording so far as a string and optionally clears internal content. 535 | /// 536 | /// A boolean value indicating whether to clear the internal content. 537 | /// 538 | /// A string with recording content. 539 | /// 540 | String^ GetRecording(bool flush) 541 | { 542 | return gcnew String(m_string_recorder->Get_Recording(flush).c_str()); 543 | } 544 | 545 | private: 546 | NativeMultiWorldTesting::StringRecorder* m_string_recorder; 547 | }; 548 | 549 | /// 550 | /// A sample context class that stores a vector of Features. 551 | /// 552 | public ref class SimpleContext : public IStringContext 553 | { 554 | public: 555 | SimpleContext(cli::array^ features) 556 | { 557 | Features = features; 558 | 559 | // TODO: add another constructor overload for native SimpleContext to avoid copying feature values 560 | m_features = new vector(); 561 | for (int i = 0; i < features->Length; i++) 562 | { 563 | m_features->push_back({ features[i].Value, features[i].Id }); 564 | } 565 | 566 | m_native_context = new NativeMultiWorldTesting::SimpleContext(*m_features); 567 | } 568 | 569 | String^ ToString() override 570 | { 571 | return gcnew String(m_native_context->To_String().c_str()); 572 | } 573 | 574 | ~SimpleContext() 575 | { 576 | delete m_native_context; 577 | } 578 | 579 | public: 580 | cli::array^ GetFeatures() { return Features; } 581 | 582 | internal: 583 | cli::array^ Features; 584 | 585 | private: 586 | vector* m_features; 587 | NativeMultiWorldTesting::SimpleContext* m_native_context; 588 | }; 589 | } 590 | 591 | /*! @} End of Doxygen Groups*/ 592 | -------------------------------------------------------------------------------- /clr/explore_interface.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | using namespace System; 4 | using namespace System::Collections::Generic; 5 | 6 | /** \defgroup MultiWorldTestingCsharp 7 | \brief C# implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs 8 | */ 9 | 10 | /*! 11 | * \addtogroup MultiWorldTestingCsharp 12 | * @{ 13 | */ 14 | 15 | //! Interface for C# version of Multiworld Testing library. 16 | //! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs 17 | namespace MultiWorldTesting { 18 | 19 | /// 20 | /// Represents a recorder that exposes a method to record exploration data based on generic contexts. 21 | /// 22 | /// The Context type. 23 | /// 24 | /// Exploration data is specified as a set of tuples (context, action, probability, key) as described below. An 25 | /// application passes an IRecorder object to the @MwtExplorer constructor. See 26 | /// @StringRecorder for a sample IRecorder object. 27 | /// 28 | generic 29 | public interface class IRecorder 30 | { 31 | public: 32 | /// 33 | /// Records the exploration data associated with a given decision. 34 | /// This implementation should be thread-safe if multithreading is needed. 35 | /// 36 | /// A user-defined context for the decision. 37 | /// Chosen by an exploration algorithm given context. 38 | /// The probability of the chosen action given context. 39 | /// A user-defined identifer for the decision. 40 | virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey) = 0; 41 | }; 42 | 43 | /// 44 | /// Exposes a method for choosing an action given a generic context. IPolicy objects are 45 | /// passed to (and invoked by) exploration algorithms to specify the default policy behavior. 46 | /// 47 | /// The Context type. 48 | generic 49 | public interface class IPolicy 50 | { 51 | public: 52 | /// 53 | /// Determines the action to take for a given context. 54 | /// This implementation should be thread-safe if multithreading is needed. 55 | /// 56 | /// A user-defined context for the decision. 57 | /// Index of the action to take (1-based) 58 | virtual UInt32 ChooseAction(Ctx context) = 0; 59 | }; 60 | 61 | /// 62 | /// Exposes a method for specifying a score (weight) for each action given a generic context. 63 | /// 64 | /// The Context type. 65 | generic 66 | public interface class IScorer 67 | { 68 | public: 69 | /// 70 | /// Determines the score of each action for a given context. 71 | /// This implementation should be thread-safe if multithreading is needed. 72 | /// 73 | /// A user-defined context for the decision. 74 | /// Vector of scores indexed by action (1-based). 75 | virtual List^ ScoreActions(Ctx context) = 0; 76 | }; 77 | 78 | /// 79 | /// Represents a context interface with variable number of actions which is 80 | /// enforced if exploration algorithm is initialized in variable number of actions mode. 81 | /// 82 | public interface class IVariableActionContext 83 | { 84 | public: 85 | /// 86 | /// Gets the number of actions for the current context. 87 | /// 88 | /// The number of actions available for the current context. 89 | virtual UInt32 GetNumberOfActions() = 0; 90 | }; 91 | 92 | generic 93 | public interface class IExplorer 94 | { 95 | public: 96 | virtual void EnableExplore(bool explore) = 0; 97 | }; 98 | 99 | generic 100 | public interface class IConsumePolicy 101 | { 102 | public: 103 | virtual void UpdatePolicy(IPolicy^ newPolicy) = 0; 104 | }; 105 | 106 | generic 107 | public interface class IConsumePolicies 108 | { 109 | public: 110 | virtual void UpdatePolicy(cli::array^>^ newPolicies) = 0; 111 | }; 112 | 113 | generic 114 | public interface class IConsumeScorer 115 | { 116 | public: 117 | virtual void UpdateScorer(IScorer^ newScorer) = 0; 118 | }; 119 | 120 | public interface class IStringContext 121 | { 122 | public: 123 | virtual String^ ToString() = 0; 124 | }; 125 | 126 | } 127 | 128 | /*! @} End of Doxygen Groups*/ -------------------------------------------------------------------------------- /clr/explore_interop.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #define MANAGED_CODE 4 | 5 | #include "explore_interface.h" 6 | #include "MWTExplorer.h" 7 | 8 | #include 9 | 10 | using namespace System; 11 | using namespace System::Collections::Generic; 12 | using namespace System::IO; 13 | using namespace System::Runtime::InteropServices; 14 | using namespace System::Xml::Serialization; 15 | using namespace msclr::interop; 16 | 17 | namespace MultiWorldTesting { 18 | 19 | // Context callback 20 | private delegate UInt32 ClrContextGetNumActionsCallback(IntPtr contextPtr); 21 | typedef u32 Native_Context_Get_Num_Actions_Callback(void* context); 22 | 23 | // Policy callback 24 | private delegate UInt32 ClrPolicyCallback(IntPtr explorerPtr, IntPtr contextPtr, int index); 25 | typedef u32 Native_Policy_Callback(void* explorer, void* context, int index); 26 | 27 | // Scorer callback 28 | private delegate void ClrScorerCallback(IntPtr explorerPtr, IntPtr contextPtr, IntPtr scores, IntPtr size); 29 | typedef void Native_Scorer_Callback(void* explorer, void* context, float* scores[], u32* size); 30 | 31 | // Recorder callback 32 | private delegate void ClrRecorderCallback(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKey); 33 | typedef void Native_Recorder_Callback(void* mwt, void* context, u32 action, float probability, void* unique_key); 34 | 35 | // ToString callback 36 | private delegate void ClrToStringCallback(IntPtr contextPtr, IntPtr stringValue); 37 | typedef void Native_To_String_Callback(void* explorer, void* string_value); 38 | 39 | // NativeContext travels through interop space and contains instances of Mwt, Explorer, Context 40 | // used for triggering callback for Policy, Scorer, Recorder 41 | class NativeContext : public NativeMultiWorldTesting::IVariableActionContext 42 | { 43 | public: 44 | NativeContext(void* clr_mwt, void* clr_explorer, void* clr_context, 45 | Native_Context_Get_Num_Actions_Callback* callback_num_actions) 46 | { 47 | m_clr_mwt = clr_mwt; 48 | m_clr_explorer = clr_explorer; 49 | m_clr_context = clr_context; 50 | 51 | m_callback_num_actions = callback_num_actions; 52 | } 53 | 54 | u32 Get_Number_Of_Actions() 55 | { 56 | return m_callback_num_actions(m_clr_context); 57 | } 58 | 59 | void* Get_Clr_Mwt() 60 | { 61 | return m_clr_mwt; 62 | } 63 | 64 | void* Get_Clr_Context() 65 | { 66 | return m_clr_context; 67 | } 68 | 69 | void* Get_Clr_Explorer() 70 | { 71 | return m_clr_explorer; 72 | } 73 | 74 | private: 75 | void* m_clr_mwt; 76 | void* m_clr_context; 77 | void* m_clr_explorer; 78 | 79 | private: 80 | Native_Context_Get_Num_Actions_Callback* m_callback_num_actions; 81 | }; 82 | 83 | class NativeStringContext 84 | { 85 | public: 86 | NativeStringContext(void* clr_context, Native_To_String_Callback* func) : m_func(func) 87 | { 88 | m_clr_context = clr_context; 89 | } 90 | 91 | string To_String() 92 | { 93 | string value; 94 | m_func(m_clr_context, &value); 95 | return value; 96 | } 97 | private: 98 | void* m_clr_context; 99 | Native_To_String_Callback* const m_func; 100 | }; 101 | 102 | // NativeRecorder listens to callback event and reroute it to the managed Recorder instance 103 | class NativeRecorder : public NativeMultiWorldTesting::IRecorder 104 | { 105 | public: 106 | NativeRecorder(Native_Recorder_Callback* native_func) : m_func(native_func) 107 | { 108 | } 109 | 110 | void Record(NativeContext& context, u32 action, float probability, string unique_key) 111 | { 112 | // Normal handles are sufficient here since native code will only hold references and not access the object's data 113 | // https://www.microsoftpressstore.com/articles/article.aspx?p=2224054&seqNum=4 114 | GCHandle uniqueKeyHandle = GCHandle::Alloc(gcnew String(unique_key.c_str())); 115 | try 116 | { 117 | IntPtr uniqueKeyPtr = (IntPtr)uniqueKeyHandle; 118 | 119 | m_func(context.Get_Clr_Mwt(), context.Get_Clr_Context(), action, probability, uniqueKeyPtr.ToPointer()); 120 | } 121 | finally 122 | { 123 | if (uniqueKeyHandle.IsAllocated) 124 | { 125 | uniqueKeyHandle.Free(); 126 | } 127 | } 128 | } 129 | private: 130 | Native_Recorder_Callback* const m_func; 131 | }; 132 | 133 | // NativePolicy listens to callback event and reroute it to the managed Policy instance 134 | class NativePolicy : public NativeMultiWorldTesting::IPolicy 135 | { 136 | public: 137 | NativePolicy(Native_Policy_Callback* func, int index = -1) : m_func(func) 138 | { 139 | m_index = index; 140 | } 141 | 142 | u32 Choose_Action(NativeContext& context) 143 | { 144 | return m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), m_index); 145 | } 146 | 147 | private: 148 | Native_Policy_Callback* const m_func; 149 | int m_index; 150 | }; 151 | 152 | class NativeScorer : public NativeMultiWorldTesting::IScorer 153 | { 154 | public: 155 | NativeScorer(Native_Scorer_Callback* func) : m_func(func) 156 | { 157 | } 158 | 159 | vector Score_Actions(NativeContext& context) 160 | { 161 | float* scores = nullptr; 162 | u32 num_scores = 0; 163 | try 164 | { 165 | m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), &scores, &num_scores); 166 | 167 | // It's ok if scores is null, vector will be empty 168 | vector scores_vector(scores, scores + num_scores); 169 | 170 | return scores_vector; 171 | } 172 | finally 173 | { 174 | delete[] scores; 175 | } 176 | } 177 | private: 178 | Native_Scorer_Callback* const m_func; 179 | }; 180 | 181 | // Triggers callback to the Context instance 182 | generic 183 | public ref class ContextCallback 184 | { 185 | internal: 186 | ContextCallback() 187 | { 188 | contextNumActionsCallback = gcnew ClrContextGetNumActionsCallback(&ContextCallback::InteropInvokeNumActions); 189 | IntPtr contextNumActionsCallbackPtr = Marshal::GetFunctionPointerForDelegate(contextNumActionsCallback); 190 | m_num_actions_callback = static_cast(contextNumActionsCallbackPtr.ToPointer()); 191 | } 192 | 193 | Native_Context_Get_Num_Actions_Callback* GetNumActionsCallback() 194 | { 195 | return m_num_actions_callback; 196 | } 197 | 198 | static UInt32 InteropInvokeNumActions(IntPtr contextPtr) 199 | { 200 | GCHandle contextHandle = (GCHandle)contextPtr; 201 | 202 | return ((IVariableActionContext^)contextHandle.Target)->GetNumberOfActions(); 203 | } 204 | 205 | private: 206 | initonly ClrContextGetNumActionsCallback^ contextNumActionsCallback; 207 | 208 | private: 209 | Native_Context_Get_Num_Actions_Callback* m_num_actions_callback; 210 | }; 211 | 212 | // Triggers callback to the Policy instance to choose an action 213 | generic 214 | public ref class PolicyCallback abstract 215 | { 216 | internal: 217 | virtual UInt32 InvokePolicyCallback(Ctx context, int index) = 0; 218 | 219 | PolicyCallback() 220 | { 221 | policyCallback = gcnew ClrPolicyCallback(&PolicyCallback::InteropInvoke); 222 | IntPtr policyCallbackPtr = Marshal::GetFunctionPointerForDelegate(policyCallback); 223 | m_callback = static_cast(policyCallbackPtr.ToPointer()); 224 | m_native_policy = nullptr; 225 | m_native_policies = nullptr; 226 | } 227 | 228 | ~PolicyCallback() 229 | { 230 | delete m_native_policy; 231 | delete m_native_policies; 232 | } 233 | 234 | NativePolicy* GetNativePolicy() 235 | { 236 | if (m_native_policy == nullptr) 237 | { 238 | m_native_policy = new NativePolicy(m_callback); 239 | } 240 | return m_native_policy; 241 | } 242 | 243 | vector>>* GetNativePolicies(int count) 244 | { 245 | if (m_native_policies == nullptr) 246 | { 247 | m_native_policies = new vector>>(); 248 | for (int i = 0; i < count; i++) 249 | { 250 | m_native_policies->push_back(unique_ptr>(new NativePolicy(m_callback, i))); 251 | } 252 | } 253 | 254 | return m_native_policies; 255 | } 256 | 257 | static UInt32 InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, int index) 258 | { 259 | GCHandle callbackHandle = (GCHandle)callbackPtr; 260 | PolicyCallback^ callback = (PolicyCallback^)callbackHandle.Target; 261 | 262 | GCHandle contextHandle = (GCHandle)contextPtr; 263 | Ctx context = (Ctx)contextHandle.Target; 264 | 265 | return callback->InvokePolicyCallback(context, index); 266 | } 267 | 268 | private: 269 | initonly ClrPolicyCallback^ policyCallback; 270 | 271 | private: 272 | NativePolicy* m_native_policy; 273 | vector>>* m_native_policies; 274 | Native_Policy_Callback* m_callback; 275 | }; 276 | 277 | // Triggers callback to the Recorder instance to record interaction data 278 | generic 279 | public ref class RecorderCallback abstract : public ContextCallback 280 | { 281 | internal: 282 | virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) = 0; 283 | 284 | RecorderCallback() 285 | { 286 | recorderCallback = gcnew ClrRecorderCallback(&RecorderCallback::InteropInvoke); 287 | IntPtr recorderCallbackPtr = Marshal::GetFunctionPointerForDelegate(recorderCallback); 288 | Native_Recorder_Callback* callback = static_cast(recorderCallbackPtr.ToPointer()); 289 | m_native_recorder = new NativeRecorder(callback); 290 | } 291 | 292 | ~RecorderCallback() 293 | { 294 | delete m_native_recorder; 295 | } 296 | 297 | NativeRecorder* GetNativeRecorder() 298 | { 299 | return m_native_recorder; 300 | } 301 | 302 | static void InteropInvoke(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKeyPtr) 303 | { 304 | GCHandle mwtHandle = (GCHandle)mwtPtr; 305 | RecorderCallback^ callback = (RecorderCallback^)mwtHandle.Target; 306 | 307 | GCHandle contextHandle = (GCHandle)contextPtr; 308 | Ctx context = (Ctx)contextHandle.Target; 309 | 310 | GCHandle uniqueKeyHandle = (GCHandle)uniqueKeyPtr; 311 | String^ uniqueKey = (String^)uniqueKeyHandle.Target; 312 | 313 | callback->InvokeRecorderCallback(context, action, probability, uniqueKey); 314 | } 315 | 316 | private: 317 | initonly ClrRecorderCallback^ recorderCallback; 318 | 319 | private: 320 | NativeRecorder* m_native_recorder; 321 | }; 322 | 323 | // Triggers callback to the Recorder instance to record interaction data 324 | generic 325 | public ref class ScorerCallback abstract 326 | { 327 | internal: 328 | virtual List^ InvokeScorerCallback(Ctx context) = 0; 329 | 330 | ScorerCallback() 331 | { 332 | scorerCallback = gcnew ClrScorerCallback(&ScorerCallback::InteropInvoke); 333 | IntPtr scorerCallbackPtr = Marshal::GetFunctionPointerForDelegate(scorerCallback); 334 | Native_Scorer_Callback* callback = static_cast(scorerCallbackPtr.ToPointer()); 335 | m_native_scorer = new NativeScorer(callback); 336 | } 337 | 338 | ~ScorerCallback() 339 | { 340 | delete m_native_scorer; 341 | } 342 | 343 | NativeScorer* GetNativeScorer() 344 | { 345 | return m_native_scorer; 346 | } 347 | 348 | static void InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, IntPtr scoresPtr, IntPtr sizePtr) 349 | { 350 | GCHandle callbackHandle = (GCHandle)callbackPtr; 351 | ScorerCallback^ callback = (ScorerCallback^)callbackHandle.Target; 352 | 353 | GCHandle contextHandle = (GCHandle)contextPtr; 354 | Ctx context = (Ctx)contextHandle.Target; 355 | 356 | List^ scoreList = callback->InvokeScorerCallback(context); 357 | 358 | if (scoreList == nullptr || scoreList->Count == 0) 359 | { 360 | return; 361 | } 362 | 363 | u32* num_scores = (u32*)sizePtr.ToPointer(); 364 | *num_scores = (u32)scoreList->Count; 365 | 366 | float* scores = new float[*num_scores]; 367 | for (u32 i = 0; i < *num_scores; i++) 368 | { 369 | scores[i] = scoreList[i]; 370 | } 371 | 372 | float** native_scores = (float**)scoresPtr.ToPointer(); 373 | *native_scores = scores; 374 | } 375 | 376 | private: 377 | initonly ClrScorerCallback^ scorerCallback; 378 | 379 | private: 380 | NativeScorer* m_native_scorer; 381 | }; 382 | 383 | // Triggers callback to the Context instance to perform ToString() operation 384 | generic where Ctx : IStringContext 385 | public ref class ToStringCallback 386 | { 387 | internal: 388 | ToStringCallback() 389 | { 390 | toStringCallback = gcnew ClrToStringCallback(&ToStringCallback::InteropInvoke); 391 | IntPtr toStringCallbackPtr = Marshal::GetFunctionPointerForDelegate(toStringCallback); 392 | m_callback = static_cast(toStringCallbackPtr.ToPointer()); 393 | } 394 | 395 | Native_To_String_Callback* GetCallback() 396 | { 397 | return m_callback; 398 | } 399 | 400 | static void InteropInvoke(IntPtr contextPtr, IntPtr stringPtr) 401 | { 402 | GCHandle contextHandle = (GCHandle)contextPtr; 403 | Ctx context = (Ctx)contextHandle.Target; 404 | 405 | string* out_string = (string*)stringPtr.ToPointer(); 406 | *out_string = marshal_as(context->ToString()); 407 | } 408 | 409 | private: 410 | initonly ClrToStringCallback^ toStringCallback; 411 | 412 | private: 413 | Native_To_String_Callback* m_callback; 414 | }; 415 | 416 | } -------------------------------------------------------------------------------- /explore.cpp: -------------------------------------------------------------------------------- 1 | // explore.cpp : Timing code to measure performance of MWT Explorer library 2 | 3 | #include "MWTExplorer.h" 4 | #include 5 | #include 6 | #include 7 | 8 | using namespace std; 9 | using namespace std::chrono; 10 | 11 | using namespace MultiWorldTesting; 12 | 13 | class MySimplePolicy : public IPolicy 14 | { 15 | public: 16 | u32 Choose_Action(SimpleContext& context) 17 | { 18 | return (u32)1; 19 | } 20 | }; 21 | 22 | const u32 num_actions = 10; 23 | 24 | void Clock_Explore() 25 | { 26 | float epsilon = .2f; 27 | string unique_key = "key"; 28 | int num_features = 1000; 29 | int num_iter = 10000; 30 | int num_warmup = 100; 31 | int num_interactions = 1; 32 | 33 | // pre-create features 34 | vector features; 35 | for (int i = 0; i < num_features; i++) 36 | { 37 | Feature f = {0.5, i+1}; 38 | features.push_back(f); 39 | } 40 | 41 | long long time_init = 0, time_choose = 0; 42 | for (int iter = 0; iter < num_iter + num_warmup; iter++) 43 | { 44 | high_resolution_clock::time_point t1 = high_resolution_clock::now(); 45 | StringRecorder recorder; 46 | MwtExplorer mwt("test", recorder); 47 | MySimplePolicy default_policy; 48 | EpsilonGreedyExplorer explorer(default_policy, epsilon, num_actions); 49 | high_resolution_clock::time_point t2 = high_resolution_clock::now(); 50 | time_init += iter < num_warmup ? 0 : duration_cast(t2 - t1).count(); 51 | 52 | t1 = high_resolution_clock::now(); 53 | SimpleContext appContext(features); 54 | for (int i = 0; i < num_interactions; i++) 55 | { 56 | mwt.Choose_Action(explorer, unique_key, appContext); 57 | } 58 | t2 = high_resolution_clock::now(); 59 | time_choose += iter < num_warmup ? 0 : duration_cast(t2 - t1).count(); 60 | } 61 | 62 | cout << "# iterations: " << num_iter << ", # interactions: " << num_interactions << ", # context features: " << num_features << endl; 63 | cout << "--- PER ITERATION ---" << endl; 64 | cout << "Init: " << (double)time_init / num_iter << " micro" << endl; 65 | cout << "Choose Action: " << (double)time_choose / (num_iter * num_interactions) << " micro" << endl; 66 | cout << "--- TOTAL TIME ---: " << (time_init + time_choose) << " micro" << endl; 67 | } 68 | 69 | int main(int argc, char* argv[]) 70 | { 71 | Clock_Explore(); 72 | return 0; 73 | } 74 | -------------------------------------------------------------------------------- /explore.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 2013 4 | VisualStudioVersion = 12.0.30723.0 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "explore", "explore.vcxproj", "{FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}" 7 | EndProject 8 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "explore_static", "static\explore_static.vcxproj", "{ACE47E98-488C-4CDF-B9F1-36337B2855AD}" 9 | EndProject 10 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "explore_clr", "clr\explore_clr.vcxproj", "{8400DA16-1F46-4A31-A126-BBE16F62BFD7}" 11 | EndProject 12 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "explore_tests", "tests\explore_tests.vcxproj", "{5AE3AA40-BEB0-4979-8166-3B885172C430}" 13 | EndProject 14 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ExploreTests", "tests\ExploreTests.csproj", "{CB0C6B20-560C-4822-8EF6-DA999A64B542}" 15 | EndProject 16 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ExploreSample", "ExploreSample\ExploreSample.csproj", "{7081D542-AE64-485D-9087-79194B958499}" 17 | EndProject 18 | Global 19 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 20 | Debug|Win32 = Debug|Win32 21 | Debug|x64 = Debug|x64 22 | Release|Win32 = Release|Win32 23 | Release|x64 = Release|x64 24 | EndGlobalSection 25 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 26 | {FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Debug|Win32.ActiveCfg = Debug|Win32 27 | {FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Debug|Win32.Build.0 = Debug|Win32 28 | {FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Debug|x64.ActiveCfg = Debug|x64 29 | {FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Debug|x64.Build.0 = Debug|x64 30 | {FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Release|Win32.ActiveCfg = Release|Win32 31 | {FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Release|Win32.Build.0 = Release|Win32 32 | {FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Release|x64.ActiveCfg = Release|x64 33 | {FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Release|x64.Build.0 = Release|x64 34 | {ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Debug|Win32.ActiveCfg = Debug|Win32 35 | {ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Debug|Win32.Build.0 = Debug|Win32 36 | {ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Debug|x64.ActiveCfg = Debug|x64 37 | {ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Debug|x64.Build.0 = Debug|x64 38 | {ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Release|Win32.ActiveCfg = Release|Win32 39 | {ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Release|Win32.Build.0 = Release|Win32 40 | {ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Release|x64.ActiveCfg = Release|x64 41 | {ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Release|x64.Build.0 = Release|x64 42 | {8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Debug|Win32.ActiveCfg = Debug|Win32 43 | {8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Debug|Win32.Build.0 = Debug|Win32 44 | {8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Debug|x64.ActiveCfg = Debug|x64 45 | {8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Debug|x64.Build.0 = Debug|x64 46 | {8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Release|Win32.ActiveCfg = Release|Win32 47 | {8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Release|Win32.Build.0 = Release|Win32 48 | {8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Release|x64.ActiveCfg = Release|x64 49 | {8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Release|x64.Build.0 = Release|x64 50 | {5AE3AA40-BEB0-4979-8166-3B885172C430}.Debug|Win32.ActiveCfg = Debug|Win32 51 | {5AE3AA40-BEB0-4979-8166-3B885172C430}.Debug|Win32.Build.0 = Debug|Win32 52 | {5AE3AA40-BEB0-4979-8166-3B885172C430}.Debug|x64.ActiveCfg = Debug|x64 53 | {5AE3AA40-BEB0-4979-8166-3B885172C430}.Debug|x64.Build.0 = Debug|x64 54 | {5AE3AA40-BEB0-4979-8166-3B885172C430}.Release|Win32.ActiveCfg = Release|Win32 55 | {5AE3AA40-BEB0-4979-8166-3B885172C430}.Release|Win32.Build.0 = Release|Win32 56 | {5AE3AA40-BEB0-4979-8166-3B885172C430}.Release|x64.ActiveCfg = Release|x64 57 | {5AE3AA40-BEB0-4979-8166-3B885172C430}.Release|x64.Build.0 = Release|x64 58 | {CB0C6B20-560C-4822-8EF6-DA999A64B542}.Debug|Win32.ActiveCfg = Debug|x86 59 | {CB0C6B20-560C-4822-8EF6-DA999A64B542}.Debug|Win32.Build.0 = Debug|x86 60 | {CB0C6B20-560C-4822-8EF6-DA999A64B542}.Debug|x64.ActiveCfg = Debug|x64 61 | {CB0C6B20-560C-4822-8EF6-DA999A64B542}.Debug|x64.Build.0 = Debug|x64 62 | {CB0C6B20-560C-4822-8EF6-DA999A64B542}.Release|Win32.ActiveCfg = Release|x86 63 | {CB0C6B20-560C-4822-8EF6-DA999A64B542}.Release|Win32.Build.0 = Release|x86 64 | {CB0C6B20-560C-4822-8EF6-DA999A64B542}.Release|x64.ActiveCfg = Release|x64 65 | {CB0C6B20-560C-4822-8EF6-DA999A64B542}.Release|x64.Build.0 = Release|x64 66 | {7081D542-AE64-485D-9087-79194B958499}.Debug|Win32.ActiveCfg = Debug|x86 67 | {7081D542-AE64-485D-9087-79194B958499}.Debug|Win32.Build.0 = Debug|x86 68 | {7081D542-AE64-485D-9087-79194B958499}.Debug|x64.ActiveCfg = Debug|x64 69 | {7081D542-AE64-485D-9087-79194B958499}.Debug|x64.Build.0 = Debug|x64 70 | {7081D542-AE64-485D-9087-79194B958499}.Release|Win32.ActiveCfg = Release|x86 71 | {7081D542-AE64-485D-9087-79194B958499}.Release|Win32.Build.0 = Release|x86 72 | {7081D542-AE64-485D-9087-79194B958499}.Release|x64.ActiveCfg = Release|x64 73 | {7081D542-AE64-485D-9087-79194B958499}.Release|x64.Build.0 = Release|x64 74 | EndGlobalSection 75 | GlobalSection(SolutionProperties) = preSolution 76 | HideSolutionNode = FALSE 77 | EndGlobalSection 78 | EndGlobal 79 | -------------------------------------------------------------------------------- /explore.vcxproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | Debug 6 | Win32 7 | 8 | 9 | Debug 10 | x64 11 | 12 | 13 | Release 14 | Win32 15 | 16 | 17 | Release 18 | x64 19 | 20 | 21 | 22 | {FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4} 23 | Win32Proj 24 | vw_explore 25 | explore 26 | 27 | 28 | 29 | Application 30 | true 31 | v120 32 | Unicode 33 | 34 | 35 | Application 36 | true 37 | v120 38 | Unicode 39 | 40 | 41 | Application 42 | false 43 | v120 44 | true 45 | Unicode 46 | 47 | 48 | Application 49 | false 50 | v120 51 | true 52 | Unicode 53 | 54 | 55 | c:\boost\x64\include\boost-1_56 56 | c:\boost\x64\lib 57 | ..\..\zlib-1.2.8 58 | $(ZlibIncludeDir)\contrib\vstudio\vc10\x64\ZlibStat$(Configuration) 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | true 78 | 79 | 80 | true 81 | 82 | 83 | false 84 | 85 | 86 | false 87 | 88 | 89 | 90 | 91 | 92 | Level3 93 | Disabled 94 | WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions) 95 | static; 96 | 97 | 98 | Console 99 | true 100 | 101 | 102 | 103 | 104 | 105 | 106 | Level3 107 | Disabled 108 | WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions) 109 | static; 110 | 111 | 112 | Console 113 | true 114 | 115 | 116 | 117 | 118 | Level3 119 | 120 | 121 | MaxSpeed 122 | true 123 | true 124 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 125 | true 126 | static; 127 | 128 | 129 | Console 130 | true 131 | true 132 | true 133 | 134 | 135 | 136 | 137 | Level3 138 | 139 | 140 | MaxSpeed 141 | true 142 | true 143 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 144 | true 145 | static; 146 | 147 | 148 | Console 149 | true 150 | true 151 | true 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | true 160 | true 161 | true 162 | true 163 | 164 | 165 | 166 | 167 | 168 | {ace47e98-488c-4cdf-b9f1-36337b2855ad} 169 | 170 | 171 | 172 | 173 | 174 | -------------------------------------------------------------------------------- /explore.vcxproj.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF} 6 | cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx 7 | 8 | 9 | {93995380-89BD-4b04-88EB-625FBE52EBFB} 10 | h;hpp;hxx;hm;inl;inc;xsd 11 | 12 | 13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} 14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | Source Files 23 | 24 | 25 | Source Files 26 | 27 | 28 | -------------------------------------------------------------------------------- /explore_sample.cpp: -------------------------------------------------------------------------------- 1 | // vw_explore.cpp : Defines the entry point for the console application. 2 | // 3 | 4 | #include "MWTExplorer.h" 5 | #include 6 | #include 7 | #include 8 | 9 | using namespace std; 10 | using namespace std::chrono; 11 | using namespace MultiWorldTesting; 12 | 13 | /// Example of a custom context. 14 | class MyContext 15 | { 16 | 17 | }; 18 | 19 | /// Example of a custom policy which implements the IPolicy, 20 | /// declaring that this policy only interacts with MyContext objects. 21 | class MyPolicy : public IPolicy 22 | { 23 | public: 24 | u32 Choose_Action(MyContext& context) 25 | { 26 | // Always returns the same action regardless of context 27 | return (u32)1; 28 | } 29 | }; 30 | 31 | /// Example of a custom policy which implements the IPolicy, 32 | /// declaring that this policy only interacts with SimpleContext objects. 33 | class MySimplePolicy : public IPolicy 34 | { 35 | public: 36 | u32 Choose_Action(SimpleContext& context) 37 | { 38 | // Always returns the same action regardless of context 39 | return (u32)1; 40 | } 41 | }; 42 | 43 | /// Example of a custom scorer which implements the IScorer, 44 | /// declaring that this scorer only interacts with MyContext objects. 45 | class MyScorer : public IScorer 46 | { 47 | public: 48 | MyScorer(u32 num_actions) : m_num_actions(num_actions) 49 | { 50 | 51 | } 52 | vector Score_Actions(MyContext& context) 53 | { 54 | vector scores; 55 | for (size_t i = 0; i < m_num_actions; i++) 56 | { 57 | // Gives every action the same score (which results in a uniform distribution). 58 | scores.push_back(.1f); 59 | } 60 | return scores; 61 | } 62 | private: 63 | u32 m_num_actions; 64 | }; 65 | 66 | /// 67 | /// Represents a tuple . 68 | /// 69 | template 70 | struct MyInteraction 71 | { 72 | Ctx Context; 73 | u32 Action; 74 | float Probability; 75 | string Unique_Key; 76 | }; 77 | 78 | /// Example of a custom recorder which implements the IRecorder, 79 | /// declaring that this recorder only interacts with MyContext objects. 80 | class MyRecorder : public IRecorder 81 | { 82 | public: 83 | virtual void Record(MyContext& context, u32 action, float probability, string unique_key) 84 | { 85 | // Stores the tuple internally in a vector that could be used later for other purposes. 86 | m_interactions.push_back({ context, action, probability, unique_key }); 87 | } 88 | private: 89 | vector> m_interactions; 90 | }; 91 | 92 | int main(int argc, char* argv[]) 93 | { 94 | if (argc < 2) 95 | { 96 | cerr << "arguments: {greedy,tau-first,bootstrap,softmax,generic}" << endl; 97 | exit(1); 98 | } 99 | 100 | // Arguments for individual explorers 101 | if (strcmp(argv[1], "greedy") == 0) 102 | { 103 | // Initialize Epsilon-Greedy explore algorithm using MyPolicy 104 | 105 | // Creates a recorder of built-in StringRecorder type for string serialization 106 | StringRecorder recorder; 107 | 108 | // Creates a policy that interacts with SimpleContext type 109 | MySimplePolicy default_policy; 110 | 111 | // Creates an MwtExplorer instance using the recorder above 112 | MwtExplorer mwt("appid", recorder); 113 | 114 | u32 num_actions = 10; 115 | float epsilon = .2f; 116 | // Creates an Epsilon-Greedy explorer using the specified settings 117 | EpsilonGreedyExplorer explorer(default_policy, epsilon, num_actions); 118 | 119 | // Creates a context of built-in SimpleContext type 120 | vector features; 121 | features.push_back({ 0.5f, 1 }); 122 | features.push_back({ 1.3f, 11 }); 123 | features.push_back({ -.95f, 413 }); 124 | 125 | SimpleContext context(features); 126 | 127 | // Performs exploration by passing an instance of the Epsilon-Greedy exploration algorithm into MwtExplorer 128 | // using a sample string to uniquely identify this event 129 | string unique_key = "eventid"; 130 | u32 action = mwt.Choose_Action(explorer, unique_key, context); 131 | 132 | cout << "Chosen action = " << action << endl; 133 | cout << "Exploration record = " << recorder.Get_Recording(); 134 | } 135 | else if (strcmp(argv[1], "tau-first") == 0) 136 | { 137 | // Initialize Tau-First explore algorithm using MyPolicy 138 | MyRecorder recorder; 139 | MwtExplorer mwt("appid", recorder); 140 | 141 | int num_actions = 10; 142 | u32 tau = 5; 143 | MyPolicy default_policy; 144 | TauFirstExplorer explorer(default_policy, tau, num_actions); 145 | MyContext ctx; 146 | string unique_key = "eventid"; 147 | u32 action = mwt.Choose_Action(explorer, unique_key, ctx); 148 | 149 | cout << "action = " << action << endl; 150 | } 151 | else if (strcmp(argv[1], "bootstrap") == 0) 152 | { 153 | // Initialize Bootstrap explore algorithm using MyPolicy 154 | MyRecorder recorder; 155 | MwtExplorer mwt("appid", recorder); 156 | 157 | u32 num_bags = 2; 158 | 159 | // Create a vector of smart pointers to default policies using the built-in type PolicyPtr 160 | vector>> policy_functions; 161 | for (size_t i = 0; i < num_bags; i++) 162 | { 163 | policy_functions.push_back(unique_ptr>(new MyPolicy())); 164 | } 165 | int num_actions = 10; 166 | BootstrapExplorer explorer(policy_functions, num_actions); 167 | MyContext ctx; 168 | string unique_key = "eventid"; 169 | u32 action = mwt.Choose_Action(explorer, unique_key, ctx); 170 | 171 | cout << "action = " << action << endl; 172 | } 173 | else if (strcmp(argv[1], "softmax") == 0) 174 | { 175 | // Initialize Softmax explore algorithm using MyScorer 176 | MyRecorder recorder; 177 | MwtExplorer mwt("salt", recorder); 178 | 179 | u32 num_actions = 10; 180 | MyScorer scorer(num_actions); 181 | float lambda = 0.5f; 182 | SoftmaxExplorer explorer(scorer, lambda, num_actions); 183 | 184 | MyContext ctx; 185 | string unique_key = "eventid"; 186 | u32 action = mwt.Choose_Action(explorer, unique_key, ctx); 187 | 188 | cout << "action = " << action << endl; 189 | } 190 | else if (strcmp(argv[1], "generic") == 0) 191 | { 192 | // Initialize Generic explore algorithm using MyScorer 193 | MyRecorder recorder; 194 | MwtExplorer mwt("appid", recorder); 195 | 196 | int num_actions = 10; 197 | MyScorer scorer(num_actions); 198 | GenericExplorer explorer(scorer, num_actions); 199 | MyContext ctx; 200 | string unique_key = "eventid"; 201 | u32 action = mwt.Choose_Action(explorer, unique_key, ctx); 202 | 203 | cout << "action = " << action << endl; 204 | } 205 | else 206 | { 207 | cerr << "unknown exploration type: " << argv[1] << endl; 208 | exit(1); 209 | } 210 | 211 | return 0; 212 | } 213 | -------------------------------------------------------------------------------- /mwt.chm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mwt-ds-explore/6e3a7124d7adbd51434a28b1d48d0d397e1fd3fd/mwt.chm -------------------------------------------------------------------------------- /static/MWTExplorer.h: -------------------------------------------------------------------------------- 1 | // 2 | // Main interface for clients of the Multiworld testing (MWT) service. 3 | // 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #ifdef MANAGED_CODE 19 | #define PORTING_INTERFACE public 20 | #define MWT_NAMESPACE namespace NativeMultiWorldTesting 21 | #else 22 | #define PORTING_INTERFACE private 23 | #define MWT_NAMESPACE namespace MultiWorldTesting 24 | #endif 25 | 26 | using namespace std; 27 | 28 | #include "utility.h" 29 | 30 | /** \defgroup MultiWorldTestingCpp 31 | \brief C++ implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp 32 | */ 33 | 34 | /*! 35 | * \addtogroup MultiWorldTestingCpp 36 | * @{ 37 | */ 38 | 39 | //! Interface for C++ version of Multiworld Testing library. 40 | //! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp 41 | MWT_NAMESPACE { 42 | 43 | // Forward declarations 44 | template 45 | class IRecorder; 46 | template 47 | class IExplorer; 48 | 49 | /// 50 | /// The top-level MwtExplorer class. Using this enables principled and efficient exploration 51 | /// over a set of possible actions, and ensures that the right bits are recorded. 52 | /// 53 | template 54 | class MwtExplorer 55 | { 56 | public: 57 | /// 58 | /// Constructor 59 | /// 60 | /// @param appid This should be unique to your experiment or you risk nasty correlation bugs. 61 | /// @param recorder A user-specified class for recording the appropriate bits for use in evaluation and learning. 62 | /// 63 | MwtExplorer(std::string app_id, IRecorder& recorder) : m_recorder(recorder) 64 | { 65 | m_app_id = HashUtils::Compute_Id_Hash(app_id); 66 | } 67 | 68 | /// 69 | /// Chooses an action by invoking an underlying exploration algorithm. This should be a 70 | /// drop-in replacement for any existing policy function. 71 | /// 72 | /// @param explorer An existing exploration algorithm (one of the below) which uses the default policy as a callback. 73 | /// @param unique_key A unique identifier for the experimental unit. This could be a user id, a session id, etc.. 74 | /// @param context The context upon which a decision is made. See SimpleContext below for an example. 75 | /// 76 | u32 Choose_Action(IExplorer& explorer, string unique_key, Ctx& context) 77 | { 78 | u64 seed = HashUtils::Compute_Id_Hash(unique_key); 79 | 80 | std::tuple action_probability_log_tuple = explorer.Choose_Action(seed + m_app_id, context); 81 | 82 | u32 action = std::get<0>(action_probability_log_tuple); 83 | float prob = std::get<1>(action_probability_log_tuple); 84 | 85 | if (std::get<2>(action_probability_log_tuple)) 86 | { 87 | m_recorder.Record(context, action, prob, unique_key); 88 | } 89 | 90 | return action; 91 | } 92 | 93 | private: 94 | u64 m_app_id; 95 | IRecorder& m_recorder; 96 | }; 97 | 98 | /// 99 | /// Exposes a method to record exploration data based on generic contexts. Exploration data 100 | /// is specified as a set of tuples as described below. An 101 | /// application passes an IRecorder object to the @MwtExplorer constructor. See 102 | /// @StringRecorder for a sample IRecorder object. 103 | /// 104 | template 105 | class IRecorder 106 | { 107 | public: 108 | /// 109 | /// Records the exploration data associated with a given decision. 110 | /// This implementation should be thread-safe if multithreading is needed. 111 | /// 112 | /// @param context A user-defined context for the decision 113 | /// @param action The action chosen by an exploration algorithm given context 114 | /// @param probability The probability the exploration algorithm chose said action 115 | /// @param unique_key A user-defined unique identifer for the decision 116 | /// 117 | virtual void Record(Ctx& context, u32 action, float probability, string unique_key) = 0; 118 | virtual ~IRecorder() { } 119 | }; 120 | 121 | /// 122 | /// Exposes a method to choose an action given a generic context, and obtain the relevant 123 | /// exploration bits. Invokes IPolicy::Choose_Action internally. Do not implement this 124 | /// interface yourself: instead, use the various exploration algorithms below, which 125 | /// implement it for you. 126 | /// 127 | template 128 | class IExplorer 129 | { 130 | public: 131 | /// 132 | /// Determines the action to take and the probability with which it was chosen, for a 133 | /// given context. 134 | /// 135 | /// @param salted_seed A PRG seed based on a unique id information provided by the user 136 | /// @param context A user-defined context for the decision 137 | /// @returns The action to take, the probability it was chosen, and a flag indicating 138 | /// whether to record this decision 139 | /// 140 | virtual std::tuple Choose_Action(u64 salted_seed, Ctx& context) = 0; 141 | virtual void Enable_Explore(bool explore) = 0; 142 | virtual ~IExplorer() { } 143 | }; 144 | 145 | /// 146 | /// Exposes a method to choose an action given a generic context. IPolicy objects are 147 | /// passed to (and invoked by) exploration algorithms to specify the default policy behavior. 148 | /// 149 | template 150 | class IPolicy 151 | { 152 | public: 153 | /// 154 | /// Determines the action to take for a given context. 155 | /// This implementation should be thread-safe if multithreading is needed. 156 | /// 157 | /// @param context A user-defined context for the decision 158 | /// @returns The action to take (1-based index) 159 | /// 160 | virtual u32 Choose_Action(Ctx& context) = 0; 161 | virtual ~IPolicy() { } 162 | }; 163 | 164 | /// 165 | /// Exposes a method for specifying a score (weight) for each action given a generic context. 166 | /// 167 | template 168 | class IScorer 169 | { 170 | public: 171 | /// 172 | /// Determines the score of each action for a given context. 173 | /// This implementation should be thread-safe if multithreading is needed. 174 | /// 175 | /// @param context A user-defined context for the decision 176 | /// @returns A vector of scores indexed by action (1-based) 177 | /// 178 | virtual vector Score_Actions(Ctx& context) = 0; 179 | virtual ~IScorer() { } 180 | }; 181 | 182 | /// 183 | /// Represents a context interface with variable number of actions which is 184 | /// enforced if exploration algorithm is initialized in variable number of actions mode. 185 | /// 186 | class IVariableActionContext 187 | { 188 | public: 189 | /// 190 | /// Gets the number of actions for the current context. 191 | /// 192 | /// @returns The number of actions available for the current context. 193 | /// 194 | virtual u32 Get_Number_Of_Actions() = 0; 195 | virtual ~IVariableActionContext() { } 196 | }; 197 | 198 | template 199 | class IConsumePolicy 200 | { 201 | public: 202 | virtual void Update_Policy(IPolicy& new_policy) = 0; 203 | virtual ~IConsumePolicy() { } 204 | }; 205 | 206 | template 207 | class IConsumePolicies 208 | { 209 | public: 210 | virtual void Update_Policy(vector>>& new_policy_functions) = 0; 211 | virtual ~IConsumePolicies() { } 212 | }; 213 | 214 | template 215 | class IConsumeScorer 216 | { 217 | public: 218 | virtual void Update_Scorer(IScorer& new_policy) = 0; 219 | virtual ~IConsumeScorer() { } 220 | }; 221 | 222 | /// 223 | /// A sample recorder class that converts the exploration tuple into string format. 224 | /// 225 | template 226 | struct StringRecorder : public IRecorder 227 | { 228 | void Record(Ctx& context, u32 action, float probability, string unique_key) 229 | { 230 | // Implicitly enforce To_String() API on the context 231 | m_recording.append(to_string((unsigned long long)action)); 232 | m_recording.append(" ", 1); 233 | m_recording.append(unique_key); 234 | m_recording.append(" ", 1); 235 | 236 | char prob_str[10] = { 0 }; 237 | int x = (int)probability; 238 | int d = (int)(fabs(probability - x) * 100000); 239 | sprintf_s(prob_str, 10 * sizeof(char), "%d.%05d", x, d); 240 | m_recording.append(prob_str); 241 | 242 | m_recording.append(" | ", 3); 243 | m_recording.append(context.To_String()); 244 | m_recording.append("\n"); 245 | } 246 | 247 | // Gets the content of the recording so far as a string and optionally clears internal content. 248 | string Get_Recording(bool flush = true) 249 | { 250 | if (!flush) 251 | { 252 | return m_recording; 253 | } 254 | string recording = m_recording; 255 | m_recording.clear(); 256 | return recording; 257 | } 258 | 259 | private: 260 | string m_recording; 261 | }; 262 | 263 | /// 264 | /// Represents a feature in a sparse array. 265 | /// 266 | struct Feature 267 | { 268 | float Value; 269 | u32 Id; 270 | 271 | bool operator==(Feature other_feature) 272 | { 273 | return Id == other_feature.Id; 274 | } 275 | }; 276 | 277 | /// 278 | /// A sample context class that stores a vector of Features. 279 | /// 280 | class SimpleContext 281 | { 282 | public: 283 | SimpleContext(vector& features) : 284 | m_features(features) 285 | { } 286 | 287 | vector& Get_Features() 288 | { 289 | return m_features; 290 | } 291 | 292 | string To_String() 293 | { 294 | string out_string; 295 | const size_t strlen = 35; 296 | char feature_str[strlen] = { 0 }; 297 | for (size_t i = 0; i < m_features.size(); i++) 298 | { 299 | int chars; 300 | if (i == 0) 301 | { 302 | chars = sprintf_s(feature_str, strlen, "%d:", m_features[i].Id); 303 | } 304 | else 305 | { 306 | chars = sprintf_s(feature_str, strlen, " %d:", m_features[i].Id); 307 | } 308 | NumberUtils::print_float(feature_str + chars, strlen-chars, m_features[i].Value); 309 | out_string.append(feature_str); 310 | } 311 | return out_string; 312 | } 313 | 314 | private: 315 | vector& m_features; 316 | }; 317 | 318 | template 319 | static u32 Get_Variable_Number_Of_Actions(Ctx& context, u32 default_num_actions) 320 | { 321 | u32 num_actions = default_num_actions; 322 | if (num_actions == UINT_MAX) 323 | { 324 | num_actions = ((IVariableActionContext*)(&context))->Get_Number_Of_Actions(); 325 | if (num_actions < 1) 326 | { 327 | throw std::invalid_argument("Number of actions must be at least 1."); 328 | } 329 | } 330 | return num_actions; 331 | } 332 | 333 | /// 334 | /// The epsilon greedy exploration algorithm. This is a good choice if you have no idea 335 | /// which actions should be preferred. Epsilon greedy is also computationally cheap. 336 | /// 337 | template 338 | class EpsilonGreedyExplorer : public IExplorer, public IConsumePolicy 339 | { 340 | public: 341 | /// 342 | /// The constructor is the only public member, because this should be used with the MwtExplorer. 343 | /// 344 | /// @param default_policy A default function which outputs an action given a context. 345 | /// @param epsilon The probability of a random exploration. 346 | /// @param num_actions The number of actions to randomize over. 347 | /// 348 | EpsilonGreedyExplorer(IPolicy& default_policy, float epsilon, u32 num_actions) : 349 | m_default_policy(default_policy), m_epsilon(epsilon), m_num_actions(num_actions), m_explore(true) 350 | { 351 | if (m_num_actions < 1) 352 | { 353 | throw std::invalid_argument("Number of actions must be at least 1."); 354 | } 355 | 356 | if (m_epsilon < 0 || m_epsilon > 1) 357 | { 358 | throw std::invalid_argument("Epsilon must be between 0 and 1."); 359 | } 360 | } 361 | 362 | /// 363 | /// Initializes an epsilon greedy explorer with variable number of actions. 364 | /// 365 | /// @param default_policy A default function which outputs an action given a context. 366 | /// @param epsilon The probability of a random exploration. 367 | /// 368 | EpsilonGreedyExplorer(IPolicy& default_policy, float epsilon) : 369 | m_default_policy(default_policy), m_epsilon(epsilon), m_num_actions(UINT_MAX), m_explore(true) 370 | { 371 | if (m_epsilon < 0 || m_epsilon > 1) 372 | { 373 | throw std::invalid_argument("Epsilon must be between 0 and 1."); 374 | } 375 | static_assert(std::is_base_of::value, "The provided context does not implement variable-action interface."); 376 | } 377 | 378 | void Update_Policy(IPolicy& new_policy) 379 | { 380 | m_default_policy = new_policy; 381 | } 382 | 383 | void Enable_Explore(bool explore) 384 | { 385 | m_explore = explore; 386 | } 387 | 388 | private: 389 | std::tuple Choose_Action(u64 salted_seed, Ctx& context) 390 | { 391 | u32 num_actions = ::Get_Variable_Number_Of_Actions(context, m_num_actions); 392 | 393 | PRG::prg random_generator(salted_seed); 394 | 395 | // Invoke the default policy function to get the action 396 | u32 chosen_action = m_default_policy.Choose_Action(context); 397 | 398 | if (chosen_action == 0 || chosen_action > num_actions) 399 | { 400 | throw std::invalid_argument("Action chosen by default policy is not within valid range."); 401 | } 402 | 403 | float epsilon = m_explore ? m_epsilon : 0.f; 404 | 405 | float action_probability = 0.f; 406 | float base_probability = epsilon / num_actions; // uniform probability 407 | 408 | // TODO: check this random generation 409 | if (random_generator.Uniform_Unit_Interval() < 1.f - epsilon) 410 | { 411 | action_probability = 1.f - epsilon + base_probability; 412 | } 413 | else 414 | { 415 | // Get uniform random action ID 416 | u32 actionId = random_generator.Uniform_Int(1, num_actions); 417 | 418 | if (actionId == chosen_action) 419 | { 420 | // IF it matches the one chosen by the default policy 421 | // then increase the probability 422 | action_probability = 1.f - epsilon + base_probability; 423 | } 424 | else 425 | { 426 | // Otherwise it's just the uniform probability 427 | action_probability = base_probability; 428 | } 429 | chosen_action = actionId; 430 | } 431 | 432 | return std::tuple(chosen_action, action_probability, true); 433 | } 434 | 435 | private: 436 | IPolicy& m_default_policy; 437 | const float m_epsilon; 438 | bool m_explore; 439 | const u32 m_num_actions; 440 | }; 441 | 442 | /// 443 | /// In some cases, different actions have a different scores, and you would prefer to 444 | /// choose actions with large scores. Softmax allows you to do that. 445 | /// 446 | template 447 | class SoftmaxExplorer : public IExplorer, public IConsumeScorer 448 | { 449 | public: 450 | /// 451 | /// The constructor is the only public member, because this should be used with the MwtExplorer. 452 | /// 453 | /// @param default_scorer A function which outputs a score for each action. 454 | /// @param lambda lambda = 0 implies uniform distribution. Large lambda is equivalent to a max. 455 | /// @param num_actions The number of actions to randomize over. 456 | /// 457 | SoftmaxExplorer(IScorer& default_scorer, float lambda, u32 num_actions) : 458 | m_default_scorer(default_scorer), m_lambda(lambda), m_num_actions(num_actions), m_explore(true) 459 | { 460 | if (m_num_actions < 1) 461 | { 462 | throw std::invalid_argument("Number of actions must be at least 1."); 463 | } 464 | } 465 | 466 | /// 467 | /// Initializes a softmax explorer with variable number of actions. 468 | /// 469 | /// @param default_scorer A function which outputs a score for each action. 470 | /// @param lambda lambda = 0 implies uniform distribution. Large lambda is equivalent to a max. 471 | /// 472 | SoftmaxExplorer(IScorer& default_scorer, float lambda) : 473 | m_default_scorer(default_scorer), m_lambda(lambda), m_num_actions(UINT_MAX), m_explore(true) 474 | { 475 | static_assert(std::is_base_of::value, "The provided context does not implement variable-action interface."); 476 | } 477 | 478 | void Update_Scorer(IScorer& new_scorer) 479 | { 480 | m_default_scorer = new_scorer; 481 | } 482 | 483 | void Enable_Explore(bool explore) 484 | { 485 | m_explore = explore; 486 | } 487 | 488 | private: 489 | std::tuple Choose_Action(u64 salted_seed, Ctx& context) 490 | { 491 | u32 num_actions = ::Get_Variable_Number_Of_Actions(context, m_num_actions); 492 | 493 | PRG::prg random_generator(salted_seed); 494 | 495 | // Invoke the default scorer function 496 | vector scores = m_default_scorer.Score_Actions(context); 497 | u32 num_scores = (u32)scores.size(); 498 | if (num_scores != num_actions) 499 | { 500 | throw std::invalid_argument("The number of scores returned by the scorer must equal number of actions"); 501 | } 502 | 503 | u32 i = 0; 504 | 505 | float max_score = -FLT_MAX; 506 | for (i = 0; i < num_scores; i++) 507 | { 508 | if (max_score < scores[i]) 509 | { 510 | max_score = scores[i]; 511 | } 512 | } 513 | 514 | float action_probability = 0.f; 515 | u32 action_index = 0; 516 | if (m_explore) 517 | { 518 | // Create a normalized exponential distribution based on the returned scores 519 | for (i = 0; i < num_scores; i++) 520 | { 521 | scores[i] = exp(m_lambda * (scores[i] - max_score)); 522 | } 523 | 524 | // Create a discrete_distribution based on the returned weights. This class handles the 525 | // case where the sum of the weights is < or > 1, by normalizing agains the sum. 526 | float total = 0.f; 527 | for (size_t i = 0; i < num_scores; i++) 528 | total += scores[i]; 529 | 530 | float draw = random_generator.Uniform_Unit_Interval(); 531 | 532 | float sum = 0.f; 533 | action_probability = 0.f; 534 | action_index = num_scores - 1; 535 | for (u32 i = 0; i < num_scores; i++) 536 | { 537 | scores[i] = scores[i] / total; 538 | sum += scores[i]; 539 | if (sum > draw) 540 | { 541 | action_index = i; 542 | action_probability = scores[i]; 543 | break; 544 | } 545 | } 546 | } 547 | else 548 | { 549 | float max_score = 0.f; 550 | for (size_t i = 0; i < num_scores; i++) 551 | { 552 | if (max_score < scores[i]) 553 | { 554 | max_score = scores[i]; 555 | action_index = (u32)i; 556 | } 557 | } 558 | action_probability = 1.f; // Set to 1 since we always pick the highest one. 559 | } 560 | 561 | // action id is one-based 562 | return std::tuple(action_index + 1, action_probability, true); 563 | } 564 | 565 | private: 566 | IScorer& m_default_scorer; 567 | bool m_explore; 568 | const float m_lambda; 569 | const u32 m_num_actions; 570 | }; 571 | 572 | /// 573 | /// GenericExplorer provides complete flexibility. You can create any 574 | /// distribution over actions desired, and it will draw from that. 575 | /// 576 | template 577 | class GenericExplorer : public IExplorer, public IConsumeScorer 578 | { 579 | public: 580 | /// 581 | /// The constructor is the only public member, because this should be used with the MwtExplorer. 582 | /// 583 | /// @param default_scorer A function which outputs the probability of each action. 584 | /// @param num_actions The number of actions to randomize over. 585 | /// 586 | GenericExplorer(IScorer& default_scorer, u32 num_actions) : 587 | m_default_scorer(default_scorer), m_num_actions(num_actions), m_explore(true) 588 | { 589 | if (m_num_actions < 1) 590 | { 591 | throw std::invalid_argument("Number of actions must be at least 1."); 592 | } 593 | } 594 | 595 | /// 596 | /// Initializes a generic explorer with variable number of actions. 597 | /// 598 | /// @param default_scorer A function which outputs the probability of each action. 599 | /// 600 | GenericExplorer(IScorer& default_scorer) : 601 | m_default_scorer(default_scorer), m_num_actions(UINT_MAX), m_explore(true) 602 | { 603 | static_assert(std::is_base_of::value, "The provided context does not implement variable-action interface."); 604 | } 605 | 606 | void Update_Scorer(IScorer& new_scorer) 607 | { 608 | m_default_scorer = new_scorer; 609 | } 610 | 611 | void Enable_Explore(bool explore) 612 | { 613 | m_explore = explore; 614 | } 615 | 616 | private: 617 | std::tuple Choose_Action(u64 salted_seed, Ctx& context) 618 | { 619 | u32 num_actions = ::Get_Variable_Number_Of_Actions(context, m_num_actions); 620 | 621 | PRG::prg random_generator(salted_seed); 622 | 623 | // Invoke the default scorer function 624 | vector weights = m_default_scorer.Score_Actions(context); 625 | u32 num_weights = (u32)weights.size(); 626 | if (num_weights != num_actions) 627 | { 628 | throw std::invalid_argument("The number of weights returned by the scorer must equal number of actions"); 629 | } 630 | 631 | // Create a discrete_distribution based on the returned weights. This class handles the 632 | // case where the sum of the weights is < or > 1, by normalizing agains the sum. 633 | float total = 0.f; 634 | for (size_t i = 0; i < num_weights; i++) 635 | { 636 | if (weights[i] < 0) 637 | { 638 | throw std::invalid_argument("Scores must be non-negative."); 639 | } 640 | total += weights[i]; 641 | } 642 | if (total == 0) 643 | { 644 | throw std::invalid_argument("At least one score must be positive."); 645 | } 646 | 647 | float draw = random_generator.Uniform_Unit_Interval(); 648 | 649 | float sum = 0.f; 650 | float action_probability = 0.f; 651 | u32 action_index = num_weights - 1; 652 | for (u32 i = 0; i < num_weights; i++) 653 | { 654 | weights[i] = weights[i] / total; 655 | sum += weights[i]; 656 | if (sum > draw) 657 | { 658 | action_index = i; 659 | action_probability = weights[i]; 660 | break; 661 | } 662 | } 663 | 664 | // action id is one-based 665 | return std::tuple(action_index + 1, action_probability, true); 666 | } 667 | 668 | private: 669 | IScorer& m_default_scorer; 670 | bool m_explore; 671 | const u32 m_num_actions; 672 | }; 673 | 674 | /// 675 | /// The tau-first explorer collects exactly tau uniform random exploration events, and then 676 | /// uses the default policy thereafter. 677 | /// 678 | template 679 | class TauFirstExplorer : public IExplorer, public IConsumePolicy 680 | { 681 | public: 682 | 683 | /// 684 | /// The constructor is the only public member, because this should be used with the MwtExplorer. 685 | /// 686 | /// @param default_policy A default policy after randomization finishes. 687 | /// @param tau The number of events to be uniform over. 688 | /// @param num_actions The number of actions to randomize over. 689 | /// 690 | TauFirstExplorer(IPolicy& default_policy, u32 tau, u32 num_actions) : 691 | m_default_policy(default_policy), m_tau(tau), m_num_actions(num_actions), m_explore(true) 692 | { 693 | if (m_num_actions < 1) 694 | { 695 | throw std::invalid_argument("Number of actions must be at least 1."); 696 | } 697 | } 698 | 699 | /// 700 | /// Initializes a tau-first explorer with variable number of actions. 701 | /// 702 | /// @param default_policy A default policy after randomization finishes. 703 | /// @param tau The number of events to be uniform over. 704 | /// 705 | TauFirstExplorer(IPolicy& default_policy, u32 tau) : 706 | m_default_policy(default_policy), m_tau(tau), m_num_actions(UINT_MAX), m_explore(true) 707 | { 708 | static_assert(std::is_base_of::value, "The provided context does not implement variable-action interface."); 709 | } 710 | 711 | void Update_Policy(IPolicy& new_policy) 712 | { 713 | m_default_policy = new_policy; 714 | } 715 | 716 | void Enable_Explore(bool explore) 717 | { 718 | m_explore = explore; 719 | } 720 | 721 | private: 722 | std::tuple Choose_Action(u64 salted_seed, Ctx& context) 723 | { 724 | u32 num_actions = ::Get_Variable_Number_Of_Actions(context, m_num_actions); 725 | 726 | PRG::prg random_generator(salted_seed); 727 | 728 | u32 chosen_action = 0; 729 | float action_probability = 0.f; 730 | bool log_action; 731 | 732 | if (m_tau && m_explore) 733 | { 734 | m_tau--; 735 | u32 actionId = random_generator.Uniform_Int(1, num_actions); 736 | action_probability = 1.f / num_actions; 737 | chosen_action = actionId; 738 | log_action = true; 739 | } 740 | else 741 | { 742 | // Invoke the default policy function to get the action 743 | chosen_action = m_default_policy.Choose_Action(context); 744 | 745 | if (chosen_action == 0 || chosen_action > num_actions) 746 | { 747 | throw std::invalid_argument("Action chosen by default policy is not within valid range."); 748 | } 749 | 750 | action_probability = 1.f; 751 | log_action = false; 752 | } 753 | 754 | return std::tuple(chosen_action, action_probability, log_action); 755 | } 756 | 757 | private: 758 | IPolicy& m_default_policy; 759 | bool m_explore; 760 | u32 m_tau; 761 | const u32 m_num_actions; 762 | }; 763 | 764 | /// 765 | /// The Bootstrap explorer randomizes over the actions chosen by a set of default policies. 766 | /// This performs well statistically but can be computationally expensive. 767 | /// 768 | template 769 | class BootstrapExplorer : public IExplorer, public IConsumePolicies 770 | { 771 | public: 772 | /// 773 | /// The constructor is the only public member, because this should be used with the MwtExplorer. 774 | /// 775 | /// @param default_policy_functions A set of default policies to be uniform random over. 776 | /// The policy pointers must be valid throughout the lifetime of this explorer. 777 | /// @param num_actions The number of actions to randomize over. 778 | /// 779 | BootstrapExplorer(vector>>& default_policy_functions, u32 num_actions) : 780 | m_default_policy_functions(default_policy_functions), 781 | m_num_actions(num_actions), m_explore(true), m_bags((u32)default_policy_functions.size()) 782 | { 783 | if (m_num_actions < 1) 784 | { 785 | throw std::invalid_argument("Number of actions must be at least 1."); 786 | } 787 | 788 | if (m_bags < 1) 789 | { 790 | throw std::invalid_argument("Number of bags must be at least 1."); 791 | } 792 | } 793 | 794 | /// 795 | /// Initializes a bootstrap explorer with variable number of actions. 796 | /// 797 | /// @param default_policy_functions A set of default policies to be uniform random over. 798 | /// The policy pointers must be valid throughout the lifetime of this explorer. 799 | /// 800 | BootstrapExplorer(vector>>& default_policy_functions) : 801 | m_default_policy_functions(default_policy_functions), 802 | m_num_actions(UINT_MAX), m_explore(true), m_bags((u32)default_policy_functions.size()) 803 | { 804 | if (m_bags < 1) 805 | { 806 | throw std::invalid_argument("Number of bags must be at least 1."); 807 | } 808 | 809 | static_assert(std::is_base_of::value, "The provided context does not implement variable-action interface."); 810 | } 811 | 812 | void Update_Policy(vector>>& new_policy_functions) 813 | { 814 | m_default_policy_functions = move(new_policy_functions); 815 | } 816 | 817 | void Enable_Explore(bool explore) 818 | { 819 | m_explore = explore; 820 | } 821 | 822 | private: 823 | std::tuple Choose_Action(u64 salted_seed, Ctx& context) 824 | { 825 | u32 num_actions = ::Get_Variable_Number_Of_Actions(context, m_num_actions); 826 | 827 | PRG::prg random_generator(salted_seed); 828 | 829 | // Select bag 830 | u32 chosen_bag = random_generator.Uniform_Int(0, m_bags - 1); 831 | 832 | // Invoke the default policy function to get the action 833 | u32 chosen_action = 0; 834 | float action_probability = 0.f; 835 | 836 | if (m_explore) 837 | { 838 | u32 action_from_bag = 0; 839 | vector actions_selected; 840 | for (size_t i = 0; i < num_actions; i++) 841 | { 842 | actions_selected.push_back(0); 843 | } 844 | 845 | // Invoke the default policy function to get the action 846 | for (u32 current_bag = 0; current_bag < m_bags; current_bag++) 847 | { 848 | // TODO: can VW predict for all bags on one call? (returning all actions at once) 849 | // if we trigger into VW passing an index to invoke bootstrap scoring, and if VW model changes while we are doing so, 850 | // we could end up calling the wrong bag 851 | action_from_bag = m_default_policy_functions[current_bag]->Choose_Action(context); 852 | 853 | if (action_from_bag == 0 || action_from_bag > num_actions) 854 | { 855 | throw std::invalid_argument("Action chosen by default policy is not within valid range."); 856 | } 857 | 858 | if (current_bag == chosen_bag) 859 | { 860 | chosen_action = action_from_bag; 861 | } 862 | //this won't work if actions aren't 0 to Count 863 | actions_selected[action_from_bag - 1]++; // action id is one-based 864 | } 865 | action_probability = (float)actions_selected[chosen_action - 1] / m_bags; // action id is one-based 866 | } 867 | else 868 | { 869 | chosen_action = m_default_policy_functions[0]->Choose_Action(context); 870 | action_probability = 1.f; 871 | } 872 | 873 | return std::tuple(chosen_action, action_probability, true); 874 | } 875 | 876 | private: 877 | vector>>& m_default_policy_functions; 878 | bool m_explore; 879 | const u32 m_bags; 880 | const u32 m_num_actions; 881 | }; 882 | 883 | } // End namespace MultiWorldTestingCpp 884 | /*! @} End of Doxygen Groups*/ 885 | -------------------------------------------------------------------------------- /static/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | cd ..; $(MAKE) 3 | -------------------------------------------------------------------------------- /static/explore.cpp: -------------------------------------------------------------------------------- 1 | // vw_explore.cpp : Defines the entry point for the console application. 2 | // 3 | 4 | #include "MwtExplorer.h" 5 | -------------------------------------------------------------------------------- /static/explore_static.vcxproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | Debug 6 | Win32 7 | 8 | 9 | Debug 10 | x64 11 | 12 | 13 | Release 14 | Win32 15 | 16 | 17 | Release 18 | x64 19 | 20 | 21 | 22 | {ACE47E98-488C-4CDF-B9F1-36337B2855AD} 23 | Win32Proj 24 | vw_explore_static 25 | explore_static 26 | 27 | 28 | 29 | StaticLibrary 30 | true 31 | v120 32 | Unicode 33 | 34 | 35 | StaticLibrary 36 | true 37 | v120 38 | Unicode 39 | 40 | 41 | StaticLibrary 42 | false 43 | v120 44 | true 45 | Unicode 46 | 47 | 48 | StaticLibrary 49 | false 50 | v120 51 | true 52 | Unicode 53 | 54 | 55 | c:\boost\x64\include\boost-1_56 56 | ..\..\..\zlib-1.2.8 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | true 76 | 77 | 78 | true 79 | 80 | 81 | false 82 | 83 | 84 | false 85 | 86 | 87 | 88 | 89 | 90 | Level3 91 | Disabled 92 | WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 93 | ProgramDatabase 94 | 95 | 96 | Console 97 | true 98 | 99 | 100 | 101 | 102 | Level3 103 | Disabled 104 | WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 105 | 106 | 107 | Console 108 | true 109 | 110 | 111 | 112 | 113 | Level3 114 | 115 | 116 | MaxSpeed 117 | true 118 | true 119 | WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 120 | true 121 | 122 | 123 | Console 124 | true 125 | true 126 | true 127 | 128 | 129 | 130 | 131 | Level3 132 | MaxSpeed 133 | true 134 | true 135 | WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 136 | true 137 | 138 | 139 | Console 140 | true 141 | true 142 | true 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | -------------------------------------------------------------------------------- /static/utility.h: -------------------------------------------------------------------------------- 1 | /*******************************************************************/ 2 | // Classes declared in this file are intended for internal use only. 3 | /*******************************************************************/ 4 | 5 | #pragma once 6 | #include 7 | #include /* defines size_t */ 8 | 9 | #ifdef WIN32 10 | typedef unsigned __int64 u64; 11 | typedef unsigned __int32 u32; 12 | typedef unsigned __int16 u16; 13 | typedef unsigned __int8 u8; 14 | typedef signed __int64 i64; 15 | typedef signed __int32 i32; 16 | typedef signed __int16 i16; 17 | typedef signed __int8 i8; 18 | // cross-platform float to_string 19 | #else 20 | typedef uint64_t u64; 21 | typedef uint32_t u32; 22 | typedef uint16_t u16; 23 | typedef uint8_t u8; 24 | typedef int64_t i64; 25 | typedef int32_t i32; 26 | typedef int16_t i16; 27 | typedef int8_t i8; 28 | // cross-platform float to_string 29 | #define sprintf_s snprintf 30 | #endif 31 | 32 | typedef unsigned char byte; 33 | 34 | #include 35 | #include 36 | #include 37 | 38 | /*! 39 | * \addtogroup MultiWorldTestingCpp 40 | * @{ 41 | */ 42 | 43 | MWT_NAMESPACE { 44 | 45 | // 46 | // MurmurHash3, by Austin Appleby 47 | // 48 | // Originals at: 49 | // http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.cpp 50 | // http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.h 51 | // 52 | // Notes: 53 | // 1) this code assumes we can read a 4-byte value from any address 54 | // without crashing (i.e non aligned access is supported). This is 55 | // not a problem on Intel/x86/AMD64 machines (including new Macs) 56 | // 2) It produces different results on little-endian and big-endian machines. 57 | // 58 | //----------------------------------------------------------------------------- 59 | // MurmurHash3 was written by Austin Appleby, and is placed in the public 60 | // domain. The author hereby disclaims copyright to this source code. 61 | 62 | // Note - The x86 and x64 versions do _not_ produce the same results, as the 63 | // algorithms are optimized for their respective platforms. You can still 64 | // compile and run any of them on any platform, but your performance with the 65 | // non-native version will be less than optimal. 66 | //----------------------------------------------------------------------------- 67 | 68 | // Platform-specific functions and macros 69 | #if defined(_MSC_VER) // Microsoft Visual Studio 70 | # include 71 | 72 | # include 73 | # define ROTL32(x,y) _rotl(x,y) 74 | # define BIG_CONSTANT(x) (x) 75 | 76 | #else // Other compilers 77 | # include /* defines uint32_t etc */ 78 | 79 | inline uint32_t rotl32(uint32_t x, int8_t r) 80 | { 81 | return (x << r) | (x >> (32 - r)); 82 | } 83 | 84 | # define ROTL32(x,y) rotl32(x,y) 85 | # define BIG_CONSTANT(x) (x##LLU) 86 | 87 | #endif // !defined(_MSC_VER) 88 | 89 | struct murmur_hash { 90 | 91 | //----------------------------------------------------------------------------- 92 | // Block read - if your platform needs to do endian-swapping or can only 93 | // handle aligned reads, do the conversion here 94 | private: 95 | static inline uint32_t getblock(const uint32_t * p, int i) 96 | { 97 | return p[i]; 98 | } 99 | 100 | //----------------------------------------------------------------------------- 101 | // Finalization mix - force all bits of a hash block to avalanche 102 | 103 | static inline uint32_t fmix(uint32_t h) 104 | { 105 | h ^= h >> 16; 106 | h *= 0x85ebca6b; 107 | h ^= h >> 13; 108 | h *= 0xc2b2ae35; 109 | h ^= h >> 16; 110 | 111 | return h; 112 | } 113 | 114 | //----------------------------------------------------------------------------- 115 | public: 116 | uint32_t uniform_hash(const void * key, size_t len, uint32_t seed) 117 | { 118 | const uint8_t * data = (const uint8_t*)key; 119 | const int nblocks = (int)len / 4; 120 | 121 | uint32_t h1 = seed; 122 | 123 | const uint32_t c1 = 0xcc9e2d51; 124 | const uint32_t c2 = 0x1b873593; 125 | 126 | // --- body 127 | const uint32_t * blocks = (const uint32_t *)(data + nblocks * 4); 128 | 129 | for (int i = -nblocks; i; i++) { 130 | uint32_t k1 = getblock(blocks, i); 131 | 132 | k1 *= c1; 133 | k1 = ROTL32(k1, 15); 134 | k1 *= c2; 135 | 136 | h1 ^= k1; 137 | h1 = ROTL32(h1, 13); 138 | h1 = h1 * 5 + 0xe6546b64; 139 | } 140 | 141 | // --- tail 142 | const uint8_t * tail = (const uint8_t*)(data + nblocks * 4); 143 | 144 | uint32_t k1 = 0; 145 | 146 | switch (len & 3) { 147 | case 3: k1 ^= tail[2] << 16; 148 | case 2: k1 ^= tail[1] << 8; 149 | case 1: k1 ^= tail[0]; 150 | k1 *= c1; k1 = ROTL32(k1, 15); k1 *= c2; h1 ^= k1; 151 | } 152 | 153 | // --- finalization 154 | h1 ^= len; 155 | 156 | return fmix(h1); 157 | } 158 | }; 159 | 160 | class HashUtils 161 | { 162 | public: 163 | static u64 Compute_Id_Hash(const std::string& unique_id) 164 | { 165 | size_t ret = 0; 166 | const char *p = unique_id.c_str(); 167 | while (*p != '\0') 168 | if (*p >= '0' && *p <= '9') 169 | ret = 10 * ret + *(p++) - '0'; 170 | else 171 | { 172 | murmur_hash foo; 173 | return foo.uniform_hash(unique_id.c_str(), unique_id.size(), 0); 174 | } 175 | return ret; 176 | } 177 | }; 178 | 179 | const size_t max_int = 100000; 180 | const float max_float = max_int; 181 | const float min_float = 0.00001f; 182 | const size_t max_digits = (size_t) roundf((float) (-log(min_float) / log(10.))); 183 | 184 | class NumberUtils 185 | { 186 | public: 187 | template 188 | static void print_mantissa(char*& begin, float f) 189 | { // helper for print_float 190 | char values[10]; 191 | size_t v = (size_t)f; 192 | size_t digit = 0; 193 | size_t first_nonzero = 0; 194 | for (size_t max = 1; max <= v; ++digit) 195 | { 196 | size_t max_next = max * 10; 197 | char v_mod = (char) (v % max_next / max); 198 | if (!trailing_zeros && v_mod != '\0' && first_nonzero == 0) 199 | first_nonzero = digit; 200 | values[digit] = '0' + v_mod; 201 | max = max_next; 202 | } 203 | if (!trailing_zeros) 204 | for (size_t i = max_digits; i > digit; i--) 205 | *begin++ = '0'; 206 | while (digit > first_nonzero) 207 | *begin++ = values[--digit]; 208 | } 209 | 210 | static void print_float(char* begin, size_t size, float f) 211 | { 212 | bool sign = false; 213 | if (f < 0.f) 214 | sign = true; 215 | float unsigned_f = fabsf(f); 216 | if (unsigned_f < max_float && unsigned_f > min_float) 217 | { 218 | if (sign) 219 | *begin++ = '-'; 220 | print_mantissa(begin, unsigned_f); 221 | unsigned_f -= (size_t)unsigned_f; 222 | unsigned_f *= max_int; 223 | if (unsigned_f >= 1.f) 224 | { 225 | *begin++ = '.'; 226 | print_mantissa(begin, unsigned_f); 227 | } 228 | } 229 | else if (unsigned_f == 0.) 230 | *begin++ = '0'; 231 | else 232 | { 233 | sprintf_s(begin, size, "%g", f); 234 | return; 235 | } 236 | *begin = '\0'; 237 | return; 238 | } 239 | }; 240 | 241 | //A quick implementation similar to drand48 for cross-platform compatibility 242 | namespace PRG { 243 | const uint64_t a = 0xeece66d5deece66dULL; 244 | const uint64_t c = 2147483647; 245 | 246 | const int bias = 127 << 23; 247 | 248 | union int_float { 249 | int32_t i; 250 | float f; 251 | }; 252 | 253 | struct prg { 254 | private: 255 | uint64_t v; 256 | public: 257 | prg() { v = c; } 258 | prg(uint64_t initial) { v = initial; } 259 | 260 | float merand48(uint64_t& initial) 261 | { 262 | initial = a * initial + c; 263 | int_float temp; 264 | temp.i = ((initial >> 25) & 0x7FFFFF) | bias; 265 | return temp.f - 1; 266 | } 267 | 268 | float Uniform_Unit_Interval() 269 | { 270 | return merand48(v); 271 | } 272 | 273 | uint32_t Uniform_Int(uint32_t low, uint32_t high) 274 | { 275 | merand48(v); 276 | uint32_t ret = low + ((v >> 25) % (high - low + 1)); 277 | return ret; 278 | } 279 | }; 280 | } 281 | } 282 | /*! @} End of Doxygen Groups*/ 283 | -------------------------------------------------------------------------------- /static/vw_explore.vcxproj.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF} 6 | cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx 7 | 8 | 9 | {93995380-89BD-4b04-88EB-625FBE52EBFB} 10 | h;hh;hpp;hxx;hm;inl;inc;xsd 11 | 12 | 13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} 14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | Header Files 23 | 24 | 25 | Header Files 26 | 27 | 28 | Source Files 29 | 30 | 31 | Source Files 32 | 33 | 34 | Source Files 35 | 36 | 37 | Source Files 38 | 39 | 40 | Source Files 41 | 42 | 43 | 44 | 45 | Source Files 46 | 47 | 48 | Source Files 49 | 50 | 51 | -------------------------------------------------------------------------------- /tests/ExploreTests.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | Debug 5 | AnyCPU 6 | {CB0C6B20-560C-4822-8EF6-DA999A64B542} 7 | Library 8 | Properties 9 | ExploreTests 10 | ExploreTests 11 | v4.5 12 | 512 13 | {3AC096D0-A1C2-E12C-1390-A8335801FDAB};{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC} 14 | 10.0 15 | $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion) 16 | $(ProgramFiles)\Common Files\microsoft shared\VSTT\$(VisualStudioVersion)\UITestExtensionPackages 17 | False 18 | UnitTest 19 | 20 | 21 | true 22 | bin\x64\Debug\ 23 | DEBUG;TRACE 24 | full 25 | x64 26 | prompt 27 | MinimumRecommendedRules.ruleset 28 | 29 | 30 | bin\x64\Release\ 31 | TRACE 32 | true 33 | pdbonly 34 | x64 35 | prompt 36 | MinimumRecommendedRules.ruleset 37 | 38 | 39 | true 40 | bin\x86\Debug\ 41 | DEBUG;TRACE 42 | full 43 | x86 44 | prompt 45 | MinimumRecommendedRules.ruleset 46 | 47 | 48 | bin\x86\Release\ 49 | TRACE 50 | true 51 | pdbonly 52 | x86 53 | prompt 54 | MinimumRecommendedRules.ruleset 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | {8400da16-1f46-4a31-a126-bbe16f62bfd7} 78 | explore_clr 79 | 80 | 81 | 82 | 83 | 84 | 85 | False 86 | 87 | 88 | False 89 | 90 | 91 | False 92 | 93 | 94 | False 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 112 | -------------------------------------------------------------------------------- /tests/ExploreTests.csproj.user: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | true 5 | 6 | 7 | true 8 | 9 | -------------------------------------------------------------------------------- /tests/MWTExploreTests.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Runtime.InteropServices; 3 | using Microsoft.VisualStudio.TestTools.UnitTesting; 4 | using MultiWorldTesting; 5 | using System.Collections.Generic; 6 | using System.Linq; 7 | 8 | namespace ExploreTests 9 | { 10 | [TestClass] 11 | public class MWTExploreTests 12 | { 13 | /* 14 | ** C# Tests do not need to be as extensive as those for C++. These tests should ensure 15 | ** the interactions between managed and native code are as expected. 16 | */ 17 | [TestMethod] 18 | public void EpsilonGreedy() 19 | { 20 | uint numActions = 10; 21 | float epsilon = 0f; 22 | var policy = new TestPolicy(); 23 | var testContext = new TestContext(); 24 | var explorer = new EpsilonGreedyExplorer(policy, epsilon, numActions); 25 | 26 | EpsilonGreedyWithContext(numActions, testContext, policy, explorer); 27 | } 28 | 29 | [TestMethod] 30 | public void EpsilonGreedyFixedActionUsingVariableActionInterface() 31 | { 32 | uint numActions = 10; 33 | float epsilon = 0f; 34 | var policy = new TestPolicy(); 35 | var testContext = new TestVarContext(numActions); 36 | var explorer = new EpsilonGreedyExplorer(policy, epsilon); 37 | 38 | EpsilonGreedyWithContext(numActions, testContext, policy, explorer); 39 | } 40 | 41 | private static void EpsilonGreedyWithContext(uint numActions, TContext testContext, TestPolicy policy, IExplorer explorer) 42 | where TContext : TestContext 43 | { 44 | string uniqueKey = "ManagedTestId"; 45 | TestRecorder recorder = new TestRecorder(); 46 | MwtExplorer mwtt = new MwtExplorer("mwt", recorder); 47 | testContext.Id = 100; 48 | 49 | uint expectedAction = policy.ChooseAction(testContext); 50 | 51 | uint chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext); 52 | Assert.AreEqual(expectedAction, chosenAction); 53 | 54 | chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext); 55 | Assert.AreEqual(expectedAction, chosenAction); 56 | 57 | var interactions = recorder.GetAllInteractions(); 58 | Assert.AreEqual(2, interactions.Count); 59 | 60 | Assert.AreEqual(testContext.Id, interactions[0].Context.Id); 61 | 62 | // Verify that policy action is chosen all the time 63 | explorer.EnableExplore(false); 64 | for (int i = 0; i < 1000; i++) 65 | { 66 | chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext); 67 | Assert.AreEqual(expectedAction, chosenAction); 68 | } 69 | } 70 | 71 | [TestMethod] 72 | public void TauFirst() 73 | { 74 | uint numActions = 10; 75 | uint tau = 0; 76 | TestContext testContext = new TestContext() { Id = 100 }; 77 | var policy = new TestPolicy(); 78 | var explorer = new TauFirstExplorer(policy, tau, numActions); 79 | TauFirstWithContext(numActions, testContext, policy, explorer); 80 | } 81 | 82 | [TestMethod] 83 | public void TauFirstFixedActionUsingVariableActionInterface() 84 | { 85 | uint numActions = 10; 86 | uint tau = 0; 87 | var testContext = new TestVarContext(numActions) { Id = 100 }; 88 | var policy = new TestPolicy(); 89 | var explorer = new TauFirstExplorer(policy, tau); 90 | TauFirstWithContext(numActions, testContext, policy, explorer); 91 | } 92 | 93 | private static void TauFirstWithContext(uint numActions, TContext testContext, TestPolicy policy, IExplorer explorer) 94 | where TContext : TestContext 95 | { 96 | string uniqueKey = "ManagedTestId"; 97 | 98 | var recorder = new TestRecorder(); 99 | var mwtt = new MwtExplorer("mwt", recorder); 100 | 101 | uint expectedAction = policy.ChooseAction(testContext); 102 | 103 | uint chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext); 104 | Assert.AreEqual(expectedAction, chosenAction); 105 | 106 | var interactions = recorder.GetAllInteractions(); 107 | Assert.AreEqual(0, interactions.Count); 108 | 109 | // Verify that policy action is chosen all the time 110 | explorer.EnableExplore(false); 111 | for (int i = 0; i < 1000; i++) 112 | { 113 | chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext); 114 | Assert.AreEqual(expectedAction, chosenAction); 115 | } 116 | } 117 | 118 | [TestMethod] 119 | public void Bootstrap() 120 | { 121 | uint numActions = 10; 122 | uint numbags = 2; 123 | TestContext testContext1 = new TestContext() { Id = 99 }; 124 | TestContext testContext2 = new TestContext() { Id = 100 }; 125 | 126 | var policies = new TestPolicy[numbags]; 127 | for (int i = 0; i < numbags; i++) 128 | { 129 | policies[i] = new TestPolicy(i * 2); 130 | } 131 | var explorer = new BootstrapExplorer(policies, numActions); 132 | 133 | BootstrapWithContext(numActions, testContext1, testContext2, policies, explorer); 134 | } 135 | 136 | [TestMethod] 137 | public void BootstrapFixedActionUsingVariableActionInterface() 138 | { 139 | uint numActions = 10; 140 | uint numbags = 2; 141 | var testContext1 = new TestVarContext(numActions) { Id = 99 }; 142 | var testContext2 = new TestVarContext(numActions) { Id = 100 }; 143 | 144 | var policies = new TestPolicy[numbags]; 145 | for (int i = 0; i < numbags; i++) 146 | { 147 | policies[i] = new TestPolicy(i * 2); 148 | } 149 | var explorer = new BootstrapExplorer(policies); 150 | 151 | BootstrapWithContext(numActions, testContext1, testContext2, policies, explorer); 152 | } 153 | 154 | private static void BootstrapWithContext(uint numActions, TContext testContext1, TContext testContext2, TestPolicy[] policies, IExplorer explorer) 155 | where TContext : TestContext 156 | { 157 | string uniqueKey = "ManagedTestId"; 158 | 159 | var recorder = new TestRecorder(); 160 | var mwtt = new MwtExplorer("mwt", recorder); 161 | 162 | uint expectedAction = policies[0].ChooseAction(testContext1); 163 | 164 | uint chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext1); 165 | Assert.AreEqual(expectedAction, chosenAction); 166 | 167 | chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext2); 168 | Assert.AreEqual(expectedAction, chosenAction); 169 | 170 | var interactions = recorder.GetAllInteractions(); 171 | Assert.AreEqual(2, interactions.Count); 172 | 173 | Assert.AreEqual(testContext1.Id, interactions[0].Context.Id); 174 | Assert.AreEqual(testContext2.Id, interactions[1].Context.Id); 175 | 176 | // Verify that policy action is chosen all the time 177 | explorer.EnableExplore(false); 178 | for (int i = 0; i < 1000; i++) 179 | { 180 | chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext1); 181 | Assert.AreEqual(expectedAction, chosenAction); 182 | } 183 | } 184 | 185 | [TestMethod] 186 | public void Softmax() 187 | { 188 | uint numActions = 10; 189 | float lambda = 0.5f; 190 | uint numActionsCover = 100; 191 | float C = 5; 192 | var scorer = new TestScorer(numActions); 193 | var explorer = new SoftmaxExplorer(scorer, lambda, numActions); 194 | 195 | uint numDecisions = (uint)(numActions * Math.Log(numActions * 1.0) + Math.Log(numActionsCover * 1.0 / numActions) * C * numActions); 196 | var contexts = new TestContext[numDecisions]; 197 | for (int i = 0; i < numDecisions; i++) 198 | { 199 | contexts[i] = new TestContext { Id = i }; 200 | } 201 | 202 | SoftmaxWithContext(numActions, explorer, contexts); 203 | } 204 | 205 | [TestMethod] 206 | public void SoftmaxFixedActionUsingVariableActionInterface() 207 | { 208 | uint numActions = 10; 209 | float lambda = 0.5f; 210 | uint numActionsCover = 100; 211 | float C = 5; 212 | var scorer = new TestScorer(numActions); 213 | var explorer = new SoftmaxExplorer(scorer, lambda); 214 | 215 | uint numDecisions = (uint)(numActions * Math.Log(numActions * 1.0) + Math.Log(numActionsCover * 1.0 / numActions) * C * numActions); 216 | var contexts = new TestVarContext[numDecisions]; 217 | for (int i = 0; i < numDecisions; i++) 218 | { 219 | contexts[i] = new TestVarContext(numActions) { Id = i }; 220 | } 221 | 222 | SoftmaxWithContext(numActions, explorer, contexts); 223 | } 224 | 225 | private static void SoftmaxWithContext(uint numActions, IExplorer explorer, TContext[] contexts) 226 | where TContext : TestContext 227 | { 228 | var recorder = new TestRecorder(); 229 | var mwtt = new MwtExplorer("mwt", recorder); 230 | 231 | uint[] actions = new uint[numActions]; 232 | 233 | Random rand = new Random(); 234 | for (uint i = 0; i < contexts.Length; i++) 235 | { 236 | uint chosenAction = mwtt.ChooseAction(explorer, rand.NextDouble().ToString(), contexts[i]); 237 | actions[chosenAction - 1]++; // action id is one-based 238 | } 239 | 240 | for (uint i = 0; i < numActions; i++) 241 | { 242 | Assert.IsTrue(actions[i] > 0); 243 | } 244 | 245 | var interactions = recorder.GetAllInteractions(); 246 | Assert.AreEqual(contexts.Length, interactions.Count); 247 | 248 | for (int i = 0; i < contexts.Length; i++) 249 | { 250 | Assert.AreEqual(i, interactions[i].Context.Id); 251 | } 252 | } 253 | 254 | [TestMethod] 255 | public void SoftmaxScores() 256 | { 257 | uint numActions = 10; 258 | float lambda = 0.5f; 259 | var recorder = new TestRecorder(); 260 | var scorer = new TestScorer(numActions, uniform: false); 261 | 262 | var mwtt = new MwtExplorer("mwt", recorder); 263 | var explorer = new SoftmaxExplorer(scorer, lambda, numActions); 264 | 265 | Random rand = new Random(); 266 | mwtt.ChooseAction(explorer, rand.NextDouble().ToString(), new TestContext() { Id = 100 }); 267 | mwtt.ChooseAction(explorer, rand.NextDouble().ToString(), new TestContext() { Id = 101 }); 268 | mwtt.ChooseAction(explorer, rand.NextDouble().ToString(), new TestContext() { Id = 102 }); 269 | 270 | var interactions = recorder.GetAllInteractions(); 271 | 272 | Assert.AreEqual(3, interactions.Count); 273 | 274 | for (int i = 0; i < interactions.Count; i++) 275 | { 276 | // Scores are not equal therefore probabilities should not be uniform 277 | Assert.AreNotEqual(interactions[i].Probability, 1.0f / numActions); 278 | Assert.AreEqual(100 + i, interactions[i].Context.Id); 279 | } 280 | 281 | // Verify that policy action is chosen all the time 282 | TestContext context = new TestContext { Id = 100 }; 283 | List scores = scorer.ScoreActions(context); 284 | float maxScore = 0; 285 | uint highestScoreAction = 0; 286 | for (int i = 0; i < scores.Count; i++) 287 | { 288 | if (maxScore < scores[i]) 289 | { 290 | maxScore = scores[i]; 291 | highestScoreAction = (uint)i + 1; 292 | } 293 | } 294 | 295 | explorer.EnableExplore(false); 296 | for (int i = 0; i < 1000; i++) 297 | { 298 | uint chosenAction = mwtt.ChooseAction(explorer, rand.NextDouble().ToString(), new TestContext() { Id = (int)i }); 299 | Assert.AreEqual(highestScoreAction, chosenAction); 300 | } 301 | } 302 | 303 | [TestMethod] 304 | public void Generic() 305 | { 306 | uint numActions = 10; 307 | TestScorer scorer = new TestScorer(numActions); 308 | TestContext testContext = new TestContext() { Id = 100 }; 309 | var explorer = new GenericExplorer(scorer, numActions); 310 | GenericWithContext(numActions, testContext, explorer); 311 | } 312 | 313 | [TestMethod] 314 | public void GenericFixedActionUsingVariableActionInterface() 315 | { 316 | uint numActions = 10; 317 | var scorer = new TestScorer(numActions); 318 | var testContext = new TestVarContext(numActions) { Id = 100 }; 319 | var explorer = new GenericExplorer(scorer); 320 | GenericWithContext(numActions, testContext, explorer); 321 | } 322 | 323 | private static void GenericWithContext(uint numActions, TContext testContext, IExplorer explorer) 324 | where TContext : TestContext 325 | { 326 | string uniqueKey = "ManagedTestId"; 327 | var recorder = new TestRecorder(); 328 | 329 | var mwtt = new MwtExplorer("mwt", recorder); 330 | 331 | uint chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext); 332 | 333 | var interactions = recorder.GetAllInteractions(); 334 | Assert.AreEqual(1, interactions.Count); 335 | Assert.AreEqual(testContext.Id, interactions[0].Context.Id); 336 | } 337 | 338 | [TestMethod] 339 | public void UsageBadVariableActionContext() 340 | { 341 | int numExceptionsCaught = 0; 342 | int numExceptionsExpected = 5; 343 | 344 | var tryCatchArgumentException = (Action)((action) => { 345 | try 346 | { 347 | action(); 348 | } 349 | catch (ArgumentException ex) 350 | { 351 | if (ex.ParamName.ToLower() == "ctx") 352 | { 353 | numExceptionsCaught++; 354 | } 355 | } 356 | }); 357 | 358 | tryCatchArgumentException(() => { 359 | var mwt = new MwtExplorer("test", new TestRecorder()); 360 | var policy = new TestPolicy(); 361 | var explorer = new EpsilonGreedyExplorer(policy, 0.2f); 362 | mwt.ChooseAction(explorer, "key", new TestContext()); 363 | }); 364 | tryCatchArgumentException(() => 365 | { 366 | var mwt = new MwtExplorer("test", new TestRecorder()); 367 | var policy = new TestPolicy(); 368 | var explorer = new TauFirstExplorer(policy, 10); 369 | mwt.ChooseAction(explorer, "key", new TestContext()); 370 | }); 371 | tryCatchArgumentException(() => 372 | { 373 | var mwt = new MwtExplorer("test", new TestRecorder()); 374 | var policies = new TestPolicy[2]; 375 | for (int i = 0; i < 2; i++) 376 | { 377 | policies[i] = new TestPolicy(i * 2); 378 | } 379 | var explorer = new BootstrapExplorer(policies); 380 | mwt.ChooseAction(explorer, "key", new TestContext()); 381 | }); 382 | tryCatchArgumentException(() => 383 | { 384 | var mwt = new MwtExplorer("test", new TestRecorder()); 385 | var scorer = new TestScorer(10); 386 | var explorer = new SoftmaxExplorer(scorer, 0.5f); 387 | mwt.ChooseAction(explorer, "key", new TestContext()); 388 | }); 389 | tryCatchArgumentException(() => 390 | { 391 | var mwt = new MwtExplorer("test", new TestRecorder()); 392 | var scorer = new TestScorer(10); 393 | var explorer = new GenericExplorer(scorer); 394 | mwt.ChooseAction(explorer, "key", new TestContext()); 395 | }); 396 | 397 | Assert.AreEqual(numExceptionsExpected, numExceptionsCaught); 398 | } 399 | 400 | [TestInitialize] 401 | public void TestInitialize() 402 | { 403 | } 404 | 405 | [TestCleanup] 406 | public void TestCleanup() 407 | { 408 | } 409 | } 410 | 411 | struct TestInteraction 412 | { 413 | public Ctx Context; 414 | public UInt32 Action; 415 | public float Probability; 416 | public string UniqueKey; 417 | } 418 | 419 | class TestContext 420 | { 421 | private int id; 422 | 423 | public int Id 424 | { 425 | get { return id; } 426 | set { id = value; } 427 | } 428 | } 429 | 430 | class TestVarContext : TestContext, IVariableActionContext 431 | { 432 | public TestVarContext(uint numberOfActions) 433 | { 434 | NumberOfActions = numberOfActions; 435 | } 436 | 437 | public uint GetNumberOfActions() 438 | { 439 | return NumberOfActions; 440 | } 441 | 442 | public uint NumberOfActions { get; set; } 443 | } 444 | 445 | class TestRecorder : IRecorder 446 | { 447 | public void Record(Ctx context, UInt32 action, float probability, string uniqueKey) 448 | { 449 | interactions.Add(new TestInteraction() 450 | { 451 | Context = context, 452 | Action = action, 453 | Probability = probability, 454 | UniqueKey = uniqueKey 455 | }); 456 | } 457 | 458 | public List> GetAllInteractions() 459 | { 460 | return interactions; 461 | } 462 | 463 | private List> interactions = new List>(); 464 | } 465 | 466 | class TestPolicy : IPolicy 467 | { 468 | public TestPolicy() : this(-1) { } 469 | 470 | public TestPolicy(int index) 471 | { 472 | this.index = index; 473 | } 474 | 475 | public uint ChooseAction(TContext context) 476 | { 477 | return 5; 478 | } 479 | 480 | private int index; 481 | } 482 | 483 | class TestSimplePolicy : IPolicy 484 | { 485 | public uint ChooseAction(SimpleContext context) 486 | { 487 | return 1; 488 | } 489 | } 490 | 491 | class StringPolicy : IPolicy 492 | { 493 | public uint ChooseAction(SimpleContext context) 494 | { 495 | return 1; 496 | } 497 | } 498 | 499 | class TestScorer : IScorer 500 | { 501 | public TestScorer(uint numActions, bool uniform = true) 502 | { 503 | this.uniform = uniform; 504 | this.numActions = numActions; 505 | } 506 | public List ScoreActions(Ctx context) 507 | { 508 | if (uniform) 509 | { 510 | return Enumerable.Repeat(1.0f / numActions, (int)numActions).ToList(); 511 | } 512 | else 513 | { 514 | return Array.ConvertAll(Enumerable.Range(1, (int)numActions).ToArray(), Convert.ToSingle).ToList(); 515 | } 516 | } 517 | private uint numActions; 518 | private bool uniform; 519 | } 520 | } 521 | -------------------------------------------------------------------------------- /tests/MWTExploreTests.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "MWTExplorer.h" 4 | #include "utility.h" 5 | #include 6 | #include 7 | #include 8 | 9 | using namespace MultiWorldTesting; 10 | 11 | class TestContext 12 | { 13 | 14 | }; 15 | 16 | class TestVarContext : public TestContext, public IVariableActionContext 17 | { 18 | public: 19 | TestVarContext(u32 num_actions) 20 | { 21 | m_num_actions = num_actions; 22 | } 23 | 24 | u32 Get_Number_Of_Actions() 25 | { 26 | return m_num_actions; 27 | } 28 | 29 | private: 30 | u32 m_num_actions; 31 | }; 32 | 33 | template 34 | struct TestInteraction 35 | { 36 | Ctx& Context; 37 | u32 Action; 38 | float Probability; 39 | string Unique_Key; 40 | }; 41 | 42 | template 43 | class TestPolicy : public IPolicy 44 | { 45 | public: 46 | TestPolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { } 47 | u32 Choose_Action(TContext& context) 48 | { 49 | return m_params % m_num_actions + 1; // action id is one-based 50 | } 51 | private: 52 | int m_params; 53 | int m_num_actions; 54 | }; 55 | 56 | template 57 | class TestScorer : public IScorer 58 | { 59 | public: 60 | TestScorer(int params, int num_actions, bool uniform = true) : 61 | m_params(params), m_num_actions(num_actions), m_uniform(uniform) 62 | { } 63 | 64 | vector Score_Actions(TContext& context) 65 | { 66 | vector scores; 67 | if (m_uniform) 68 | { 69 | for (int i = 0; i < m_num_actions; i++) 70 | { 71 | scores.push_back((float)m_params); 72 | } 73 | } 74 | else 75 | { 76 | for (int i = 0; i < m_num_actions; i++) 77 | { 78 | scores.push_back((float)m_params + i); 79 | } 80 | } 81 | return scores; 82 | } 83 | private: 84 | int m_params; 85 | int m_num_actions; 86 | bool m_uniform; 87 | }; 88 | 89 | class FixedScorer : public IScorer 90 | { 91 | public: 92 | FixedScorer(int num_actions, int value) : 93 | m_num_actions(num_actions), m_value(value) 94 | { } 95 | 96 | vector Score_Actions(TestContext& context) 97 | { 98 | vector scores; 99 | for (int i = 0; i < m_num_actions; i++) 100 | { 101 | scores.push_back((float)m_value); 102 | } 103 | return scores; 104 | } 105 | private: 106 | int m_num_actions; 107 | int m_value; 108 | }; 109 | 110 | class TestSimpleScorer : public IScorer 111 | { 112 | public: 113 | TestSimpleScorer(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { } 114 | vector Score_Actions(SimpleContext& context) 115 | { 116 | vector scores; 117 | for (int i = 0; i < m_num_actions; i++) 118 | { 119 | scores.push_back((float)m_params); 120 | } 121 | return scores; 122 | } 123 | private: 124 | int m_params; 125 | int m_num_actions; 126 | }; 127 | 128 | class TestSimplePolicy : public IPolicy 129 | { 130 | public: 131 | TestSimplePolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { } 132 | u32 Choose_Action(SimpleContext& context) 133 | { 134 | return m_params % m_num_actions + 1; // action id is one-based 135 | } 136 | private: 137 | int m_params; 138 | int m_num_actions; 139 | }; 140 | 141 | class TestSimpleRecorder : public IRecorder 142 | { 143 | public: 144 | virtual void Record(SimpleContext& context, u32 action, float probability, string unique_key) 145 | { 146 | m_interactions.push_back({ context, action, probability, unique_key }); 147 | } 148 | 149 | vector> Get_All_Interactions() 150 | { 151 | return m_interactions; 152 | } 153 | 154 | private: 155 | vector> m_interactions; 156 | }; 157 | 158 | // Return action outside valid range 159 | class TestBadPolicy : public IPolicy 160 | { 161 | public: 162 | u32 Choose_Action(TestContext& context) 163 | { 164 | return 100; 165 | } 166 | }; 167 | 168 | template 169 | class TestRecorder : public IRecorder 170 | { 171 | public: 172 | virtual void Record(TContext& context, u32 action, float probability, string unique_key) 173 | { 174 | m_interactions.push_back({ context, action, probability, unique_key }); 175 | } 176 | 177 | vector> Get_All_Interactions() 178 | { 179 | return m_interactions; 180 | } 181 | 182 | private: 183 | vector> m_interactions; 184 | }; 185 | -------------------------------------------------------------------------------- /tests/Properties/AssemblyInfo.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | using System.Runtime.CompilerServices; 3 | using System.Runtime.InteropServices; 4 | 5 | // General Information about an assembly is controlled through the following 6 | // set of attributes. Change these attribute values to modify the information 7 | // associated with an assembly. 8 | [assembly: AssemblyTitle("ExploreTests")] 9 | [assembly: AssemblyDescription("")] 10 | [assembly: AssemblyConfiguration("")] 11 | [assembly: AssemblyCompany("")] 12 | [assembly: AssemblyProduct("ExploreTests")] 13 | [assembly: AssemblyCopyright("Copyright © 2014")] 14 | [assembly: AssemblyTrademark("")] 15 | [assembly: AssemblyCulture("")] 16 | 17 | // Setting ComVisible to false makes the types in this assembly not visible 18 | // to COM components. If you need to access a type in this assembly from 19 | // COM, set the ComVisible attribute to true on that type. 20 | [assembly: ComVisible(false)] 21 | 22 | // The following GUID is for the ID of the typelib if this project is exposed to COM 23 | [assembly: Guid("26b5aea9-84b9-4627-a0c6-8d33dd3e5035")] 24 | 25 | // Version information for an assembly consists of the following four values: 26 | // 27 | // Major Version 28 | // Minor Version 29 | // Build Number 30 | // Revision 31 | // 32 | // You can specify all the values or you can default the Build and Revision Numbers 33 | // by using the '*' as shown below: 34 | // [assembly: AssemblyVersion("1.0.*")] 35 | [assembly: AssemblyVersion("1.0.0.0")] 36 | [assembly: AssemblyFileVersion("1.0.0.0")] 37 | -------------------------------------------------------------------------------- /tests/explore_tests.vcxproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | Debug 6 | Win32 7 | 8 | 9 | Debug 10 | x64 11 | 12 | 13 | Release 14 | Win32 15 | 16 | 17 | Release 18 | x64 19 | 20 | 21 | 22 | {5AE3AA40-BEB0-4979-8166-3B885172C430} 23 | Win32Proj 24 | vw_explore_tests 25 | explore_tests 26 | 27 | 28 | 29 | DynamicLibrary 30 | true 31 | v120 32 | Unicode 33 | false 34 | 35 | 36 | DynamicLibrary 37 | true 38 | v120 39 | Unicode 40 | false 41 | 42 | 43 | DynamicLibrary 44 | false 45 | v120 46 | true 47 | Unicode 48 | false 49 | 50 | 51 | DynamicLibrary 52 | false 53 | v120 54 | true 55 | Unicode 56 | false 57 | 58 | 59 | c:\boost\x64\include\boost-1_56 60 | c:\boost\x64\lib 61 | ..\..\..\zlib-1.2.8 62 | $(ZlibIncludeDir)\contrib\vstudio\vc10\x64\ZlibStat$(Configuration) 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | true 82 | 83 | 84 | true 85 | 86 | 87 | 88 | 89 | 90 | NotUsing 91 | Level3 92 | Disabled 93 | ..\static;$(VCInstallDir)UnitTest\include;%(AdditionalIncludeDirectories) 94 | WIN32;_DEBUG;%(PreprocessorDefinitions) 95 | true 96 | 97 | 98 | Windows 99 | true 100 | $(VCInstallDir)UnitTest\lib;%(AdditionalLibraryDirectories) 101 | 102 | 103 | 104 | 105 | 106 | NotUsing 107 | Level3 108 | Disabled 109 | ..\static;$(VCInstallDir)UnitTest\include;%(AdditionalIncludeDirectories) 110 | WIN32;_DEBUG;%(PreprocessorDefinitions) 111 | true 112 | 113 | 114 | Windows 115 | true 116 | $(VCInstallDir)UnitTest\lib;%(AdditionalLibraryDirectories) 117 | 118 | 119 | 120 | 121 | 122 | Level3 123 | NotUsing 124 | MaxSpeed 125 | true 126 | true 127 | ..\static;$(VCInstallDir)UnitTest\include;%(AdditionalIncludeDirectories) 128 | WIN32;NDEBUG;%(PreprocessorDefinitions) 129 | true 130 | 131 | 132 | Windows 133 | true 134 | true 135 | true 136 | $(VCInstallDir)UnitTest\lib;%(AdditionalLibraryDirectories) 137 | 138 | 139 | 140 | 141 | 142 | Level3 143 | NotUsing 144 | MaxSpeed 145 | true 146 | true 147 | ..\static;$(VCInstallDir)UnitTest\include;%(AdditionalIncludeDirectories) 148 | WIN32;NDEBUG;%(PreprocessorDefinitions) 149 | true 150 | 151 | 152 | Windows 153 | true 154 | true 155 | true 156 | $(VCInstallDir)UnitTest\lib;%(AdditionalLibraryDirectories) 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | -------------------------------------------------------------------------------- /tests/explore_tests.vcxproj.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF} 6 | cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx 7 | 8 | 9 | {93995380-89BD-4b04-88EB-625FBE52EBFB} 10 | h;hpp;hxx;hm;inl;inc;xsd 11 | 12 | 13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} 14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms 15 | 16 | 17 | 18 | 19 | Header Files 20 | 21 | 22 | 23 | 24 | Source Files 25 | 26 | 27 | -------------------------------------------------------------------------------- /tests/stdafx.cpp: -------------------------------------------------------------------------------- 1 | // stdafx.cpp : source file that includes just the standard includes 2 | // vw_explore_tests.pch will be the pre-compiled header 3 | // stdafx.obj will contain the pre-compiled type information 4 | 5 | #include "stdafx.h" 6 | 7 | // TODO: reference any additional headers you need in STDAFX.H 8 | // and not in this file 9 | -------------------------------------------------------------------------------- /tests/stdafx.h: -------------------------------------------------------------------------------- 1 | // stdafx.h : include file for standard system include files, 2 | // or project specific include files that are used frequently, but 3 | // are changed infrequently 4 | // 5 | 6 | #pragma once 7 | 8 | #include "targetver.h" 9 | #include 10 | #include 11 | 12 | // Headers for CppUnitTest 13 | #include "CppUnitTest.h" 14 | 15 | // TODO: reference additional headers your program requires here 16 | #define TEST_CPP 17 | -------------------------------------------------------------------------------- /tests/targetver.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // Including SDKDDKVer.h defines the highest available Windows platform. 4 | 5 | // If you wish to build your application for a previous Windows platform, include WinSDKVer.h and 6 | // set the _WIN32_WINNT macro to the platform you wish to support before including SDKDDKVer.h. 7 | 8 | #include 9 | --------------------------------------------------------------------------------