├── image ├── d1.png ├── d2.png ├── d3.png ├── d4.png ├── d5.png ├── d6.png ├── d7.png ├── d8.png └── d9.png ├── src ├── AI │ ├── AI.class │ ├── UCTAI.class │ ├── Transition.class │ ├── AlphaBetaAI.class │ ├── QLearningAI.class │ └── UCTTreeNode.class ├── Main.class ├── Game │ ├── GUI.class │ ├── Board.class │ ├── Chess.class │ ├── GUI$1.class │ ├── Player.class │ ├── InitFrame$1.class │ ├── InitFrame.class │ ├── Playground.class │ └── Playground$1.class ├── IOHandler │ └── IOHandler.class ├── NeuralNetwork │ ├── Activator.class │ ├── VectorDNN.class │ ├── NetworkLayer.class │ ├── VectorMatrix.class │ └── NeuralNetwork.class ├── Player.java ├── Main.java ├── Transition.java ├── AI.java ├── Activator.java ├── Chess.java ├── UCTTreeNode.java ├── InitFrame.java ├── IOHandler.java ├── AlphaBetaAI.java ├── Playground.java ├── VectorMatrix.java ├── NetworkLayer.java ├── GUI.java ├── UCTAI.java ├── Board.java ├── VectorDNN.java ├── NeuralNetwork.java └── QLearningAI.java └── README.md /image/d1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/image/d1.png -------------------------------------------------------------------------------- /image/d2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/image/d2.png -------------------------------------------------------------------------------- /image/d3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/image/d3.png -------------------------------------------------------------------------------- /image/d4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/image/d4.png -------------------------------------------------------------------------------- /image/d5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/image/d5.png -------------------------------------------------------------------------------- /image/d6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/image/d6.png -------------------------------------------------------------------------------- /image/d7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/image/d7.png -------------------------------------------------------------------------------- /image/d8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/image/d8.png -------------------------------------------------------------------------------- /image/d9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/image/d9.png -------------------------------------------------------------------------------- /src/AI/AI.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/AI/AI.class -------------------------------------------------------------------------------- /src/Main.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/Main.class -------------------------------------------------------------------------------- /src/AI/UCTAI.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/AI/UCTAI.class -------------------------------------------------------------------------------- /src/Game/GUI.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/Game/GUI.class -------------------------------------------------------------------------------- /src/Game/Board.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/Game/Board.class -------------------------------------------------------------------------------- /src/Game/Chess.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/Game/Chess.class -------------------------------------------------------------------------------- /src/Game/GUI$1.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/Game/GUI$1.class -------------------------------------------------------------------------------- /src/Game/Player.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/Game/Player.class -------------------------------------------------------------------------------- /src/AI/Transition.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/AI/Transition.class -------------------------------------------------------------------------------- /src/AI/AlphaBetaAI.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/AI/AlphaBetaAI.class -------------------------------------------------------------------------------- /src/AI/QLearningAI.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/AI/QLearningAI.class -------------------------------------------------------------------------------- /src/AI/UCTTreeNode.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/AI/UCTTreeNode.class -------------------------------------------------------------------------------- /src/Game/InitFrame$1.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/Game/InitFrame$1.class -------------------------------------------------------------------------------- /src/Game/InitFrame.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/Game/InitFrame.class -------------------------------------------------------------------------------- /src/Game/Playground.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/Game/Playground.class -------------------------------------------------------------------------------- /src/Game/Playground$1.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/Game/Playground$1.class -------------------------------------------------------------------------------- /src/IOHandler/IOHandler.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/IOHandler/IOHandler.class -------------------------------------------------------------------------------- /src/NeuralNetwork/Activator.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/NeuralNetwork/Activator.class -------------------------------------------------------------------------------- /src/NeuralNetwork/VectorDNN.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/NeuralNetwork/VectorDNN.class -------------------------------------------------------------------------------- /src/NeuralNetwork/NetworkLayer.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/NeuralNetwork/NetworkLayer.class -------------------------------------------------------------------------------- /src/NeuralNetwork/VectorMatrix.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/NeuralNetwork/VectorMatrix.class -------------------------------------------------------------------------------- /src/NeuralNetwork/NeuralNetwork.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaofengmarco/JavaOthello/HEAD/src/NeuralNetwork/NeuralNetwork.class -------------------------------------------------------------------------------- /src/Player.java: -------------------------------------------------------------------------------- 1 | package Game; 2 | public class Player 3 | { 4 | public int hold; 5 | public Player(int h) 6 | { 7 | hold = h; 8 | } 9 | } -------------------------------------------------------------------------------- /src/Main.java: -------------------------------------------------------------------------------- 1 | public class Main 2 | { 3 | public static void main(String[] args) 4 | { 5 | Game.GUI game = new Game.GUI(); 6 | game.play(); 7 | game.setVisible(false); 8 | game.dispose(); 9 | } 10 | } -------------------------------------------------------------------------------- /src/Transition.java: -------------------------------------------------------------------------------- 1 | package AI; 2 | public class Transition 3 | { 4 | public int[] s1, s2; 5 | public int action; 6 | public double reward; 7 | public Transition(int[] a1, int[] a2, int a, double r) 8 | { 9 | for (int i = 0; i < 64; i++) 10 | { 11 | s1[i] = a1[i]; 12 | s2[i] = a2[i]; 13 | } 14 | action = a; 15 | reward = r; 16 | } 17 | } -------------------------------------------------------------------------------- /src/AI.java: -------------------------------------------------------------------------------- 1 | package AI; 2 | import Game.*; 3 | public class AI 4 | { 5 | public Board table; 6 | public int hold; 7 | public AI(int h) 8 | { 9 | hold = h; 10 | table = new Board(); 11 | } 12 | public AI(int[] x, int h) 13 | { 14 | table = new Board(x); 15 | hold = h; 16 | } 17 | public void setBoard(int[] x) 18 | { 19 | table = new Board(x); 20 | } 21 | public int move() 22 | { 23 | int[] steps = new int[65]; 24 | int step; 25 | steps = table.nextSteps(hold); 26 | if (steps[0] > 0) 27 | { 28 | if (steps[0] == 1) 29 | step = steps[1]; 30 | else 31 | step = steps[(int)(Math.random() * steps[0]) + 1]; 32 | table.set(step / 8 + 1, step % 8 + 1, hold); 33 | return step; 34 | } 35 | return -1; 36 | } 37 | } -------------------------------------------------------------------------------- /src/Activator.java: -------------------------------------------------------------------------------- 1 | package NeuralNetwork; 2 | public class Activator 3 | { 4 | public static double Relu(double a) 5 | { 6 | if (a > 0) 7 | return a; 8 | return 0; 9 | } 10 | public static double dRelu(double a) 11 | { 12 | if (a > 0) 13 | return 1; 14 | return 0; 15 | } 16 | public static double LRelu(double rate, double a) 17 | { 18 | if (a > 0) 19 | return a; 20 | return rate * a; 21 | } 22 | public static double dLRelu(double rate, double a) 23 | { 24 | if (a > 0) 25 | return 1; 26 | return rate; 27 | } 28 | public static double Sigmoid(double a) 29 | { 30 | double ans = 1.0 / (1.0 + Math.exp(-a)); 31 | return ans; 32 | } 33 | public static double dSigmoid(double a) 34 | { 35 | double ans = a * (1.0 - a); 36 | return ans; 37 | } 38 | } -------------------------------------------------------------------------------- /src/Chess.java: -------------------------------------------------------------------------------- 1 | package Game; 2 | import javax.swing.*; 3 | import java.awt.*; 4 | public class Chess extends JPanel 5 | { 6 | public int color = 0; 7 | public Chess() 8 | { 9 | setBackground(Color.white); 10 | } 11 | public void paintComponent(Graphics g) 12 | { 13 | Graphics2D g2 = (Graphics2D) g; 14 | super.paintComponent(g); 15 | g.drawRect(0, 0, getWidth(), getHeight()); 16 | if ((color == 1) || (color == 2)) 17 | g.fillOval(0, 0, getWidth(), getHeight()); 18 | else if ((color == -1) || (color == -2)) 19 | g.drawOval(0, 0, getWidth(), getHeight()); 20 | if ((color == 2) || (color == -2)) 21 | { 22 | if (color == -2) 23 | g2.setColor(Color.GREEN); 24 | else 25 | g2.setColor(Color.RED); 26 | g2.setStroke(new BasicStroke(6.0f)); 27 | g.drawRect(0, 0, getWidth(), getHeight()); 28 | g2.setStroke(new BasicStroke(1.0f)); 29 | g2.setColor(Color.BLACK); 30 | } 31 | } 32 | } -------------------------------------------------------------------------------- /src/UCTTreeNode.java: -------------------------------------------------------------------------------- 1 | package AI; 2 | import java.util.ArrayList; 3 | import Game.*; 4 | public class UCTTreeNode 5 | { 6 | public ArrayList children = new ArrayList(); 7 | public UCTTreeNode parent = null; 8 | public int[] data = new int[64]; 9 | public int expand = 0; 10 | public int[] next = new int[65]; 11 | public int hold; 12 | 13 | public UCTTreeNode(int[] data, int h) 14 | { 15 | for (int i = 0; i < 64; i++) 16 | this.data[i] = data[i]; 17 | next = Board.searchNext(this.data, h); 18 | } 19 | 20 | public boolean equals(Object o) 21 | { 22 | if (o == null) 23 | return false; 24 | if (o == this) 25 | return true; 26 | if (getClass() != o.getClass()) 27 | return false; 28 | UCTTreeNode a = (UCTTreeNode) o; 29 | if (children == a.children) 30 | return true; 31 | if ((children == a.children) && (parent == a.parent) && (data == a.data) && (hold == a.hold)) 32 | return true; 33 | return false; 34 | } 35 | public int hashCode() 36 | { 37 | return data.hashCode(); 38 | } 39 | 40 | public void setParent(UCTTreeNode parent) 41 | { 42 | this.parent = parent; 43 | } 44 | public void addChild(UCTTreeNode child) 45 | { 46 | child.setParent(this); 47 | this.children.add(child); 48 | } 49 | public UCTTreeNode addChild(int[] data) 50 | { 51 | UCTTreeNode child = new UCTTreeNode(data, hold); 52 | child.setParent(this); 53 | this.children.add(child); 54 | return child; 55 | } 56 | 57 | public boolean isRoot() { 58 | return (this.parent == null); 59 | } 60 | 61 | public boolean isLeaf() { 62 | if(this.children.size() == 0) 63 | return true; 64 | else 65 | return false; 66 | } 67 | } -------------------------------------------------------------------------------- /src/InitFrame.java: -------------------------------------------------------------------------------- 1 | package Game; 2 | import java.awt.*; 3 | import java.awt.event.*; 4 | import javax.swing.*; 5 | import java.util.*; 6 | import java.io.*; 7 | public class InitFrame extends JFrame implements MouseListener 8 | { 9 | ImageIcon icon1, icon2; 10 | JPanel p1, p2, big; 11 | public int phold = 0, ahold = 0; 12 | public boolean ok = false; 13 | public InitFrame() 14 | { 15 | super("Black or White?"); 16 | p1 = new JPanel(); 17 | p2 = new JPanel(); 18 | big = new JPanel(); 19 | p1.setBackground(Color.black); 20 | p2.setBackground(Color.white); 21 | big.setLayout(new GridLayout(1, 2, 1, 1)); 22 | big.add(p1); 23 | big.add(p2); 24 | big.setBounds(0, 0, 400, 200); 25 | add(big); 26 | addWindowListener(new WindowAdapter() 27 | { 28 | @Override 29 | public void windowClosing(WindowEvent e) 30 | { 31 | if ((phold == 0) && (ahold == 0)) 32 | { 33 | phold = 1; 34 | ahold = -1; 35 | ok = true; 36 | } 37 | } 38 | }); 39 | addMouseListener(this); 40 | setSize(400, 235); 41 | setVisible(true); 42 | setFocusable(true); 43 | setResizable(false); 44 | } 45 | public void mousePressed(MouseEvent event) 46 | { 47 | int x1, y1, tx, ty, width, length; 48 | tx = event.getX(); 49 | ty = event.getY(); 50 | Point end = SwingUtilities.convertPoint(big, big.getWidth(), big.getHeight(), this); 51 | Point top = SwingUtilities.convertPoint(big, 0, 0, this); 52 | width = (end.x - top.x) / 2; 53 | length = end.y - top.y; 54 | if ((tx >= top.x) && (ty >= top.y) && (tx <= end.x) && (ty <= end.y)) 55 | { 56 | x1 = (tx - top.x) / width + 1; 57 | y1 = (ty - top.y) / length + 1; 58 | //System.out.printf("%d %d", x1, y1); 59 | if (x1 == 1) 60 | { 61 | phold = 1; 62 | ahold = -1; 63 | ok = true; 64 | setVisible(false); 65 | } 66 | else 67 | { 68 | phold = -1; 69 | ahold = 1; 70 | ok = true; 71 | setVisible(false); 72 | } 73 | } 74 | } 75 | public void mouseReleased(MouseEvent event){} 76 | public void mouseClicked(MouseEvent event){} 77 | public void mouseEntered(MouseEvent event){} 78 | public void mouseExited(MouseEvent event){} 79 | /*public static void main(String[] args) 80 | { 81 | InitFrame frame = new InitFrame(); 82 | }*/ 83 | } -------------------------------------------------------------------------------- /src/IOHandler.java: -------------------------------------------------------------------------------- 1 | package IOHandler; 2 | import java.io.*; 3 | import java.util.*; 4 | public class IOHandler 5 | { 6 | private FileInputStream fi; 7 | private FileOutputStream fo; 8 | private ObjectInputStream oi; 9 | private ObjectOutputStream oo; 10 | private String fileName; 11 | private File f; 12 | public IOHandler(String name) 13 | { 14 | fileName = name; 15 | } 16 | public void openFileRead() 17 | { 18 | f = new File(fileName); 19 | if (!f.exists()) 20 | { 21 | try 22 | { 23 | f.createNewFile(); 24 | } 25 | catch(IOException e){} 26 | } 27 | try 28 | { 29 | fi = new FileInputStream(fileName); 30 | oi = new ObjectInputStream(fi); 31 | } 32 | catch(IOException e){} 33 | } 34 | public void openFileWrite() 35 | { 36 | f = new File(fileName); 37 | if (!f.exists()) 38 | { 39 | try 40 | { 41 | f.createNewFile(); 42 | } 43 | catch(IOException e){} 44 | } 45 | try 46 | { 47 | fo = new FileOutputStream(fileName); 48 | oo = new ObjectOutputStream(fo); 49 | } 50 | catch(IOException e){} 51 | } 52 | public void closeAll() 53 | { 54 | if (fi != null) 55 | { 56 | try 57 | { 58 | oi.close(); 59 | fi.close(); 60 | oi = null; 61 | fi = null; 62 | } 63 | catch(IOException e){} 64 | } 65 | if (fo != null) 66 | { 67 | try 68 | { 69 | oo.close(); 70 | fo.close(); 71 | oo = null; 72 | fo = null; 73 | } 74 | catch(IOException e){} 75 | } 76 | } 77 | public void flush() 78 | { 79 | try 80 | { 81 | oo.flush(); 82 | } 83 | catch(IOException e){} 84 | } 85 | public void writeDouble(double d) 86 | { 87 | try 88 | { 89 | oo.writeDouble(d); 90 | } 91 | catch (IOException e){} 92 | } 93 | public double readDouble() 94 | { 95 | double d = 0; 96 | try 97 | { 98 | d = oi.readDouble(); 99 | } 100 | catch (EOFException e) 101 | { 102 | closeAll(); 103 | return Double.MIN_NORMAL; 104 | } 105 | catch (IOException e){} 106 | return d; 107 | } 108 | /* 109 | public static void main(String[] args) 110 | { 111 | IOHandler handler = new IOHandler("test.txt"); 112 | handler.openFileWrite(); 113 | handler.writeDouble(0.21); 114 | handler.writeDouble(0.22); 115 | handler.flush(); 116 | handler.closeAll(); 117 | double a, b; 118 | handler.openFileRead(); 119 | a = handler.readDouble(); 120 | b = handler.readDouble(); 121 | System.out.printf("%2f %2f\n", a, b); 122 | handler.closeAll(); 123 | }*/ 124 | } -------------------------------------------------------------------------------- /src/AlphaBetaAI.java: -------------------------------------------------------------------------------- 1 | package AI; 2 | import AI.*; 3 | import Game.*; 4 | public class AlphaBetaAI extends AI 5 | { 6 | private double[] value = new double[65]; 7 | public AlphaBetaAI(int h) 8 | { 9 | super(h); 10 | for (int i = 1; i <= 64; i++) 11 | value[i] = 0; 12 | value[1] = 100; 13 | value[8] = 100; 14 | value[57] = 100; 15 | value[64] = 100; 16 | value[2] = -2; 17 | value[7] = -2; 18 | value[9] = -2; 19 | value[10] = -2; 20 | value[15] = -2; 21 | value[16] = -2; 22 | value[49] = -2; 23 | value[50] = -2; 24 | value[58] = -2; 25 | value[55] = -2; 26 | value[56] = -2; 27 | value[63] = -2; 28 | } 29 | public AlphaBetaAI(int x[], int h) 30 | { 31 | super(x, h); 32 | } 33 | private double max(double a, double b) 34 | { 35 | if (a >= b) 36 | return a; 37 | return b; 38 | } 39 | private double min(double a, double b) 40 | { 41 | if (a <= b) 42 | return a; 43 | return b; 44 | } 45 | private double AlphaBeta(int s, int[] pt, int depth, double a, double b, int h) 46 | { 47 | int[] moves = new int[65]; 48 | pt = Board.nextState(pt, s, h); 49 | moves = Board.searchNext(pt, h); 50 | if ((depth == 0) || (moves[0] == 0)) 51 | { 52 | double sum = 0; 53 | int hands = Board.calcHand(pt); 54 | for (int i = 0; i < 64; i++) 55 | { 56 | if (pt[i] == hold) 57 | sum += value[i] + Math.pow(1.2, Math.max(35, hands) - 34) * 0.09; 58 | else if (pt[i] == -hold) 59 | sum -= value[i] + Math.pow(1.2, Math.max(35, hands) - 34) * 0.09; 60 | } 61 | if (Board.terminal(pt)) sum *= 50; 62 | return sum; 63 | } 64 | if (h == hold) 65 | { 66 | for (int i = 1; i <= moves[0]; i++) 67 | { 68 | a = max(a, AlphaBeta(moves[i], pt, depth - 1, a, b, -h)); 69 | if (b <= a) 70 | break; 71 | return a; 72 | } 73 | } 74 | else 75 | { 76 | for (int i = 1; i <= moves[0]; i++) 77 | { 78 | b = min(b, AlphaBeta(moves[i], pt, depth - 1, a, b, -h)); 79 | if (b <= a) 80 | break; 81 | return b; 82 | } 83 | } 84 | return 0.0; 85 | } 86 | @Override 87 | public int move() 88 | { 89 | int[] steps = new int[65]; 90 | int step = 1; 91 | double max = -2147483647, temp; 92 | int[] s = new int[64]; 93 | int depth = 4; 94 | int hands = Board.calcHand(table.getTable()); 95 | steps = table.nextSteps(hold); 96 | int[] nexts = new int[65]; 97 | if (steps[0] > 0) 98 | { 99 | if (steps[0] == 1) 100 | step = steps[1]; 101 | else 102 | { 103 | for (int i = 1; i <= 8; i++) 104 | for (int j = 1; j <= 8; j++) 105 | s[(i - 1) * 8 + j - 1] = table.bigTable[i][j]; 106 | for (int i = 1; i <= steps[0]; i++) 107 | { 108 | temp = AlphaBeta(steps[i], s, depth, -2147483647, 2147483647, hold); 109 | if (temp > max) 110 | { 111 | max = temp; 112 | nexts[0] = 1; 113 | nexts[1] = steps[i]; 114 | } 115 | else if (temp == max) 116 | { 117 | nexts[0]++; 118 | nexts[nexts[0]] = steps[i]; 119 | } 120 | } 121 | if (nexts[0] == 1) 122 | step = nexts[1]; 123 | else 124 | step = nexts[(int)(Math.random() * nexts[0]) + 1]; 125 | } 126 | table.set(step / 8 + 1, step % 8 + 1, hold); 127 | return step; 128 | } 129 | return -1; 130 | } 131 | } -------------------------------------------------------------------------------- /src/Playground.java: -------------------------------------------------------------------------------- 1 | package Game; 2 | import Game.*; 3 | import AI.*; 4 | import java.awt.*; 5 | import java.awt.event.*; 6 | import javax.swing.*; 7 | public class Playground extends JFrame 8 | { 9 | public Chess[][] p; 10 | public JPanel big; 11 | private UCTAI ai; 12 | private AlphaBetaAI ab; 13 | private Board table = new Board(); 14 | private int now = 1; 15 | public Playground() 16 | { 17 | super("Othello"); 18 | p = new Chess[8][8]; 19 | big = new JPanel() 20 | { 21 | public void paintComponent(Graphics g) 22 | { 23 | super.paintComponent(g); 24 | for (int i = 0; i < 8; i++) 25 | for (int j = 0; j < 8; j++) 26 | { 27 | p[i][j].color = table.bigTable[i + 1][j + 1]; 28 | p[i][j].repaint(); 29 | if ((table.bigTable[i + 1][j + 1] == -2) || (table.bigTable[i + 1][j + 1] == 2)) 30 | table.bigTable[i + 1][j + 1] /= 2; 31 | } 32 | } 33 | }; 34 | big.setLayout(new GridLayout(8, 8, 2, 2)); 35 | big.setBackground(Color.white); 36 | for (int i = 0; i < 8; i++) 37 | for (int j = 0; j < 8; j++) 38 | { 39 | p[i][j] = new Chess(); 40 | p[i][j].color = table.bigTable[i + 1][j + 1]; 41 | big.add(p[i][j]); 42 | } 43 | add(big); 44 | setBackground(Color.white); 45 | setSize(600, 600); 46 | setResizable(false); 47 | } 48 | public void init() 49 | { 50 | int kk = (int)(Math.random() * 2); 51 | kk = (int)Math.pow(-1, kk); 52 | ab = new AlphaBetaAI(kk); 53 | ai = new UCTAI(-kk); 54 | ai.table = table; 55 | ab.table = table; 56 | String ss1 = (kk == 1)?"black":"white"; 57 | String ss2 = (kk == 1)?"white":"black"; 58 | System.out.printf("AlphaBetaAI is %s\n", ss1); 59 | System.out.printf("UCT AI is %s\n", ss2); 60 | big.repaint(); 61 | big.setFocusable(true); 62 | setVisible(true); 63 | } 64 | public void checkWinner() 65 | { 66 | int b = 0, a = 0; 67 | for (int i = 1; i <= 8; i++) 68 | for (int j = 1; j <= 8; j++) 69 | if (table.bigTable[i][j] == ai.hold) 70 | a++; 71 | else if (table.bigTable[i][j] == ab.hold) 72 | b++; 73 | if (b > a) 74 | System.out.println("AlphaBetaAI wins!"); 75 | else if (a > b) 76 | System.out.println("UTC AI wins!"); 77 | else 78 | System.out.println("Draw!"); 79 | } 80 | public void play() 81 | { 82 | int step1, count1 = 0; 83 | init(); 84 | while (true) 85 | { 86 | big.repaint(); 87 | try 88 | { 89 | Thread.sleep(200); 90 | } 91 | catch (Exception e){} 92 | if (now == ai.hold) 93 | { 94 | ai.table = new Board(table); 95 | step1 = ai.move(); 96 | //System.out.printf("ai: %d\n", step1); 97 | if (step1 >= 0) 98 | { 99 | table = new Board(ai.table); 100 | count1 = 0; 101 | } 102 | else 103 | { 104 | if (count1 > 0) 105 | { 106 | checkWinner(); 107 | break; 108 | } 109 | count1++; 110 | } 111 | now = ab.hold; 112 | } 113 | else if (now == ab.hold) 114 | { 115 | ab.table = new Board(table); 116 | step1 = ab.move(); 117 | //System.out.printf("ab: %d\n", step1); 118 | if (step1 >= 0) 119 | { 120 | table = new Board(ab.table); 121 | count1 = 0; 122 | } 123 | else 124 | { 125 | if (count1 > 0) 126 | { 127 | checkWinner(); 128 | break; 129 | } 130 | count1++; 131 | } 132 | now = ai.hold; 133 | } 134 | } 135 | } 136 | } -------------------------------------------------------------------------------- /src/VectorMatrix.java: -------------------------------------------------------------------------------- 1 | package NeuralNetwork; 2 | import jcuda.*; 3 | import jcuda.jcublas.*; 4 | import jcuda.runtime.*; 5 | public class VectorMatrix 6 | { 7 | public static double dot(double[] a, double[] b, int n) 8 | { 9 | double ans; 10 | JCublas.cublasInit(); 11 | Pointer pa = new Pointer(); 12 | Pointer pb = new Pointer(); 13 | JCublas.cublasAlloc(n, Sizeof.DOUBLE, pa); 14 | JCublas.cublasAlloc(n, Sizeof.DOUBLE, pb); 15 | JCublas.cublasSetVector(n, Sizeof.DOUBLE, Pointer.to(a), 1, pa, 1); 16 | JCublas.cublasSetVector(n, Sizeof.DOUBLE, Pointer.to(b), 1, pb, 1); 17 | ans = JCublas.cublasDdot(n, pa, 1, pb, 1); 18 | JCublas.cublasFree(pa); 19 | JCublas.cublasFree(pb); 20 | JCublas.cublasShutdown(); 21 | return ans; 22 | } 23 | public static double[] mulMatrixVector(double[] a, int n, double[] b, int m, char op) 24 | { 25 | double[] ans = new double[n]; 26 | JCublas.cublasInit(); 27 | Pointer da = new Pointer(); 28 | Pointer db = new Pointer(); 29 | Pointer dc = new Pointer(); 30 | JCublas.cublasAlloc(n * m, Sizeof.DOUBLE, da); 31 | JCublas.cublasAlloc(m, Sizeof.DOUBLE, db); 32 | JCublas.cublasAlloc(n, Sizeof.DOUBLE, dc); 33 | if (op == 't') 34 | JCublas.cublasSetMatrix(n, m, Sizeof.DOUBLE, Pointer.to(a), n, da, n); 35 | else 36 | JCublas.cublasSetMatrix(m, n, Sizeof.DOUBLE, Pointer.to(a), m, da, m); 37 | JCublas.cublasSetVector(m, Sizeof.DOUBLE, Pointer.to(b), 1, db, 1); 38 | JCublas.cublasSetVector(n, Sizeof.DOUBLE, Pointer.to(ans), 1, dc, 1); 39 | JCublas.cublasDgemv(op, n, m, 1, da, n, db, 1, 0, dc, 1); 40 | JCublas.cublasGetVector(n, Sizeof.DOUBLE, dc, 1, Pointer.to(ans), 1); 41 | JCublas.cublasFree(da); 42 | JCublas.cublasFree(db); 43 | JCublas.cublasFree(dc); 44 | JCublas.cublasShutdown(); 45 | return ans; 46 | } 47 | public static double[] mulVectorVector(double[] a, int n, double[] b, int m) 48 | { 49 | double[] ans = new double[n * m]; 50 | JCublas.cublasInit(); 51 | Pointer da = new Pointer(); 52 | Pointer db = new Pointer(); 53 | Pointer dc = new Pointer(); 54 | JCublas.cublasAlloc(n, Sizeof.DOUBLE, da); 55 | JCublas.cublasAlloc(m, Sizeof.DOUBLE, db); 56 | JCublas.cublasAlloc(n * m, Sizeof.DOUBLE, dc); 57 | JCublas.cublasSetVector(n, Sizeof.DOUBLE, Pointer.to(a), 1, da, 1); 58 | JCublas.cublasSetVector(m, Sizeof.DOUBLE, Pointer.to(b), 1, db, 1); 59 | JCublas.cublasDger(m, n, 1, db, 1, da, 1, dc, m); 60 | JCublas.cublasGetMatrix(m, n, Sizeof.DOUBLE, dc, m, Pointer.to(ans), m); 61 | JCublas.cublasFree(da); 62 | JCublas.cublasFree(db); 63 | JCublas.cublasFree(dc); 64 | JCublas.cublasShutdown(); 65 | return ans; 66 | } 67 | public static double[] addMatrix(double[] a, double[] b, int n, int m, double alpha, double beta) 68 | { 69 | double[] ans = new double[n * m]; 70 | JCublas.cublasInit(); 71 | Pointer da = new Pointer(); 72 | Pointer db = new Pointer(); 73 | Pointer dc = new Pointer(); 74 | JCuda.cudaMalloc(da, n * m * Sizeof.DOUBLE); 75 | JCuda.cudaMalloc(db, n * m * Sizeof.DOUBLE); 76 | JCuda.cudaMalloc(dc, n * m * Sizeof.DOUBLE); 77 | JCuda.cudaMemcpy(da, Pointer.to(a), n * m * Sizeof.DOUBLE, cudaMemcpyKind.cudaMemcpyHostToDevice); 78 | JCuda.cudaMemcpy(db, Pointer.to(b), n * m * Sizeof.DOUBLE, cudaMemcpyKind.cudaMemcpyHostToDevice); 79 | JCublas2.cublasSetMatrix(m, n, Sizeof.DOUBLE, Pointer.to(a), m, da, m); 80 | JCublas2.cublasSetMatrix(m, n, Sizeof.DOUBLE, Pointer.to(b), m, db, m); 81 | cublasHandle handle = new cublasHandle(); 82 | JCublas2.cublasCreate(handle); 83 | JCublas2.cublasDgeam(handle, cublasOperation.CUBLAS_OP_N, cublasOperation.CUBLAS_OP_N, m, n, Pointer.to(new double[]{alpha}), da, m, Pointer.to(new double[]{beta}), db, m, dc, m); 84 | JCublas2.cublasGetMatrix(n, m, Sizeof.DOUBLE, dc, n, Pointer.to(ans), n); 85 | JCublas.cublasFree(da); 86 | JCublas.cublasFree(db); 87 | JCublas.cublasFree(dc); 88 | JCublas.cublasShutdown(); 89 | return ans; 90 | } 91 | } -------------------------------------------------------------------------------- /src/NetworkLayer.java: -------------------------------------------------------------------------------- 1 | package NeuralNetwork; 2 | import NeuralNetwork.*; 3 | public class NetworkLayer 4 | { 5 | public double[] out, delta, in, b, w, w_grad, b_grad; 6 | public int pre_nodeNum = 1, nodeNum = 1; 7 | public double rate = 0.25; 8 | public NetworkLayer(int n1, int n2) 9 | { 10 | pre_nodeNum = n1; 11 | nodeNum = n2; 12 | 13 | in = new double[n1]; 14 | delta = new double[n1]; 15 | 16 | b = new double[n2]; 17 | out = new double[n2]; 18 | w = new double[n2 * n1]; 19 | w_grad = new double[n2 * n1]; 20 | b_grad = new double[n2]; 21 | } 22 | public void forward(double[] a, int act_type) 23 | { 24 | in = a; 25 | 26 | //W * x 27 | out = VectorMatrix.mulMatrixVector(w, nodeNum, in, pre_nodeNum, 't'); 28 | 29 | //Activate 30 | if (act_type > 0) 31 | { 32 | if (act_type == 1) 33 | { 34 | for (int i = 0; i < nodeNum; i++) 35 | out[i] = Activator.Relu(out[i] + b[i]); 36 | } 37 | else if (act_type == 2) 38 | { 39 | for (int i = 0; i < nodeNum; i++) 40 | out[i] = Activator.LRelu(rate, out[i] + b[i]); 41 | } 42 | else if (act_type == 3) 43 | { 44 | for (int i = 0; i < nodeNum; i++) 45 | out[i] = Activator.Sigmoid(out[i] + b[i]); 46 | } 47 | } 48 | else 49 | { 50 | for (int i = 0; i < nodeNum; i++) 51 | out[i] += b[i]; 52 | } 53 | } 54 | public void backward(double[] pre_delta, int act_type, boolean calc_delta) 55 | { 56 | if (calc_delta) 57 | { 58 | double[] deInput = new double[pre_nodeNum]; 59 | if (act_type == 0) 60 | { 61 | for (int i = 0; i < pre_nodeNum; i++) 62 | deInput[i] = 1; 63 | } 64 | else if (act_type == 1) 65 | { 66 | for (int i = 0; i < pre_nodeNum; i++) 67 | deInput[i] = Activator.dRelu(in[i]); 68 | } 69 | else if (act_type == 2) 70 | { 71 | for (int i = 0; i < pre_nodeNum; i++) 72 | deInput[i] = Activator.dLRelu(rate, in[i]); 73 | } 74 | else if (act_type == 3) 75 | { 76 | for (int i = 0; i < pre_nodeNum; i++) 77 | deInput[i] = Activator.dSigmoid(in[i]); 78 | } 79 | 80 | // delta = (W^T * pre_delta) * deOut 81 | delta = VectorMatrix.mulMatrixVector(w, pre_nodeNum, pre_delta, nodeNum, 'n'); 82 | for (int i = 0; i < pre_nodeNum; i++) 83 | delta[i] *= deInput[i]; 84 | } 85 | 86 | w_grad = VectorMatrix.mulVectorVector(pre_delta, nodeNum, in, pre_nodeNum); 87 | b_grad = pre_delta; 88 | } 89 | public void backwardOut(double[] t, int act_type, int act_type2) 90 | { 91 | double[] d = new double[nodeNum]; 92 | 93 | if (act_type == 0) 94 | { 95 | for (int i = 0; i < nodeNum; i++) 96 | d[i] = t[i] - out[i]; 97 | } 98 | else if (act_type == 1) 99 | { 100 | for (int i = 0; i < nodeNum; i++) 101 | d[i] = Activator.dRelu(out[i]) * (t[i] - out[i]); 102 | } 103 | else if (act_type == 2) 104 | { 105 | for (int i = 0; i < nodeNum; i++) 106 | d[i] = Activator.dLRelu(rate, out[i]) * (t[i] - out[i]); 107 | } 108 | else if (act_type == 3) 109 | { 110 | for (int i = 0; i < nodeNum; i++) 111 | d[i] = Activator.dSigmoid(out[i]) * (t[i] - out[i]); 112 | } 113 | 114 | w_grad = VectorMatrix.mulVectorVector(d, nodeNum, in, pre_nodeNum); 115 | b_grad = d; 116 | 117 | //for (int j = 0; j < nodeNum * pre_nodeNum; j++) 118 | //System.out.printf("%f ", w_grad[j]); 119 | //System.out.println(); 120 | 121 | double[] deInput = new double[pre_nodeNum]; 122 | if (act_type2 == 0) 123 | { 124 | for (int i = 0; i < pre_nodeNum; i++) 125 | deInput[i] = 1; 126 | } 127 | else if (act_type2 == 1) 128 | { 129 | for (int i = 0; i < pre_nodeNum; i++) 130 | deInput[i] = Activator.dRelu(in[i]); 131 | } 132 | else if (act_type2 == 2) 133 | { 134 | for (int i = 0; i < pre_nodeNum; i++) 135 | deInput[i] = Activator.dLRelu(rate, in[i]); 136 | } 137 | else if (act_type2 == 3) 138 | { 139 | for (int i = 0; i < pre_nodeNum; i++) 140 | deInput[i] = Activator.dSigmoid(in[i]); 141 | } 142 | 143 | // delta = (W^T * pre_delta) * deOut 144 | delta = VectorMatrix.mulMatrixVector(w, pre_nodeNum, d, nodeNum, 'n'); 145 | for (int i = 0; i < pre_nodeNum; i++) 146 | delta[i] *= deInput[i]; 147 | 148 | //for (int j = 0; j < nodeNum; j++) 149 | //System.out.printf("%f ", delta[j]); 150 | //System.out.println(); 151 | } 152 | } -------------------------------------------------------------------------------- /src/GUI.java: -------------------------------------------------------------------------------- 1 | package Game; 2 | import java.awt.*; 3 | import java.awt.event.*; 4 | import javax.swing.*; 5 | import Game.*; 6 | import AI.*; 7 | public class GUI extends JFrame implements MouseListener 8 | { 9 | public Chess[][] p; 10 | public JPanel big; 11 | private Player player; 12 | private UCTAI ai; 13 | private Board table = new Board(); 14 | private int now = 1; 15 | private boolean lock = true, done = false; 16 | public GUI() 17 | { 18 | super("Othello"); 19 | p = new Chess[8][8]; 20 | big = new JPanel() 21 | { 22 | public void paintComponent(Graphics g) 23 | { 24 | super.paintComponent(g); 25 | for (int i = 0; i < 8; i++) 26 | for (int j = 0; j < 8; j++) 27 | { 28 | if (table.bigTable[i + 1][j + 1] * ai.hold == -2) 29 | table.bigTable[i + 1][j + 1] /= 2; 30 | p[i][j].color = table.bigTable[i + 1][j + 1]; 31 | p[i][j].repaint(); 32 | if ((table.bigTable[i + 1][j + 1] == -2) || (table.bigTable[i + 1][j + 1] == 2)) 33 | table.bigTable[i + 1][j + 1] /= 2; 34 | } 35 | } 36 | }; 37 | big.setLayout(new GridLayout(8, 8, 2, 2)); 38 | big.setBackground(Color.white); 39 | for (int i = 0; i < 8; i++) 40 | for (int j = 0; j < 8; j++) 41 | { 42 | p[i][j] = new Chess(); 43 | p[i][j].color = table.bigTable[i + 1][j + 1]; 44 | big.add(p[i][j]); 45 | } 46 | add(big); 47 | addMouseListener(this); 48 | setBackground(Color.white); 49 | setSize(600, 600); 50 | setResizable(false); 51 | } 52 | public void init() 53 | { 54 | InitFrame frame = new InitFrame(); 55 | while (!frame.ok) 56 | { 57 | try 58 | { 59 | Thread.sleep(10); 60 | } 61 | catch (Exception e){} 62 | } 63 | try 64 | { 65 | Thread.sleep(200); 66 | } 67 | catch (Exception e){} 68 | player = new Player(frame.phold); 69 | ai = new UCTAI(frame.ahold); 70 | ai.table = table; 71 | setVisible(true); 72 | big.repaint(); 73 | big.setFocusable(true); 74 | } 75 | public void mousePressed(MouseEvent event) 76 | { 77 | int x1, y1, tx, ty, width, length; 78 | if (lock) return; 79 | if (now == player.hold) 80 | { 81 | tx = event.getX(); 82 | ty = event.getY(); 83 | Point end = SwingUtilities.convertPoint(big, big.getWidth(), big.getHeight(), this); 84 | Point top = SwingUtilities.convertPoint(big, 0, 0, this); 85 | width = (end.x - top.x) / 8; 86 | length = (end.y - top.y) / 8; 87 | if ((tx >= top.x) && (ty >= top.y) && (tx <= end.x) && (ty <= end.y)) 88 | { 89 | x1 = (tx - top.x) / width + 1; 90 | y1 = (ty - top.y) / length + 1; 91 | if (table.checkStep(y1, x1, player.hold)) 92 | { 93 | table.set(y1, x1, player.hold); 94 | ai.table = new Board(table); 95 | now = ai.hold; 96 | lock = true; 97 | done = true; 98 | //big.repaint(); 99 | } 100 | } 101 | } 102 | } 103 | public void mouseReleased(MouseEvent event){} 104 | public void mouseClicked(MouseEvent event){} 105 | public void mouseEntered(MouseEvent event){} 106 | public void mouseExited(MouseEvent event){} 107 | public void checkWinner() 108 | { 109 | int p = 0, a = 0; 110 | for (int i = 1; i <= 8; i++) 111 | for (int j = 1; j <= 8; j++) 112 | if (table.bigTable[i][j] == ai.hold) 113 | a++; 114 | else if (table.bigTable[i][j] == player.hold) 115 | p++; 116 | if (p > a) 117 | System.out.println("Congratulation! You win!"); 118 | else if (a > p) 119 | System.out.println("Sorry! AI wins!"); 120 | else 121 | System.out.println("Draw!"); 122 | } 123 | public void play() 124 | { 125 | int step1, count1 = 0; 126 | int[] moves = new int[65]; 127 | init(); 128 | while (true) 129 | { 130 | //big.repaint(); 131 | 132 | if (now == ai.hold) 133 | { 134 | ai.table = new Board(table); 135 | step1 = ai.move(); 136 | if (step1 >= 0) 137 | { 138 | table = new Board(ai.table); 139 | count1 = 0; 140 | } 141 | else 142 | { 143 | if (count1 > 0) 144 | { 145 | checkWinner(); 146 | break; 147 | } 148 | count1++; 149 | } 150 | now = player.hold; 151 | } 152 | else if (now == player.hold) 153 | { 154 | big.repaint(); 155 | moves = table.nextSteps(player.hold); 156 | if ((count1 > 0) && (moves[0] == 0)) 157 | { 158 | checkWinner(); 159 | break; 160 | } 161 | else if (moves[0] == 0) 162 | { 163 | count1++; 164 | } 165 | else 166 | { 167 | lock = false; 168 | done = false; 169 | while (!done) 170 | { 171 | try 172 | { 173 | Thread.sleep(10); 174 | } 175 | catch (Exception e){} 176 | } 177 | count1 = 0; 178 | } 179 | now = ai.hold; 180 | } 181 | } 182 | } 183 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Othello 2 | 3 | ![image](https://github.com/qiaofengmarco/JavaOthello/raw/master/image/d1.png) 4 | 5 | ![image](https://github.com/qiaofengmarco/JavaOthello/raw/master/image/d2.png) 6 | 7 | ![image](https://github.com/qiaofengmarco/JavaOthello/raw/master/image/d8.png) 8 | 9 | ![image](https://github.com/qiaofengmarco/JavaOthello/raw/master/image/d9.png) 10 | 11 | ## Dependencies: 12 | + Google Guava version 23.0 13 | + JCuda 0.8.0 (Mainly JCublas...) 14 | 15 | ## Important Notes: 16 | + I didn't deploy my QLearningAI in the game, because I had encounted problems with NeuralNetwork. 17 | + To be honest, I don't know what is going wrong in my NeuralNetwork.java, NetworkLayer.java and VectorDNN.java. 18 | + However, my UCTAI and AlphaBeta-pruning AI are proved correct and the whole game frame is runnable. 19 | + Please contact me if you find something wrong, especially in the files that I mention above, thanks! 20 | 21 | ## 2017.9.14 Update: 22 | 1. Added TrainFrame class. 23 | 2. Added alpha-beta pruning Min-max AI, referring to: 24 | + http://www.cnblogs.com/pangxiaodong/archive/2011/05/26/2058864.html 25 | 26 | ## 2017.9.19 Update: 27 | 1. Added NeuralNetwork class, referring to: 28 | + http://ufldl.stanford.edu/tutorial/supervised/MultiLayerNeuralNetworks/ 29 | + http://ufldl.stanford.edu/wiki/index.php/%E5%8F%8D%E5%90%91%E4%BC%A0%E5%AF%BC%E7%AE%97%E6%B3%95 30 | + https://www.zybuluo.com/hanbingtao/note/476663 31 | + http://shuokay.com/2016/06/11/optimization/ 32 | 2. Added IOHandler class. 33 | 34 | ## 2017.9.21 Update: 35 | 1. Corrected the mistakes in NeuralNetwork.java, rewrote the backward function. 36 | 2. Finished QLearningAI.java and its debugging. 37 | + http://blog.csdn.net/songrotek/article/details/50917286 38 | 39 | ![image](https://github.com/qiaofengmarco/JavaOthello/raw/master/image/d3.png) 40 | 41 | ![image](https://github.com/qiaofengmarco/JavaOthello/raw/master/image/d4.png) 42 | 43 | 3. Added many static functions in Board.java as helper function. 44 | 45 | ## 2017.9.22 Update: 46 | 1. Recorrected the mistakes in NeuralNetwork.java, also attaching some comments. 47 | 2. Rewrote backward function, corrected the previous misunderstanding of minibatch-sgd, referring to: 48 | + http://blog.csdn.net/u014595019/article/details/52989301 49 | 3. Finished debugging QLearningAI and rewrote some special cases. 50 | 4. Corrected the wrong method for updating minibatches. 51 | 5. Fixed some small bugs. 52 | 6. Changed name: TrainFrame.java -> Playground.java 53 | 54 | ## 2017.9.23 Update: 55 | 1. Fixed numerous misunderstandable bugs in AlphaBetaAI... (Java function passes the reference of an object input?) 56 | 2. Fixed bugs of state-determining logics in GUI.java and Playground.java. 57 | 3. Adjusted some default parameters. 58 | 59 | ## 2017.9.24 Update: 60 | 1. Finished UCT AI, referring to: 61 | + http://blog.csdn.net/Dinosoft/article/details/50893291 62 | + http://www.jianshu.com/p/d011baff6b64 63 | + http://mcts.ai/pubs/mcts-survey-master.pdf 64 | 2. Adjusted some parameters in AlphaBetaAI.java. 65 | 3. Fixed bugs in Board.java and added new static functions as helper functions. 66 | 67 | ![image](https://github.com/qiaofengmarco/JavaOthello/raw/master/image/d5.png) 68 | 69 | ## 2017.9.30 Update: 70 | 1. Converted previous version README.md to English version. (Please forgive me for my non-native expressions...) 71 | 72 | ## 2017.10.8 Update: 73 | 1. Wrote vector-based DNN VectorDNN.java. (Haven't finished yet...) 74 | 75 | ![image](https://github.com/qiaofengmarco/JavaOthello/raw/master/image/d7.png) 76 | 77 | 2. Used JCublas API for matrix-related or vector-related computations. (VectorMatrix.java) 78 | + http://www.jcuda.org/jcuda/jcublas/doc/index.html 79 | + http://blog.csdn.net/cqupt0901/article/details/26393857 80 | + http://www.jcuda.org/jcuda/JCuda.html#Runtime 81 | + http://www.jcuda.org/jcuda/jcublas/JCublas.html 82 | + https://devtalk.nvidia.com/default/topic/1005058/cuda-programming-and-performance/jcuda-matrices-addition/ 83 | 3. Wrapped up Activator Functions. (Activator.java) 84 | 85 | ## 2017.10.9 Update: 86 | 1. Improved GUI. 87 | 88 | ![image](https://github.com/qiaofengmarco/JavaOthello/raw/master/image/d6.png) 89 | 90 | 2. Normalized the format of codes. 91 | 3. Fixed bug in VectorMatrix.java. 92 | 93 | ## 2017.10.10 Update: 94 | 1. Improved GUI. 95 | 96 | ![image](https://github.com/qiaofengmarco/JavaOthello/raw/master/image/d8.png) 97 | 98 | ![image](https://github.com/qiaofengmarco/JavaOthello/raw/master/image/d9.png) 99 | 100 | ## 2017.10.11 Update: 101 | 1. Organized the files. 102 | 2. To do next:fix VectorDNN...QAQ 103 | 3. Important references: 104 | + https://aijunbai.github.io/publications/USTC07-Bai.pdf 105 | 106 | **Modified on 2017.10.11.** 107 | -------------------------------------------------------------------------------- /src/UCTAI.java: -------------------------------------------------------------------------------- 1 | package AI; 2 | import AI.*; 3 | import Game.*; 4 | import java.util.*; 5 | public class UCTAI extends AI 6 | { 7 | private Map Q = new HashMap(); 8 | private Map N = new HashMap(); 9 | public UCTAI(int h) 10 | { 11 | super(h); 12 | } 13 | public UCTAI(int x[], int h) 14 | { 15 | super(x, h); 16 | } 17 | private UCTTreeNode TreePolicy(UCTTreeNode v) 18 | { 19 | while (!Board.terminal(v.data)) 20 | { 21 | if (v.next[0] == -1) 22 | return null; 23 | if (v.expand < v.next[0]) 24 | return Expand(v); 25 | else 26 | v = BestChild(v, 1 / Math.sqrt(2)); 27 | } 28 | return v; 29 | } 30 | private UCTTreeNode Expand(UCTTreeNode v) 31 | { 32 | if (v.next[0] == 0) return v; 33 | v.expand++; 34 | int action = v.next[v.expand]; 35 | int[] next = Board.nextState(v.data, action, hold); 36 | return v.addChild(next); 37 | } 38 | private UCTTreeNode BestChild(UCTTreeNode v, double c) 39 | { 40 | double max = -2147483647, temp = 0, qv, nv, nvthis; 41 | UCTTreeNode[] anstemp = new UCTTreeNode[v.children.size()]; 42 | UCTTreeNode ans = null; 43 | int count = 0; 44 | if ((v.isLeaf()) || (v.next[0] == 0)) return null; 45 | if (v.next[0] == 1) return v.children.get(0); 46 | ans = v.children.get(0); 47 | nvthis = N.get(v).doubleValue(); 48 | for (int i = 1; i <= v.next[0]; i++) 49 | { 50 | qv = Q.get(v.children.get(i - 1)).doubleValue(); 51 | nv = N.get(v.children.get(i - 1)).doubleValue(); 52 | temp = qv / nv + c * Math.sqrt(2 * Math.log(nvthis) / nv); 53 | if (temp > max) 54 | { 55 | max = temp; 56 | count = 1; 57 | anstemp[0] = v.children.get(i); 58 | } 59 | else if (temp == max) 60 | { 61 | count++; 62 | anstemp[count - 1] = v.children.get(i); 63 | } 64 | } 65 | if (count == 1) 66 | ans = anstemp[0]; 67 | else 68 | ans = anstemp[(int)(Math.random() * count)]; 69 | return ans; 70 | } 71 | private int BestChildAction(UCTTreeNode v, double c) 72 | { 73 | double max = -2147483647, temp = 0, qv, nv, nvthis; 74 | int[] anstemp = new int[64]; 75 | int ans; 76 | int count = 0; 77 | if ((v.isLeaf()) || (v.next[0] == 0)) return -1; 78 | if (v.next[0] == 1) return v.next[1]; 79 | ans = v.next[1]; 80 | nvthis = N.get(v).doubleValue(); 81 | for (int i = 1; i <= v.next[0]; i++) 82 | { 83 | qv = Q.get(v.children.get(i - 1)).doubleValue(); 84 | nv = N.get(v.children.get(i - 1)).doubleValue(); 85 | temp = qv / nv + c * Math.sqrt(2 * Math.log(nvthis) / nv); 86 | if (temp > max) 87 | { 88 | max = temp; 89 | count = 1; 90 | anstemp[0] = v.next[i]; 91 | } 92 | else if (temp == max) 93 | { 94 | count++; 95 | anstemp[count - 1] = v.next[i]; 96 | } 97 | } 98 | if (count == 1) 99 | ans = anstemp[0]; 100 | else 101 | ans = anstemp[(int)(Math.random() * count)]; 102 | return ans; 103 | } 104 | private double DefaultPolicy(int[] s) 105 | { 106 | int action = 0, now = hold; 107 | int[] next1, next2; 108 | while (!Board.terminal(s)) 109 | { 110 | if (now == hold) 111 | { 112 | next1 = Board.searchNext(s, hold); 113 | if (next1[0] > 0) 114 | { 115 | action = next1[(int)(Math.random() * next1[0]) + 1]; 116 | s = Board.nextState(s, action, hold); 117 | } 118 | now = -hold; 119 | } 120 | else 121 | { 122 | next2 = Board.searchNext(s, -hold); 123 | if (next2[0] > 0) 124 | { 125 | action = next2[(int)(Math.random() * next2[0]) + 1]; 126 | s = Board.nextState(s, action, -hold); 127 | } 128 | now = hold; 129 | } 130 | } 131 | return 100 * Board.getWinner(s) * hold; 132 | } 133 | private void BackupNegaMax(UCTTreeNode v, double delta) 134 | { 135 | while (v != null) 136 | { 137 | if (N.containsKey(v)) 138 | N.put(v, new Integer(N.get(v).intValue() + 1)); 139 | else 140 | N.put(v, new Integer(1)); 141 | if (Q.containsKey(v)) 142 | Q.put(v, new Double(Q.get(v).doubleValue() + delta)); 143 | else 144 | Q.put(v, new Double(delta)); 145 | delta = -delta; 146 | v = v.parent; 147 | } 148 | } 149 | private int UCTSearch(int[] s) 150 | { 151 | UCTTreeNode v = new UCTTreeNode(s, hold); 152 | UCTTreeNode v1; 153 | double delta; 154 | if (v.next[0] == 0) return -1; 155 | while (v.expand < v.next[0]) 156 | { 157 | v1 = TreePolicy(v); 158 | if (v1 == null) break; 159 | //System.out.printf("%d %d\n", v.expand, v.next[0]); 160 | //System.out.println("aaa"); 161 | delta = DefaultPolicy(v1.data); 162 | //System.out.println("bbb"); 163 | BackupNegaMax(v1, delta); 164 | } 165 | //System.out.println("ccc"); 166 | return BestChildAction(v, 0); 167 | } 168 | @Override 169 | public int move() 170 | { 171 | int[] steps = new int[65]; 172 | int step; 173 | steps = table.nextSteps(hold); 174 | if (steps[0] > 0) 175 | { 176 | if (steps[0] == 1) 177 | step = steps[1]; 178 | else 179 | step = UCTSearch(table.getTable()); 180 | table.set(step / 8 + 1, step % 8 + 1, hold); 181 | return step; 182 | } 183 | return -1; 184 | } 185 | } -------------------------------------------------------------------------------- /src/Board.java: -------------------------------------------------------------------------------- 1 | package Game; 2 | public class Board 3 | { 4 | public int[][] bigTable; 5 | public int bhand = 2, whand = 2, hand = 4; 6 | private int[] MoveX = {0, 0, 0, -1, 1, -1, 1, -1, 1}; 7 | private int[] MoveY = {0, -1, 1, -1, -1, 0, 0, 1, 1}; 8 | public Board() 9 | { 10 | bigTable = new int[9][]; 11 | for (int i = 0; i <= 8; i++) 12 | { 13 | bigTable[i] = new int[9]; 14 | for (int j = 1; j <= 8; j++) 15 | bigTable[i][j] = 0; 16 | } 17 | bigTable[4][4] = -1; 18 | bigTable[4][5] = 1; 19 | bigTable[5][4] = 1; 20 | bigTable[5][5] = -1; 21 | } 22 | public Board(int[] s) 23 | { 24 | bigTable = new int[9][9]; 25 | for (int i = 1; i <= 8; i++) 26 | for (int j = 1; j <= 8; j++) 27 | { 28 | bigTable[i][j] = s[(i - 1) * 8 + j - 1]; 29 | if (bigTable[i][j] == 1) 30 | { 31 | bhand++; 32 | hand++; 33 | } 34 | else if (bigTable[i][j] == -1) 35 | { 36 | whand++; 37 | hand++; 38 | } 39 | } 40 | } 41 | public Board(Board b) 42 | { 43 | bigTable = new int[9][9]; 44 | for (int i = 1; i <= 8; i++) 45 | for (int j = 1; j <= 8; j++) 46 | bigTable[i][j] = b.bigTable[i][j]; 47 | hand = b.hand; 48 | bhand = b.bhand; 49 | whand = b.whand; 50 | } 51 | public static int[] getInit() 52 | { 53 | int[] ans = new int[64]; 54 | for (int i = 0; i < 64; i++) 55 | ans[i] = 0; 56 | ans[27] = -1; 57 | ans[28] = 1; 58 | ans[35] = 1; 59 | ans[36] = -1; 60 | return ans; 61 | } 62 | public static int[] nextState(int[] s1, int action, int hold) 63 | { 64 | Board t = new Board(s1); 65 | t.set(action / 8 + 1, action % 8 + 1, hold); 66 | return t.getTable(); 67 | } 68 | public static int[] searchNext(int[] s1, int hold) 69 | { 70 | Board t = new Board(s1); 71 | return t.nextSteps(hold); 72 | } 73 | public static int calcHand(int[] s) 74 | { 75 | int total = 0; 76 | for (int i = 0; i < 64; i++) 77 | if (s[i] != 0) 78 | total++; 79 | return total; 80 | } 81 | public static int getWinner(int[] s) 82 | { 83 | int ans = -1; 84 | int a = 0, b = 0; 85 | for (int i = 0; i < 64; i++) 86 | if ((s[i] == 1) || (s[i] == 2)) 87 | a++; 88 | else if ((s[i] == -1) || (s[i] == -2)) 89 | b++; 90 | if (a > b) 91 | return 1; 92 | else if (a < b) 93 | return -1; 94 | return 0; 95 | } 96 | public int[] getTable() 97 | { 98 | int[] ans = new int[64]; 99 | for (int i = 1; i <= 8; i++) 100 | for (int j = 1; j <= 8; j++) 101 | ans[(i - 1) * 8 + j - 1] = bigTable[i][j]; 102 | return ans; 103 | } 104 | public void set(int x, int y, int who) 105 | { 106 | int x1, y1, count = 0; 107 | if ((x < 1) || (x > 8) || (y < 1) || (y > 8)) return; 108 | for (int i = 1; i <= 8; i++) 109 | { 110 | count = 0; 111 | x1 = x + MoveX[i]; 112 | y1 = y + MoveY[i]; 113 | count++; 114 | if ((x1 < 1) || (x1 > 8) || (y1 < 1) || (y1 > 8)) continue; 115 | if ((bigTable[x1][y1] != -who) && (bigTable[x1][y1] != -2 * who)) continue; 116 | while ((bigTable[x1][y1] == -who) || (bigTable[x1][y1] == -2 * who)) 117 | { 118 | x1 += MoveX[i]; 119 | y1 += MoveY[i]; 120 | count++; 121 | if ((x1 < 1) || (x1 > 8) || (y1 < 1) || (y1 > 8)) break; 122 | if (bigTable[x1][y1] == 0) break; 123 | if ((bigTable[x1][y1] == who) || (bigTable[x1][y1] == 2 * who)) 124 | { 125 | for (int j = 1; j <= count - 1; j++) 126 | { 127 | bigTable[x + j * MoveX[i]][y + j * MoveY[i]] = who; 128 | //System.out.printf("%d, %d; %d, %d\n", x, y, x + j * MoveX[i], y + j * MoveY[i]); 129 | } 130 | break; 131 | } 132 | } 133 | } 134 | bigTable[x][y] = who * 2; 135 | hand++; 136 | if (who == 1) 137 | bhand++; 138 | else if (who == -1) 139 | whand++; 140 | } 141 | public boolean checkStep(int x, int y, int who) 142 | { 143 | int x1, y1; 144 | if ((x < 1) || (x > 8) || (y < 1) || (y > 8)) return false; 145 | if (bigTable[x][y] != 0) return false; 146 | for (int i = 1; i <= 8; i++) 147 | { 148 | x1 = x + MoveX[i]; 149 | y1 = y + MoveY[i]; 150 | if ((x1 < 1) || (x1 > 8) || (y1 < 1) || (y1 > 8)) continue; 151 | if ((bigTable[x1][y1] != -who) && (bigTable[x1][y1] != -2 * who)) continue; 152 | while ((bigTable[x1][y1] == -who) || (bigTable[x1][y1] == -2 * who)) 153 | { 154 | x1 += MoveX[i]; 155 | y1 += MoveY[i]; 156 | if (((x1 < 1) || (x1 > 8) || (y1 < 1) || (y1 > 8))) break; 157 | if (bigTable[x1][y1] == 0) break; 158 | if ((bigTable[x1][y1] == who) || (bigTable[x1][y1] == 2 * who)) 159 | { 160 | return true; 161 | } 162 | } 163 | } 164 | return false; 165 | } 166 | public int[] nextSteps(int who) 167 | { 168 | int x1, y1; 169 | int[] ans = new int[65]; 170 | for (int i = 0; i < 65; i++) 171 | ans[i] = 0; 172 | for (int i = 1; i <= 8; i++) 173 | for (int j = 1; j <= 8; j++) 174 | if (checkStep(i, j, who)) 175 | { 176 | ans[0]++; 177 | ans[ans[0]] = (i - 1) * 8 + j - 1; 178 | } 179 | return ans; 180 | } 181 | public static boolean terminal(int[] s) 182 | { 183 | Board b = new Board(s); 184 | if (Board.calcHand(s) == 64) return true; 185 | int[] a1, b1; 186 | a1 = b.nextSteps(1); 187 | if (a1[0] > 0) return false; 188 | b1 = b.nextSteps(-1); 189 | if (b1[0] > 0) return false; 190 | return true; 191 | } 192 | } -------------------------------------------------------------------------------- /src/VectorDNN.java: -------------------------------------------------------------------------------- 1 | package NeuralNetwork; 2 | import java.io.*; 3 | import java.util.*; 4 | import NeuralNetwork.*; 5 | import IOHandler.*; 6 | public class VectorDNN 7 | { 8 | private NetworkLayer[] layer; 9 | private int layers = 1; 10 | private double learning_rate = 0.01; 11 | private String path; 12 | private boolean finish = false; 13 | public VectorDNN(String name) 14 | { 15 | double kk; 16 | path = name; 17 | layer = new NetworkLayer[4]; 18 | 19 | layer[0] = new NetworkLayer(64, 16); 20 | layer[1] = new NetworkLayer(16, 16); 21 | layer[2] = new NetworkLayer(16, 16); 22 | layer[3] = new NetworkLayer(16, 1); 23 | layers = 4; 24 | 25 | File f = new File(path); 26 | if (!f.exists()) 27 | { 28 | try 29 | { 30 | f.createNewFile(); 31 | } 32 | catch(IOException e){} 33 | IOHandler handler = new IOHandler(path); 34 | handler.openFileWrite(); 35 | for (int i = 0; i < layers; i++) 36 | { 37 | for (int j = 0; j < layer[i].pre_nodeNum * layer[i].nodeNum; j++) 38 | { 39 | kk = Math.random() * 0.04 + 0.001; 40 | layer[i].w[j] = kk; 41 | handler.writeDouble(layer[i].w[j]); 42 | } 43 | for (int j = 0; j < layer[i].nodeNum; j++) 44 | { 45 | kk = Math.random() * 0.005 + 0.0001; 46 | layer[i].b[j] = kk; 47 | handler.writeDouble(layer[i].b[j]); 48 | } 49 | } 50 | handler.flush(); 51 | handler.closeAll(); 52 | } 53 | else 54 | load(); 55 | } 56 | public double[] forward(double[] x) 57 | { 58 | double[] ans; 59 | layer[0].forward(x, 3); 60 | for (int i = 1; i < layers; i++) 61 | layer[i].forward(layer[i - 1].out, 3); 62 | ans = layer[layers - 1].out; 63 | return ans; 64 | } 65 | public void backward(double[] x, double[] t) 66 | { 67 | forward(x); 68 | layer[layers - 1].backwardOut(t, 3, 3); 69 | for (int i = layers - 2; i >= 1; i--) 70 | layer[i].backward(layer[i + 1].delta, 3, true); 71 | layer[0].backward(layer[1].delta, 3, false); 72 | } 73 | public double evaluate(double[][] x, double[][] tag, int size) 74 | { 75 | double error = 0; 76 | double[] predict; 77 | for (int i = 0; i < size; i++) 78 | { 79 | predict = forward(x[i]); 80 | for (int j = 0; j < layer[layers - 1].nodeNum; j++) 81 | if (Math.abs(predict[j] - tag[i][j]) > 0.01) 82 | { 83 | error += 1; 84 | break; 85 | } 86 | } 87 | return error / size; 88 | } 89 | public void minibatchSGD(double[][] x, double[][] tag, int batchSize) 90 | { 91 | double[][] w_grad_tot = new double[4][]; 92 | double[][] b_grad_tot = new double[4][]; 93 | 94 | for (int i = 1; i < 3; i++) 95 | { 96 | w_grad_tot[i] = new double[16 * 16]; 97 | b_grad_tot[i] = new double[16]; 98 | } 99 | b_grad_tot[0] = new double[16]; 100 | w_grad_tot[0] = new double[64 * 16]; 101 | b_grad_tot[3] = new double[1]; 102 | w_grad_tot[3] = new double[16]; 103 | 104 | for (int k = 0; k < batchSize; k++) 105 | { 106 | backward(x[k], tag[k]); 107 | for (int i = 0; i < layers; i++) 108 | { 109 | w_grad_tot[i] = VectorMatrix.addMatrix(w_grad_tot[i], layer[i].w_grad, layer[i].nodeNum, layer[i].pre_nodeNum, 1, 1.0 / batchSize); 110 | b_grad_tot[i] = VectorMatrix.addMatrix(b_grad_tot[i], layer[i].b_grad, layer[i].nodeNum, 1, 1, 1.0 / batchSize); 111 | } 112 | } 113 | for (int i = 0; i < layers; i++) 114 | { 115 | layer[i].w_grad = VectorMatrix.addMatrix(layer[i].w_grad, w_grad_tot[i], layer[i].nodeNum, layer[i].pre_nodeNum, 1.0 - learning_rate * 0.01, -learning_rate); 116 | layer[i].b_grad = VectorMatrix.addMatrix(layer[i].b_grad, b_grad_tot[i], layer[i].nodeNum, 1, 1, -learning_rate); 117 | } 118 | } 119 | public void SGD(double[] x, double[] tag) 120 | { 121 | double[][] w_grad_tot = new double[4][]; 122 | double[][] b_grad_tot = new double[4][]; 123 | 124 | for (int i = 1; i < 3; i++) 125 | { 126 | w_grad_tot[i] = new double[16 * 16]; 127 | b_grad_tot[i] = new double[16]; 128 | } 129 | b_grad_tot[0] = new double[16]; 130 | w_grad_tot[0] = new double[64 * 16]; 131 | b_grad_tot[3] = new double[1]; 132 | w_grad_tot[3] = new double[16]; 133 | 134 | backward(x, tag); 135 | for (int i = 0; i < layers; i++) 136 | { 137 | w_grad_tot[i] = VectorMatrix.addMatrix(w_grad_tot[i], layer[i].w_grad, layer[i].nodeNum, layer[i].pre_nodeNum, 1.0, 1.0); 138 | b_grad_tot[i] = VectorMatrix.addMatrix(b_grad_tot[i], layer[i].b_grad, layer[i].nodeNum, 1, 1.0, 1.0); 139 | layer[i].w_grad = VectorMatrix.addMatrix(layer[i].w_grad, w_grad_tot[i], layer[i].nodeNum, layer[i].pre_nodeNum, 1.0, -learning_rate); 140 | layer[i].b_grad = VectorMatrix.addMatrix(layer[i].b_grad, b_grad_tot[i], layer[i].nodeNum, 1, 1.0, -learning_rate); 141 | } 142 | } 143 | public void store() 144 | { 145 | IOHandler handler = new IOHandler(path); 146 | handler.openFileWrite(); 147 | for (int i = 0; i < layers; i++) 148 | { 149 | for (int j = 0; j < layer[i].pre_nodeNum * layer[i].nodeNum; j++) 150 | handler.writeDouble(layer[i].w[j]); 151 | for (int j = 0; j < layer[i].nodeNum; j++) 152 | handler.writeDouble(layer[i].b[j]); 153 | } 154 | handler.flush(); 155 | handler.closeAll(); 156 | } 157 | public void load() 158 | { 159 | IOHandler handler = new IOHandler(path); 160 | handler.openFileRead(); 161 | for (int i = 0; i < layers; i++) 162 | { 163 | for (int j = 0; j < layer[i].pre_nodeNum * layer[i].nodeNum; j++) 164 | layer[i].w[j] = handler.readDouble(); 165 | for (int j = 0; j < layer[i].nodeNum; j++) 166 | layer[i].b[j] = handler.readDouble(); 167 | } 168 | handler.closeAll(); 169 | } 170 | public void sampling(double[][] batch1, double[][] batch2, double[][] ans1, double[][] ans2) 171 | { 172 | int count = 0, temp; 173 | HashSet set = new HashSet(); 174 | while (set.size() < 100) 175 | { 176 | temp = (int)(Math.random() * 100); 177 | set.add(new Integer(temp)); 178 | } 179 | for (Integer i : set) 180 | { 181 | ans1[count] = batch1[i.intValue()]; 182 | ans2[count] = batch2[i.intValue()]; 183 | count++; 184 | } 185 | } 186 | public void train() 187 | { 188 | double[][] x = new double[100][64]; 189 | double[][] tag = new double[100][1]; 190 | //double[][] x1 = new double[100][64]; 191 | //double[][] tag1 = new double[100][1]; 192 | double sum; 193 | for (int i = 0; i < 100; i++) 194 | { 195 | sum = 0; 196 | for (int j = 0; j < 64; j++) 197 | { 198 | x[i][j] = Math.random(); 199 | sum += x[i][j]; 200 | } 201 | sum /= 64; 202 | tag[i][0] = sum; 203 | } 204 | //sampling(x, tag, x1, tag1); 205 | for (int i = 0; i < 100; i++) 206 | SGD(x[i], tag[i]); 207 | System.out.printf("Error rate: %f\n", evaluate(x, tag, 100)); 208 | store(); 209 | } 210 | public static void main(String[] args) 211 | { 212 | VectorDNN dnn = new VectorDNN("test.txt"); 213 | for (int i = 0; i < 50; i++) 214 | { 215 | System.out.printf("Round %d:\n", i + 1); 216 | dnn.train(); 217 | } 218 | } 219 | } -------------------------------------------------------------------------------- /src/NeuralNetwork.java: -------------------------------------------------------------------------------- 1 | package NeuralNetwork; 2 | import java.io.*; 3 | import java.util.*; 4 | import IOHandler.*; 5 | public class NeuralNetwork 6 | { 7 | private double rate = 0.001; 8 | private double[][] sum; 9 | private double[][][] w, v, m, dw; 10 | private double[][] vv, mm, b, db; 11 | private String path; 12 | private double[][] error; 13 | private int[] ct = new int[64]; 14 | public NeuralNetwork(String name) 15 | { 16 | int kk = 0; 17 | path = name; 18 | vv = new double[4][200]; 19 | mm = new double[4][200]; 20 | b = new double[4][]; 21 | db = new double[4][]; 22 | sum = new double[3][]; 23 | for (int i = 0; i <= 2; i++) 24 | { 25 | b[i] = new double[200]; 26 | db[i] = new double[200]; 27 | sum[i] = new double[200]; 28 | } 29 | b[3] = new double[64]; 30 | db[3] = new double[64]; 31 | w = new double[4][][]; 32 | dw = new double[4][][]; 33 | v = new double[4][][]; 34 | m = new double[4][][]; 35 | w[0] = new double[64][200]; 36 | dw[0] = new double[64][200]; 37 | v[0] = new double[64][200]; 38 | m[0] = new double[64][200]; 39 | w[3] = new double[200][64]; 40 | dw[3] = new double[200][64]; 41 | v[3] = new double[200][64]; 42 | m[3] = new double[200][64]; 43 | for (int i = 1; i <= 2; i++) 44 | { 45 | w[i] = new double[200][200]; 46 | dw[i] = new double[200][200]; 47 | v[i] = new double[200][200]; 48 | m[i] = new double[200][200]; 49 | } 50 | 51 | 52 | error = new double[4][]; 53 | for (int i = 0; i < 3; i++) 54 | error[i] = new double[200]; 55 | error[3] = new double[1]; 56 | 57 | File f = new File(name); 58 | if (!f.exists()) 59 | { 60 | try 61 | { 62 | f.createNewFile(); 63 | } 64 | catch(IOException e){} 65 | IOHandler handler = new IOHandler(path); 66 | handler.openFileWrite(); 67 | for (int i = 0; i < 64; i++) 68 | for (int j = 0; j < 200; j++) 69 | { 70 | kk = (int)(Math.random() * 2); 71 | w[0][i][j] = (Math.random() * 29 + 1) * 0.01 * ((int)Math.pow(-1, kk)); 72 | handler.writeDouble(w[0][i][j]); 73 | } 74 | for (int i = 1; i < 3; i++) 75 | for (int j = 0; j < 200; j++) 76 | for (int k = 0; k < 200; k++) 77 | { 78 | kk = (int)(Math.random() * 2); 79 | w[i][j][k] = (Math.random() * 29 + 1) * 0.01 * ((int)Math.pow(-1, kk)); 80 | handler.writeDouble(w[i][j][k]); 81 | } 82 | for (int i = 0; i < 200; i++) 83 | for (int j = 0; j < 64; j++) 84 | { 85 | kk = (int)(Math.random() * 2); 86 | w[3][i][j] = (Math.random() * 29 + 1) * 0.01 * ((int)Math.pow(-1, kk)); 87 | handler.writeDouble(w[3][i][j]); 88 | } 89 | for (int i = 0; i < 3; i++) 90 | for (int j = 0; j < 200; j++) 91 | { 92 | kk = (int)(Math.random() * 2); 93 | b[i][j] = (Math.random() * 29 + 1) * 0.01 * ((int)Math.pow(-1, kk)); 94 | handler.writeDouble(b[i][j]); 95 | } 96 | for (int i = 0; i < 64; i++) 97 | { 98 | kk = (int)(Math.random() * 2); 99 | b[3][i] = (Math.random() * 29 + 1) * 0.01 * ((int)Math.pow(-1, kk)); 100 | handler.writeDouble(b[3][i]); 101 | } 102 | handler.flush(); 103 | handler.closeAll(); 104 | } 105 | else 106 | load(); 107 | } 108 | public void store() 109 | { 110 | IOHandler handler = new IOHandler(path); 111 | handler.openFileWrite(); 112 | for (int i = 0; i < 64; i++) 113 | for (int j = 0; j < 200; j++) 114 | handler.writeDouble(w[0][i][j]); 115 | for (int i = 1; i < 3; i++) 116 | for (int j = 0; j < 200; j++) 117 | for (int k = 0; k < 200; k++) 118 | handler.writeDouble(w[i][j][k]); 119 | for (int i = 0; i < 200; i++) 120 | for (int j = 0; j < 64; j++) 121 | handler.writeDouble(w[3][i][j]); 122 | for (int i = 0; i < 3; i++) 123 | for (int j = 0; j < 200; j++) 124 | handler.writeDouble(b[i][j]); 125 | for (int i = 0; i < 64; i++) 126 | handler.writeDouble(b[3][i]); 127 | handler.flush(); 128 | handler.closeAll(); 129 | } 130 | public void load() 131 | { 132 | IOHandler handler = new IOHandler(path); 133 | handler.openFileRead(); 134 | for (int i = 0; i < 64; i++) 135 | for (int j = 0; j < 200; j++) 136 | w[0][i][j] = handler.readDouble(); 137 | for (int i = 1; i < 3; i++) 138 | for (int j = 0; j < 200; j++) 139 | for (int k = 0; k < 200; k++) 140 | w[i][j][k] = handler.readDouble(); 141 | for (int i = 0; i < 200; i++) 142 | for (int j = 0; j < 64; j++) 143 | w[3][i][j] = handler.readDouble(); 144 | for (int i = 0; i < 3; i++) 145 | for (int j = 0; j < 200; j++) 146 | b[i][j] = handler.readDouble(); 147 | for (int i = 0; i < 64; i++) 148 | b[3][i] = handler.readDouble(); 149 | handler.closeAll(); 150 | } 151 | public double LRelu(double a) 152 | { 153 | if (a >= 0) 154 | return a; 155 | return 0.3 * a; 156 | } 157 | public double de(double a) 158 | { 159 | if (a >= 0) 160 | return 1.0; 161 | return 0.3; 162 | } 163 | public double forward(int[] x, int action) 164 | //x is a 64 * 1 vector 165 | { 166 | double ans = 0; 167 | 168 | for (int j = 0; j < 200; j++) 169 | { 170 | for (int i = 0; i < 64; i++) 171 | { 172 | if (i == 0) sum[0][j] = 0; //clear 173 | sum[0][j] += x[i] * w[0][i][j]; 174 | } 175 | sum[0][j] = LRelu(sum[0][j] + b[0][j]); 176 | } 177 | 178 | for (int k = 1; k <= 2; k++) 179 | for (int j = 0; j < 200; j++) 180 | { 181 | for (int i = 0; i < 200; i++) 182 | { 183 | if (i == 0) sum[k][j] = 0; //clear 184 | sum[k][j] += sum[k - 1][i] * w[k][i][j]; 185 | } 186 | sum[k][j] = LRelu(sum[k][j] + b[k][j]); 187 | } 188 | 189 | ans = 0; 190 | for (int i = 0; i < 200; i++) 191 | ans += sum[2][i] * w[3][i][action]; 192 | //previous: ans = Math.tanh(ans + b[3][action]); 193 | ans += b[3][action]; 194 | return ans; 195 | } 196 | public void backward(int[][] x, int[] action, double[] y, int minibatchSize) 197 | { 198 | double t, temp = 0; 199 | for (int kk = 0; kk < minibatchSize; kk++) 200 | { 201 | t = forward(x[kk], action[kk]); 202 | ct[action[kk]]++; 203 | temp = 0; 204 | 205 | //[output -> (n-1) hidden] 206 | //dE / dnet_j = df * (y - t) 207 | //dE /dw_i,j = dE / dnet_j * x_i,j 208 | //previous: error[3][0] = (1 - Math.pow(y[kk], 2.0)) * (y[kk] - t); 209 | error[3][0] = 0.5 * y[kk] - t; 210 | for (int i = 0; i < 200; i++) 211 | { 212 | //update w_3,i,action and b_3,action 213 | dw[3][i][action[kk]] += error[3][0] * sum[2][i]; 214 | } 215 | db[3][action[kk]] += error[3][0]; 216 | 217 | 218 | 219 | //[(n-1) hidden -> (n-2) hidden] 220 | //dE / dnet_j = df(j) * sigma(error_l,k * w_l,j,k) where l is round and k is from j's downstream 221 | //dE /dw_i,j = dE / dnet_j * x_i,j 222 | for (int i = 0; i < 200; i++) 223 | { 224 | for (int j = 0; j < 200; j++) 225 | { 226 | //the statement below simply means: 227 | //temp = error[2][0] * w[3][j][action]; 228 | //error[2][j] = temp * de(sum[2][j]); 229 | 230 | error[2][j] = error[3][0] * w[3][j][action[kk]] * de(sum[2][i]); 231 | dw[2][i][j] += error[2][j] * sum[1][i]; 232 | } 233 | db[2][i] += error[2][i]; 234 | } 235 | 236 | 237 | 238 | //[(n-2) hidden -> (n-3) hidden] 239 | //dE / dnet_j = df(j) * sigma(error_l,k * w_l,j,k) where l is round and k is from j's downstream 240 | //dE /dw_i,j = dE / dnet_j * x_i,j 241 | for (int i = 0; i < 200; i++) 242 | { 243 | for (int j = 0; j < 200; j++) 244 | { 245 | temp = 0; 246 | for (int k = 0; k < 200; k++) 247 | temp += error[2][k] * w[2][j][k]; 248 | error[1][j] = temp * de(sum[1][i]); 249 | dw[1][i][j] += error[1][j] * sum[0][i]; 250 | } 251 | db[1][i] += error[1][i]; 252 | } 253 | 254 | 255 | 256 | //[(n-3) hidden -> input] 257 | //dE / dnet_j = df(j) * sigma(error_l,k * w_l,j,k) where l is round and k is from j's downstream 258 | //dE /dw_i,j = dE / dnet_j * x_i,j 259 | for (int i = 0; i < 64; i++) 260 | { 261 | for (int j = 0; j < 200; j++) 262 | { 263 | temp = 0; 264 | for (int k = 0; k < 200; k++) 265 | temp += error[1][k] * w[1][j][k]; 266 | error[0][j] = temp * de(sum[0][i]); 267 | dw[0][i][j] += error[0][j] * x[kk][i]; 268 | } 269 | db[0][i] = error[0][i]; 270 | } 271 | } 272 | 273 | for (int kk = 0; kk < minibatchSize; kk++) 274 | if (ct[action[kk]] > 0) 275 | { 276 | for (int i = 0; i < 200; i++) 277 | { 278 | dw[3][i][action[kk]] = (1.0 / ct[action[kk]]) * dw[3][i][action[kk]]; 279 | m[3][i][action[kk]] = 0.9 * m[3][i][action[kk]] + 0.1 * dw[3][i][action[kk]]; 280 | v[3][i][action[kk]] = 0.999 * v[3][i][action[kk]] + 0.001 * dw[3][i][action[kk]] * dw[3][i][action[kk]]; 281 | w[3][i][action[kk]] -= rate * m[3][i][action[kk]] / (Math.sqrt(v[3][i][action[kk]]) + 1e-8); 282 | dw[3][i][action[kk]] = 0; 283 | } 284 | db[3][action[kk]] = (1.0 / ct[action[kk]]) * db[3][action[kk]]; 285 | mm[3][action[kk]] = 0.9 * mm[3][action[kk]] + 0.1 * db[3][action[kk]]; 286 | vv[3][action[kk]] = 0.999 * vv[3][action[kk]] + 0.001 * db[3][action[kk]] * db[3][action[kk]]; 287 | b[3][action[kk]] -= rate * mm[3][action[kk]] / (Math.sqrt(vv[3][action[kk]]) + 1e-8); 288 | db[3][action[kk]] = 0; 289 | ct[action[kk]] = 0; 290 | } 291 | 292 | for (int i = 0; i < 64; i++) 293 | for (int j = 0; j < 200; j++) 294 | { 295 | dw[0][i][j] = (1 / minibatchSize) * dw[0][i][j]; 296 | m[0][i][j] = 0.9 * m[0][i][j] + 0.1 * dw[0][i][j]; 297 | v[0][i][j] = 0.999 * v[0][i][j] + 0.001 * dw[0][i][j] * dw[0][i][j]; 298 | w[0][i][j] -= rate * m[0][i][j] / (Math.sqrt(v[0][i][j]) + 1e-8); 299 | dw[0][i][j] = 0; 300 | } 301 | for (int k = 1; k <= 2; k++) 302 | for (int i = 0; i < 200; i++) 303 | for (int j = 0; j < 200; j++) 304 | { 305 | dw[k][i][j] = (1 / minibatchSize) * dw[k][i][j]; 306 | m[k][i][j] = 0.9 * m[k][i][j] + 0.1 * dw[k][i][j]; 307 | v[k][i][j] = 0.999 * v[k][i][j] + 0.001 * dw[k][i][j] * dw[k][i][j]; 308 | w[k][i][j] -= rate * m[k][i][j] / (Math.sqrt(v[k][i][j]) + 1e-8); 309 | dw[k][i][j] = 0; 310 | } 311 | for (int i = 0; i < 3; i++) 312 | for (int j = 0; j < 200; j++) 313 | { 314 | db[i][j] = (1 / minibatchSize) * db[i][j]; 315 | mm[i][j] = 0.9 * mm[i][j] + 0.1 * db[i][j]; 316 | vv[i][j] = 0.999 * vv[i][j] + 0.001 * db[i][j] * db[i][j]; 317 | b[i][j] -= rate * mm[i][j] / (Math.sqrt(vv[i][j]) + 1e-8); 318 | db[i][j] = 0; 319 | } 320 | } 321 | } -------------------------------------------------------------------------------- /src/QLearningAI.java: -------------------------------------------------------------------------------- 1 | package AI; 2 | import java.io.*; 3 | import java.util.*; 4 | import com.google.common.collect.*; 5 | import Game.*; 6 | import AI.*; 7 | import NeuralNetwork.*; 8 | public class QLearningAI extends AI 9 | { 10 | private double[] value = new double[65]; 11 | private int[] steps = new int[65]; 12 | private NeuralNetwork dqn; 13 | private HashBasedTable Q = HashBasedTable.create(); 14 | private LinkedList D = new LinkedList(); 15 | private double epsilon = 0.9, gamma = 0.99; 16 | private int maxSize = 1000, minibatchSize = 0; 17 | public QLearningAI(int h) 18 | { 19 | super(h); 20 | for (int i = 1; i <= 64; i++) 21 | value[i] = 0; 22 | value[1] = 0.9; 23 | value[2] = -0.6; 24 | value[3] = 0.1; 25 | value[4] = 0.1; 26 | value[5] = 0.1; 27 | value[6] = 0.1; 28 | value[7] = -0.6; 29 | value[8] = 0.9; 30 | value[9] = -0.6; 31 | value[10] = -0.8; 32 | value[11] = 0.05; 33 | value[12] = 0.05; 34 | value[13] = 0.05; 35 | value[14] = 0.05; 36 | value[15] = -0.8; 37 | value[16] = -0.6; 38 | value[17] = 0.1; 39 | value[18] = 0.05; 40 | value[19] = 0.01; 41 | value[20] = 0.01; 42 | value[21] = 0.01; 43 | value[22] = 0.01; 44 | value[23] = 0.05; 45 | value[24] = 0.1; 46 | value[25] = 0.1; 47 | value[26] = 0.05; 48 | value[27] = 0.01; 49 | value[28] = 0.01; 50 | value[29] = 0.01; 51 | value[30] = 0.01; 52 | value[31] = 0.05; 53 | value[32] = 0.1; 54 | value[33] = 0.1; 55 | value[34] = 0.05; 56 | value[35] = 0.01; 57 | value[36] = 0.01; 58 | value[37] = 0.01; 59 | value[38] = 0.01; 60 | value[39] = 0.05; 61 | value[40] = 0.1; 62 | value[41] = 0.1; 63 | value[42] = 0.05; 64 | value[43] = 0.01; 65 | value[44] = 0.01; 66 | value[45] = 0.01; 67 | value[46] = 0.01; 68 | value[47] = 0.05; 69 | value[48] = 0.1; 70 | value[49] = -0.6; 71 | value[50] = -0.8; 72 | value[51] = -0.05; 73 | value[52] = -0.05; 74 | value[53] = -0.05; 75 | value[54] = -0.05; 76 | value[55] = -0.8; 77 | value[56] = -0.6; 78 | value[57] = 0.9; 79 | value[58] = -0.6; 80 | value[59] = 0.1; 81 | value[60] = 0.1; 82 | value[61] = 0.1; 83 | value[62] = 0.1; 84 | value[63] = -0.6; 85 | value[64] = 0.9; 86 | } 87 | private double max(double a, double b) 88 | { 89 | if (a >= b) 90 | return a; 91 | return b; 92 | } 93 | private double evaluate(int[] x) 94 | { 95 | double sum = 0; 96 | for (int i = 1; i <= 8; i++) 97 | for (int j = 1; j <= 8; j++) 98 | { 99 | if (table.bigTable[i][j] != 0) continue; 100 | if (table.checkStep(i, j, hold)) 101 | sum += 0.01; 102 | if (table.checkStep(i, j, -hold)) 103 | sum -= 0.01; 104 | sum += value[(i - 1) * 8 + j] * table.bigTable[i][j] * hold; 105 | } 106 | return sum; 107 | } 108 | public int[][] epsilon_greedy_move(int[] now) 109 | { 110 | int action; 111 | int[] next, s1; 112 | int[][] out = new int[2][64]; 113 | double p = Math.random(), max = -100000, temp, reward = 0, reward_temp = 0, now_value = 0; 114 | int[] steps; 115 | UCTAI opponent; 116 | Integer[] st = Arrays.stream(now).boxed().toArray(Integer[]::new); 117 | Integer a; 118 | Double d; 119 | boolean flag = false; 120 | 121 | //if AI cannot move now, then jump to next moveable state 122 | now_value = evaluate(now); 123 | steps = Board.searchNext(now, hold); 124 | s1 = now; 125 | while (steps[0] == 0) 126 | { 127 | opponent = new UCTAI(s1, -hold); 128 | opponent.move(); 129 | s1 = opponent.table.getTable(); 130 | steps = Board.searchNext(s1, hold); 131 | if (Board.terminal(s1)) 132 | { 133 | flag = true; //AI is not moveable and the game is finish 134 | break; 135 | } 136 | } 137 | if (!flag) //AI can move originally or AI can move afterwards 138 | { 139 | reward_temp = evaluate(s1) - now_value; 140 | now = s1; 141 | now_value = evaluate(now); 142 | } 143 | out[0] = now; 144 | out[1] = now; 145 | st = Arrays.stream(now).boxed().toArray(Integer[]::new); 146 | 147 | //e-greedy policy 148 | if (steps[0] > 0) //AI is moveable 149 | { 150 | action = steps[1]; 151 | if (p <= epsilon) //do argmax Q(s, a) 152 | { 153 | for (int i = 1; i <= steps[0]; i++) 154 | { 155 | a = new Integer(steps[i]); 156 | if (Q.contains(st, a)) 157 | { 158 | temp = Q.get(st, a).doubleValue(); 159 | if (temp > max) 160 | { 161 | max = temp; 162 | action = steps[i]; 163 | } 164 | } 165 | else 166 | { 167 | temp = dqn.forward(now, steps[i]); 168 | d = new Double(temp); 169 | Q.put(st, a, d); 170 | if (temp > max) 171 | { 172 | max = temp; 173 | action = steps[i]; 174 | } 175 | } 176 | } 177 | } 178 | else //act randomly 179 | action = steps[(int)(Math.random() * steps[0]) + 1]; 180 | next = Board.nextState(now, action, hold); 181 | 182 | //opponent move 183 | opponent = new UCTAI(next, -hold); 184 | int opmove = opponent.move(); 185 | if (opmove >= 0) //opponent is moveable 186 | { 187 | next = opponent.table.getTable(); 188 | reward = evaluate(next) - now_value + reward_temp; 189 | } 190 | else //opponent is not moveable 191 | { 192 | reward = now_value + reward_temp; 193 | } 194 | out[1] = next; 195 | 196 | //store transition in memory D 197 | Transition tr = new Transition(now, next, action, reward); 198 | D.offer(tr); 199 | if (D.size() > maxSize) 200 | D.pollFirst(); 201 | } 202 | else 203 | { 204 | out[1] = now; 205 | } 206 | 207 | 208 | return out; 209 | } 210 | public double maxQ(int[] now) 211 | //max Q(s', a') 212 | { 213 | double max = -100000, temp; 214 | Integer[] st = Arrays.stream(now).boxed().toArray(Integer[]::new); 215 | Integer a; 216 | Double d; 217 | int[] nextStep = Board.searchNext(now, hold); 218 | if (nextStep[0] == 0) return 0.0; 219 | for (int i = 1; i <= nextStep[0]; i++) 220 | { 221 | a = new Integer(nextStep[i]); 222 | if (Q.contains(st, a)) 223 | { 224 | temp = Q.get(st, a).doubleValue(); 225 | if (temp > max) 226 | max = temp; 227 | } 228 | else 229 | { 230 | temp = dqn.forward(now, nextStep[i]); 231 | d = new Double(temp); 232 | Q.put(st, a, d); 233 | if (temp > max) 234 | max = temp; 235 | } 236 | } 237 | return max; 238 | } 239 | public Transition[] sampling() 240 | { 241 | Transition[] ans = new Transition[32]; 242 | int count = 0, size = D.size(), temp; 243 | if (size <= 32) 244 | { 245 | for (int i = 0; i < size; i++) 246 | ans[i] = D.get(i); 247 | minibatchSize = size; 248 | return ans; 249 | } 250 | HashSet set = new HashSet(); 251 | while (set.size() < 32) 252 | { 253 | temp = (int)(Math.random() * size); 254 | set.add(new Integer(temp)); 255 | } 256 | for (Integer i : set) 257 | { 258 | ans[count] = D.get(i.intValue()); 259 | count++; 260 | } 261 | minibatchSize = 32; 262 | return ans; 263 | } 264 | public void qLearning() 265 | { 266 | double reward, temp; 267 | Double d; 268 | Integer a; 269 | Integer[] st; 270 | double[] y = new double[32]; 271 | int[] now, next; 272 | int[][] out = new int[2][]; 273 | int[][] input = new int[32][64]; 274 | int[] act = new int[32]; 275 | Transition[] minibatch = new Transition[32]; 276 | 277 | if (hold == 1) 278 | dqn = new NeuralNetwork("black.txt"); 279 | else if (hold == -1) 280 | dqn = new NeuralNetwork("white.txt"); 281 | 282 | for (int episode = 1; episode <= 200; episode++) 283 | { 284 | now = Board.getInit(); 285 | System.out.printf("episode: %d\n", episode); 286 | for (int T = 1; T <= 200; T++) 287 | { 288 | //Act e-greedy policy and store the transition 289 | //Then set now = next 290 | out = epsilon_greedy_move(now); 291 | now = out[0]; 292 | next = out[1]; 293 | if ((Board.terminal(next)) || (now == next)) 294 | now = Board.getInit(); 295 | else 296 | now = next; 297 | 298 | 299 | //update 300 | if (D.size() > 0) 301 | { 302 | minibatch = sampling(); //sampling minibatch 303 | 304 | for (int i = 0; i < minibatchSize; i++) 305 | { 306 | st = Arrays.stream(minibatch[i].s1).boxed().toArray(Integer[]::new); 307 | a = new Integer(minibatch[i].action); 308 | input[i] = minibatch[i].s1; 309 | act[i] = minibatch[i].action; 310 | 311 | // yi = ri for terminal s_t+1 312 | // = ri + gamma * max Q(s', a') for non-terminal s_t+1 313 | if (Board.terminal(minibatch[i].s2)) 314 | y[i] = minibatch[i].reward; 315 | else 316 | y[i] = minibatch[i].reward + gamma * maxQ(minibatch[i].s2); 317 | } 318 | 319 | //gradient descent 320 | dqn.backward(input, act, y, minibatchSize); 321 | } 322 | } 323 | if ((episode % 50 == 0) && (episode > 0)) 324 | dqn.store(); 325 | } 326 | } 327 | @Override 328 | public int move() 329 | { 330 | int[] steps = new int[65]; 331 | int step = -1; 332 | double max = -100000, temp; 333 | Integer a; 334 | Double d; 335 | double p = Math.random(); 336 | int[] now = table.getTable(); 337 | Integer[] st = Arrays.stream(now).boxed().toArray(Integer[]::new); 338 | steps = table.nextSteps(hold); 339 | if (steps[0] > 0) 340 | { 341 | if (steps[0] == 1) 342 | step = steps[1]; 343 | else 344 | { 345 | if (p > epsilon) 346 | step = steps[(int)(Math.random()*steps[0]) + 1]; 347 | else 348 | { 349 | for (int i = 1; i <= steps[0]; i++) 350 | { 351 | a = new Integer(steps[i]); 352 | if (Q.contains(st, a)) 353 | { 354 | temp = Q.get(st, a).doubleValue(); 355 | if (temp > max) 356 | { 357 | max = temp; 358 | step = steps[i]; 359 | } 360 | else 361 | { 362 | temp = dqn.forward(now, steps[i]); 363 | d = new Double(temp); 364 | Q.put(st, a, d); 365 | if (temp > max) 366 | { 367 | max = temp; 368 | step = steps[i]; 369 | } 370 | } 371 | } 372 | } 373 | } 374 | } 375 | table.set(step / 8 + 1, step % 8 + 1, hold); 376 | return step; 377 | } 378 | return -1; 379 | } 380 | public static void main(String[] args) 381 | { 382 | int kk = (int)(Math.random() * 2); 383 | kk = (int)Math.pow(-1, kk); 384 | /*if (kk == 1) 385 | System.out.println("Training black..."); 386 | else 387 | System.out.println("Training white...");*/ 388 | System.out.println("Training black..."); 389 | QLearningAI ai = new QLearningAI(1); 390 | ai.qLearning(); 391 | System.out.println("Finish"); 392 | } 393 | } --------------------------------------------------------------------------------