├── README.md ├── src └── MonteCarlo │ ├── MonteCarlo.csproj │ └── MonteCarloTreeSearch.cs ├── examples ├── ConnectFour │ ├── ConnectFour │ │ ├── ConnectFour.csproj │ │ ├── Program.cs │ │ └── ConnectFourState.cs │ └── ConnectFourTests │ │ ├── TestData │ │ ├── GetResult_Columns.txt │ │ ├── GetResult_Rows.txt │ │ ├── GetResult_DiagonalTopBottom.txt │ │ └── GetResult_DiagonalBottomTop.txt │ │ ├── ConnectFourTests.csproj │ │ └── ConnectFourStateTest.cs └── TicTacToe │ ├── TicTacToe │ ├── TicTacToe.csproj │ ├── Program.cs │ └── TicTacToe.cs │ └── TicTacToeTests │ ├── TicTacToeTests.csproj │ └── TicTacToeStateTest.cs ├── LICENSE ├── .gitignore └── montecarlo.sln /README.md: -------------------------------------------------------------------------------- 1 | # mcts 2 | 3 | Exploration with a simple generic Monte Carlo Tree Search algorithm in C# 4 | -------------------------------------------------------------------------------- /src/MonteCarlo/MonteCarlo.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | netstandard2.0 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /examples/ConnectFour/ConnectFour/ConnectFour.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Exe 5 | netcoreapp2.0 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /examples/ConnectFour/ConnectFourTests/TestData/GetResult_Columns.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | X 4 | X 5 | OX 6 | OOX 7 | ------- 8 | 9 | 10 | X 11 | XO 12 | XO 13 | XO 14 | ------- 15 | X 16 | X 17 | X 18 | X 19 | OX 20 | OOO 21 | ------- 22 | X 23 | X 24 | X 25 | XOX 26 | OXOX 27 | XOOOX 28 | ------- 29 | O 30 | XO 31 | XX 32 | XO 33 | OXOXX 34 | XXOOXOX 35 | ------- -------------------------------------------------------------------------------- /examples/TicTacToe/TicTacToe/TicTacToe.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Exe 5 | netcoreapp2.0 6 | TicTacToe.Program 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /examples/TicTacToe/TicTacToeTests/TicTacToeTests.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | netcoreapp2.0 5 | 6 | false 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /examples/ConnectFour/ConnectFourTests/TestData/GetResult_Rows.txt: -------------------------------------------------------------------------------- 1 | XXXX 2 | XOXOXOX 3 | OXOXOXO 4 | OXOXOXO 5 | XOXOXOX 6 | OXOXOXO 7 | --- ^ top left 8 | 9 | 10 | 11 | 12 | OOO 13 | XXXX 14 | --- ^ bottom left 15 | 16 | 17 | 18 | 19 | O O 20 | OXXXX 21 | --- ^ bottom middle 22 | 23 | 24 | 25 | 26 | O O 27 | OXXXX 28 | --- ^ bottom middle 29 | 30 | 31 | O 32 | O 33 | O 34 | XXXX 35 | --- ^ bottom right 36 | XXXX 37 | XOXOXOX 38 | OXOXOXO 39 | OXOXOXO 40 | XOXOXOX 41 | OXOXOXO 42 | --- ^ top middle 43 | XXXX 44 | XOXOXOX 45 | OXOXOXO 46 | OXOXOXO 47 | XOXOXOX 48 | OXOXOXO 49 | --- ^ top right 50 | 51 | 52 | 53 | OO 54 | XXXXO 55 | OXOXOXO 56 | --- ^ middle middle -------------------------------------------------------------------------------- /examples/ConnectFour/ConnectFourTests/TestData/GetResult_DiagonalTopBottom.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | X 4 | X 5 | X 6 | X 7 | --- bottom left 8 | 9 | 10 | X 11 | X 12 | X 13 | X 14 | --- bottom middle 15 | 16 | 17 | X 18 | X 19 | X 20 | X 21 | --- bottom right 22 | 23 | X 24 | X 25 | X 26 | X 27 | 28 | --- middle left 29 | 30 | X 31 | X 32 | X 33 | X 34 | 35 | --- middle middle 36 | 37 | X 38 | X 39 | X 40 | X 41 | 42 | --- middle right 43 | X 44 | X 45 | X 46 | X 47 | 48 | 49 | --- top left 50 | X 51 | X 52 | X 53 | X 54 | 55 | 56 | --- top middle 57 | X 58 | X 59 | X 60 | X 61 | 62 | 63 | --- top right -------------------------------------------------------------------------------- /examples/ConnectFour/ConnectFourTests/TestData/GetResult_DiagonalBottomTop.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | X 4 | XXO 5 | XOXO 6 | XXOOO 7 | --- bottom left 8 | 9 | 10 | X 11 | XO 12 | XXO 13 | XOOX 14 | --- bottom right 15 | 16 | 17 | X 18 | XO 19 | XXO 20 | XOOX 21 | --- bottom middle 22 | 23 | X 24 | XX 25 | XXO 26 | XOOX 27 | XOXOXOX 28 | --- middle left 29 | 30 | X 31 | XO 32 | XOX 33 | XOXO 34 | XOXOXOX 35 | --- middle middle 36 | 37 | X 38 | XO 39 | XXO 40 | XOOX 41 | XOXOXOX 42 | --- middle right 43 | X 44 | XXO 45 | XOXO 46 | XXOOO 47 | XOXOXOX 48 | XOXOXOX 49 | --- top left 50 | X 51 | XO 52 | XXO 53 | XOOX 54 | XOXOXOX 55 | XOXOXOX 56 | --- top right 57 | X 58 | XO 59 | XOX 60 | XOXO 61 | XOXOXOX 62 | XOXOXOX 63 | --- top middle -------------------------------------------------------------------------------- /examples/ConnectFour/ConnectFourTests/ConnectFourTests.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | netcoreapp2.0 5 | 6 | false 7 | 8 | 9 | 10 | 11 | Always 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Daniel McDonald 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/TicTacToe/TicTacToe/Program.cs: -------------------------------------------------------------------------------- 1 | using MonteCarlo; 2 | using System; 3 | using System.Linq; 4 | 5 | namespace TicTacToe 6 | { 7 | class Program 8 | { 9 | private const string SEPARATOR = "______"; 10 | static void Main(string[] args) 11 | { 12 | var game = new TicTacToeState(); 13 | while(game.Actions.Any()) 14 | { 15 | Console.WriteLine($"CurrentPlayer: {game.CurrentPlayer}"); 16 | Console.WriteLine(game); 17 | Console.WriteLine(SEPARATOR); 18 | 19 | var position = -1; 20 | while (position < 0) 21 | { 22 | Console.WriteLine("Choose a free space: (0-8)"); 23 | 24 | var input = Console.ReadKey(); 25 | int.TryParse(input.KeyChar.ToString(), out position); 26 | } 27 | 28 | Console.WriteLine(); 29 | 30 | game.ApplyAction(new TicTacToeAction(position, TicTacToePlayer.X)); 31 | var computer = MonteCarloTreeSearch.GetTopActions(game, 50000).ToList(); 32 | Console.WriteLine(SEPARATOR); 33 | if (computer.Count > 0) 34 | { 35 | Console.WriteLine("Computer's ranked plays:"); 36 | foreach (var a in computer) 37 | Console.WriteLine($"\t{a.Action}\t{a.NumWins}/{a.NumRuns} ({a.NumWins / a.NumRuns})"); 38 | game.ApplyAction(computer[0].Action); 39 | } 40 | 41 | position = -1; 42 | } 43 | 44 | Console.WriteLine(SEPARATOR); 45 | Console.WriteLine(game.ToString()); 46 | Console.WriteLine("Game Over"); 47 | Console.ReadKey(); 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /examples/ConnectFour/ConnectFour/Program.cs: -------------------------------------------------------------------------------- 1 | using MonteCarlo; 2 | using System; 3 | using System.Linq; 4 | 5 | namespace ConnectFour 6 | { 7 | class Program 8 | { 9 | private const string SEPARATOR = "-------"; 10 | static void Main(string[] args) 11 | { 12 | var game = new ConnectFourState(); 13 | while (game.Actions.Any()) 14 | { 15 | Console.WriteLine($"CurrentPlayer: {game.CurrentPlayer}"); 16 | Console.WriteLine(game); 17 | Console.WriteLine("0123456"); 18 | Console.WriteLine(SEPARATOR); 19 | 20 | var position = -1; 21 | while(position < 0 || position > 6) 22 | { 23 | Console.WriteLine("Choose a free space: (0-6)"); 24 | 25 | var input = Console.ReadKey(); 26 | int.TryParse(input.KeyChar.ToString(), out position); 27 | } 28 | 29 | Console.WriteLine(); 30 | game.ApplyAction(new ConnectFourAction(position)); 31 | var computer = MonteCarloTreeSearch.GetTopActions(game, 50000, 1000).ToList(); 32 | 33 | Console.WriteLine(SEPARATOR); 34 | if (computer.Count > 0) 35 | { 36 | Console.WriteLine("Computer's ranked plays:"); 37 | foreach (var a in computer) 38 | Console.WriteLine($"\t{a.Action}\t{a.NumWins}/{a.NumRuns} ({a.NumWins / a.NumRuns})"); 39 | game.ApplyAction(computer[0].Action); 40 | } 41 | 42 | position = -1; 43 | } 44 | 45 | Console.WriteLine(SEPARATOR); 46 | Console.WriteLine(game.ToString()); 47 | Console.WriteLine("Game Over"); 48 | Console.ReadKey(); 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /examples/TicTacToe/TicTacToe/TicTacToe.cs: -------------------------------------------------------------------------------- 1 | using MonteCarlo; 2 | using System; 3 | using System.Collections.Generic; 4 | using System.Linq; 5 | using System.Text; 6 | 7 | namespace TicTacToe 8 | { 9 | public struct TicTacToeAction : IAction 10 | { 11 | public TicTacToeAction(int position, TicTacToePlayer player) 12 | { 13 | if (position < 0 || position > 8) 14 | throw new ArgumentException("position must be between 0 and 8, inclusive", nameof(position)); 15 | 16 | Position = position; 17 | Player = player; 18 | } 19 | 20 | public int Position { get; private set; } 21 | public TicTacToePlayer Player { get; private set; } 22 | 23 | public override string ToString() 24 | { 25 | return $"{Player}: {Position}"; 26 | } 27 | } 28 | 29 | public class TicTacToePlayer : IPlayer 30 | { 31 | private TicTacToePlayer(bool isX) 32 | { 33 | IsX = isX; 34 | } 35 | 36 | public bool IsX { get; } 37 | 38 | public TicTacToePlayer NextPlayer => IsX ? O : X; 39 | 40 | public override string ToString() 41 | { 42 | return IsX ? "X" : "O"; 43 | } 44 | 45 | public static TicTacToePlayer X = new TicTacToePlayer(true); 46 | public static TicTacToePlayer O = new TicTacToePlayer(false); 47 | } 48 | 49 | public class TicTacToeState : IState 50 | { 51 | private IList board; 52 | 53 | public TicTacToeState() : this(new TicTacToePlayer[9], TicTacToePlayer.X) { } 54 | 55 | public TicTacToeState(TicTacToePlayer[] board, TicTacToePlayer currentPlayer) 56 | { 57 | this.board = board; 58 | CurrentPlayer = currentPlayer; 59 | } 60 | 61 | public TicTacToePlayer CurrentPlayer { get; private set; } 62 | 63 | public IList Actions 64 | { 65 | get 66 | { 67 | if (hasWinner(TicTacToePlayer.X) || hasWinner(TicTacToePlayer.O)) 68 | return new TicTacToeAction[0]; 69 | 70 | return board 71 | .Take(9) 72 | .Select((player, position) => new { player, position }) 73 | .Where(o => o.player == null) 74 | .Select((o) => new TicTacToeAction(o.position, CurrentPlayer)) 75 | .ToList(); 76 | } 77 | } 78 | 79 | public void ApplyAction(TicTacToeAction action) 80 | { 81 | board[action.Position] = action.Player; 82 | CurrentPlayer = CurrentPlayer.NextPlayer; 83 | } 84 | 85 | public IState Clone() 86 | { 87 | return new TicTacToeState(board.ToArray(), CurrentPlayer); 88 | } 89 | 90 | private static int[][] winningCombos = new[] 91 | { 92 | new [] {0, 1, 2}, 93 | new [] {0, 4, 8}, 94 | new [] {0, 3, 6}, 95 | new [] {1, 4, 7}, 96 | new [] {2, 4, 6}, 97 | new [] {2, 5, 8}, 98 | new [] {3, 4, 5} 99 | }; 100 | 101 | private bool hasWinner(TicTacToePlayer forPlayer) 102 | { 103 | return winningCombos.Any(c => c.All(i => board[i] != null && board[i].IsX == forPlayer.IsX)); 104 | } 105 | 106 | public double GetResult(TicTacToePlayer forPlayer) 107 | { 108 | return hasWinner(forPlayer) ? 1 109 | : hasWinner(forPlayer.NextPlayer) ? 0 110 | : 0.5; 111 | } 112 | 113 | public override string ToString() 114 | { 115 | var sb = new StringBuilder(); 116 | for (var rowOffset = 0; rowOffset < 9; rowOffset += 3) 117 | { 118 | for (var col = 0; col < 3; col++) 119 | { 120 | var inputValue = rowOffset + col; 121 | var player = board[inputValue]; 122 | sb.Append(player == null ? inputValue.ToString() : player.ToString()); 123 | if (col < 2) 124 | sb.Append("|"); 125 | } 126 | if (rowOffset < 8) 127 | sb.Append("\n-----\n"); 128 | } 129 | 130 | return sb.ToString(); 131 | } 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /examples/TicTacToe/TicTacToeTests/TicTacToeStateTest.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | using TicTacToe; 5 | using Xunit; 6 | 7 | namespace TicTacToeTests 8 | { 9 | public class TicTacToeStateTest 10 | { 11 | private static void AssertActionsEqual(IList expected, IList actual) 12 | { 13 | Assert.Equal(expected.Count, actual.Count); 14 | foreach(var a in expected) 15 | Assert.True(actual.Any(aa => aa.Player == a.Player && aa.Position == a.Position)); 16 | foreach (var aa in actual) 17 | Assert.True(expected.Any(a => a.Player == aa.Player && a.Position == aa.Position)); 18 | } 19 | 20 | [Fact] 21 | public void ToString_DoesntLookTooStupid() 22 | { 23 | var expected = "0|1|2\n-----\nX|X|X\n-----\nO|O|8\n-----\n"; 24 | var actual = new TicTacToeState(new[] { 25 | null, null, null, 26 | TicTacToePlayer.X, TicTacToePlayer.X, TicTacToePlayer.X, 27 | TicTacToePlayer.O, TicTacToePlayer.O, null }, 28 | TicTacToePlayer.O) 29 | .ToString(); 30 | Xunit.Assert.Equal(expected, actual); 31 | } 32 | 33 | [Fact] 34 | public void Actions_EmptyBoard_9PossibleActionsForXPlayer() 35 | { 36 | var expected = Enumerable.Range(0, 9).Select(i => new TicTacToeAction(i, TicTacToePlayer.X)).ToList(); 37 | var actual = new TicTacToeState().Actions.ToList(); 38 | 39 | AssertActionsEqual(expected, actual); 40 | } 41 | 42 | [Fact] 43 | public void Actions_ExcludesOptionsAlreadyPlayed() 44 | { 45 | var expected = new[] { 2, 3, 4, 7, 8 }.Select(position => new TicTacToeAction(position, TicTacToePlayer.O)).ToList(); 46 | var actual = new TicTacToeState(new[] { 47 | TicTacToePlayer.X, TicTacToePlayer.O, null, 48 | null, null, TicTacToePlayer.X, 49 | TicTacToePlayer.O, null, null 50 | }, 51 | TicTacToePlayer.O) 52 | .Actions.ToList(); 53 | 54 | AssertActionsEqual(expected, actual); 55 | } 56 | 57 | [Fact] 58 | public void Actions_NoneWhenWinnerOnBoard() 59 | { 60 | var expected = new TicTacToeAction[0]; 61 | var actual = new TicTacToeState(new[] { TicTacToePlayer.X, TicTacToePlayer.X, TicTacToePlayer.X, null, null, null, null, null, null }, TicTacToePlayer.O).Actions.ToList(); 62 | AssertActionsEqual(expected, actual); 63 | } 64 | 65 | [Fact] 66 | public void Actions_DebugThisSetup() 67 | { 68 | var state = new TicTacToeState(new[] { 69 | TicTacToePlayer.X, null, null, 70 | null, null, null, 71 | null, null, null}, TicTacToePlayer.O); 72 | var actual = state.Actions; 73 | var expected = new[] { 1, 2, 3, 4, 5, 6, 7, 8 }.Select(position => new TicTacToeAction(position, TicTacToePlayer.O)).ToList(); 74 | AssertActionsEqual(expected, actual); 75 | } 76 | 77 | [Fact] 78 | public void ApplyAction_ModifiesAvailableActions() 79 | { 80 | var game = new TicTacToeState(); 81 | 82 | for(var i = 9; i>2; i--) 83 | { 84 | Assert.Equal(i, game.Actions.Count); 85 | game.ApplyAction(game.Actions[0]); 86 | } 87 | } 88 | 89 | [Fact] 90 | public void GetResult_BoardHasWin_DifferentValueForEachPlayer() 91 | { 92 | var game = new TicTacToeState(new[] 93 | { 94 | TicTacToePlayer.X, TicTacToePlayer.X, TicTacToePlayer.X, 95 | TicTacToePlayer.O, TicTacToePlayer.O, null, 96 | TicTacToePlayer.O, null, null 97 | }, TicTacToePlayer.X); 98 | 99 | Assert.Equal(1, game.GetResult(TicTacToePlayer.X)); 100 | Assert.Equal(0, game.GetResult(TicTacToePlayer.O)); 101 | } 102 | 103 | [Fact] 104 | public void GetResult_BoardIsTie_EqualsHalf() 105 | { 106 | var game = new TicTacToeState(new[] 107 | { 108 | TicTacToePlayer.X, TicTacToePlayer.X, TicTacToePlayer.O, 109 | TicTacToePlayer.O, TicTacToePlayer.X, TicTacToePlayer.X, 110 | TicTacToePlayer.X, TicTacToePlayer.O, TicTacToePlayer.O 111 | }, TicTacToePlayer.O); 112 | 113 | Assert.Equal(0.5, game.GetResult(TicTacToePlayer.X)); 114 | Assert.Equal(0.5, game.GetResult(TicTacToePlayer.O)); 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /examples/ConnectFour/ConnectFourTests/ConnectFourStateTest.cs: -------------------------------------------------------------------------------- 1 | using ConnectFour; 2 | using Microsoft.Extensions.FileProviders; 3 | using Microsoft.VisualStudio.TestTools.UnitTesting; 4 | using System; 5 | using System.Collections.Generic; 6 | using System.IO; 7 | using System.Linq; 8 | using System.Reflection; 9 | 10 | namespace ConnectFourTests 11 | { 12 | [TestClass] 13 | public class ConnectFourStateTest 14 | { 15 | private static ConnectFourPlayer X = ConnectFourPlayer.X; 16 | private static ConnectFourPlayer O = ConnectFourPlayer.O; 17 | private static ConnectFourPlayer _ = null; 18 | 19 | private IEnumerable GetStatesFromData(string dataFileName) 20 | { 21 | var fileProvider = new EmbeddedFileProvider(Assembly.GetAssembly(GetType())); 22 | 23 | var boards = fileProvider.ReadAllText(dataFileName) 24 | .SplitOnNewLine() 25 | .Batch(ConnectFourState.NumRows + 1); 26 | 27 | foreach(var rows in boards) 28 | { 29 | var board = rows 30 | .Reverse() // reverse the order because the bottom-left position is actually the first element in the board state 31 | .Skip(1) // skip the delimiter row at the bottom of the board (i.e. "-------") 32 | .Select(row => row.PadRight(ConnectFourState.NumCols)) 33 | .SelectMany(s => s.Select(c => c == 'X' ? X : c == 'O' ? O : _)) 34 | .ToArray(); 35 | 36 | var state = new ConnectFourState(board); 37 | yield return state; 38 | } 39 | } 40 | 41 | [TestMethod] 42 | public void GetResult_Rows() 43 | { 44 | var boards = GetStatesFromData("TestData.GetResult_Rows.txt"); 45 | foreach(var state in boards) 46 | { 47 | Assert.AreEqual(1, state.GetResult(X), state.ToString()); 48 | Assert.AreEqual(0, state.GetResult(O), state.ToString()); 49 | Assert.IsFalse(state.Actions.Any()); 50 | } 51 | } 52 | 53 | [TestMethod] 54 | public void GetResult_Columns() 55 | { 56 | var boards = GetStatesFromData("TestData.GetResult_Columns.txt"); 57 | 58 | foreach(var state in boards) 59 | { 60 | Assert.AreEqual(1, state.GetResult(X), state.ToString()); 61 | Assert.AreEqual(0, state.GetResult(O), state.ToString()); 62 | Assert.IsFalse(state.Actions.Any()); 63 | } 64 | } 65 | 66 | [TestMethod] 67 | public void GetResult_DiagonalBottomTop() 68 | { 69 | var boards = GetStatesFromData("TestData.GetResult_DiagonalBottomTop.txt"); 70 | 71 | foreach (var state in boards) 72 | { 73 | Assert.AreEqual(1, state.GetResult(X), state.ToString()); 74 | Assert.AreEqual(0, state.GetResult(O), state.ToString()); 75 | Assert.IsFalse(state.Actions.Any()); 76 | } 77 | } 78 | 79 | [TestMethod] 80 | public void GetResult_DiagonalTopBottom() 81 | { 82 | var boards = GetStatesFromData("TestData.GetResult_DiagonalTopBottom.txt"); 83 | 84 | foreach (var state in boards) 85 | { 86 | Assert.AreEqual(1, state.GetResult(X), state.ToString()); 87 | Assert.AreEqual(0, state.GetResult(O), state.ToString()); 88 | Assert.IsFalse(state.Actions.Any()); 89 | } 90 | } 91 | 92 | [TestMethod] 93 | public void FullRun_Simple() 94 | { 95 | var state = new ConnectFourState(); 96 | foreach (var position in new[] { 0, 0, 1, 1, 2, 2, 3 }) 97 | state.ApplyAction(new ConnectFourAction(position)); 98 | 99 | Assert.IsFalse(state.Actions.Any()); 100 | Assert.AreEqual(1, state.GetResult(X)); 101 | Assert.AreEqual(0, state.GetResult(O)); 102 | } 103 | } 104 | 105 | public static class ExtensionMethods 106 | { 107 | public static IEnumerable SplitOnNewLine(this string source) 108 | { 109 | return source.Split(new[] { Environment.NewLine }, StringSplitOptions.None); 110 | } 111 | 112 | public static IEnumerable> Batch(this IEnumerable source, int batchSize) 113 | { 114 | var batch = new List(batchSize); 115 | foreach(var t in source) 116 | { 117 | batch.Add(t); 118 | if(batch.Count == batchSize) 119 | { 120 | yield return batch; 121 | batch = new List(batchSize); 122 | } 123 | } 124 | 125 | if (batch.Any()) 126 | yield return batch; 127 | } 128 | 129 | public static string ReadAllText(this IFileProvider fileProvider, string fileName) 130 | { 131 | var file = fileProvider.GetFileInfo(fileName); 132 | using (var stream = file.CreateReadStream()) 133 | using (var reader = new StreamReader(stream)) 134 | return reader.ReadToEnd(); 135 | } 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.suo 8 | *.user 9 | *.userosscache 10 | *.sln.docstates 11 | 12 | # User-specific files (MonoDevelop/Xamarin Studio) 13 | *.userprefs 14 | 15 | # Build results 16 | [Dd]ebug/ 17 | [Dd]ebugPublic/ 18 | [Rr]elease/ 19 | [Rr]eleases/ 20 | x64/ 21 | x86/ 22 | bld/ 23 | [Bb]in/ 24 | [Oo]bj/ 25 | [Ll]og/ 26 | 27 | # Visual Studio 2015 cache/options directory 28 | .vs/ 29 | # Uncomment if you have tasks that create the project's static files in wwwroot 30 | #wwwroot/ 31 | 32 | # MSTest test Results 33 | [Tt]est[Rr]esult*/ 34 | [Bb]uild[Ll]og.* 35 | 36 | # NUNIT 37 | *.VisualState.xml 38 | TestResult.xml 39 | 40 | # Build Results of an ATL Project 41 | [Dd]ebugPS/ 42 | [Rr]eleasePS/ 43 | dlldata.c 44 | 45 | # .NET Core 46 | project.lock.json 47 | project.fragment.lock.json 48 | artifacts/ 49 | **/Properties/launchSettings.json 50 | 51 | *_i.c 52 | *_p.c 53 | *_i.h 54 | *.ilk 55 | *.meta 56 | *.obj 57 | *.pch 58 | *.pdb 59 | *.pgc 60 | *.pgd 61 | *.rsp 62 | *.sbr 63 | *.tlb 64 | *.tli 65 | *.tlh 66 | *.tmp 67 | *.tmp_proj 68 | *.log 69 | *.vspscc 70 | *.vssscc 71 | .builds 72 | *.pidb 73 | *.svclog 74 | *.scc 75 | 76 | # Chutzpah Test files 77 | _Chutzpah* 78 | 79 | # Visual C++ cache files 80 | ipch/ 81 | *.aps 82 | *.ncb 83 | *.opendb 84 | *.opensdf 85 | *.sdf 86 | *.cachefile 87 | *.VC.db 88 | *.VC.VC.opendb 89 | 90 | # Visual Studio profiler 91 | *.psess 92 | *.vsp 93 | *.vspx 94 | *.sap 95 | 96 | # TFS 2012 Local Workspace 97 | $tf/ 98 | 99 | # Guidance Automation Toolkit 100 | *.gpState 101 | 102 | # ReSharper is a .NET coding add-in 103 | _ReSharper*/ 104 | *.[Rr]e[Ss]harper 105 | *.DotSettings.user 106 | 107 | # JustCode is a .NET coding add-in 108 | .JustCode 109 | 110 | # TeamCity is a build add-in 111 | _TeamCity* 112 | 113 | # DotCover is a Code Coverage Tool 114 | *.dotCover 115 | 116 | # Visual Studio code coverage results 117 | *.coverage 118 | *.coveragexml 119 | 120 | # NCrunch 121 | _NCrunch_* 122 | .*crunch*.local.xml 123 | nCrunchTemp_* 124 | 125 | # MightyMoose 126 | *.mm.* 127 | AutoTest.Net/ 128 | 129 | # Web workbench (sass) 130 | .sass-cache/ 131 | 132 | # Installshield output folder 133 | [Ee]xpress/ 134 | 135 | # DocProject is a documentation generator add-in 136 | DocProject/buildhelp/ 137 | DocProject/Help/*.HxT 138 | DocProject/Help/*.HxC 139 | DocProject/Help/*.hhc 140 | DocProject/Help/*.hhk 141 | DocProject/Help/*.hhp 142 | DocProject/Help/Html2 143 | DocProject/Help/html 144 | 145 | # Click-Once directory 146 | publish/ 147 | 148 | # Publish Web Output 149 | *.[Pp]ublish.xml 150 | *.azurePubxml 151 | # TODO: Comment the next line if you want to checkin your web deploy settings 152 | # but database connection strings (with potential passwords) will be unencrypted 153 | *.pubxml 154 | *.publishproj 155 | 156 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 157 | # checkin your Azure Web App publish settings, but sensitive information contained 158 | # in these scripts will be unencrypted 159 | PublishScripts/ 160 | 161 | # NuGet Packages 162 | *.nupkg 163 | # The packages folder can be ignored because of Package Restore 164 | **/packages/* 165 | # except build/, which is used as an MSBuild target. 166 | !**/packages/build/ 167 | # Uncomment if necessary however generally it will be regenerated when needed 168 | #!**/packages/repositories.config 169 | # NuGet v3's project.json files produces more ignorable files 170 | *.nuget.props 171 | *.nuget.targets 172 | 173 | # Microsoft Azure Build Output 174 | csx/ 175 | *.build.csdef 176 | 177 | # Microsoft Azure Emulator 178 | ecf/ 179 | rcf/ 180 | 181 | # Windows Store app package directories and files 182 | AppPackages/ 183 | BundleArtifacts/ 184 | Package.StoreAssociation.xml 185 | _pkginfo.txt 186 | 187 | # Visual Studio cache files 188 | # files ending in .cache can be ignored 189 | *.[Cc]ache 190 | # but keep track of directories ending in .cache 191 | !*.[Cc]ache/ 192 | 193 | # Others 194 | ClientBin/ 195 | ~$* 196 | *~ 197 | *.dbmdl 198 | *.dbproj.schemaview 199 | *.jfm 200 | *.pfx 201 | *.publishsettings 202 | orleans.codegen.cs 203 | 204 | # Since there are multiple workflows, uncomment next line to ignore bower_components 205 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 206 | #bower_components/ 207 | 208 | # RIA/Silverlight projects 209 | Generated_Code/ 210 | 211 | # Backup & report files from converting an old project file 212 | # to a newer Visual Studio version. Backup files are not needed, 213 | # because we have git ;-) 214 | _UpgradeReport_Files/ 215 | Backup*/ 216 | UpgradeLog*.XML 217 | UpgradeLog*.htm 218 | 219 | # SQL Server files 220 | *.mdf 221 | *.ldf 222 | *.ndf 223 | 224 | # Business Intelligence projects 225 | *.rdl.data 226 | *.bim.layout 227 | *.bim_*.settings 228 | 229 | # Microsoft Fakes 230 | FakesAssemblies/ 231 | 232 | # GhostDoc plugin setting file 233 | *.GhostDoc.xml 234 | 235 | # Node.js Tools for Visual Studio 236 | .ntvs_analysis.dat 237 | node_modules/ 238 | 239 | # Typescript v1 declaration files 240 | typings/ 241 | 242 | # Visual Studio 6 build log 243 | *.plg 244 | 245 | # Visual Studio 6 workspace options file 246 | *.opt 247 | 248 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 249 | *.vbw 250 | 251 | # Visual Studio LightSwitch build output 252 | **/*.HTMLClient/GeneratedArtifacts 253 | **/*.DesktopClient/GeneratedArtifacts 254 | **/*.DesktopClient/ModelManifest.xml 255 | **/*.Server/GeneratedArtifacts 256 | **/*.Server/ModelManifest.xml 257 | _Pvt_Extensions 258 | 259 | # Paket dependency manager 260 | .paket/paket.exe 261 | paket-files/ 262 | 263 | # FAKE - F# Make 264 | .fake/ 265 | 266 | # JetBrains Rider 267 | .idea/ 268 | *.sln.iml 269 | 270 | # CodeRush 271 | .cr/ 272 | 273 | # Python Tools for Visual Studio (PTVS) 274 | __pycache__/ 275 | *.pyc 276 | 277 | # Cake - Uncomment if you are using it 278 | # tools/** 279 | # !tools/packages.config 280 | 281 | # Telerik's JustMock configuration file 282 | *.jmconfig 283 | 284 | # BizTalk build output 285 | *.btp.cs 286 | *.btm.cs 287 | *.odx.cs 288 | *.xsd.cs 289 | -------------------------------------------------------------------------------- /montecarlo.sln: -------------------------------------------------------------------------------- 1 | 2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 15 4 | VisualStudioVersion = 15.0.26730.12 5 | MinimumVisualStudioVersion = 15.0.26124.0 6 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MonteCarlo", "src\MonteCarlo\MonteCarlo.csproj", "{1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}" 7 | EndProject 8 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TicTacToe", "examples\TicTacToe\TicTacToe\TicTacToe.csproj", "{C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}" 9 | EndProject 10 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TicTacToeTests", "examples\TicTacToe\TicTacToeTests\TicTacToeTests.csproj", "{62924053-8DE0-458D-AEB7-F8B46F9FBAB3}" 11 | EndProject 12 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ConnectFour", "examples\ConnectFour\ConnectFour\ConnectFour.csproj", "{9858D70A-05AA-4761-99D3-36168BA3C632}" 13 | EndProject 14 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ConnectFourTests", "examples\ConnectFour\ConnectFourTests\ConnectFourTests.csproj", "{59408B67-4DEC-4FEF-8F8D-A2D552213017}" 15 | EndProject 16 | Global 17 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 18 | Debug|Any CPU = Debug|Any CPU 19 | Debug|x64 = Debug|x64 20 | Debug|x86 = Debug|x86 21 | Release|Any CPU = Release|Any CPU 22 | Release|x64 = Release|x64 23 | Release|x86 = Release|x86 24 | EndGlobalSection 25 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 26 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 27 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Debug|Any CPU.Build.0 = Debug|Any CPU 28 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Debug|x64.ActiveCfg = Debug|Any CPU 29 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Debug|x64.Build.0 = Debug|Any CPU 30 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Debug|x86.ActiveCfg = Debug|Any CPU 31 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Debug|x86.Build.0 = Debug|Any CPU 32 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Release|Any CPU.ActiveCfg = Release|Any CPU 33 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Release|Any CPU.Build.0 = Release|Any CPU 34 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Release|x64.ActiveCfg = Release|Any CPU 35 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Release|x64.Build.0 = Release|Any CPU 36 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Release|x86.ActiveCfg = Release|Any CPU 37 | {1A69C40F-DCE9-4DEA-8A49-7EA227707E9A}.Release|x86.Build.0 = Release|Any CPU 38 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 39 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Debug|Any CPU.Build.0 = Debug|Any CPU 40 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Debug|x64.ActiveCfg = Debug|Any CPU 41 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Debug|x64.Build.0 = Debug|Any CPU 42 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Debug|x86.ActiveCfg = Debug|Any CPU 43 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Debug|x86.Build.0 = Debug|Any CPU 44 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Release|Any CPU.ActiveCfg = Release|Any CPU 45 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Release|Any CPU.Build.0 = Release|Any CPU 46 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Release|x64.ActiveCfg = Release|Any CPU 47 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Release|x64.Build.0 = Release|Any CPU 48 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Release|x86.ActiveCfg = Release|Any CPU 49 | {C1A9F4F9-6A35-4692-ABCF-2E8FEB95CCA8}.Release|x86.Build.0 = Release|Any CPU 50 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 51 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Debug|Any CPU.Build.0 = Debug|Any CPU 52 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Debug|x64.ActiveCfg = Debug|Any CPU 53 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Debug|x64.Build.0 = Debug|Any CPU 54 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Debug|x86.ActiveCfg = Debug|Any CPU 55 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Debug|x86.Build.0 = Debug|Any CPU 56 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Release|Any CPU.ActiveCfg = Release|Any CPU 57 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Release|Any CPU.Build.0 = Release|Any CPU 58 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Release|x64.ActiveCfg = Release|Any CPU 59 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Release|x64.Build.0 = Release|Any CPU 60 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Release|x86.ActiveCfg = Release|Any CPU 61 | {62924053-8DE0-458D-AEB7-F8B46F9FBAB3}.Release|x86.Build.0 = Release|Any CPU 62 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 63 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Debug|Any CPU.Build.0 = Debug|Any CPU 64 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Debug|x64.ActiveCfg = Debug|Any CPU 65 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Debug|x64.Build.0 = Debug|Any CPU 66 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Debug|x86.ActiveCfg = Debug|Any CPU 67 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Debug|x86.Build.0 = Debug|Any CPU 68 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Release|Any CPU.ActiveCfg = Release|Any CPU 69 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Release|Any CPU.Build.0 = Release|Any CPU 70 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Release|x64.ActiveCfg = Release|Any CPU 71 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Release|x64.Build.0 = Release|Any CPU 72 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Release|x86.ActiveCfg = Release|Any CPU 73 | {9858D70A-05AA-4761-99D3-36168BA3C632}.Release|x86.Build.0 = Release|Any CPU 74 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 75 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Debug|Any CPU.Build.0 = Debug|Any CPU 76 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Debug|x64.ActiveCfg = Debug|Any CPU 77 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Debug|x64.Build.0 = Debug|Any CPU 78 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Debug|x86.ActiveCfg = Debug|Any CPU 79 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Debug|x86.Build.0 = Debug|Any CPU 80 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Release|Any CPU.ActiveCfg = Release|Any CPU 81 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Release|Any CPU.Build.0 = Release|Any CPU 82 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Release|x64.ActiveCfg = Release|Any CPU 83 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Release|x64.Build.0 = Release|Any CPU 84 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Release|x86.ActiveCfg = Release|Any CPU 85 | {59408B67-4DEC-4FEF-8F8D-A2D552213017}.Release|x86.Build.0 = Release|Any CPU 86 | EndGlobalSection 87 | GlobalSection(SolutionProperties) = preSolution 88 | HideSolutionNode = FALSE 89 | EndGlobalSection 90 | GlobalSection(ExtensibilityGlobals) = postSolution 91 | SolutionGuid = {BB81C8F8-976A-4679-BF19-A0D29029D70D} 92 | EndGlobalSection 93 | EndGlobal 94 | -------------------------------------------------------------------------------- /src/MonteCarlo/MonteCarloTreeSearch.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Diagnostics; 4 | using System.Linq; 5 | 6 | namespace MonteCarlo 7 | { 8 | public interface IAction {} 9 | 10 | public interface IPlayer {} 11 | 12 | public interface IState 13 | { 14 | IState Clone(); 15 | 16 | TPlayer CurrentPlayer { get; } 17 | 18 | IList Actions { get; } 19 | 20 | void ApplyAction(TAction action); 21 | 22 | double GetResult(TPlayer forPlayer); 23 | } 24 | 25 | public class MonteCarloTreeSearch 26 | { 27 | private class Node : IMctsNode where TPlayer : IPlayer where TAction: IAction 28 | { 29 | public Node(IState state, TAction action = default(TAction), Node parent = null) 30 | { 31 | this.Parent = parent; 32 | Player = state.CurrentPlayer; 33 | State = state; 34 | Action = action; 35 | UntriedActions = new HashSet(state.Actions); 36 | } 37 | 38 | public Node Parent { get; } 39 | 40 | public IList> Children { get; } = new List>(); 41 | 42 | public int NumRuns { get; set; } 43 | 44 | public double NumWins { get; set; } 45 | 46 | public TPlayer Player { get; } 47 | 48 | public IState State { get; } 49 | 50 | public TAction Action { get; } 51 | 52 | public ISet UntriedActions { get; } 53 | 54 | public IList Actions => State.Actions; 55 | 56 | private static double c = Math.Sqrt(2); 57 | 58 | public double ExploitationValue => NumWins / NumRuns; 59 | 60 | public double ExplorationValue => (Math.Sqrt(2*Math.Log(Parent.NumRuns) / NumRuns)); 61 | 62 | private double UCT => ExploitationValue + ExplorationValue; 63 | 64 | public Node SelectChild() 65 | { 66 | return Children.MaxElementBy(e => e.UCT); 67 | } 68 | 69 | public Node AddChild(TAction action, IState state) 70 | { 71 | var child = new Node(state, action, this); 72 | UntriedActions.Remove(action); 73 | Children.Add(child); 74 | 75 | return child; 76 | } 77 | 78 | public void BuildTree(Func shouldContinue) 79 | { 80 | var iterations = 0; 81 | var timer = Stopwatch.StartNew(); 82 | while (shouldContinue(iterations++, timer.ElapsedMilliseconds)) 83 | { 84 | var node = this; 85 | var state = State.Clone(); 86 | 87 | //select 88 | while (!node.UntriedActions.Any() && node.Actions.Any()) 89 | { 90 | node = node.SelectChild(); 91 | state.ApplyAction(node.Action); 92 | } 93 | 94 | //expand 95 | if (node.UntriedActions.Any()) 96 | { 97 | var action = node.UntriedActions.RandomChoice(); 98 | state.ApplyAction(action); 99 | node = node.AddChild(action, state); 100 | } 101 | 102 | //simulate 103 | while (state.Actions.Any()) 104 | state.ApplyAction(state.Actions.RandomChoice()); 105 | 106 | //backpropagate 107 | while (node != null) 108 | { 109 | node.NumRuns++; 110 | node.NumWins += state.GetResult(this.Player); 111 | node = node.Parent; 112 | } 113 | } 114 | } 115 | 116 | public override string ToString() 117 | { 118 | return $"{NumWins}/{NumRuns}: ({ExploitationValue}/{ExplorationValue}={UCT})"; 119 | } 120 | } 121 | 122 | public static IEnumerable> GetTopActions(IState state, int maxIterations) where TPlayer : IPlayer where TAction : IAction 123 | { 124 | return GetTopActions(state, maxIterations, long.MaxValue); 125 | } 126 | 127 | public static IEnumerable> GetTopActions(IState state, long timeBudget) where TPlayer : IPlayer where TAction : IAction 128 | { 129 | return GetTopActions(state, int.MaxValue, timeBudget); 130 | } 131 | 132 | public static IEnumerable> GetTopActions(IState state, int maxIterations, long timeBudget) where TPlayer : IPlayer where TAction : IAction 133 | { 134 | var root = new Node(state); 135 | root.BuildTree((numIterations, elapsedMs) => numIterations < maxIterations && elapsedMs < timeBudget); 136 | return root.Children 137 | .OrderByDescending(n => n.NumRuns); 138 | } 139 | } 140 | 141 | public interface IMctsNode where TAction : IAction 142 | { 143 | TAction Action { get; } 144 | 145 | int NumRuns { get; } 146 | 147 | double NumWins { get; } 148 | } 149 | 150 | internal static class CollectionExtensions 151 | { 152 | private static Random _random = new Random(); 153 | 154 | public static T RandomChoice(this ICollection source, Random random = null) 155 | { 156 | var i = (random ?? _random).Next(source.Count); 157 | return source.ElementAt(i); 158 | } 159 | 160 | public static T MaxElementBy(this IEnumerable source, Func selector) 161 | { 162 | var currentMaxElement = default(T); 163 | var currentMaxValue = double.MinValue; 164 | 165 | foreach (var element in source) 166 | { 167 | var value = selector(element); 168 | if (currentMaxValue < value) 169 | { 170 | currentMaxValue = value; 171 | currentMaxElement = element; 172 | } 173 | } 174 | 175 | return currentMaxElement; 176 | } 177 | } 178 | } 179 | -------------------------------------------------------------------------------- /examples/ConnectFour/ConnectFour/ConnectFourState.cs: -------------------------------------------------------------------------------- 1 | using MonteCarlo; 2 | using System; 3 | using System.Collections.Generic; 4 | using System.Linq; 5 | using System.Text; 6 | 7 | namespace ConnectFour 8 | { 9 | public class ConnectFourPlayer : IPlayer 10 | { 11 | public string Name { get; private set; } 12 | 13 | private ConnectFourPlayer(string name) 14 | { 15 | Name = name; 16 | } 17 | 18 | public ConnectFourPlayer NextPlayer => this == X ? O : X; 19 | 20 | public override string ToString() 21 | { 22 | return Name; 23 | } 24 | 25 | public static readonly ConnectFourPlayer X = new ConnectFourPlayer("X"); 26 | public static readonly ConnectFourPlayer O = new ConnectFourPlayer("O"); 27 | } 28 | 29 | public struct ConnectFourAction : IAction 30 | { 31 | public int Column { get; private set; } 32 | 33 | public ConnectFourAction(int column) 34 | { 35 | Column = column; 36 | } 37 | 38 | public override string ToString() 39 | { 40 | return Column.ToString(); 41 | } 42 | } 43 | 44 | public class ConnectFourState : IState 45 | { 46 | public const int NumRows = 6; 47 | public const int NumCols = 7; 48 | 49 | // board is a NumRows * NumCols array, with the 0th index being the bottom left and the last value being the top right 50 | private ConnectFourPlayer[] board; 51 | // to save time, we precompute the available actions and store them in this variable 52 | private IList availableActions; 53 | // to save time, we precompute the available winning runs of 4 ahead of time and prune the list as we apply actions 54 | private IList availableWinningRuns = new List(); 55 | // when a winner is found, store it here. Maybe that's not a bad idea, but I'll let future Dan be the judge 56 | private ConnectFourPlayer winner; 57 | 58 | private void pruneAvailableWinningRuns(int forPosition) 59 | { 60 | var toPrune = new List(); 61 | foreach(var run in availableWinningRuns) 62 | { 63 | if (run.Contains(forPosition)) 64 | { 65 | var boardValues = run.ToBoardValues(board); 66 | ConnectFourPlayer runWinner; 67 | if((runWinner = boardValues.GetWinner()) != null) 68 | { 69 | winner = runWinner; 70 | availableWinningRuns.Clear(); 71 | return; 72 | } 73 | 74 | if(!boardValues.CanWin()) 75 | { 76 | toPrune.Add(run); 77 | } 78 | } 79 | } 80 | toPrune.ForEach(run => availableWinningRuns.Remove(run)); 81 | } 82 | 83 | /// 84 | /// Converts a row, column coordinate into a flat index in the board 85 | /// 86 | /// 87 | /// 88 | /// 89 | private int GetBoardIndex(int row, int col) 90 | { 91 | return row * NumCols + col; 92 | } 93 | 94 | private IList GetDefaultAvailableWinningRuns() 95 | { 96 | var runs = new List(); 97 | 98 | for (var row = 0; row < NumRows; row++) 99 | { 100 | var canRunTop = row < NumRows - 3; 101 | for (var col = 0; col < NumCols; col++) 102 | { 103 | var canRunRight = col < NumCols - 3; 104 | if (canRunRight) 105 | { 106 | var horizontalRun = Enumerable.Range(col, 4).Select(c => GetBoardIndex(row, c)).ToArray(); 107 | runs.Add(horizontalRun); 108 | } 109 | if (canRunTop) 110 | { 111 | var verticalRun = Enumerable.Range(row, 4).Select(r => GetBoardIndex(r, col)).ToArray(); 112 | runs.Add(verticalRun); 113 | } 114 | if(canRunRight && canRunTop) 115 | { 116 | var diagonalBottomRopRun = Enumerable.Range(0, 4).Select(i => GetBoardIndex(row + i, col + i)).ToArray(); 117 | runs.Add(diagonalBottomRopRun); 118 | } 119 | if(canRunRight && row > 2) 120 | { 121 | var diagonalTopBottomRun = Enumerable.Range(0, 4).Select(i => GetBoardIndex(row - i, col + i)).ToArray(); 122 | runs.Add(diagonalTopBottomRun); 123 | } 124 | } 125 | } 126 | 127 | return runs; 128 | } 129 | 130 | public ConnectFourPlayer CurrentPlayer { get; private set; } 131 | 132 | public IList Actions => winner == null ? availableActions : new ConnectFourAction[0]; 133 | 134 | public ConnectFourState(ConnectFourPlayer[] board, ConnectFourPlayer currentPlayer = null) 135 | { 136 | this.board = board; 137 | availableActions = GetRow(NumCols - 1) 138 | .Select((p, i) => new { p, i }) 139 | .Where(o => o.p == null) 140 | .Select(o => new ConnectFourAction(o.i)) 141 | .ToList(); 142 | 143 | // TODO: prune runs that can't produce wins 144 | availableWinningRuns = GetDefaultAvailableWinningRuns(); 145 | CurrentPlayer = CurrentPlayer ?? ConnectFourPlayer.X; 146 | board 147 | .Select((p, i) => new { p, i }) 148 | .Where(o => o.p != null) 149 | .Select(o => o.i) 150 | .ToList() 151 | .ForEach(pruneAvailableWinningRuns); 152 | } 153 | 154 | public ConnectFourState() { 155 | this.board = new ConnectFourPlayer[NumRows * NumCols]; 156 | availableActions = Enumerable.Range(0, NumCols).Select(i => new ConnectFourAction(i)).ToList(); 157 | availableWinningRuns = GetDefaultAvailableWinningRuns(); 158 | CurrentPlayer = ConnectFourPlayer.X; 159 | } 160 | 161 | public ConnectFourState(ConnectFourState toClone) 162 | { 163 | board = toClone.board.ToArray(); 164 | availableActions = new List(toClone.availableActions); 165 | availableWinningRuns = new List(toClone.availableWinningRuns); 166 | CurrentPlayer = toClone.CurrentPlayer; 167 | } 168 | 169 | public IState Clone() 170 | { 171 | return new ConnectFourState(this); 172 | } 173 | 174 | public void ApplyAction(ConnectFourAction action) 175 | { 176 | var row = GetCol(action.Column) 177 | .Select((p, i) => new { p, i }) 178 | .SkipWhile(o => o.p != null) 179 | .FirstOrDefault()?.i; 180 | 181 | if (row == null) 182 | throw new ArgumentException($"Column {action.Column} is already full", nameof(action)); 183 | 184 | // if this was the top row, then remove this column as an available action 185 | if (row >= NumRows - 1) 186 | availableActions.Remove(availableActions.FirstOrDefault(a => a.Column == action.Column)); 187 | 188 | // TODO: prune the available winning runs list 189 | var actionPosition = GetBoardIndex(row.Value, action.Column); 190 | board[actionPosition] = CurrentPlayer; 191 | CurrentPlayer = CurrentPlayer.NextPlayer; 192 | pruneAvailableWinningRuns(actionPosition); 193 | } 194 | 195 | public double GetResult(ConnectFourPlayer forPlayer) 196 | { 197 | if (Actions.Any()) 198 | throw new InvalidOperationException("Game isn't over yet"); 199 | 200 | if (winner == forPlayer) return 1; 201 | if (winner == null) return 0.5; 202 | return 0; 203 | } 204 | 205 | public IEnumerable GetRow(int rowNum) 206 | { 207 | var startIndex = NumCols * rowNum; 208 | var endIndex = startIndex + NumCols; 209 | for(var i = startIndex; i < endIndex && i < board.Length; i++) 210 | yield return board[i]; 211 | } 212 | 213 | public IEnumerable GetCol(int colNum) 214 | { 215 | for(var i = colNum; i < board.Length; i += NumCols) 216 | yield return board[i]; 217 | } 218 | 219 | public override string ToString() 220 | { 221 | var sb = new StringBuilder(); 222 | for (var rowNum = NumRows - 1; rowNum >= 0; rowNum--) 223 | { 224 | sb.AppendJoin("", GetRow(rowNum).Select(c => c == null ? " " : c.ToString().Substring(0, 1))); 225 | if (rowNum > 0) 226 | sb.Append("\n"); 227 | } 228 | return sb.ToString(); 229 | } 230 | } 231 | 232 | public static class ExtensionMethods 233 | { 234 | public static IEnumerable ToBoardValues(this IEnumerable positions, ConnectFourPlayer[] board) 235 | { 236 | return positions.Select(pos => board[pos]); 237 | } 238 | 239 | public static ConnectFourPlayer GetWinner(this IEnumerable source) 240 | { 241 | ConnectFourPlayer winner = null; 242 | var iteratedFirstElement = false; 243 | 244 | foreach(var p in source) 245 | { 246 | if (!iteratedFirstElement) 247 | { 248 | iteratedFirstElement = true; 249 | winner = p; 250 | } 251 | if (p == null || p != winner) 252 | return null; 253 | } 254 | 255 | return winner; 256 | } 257 | 258 | public static bool CanWin(this IEnumerable source) 259 | { 260 | ConnectFourPlayer firstFoundPlayer = null; 261 | 262 | foreach (var p in source) 263 | if(p != null) 264 | { 265 | if(firstFoundPlayer == null) 266 | firstFoundPlayer = p; 267 | if (p != firstFoundPlayer) 268 | return false; 269 | } 270 | 271 | return true; 272 | } 273 | } 274 | } 275 | --------------------------------------------------------------------------------