├── .gitignore
├── Doxyfile
├── ExploreSample
├── App.config
├── ExploreOnlySample.cs
├── ExploreSample.csproj
├── Program.cs
└── Properties
│ └── AssemblyInfo.cs
├── LICENSE
├── Makefile
├── README.md
├── ReadMe.txt
├── SECURITY.md
├── clr
├── explore_clr.vcxproj
├── explore_clr.vcxproj.filters
├── explore_clr_wrapper.cpp
├── explore_clr_wrapper.h
├── explore_interface.h
└── explore_interop.h
├── explore.cpp
├── explore.sln
├── explore.vcxproj
├── explore.vcxproj.filters
├── explore_sample.cpp
├── mwt.chm
├── static
├── MWTExplorer.h
├── Makefile
├── explore.cpp
├── explore_static.vcxproj
├── utility.h
└── vw_explore.vcxproj.filters
└── tests
├── ExploreTests.csproj
├── ExploreTests.csproj.user
├── MWTExploreTests.cpp
├── MWTExploreTests.cs
├── MWTExploreTests.h
├── Properties
└── AssemblyInfo.cs
├── explore_tests.vcxproj
├── explore_tests.vcxproj.filters
├── stdafx.cpp
├── stdafx.h
└── targetver.h
/.gitignore:
--------------------------------------------------------------------------------
1 | # build folders
2 | **/x64
3 | **/obj
4 | **/Debug
5 | **/Release
6 | **/dll
7 | **/ipch
8 |
9 | # VS files
10 | *.opensdf
11 | *.suo
12 | *.sdf
13 | *.vcxproj.user
14 | *.so
15 |
16 | # doxygen output folders
17 | html
18 | latex
--------------------------------------------------------------------------------
/ExploreSample/App.config:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/ExploreSample/ExploreOnlySample.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 | using System.Linq;
4 | using System.Text;
5 | using MultiWorldTesting;
6 |
7 | namespace cs_test
8 | {
9 | class ExploreOnlySample
10 | {
11 | ///
12 | /// Example of a custom context.
13 | ///
14 | class MyContext { }
15 |
16 | ///
17 | /// Example of a custom recorder which implements the IRecorder,
18 | /// declaring that this recorder only interacts with MyContext objects.
19 | ///
20 | class MyRecorder : IRecorder
21 | {
22 | public void Record(MyContext context, UInt32 action, float probability, string uniqueKey)
23 | {
24 | // Stores the tuple internally in a vector that could be used later for other purposes.
25 | interactions.Add(new Interaction()
26 | {
27 | Context = context,
28 | Action = action,
29 | Probability = probability,
30 | UniqueKey = uniqueKey
31 | });
32 | }
33 |
34 | public List> GetAllInteractions()
35 | {
36 | return interactions;
37 | }
38 |
39 | private List> interactions = new List>();
40 | }
41 |
42 | ///
43 | /// Example of a custom policy which implements the IPolicy,
44 | /// declaring that this policy only interacts with MyContext objects.
45 | ///
46 | class MyPolicy : IPolicy
47 | {
48 | public MyPolicy() : this(-1) { }
49 |
50 | public MyPolicy(int index)
51 | {
52 | this.index = index;
53 | }
54 |
55 | public uint ChooseAction(MyContext context)
56 | {
57 | // Always returns the same action regardless of context
58 | return 5;
59 | }
60 |
61 | private int index;
62 | }
63 |
64 | ///
65 | /// Example of a custom policy which implements the IPolicy,
66 | /// declaring that this policy only interacts with SimpleContext objects.
67 | ///
68 | class StringPolicy : IPolicy
69 | {
70 | public uint ChooseAction(SimpleContext context)
71 | {
72 | // Always returns the same action regardless of context
73 | return 1;
74 | }
75 | }
76 |
77 | ///
78 | /// Example of a custom scorer which implements the IScorer,
79 | /// declaring that this scorer only interacts with MyContext objects.
80 | ///
81 | class MyScorer : IScorer
82 | {
83 | public MyScorer(uint numActions)
84 | {
85 | this.numActions = numActions;
86 | }
87 | public List ScoreActions(MyContext context)
88 | {
89 | return Enumerable.Repeat(1.0f / numActions, (int)numActions).ToList();
90 | }
91 | private uint numActions;
92 | }
93 |
94 | ///
95 | /// Represents a tuple .
96 | ///
97 | /// The Context type.
98 | struct Interaction
99 | {
100 | public Ctx Context;
101 | public uint Action;
102 | public float Probability;
103 | public string UniqueKey;
104 | }
105 |
106 | public static void Run()
107 | {
108 | string exploration_type = "greedy";
109 |
110 | if (exploration_type == "greedy")
111 | {
112 | // Initialize Epsilon-Greedy explore algorithm using built-in StringRecorder and SimpleContext types
113 |
114 | // Creates a recorder of built-in StringRecorder type for string serialization
115 | StringRecorder recorder = new StringRecorder();
116 |
117 | // Creates an MwtExplorer instance using the recorder above
118 | MwtExplorer mwtt = new MwtExplorer("mwt", recorder);
119 |
120 | // Creates a policy that interacts with SimpleContext type
121 | StringPolicy policy = new StringPolicy();
122 |
123 | uint numActions = 10;
124 | float epsilon = 0.2f;
125 | // Creates an Epsilon-Greedy explorer using the specified settings
126 | EpsilonGreedyExplorer explorer = new EpsilonGreedyExplorer(policy, epsilon, numActions);
127 |
128 | // Creates a context of built-in SimpleContext type
129 | SimpleContext context = new SimpleContext(new Feature[] {
130 | new Feature() { Id = 1, Value = 0.5f },
131 | new Feature() { Id = 4, Value = 1.3f },
132 | new Feature() { Id = 9, Value = -0.5f },
133 | });
134 |
135 | // Performs exploration by passing an instance of the Epsilon-Greedy exploration algorithm into MwtExplorer
136 | // using a sample string to uniquely identify this event
137 | string uniqueKey = "eventid";
138 | uint action = mwtt.ChooseAction(explorer, uniqueKey, context);
139 |
140 | Console.WriteLine(recorder.GetRecording());
141 |
142 | return;
143 | }
144 | else if (exploration_type == "tau-first")
145 | {
146 | // Initialize Tau-First explore algorithm using custom Recorder, Policy & Context types
147 | MyRecorder recorder = new MyRecorder();
148 | MwtExplorer mwtt = new MwtExplorer("mwt", recorder);
149 |
150 | uint numActions = 10;
151 | uint tau = 0;
152 | MyPolicy policy = new MyPolicy();
153 | uint action = mwtt.ChooseAction(new TauFirstExplorer(policy, tau, numActions), "key", new MyContext());
154 | Console.WriteLine(String.Join(",", recorder.GetAllInteractions().Select(it => it.Action)));
155 | return;
156 | }
157 | else if (exploration_type == "bootstrap")
158 | {
159 | // Initialize Bootstrap explore algorithm using custom Recorder, Policy & Context types
160 | MyRecorder recorder = new MyRecorder();
161 | MwtExplorer mwtt = new MwtExplorer("mwt", recorder);
162 |
163 | uint numActions = 10;
164 | uint numbags = 2;
165 | MyPolicy[] policies = new MyPolicy[numbags];
166 | for (int i = 0; i < numbags; i++)
167 | {
168 | policies[i] = new MyPolicy(i * 2);
169 | }
170 | uint action = mwtt.ChooseAction(new BootstrapExplorer(policies, numActions), "key", new MyContext());
171 | Console.WriteLine(String.Join(",", recorder.GetAllInteractions().Select(it => it.Action)));
172 | return;
173 | }
174 | else if (exploration_type == "softmax")
175 | {
176 | // Initialize Softmax explore algorithm using custom Recorder, Scorer & Context types
177 | MyRecorder recorder = new MyRecorder();
178 | MwtExplorer mwtt = new MwtExplorer("mwt", recorder);
179 |
180 | uint numActions = 10;
181 | float lambda = 0.5f;
182 | MyScorer scorer = new MyScorer(numActions);
183 | uint action = mwtt.ChooseAction(new SoftmaxExplorer(scorer, lambda, numActions), "key", new MyContext());
184 |
185 | Console.WriteLine(String.Join(",", recorder.GetAllInteractions().Select(it => it.Action)));
186 | return;
187 | }
188 | else if (exploration_type == "generic")
189 | {
190 | // Initialize Generic explore algorithm using custom Recorder, Scorer & Context types
191 | MyRecorder recorder = new MyRecorder();
192 | MwtExplorer mwtt = new MwtExplorer("mwt", recorder);
193 |
194 | uint numActions = 10;
195 | MyScorer scorer = new MyScorer(numActions);
196 | uint action = mwtt.ChooseAction(new GenericExplorer(scorer, numActions), "key", new MyContext());
197 |
198 | Console.WriteLine(String.Join(",", recorder.GetAllInteractions().Select(it => it.Action)));
199 | return;
200 | }
201 | else
202 | { //add error here
203 |
204 |
205 | }
206 | }
207 | }
208 | }
209 |
--------------------------------------------------------------------------------
/ExploreSample/ExploreSample.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Debug
6 | AnyCPU
7 | {7081D542-AE64-485D-9087-79194B958499}
8 | Exe
9 | Properties
10 | ExploreSample
11 | ExploreSample
12 | v4.5
13 | 512
14 |
15 |
16 | true
17 | bin\x86\Debug\
18 | DEBUG;TRACE
19 | full
20 | x86
21 | prompt
22 | MinimumRecommendedRules.ruleset
23 | true
24 |
25 |
26 | bin\x86\Release\
27 | TRACE
28 | true
29 | pdbonly
30 | x86
31 | prompt
32 | MinimumRecommendedRules.ruleset
33 | true
34 |
35 |
36 | true
37 | bin\x64\Debug\
38 | DEBUG;TRACE
39 | full
40 | x64
41 | prompt
42 | MinimumRecommendedRules.ruleset
43 | true
44 |
45 |
46 | bin\x64\Release\
47 | TRACE
48 | true
49 | pdbonly
50 | x64
51 | prompt
52 | MinimumRecommendedRules.ruleset
53 | true
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 | {8400da16-1f46-4a31-a126-bbe16f62bfd7}
75 | explore_clr
76 |
77 |
78 |
79 |
86 |
--------------------------------------------------------------------------------
/ExploreSample/Program.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 | using System.Linq;
4 | using System.Text;
5 | using MultiWorldTesting;
6 |
7 | namespace ExploreSample
8 | {
9 | class Program
10 | {
11 | public static void Main(string[] args)
12 | {
13 | cs_test.ExploreOnlySample.Run();
14 | }
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/ExploreSample/Properties/AssemblyInfo.cs:
--------------------------------------------------------------------------------
1 | using System.Reflection;
2 | using System.Runtime.CompilerServices;
3 | using System.Runtime.InteropServices;
4 |
5 | // General Information about an assembly is controlled through the following
6 | // set of attributes. Change these attribute values to modify the information
7 | // associated with an assembly.
8 | [assembly: AssemblyTitle("ExploreSample")]
9 | [assembly: AssemblyDescription("")]
10 | [assembly: AssemblyConfiguration("")]
11 | [assembly: AssemblyCompany("")]
12 | [assembly: AssemblyProduct("ExploreSample")]
13 | [assembly: AssemblyCopyright("Copyright © 2014")]
14 | [assembly: AssemblyTrademark("")]
15 | [assembly: AssemblyCulture("")]
16 |
17 | // Setting ComVisible to false makes the types in this assembly not visible
18 | // to COM components. If you need to access a type in this assembly from
19 | // COM, set the ComVisible attribute to true on that type.
20 | [assembly: ComVisible(false)]
21 |
22 | // The following GUID is for the ID of the typelib if this project is exposed to COM
23 | [assembly: Guid("767d4e7c-6acc-4b46-9eac-e86ab079625a")]
24 |
25 | // Version information for an assembly consists of the following four values:
26 | //
27 | // Major Version
28 | // Minor Version
29 | // Build Number
30 | // Revision
31 | //
32 | // You can specify all the values or you can default the Build and Revision Numbers
33 | // by using the '*' as shown below:
34 | // [assembly: AssemblyVersion("1.0.*")]
35 | [assembly: AssemblyVersion("1.0.0.0")]
36 | [assembly: AssemblyFileVersion("1.0.0.0")]
37 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright © Microsoft Corp 2012-2014, Yahoo! Inc. 2007-2012, and many
2 | individual contributors.
3 |
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 |
8 | modification, are permitted provided that the following conditions are met:
9 |
10 | * Redistributions of source code must retain the above copyright
11 |
12 | notice, this list of conditions and the following disclaimer.
13 |
14 | * Redistributions in binary form must reproduce the above copyright
15 |
16 | notice, this list of conditions and the following disclaimer in the
17 |
18 | documentation and/or other materials provided with the distribution.
19 |
20 | * Neither the name of the Microsoft Corp nor the
21 |
22 | names of its contributors may be used to endorse or promote products
23 |
24 | derived from this software without specific prior written permission.
25 |
26 |
27 |
28 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
29 |
30 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
31 |
32 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
33 |
34 | DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY
35 |
36 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
37 |
38 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
39 |
40 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
41 |
42 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
43 |
44 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
45 |
46 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
47 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | all: sample
2 |
3 | time: explore.cpp
4 | g++ -Wall -O3 explore.cpp -I static -I ../vowpalwabbit -std=c++0x
5 |
6 | sample: explore_sample.cpp
7 | g++ -Wall -O3 explore_sample.cpp -I static -I ../vowpalwabbit -std=c++0x
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Exploration Library
2 | =======
3 |
4 | [WARNING: This repository is deprecated. See https://github.com/multiworldtesting/explore-csharp for C# and https://github.com/multiworldtesting/explore-cpp for C++ instead]
5 |
6 | The exploration library addresses the ‘gathering the data’ aspect of machine learning rather than the ‘using the data’ aspect we are most familiar with. The primary goal here is to enable individuals (i.e. you) to gather the right data for using machine learning for interventions in a live system based on user feedback (click, dwell, correction, etc…). Empirically, gathering the right data has often made a substantial difference. Theoretically, we know it is required to disentangle causation from correlation effectively in general.
7 |
8 | First version of client-side exploration library that includes the following exploration algorithms:
9 | - Epsilon Greedy
10 | - Tau First
11 | - Softmax
12 | - Bootstrap
13 | - Generic (users can specify custom weight for every action)
14 |
15 | This release supports C++ and C#.
16 |
17 | For sample usage, see:
18 |
19 | C++: https://github.com/multiworldtesting/explore/blob/master/explore_sample.cpp
20 |
21 | C#: https://github.com/multiworldtesting/explore/blob/master/ExploreSample/ExploreOnlySample.cs
22 |
--------------------------------------------------------------------------------
/ReadMe.txt:
--------------------------------------------------------------------------------
1 | This is the root of the exploration library and client-side decision service.
2 |
3 | explore_sample.cpp shows how to use the exploration library which is a
4 | header-only include in C++.
5 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/clr/explore_clr.vcxproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Debug
6 | Win32
7 |
8 |
9 | Debug
10 | x64
11 |
12 |
13 | Release
14 | Win32
15 |
16 |
17 | Release
18 | x64
19 |
20 |
21 |
22 | {8400DA16-1F46-4A31-A126-BBE16F62BFD7}
23 | Win32Proj
24 | vw_explore_clr_wrapper
25 | explore_clr
26 |
27 |
28 |
29 | DynamicLibrary
30 | true
31 | v120
32 | Unicode
33 | true
34 |
35 |
36 | DynamicLibrary
37 | true
38 | v120
39 | Unicode
40 | true
41 |
42 |
43 | DynamicLibrary
44 | false
45 | v120
46 | true
47 | Unicode
48 | true
49 |
50 |
51 | DynamicLibrary
52 | false
53 | v120
54 | true
55 | Unicode
56 | true
57 |
58 |
59 | c:\boost\x64\include\boost-1_56
60 | c:\boost\x64\lib
61 | ..\..\..\zlib-1.2.8
62 | $(ZlibIncludeDir)\contrib\vstudio\vc11\x64\ZlibStat$(Configuration)
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 | true
82 |
83 |
84 | true
85 |
86 |
87 | false
88 |
89 |
90 | false
91 |
92 |
93 |
94 | NotUsing
95 | Level3
96 | Disabled
97 | WIN32;_DEBUG;_WINDOWS;_USRDLL;VW_EXPLORE_CLR_WRAPPER_EXPORTS;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions)
98 | true
99 | ..\static;%(AdditionalIncludeDirectories)
100 | true
101 |
102 |
103 | Windows
104 | true
105 |
106 |
107 |
108 |
109 | NotUsing
110 | Level3
111 | Disabled
112 | WIN32;_DEBUG;_WINDOWS;_USRDLL;VW_EXPLORE_CLR_WRAPPER_EXPORTS;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions)
113 | true
114 | ..\static;%(AdditionalIncludeDirectories)
115 | true
116 |
117 |
118 | Windows
119 | true
120 |
121 |
122 |
123 |
124 | Level3
125 | NotUsing
126 | MaxSpeed
127 | true
128 | true
129 | WIN32;NDEBUG;_WINDOWS;_USRDLL;VW_EXPLORE_CLR_WRAPPER_EXPORTS;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions)
130 | true
131 | ..\static;%(AdditionalIncludeDirectories)
132 | true
133 |
134 |
135 | Windows
136 | true
137 | true
138 | true
139 |
140 |
141 |
142 |
143 | Level3
144 | NotUsing
145 | MaxSpeed
146 | true
147 | true
148 | WIN32;NDEBUG;_WINDOWS;_USRDLL;VW_EXPLORE_CLR_WRAPPER_EXPORTS;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions)
149 | true
150 | ..\static;%(AdditionalIncludeDirectories)
151 | true
152 |
153 |
154 | Windows
155 | true
156 | true
157 | true
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 | {ace47e98-488c-4cdf-b9f1-36337b2855ad}
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
--------------------------------------------------------------------------------
/clr/explore_clr.vcxproj.filters:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF}
6 | cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx
7 |
8 |
9 | {93995380-89BD-4b04-88EB-625FBE52EBFB}
10 | h;hpp;hxx;hm;inl;inc;xsd
11 |
12 |
13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01}
14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 | Header Files
23 |
24 |
25 | Header Files
26 |
27 |
28 | Header Files
29 |
30 |
31 |
32 |
33 | Source Files
34 |
35 |
36 |
--------------------------------------------------------------------------------
/clr/explore_clr_wrapper.cpp:
--------------------------------------------------------------------------------
1 | // vw_explore_clr_wrapper.cpp : Defines the exported functions for the DLL application.
2 | //
3 |
4 | #define WIN32_LEAN_AND_MEAN
5 | #include
6 |
7 | #include "explore_clr_wrapper.h"
8 |
9 | using namespace System;
10 | using namespace System::Collections;
11 | using namespace System::Collections::Generic;
12 | using namespace System::Runtime::InteropServices;
13 | using namespace msclr::interop;
14 | using namespace NativeMultiWorldTesting;
15 |
16 | namespace MultiWorldTesting {
17 |
18 | }
19 |
--------------------------------------------------------------------------------
/clr/explore_clr_wrapper.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "explore_interop.h"
3 |
4 | /*!
5 | * \addtogroup MultiWorldTestingCsharp
6 | * @{
7 | */
8 | namespace MultiWorldTesting {
9 |
10 | ///
11 | /// The epsilon greedy exploration class.
12 | ///
13 | ///
14 | /// This is a good choice if you have no idea which actions should be preferred.
15 | /// Epsilon greedy is also computationally cheap.
16 | ///
17 | /// The Context type.
18 | generic
19 | public ref class EpsilonGreedyExplorer : public IExplorer, public IConsumePolicy, public PolicyCallback
20 | {
21 | public:
22 | ///
23 | /// The constructor is the only public member, because this should be used with the MwtExplorer.
24 | ///
25 | /// A default function which outputs an action given a context.
26 | /// The probability of a random exploration.
27 | /// The number of actions to randomize over.
28 | EpsilonGreedyExplorer(IPolicy^ defaultPolicy, float epsilon, UInt32 numActions)
29 | {
30 | this->defaultPolicy = defaultPolicy;
31 | m_explorer = new NativeMultiWorldTesting::EpsilonGreedyExplorer(*GetNativePolicy(), epsilon, (u32)numActions);
32 | }
33 |
34 | ///
35 | /// Initializes an epsilon greedy explorer with variable number of actions.
36 | ///
37 | /// A default function which outputs an action given a context.
38 | /// The probability of a random exploration.
39 | EpsilonGreedyExplorer(IPolicy^ defaultPolicy, float epsilon)
40 | {
41 | if (!(IVariableActionContext::typeid->IsAssignableFrom(Ctx::typeid)))
42 | {
43 | throw gcnew ArgumentException("The specified context type does not implement variable-action interface.", "Ctx");
44 | }
45 |
46 | this->defaultPolicy = defaultPolicy;
47 | m_explorer = new NativeMultiWorldTesting::EpsilonGreedyExplorer(*GetNativePolicy(), epsilon);
48 | }
49 |
50 | ~EpsilonGreedyExplorer()
51 | {
52 | delete m_explorer;
53 | }
54 |
55 | virtual void UpdatePolicy(IPolicy^ newPolicy)
56 | {
57 | this->defaultPolicy = newPolicy;
58 | }
59 |
60 | virtual void EnableExplore(bool explore)
61 | {
62 | m_explorer->Enable_Explore(explore);
63 | }
64 |
65 | internal:
66 | virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
67 | {
68 | return defaultPolicy->ChooseAction(context);
69 | }
70 |
71 | NativeMultiWorldTesting::EpsilonGreedyExplorer* Get()
72 | {
73 | return m_explorer;
74 | }
75 |
76 | private:
77 | IPolicy^ defaultPolicy;
78 | NativeMultiWorldTesting::EpsilonGreedyExplorer* m_explorer;
79 | };
80 |
81 | ///
82 | /// The tau-first exploration class.
83 | ///
84 | ///
85 | /// The tau-first explorer collects precisely tau uniform random
86 | /// exploration events, and then uses the default policy.
87 | ///
88 | /// The Context type.
89 | generic
90 | public ref class TauFirstExplorer : public IExplorer, public IConsumePolicy, public PolicyCallback
91 | {
92 | public:
93 | ///
94 | /// The constructor is the only public member, because this should be used with the MwtExplorer.
95 | ///
96 | /// A default policy after randomization finishes.
97 | /// The number of events to be uniform over.
98 | /// The number of actions to randomize over.
99 | TauFirstExplorer(IPolicy^ defaultPolicy, UInt32 tau, UInt32 numActions)
100 | {
101 | this->defaultPolicy = defaultPolicy;
102 | m_explorer = new NativeMultiWorldTesting::TauFirstExplorer(*GetNativePolicy(), tau, (u32)numActions);
103 | }
104 |
105 | ///
106 | /// Initializes a tau-first explorer with variable number of actions.
107 | ///
108 | /// A default policy after randomization finishes.
109 | /// The number of events to be uniform over.
110 | TauFirstExplorer(IPolicy^ defaultPolicy, UInt32 tau)
111 | {
112 | if (!(IVariableActionContext::typeid->IsAssignableFrom(Ctx::typeid)))
113 | {
114 | throw gcnew ArgumentException("The specified context type does not implement variable-action interface.", "Ctx");
115 | }
116 |
117 | this->defaultPolicy = defaultPolicy;
118 | m_explorer = new NativeMultiWorldTesting::TauFirstExplorer(*GetNativePolicy(), tau);
119 | }
120 |
121 | virtual void UpdatePolicy(IPolicy^ newPolicy)
122 | {
123 | this->defaultPolicy = newPolicy;
124 | }
125 |
126 | virtual void EnableExplore(bool explore)
127 | {
128 | m_explorer->Enable_Explore(explore);
129 | }
130 |
131 | ~TauFirstExplorer()
132 | {
133 | delete m_explorer;
134 | }
135 |
136 | internal:
137 | virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
138 | {
139 | return defaultPolicy->ChooseAction(context);
140 | }
141 |
142 | NativeMultiWorldTesting::TauFirstExplorer* Get()
143 | {
144 | return m_explorer;
145 | }
146 |
147 | private:
148 | IPolicy^ defaultPolicy;
149 | NativeMultiWorldTesting::TauFirstExplorer* m_explorer;
150 | };
151 |
152 | ///
153 | /// The epsilon greedy exploration class.
154 | ///
155 | ///
156 | /// In some cases, different actions have a different scores, and you
157 | /// would prefer to choose actions with large scores. Softmax allows
158 | /// you to do that.
159 | ///
160 | /// The Context type.
161 | generic
162 | public ref class SoftmaxExplorer : public IExplorer, public IConsumeScorer, public ScorerCallback
163 | {
164 | public:
165 | ///
166 | /// The constructor is the only public member, because this should be used with the MwtExplorer.
167 | ///
168 | /// A function which outputs a score for each action.
169 | /// lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.
170 | /// The number of actions to randomize over.
171 | SoftmaxExplorer(IScorer^ defaultScorer, float lambda, UInt32 numActions)
172 | {
173 | this->defaultScorer = defaultScorer;
174 | m_explorer = new NativeMultiWorldTesting::SoftmaxExplorer(*GetNativeScorer(), lambda, (u32)numActions);
175 | }
176 |
177 | ///
178 | /// Initializes a softmax explorer with variable number of actions.
179 | ///
180 | /// A function which outputs a score for each action.
181 | /// lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.
182 | SoftmaxExplorer(IScorer^ defaultScorer, float lambda)
183 | {
184 | if (!(IVariableActionContext::typeid->IsAssignableFrom(Ctx::typeid)))
185 | {
186 | throw gcnew ArgumentException("The specified context type does not implement variable-action interface.", "Ctx");
187 | }
188 |
189 | this->defaultScorer = defaultScorer;
190 | m_explorer = new NativeMultiWorldTesting::SoftmaxExplorer(*GetNativeScorer(), lambda);
191 | }
192 |
193 | virtual void UpdateScorer(IScorer^ newScorer)
194 | {
195 | this->defaultScorer = newScorer;
196 | }
197 |
198 | virtual void EnableExplore(bool explore)
199 | {
200 | m_explorer->Enable_Explore(explore);
201 | }
202 |
203 | ~SoftmaxExplorer()
204 | {
205 | delete m_explorer;
206 | }
207 |
208 | internal:
209 | virtual List^ InvokeScorerCallback(Ctx context) override
210 | {
211 | return defaultScorer->ScoreActions(context);
212 | }
213 |
214 | NativeMultiWorldTesting::SoftmaxExplorer* Get()
215 | {
216 | return m_explorer;
217 | }
218 |
219 | private:
220 | IScorer^ defaultScorer;
221 | NativeMultiWorldTesting::SoftmaxExplorer* m_explorer;
222 | };
223 |
224 | ///
225 | /// The generic exploration class.
226 | ///
227 | ///
228 | /// GenericExplorer provides complete flexibility. You can create any
229 | /// distribution over actions desired, and it will draw from that.
230 | ///
231 | /// The Context type.
232 | generic
233 | public ref class GenericExplorer : public IExplorer, public IConsumeScorer, public ScorerCallback
234 | {
235 | public:
236 | ///
237 | /// The constructor is the only public member, because this should be used with the MwtExplorer.
238 | ///
239 | /// A function which outputs the probability of each action.
240 | /// The number of actions to randomize over.
241 | GenericExplorer(IScorer^ defaultScorer, UInt32 numActions)
242 | {
243 | this->defaultScorer = defaultScorer;
244 | m_explorer = new NativeMultiWorldTesting::GenericExplorer(*GetNativeScorer(), (u32)numActions);
245 | }
246 |
247 | ///
248 | /// Initializes a generic explorer with variable number of actions.
249 | ///
250 | /// A function which outputs the probability of each action.
251 | GenericExplorer(IScorer^ defaultScorer)
252 | {
253 | if (!(IVariableActionContext::typeid->IsAssignableFrom(Ctx::typeid)))
254 | {
255 | throw gcnew ArgumentException("The specified context type does not implement variable-action interface.", "Ctx");
256 | }
257 |
258 | this->defaultScorer = defaultScorer;
259 | m_explorer = new NativeMultiWorldTesting::GenericExplorer(*GetNativeScorer());
260 | }
261 |
262 | virtual void UpdateScorer(IScorer^ newScorer)
263 | {
264 | this->defaultScorer = newScorer;
265 | }
266 |
267 | virtual void EnableExplore(bool explore)
268 | {
269 | m_explorer->Enable_Explore(explore);
270 | }
271 |
272 | ~GenericExplorer()
273 | {
274 | delete m_explorer;
275 | }
276 |
277 | internal:
278 | virtual List^ InvokeScorerCallback(Ctx context) override
279 | {
280 | return defaultScorer->ScoreActions(context);
281 | }
282 |
283 | NativeMultiWorldTesting::GenericExplorer* Get()
284 | {
285 | return m_explorer;
286 | }
287 |
288 | private:
289 | IScorer^ defaultScorer;
290 | NativeMultiWorldTesting::GenericExplorer* m_explorer;
291 | };
292 |
293 | ///
294 | /// The bootstrap exploration class.
295 | ///
296 | ///
297 | /// The Bootstrap explorer randomizes over the actions chosen by a set of
298 | /// default policies. This performs well statistically but can be
299 | /// computationally expensive.
300 | ///
301 | /// The Context type.
302 | generic
303 | public ref class BootstrapExplorer : public IExplorer, public IConsumePolicies, public PolicyCallback
304 | {
305 | public:
306 | ///
307 | /// The constructor is the only public member, because this should be used with the MwtExplorer.
308 | ///
309 | /// A set of default policies to be uniform random over.
310 | /// The number of actions to randomize over.
311 | BootstrapExplorer(cli::array^>^ defaultPolicies, UInt32 numActions)
312 | {
313 | this->defaultPolicies = defaultPolicies;
314 | if (this->defaultPolicies == nullptr)
315 | {
316 | throw gcnew ArgumentNullException("The specified array of default policy functions cannot be null.");
317 | }
318 |
319 | m_explorer = new NativeMultiWorldTesting::BootstrapExplorer(*GetNativePolicies((u32)defaultPolicies->Length), (u32)numActions);
320 | }
321 |
322 | ///
323 | /// Initializes a bootstrap explorer with variable number of actions.
324 | ///
325 | /// A set of default policies to be uniform random over.
326 | BootstrapExplorer(cli::array^>^ defaultPolicies)
327 | {
328 | if (!(IVariableActionContext::typeid->IsAssignableFrom(Ctx::typeid)))
329 | {
330 | throw gcnew ArgumentException("The specified context type does not implement variable-action interface.", "Ctx");
331 | }
332 |
333 | this->defaultPolicies = defaultPolicies;
334 | if (this->defaultPolicies == nullptr)
335 | {
336 | throw gcnew ArgumentNullException("The specified array of default policy functions cannot be null.");
337 | }
338 |
339 | m_explorer = new NativeMultiWorldTesting::BootstrapExplorer(*GetNativePolicies((u32)defaultPolicies->Length));
340 | }
341 |
342 | virtual void UpdatePolicy(cli::array^>^ newPolicies)
343 | {
344 | this->defaultPolicies = newPolicies;
345 | }
346 |
347 | virtual void EnableExplore(bool explore)
348 | {
349 | m_explorer->Enable_Explore(explore);
350 | }
351 |
352 | ~BootstrapExplorer()
353 | {
354 | delete m_explorer;
355 | }
356 |
357 | internal:
358 | virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
359 | {
360 | if (index < 0 || index >= defaultPolicies->Length)
361 | {
362 | throw gcnew InvalidDataException("Internal error: Index of interop bag is out of range.");
363 | }
364 | return defaultPolicies[index]->ChooseAction(context);
365 | }
366 |
367 | NativeMultiWorldTesting::BootstrapExplorer* Get()
368 | {
369 | return m_explorer;
370 | }
371 |
372 | private:
373 | cli::array^>^ defaultPolicies;
374 | NativeMultiWorldTesting::BootstrapExplorer* m_explorer;
375 | };
376 |
377 | ///
378 | /// The top level MwtExplorer class. Using this makes sure that the
379 | /// right bits are recorded and good random actions are chosen.
380 | ///
381 | /// The Context type.
382 | generic
383 | public ref class MwtExplorer : public RecorderCallback
384 | {
385 | public:
386 | ///
387 | /// Constructor.
388 | ///
389 | /// This should be unique to each experiment to avoid correlation bugs.
390 | /// A user-specified class for recording the appropriate bits for use in evaluation and learning.
391 | MwtExplorer(String^ appId, IRecorder^ recorder)
392 | {
393 | this->appId = appId;
394 | this->recorder = recorder;
395 | }
396 |
397 | ///
398 | /// Choose_Action should be drop-in replacement for any existing policy function.
399 | ///
400 | /// An existing exploration algorithm (one of the above) which uses the default policy as a callback.
401 | /// A unique identifier for the experimental unit. This could be a user id, a session id, etc...
402 | /// The context upon which a decision is made. See SimpleContext above for an example.
403 | /// An unsigned 32-bit integer representing the 1-based chosen action.
404 | UInt32 ChooseAction(IExplorer^ explorer, String^ unique_key, Ctx context)
405 | {
406 | String^ salt = this->appId;
407 | NativeMultiWorldTesting::MwtExplorer mwt(marshal_as(salt), *GetNativeRecorder());
408 |
409 | // Normal handles are sufficient here since native code will only hold references and not access the object's data
410 | // https://www.microsoftpressstore.com/articles/article.aspx?p=2224054&seqNum=4
411 | GCHandle selfHandle = GCHandle::Alloc(this);
412 | IntPtr selfPtr = (IntPtr)selfHandle;
413 |
414 | GCHandle contextHandle = GCHandle::Alloc(context);
415 | IntPtr contextPtr = (IntPtr)contextHandle;
416 |
417 | GCHandle explorerHandle = GCHandle::Alloc(explorer);
418 | IntPtr explorerPtr = (IntPtr)explorerHandle;
419 |
420 | try
421 | {
422 | NativeContext native_context(selfPtr.ToPointer(), explorerPtr.ToPointer(), contextPtr.ToPointer(),
423 | this->GetNumActionsCallback());
424 | u32 action = 0;
425 | if (explorer->GetType() == EpsilonGreedyExplorer::typeid)
426 | {
427 | EpsilonGreedyExplorer^ epsilonGreedyExplorer = (EpsilonGreedyExplorer^)explorer;
428 | action = mwt.Choose_Action(*epsilonGreedyExplorer->Get(), marshal_as(unique_key), native_context);
429 | }
430 | else if (explorer->GetType() == TauFirstExplorer::typeid)
431 | {
432 | TauFirstExplorer^ tauFirstExplorer = (TauFirstExplorer^)explorer;
433 | action = mwt.Choose_Action(*tauFirstExplorer->Get(), marshal_as(unique_key), native_context);
434 | }
435 | else if (explorer->GetType() == SoftmaxExplorer::typeid)
436 | {
437 | SoftmaxExplorer^ softmaxExplorer = (SoftmaxExplorer^)explorer;
438 | action = mwt.Choose_Action(*softmaxExplorer->Get(), marshal_as(unique_key), native_context);
439 | }
440 | else if (explorer->GetType() == GenericExplorer::typeid)
441 | {
442 | GenericExplorer^ genericExplorer = (GenericExplorer^)explorer;
443 | action = mwt.Choose_Action(*genericExplorer->Get(), marshal_as(unique_key), native_context);
444 | }
445 | else if (explorer->GetType() == BootstrapExplorer::typeid)
446 | {
447 | BootstrapExplorer^ bootstrapExplorer = (BootstrapExplorer^)explorer;
448 | action = mwt.Choose_Action(*bootstrapExplorer->Get(), marshal_as(unique_key), native_context);
449 | }
450 | return action;
451 | }
452 | finally
453 | {
454 | if (explorerHandle.IsAllocated)
455 | {
456 | explorerHandle.Free();
457 | }
458 | if (contextHandle.IsAllocated)
459 | {
460 | contextHandle.Free();
461 | }
462 | if (selfHandle.IsAllocated)
463 | {
464 | selfHandle.Free();
465 | }
466 | }
467 | }
468 |
469 | internal:
470 | virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) override
471 | {
472 | recorder->Record(context, action, probability, unique_key);
473 | }
474 |
475 | private:
476 | IRecorder^ recorder;
477 | String^ appId;
478 | };
479 |
480 | ///
481 | /// Represents a feature in a sparse array.
482 | ///
483 | [StructLayout(LayoutKind::Sequential)]
484 | public value struct Feature
485 | {
486 | float Value;
487 | UInt32 Id;
488 | };
489 |
490 | ///
491 | /// A sample recorder class that converts the exploration tuple into string format.
492 | ///
493 | /// The Context type.
494 | generic where Ctx : IStringContext
495 | public ref class StringRecorder : public IRecorder, public ToStringCallback
496 | {
497 | public:
498 | StringRecorder()
499 | {
500 | m_string_recorder = new NativeMultiWorldTesting::StringRecorder();
501 | }
502 |
503 | ~StringRecorder()
504 | {
505 | delete m_string_recorder;
506 | }
507 |
508 | virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey)
509 | {
510 | // Normal handles are sufficient here since native code will only hold references and not access the object's data
511 | // https://www.microsoftpressstore.com/articles/article.aspx?p=2224054&seqNum=4
512 | GCHandle contextHandle = GCHandle::Alloc(context);
513 | IntPtr contextPtr = (IntPtr)contextHandle;
514 |
515 | NativeStringContext native_context(contextPtr.ToPointer(), GetCallback());
516 | m_string_recorder->Record(native_context, (u32)action, probability, marshal_as(uniqueKey));
517 |
518 | contextHandle.Free();
519 | }
520 |
521 | ///
522 | /// Gets the content of the recording so far as a string and clears internal content.
523 | ///
524 | ///
525 | /// A string with recording content.
526 | ///
527 | String^ GetRecording()
528 | {
529 | // Workaround for C++-CLI bug which does not allow default value for parameter
530 | return GetRecording(true);
531 | }
532 |
533 | ///
534 | /// Gets the content of the recording so far as a string and optionally clears internal content.
535 | ///
536 | /// A boolean value indicating whether to clear the internal content.
537 | ///
538 | /// A string with recording content.
539 | ///
540 | String^ GetRecording(bool flush)
541 | {
542 | return gcnew String(m_string_recorder->Get_Recording(flush).c_str());
543 | }
544 |
545 | private:
546 | NativeMultiWorldTesting::StringRecorder* m_string_recorder;
547 | };
548 |
549 | ///
550 | /// A sample context class that stores a vector of Features.
551 | ///
552 | public ref class SimpleContext : public IStringContext
553 | {
554 | public:
555 | SimpleContext(cli::array^ features)
556 | {
557 | Features = features;
558 |
559 | // TODO: add another constructor overload for native SimpleContext to avoid copying feature values
560 | m_features = new vector();
561 | for (int i = 0; i < features->Length; i++)
562 | {
563 | m_features->push_back({ features[i].Value, features[i].Id });
564 | }
565 |
566 | m_native_context = new NativeMultiWorldTesting::SimpleContext(*m_features);
567 | }
568 |
569 | String^ ToString() override
570 | {
571 | return gcnew String(m_native_context->To_String().c_str());
572 | }
573 |
574 | ~SimpleContext()
575 | {
576 | delete m_native_context;
577 | }
578 |
579 | public:
580 | cli::array^ GetFeatures() { return Features; }
581 |
582 | internal:
583 | cli::array^ Features;
584 |
585 | private:
586 | vector* m_features;
587 | NativeMultiWorldTesting::SimpleContext* m_native_context;
588 | };
589 | }
590 |
591 | /*! @} End of Doxygen Groups*/
592 |
--------------------------------------------------------------------------------
/clr/explore_interface.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | using namespace System;
4 | using namespace System::Collections::Generic;
5 |
6 | /** \defgroup MultiWorldTestingCsharp
7 | \brief C# implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs
8 | */
9 |
10 | /*!
11 | * \addtogroup MultiWorldTestingCsharp
12 | * @{
13 | */
14 |
15 | //! Interface for C# version of Multiworld Testing library.
16 | //! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs
17 | namespace MultiWorldTesting {
18 |
19 | ///
20 | /// Represents a recorder that exposes a method to record exploration data based on generic contexts.
21 | ///
22 | /// The Context type.
23 | ///
24 | /// Exploration data is specified as a set of tuples (context, action, probability, key) as described below. An
25 | /// application passes an IRecorder object to the @MwtExplorer constructor. See
26 | /// @StringRecorder for a sample IRecorder object.
27 | ///
28 | generic
29 | public interface class IRecorder
30 | {
31 | public:
32 | ///
33 | /// Records the exploration data associated with a given decision.
34 | /// This implementation should be thread-safe if multithreading is needed.
35 | ///
36 | /// A user-defined context for the decision.
37 | /// Chosen by an exploration algorithm given context.
38 | /// The probability of the chosen action given context.
39 | /// A user-defined identifer for the decision.
40 | virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey) = 0;
41 | };
42 |
43 | ///
44 | /// Exposes a method for choosing an action given a generic context. IPolicy objects are
45 | /// passed to (and invoked by) exploration algorithms to specify the default policy behavior.
46 | ///
47 | /// The Context type.
48 | generic
49 | public interface class IPolicy
50 | {
51 | public:
52 | ///
53 | /// Determines the action to take for a given context.
54 | /// This implementation should be thread-safe if multithreading is needed.
55 | ///
56 | /// A user-defined context for the decision.
57 | /// Index of the action to take (1-based)
58 | virtual UInt32 ChooseAction(Ctx context) = 0;
59 | };
60 |
61 | ///
62 | /// Exposes a method for specifying a score (weight) for each action given a generic context.
63 | ///
64 | /// The Context type.
65 | generic
66 | public interface class IScorer
67 | {
68 | public:
69 | ///
70 | /// Determines the score of each action for a given context.
71 | /// This implementation should be thread-safe if multithreading is needed.
72 | ///
73 | /// A user-defined context for the decision.
74 | /// Vector of scores indexed by action (1-based).
75 | virtual List^ ScoreActions(Ctx context) = 0;
76 | };
77 |
78 | ///
79 | /// Represents a context interface with variable number of actions which is
80 | /// enforced if exploration algorithm is initialized in variable number of actions mode.
81 | ///
82 | public interface class IVariableActionContext
83 | {
84 | public:
85 | ///
86 | /// Gets the number of actions for the current context.
87 | ///
88 | /// The number of actions available for the current context.
89 | virtual UInt32 GetNumberOfActions() = 0;
90 | };
91 |
92 | generic
93 | public interface class IExplorer
94 | {
95 | public:
96 | virtual void EnableExplore(bool explore) = 0;
97 | };
98 |
99 | generic
100 | public interface class IConsumePolicy
101 | {
102 | public:
103 | virtual void UpdatePolicy(IPolicy^ newPolicy) = 0;
104 | };
105 |
106 | generic
107 | public interface class IConsumePolicies
108 | {
109 | public:
110 | virtual void UpdatePolicy(cli::array^>^ newPolicies) = 0;
111 | };
112 |
113 | generic
114 | public interface class IConsumeScorer
115 | {
116 | public:
117 | virtual void UpdateScorer(IScorer^ newScorer) = 0;
118 | };
119 |
120 | public interface class IStringContext
121 | {
122 | public:
123 | virtual String^ ToString() = 0;
124 | };
125 |
126 | }
127 |
128 | /*! @} End of Doxygen Groups*/
--------------------------------------------------------------------------------
/clr/explore_interop.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #define MANAGED_CODE
4 |
5 | #include "explore_interface.h"
6 | #include "MWTExplorer.h"
7 |
8 | #include
9 |
10 | using namespace System;
11 | using namespace System::Collections::Generic;
12 | using namespace System::IO;
13 | using namespace System::Runtime::InteropServices;
14 | using namespace System::Xml::Serialization;
15 | using namespace msclr::interop;
16 |
17 | namespace MultiWorldTesting {
18 |
19 | // Context callback
20 | private delegate UInt32 ClrContextGetNumActionsCallback(IntPtr contextPtr);
21 | typedef u32 Native_Context_Get_Num_Actions_Callback(void* context);
22 |
23 | // Policy callback
24 | private delegate UInt32 ClrPolicyCallback(IntPtr explorerPtr, IntPtr contextPtr, int index);
25 | typedef u32 Native_Policy_Callback(void* explorer, void* context, int index);
26 |
27 | // Scorer callback
28 | private delegate void ClrScorerCallback(IntPtr explorerPtr, IntPtr contextPtr, IntPtr scores, IntPtr size);
29 | typedef void Native_Scorer_Callback(void* explorer, void* context, float* scores[], u32* size);
30 |
31 | // Recorder callback
32 | private delegate void ClrRecorderCallback(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKey);
33 | typedef void Native_Recorder_Callback(void* mwt, void* context, u32 action, float probability, void* unique_key);
34 |
35 | // ToString callback
36 | private delegate void ClrToStringCallback(IntPtr contextPtr, IntPtr stringValue);
37 | typedef void Native_To_String_Callback(void* explorer, void* string_value);
38 |
39 | // NativeContext travels through interop space and contains instances of Mwt, Explorer, Context
40 | // used for triggering callback for Policy, Scorer, Recorder
41 | class NativeContext : public NativeMultiWorldTesting::IVariableActionContext
42 | {
43 | public:
44 | NativeContext(void* clr_mwt, void* clr_explorer, void* clr_context,
45 | Native_Context_Get_Num_Actions_Callback* callback_num_actions)
46 | {
47 | m_clr_mwt = clr_mwt;
48 | m_clr_explorer = clr_explorer;
49 | m_clr_context = clr_context;
50 |
51 | m_callback_num_actions = callback_num_actions;
52 | }
53 |
54 | u32 Get_Number_Of_Actions()
55 | {
56 | return m_callback_num_actions(m_clr_context);
57 | }
58 |
59 | void* Get_Clr_Mwt()
60 | {
61 | return m_clr_mwt;
62 | }
63 |
64 | void* Get_Clr_Context()
65 | {
66 | return m_clr_context;
67 | }
68 |
69 | void* Get_Clr_Explorer()
70 | {
71 | return m_clr_explorer;
72 | }
73 |
74 | private:
75 | void* m_clr_mwt;
76 | void* m_clr_context;
77 | void* m_clr_explorer;
78 |
79 | private:
80 | Native_Context_Get_Num_Actions_Callback* m_callback_num_actions;
81 | };
82 |
83 | class NativeStringContext
84 | {
85 | public:
86 | NativeStringContext(void* clr_context, Native_To_String_Callback* func) : m_func(func)
87 | {
88 | m_clr_context = clr_context;
89 | }
90 |
91 | string To_String()
92 | {
93 | string value;
94 | m_func(m_clr_context, &value);
95 | return value;
96 | }
97 | private:
98 | void* m_clr_context;
99 | Native_To_String_Callback* const m_func;
100 | };
101 |
102 | // NativeRecorder listens to callback event and reroute it to the managed Recorder instance
103 | class NativeRecorder : public NativeMultiWorldTesting::IRecorder
104 | {
105 | public:
106 | NativeRecorder(Native_Recorder_Callback* native_func) : m_func(native_func)
107 | {
108 | }
109 |
110 | void Record(NativeContext& context, u32 action, float probability, string unique_key)
111 | {
112 | // Normal handles are sufficient here since native code will only hold references and not access the object's data
113 | // https://www.microsoftpressstore.com/articles/article.aspx?p=2224054&seqNum=4
114 | GCHandle uniqueKeyHandle = GCHandle::Alloc(gcnew String(unique_key.c_str()));
115 | try
116 | {
117 | IntPtr uniqueKeyPtr = (IntPtr)uniqueKeyHandle;
118 |
119 | m_func(context.Get_Clr_Mwt(), context.Get_Clr_Context(), action, probability, uniqueKeyPtr.ToPointer());
120 | }
121 | finally
122 | {
123 | if (uniqueKeyHandle.IsAllocated)
124 | {
125 | uniqueKeyHandle.Free();
126 | }
127 | }
128 | }
129 | private:
130 | Native_Recorder_Callback* const m_func;
131 | };
132 |
133 | // NativePolicy listens to callback event and reroute it to the managed Policy instance
134 | class NativePolicy : public NativeMultiWorldTesting::IPolicy
135 | {
136 | public:
137 | NativePolicy(Native_Policy_Callback* func, int index = -1) : m_func(func)
138 | {
139 | m_index = index;
140 | }
141 |
142 | u32 Choose_Action(NativeContext& context)
143 | {
144 | return m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), m_index);
145 | }
146 |
147 | private:
148 | Native_Policy_Callback* const m_func;
149 | int m_index;
150 | };
151 |
152 | class NativeScorer : public NativeMultiWorldTesting::IScorer
153 | {
154 | public:
155 | NativeScorer(Native_Scorer_Callback* func) : m_func(func)
156 | {
157 | }
158 |
159 | vector Score_Actions(NativeContext& context)
160 | {
161 | float* scores = nullptr;
162 | u32 num_scores = 0;
163 | try
164 | {
165 | m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), &scores, &num_scores);
166 |
167 | // It's ok if scores is null, vector will be empty
168 | vector scores_vector(scores, scores + num_scores);
169 |
170 | return scores_vector;
171 | }
172 | finally
173 | {
174 | delete[] scores;
175 | }
176 | }
177 | private:
178 | Native_Scorer_Callback* const m_func;
179 | };
180 |
181 | // Triggers callback to the Context instance
182 | generic
183 | public ref class ContextCallback
184 | {
185 | internal:
186 | ContextCallback()
187 | {
188 | contextNumActionsCallback = gcnew ClrContextGetNumActionsCallback(&ContextCallback::InteropInvokeNumActions);
189 | IntPtr contextNumActionsCallbackPtr = Marshal::GetFunctionPointerForDelegate(contextNumActionsCallback);
190 | m_num_actions_callback = static_cast(contextNumActionsCallbackPtr.ToPointer());
191 | }
192 |
193 | Native_Context_Get_Num_Actions_Callback* GetNumActionsCallback()
194 | {
195 | return m_num_actions_callback;
196 | }
197 |
198 | static UInt32 InteropInvokeNumActions(IntPtr contextPtr)
199 | {
200 | GCHandle contextHandle = (GCHandle)contextPtr;
201 |
202 | return ((IVariableActionContext^)contextHandle.Target)->GetNumberOfActions();
203 | }
204 |
205 | private:
206 | initonly ClrContextGetNumActionsCallback^ contextNumActionsCallback;
207 |
208 | private:
209 | Native_Context_Get_Num_Actions_Callback* m_num_actions_callback;
210 | };
211 |
212 | // Triggers callback to the Policy instance to choose an action
213 | generic
214 | public ref class PolicyCallback abstract
215 | {
216 | internal:
217 | virtual UInt32 InvokePolicyCallback(Ctx context, int index) = 0;
218 |
219 | PolicyCallback()
220 | {
221 | policyCallback = gcnew ClrPolicyCallback(&PolicyCallback::InteropInvoke);
222 | IntPtr policyCallbackPtr = Marshal::GetFunctionPointerForDelegate(policyCallback);
223 | m_callback = static_cast(policyCallbackPtr.ToPointer());
224 | m_native_policy = nullptr;
225 | m_native_policies = nullptr;
226 | }
227 |
228 | ~PolicyCallback()
229 | {
230 | delete m_native_policy;
231 | delete m_native_policies;
232 | }
233 |
234 | NativePolicy* GetNativePolicy()
235 | {
236 | if (m_native_policy == nullptr)
237 | {
238 | m_native_policy = new NativePolicy(m_callback);
239 | }
240 | return m_native_policy;
241 | }
242 |
243 | vector>>* GetNativePolicies(int count)
244 | {
245 | if (m_native_policies == nullptr)
246 | {
247 | m_native_policies = new vector>>();
248 | for (int i = 0; i < count; i++)
249 | {
250 | m_native_policies->push_back(unique_ptr>(new NativePolicy(m_callback, i)));
251 | }
252 | }
253 |
254 | return m_native_policies;
255 | }
256 |
257 | static UInt32 InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, int index)
258 | {
259 | GCHandle callbackHandle = (GCHandle)callbackPtr;
260 | PolicyCallback^ callback = (PolicyCallback^)callbackHandle.Target;
261 |
262 | GCHandle contextHandle = (GCHandle)contextPtr;
263 | Ctx context = (Ctx)contextHandle.Target;
264 |
265 | return callback->InvokePolicyCallback(context, index);
266 | }
267 |
268 | private:
269 | initonly ClrPolicyCallback^ policyCallback;
270 |
271 | private:
272 | NativePolicy* m_native_policy;
273 | vector>>* m_native_policies;
274 | Native_Policy_Callback* m_callback;
275 | };
276 |
277 | // Triggers callback to the Recorder instance to record interaction data
278 | generic
279 | public ref class RecorderCallback abstract : public ContextCallback
280 | {
281 | internal:
282 | virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) = 0;
283 |
284 | RecorderCallback()
285 | {
286 | recorderCallback = gcnew ClrRecorderCallback(&RecorderCallback::InteropInvoke);
287 | IntPtr recorderCallbackPtr = Marshal::GetFunctionPointerForDelegate(recorderCallback);
288 | Native_Recorder_Callback* callback = static_cast(recorderCallbackPtr.ToPointer());
289 | m_native_recorder = new NativeRecorder(callback);
290 | }
291 |
292 | ~RecorderCallback()
293 | {
294 | delete m_native_recorder;
295 | }
296 |
297 | NativeRecorder* GetNativeRecorder()
298 | {
299 | return m_native_recorder;
300 | }
301 |
302 | static void InteropInvoke(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKeyPtr)
303 | {
304 | GCHandle mwtHandle = (GCHandle)mwtPtr;
305 | RecorderCallback^ callback = (RecorderCallback^)mwtHandle.Target;
306 |
307 | GCHandle contextHandle = (GCHandle)contextPtr;
308 | Ctx context = (Ctx)contextHandle.Target;
309 |
310 | GCHandle uniqueKeyHandle = (GCHandle)uniqueKeyPtr;
311 | String^ uniqueKey = (String^)uniqueKeyHandle.Target;
312 |
313 | callback->InvokeRecorderCallback(context, action, probability, uniqueKey);
314 | }
315 |
316 | private:
317 | initonly ClrRecorderCallback^ recorderCallback;
318 |
319 | private:
320 | NativeRecorder* m_native_recorder;
321 | };
322 |
323 | // Triggers callback to the Recorder instance to record interaction data
324 | generic
325 | public ref class ScorerCallback abstract
326 | {
327 | internal:
328 | virtual List^ InvokeScorerCallback(Ctx context) = 0;
329 |
330 | ScorerCallback()
331 | {
332 | scorerCallback = gcnew ClrScorerCallback(&ScorerCallback::InteropInvoke);
333 | IntPtr scorerCallbackPtr = Marshal::GetFunctionPointerForDelegate(scorerCallback);
334 | Native_Scorer_Callback* callback = static_cast(scorerCallbackPtr.ToPointer());
335 | m_native_scorer = new NativeScorer(callback);
336 | }
337 |
338 | ~ScorerCallback()
339 | {
340 | delete m_native_scorer;
341 | }
342 |
343 | NativeScorer* GetNativeScorer()
344 | {
345 | return m_native_scorer;
346 | }
347 |
348 | static void InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, IntPtr scoresPtr, IntPtr sizePtr)
349 | {
350 | GCHandle callbackHandle = (GCHandle)callbackPtr;
351 | ScorerCallback^ callback = (ScorerCallback^)callbackHandle.Target;
352 |
353 | GCHandle contextHandle = (GCHandle)contextPtr;
354 | Ctx context = (Ctx)contextHandle.Target;
355 |
356 | List^ scoreList = callback->InvokeScorerCallback(context);
357 |
358 | if (scoreList == nullptr || scoreList->Count == 0)
359 | {
360 | return;
361 | }
362 |
363 | u32* num_scores = (u32*)sizePtr.ToPointer();
364 | *num_scores = (u32)scoreList->Count;
365 |
366 | float* scores = new float[*num_scores];
367 | for (u32 i = 0; i < *num_scores; i++)
368 | {
369 | scores[i] = scoreList[i];
370 | }
371 |
372 | float** native_scores = (float**)scoresPtr.ToPointer();
373 | *native_scores = scores;
374 | }
375 |
376 | private:
377 | initonly ClrScorerCallback^ scorerCallback;
378 |
379 | private:
380 | NativeScorer* m_native_scorer;
381 | };
382 |
383 | // Triggers callback to the Context instance to perform ToString() operation
384 | generic where Ctx : IStringContext
385 | public ref class ToStringCallback
386 | {
387 | internal:
388 | ToStringCallback()
389 | {
390 | toStringCallback = gcnew ClrToStringCallback(&ToStringCallback::InteropInvoke);
391 | IntPtr toStringCallbackPtr = Marshal::GetFunctionPointerForDelegate(toStringCallback);
392 | m_callback = static_cast(toStringCallbackPtr.ToPointer());
393 | }
394 |
395 | Native_To_String_Callback* GetCallback()
396 | {
397 | return m_callback;
398 | }
399 |
400 | static void InteropInvoke(IntPtr contextPtr, IntPtr stringPtr)
401 | {
402 | GCHandle contextHandle = (GCHandle)contextPtr;
403 | Ctx context = (Ctx)contextHandle.Target;
404 |
405 | string* out_string = (string*)stringPtr.ToPointer();
406 | *out_string = marshal_as(context->ToString());
407 | }
408 |
409 | private:
410 | initonly ClrToStringCallback^ toStringCallback;
411 |
412 | private:
413 | Native_To_String_Callback* m_callback;
414 | };
415 |
416 | }
--------------------------------------------------------------------------------
/explore.cpp:
--------------------------------------------------------------------------------
1 | // explore.cpp : Timing code to measure performance of MWT Explorer library
2 |
3 | #include "MWTExplorer.h"
4 | #include
5 | #include
6 | #include
7 |
8 | using namespace std;
9 | using namespace std::chrono;
10 |
11 | using namespace MultiWorldTesting;
12 |
13 | class MySimplePolicy : public IPolicy
14 | {
15 | public:
16 | u32 Choose_Action(SimpleContext& context)
17 | {
18 | return (u32)1;
19 | }
20 | };
21 |
22 | const u32 num_actions = 10;
23 |
24 | void Clock_Explore()
25 | {
26 | float epsilon = .2f;
27 | string unique_key = "key";
28 | int num_features = 1000;
29 | int num_iter = 10000;
30 | int num_warmup = 100;
31 | int num_interactions = 1;
32 |
33 | // pre-create features
34 | vector features;
35 | for (int i = 0; i < num_features; i++)
36 | {
37 | Feature f = {0.5, i+1};
38 | features.push_back(f);
39 | }
40 |
41 | long long time_init = 0, time_choose = 0;
42 | for (int iter = 0; iter < num_iter + num_warmup; iter++)
43 | {
44 | high_resolution_clock::time_point t1 = high_resolution_clock::now();
45 | StringRecorder recorder;
46 | MwtExplorer mwt("test", recorder);
47 | MySimplePolicy default_policy;
48 | EpsilonGreedyExplorer explorer(default_policy, epsilon, num_actions);
49 | high_resolution_clock::time_point t2 = high_resolution_clock::now();
50 | time_init += iter < num_warmup ? 0 : duration_cast(t2 - t1).count();
51 |
52 | t1 = high_resolution_clock::now();
53 | SimpleContext appContext(features);
54 | for (int i = 0; i < num_interactions; i++)
55 | {
56 | mwt.Choose_Action(explorer, unique_key, appContext);
57 | }
58 | t2 = high_resolution_clock::now();
59 | time_choose += iter < num_warmup ? 0 : duration_cast(t2 - t1).count();
60 | }
61 |
62 | cout << "# iterations: " << num_iter << ", # interactions: " << num_interactions << ", # context features: " << num_features << endl;
63 | cout << "--- PER ITERATION ---" << endl;
64 | cout << "Init: " << (double)time_init / num_iter << " micro" << endl;
65 | cout << "Choose Action: " << (double)time_choose / (num_iter * num_interactions) << " micro" << endl;
66 | cout << "--- TOTAL TIME ---: " << (time_init + time_choose) << " micro" << endl;
67 | }
68 |
69 | int main(int argc, char* argv[])
70 | {
71 | Clock_Explore();
72 | return 0;
73 | }
74 |
--------------------------------------------------------------------------------
/explore.sln:
--------------------------------------------------------------------------------
1 |
2 | Microsoft Visual Studio Solution File, Format Version 12.00
3 | # Visual Studio 2013
4 | VisualStudioVersion = 12.0.30723.0
5 | MinimumVisualStudioVersion = 10.0.40219.1
6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "explore", "explore.vcxproj", "{FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}"
7 | EndProject
8 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "explore_static", "static\explore_static.vcxproj", "{ACE47E98-488C-4CDF-B9F1-36337B2855AD}"
9 | EndProject
10 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "explore_clr", "clr\explore_clr.vcxproj", "{8400DA16-1F46-4A31-A126-BBE16F62BFD7}"
11 | EndProject
12 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "explore_tests", "tests\explore_tests.vcxproj", "{5AE3AA40-BEB0-4979-8166-3B885172C430}"
13 | EndProject
14 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ExploreTests", "tests\ExploreTests.csproj", "{CB0C6B20-560C-4822-8EF6-DA999A64B542}"
15 | EndProject
16 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ExploreSample", "ExploreSample\ExploreSample.csproj", "{7081D542-AE64-485D-9087-79194B958499}"
17 | EndProject
18 | Global
19 | GlobalSection(SolutionConfigurationPlatforms) = preSolution
20 | Debug|Win32 = Debug|Win32
21 | Debug|x64 = Debug|x64
22 | Release|Win32 = Release|Win32
23 | Release|x64 = Release|x64
24 | EndGlobalSection
25 | GlobalSection(ProjectConfigurationPlatforms) = postSolution
26 | {FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Debug|Win32.ActiveCfg = Debug|Win32
27 | {FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Debug|Win32.Build.0 = Debug|Win32
28 | {FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Debug|x64.ActiveCfg = Debug|x64
29 | {FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Debug|x64.Build.0 = Debug|x64
30 | {FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Release|Win32.ActiveCfg = Release|Win32
31 | {FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Release|Win32.Build.0 = Release|Win32
32 | {FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Release|x64.ActiveCfg = Release|x64
33 | {FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Release|x64.Build.0 = Release|x64
34 | {ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Debug|Win32.ActiveCfg = Debug|Win32
35 | {ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Debug|Win32.Build.0 = Debug|Win32
36 | {ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Debug|x64.ActiveCfg = Debug|x64
37 | {ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Debug|x64.Build.0 = Debug|x64
38 | {ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Release|Win32.ActiveCfg = Release|Win32
39 | {ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Release|Win32.Build.0 = Release|Win32
40 | {ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Release|x64.ActiveCfg = Release|x64
41 | {ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Release|x64.Build.0 = Release|x64
42 | {8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Debug|Win32.ActiveCfg = Debug|Win32
43 | {8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Debug|Win32.Build.0 = Debug|Win32
44 | {8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Debug|x64.ActiveCfg = Debug|x64
45 | {8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Debug|x64.Build.0 = Debug|x64
46 | {8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Release|Win32.ActiveCfg = Release|Win32
47 | {8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Release|Win32.Build.0 = Release|Win32
48 | {8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Release|x64.ActiveCfg = Release|x64
49 | {8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Release|x64.Build.0 = Release|x64
50 | {5AE3AA40-BEB0-4979-8166-3B885172C430}.Debug|Win32.ActiveCfg = Debug|Win32
51 | {5AE3AA40-BEB0-4979-8166-3B885172C430}.Debug|Win32.Build.0 = Debug|Win32
52 | {5AE3AA40-BEB0-4979-8166-3B885172C430}.Debug|x64.ActiveCfg = Debug|x64
53 | {5AE3AA40-BEB0-4979-8166-3B885172C430}.Debug|x64.Build.0 = Debug|x64
54 | {5AE3AA40-BEB0-4979-8166-3B885172C430}.Release|Win32.ActiveCfg = Release|Win32
55 | {5AE3AA40-BEB0-4979-8166-3B885172C430}.Release|Win32.Build.0 = Release|Win32
56 | {5AE3AA40-BEB0-4979-8166-3B885172C430}.Release|x64.ActiveCfg = Release|x64
57 | {5AE3AA40-BEB0-4979-8166-3B885172C430}.Release|x64.Build.0 = Release|x64
58 | {CB0C6B20-560C-4822-8EF6-DA999A64B542}.Debug|Win32.ActiveCfg = Debug|x86
59 | {CB0C6B20-560C-4822-8EF6-DA999A64B542}.Debug|Win32.Build.0 = Debug|x86
60 | {CB0C6B20-560C-4822-8EF6-DA999A64B542}.Debug|x64.ActiveCfg = Debug|x64
61 | {CB0C6B20-560C-4822-8EF6-DA999A64B542}.Debug|x64.Build.0 = Debug|x64
62 | {CB0C6B20-560C-4822-8EF6-DA999A64B542}.Release|Win32.ActiveCfg = Release|x86
63 | {CB0C6B20-560C-4822-8EF6-DA999A64B542}.Release|Win32.Build.0 = Release|x86
64 | {CB0C6B20-560C-4822-8EF6-DA999A64B542}.Release|x64.ActiveCfg = Release|x64
65 | {CB0C6B20-560C-4822-8EF6-DA999A64B542}.Release|x64.Build.0 = Release|x64
66 | {7081D542-AE64-485D-9087-79194B958499}.Debug|Win32.ActiveCfg = Debug|x86
67 | {7081D542-AE64-485D-9087-79194B958499}.Debug|Win32.Build.0 = Debug|x86
68 | {7081D542-AE64-485D-9087-79194B958499}.Debug|x64.ActiveCfg = Debug|x64
69 | {7081D542-AE64-485D-9087-79194B958499}.Debug|x64.Build.0 = Debug|x64
70 | {7081D542-AE64-485D-9087-79194B958499}.Release|Win32.ActiveCfg = Release|x86
71 | {7081D542-AE64-485D-9087-79194B958499}.Release|Win32.Build.0 = Release|x86
72 | {7081D542-AE64-485D-9087-79194B958499}.Release|x64.ActiveCfg = Release|x64
73 | {7081D542-AE64-485D-9087-79194B958499}.Release|x64.Build.0 = Release|x64
74 | EndGlobalSection
75 | GlobalSection(SolutionProperties) = preSolution
76 | HideSolutionNode = FALSE
77 | EndGlobalSection
78 | EndGlobal
79 |
--------------------------------------------------------------------------------
/explore.vcxproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Debug
6 | Win32
7 |
8 |
9 | Debug
10 | x64
11 |
12 |
13 | Release
14 | Win32
15 |
16 |
17 | Release
18 | x64
19 |
20 |
21 |
22 | {FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}
23 | Win32Proj
24 | vw_explore
25 | explore
26 |
27 |
28 |
29 | Application
30 | true
31 | v120
32 | Unicode
33 |
34 |
35 | Application
36 | true
37 | v120
38 | Unicode
39 |
40 |
41 | Application
42 | false
43 | v120
44 | true
45 | Unicode
46 |
47 |
48 | Application
49 | false
50 | v120
51 | true
52 | Unicode
53 |
54 |
55 | c:\boost\x64\include\boost-1_56
56 | c:\boost\x64\lib
57 | ..\..\zlib-1.2.8
58 | $(ZlibIncludeDir)\contrib\vstudio\vc10\x64\ZlibStat$(Configuration)
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 | true
78 |
79 |
80 | true
81 |
82 |
83 | false
84 |
85 |
86 | false
87 |
88 |
89 |
90 |
91 |
92 | Level3
93 | Disabled
94 | WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions)
95 | static;
96 |
97 |
98 | Console
99 | true
100 |
101 |
102 |
103 |
104 |
105 |
106 | Level3
107 | Disabled
108 | WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions)
109 | static;
110 |
111 |
112 | Console
113 | true
114 |
115 |
116 |
117 |
118 | Level3
119 |
120 |
121 | MaxSpeed
122 | true
123 | true
124 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions)
125 | true
126 | static;
127 |
128 |
129 | Console
130 | true
131 | true
132 | true
133 |
134 |
135 |
136 |
137 | Level3
138 |
139 |
140 | MaxSpeed
141 | true
142 | true
143 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions)
144 | true
145 | static;
146 |
147 |
148 | Console
149 | true
150 | true
151 | true
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 | true
160 | true
161 | true
162 | true
163 |
164 |
165 |
166 |
167 |
168 | {ace47e98-488c-4cdf-b9f1-36337b2855ad}
169 |
170 |
171 |
172 |
173 |
174 |
--------------------------------------------------------------------------------
/explore.vcxproj.filters:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF}
6 | cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx
7 |
8 |
9 | {93995380-89BD-4b04-88EB-625FBE52EBFB}
10 | h;hpp;hxx;hm;inl;inc;xsd
11 |
12 |
13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01}
14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 | Source Files
23 |
24 |
25 | Source Files
26 |
27 |
28 |
--------------------------------------------------------------------------------
/explore_sample.cpp:
--------------------------------------------------------------------------------
1 | // vw_explore.cpp : Defines the entry point for the console application.
2 | //
3 |
4 | #include "MWTExplorer.h"
5 | #include
6 | #include
7 | #include
8 |
9 | using namespace std;
10 | using namespace std::chrono;
11 | using namespace MultiWorldTesting;
12 |
13 | /// Example of a custom context.
14 | class MyContext
15 | {
16 |
17 | };
18 |
19 | /// Example of a custom policy which implements the IPolicy,
20 | /// declaring that this policy only interacts with MyContext objects.
21 | class MyPolicy : public IPolicy
22 | {
23 | public:
24 | u32 Choose_Action(MyContext& context)
25 | {
26 | // Always returns the same action regardless of context
27 | return (u32)1;
28 | }
29 | };
30 |
31 | /// Example of a custom policy which implements the IPolicy,
32 | /// declaring that this policy only interacts with SimpleContext objects.
33 | class MySimplePolicy : public IPolicy
34 | {
35 | public:
36 | u32 Choose_Action(SimpleContext& context)
37 | {
38 | // Always returns the same action regardless of context
39 | return (u32)1;
40 | }
41 | };
42 |
43 | /// Example of a custom scorer which implements the IScorer,
44 | /// declaring that this scorer only interacts with MyContext objects.
45 | class MyScorer : public IScorer
46 | {
47 | public:
48 | MyScorer(u32 num_actions) : m_num_actions(num_actions)
49 | {
50 |
51 | }
52 | vector Score_Actions(MyContext& context)
53 | {
54 | vector scores;
55 | for (size_t i = 0; i < m_num_actions; i++)
56 | {
57 | // Gives every action the same score (which results in a uniform distribution).
58 | scores.push_back(.1f);
59 | }
60 | return scores;
61 | }
62 | private:
63 | u32 m_num_actions;
64 | };
65 |
66 | ///
67 | /// Represents a tuple .
68 | ///
69 | template
70 | struct MyInteraction
71 | {
72 | Ctx Context;
73 | u32 Action;
74 | float Probability;
75 | string Unique_Key;
76 | };
77 |
78 | /// Example of a custom recorder which implements the IRecorder,
79 | /// declaring that this recorder only interacts with MyContext objects.
80 | class MyRecorder : public IRecorder
81 | {
82 | public:
83 | virtual void Record(MyContext& context, u32 action, float probability, string unique_key)
84 | {
85 | // Stores the tuple internally in a vector that could be used later for other purposes.
86 | m_interactions.push_back({ context, action, probability, unique_key });
87 | }
88 | private:
89 | vector> m_interactions;
90 | };
91 |
92 | int main(int argc, char* argv[])
93 | {
94 | if (argc < 2)
95 | {
96 | cerr << "arguments: {greedy,tau-first,bootstrap,softmax,generic}" << endl;
97 | exit(1);
98 | }
99 |
100 | // Arguments for individual explorers
101 | if (strcmp(argv[1], "greedy") == 0)
102 | {
103 | // Initialize Epsilon-Greedy explore algorithm using MyPolicy
104 |
105 | // Creates a recorder of built-in StringRecorder type for string serialization
106 | StringRecorder recorder;
107 |
108 | // Creates a policy that interacts with SimpleContext type
109 | MySimplePolicy default_policy;
110 |
111 | // Creates an MwtExplorer instance using the recorder above
112 | MwtExplorer mwt("appid", recorder);
113 |
114 | u32 num_actions = 10;
115 | float epsilon = .2f;
116 | // Creates an Epsilon-Greedy explorer using the specified settings
117 | EpsilonGreedyExplorer explorer(default_policy, epsilon, num_actions);
118 |
119 | // Creates a context of built-in SimpleContext type
120 | vector features;
121 | features.push_back({ 0.5f, 1 });
122 | features.push_back({ 1.3f, 11 });
123 | features.push_back({ -.95f, 413 });
124 |
125 | SimpleContext context(features);
126 |
127 | // Performs exploration by passing an instance of the Epsilon-Greedy exploration algorithm into MwtExplorer
128 | // using a sample string to uniquely identify this event
129 | string unique_key = "eventid";
130 | u32 action = mwt.Choose_Action(explorer, unique_key, context);
131 |
132 | cout << "Chosen action = " << action << endl;
133 | cout << "Exploration record = " << recorder.Get_Recording();
134 | }
135 | else if (strcmp(argv[1], "tau-first") == 0)
136 | {
137 | // Initialize Tau-First explore algorithm using MyPolicy
138 | MyRecorder recorder;
139 | MwtExplorer mwt("appid", recorder);
140 |
141 | int num_actions = 10;
142 | u32 tau = 5;
143 | MyPolicy default_policy;
144 | TauFirstExplorer explorer(default_policy, tau, num_actions);
145 | MyContext ctx;
146 | string unique_key = "eventid";
147 | u32 action = mwt.Choose_Action(explorer, unique_key, ctx);
148 |
149 | cout << "action = " << action << endl;
150 | }
151 | else if (strcmp(argv[1], "bootstrap") == 0)
152 | {
153 | // Initialize Bootstrap explore algorithm using MyPolicy
154 | MyRecorder recorder;
155 | MwtExplorer mwt("appid", recorder);
156 |
157 | u32 num_bags = 2;
158 |
159 | // Create a vector of smart pointers to default policies using the built-in type PolicyPtr
160 | vector>> policy_functions;
161 | for (size_t i = 0; i < num_bags; i++)
162 | {
163 | policy_functions.push_back(unique_ptr>(new MyPolicy()));
164 | }
165 | int num_actions = 10;
166 | BootstrapExplorer explorer(policy_functions, num_actions);
167 | MyContext ctx;
168 | string unique_key = "eventid";
169 | u32 action = mwt.Choose_Action(explorer, unique_key, ctx);
170 |
171 | cout << "action = " << action << endl;
172 | }
173 | else if (strcmp(argv[1], "softmax") == 0)
174 | {
175 | // Initialize Softmax explore algorithm using MyScorer
176 | MyRecorder recorder;
177 | MwtExplorer mwt("salt", recorder);
178 |
179 | u32 num_actions = 10;
180 | MyScorer scorer(num_actions);
181 | float lambda = 0.5f;
182 | SoftmaxExplorer explorer(scorer, lambda, num_actions);
183 |
184 | MyContext ctx;
185 | string unique_key = "eventid";
186 | u32 action = mwt.Choose_Action(explorer, unique_key, ctx);
187 |
188 | cout << "action = " << action << endl;
189 | }
190 | else if (strcmp(argv[1], "generic") == 0)
191 | {
192 | // Initialize Generic explore algorithm using MyScorer
193 | MyRecorder recorder;
194 | MwtExplorer mwt("appid", recorder);
195 |
196 | int num_actions = 10;
197 | MyScorer scorer(num_actions);
198 | GenericExplorer explorer(scorer, num_actions);
199 | MyContext ctx;
200 | string unique_key = "eventid";
201 | u32 action = mwt.Choose_Action(explorer, unique_key, ctx);
202 |
203 | cout << "action = " << action << endl;
204 | }
205 | else
206 | {
207 | cerr << "unknown exploration type: " << argv[1] << endl;
208 | exit(1);
209 | }
210 |
211 | return 0;
212 | }
213 |
--------------------------------------------------------------------------------
/mwt.chm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/mwt-ds-explore/6e3a7124d7adbd51434a28b1d48d0d397e1fd3fd/mwt.chm
--------------------------------------------------------------------------------
/static/MWTExplorer.h:
--------------------------------------------------------------------------------
1 | //
2 | // Main interface for clients of the Multiworld testing (MWT) service.
3 | //
4 |
5 | #pragma once
6 |
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 |
18 | #ifdef MANAGED_CODE
19 | #define PORTING_INTERFACE public
20 | #define MWT_NAMESPACE namespace NativeMultiWorldTesting
21 | #else
22 | #define PORTING_INTERFACE private
23 | #define MWT_NAMESPACE namespace MultiWorldTesting
24 | #endif
25 |
26 | using namespace std;
27 |
28 | #include "utility.h"
29 |
30 | /** \defgroup MultiWorldTestingCpp
31 | \brief C++ implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp
32 | */
33 |
34 | /*!
35 | * \addtogroup MultiWorldTestingCpp
36 | * @{
37 | */
38 |
39 | //! Interface for C++ version of Multiworld Testing library.
40 | //! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp
41 | MWT_NAMESPACE {
42 |
43 | // Forward declarations
44 | template
45 | class IRecorder;
46 | template
47 | class IExplorer;
48 |
49 | ///
50 | /// The top-level MwtExplorer class. Using this enables principled and efficient exploration
51 | /// over a set of possible actions, and ensures that the right bits are recorded.
52 | ///
53 | template
54 | class MwtExplorer
55 | {
56 | public:
57 | ///
58 | /// Constructor
59 | ///
60 | /// @param appid This should be unique to your experiment or you risk nasty correlation bugs.
61 | /// @param recorder A user-specified class for recording the appropriate bits for use in evaluation and learning.
62 | ///
63 | MwtExplorer(std::string app_id, IRecorder& recorder) : m_recorder(recorder)
64 | {
65 | m_app_id = HashUtils::Compute_Id_Hash(app_id);
66 | }
67 |
68 | ///
69 | /// Chooses an action by invoking an underlying exploration algorithm. This should be a
70 | /// drop-in replacement for any existing policy function.
71 | ///
72 | /// @param explorer An existing exploration algorithm (one of the below) which uses the default policy as a callback.
73 | /// @param unique_key A unique identifier for the experimental unit. This could be a user id, a session id, etc..
74 | /// @param context The context upon which a decision is made. See SimpleContext below for an example.
75 | ///
76 | u32 Choose_Action(IExplorer& explorer, string unique_key, Ctx& context)
77 | {
78 | u64 seed = HashUtils::Compute_Id_Hash(unique_key);
79 |
80 | std::tuple action_probability_log_tuple = explorer.Choose_Action(seed + m_app_id, context);
81 |
82 | u32 action = std::get<0>(action_probability_log_tuple);
83 | float prob = std::get<1>(action_probability_log_tuple);
84 |
85 | if (std::get<2>(action_probability_log_tuple))
86 | {
87 | m_recorder.Record(context, action, prob, unique_key);
88 | }
89 |
90 | return action;
91 | }
92 |
93 | private:
94 | u64 m_app_id;
95 | IRecorder& m_recorder;
96 | };
97 |
98 | ///
99 | /// Exposes a method to record exploration data based on generic contexts. Exploration data
100 | /// is specified as a set of tuples as described below. An
101 | /// application passes an IRecorder object to the @MwtExplorer constructor. See
102 | /// @StringRecorder for a sample IRecorder object.
103 | ///
104 | template
105 | class IRecorder
106 | {
107 | public:
108 | ///
109 | /// Records the exploration data associated with a given decision.
110 | /// This implementation should be thread-safe if multithreading is needed.
111 | ///
112 | /// @param context A user-defined context for the decision
113 | /// @param action The action chosen by an exploration algorithm given context
114 | /// @param probability The probability the exploration algorithm chose said action
115 | /// @param unique_key A user-defined unique identifer for the decision
116 | ///
117 | virtual void Record(Ctx& context, u32 action, float probability, string unique_key) = 0;
118 | virtual ~IRecorder() { }
119 | };
120 |
121 | ///
122 | /// Exposes a method to choose an action given a generic context, and obtain the relevant
123 | /// exploration bits. Invokes IPolicy::Choose_Action internally. Do not implement this
124 | /// interface yourself: instead, use the various exploration algorithms below, which
125 | /// implement it for you.
126 | ///
127 | template
128 | class IExplorer
129 | {
130 | public:
131 | ///
132 | /// Determines the action to take and the probability with which it was chosen, for a
133 | /// given context.
134 | ///
135 | /// @param salted_seed A PRG seed based on a unique id information provided by the user
136 | /// @param context A user-defined context for the decision
137 | /// @returns The action to take, the probability it was chosen, and a flag indicating
138 | /// whether to record this decision
139 | ///
140 | virtual std::tuple Choose_Action(u64 salted_seed, Ctx& context) = 0;
141 | virtual void Enable_Explore(bool explore) = 0;
142 | virtual ~IExplorer() { }
143 | };
144 |
145 | ///
146 | /// Exposes a method to choose an action given a generic context. IPolicy objects are
147 | /// passed to (and invoked by) exploration algorithms to specify the default policy behavior.
148 | ///
149 | template
150 | class IPolicy
151 | {
152 | public:
153 | ///
154 | /// Determines the action to take for a given context.
155 | /// This implementation should be thread-safe if multithreading is needed.
156 | ///
157 | /// @param context A user-defined context for the decision
158 | /// @returns The action to take (1-based index)
159 | ///
160 | virtual u32 Choose_Action(Ctx& context) = 0;
161 | virtual ~IPolicy() { }
162 | };
163 |
164 | ///
165 | /// Exposes a method for specifying a score (weight) for each action given a generic context.
166 | ///
167 | template
168 | class IScorer
169 | {
170 | public:
171 | ///
172 | /// Determines the score of each action for a given context.
173 | /// This implementation should be thread-safe if multithreading is needed.
174 | ///
175 | /// @param context A user-defined context for the decision
176 | /// @returns A vector of scores indexed by action (1-based)
177 | ///
178 | virtual vector Score_Actions(Ctx& context) = 0;
179 | virtual ~IScorer() { }
180 | };
181 |
182 | ///
183 | /// Represents a context interface with variable number of actions which is
184 | /// enforced if exploration algorithm is initialized in variable number of actions mode.
185 | ///
186 | class IVariableActionContext
187 | {
188 | public:
189 | ///
190 | /// Gets the number of actions for the current context.
191 | ///
192 | /// @returns The number of actions available for the current context.
193 | ///
194 | virtual u32 Get_Number_Of_Actions() = 0;
195 | virtual ~IVariableActionContext() { }
196 | };
197 |
198 | template
199 | class IConsumePolicy
200 | {
201 | public:
202 | virtual void Update_Policy(IPolicy& new_policy) = 0;
203 | virtual ~IConsumePolicy() { }
204 | };
205 |
206 | template
207 | class IConsumePolicies
208 | {
209 | public:
210 | virtual void Update_Policy(vector>>& new_policy_functions) = 0;
211 | virtual ~IConsumePolicies() { }
212 | };
213 |
214 | template
215 | class IConsumeScorer
216 | {
217 | public:
218 | virtual void Update_Scorer(IScorer& new_policy) = 0;
219 | virtual ~IConsumeScorer() { }
220 | };
221 |
222 | ///
223 | /// A sample recorder class that converts the exploration tuple into string format.
224 | ///
225 | template
226 | struct StringRecorder : public IRecorder
227 | {
228 | void Record(Ctx& context, u32 action, float probability, string unique_key)
229 | {
230 | // Implicitly enforce To_String() API on the context
231 | m_recording.append(to_string((unsigned long long)action));
232 | m_recording.append(" ", 1);
233 | m_recording.append(unique_key);
234 | m_recording.append(" ", 1);
235 |
236 | char prob_str[10] = { 0 };
237 | int x = (int)probability;
238 | int d = (int)(fabs(probability - x) * 100000);
239 | sprintf_s(prob_str, 10 * sizeof(char), "%d.%05d", x, d);
240 | m_recording.append(prob_str);
241 |
242 | m_recording.append(" | ", 3);
243 | m_recording.append(context.To_String());
244 | m_recording.append("\n");
245 | }
246 |
247 | // Gets the content of the recording so far as a string and optionally clears internal content.
248 | string Get_Recording(bool flush = true)
249 | {
250 | if (!flush)
251 | {
252 | return m_recording;
253 | }
254 | string recording = m_recording;
255 | m_recording.clear();
256 | return recording;
257 | }
258 |
259 | private:
260 | string m_recording;
261 | };
262 |
263 | ///
264 | /// Represents a feature in a sparse array.
265 | ///
266 | struct Feature
267 | {
268 | float Value;
269 | u32 Id;
270 |
271 | bool operator==(Feature other_feature)
272 | {
273 | return Id == other_feature.Id;
274 | }
275 | };
276 |
277 | ///
278 | /// A sample context class that stores a vector of Features.
279 | ///
280 | class SimpleContext
281 | {
282 | public:
283 | SimpleContext(vector& features) :
284 | m_features(features)
285 | { }
286 |
287 | vector& Get_Features()
288 | {
289 | return m_features;
290 | }
291 |
292 | string To_String()
293 | {
294 | string out_string;
295 | const size_t strlen = 35;
296 | char feature_str[strlen] = { 0 };
297 | for (size_t i = 0; i < m_features.size(); i++)
298 | {
299 | int chars;
300 | if (i == 0)
301 | {
302 | chars = sprintf_s(feature_str, strlen, "%d:", m_features[i].Id);
303 | }
304 | else
305 | {
306 | chars = sprintf_s(feature_str, strlen, " %d:", m_features[i].Id);
307 | }
308 | NumberUtils::print_float(feature_str + chars, strlen-chars, m_features[i].Value);
309 | out_string.append(feature_str);
310 | }
311 | return out_string;
312 | }
313 |
314 | private:
315 | vector& m_features;
316 | };
317 |
318 | template
319 | static u32 Get_Variable_Number_Of_Actions(Ctx& context, u32 default_num_actions)
320 | {
321 | u32 num_actions = default_num_actions;
322 | if (num_actions == UINT_MAX)
323 | {
324 | num_actions = ((IVariableActionContext*)(&context))->Get_Number_Of_Actions();
325 | if (num_actions < 1)
326 | {
327 | throw std::invalid_argument("Number of actions must be at least 1.");
328 | }
329 | }
330 | return num_actions;
331 | }
332 |
333 | ///
334 | /// The epsilon greedy exploration algorithm. This is a good choice if you have no idea
335 | /// which actions should be preferred. Epsilon greedy is also computationally cheap.
336 | ///
337 | template