├── README.md
├── src
└── MonteCarlo
│ ├── MonteCarlo.csproj
│ └── MonteCarloTreeSearch.cs
├── examples
├── ConnectFour
│ ├── ConnectFour
│ │ ├── ConnectFour.csproj
│ │ ├── Program.cs
│ │ └── ConnectFourState.cs
│ └── ConnectFourTests
│ │ ├── TestData
│ │ ├── GetResult_Columns.txt
│ │ ├── GetResult_Rows.txt
│ │ ├── GetResult_DiagonalTopBottom.txt
│ │ └── GetResult_DiagonalBottomTop.txt
│ │ ├── ConnectFourTests.csproj
│ │ └── ConnectFourStateTest.cs
└── TicTacToe
│ ├── TicTacToe
│ ├── TicTacToe.csproj
│ ├── Program.cs
│ └── TicTacToe.cs
│ └── TicTacToeTests
│ ├── TicTacToeTests.csproj
│ └── TicTacToeStateTest.cs
├── LICENSE
├── .gitignore
└── montecarlo.sln
/README.md:
--------------------------------------------------------------------------------
1 | # mcts
2 |
3 | Exploration with a simple generic Monte Carlo Tree Search algorithm in C#
4 |
--------------------------------------------------------------------------------
/src/MonteCarlo/MonteCarlo.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | netstandard2.0
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/examples/ConnectFour/ConnectFour/ConnectFour.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Exe
5 | netcoreapp2.0
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/examples/ConnectFour/ConnectFourTests/TestData/GetResult_Columns.txt:
--------------------------------------------------------------------------------
1 |
2 |
3 | X
4 | X
5 | OX
6 | OOX
7 | -------
8 |
9 |
10 | X
11 | XO
12 | XO
13 | XO
14 | -------
15 | X
16 | X
17 | X
18 | X
19 | OX
20 | OOO
21 | -------
22 | X
23 | X
24 | X
25 | XOX
26 | OXOX
27 | XOOOX
28 | -------
29 | O
30 | XO
31 | XX
32 | XO
33 | OXOXX
34 | XXOOXOX
35 | -------
--------------------------------------------------------------------------------
/examples/TicTacToe/TicTacToe/TicTacToe.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Exe
5 | netcoreapp2.0
6 | TicTacToe.Program
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/examples/TicTacToe/TicTacToeTests/TicTacToeTests.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | netcoreapp2.0
5 |
6 | false
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/examples/ConnectFour/ConnectFourTests/TestData/GetResult_Rows.txt:
--------------------------------------------------------------------------------
1 | XXXX
2 | XOXOXOX
3 | OXOXOXO
4 | OXOXOXO
5 | XOXOXOX
6 | OXOXOXO
7 | --- ^ top left
8 |
9 |
10 |
11 |
12 | OOO
13 | XXXX
14 | --- ^ bottom left
15 |
16 |
17 |
18 |
19 | O O
20 | OXXXX
21 | --- ^ bottom middle
22 |
23 |
24 |
25 |
26 | O O
27 | OXXXX
28 | --- ^ bottom middle
29 |
30 |
31 | O
32 | O
33 | O
34 | XXXX
35 | --- ^ bottom right
36 | XXXX
37 | XOXOXOX
38 | OXOXOXO
39 | OXOXOXO
40 | XOXOXOX
41 | OXOXOXO
42 | --- ^ top middle
43 | XXXX
44 | XOXOXOX
45 | OXOXOXO
46 | OXOXOXO
47 | XOXOXOX
48 | OXOXOXO
49 | --- ^ top right
50 |
51 |
52 |
53 | OO
54 | XXXXO
55 | OXOXOXO
56 | --- ^ middle middle
--------------------------------------------------------------------------------
/examples/ConnectFour/ConnectFourTests/TestData/GetResult_DiagonalTopBottom.txt:
--------------------------------------------------------------------------------
1 |
2 |
3 | X
4 | X
5 | X
6 | X
7 | --- bottom left
8 |
9 |
10 | X
11 | X
12 | X
13 | X
14 | --- bottom middle
15 |
16 |
17 | X
18 | X
19 | X
20 | X
21 | --- bottom right
22 |
23 | X
24 | X
25 | X
26 | X
27 |
28 | --- middle left
29 |
30 | X
31 | X
32 | X
33 | X
34 |
35 | --- middle middle
36 |
37 | X
38 | X
39 | X
40 | X
41 |
42 | --- middle right
43 | X
44 | X
45 | X
46 | X
47 |
48 |
49 | --- top left
50 | X
51 | X
52 | X
53 | X
54 |
55 |
56 | --- top middle
57 | X
58 | X
59 | X
60 | X
61 |
62 |
63 | --- top right
--------------------------------------------------------------------------------
/examples/ConnectFour/ConnectFourTests/TestData/GetResult_DiagonalBottomTop.txt:
--------------------------------------------------------------------------------
1 |
2 |
3 | X
4 | XXO
5 | XOXO
6 | XXOOO
7 | --- bottom left
8 |
9 |
10 | X
11 | XO
12 | XXO
13 | XOOX
14 | --- bottom right
15 |
16 |
17 | X
18 | XO
19 | XXO
20 | XOOX
21 | --- bottom middle
22 |
23 | X
24 | XX
25 | XXO
26 | XOOX
27 | XOXOXOX
28 | --- middle left
29 |
30 | X
31 | XO
32 | XOX
33 | XOXO
34 | XOXOXOX
35 | --- middle middle
36 |
37 | X
38 | XO
39 | XXO
40 | XOOX
41 | XOXOXOX
42 | --- middle right
43 | X
44 | XXO
45 | XOXO
46 | XXOOO
47 | XOXOXOX
48 | XOXOXOX
49 | --- top left
50 | X
51 | XO
52 | XXO
53 | XOOX
54 | XOXOXOX
55 | XOXOXOX
56 | --- top right
57 | X
58 | XO
59 | XOX
60 | XOXO
61 | XOXOXOX
62 | XOXOXOX
63 | --- top middle
--------------------------------------------------------------------------------
/examples/ConnectFour/ConnectFourTests/ConnectFourTests.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | netcoreapp2.0
5 |
6 | false
7 |
8 |
9 |
10 |
11 | Always
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Daniel McDonald
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/examples/TicTacToe/TicTacToe/Program.cs:
--------------------------------------------------------------------------------
1 | using MonteCarlo;
2 | using System;
3 | using System.Linq;
4 |
5 | namespace TicTacToe
6 | {
7 | class Program
8 | {
9 | private const string SEPARATOR = "______";
10 | static void Main(string[] args)
11 | {
12 | var game = new TicTacToeState();
13 | while(game.Actions.Any())
14 | {
15 | Console.WriteLine($"CurrentPlayer: {game.CurrentPlayer}");
16 | Console.WriteLine(game);
17 | Console.WriteLine(SEPARATOR);
18 |
19 | var position = -1;
20 | while (position < 0)
21 | {
22 | Console.WriteLine("Choose a free space: (0-8)");
23 |
24 | var input = Console.ReadKey();
25 | int.TryParse(input.KeyChar.ToString(), out position);
26 | }
27 |
28 | Console.WriteLine();
29 |
30 | game.ApplyAction(new TicTacToeAction(position, TicTacToePlayer.X));
31 | var computer = MonteCarloTreeSearch.GetTopActions(game, 50000).ToList();
32 | Console.WriteLine(SEPARATOR);
33 | if (computer.Count > 0)
34 | {
35 | Console.WriteLine("Computer's ranked plays:");
36 | foreach (var a in computer)
37 | Console.WriteLine($"\t{a.Action}\t{a.NumWins}/{a.NumRuns} ({a.NumWins / a.NumRuns})");
38 | game.ApplyAction(computer[0].Action);
39 | }
40 |
41 | position = -1;
42 | }
43 |
44 | Console.WriteLine(SEPARATOR);
45 | Console.WriteLine(game.ToString());
46 | Console.WriteLine("Game Over");
47 | Console.ReadKey();
48 | }
49 | }
50 | }
51 |
--------------------------------------------------------------------------------
/examples/ConnectFour/ConnectFour/Program.cs:
--------------------------------------------------------------------------------
1 | using MonteCarlo;
2 | using System;
3 | using System.Linq;
4 |
5 | namespace ConnectFour
6 | {
7 | class Program
8 | {
9 | private const string SEPARATOR = "-------";
10 | static void Main(string[] args)
11 | {
12 | var game = new ConnectFourState();
13 | while (game.Actions.Any())
14 | {
15 | Console.WriteLine($"CurrentPlayer: {game.CurrentPlayer}");
16 | Console.WriteLine(game);
17 | Console.WriteLine("0123456");
18 | Console.WriteLine(SEPARATOR);
19 |
20 | var position = -1;
21 | while(position < 0 || position > 6)
22 | {
23 | Console.WriteLine("Choose a free space: (0-6)");
24 |
25 | var input = Console.ReadKey();
26 | int.TryParse(input.KeyChar.ToString(), out position);
27 | }
28 |
29 | Console.WriteLine();
30 | game.ApplyAction(new ConnectFourAction(position));
31 | var computer = MonteCarloTreeSearch.GetTopActions(game, 50000, 1000).ToList();
32 |
33 | Console.WriteLine(SEPARATOR);
34 | if (computer.Count > 0)
35 | {
36 | Console.WriteLine("Computer's ranked plays:");
37 | foreach (var a in computer)
38 | Console.WriteLine($"\t{a.Action}\t{a.NumWins}/{a.NumRuns} ({a.NumWins / a.NumRuns})");
39 | game.ApplyAction(computer[0].Action);
40 | }
41 |
42 | position = -1;
43 | }
44 |
45 | Console.WriteLine(SEPARATOR);
46 | Console.WriteLine(game.ToString());
47 | Console.WriteLine("Game Over");
48 | Console.ReadKey();
49 | }
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/examples/TicTacToe/TicTacToe/TicTacToe.cs:
--------------------------------------------------------------------------------
1 | using MonteCarlo;
2 | using System;
3 | using System.Collections.Generic;
4 | using System.Linq;
5 | using System.Text;
6 |
7 | namespace TicTacToe
8 | {
9 | public struct TicTacToeAction : IAction
10 | {
11 | public TicTacToeAction(int position, TicTacToePlayer player)
12 | {
13 | if (position < 0 || position > 8)
14 | throw new ArgumentException("position must be between 0 and 8, inclusive", nameof(position));
15 |
16 | Position = position;
17 | Player = player;
18 | }
19 |
20 | public int Position { get; private set; }
21 | public TicTacToePlayer Player { get; private set; }
22 |
23 | public override string ToString()
24 | {
25 | return $"{Player}: {Position}";
26 | }
27 | }
28 |
29 | public class TicTacToePlayer : IPlayer
30 | {
31 | private TicTacToePlayer(bool isX)
32 | {
33 | IsX = isX;
34 | }
35 |
36 | public bool IsX { get; }
37 |
38 | public TicTacToePlayer NextPlayer => IsX ? O : X;
39 |
40 | public override string ToString()
41 | {
42 | return IsX ? "X" : "O";
43 | }
44 |
45 | public static TicTacToePlayer X = new TicTacToePlayer(true);
46 | public static TicTacToePlayer O = new TicTacToePlayer(false);
47 | }
48 |
49 | public class TicTacToeState : IState
50 | {
51 | private IList board;
52 |
53 | public TicTacToeState() : this(new TicTacToePlayer[9], TicTacToePlayer.X) { }
54 |
55 | public TicTacToeState(TicTacToePlayer[] board, TicTacToePlayer currentPlayer)
56 | {
57 | this.board = board;
58 | CurrentPlayer = currentPlayer;
59 | }
60 |
61 | public TicTacToePlayer CurrentPlayer { get; private set; }
62 |
63 | public IList Actions
64 | {
65 | get
66 | {
67 | if (hasWinner(TicTacToePlayer.X) || hasWinner(TicTacToePlayer.O))
68 | return new TicTacToeAction[0];
69 |
70 | return board
71 | .Take(9)
72 | .Select((player, position) => new { player, position })
73 | .Where(o => o.player == null)
74 | .Select((o) => new TicTacToeAction(o.position, CurrentPlayer))
75 | .ToList();
76 | }
77 | }
78 |
79 | public void ApplyAction(TicTacToeAction action)
80 | {
81 | board[action.Position] = action.Player;
82 | CurrentPlayer = CurrentPlayer.NextPlayer;
83 | }
84 |
85 | public IState Clone()
86 | {
87 | return new TicTacToeState(board.ToArray(), CurrentPlayer);
88 | }
89 |
90 | private static int[][] winningCombos = new[]
91 | {
92 | new [] {0, 1, 2},
93 | new [] {0, 4, 8},
94 | new [] {0, 3, 6},
95 | new [] {1, 4, 7},
96 | new [] {2, 4, 6},
97 | new [] {2, 5, 8},
98 | new [] {3, 4, 5}
99 | };
100 |
101 | private bool hasWinner(TicTacToePlayer forPlayer)
102 | {
103 | return winningCombos.Any(c => c.All(i => board[i] != null && board[i].IsX == forPlayer.IsX));
104 | }
105 |
106 | public double GetResult(TicTacToePlayer forPlayer)
107 | {
108 | return hasWinner(forPlayer) ? 1
109 | : hasWinner(forPlayer.NextPlayer) ? 0
110 | : 0.5;
111 | }
112 |
113 | public override string ToString()
114 | {
115 | var sb = new StringBuilder();
116 | for (var rowOffset = 0; rowOffset < 9; rowOffset += 3)
117 | {
118 | for (var col = 0; col < 3; col++)
119 | {
120 | var inputValue = rowOffset + col;
121 | var player = board[inputValue];
122 | sb.Append(player == null ? inputValue.ToString() : player.ToString());
123 | if (col < 2)
124 | sb.Append("|");
125 | }
126 | if (rowOffset < 8)
127 | sb.Append("\n-----\n");
128 | }
129 |
130 | return sb.ToString();
131 | }
132 | }
133 | }
134 |
--------------------------------------------------------------------------------
/examples/TicTacToe/TicTacToeTests/TicTacToeStateTest.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 | using System.Linq;
4 | using TicTacToe;
5 | using Xunit;
6 |
7 | namespace TicTacToeTests
8 | {
9 | public class TicTacToeStateTest
10 | {
11 | private static void AssertActionsEqual(IList expected, IList actual)
12 | {
13 | Assert.Equal(expected.Count, actual.Count);
14 | foreach(var a in expected)
15 | Assert.True(actual.Any(aa => aa.Player == a.Player && aa.Position == a.Position));
16 | foreach (var aa in actual)
17 | Assert.True(expected.Any(a => a.Player == aa.Player && a.Position == aa.Position));
18 | }
19 |
20 | [Fact]
21 | public void ToString_DoesntLookTooStupid()
22 | {
23 | var expected = "0|1|2\n-----\nX|X|X\n-----\nO|O|8\n-----\n";
24 | var actual = new TicTacToeState(new[] {
25 | null, null, null,
26 | TicTacToePlayer.X, TicTacToePlayer.X, TicTacToePlayer.X,
27 | TicTacToePlayer.O, TicTacToePlayer.O, null },
28 | TicTacToePlayer.O)
29 | .ToString();
30 | Xunit.Assert.Equal(expected, actual);
31 | }
32 |
33 | [Fact]
34 | public void Actions_EmptyBoard_9PossibleActionsForXPlayer()
35 | {
36 | var expected = Enumerable.Range(0, 9).Select(i => new TicTacToeAction(i, TicTacToePlayer.X)).ToList();
37 | var actual = new TicTacToeState().Actions.ToList();
38 |
39 | AssertActionsEqual(expected, actual);
40 | }
41 |
42 | [Fact]
43 | public void Actions_ExcludesOptionsAlreadyPlayed()
44 | {
45 | var expected = new[] { 2, 3, 4, 7, 8 }.Select(position => new TicTacToeAction(position, TicTacToePlayer.O)).ToList();
46 | var actual = new TicTacToeState(new[] {
47 | TicTacToePlayer.X, TicTacToePlayer.O, null,
48 | null, null, TicTacToePlayer.X,
49 | TicTacToePlayer.O, null, null
50 | },
51 | TicTacToePlayer.O)
52 | .Actions.ToList();
53 |
54 | AssertActionsEqual(expected, actual);
55 | }
56 |
57 | [Fact]
58 | public void Actions_NoneWhenWinnerOnBoard()
59 | {
60 | var expected = new TicTacToeAction[0];
61 | var actual = new TicTacToeState(new[] { TicTacToePlayer.X, TicTacToePlayer.X, TicTacToePlayer.X, null, null, null, null, null, null }, TicTacToePlayer.O).Actions.ToList();
62 | AssertActionsEqual(expected, actual);
63 | }
64 |
65 | [Fact]
66 | public void Actions_DebugThisSetup()
67 | {
68 | var state = new TicTacToeState(new[] {
69 | TicTacToePlayer.X, null, null,
70 | null, null, null,
71 | null, null, null}, TicTacToePlayer.O);
72 | var actual = state.Actions;
73 | var expected = new[] { 1, 2, 3, 4, 5, 6, 7, 8 }.Select(position => new TicTacToeAction(position, TicTacToePlayer.O)).ToList();
74 | AssertActionsEqual(expected, actual);
75 | }
76 |
77 | [Fact]
78 | public void ApplyAction_ModifiesAvailableActions()
79 | {
80 | var game = new TicTacToeState();
81 |
82 | for(var i = 9; i>2; i--)
83 | {
84 | Assert.Equal(i, game.Actions.Count);
85 | game.ApplyAction(game.Actions[0]);
86 | }
87 | }
88 |
89 | [Fact]
90 | public void GetResult_BoardHasWin_DifferentValueForEachPlayer()
91 | {
92 | var game = new TicTacToeState(new[]
93 | {
94 | TicTacToePlayer.X, TicTacToePlayer.X, TicTacToePlayer.X,
95 | TicTacToePlayer.O, TicTacToePlayer.O, null,
96 | TicTacToePlayer.O, null, null
97 | }, TicTacToePlayer.X);
98 |
99 | Assert.Equal(1, game.GetResult(TicTacToePlayer.X));
100 | Assert.Equal(0, game.GetResult(TicTacToePlayer.O));
101 | }
102 |
103 | [Fact]
104 | public void GetResult_BoardIsTie_EqualsHalf()
105 | {
106 | var game = new TicTacToeState(new[]
107 | {
108 | TicTacToePlayer.X, TicTacToePlayer.X, TicTacToePlayer.O,
109 | TicTacToePlayer.O, TicTacToePlayer.X, TicTacToePlayer.X,
110 | TicTacToePlayer.X, TicTacToePlayer.O, TicTacToePlayer.O
111 | }, TicTacToePlayer.O);
112 |
113 | Assert.Equal(0.5, game.GetResult(TicTacToePlayer.X));
114 | Assert.Equal(0.5, game.GetResult(TicTacToePlayer.O));
115 | }
116 | }
117 | }
118 |
--------------------------------------------------------------------------------
/examples/ConnectFour/ConnectFourTests/ConnectFourStateTest.cs:
--------------------------------------------------------------------------------
1 | using ConnectFour;
2 | using Microsoft.Extensions.FileProviders;
3 | using Microsoft.VisualStudio.TestTools.UnitTesting;
4 | using System;
5 | using System.Collections.Generic;
6 | using System.IO;
7 | using System.Linq;
8 | using System.Reflection;
9 |
10 | namespace ConnectFourTests
11 | {
12 | [TestClass]
13 | public class ConnectFourStateTest
14 | {
15 | private static ConnectFourPlayer X = ConnectFourPlayer.X;
16 | private static ConnectFourPlayer O = ConnectFourPlayer.O;
17 | private static ConnectFourPlayer _ = null;
18 |
19 | private IEnumerable GetStatesFromData(string dataFileName)
20 | {
21 | var fileProvider = new EmbeddedFileProvider(Assembly.GetAssembly(GetType()));
22 |
23 | var boards = fileProvider.ReadAllText(dataFileName)
24 | .SplitOnNewLine()
25 | .Batch(ConnectFourState.NumRows + 1);
26 |
27 | foreach(var rows in boards)
28 | {
29 | var board = rows
30 | .Reverse() // reverse the order because the bottom-left position is actually the first element in the board state
31 | .Skip(1) // skip the delimiter row at the bottom of the board (i.e. "-------")
32 | .Select(row => row.PadRight(ConnectFourState.NumCols))
33 | .SelectMany(s => s.Select(c => c == 'X' ? X : c == 'O' ? O : _))
34 | .ToArray();
35 |
36 | var state = new ConnectFourState(board);
37 | yield return state;
38 | }
39 | }
40 |
41 | [TestMethod]
42 | public void GetResult_Rows()
43 | {
44 | var boards = GetStatesFromData("TestData.GetResult_Rows.txt");
45 | foreach(var state in boards)
46 | {
47 | Assert.AreEqual(1, state.GetResult(X), state.ToString());
48 | Assert.AreEqual(0, state.GetResult(O), state.ToString());
49 | Assert.IsFalse(state.Actions.Any());
50 | }
51 | }
52 |
53 | [TestMethod]
54 | public void GetResult_Columns()
55 | {
56 | var boards = GetStatesFromData("TestData.GetResult_Columns.txt");
57 |
58 | foreach(var state in boards)
59 | {
60 | Assert.AreEqual(1, state.GetResult(X), state.ToString());
61 | Assert.AreEqual(0, state.GetResult(O), state.ToString());
62 | Assert.IsFalse(state.Actions.Any());
63 | }
64 | }
65 |
66 | [TestMethod]
67 | public void GetResult_DiagonalBottomTop()
68 | {
69 | var boards = GetStatesFromData("TestData.GetResult_DiagonalBottomTop.txt");
70 |
71 | foreach (var state in boards)
72 | {
73 | Assert.AreEqual(1, state.GetResult(X), state.ToString());
74 | Assert.AreEqual(0, state.GetResult(O), state.ToString());
75 | Assert.IsFalse(state.Actions.Any());
76 | }
77 | }
78 |
79 | [TestMethod]
80 | public void GetResult_DiagonalTopBottom()
81 | {
82 | var boards = GetStatesFromData("TestData.GetResult_DiagonalTopBottom.txt");
83 |
84 | foreach (var state in boards)
85 | {
86 | Assert.AreEqual(1, state.GetResult(X), state.ToString());
87 | Assert.AreEqual(0, state.GetResult(O), state.ToString());
88 | Assert.IsFalse(state.Actions.Any());
89 | }
90 | }
91 |
92 | [TestMethod]
93 | public void FullRun_Simple()
94 | {
95 | var state = new ConnectFourState();
96 | foreach (var position in new[] { 0, 0, 1, 1, 2, 2, 3 })
97 | state.ApplyAction(new ConnectFourAction(position));
98 |
99 | Assert.IsFalse(state.Actions.Any());
100 | Assert.AreEqual(1, state.GetResult(X));
101 | Assert.AreEqual(0, state.GetResult(O));
102 | }
103 | }
104 |
105 | public static class ExtensionMethods
106 | {
107 | public static IEnumerable SplitOnNewLine(this string source)
108 | {
109 | return source.Split(new[] { Environment.NewLine }, StringSplitOptions.None);
110 | }
111 |
112 | public static IEnumerable> Batch(this IEnumerable source, int batchSize)
113 | {
114 | var batch = new List(batchSize);
115 | foreach(var t in source)
116 | {
117 | batch.Add(t);
118 | if(batch.Count == batchSize)
119 | {
120 | yield return batch;
121 | batch = new List(batchSize);
122 | }
123 | }
124 |
125 | if (batch.Any())
126 | yield return batch;
127 | }
128 |
129 | public static string ReadAllText(this IFileProvider fileProvider, string fileName)
130 | {
131 | var file = fileProvider.GetFileInfo(fileName);
132 | using (var stream = file.CreateReadStream())
133 | using (var reader = new StreamReader(stream))
134 | return reader.ReadToEnd();
135 | }
136 | }
137 | }
138 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ## Ignore Visual Studio temporary files, build results, and
2 | ## files generated by popular Visual Studio add-ons.
3 | ##
4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
5 |
6 | # User-specific files
7 | *.suo
8 | *.user
9 | *.userosscache
10 | *.sln.docstates
11 |
12 | # User-specific files (MonoDevelop/Xamarin Studio)
13 | *.userprefs
14 |
15 | # Build results
16 | [Dd]ebug/
17 | [Dd]ebugPublic/
18 | [Rr]elease/
19 | [Rr]eleases/
20 | x64/
21 | x86/
22 | bld/
23 | [Bb]in/
24 | [Oo]bj/
25 | [Ll]og/
26 |
27 | # Visual Studio 2015 cache/options directory
28 | .vs/
29 | # Uncomment if you have tasks that create the project's static files in wwwroot
30 | #wwwroot/
31 |
32 | # MSTest test Results
33 | [Tt]est[Rr]esult*/
34 | [Bb]uild[Ll]og.*
35 |
36 | # NUNIT
37 | *.VisualState.xml
38 | TestResult.xml
39 |
40 | # Build Results of an ATL Project
41 | [Dd]ebugPS/
42 | [Rr]eleasePS/
43 | dlldata.c
44 |
45 | # .NET Core
46 | project.lock.json
47 | project.fragment.lock.json
48 | artifacts/
49 | **/Properties/launchSettings.json
50 |
51 | *_i.c
52 | *_p.c
53 | *_i.h
54 | *.ilk
55 | *.meta
56 | *.obj
57 | *.pch
58 | *.pdb
59 | *.pgc
60 | *.pgd
61 | *.rsp
62 | *.sbr
63 | *.tlb
64 | *.tli
65 | *.tlh
66 | *.tmp
67 | *.tmp_proj
68 | *.log
69 | *.vspscc
70 | *.vssscc
71 | .builds
72 | *.pidb
73 | *.svclog
74 | *.scc
75 |
76 | # Chutzpah Test files
77 | _Chutzpah*
78 |
79 | # Visual C++ cache files
80 | ipch/
81 | *.aps
82 | *.ncb
83 | *.opendb
84 | *.opensdf
85 | *.sdf
86 | *.cachefile
87 | *.VC.db
88 | *.VC.VC.opendb
89 |
90 | # Visual Studio profiler
91 | *.psess
92 | *.vsp
93 | *.vspx
94 | *.sap
95 |
96 | # TFS 2012 Local Workspace
97 | $tf/
98 |
99 | # Guidance Automation Toolkit
100 | *.gpState
101 |
102 | # ReSharper is a .NET coding add-in
103 | _ReSharper*/
104 | *.[Rr]e[Ss]harper
105 | *.DotSettings.user
106 |
107 | # JustCode is a .NET coding add-in
108 | .JustCode
109 |
110 | # TeamCity is a build add-in
111 | _TeamCity*
112 |
113 | # DotCover is a Code Coverage Tool
114 | *.dotCover
115 |
116 | # Visual Studio code coverage results
117 | *.coverage
118 | *.coveragexml
119 |
120 | # NCrunch
121 | _NCrunch_*
122 | .*crunch*.local.xml
123 | nCrunchTemp_*
124 |
125 | # MightyMoose
126 | *.mm.*
127 | AutoTest.Net/
128 |
129 | # Web workbench (sass)
130 | .sass-cache/
131 |
132 | # Installshield output folder
133 | [Ee]xpress/
134 |
135 | # DocProject is a documentation generator add-in
136 | DocProject/buildhelp/
137 | DocProject/Help/*.HxT
138 | DocProject/Help/*.HxC
139 | DocProject/Help/*.hhc
140 | DocProject/Help/*.hhk
141 | DocProject/Help/*.hhp
142 | DocProject/Help/Html2
143 | DocProject/Help/html
144 |
145 | # Click-Once directory
146 | publish/
147 |
148 | # Publish Web Output
149 | *.[Pp]ublish.xml
150 | *.azurePubxml
151 | # TODO: Comment the next line if you want to checkin your web deploy settings
152 | # but database connection strings (with potential passwords) will be unencrypted
153 | *.pubxml
154 | *.publishproj
155 |
156 | # Microsoft Azure Web App publish settings. Comment the next line if you want to
157 | # checkin your Azure Web App publish settings, but sensitive information contained
158 | # in these scripts will be unencrypted
159 | PublishScripts/
160 |
161 | # NuGet Packages
162 | *.nupkg
163 | # The packages folder can be ignored because of Package Restore
164 | **/packages/*
165 | # except build/, which is used as an MSBuild target.
166 | !**/packages/build/
167 | # Uncomment if necessary however generally it will be regenerated when needed
168 | #!**/packages/repositories.config
169 | # NuGet v3's project.json files produces more ignorable files
170 | *.nuget.props
171 | *.nuget.targets
172 |
173 | # Microsoft Azure Build Output
174 | csx/
175 | *.build.csdef
176 |
177 | # Microsoft Azure Emulator
178 | ecf/
179 | rcf/
180 |
181 | # Windows Store app package directories and files
182 | AppPackages/
183 | BundleArtifacts/
184 | Package.StoreAssociation.xml
185 | _pkginfo.txt
186 |
187 | # Visual Studio cache files
188 | # files ending in .cache can be ignored
189 | *.[Cc]ache
190 | # but keep track of directories ending in .cache
191 | !*.[Cc]ache/
192 |
193 | # Others
194 | ClientBin/
195 | ~$*
196 | *~
197 | *.dbmdl
198 | *.dbproj.schemaview
199 | *.jfm
200 | *.pfx
201 | *.publishsettings
202 | orleans.codegen.cs
203 |
204 | # Since there are multiple workflows, uncomment next line to ignore bower_components
205 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
206 | #bower_components/
207 |
208 | # RIA/Silverlight projects
209 | Generated_Code/
210 |
211 | # Backup & report files from converting an old project file
212 | # to a newer Visual Studio version. Backup files are not needed,
213 | # because we have git ;-)
214 | _UpgradeReport_Files/
215 | Backup*/
216 | UpgradeLog*.XML
217 | UpgradeLog*.htm
218 |
219 | # SQL Server files
220 | *.mdf
221 | *.ldf
222 | *.ndf
223 |
224 | # Business Intelligence projects
225 | *.rdl.data
226 | *.bim.layout
227 | *.bim_*.settings
228 |
229 | # Microsoft Fakes
230 | FakesAssemblies/
231 |
232 | # GhostDoc plugin setting file
233 | *.GhostDoc.xml
234 |
235 | # Node.js Tools for Visual Studio
236 | .ntvs_analysis.dat
237 | node_modules/
238 |
239 | # Typescript v1 declaration files
240 | typings/
241 |
242 | # Visual Studio 6 build log
243 | *.plg
244 |
245 | # Visual Studio 6 workspace options file
246 | *.opt
247 |
248 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
249 | *.vbw
250 |
251 | # Visual Studio LightSwitch build output
252 | **/*.HTMLClient/GeneratedArtifacts
253 | **/*.DesktopClient/GeneratedArtifacts
254 | **/*.DesktopClient/ModelManifest.xml
255 | **/*.Server/GeneratedArtifacts
256 | **/*.Server/ModelManifest.xml
257 | _Pvt_Extensions
258 |
259 | # Paket dependency manager
260 | .paket/paket.exe
261 | paket-files/
262 |
263 | # FAKE - F# Make
264 | .fake/
265 |
266 | # JetBrains Rider
267 | .idea/
268 | *.sln.iml
269 |
270 | # CodeRush
271 | .cr/
272 |
273 | # Python Tools for Visual Studio (PTVS)
274 | __pycache__/
275 | *.pyc
276 |
277 | # Cake - Uncomment if you are using it
278 | # tools/**
279 | # !tools/packages.config
280 |
281 | # Telerik's JustMock configuration file
282 | *.jmconfig
283 |
284 | # BizTalk build output
285 | *.btp.cs
286 | *.btm.cs
287 | *.odx.cs
288 | *.xsd.cs
289 |
--------------------------------------------------------------------------------
/montecarlo.sln:
--------------------------------------------------------------------------------
1 |
2 | Microsoft Visual Studio Solution File, Format Version 12.00
3 | # Visual Studio 15
4 | VisualStudioVersion = 15.0.26730.12
5 | MinimumVisualStudioVersion = 15.0.26124.0
6 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MonteCarlo", "src\MonteCarlo\MonteCarlo.csproj", "{1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}"
7 | EndProject
8 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TicTacToe", "examples\TicTacToe\TicTacToe\TicTacToe.csproj", "{C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}"
9 | EndProject
10 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TicTacToeTests", "examples\TicTacToe\TicTacToeTests\TicTacToeTests.csproj", "{62924053-8DE0-458D-AEB7-F8B46F9FBAB3}"
11 | EndProject
12 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ConnectFour", "examples\ConnectFour\ConnectFour\ConnectFour.csproj", "{9858D70A-05AA-4761-99D3-36168BA3C632}"
13 | EndProject
14 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ConnectFourTests", "examples\ConnectFour\ConnectFourTests\ConnectFourTests.csproj", "{59408B67-4DEC-4FEF-8F8D-A2D552213017}"
15 | EndProject
16 | Global
17 | GlobalSection(SolutionConfigurationPlatforms) = preSolution
18 | Debug|Any CPU = Debug|Any CPU
19 | Debug|x64 = Debug|x64
20 | Debug|x86 = Debug|x86
21 | Release|Any CPU = Release|Any CPU
22 | Release|x64 = Release|x64
23 | Release|x86 = Release|x86
24 | EndGlobalSection
25 | GlobalSection(ProjectConfigurationPlatforms) = postSolution
26 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
27 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Debug|Any CPU.Build.0 = Debug|Any CPU
28 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Debug|x64.ActiveCfg = Debug|Any CPU
29 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Debug|x64.Build.0 = Debug|Any CPU
30 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Debug|x86.ActiveCfg = Debug|Any CPU
31 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Debug|x86.Build.0 = Debug|Any CPU
32 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Release|Any CPU.ActiveCfg = Release|Any CPU
33 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Release|Any CPU.Build.0 = Release|Any CPU
34 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Release|x64.ActiveCfg = Release|Any CPU
35 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Release|x64.Build.0 = Release|Any CPU
36 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Release|x86.ActiveCfg = Release|Any CPU
37 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Release|x86.Build.0 = Release|Any CPU
38 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
39 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Debug|Any CPU.Build.0 = Debug|Any CPU
40 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Debug|x64.ActiveCfg = Debug|Any CPU
41 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Debug|x64.Build.0 = Debug|Any CPU
42 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Debug|x86.ActiveCfg = Debug|Any CPU
43 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Debug|x86.Build.0 = Debug|Any CPU
44 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Release|Any CPU.ActiveCfg = Release|Any CPU
45 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Release|Any CPU.Build.0 = Release|Any CPU
46 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Release|x64.ActiveCfg = Release|Any CPU
47 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Release|x64.Build.0 = Release|Any CPU
48 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Release|x86.ActiveCfg = Release|Any CPU
49 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Release|x86.Build.0 = Release|Any CPU
50 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
51 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Debug|Any CPU.Build.0 = Debug|Any CPU
52 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Debug|x64.ActiveCfg = Debug|Any CPU
53 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Debug|x64.Build.0 = Debug|Any CPU
54 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Debug|x86.ActiveCfg = Debug|Any CPU
55 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Debug|x86.Build.0 = Debug|Any CPU
56 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Release|Any CPU.ActiveCfg = Release|Any CPU
57 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Release|Any CPU.Build.0 = Release|Any CPU
58 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Release|x64.ActiveCfg = Release|Any CPU
59 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Release|x64.Build.0 = Release|Any CPU
60 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Release|x86.ActiveCfg = Release|Any CPU
61 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Release|x86.Build.0 = Release|Any CPU
62 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
63 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Debug|Any CPU.Build.0 = Debug|Any CPU
64 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Debug|x64.ActiveCfg = Debug|Any CPU
65 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Debug|x64.Build.0 = Debug|Any CPU
66 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Debug|x86.ActiveCfg = Debug|Any CPU
67 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Debug|x86.Build.0 = Debug|Any CPU
68 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Release|Any CPU.ActiveCfg = Release|Any CPU
69 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Release|Any CPU.Build.0 = Release|Any CPU
70 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Release|x64.ActiveCfg = Release|Any CPU
71 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Release|x64.Build.0 = Release|Any CPU
72 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Release|x86.ActiveCfg = Release|Any CPU
73 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Release|x86.Build.0 = Release|Any CPU
74 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
75 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Debug|Any CPU.Build.0 = Debug|Any CPU
76 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Debug|x64.ActiveCfg = Debug|Any CPU
77 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Debug|x64.Build.0 = Debug|Any CPU
78 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Debug|x86.ActiveCfg = Debug|Any CPU
79 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Debug|x86.Build.0 = Debug|Any CPU
80 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Release|Any CPU.ActiveCfg = Release|Any CPU
81 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Release|Any CPU.Build.0 = Release|Any CPU
82 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Release|x64.ActiveCfg = Release|Any CPU
83 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Release|x64.Build.0 = Release|Any CPU
84 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Release|x86.ActiveCfg = Release|Any CPU
85 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Release|x86.Build.0 = Release|Any CPU
86 | EndGlobalSection
87 | GlobalSection(SolutionProperties) = preSolution
88 | HideSolutionNode = FALSE
89 | EndGlobalSection
90 | GlobalSection(ExtensibilityGlobals) = postSolution
91 | SolutionGuid = {BB81C8F8-976A-4679-BF19-A0D29029D70D}
92 | EndGlobalSection
93 | EndGlobal
94 |
--------------------------------------------------------------------------------
/src/MonteCarlo/MonteCarloTreeSearch.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 | using System.Diagnostics;
4 | using System.Linq;
5 |
6 | namespace MonteCarlo
7 | {
8 | public interface IAction {}
9 |
10 | public interface IPlayer {}
11 |
12 | public interface IState
13 | {
14 | IState Clone();
15 |
16 | TPlayer CurrentPlayer { get; }
17 |
18 | IList Actions { get; }
19 |
20 | void ApplyAction(TAction action);
21 |
22 | double GetResult(TPlayer forPlayer);
23 | }
24 |
25 | public class MonteCarloTreeSearch
26 | {
27 | private class Node : IMctsNode where TPlayer : IPlayer where TAction: IAction
28 | {
29 | public Node(IState state, TAction action = default(TAction), Node parent = null)
30 | {
31 | this.Parent = parent;
32 | Player = state.CurrentPlayer;
33 | State = state;
34 | Action = action;
35 | UntriedActions = new HashSet(state.Actions);
36 | }
37 |
38 | public Node Parent { get; }
39 |
40 | public IList> Children { get; } = new List>();
41 |
42 | public int NumRuns { get; set; }
43 |
44 | public double NumWins { get; set; }
45 |
46 | public TPlayer Player { get; }
47 |
48 | public IState State { get; }
49 |
50 | public TAction Action { get; }
51 |
52 | public ISet UntriedActions { get; }
53 |
54 | public IList Actions => State.Actions;
55 |
56 | private static double c = Math.Sqrt(2);
57 |
58 | public double ExploitationValue => NumWins / NumRuns;
59 |
60 | public double ExplorationValue => (Math.Sqrt(2*Math.Log(Parent.NumRuns) / NumRuns));
61 |
62 | private double UCT => ExploitationValue + ExplorationValue;
63 |
64 | public Node SelectChild()
65 | {
66 | return Children.MaxElementBy(e => e.UCT);
67 | }
68 |
69 | public Node AddChild(TAction action, IState state)
70 | {
71 | var child = new Node(state, action, this);
72 | UntriedActions.Remove(action);
73 | Children.Add(child);
74 |
75 | return child;
76 | }
77 |
78 | public void BuildTree(Func shouldContinue)
79 | {
80 | var iterations = 0;
81 | var timer = Stopwatch.StartNew();
82 | while (shouldContinue(iterations++, timer.ElapsedMilliseconds))
83 | {
84 | var node = this;
85 | var state = State.Clone();
86 |
87 | //select
88 | while (!node.UntriedActions.Any() && node.Actions.Any())
89 | {
90 | node = node.SelectChild();
91 | state.ApplyAction(node.Action);
92 | }
93 |
94 | //expand
95 | if (node.UntriedActions.Any())
96 | {
97 | var action = node.UntriedActions.RandomChoice();
98 | state.ApplyAction(action);
99 | node = node.AddChild(action, state);
100 | }
101 |
102 | //simulate
103 | while (state.Actions.Any())
104 | state.ApplyAction(state.Actions.RandomChoice());
105 |
106 | //backpropagate
107 | while (node != null)
108 | {
109 | node.NumRuns++;
110 | node.NumWins += state.GetResult(this.Player);
111 | node = node.Parent;
112 | }
113 | }
114 | }
115 |
116 | public override string ToString()
117 | {
118 | return $"{NumWins}/{NumRuns}: ({ExploitationValue}/{ExplorationValue}={UCT})";
119 | }
120 | }
121 |
122 | public static IEnumerable> GetTopActions(IState state, int maxIterations) where TPlayer : IPlayer where TAction : IAction
123 | {
124 | return GetTopActions(state, maxIterations, long.MaxValue);
125 | }
126 |
127 | public static IEnumerable> GetTopActions(IState state, long timeBudget) where TPlayer : IPlayer where TAction : IAction
128 | {
129 | return GetTopActions(state, int.MaxValue, timeBudget);
130 | }
131 |
132 | public static IEnumerable> GetTopActions(IState state, int maxIterations, long timeBudget) where TPlayer : IPlayer where TAction : IAction
133 | {
134 | var root = new Node(state);
135 | root.BuildTree((numIterations, elapsedMs) => numIterations < maxIterations && elapsedMs < timeBudget);
136 | return root.Children
137 | .OrderByDescending(n => n.NumRuns);
138 | }
139 | }
140 |
141 | public interface IMctsNode where TAction : IAction
142 | {
143 | TAction Action { get; }
144 |
145 | int NumRuns { get; }
146 |
147 | double NumWins { get; }
148 | }
149 |
150 | internal static class CollectionExtensions
151 | {
152 | private static Random _random = new Random();
153 |
154 | public static T RandomChoice(this ICollection source, Random random = null)
155 | {
156 | var i = (random ?? _random).Next(source.Count);
157 | return source.ElementAt(i);
158 | }
159 |
160 | public static T MaxElementBy(this IEnumerable source, Func selector)
161 | {
162 | var currentMaxElement = default(T);
163 | var currentMaxValue = double.MinValue;
164 |
165 | foreach (var element in source)
166 | {
167 | var value = selector(element);
168 | if (currentMaxValue < value)
169 | {
170 | currentMaxValue = value;
171 | currentMaxElement = element;
172 | }
173 | }
174 |
175 | return currentMaxElement;
176 | }
177 | }
178 | }
179 |
--------------------------------------------------------------------------------
/examples/ConnectFour/ConnectFour/ConnectFourState.cs:
--------------------------------------------------------------------------------
1 | using MonteCarlo;
2 | using System;
3 | using System.Collections.Generic;
4 | using System.Linq;
5 | using System.Text;
6 |
7 | namespace ConnectFour
8 | {
9 | public class ConnectFourPlayer : IPlayer
10 | {
11 | public string Name { get; private set; }
12 |
13 | private ConnectFourPlayer(string name)
14 | {
15 | Name = name;
16 | }
17 |
18 | public ConnectFourPlayer NextPlayer => this == X ? O : X;
19 |
20 | public override string ToString()
21 | {
22 | return Name;
23 | }
24 |
25 | public static readonly ConnectFourPlayer X = new ConnectFourPlayer("X");
26 | public static readonly ConnectFourPlayer O = new ConnectFourPlayer("O");
27 | }
28 |
29 | public struct ConnectFourAction : IAction
30 | {
31 | public int Column { get; private set; }
32 |
33 | public ConnectFourAction(int column)
34 | {
35 | Column = column;
36 | }
37 |
38 | public override string ToString()
39 | {
40 | return Column.ToString();
41 | }
42 | }
43 |
44 | public class ConnectFourState : IState
45 | {
46 | public const int NumRows = 6;
47 | public const int NumCols = 7;
48 |
49 | // board is a NumRows * NumCols array, with the 0th index being the bottom left and the last value being the top right
50 | private ConnectFourPlayer[] board;
51 | // to save time, we precompute the available actions and store them in this variable
52 | private IList availableActions;
53 | // to save time, we precompute the available winning runs of 4 ahead of time and prune the list as we apply actions
54 | private IList availableWinningRuns = new List();
55 | // when a winner is found, store it here. Maybe that's not a bad idea, but I'll let future Dan be the judge
56 | private ConnectFourPlayer winner;
57 |
58 | private void pruneAvailableWinningRuns(int forPosition)
59 | {
60 | var toPrune = new List();
61 | foreach(var run in availableWinningRuns)
62 | {
63 | if (run.Contains(forPosition))
64 | {
65 | var boardValues = run.ToBoardValues(board);
66 | ConnectFourPlayer runWinner;
67 | if((runWinner = boardValues.GetWinner()) != null)
68 | {
69 | winner = runWinner;
70 | availableWinningRuns.Clear();
71 | return;
72 | }
73 |
74 | if(!boardValues.CanWin())
75 | {
76 | toPrune.Add(run);
77 | }
78 | }
79 | }
80 | toPrune.ForEach(run => availableWinningRuns.Remove(run));
81 | }
82 |
83 | ///
84 | /// Converts a row, column coordinate into a flat index in the board
85 | ///
86 | ///
87 | ///
88 | ///
89 | private int GetBoardIndex(int row, int col)
90 | {
91 | return row * NumCols + col;
92 | }
93 |
94 | private IList GetDefaultAvailableWinningRuns()
95 | {
96 | var runs = new List();
97 |
98 | for (var row = 0; row < NumRows; row++)
99 | {
100 | var canRunTop = row < NumRows - 3;
101 | for (var col = 0; col < NumCols; col++)
102 | {
103 | var canRunRight = col < NumCols - 3;
104 | if (canRunRight)
105 | {
106 | var horizontalRun = Enumerable.Range(col, 4).Select(c => GetBoardIndex(row, c)).ToArray();
107 | runs.Add(horizontalRun);
108 | }
109 | if (canRunTop)
110 | {
111 | var verticalRun = Enumerable.Range(row, 4).Select(r => GetBoardIndex(r, col)).ToArray();
112 | runs.Add(verticalRun);
113 | }
114 | if(canRunRight && canRunTop)
115 | {
116 | var diagonalBottomRopRun = Enumerable.Range(0, 4).Select(i => GetBoardIndex(row + i, col + i)).ToArray();
117 | runs.Add(diagonalBottomRopRun);
118 | }
119 | if(canRunRight && row > 2)
120 | {
121 | var diagonalTopBottomRun = Enumerable.Range(0, 4).Select(i => GetBoardIndex(row - i, col + i)).ToArray();
122 | runs.Add(diagonalTopBottomRun);
123 | }
124 | }
125 | }
126 |
127 | return runs;
128 | }
129 |
130 | public ConnectFourPlayer CurrentPlayer { get; private set; }
131 |
132 | public IList Actions => winner == null ? availableActions : new ConnectFourAction[0];
133 |
134 | public ConnectFourState(ConnectFourPlayer[] board, ConnectFourPlayer currentPlayer = null)
135 | {
136 | this.board = board;
137 | availableActions = GetRow(NumCols - 1)
138 | .Select((p, i) => new { p, i })
139 | .Where(o => o.p == null)
140 | .Select(o => new ConnectFourAction(o.i))
141 | .ToList();
142 |
143 | // TODO: prune runs that can't produce wins
144 | availableWinningRuns = GetDefaultAvailableWinningRuns();
145 | CurrentPlayer = CurrentPlayer ?? ConnectFourPlayer.X;
146 | board
147 | .Select((p, i) => new { p, i })
148 | .Where(o => o.p != null)
149 | .Select(o => o.i)
150 | .ToList()
151 | .ForEach(pruneAvailableWinningRuns);
152 | }
153 |
154 | public ConnectFourState() {
155 | this.board = new ConnectFourPlayer[NumRows * NumCols];
156 | availableActions = Enumerable.Range(0, NumCols).Select(i => new ConnectFourAction(i)).ToList();
157 | availableWinningRuns = GetDefaultAvailableWinningRuns();
158 | CurrentPlayer = ConnectFourPlayer.X;
159 | }
160 |
161 | public ConnectFourState(ConnectFourState toClone)
162 | {
163 | board = toClone.board.ToArray();
164 | availableActions = new List(toClone.availableActions);
165 | availableWinningRuns = new List(toClone.availableWinningRuns);
166 | CurrentPlayer = toClone.CurrentPlayer;
167 | }
168 |
169 | public IState Clone()
170 | {
171 | return new ConnectFourState(this);
172 | }
173 |
174 | public void ApplyAction(ConnectFourAction action)
175 | {
176 | var row = GetCol(action.Column)
177 | .Select((p, i) => new { p, i })
178 | .SkipWhile(o => o.p != null)
179 | .FirstOrDefault()?.i;
180 |
181 | if (row == null)
182 | throw new ArgumentException($"Column {action.Column} is already full", nameof(action));
183 |
184 | // if this was the top row, then remove this column as an available action
185 | if (row >= NumRows - 1)
186 | availableActions.Remove(availableActions.FirstOrDefault(a => a.Column == action.Column));
187 |
188 | // TODO: prune the available winning runs list
189 | var actionPosition = GetBoardIndex(row.Value, action.Column);
190 | board[actionPosition] = CurrentPlayer;
191 | CurrentPlayer = CurrentPlayer.NextPlayer;
192 | pruneAvailableWinningRuns(actionPosition);
193 | }
194 |
195 | public double GetResult(ConnectFourPlayer forPlayer)
196 | {
197 | if (Actions.Any())
198 | throw new InvalidOperationException("Game isn't over yet");
199 |
200 | if (winner == forPlayer) return 1;
201 | if (winner == null) return 0.5;
202 | return 0;
203 | }
204 |
205 | public IEnumerable GetRow(int rowNum)
206 | {
207 | var startIndex = NumCols * rowNum;
208 | var endIndex = startIndex + NumCols;
209 | for(var i = startIndex; i < endIndex && i < board.Length; i++)
210 | yield return board[i];
211 | }
212 |
213 | public IEnumerable GetCol(int colNum)
214 | {
215 | for(var i = colNum; i < board.Length; i += NumCols)
216 | yield return board[i];
217 | }
218 |
219 | public override string ToString()
220 | {
221 | var sb = new StringBuilder();
222 | for (var rowNum = NumRows - 1; rowNum >= 0; rowNum--)
223 | {
224 | sb.AppendJoin("", GetRow(rowNum).Select(c => c == null ? " " : c.ToString().Substring(0, 1)));
225 | if (rowNum > 0)
226 | sb.Append("\n");
227 | }
228 | return sb.ToString();
229 | }
230 | }
231 |
232 | public static class ExtensionMethods
233 | {
234 | public static IEnumerable ToBoardValues(this IEnumerable positions, ConnectFourPlayer[] board)
235 | {
236 | return positions.Select(pos => board[pos]);
237 | }
238 |
239 | public static ConnectFourPlayer GetWinner(this IEnumerable source)
240 | {
241 | ConnectFourPlayer winner = null;
242 | var iteratedFirstElement = false;
243 |
244 | foreach(var p in source)
245 | {
246 | if (!iteratedFirstElement)
247 | {
248 | iteratedFirstElement = true;
249 | winner = p;
250 | }
251 | if (p == null || p != winner)
252 | return null;
253 | }
254 |
255 | return winner;
256 | }
257 |
258 | public static bool CanWin(this IEnumerable source)
259 | {
260 | ConnectFourPlayer firstFoundPlayer = null;
261 |
262 | foreach (var p in source)
263 | if(p != null)
264 | {
265 | if(firstFoundPlayer == null)
266 | firstFoundPlayer = p;
267 | if (p != firstFoundPlayer)
268 | return false;
269 | }
270 |
271 | return true;
272 | }
273 | }
274 | }
275 |
--------------------------------------------------------------------------------