近几年来,随着谷歌的阿尔法狗和阿尔法元的问世,蒙特卡洛树搜索(MCTS),作为一种不需要特定领域的先验知识的搜索算法逐渐被人们重视起来。其可以在无任何已知知识,而仅需要了解模拟规则和结束状态的情况下,得到非常好的策略。但是由于其盲目的搜索,其运行时间和对内存空间的需求成为衡量其性能的主要因素之一。随着计算机的计算能力提升,对于一些特定的状态空间较少的问题,MCTS可以在其中表现优秀。
五子棋是初学者熟悉MCTS的常用用例。
图中红色代表模拟结果中下一步棋子的可能位置,红色棋子越不透明,则该位置获胜概率越高。
参数:
1.棋盘维度:7*7
2.根节点模拟次数:一百万次
3.每次运行时间:约40秒
4.赢棋赋分:1分,输棋赋分:-1分
5. Cp = 2
【目前无做任何优化,尚有很大性能优化空间,请查看后续版本】
基本呈现对称分布
下图明显白棋的落子位置能够堵住黑棋,同时与已有白棋连城一条线
黑棋落子位置为两白棋之间,同时与已有黑棋连线
白棋落子位置为黑棋三字的两端,并选择了直接杀掉这三个子的一段
这里还存在bug,白棋没有选择挡住左面即将连成四子的黑棋
黑棋也没有选择即将连成四子的黑棋,怀疑判断输赢的函数出现问题
黑棋发现已有棋面无法继续,故在四周拓展
白棋阻挡黑棋赢棋
出现了bug,所有得分均为负分。
代码尚存在一些bug,暂时开源如下,在后续的版本中修改。
代码如下
1.MCTSwuziqi.java 主程序,负责调用生成棋盘
2.DrawChessBoard.java 负责画出棋盘,时间监听,调用MCTS
3.MCTS_01.java 负责进行MCTS模拟
MCTSwuziqi.java
package bwjiang; public class MCTSwuziqi { public static void main(String[] args) { // TODO Auto-generated method stub System.out.println("main init"); DrawChessBoard chessBoard = new DrawChessBoard(); //chessBoard.boardFrame.setVisible(true); //System.out.println(chessBoard.boardFrame.getChessmans()[0][0].getColor()); } }
DrawChessBoard.java
package bwjiang; import javax.imageio.ImageIO; import javax.swing.*; import javax.swing.border.*; import java.awt.*; import java.awt.event.*; import java.awt.geom.*; import java.io.*; public class DrawChessBoard { public static final int BLACKWIN = 1; public static final int WHITEWIN = 2; public static final int NOTWIN = 0; public static final int ALLFILLED = -1; public static final int BLACK = 1; public static final int NOCHESS = 0; public static final int WHITE = -1; public static final int WINLENGTH = 5; public BoardFrame boardFrame; public int rows = 9; JTextField rowsText; JLabel nowChessColor; public DrawChessBoard() { //棋盘窗体 this.boardFrame = new BoardFrame(); boardFrame.setVisible(true); } //棋盘内容 class BoardPanel extends JPanel implements MouseListener{ public Image boardImage;//棋盘边框 public int lastChessColor = WHITE; //记录全部落子的棋子类,第一项为行数,第二项为列数 public int[][] chessmans = new int[rows][rows]; public int[][] predictChessmans = new int[rows][rows]; public int maxReward = 0; public int nextChessColor = BLACK;//控制交换棋权 int FrameWidth;//窗体 int FrameHeight; int chessBoardX;//棋盘边框左上角 int chessBoardY; int realChessBoardX;//真正落子的左上角 int realChessBoardY; int deltaX;//间距 int deltaY; public int[][] getChessmans(){ return this.chessmans; } //棋盘边框 public BoardPanel() { try { boardImage = ImageIO.read(new File("res/boardFrame.jpg")); } catch (IOException e) { System.out.print("error: boardImage not exist"); return; } addMouseListener(this); } @Override //棋盘内部的画线与棋子 protected void paintComponent(Graphics g) { super.paintComponent(g); int imageWidth = boardImage.getWidth(this); int imageHeight = boardImage.getHeight(this); FrameWidth = getWidth(); FrameHeight = getHeight(); //将图片呈现在画布中间 chessBoardX = (FrameWidth - imageWidth)/2; chessBoardY = (FrameHeight - imageHeight)/2; g.drawImage(boardImage, chessBoardX, chessBoardY, null); //画出棋盘线 int margin = 40; deltaX = (imageWidth - 2*margin)/(rows-1); deltaY = (imageHeight - 2*margin)/(rows-1); realChessBoardX = chessBoardX+(imageWidth-deltaX*(rows-1))/2; realChessBoardY = chessBoardY+(imageHeight-deltaY*(rows-1))/2; for(int i=0; i<rows; i++) {//横线 g.drawLine(realChessBoardX, realChessBoardY+deltaY*i, realChessBoardX+deltaX*(rows-1), realChessBoardY+deltaY*i); } for(int i=0; i<rows; i++) {//竖线 g.drawLine(realChessBoardX+deltaX*i, realChessBoardY, realChessBoardX+deltaX*i, realChessBoardY+deltaY*(rows-1)); } //画棋子 int radius = (int)(deltaX*0.3); for(int col_i=0; col_i<rows; col_i++) { for(int row_j=0; row_j<rows; row_j++) { if(chessmans[col_i][row_j]!=NOCHESS) { if(chessmans[col_i][row_j] == WHITE) { g.setColor(Color.white); g.fillOval(realChessBoardX - radius + deltaX*col_i, realChessBoardY - radius + deltaY*row_j, 2*radius, 2*radius); }else if(chessmans[col_i][row_j] == BLACK) { g.setColor(Color.black); g.fillOval(realChessBoardX - radius + deltaX*col_i, realChessBoardY - radius + deltaY*row_j, 2*radius, 2*radius); } } if(predictChessmans[col_i][row_j] != 0) { if(predictChessmans[col_i][row_j] < 0) { predictChessmans[col_i][row_j] =0; } System.out.println(predictChessmans[col_i][row_j] + "--"+maxReward); float pingfangcof = (predictChessmans[col_i][row_j]/(float)maxReward)*(predictChessmans[col_i][row_j]/(float)maxReward); System.out.println(pingfangcof); g.setColor(new Color(255,0,0,(int)(255*pingfangcof))); g.fillOval(realChessBoardX - radius + deltaX*col_i, realChessBoardY - radius + deltaY*row_j, 2*radius, 2*radius); } } } System.out.println("更新棋盘状态"); } @Override //控制落子 public void mouseClicked(MouseEvent e) { // TODO Auto-generated method stub if(e.getButton() == MouseEvent.BUTTON1){ int x = e.getX(); //得到鼠标x坐标 int y = e.getY(); //得到鼠标y坐标 //检验是否在棋盘内 if(x>chessBoardX&&x<FrameWidth-chessBoardX&&y>chessBoardY&&y<FrameHeight-chessBoardY) { int col_i = (x-realChessBoardX+deltaX/2)/deltaX; int row_j = (y-realChessBoardY+deltaY/2)/deltaY; //检验当前位置是否有棋 if(chessmans[col_i][row_j] == NOCHESS) { if(nextChessColor==BLACK) { System.out.print("黑棋 "); nowChessColor.setText("白棋"); chessmans[col_i][row_j] = BLACK; nextChessColor = WHITE; }else if(nextChessColor==WHITE) { System.out.print("白棋 "); nowChessColor.setText("黑棋"); chessmans[col_i][row_j] = WHITE; nextChessColor = BLACK; } String banner = "鼠标当前点击位置的坐标是" + x + "," + y+" 对应棋盘坐标为"+col_i+","+row_j+" "; System.out.print(banner); int isWin = checkTerminal(chessmans); System.out.print("isWin:"+isWin); repaint(); ////////////////////////////////// //调用MCTS算法 //传回来一个记录每个位置落子概率的矩阵 MCTS_01 MCTS = new MCTS_01(chessmans); System.out.println("行走策略"+MCTS.getProbabilityMap()); int[][] tempMap = new int[rows][rows]; tempMap = MCTS.getProbabilityMap(); maxReward = 0; for(int i=0; i<rows; i++) { for(int j=0; j<rows; j++) { System.out.print(" "+-tempMap[i][j]); predictChessmans[i][j] = -tempMap[i][j]; if(-tempMap[i][j]>maxReward) { maxReward = -tempMap[i][j]; } } System.out.println(); } ////////////////////////////////// }else { System.out.println("当前位置有棋"); } }else { System.out.println("鼠标位置非法"); } } } @Override public void mouseEntered(MouseEvent arg0) { // TODO Auto-generated method stub } @Override public void mouseExited(MouseEvent arg0) { // TODO Auto-generated method stub } @Override public void mousePressed(MouseEvent e) { // TODO Auto-generated method stub } @Override public void mouseReleased(MouseEvent arg0) { // TODO Auto-generated method stub } public void refreshChessBoard() { if(Integer.parseInt(rowsText.getText())>30 || Integer.parseInt(rowsText.getText())<5) { System.out.println("棋盘维度错误"); return; } for(int row_i=0; row_i<rows; row_i++) { for(int col_j=0; col_j<rows; col_j++) { chessmans[row_i][col_j] = NOCHESS; } } nowChessColor.setText("黑棋"); nextChessColor = BLACK; rows = Integer.parseInt(rowsText.getText()); chessmans = new int[rows][rows]; predictChessmans = new int[rows][rows]; repaint(); } public void restart() { for(int row_i=0; row_i<rows; row_i++) { for(int col_j=0; col_j<rows; col_j++) { chessmans[row_i][col_j] = NOCHESS; predictChessmans[rows][rows] = NOCHESS; } } nowChessColor.setText("黑棋"); nextChessColor = BLACK; repaint(); } } class BoardFrame extends JFrame{ private BoardPanel boardPanel; JButton refreshBoard; JButton restartBoard; public int[][] getChessmans(){ return boardPanel.chessmans; } public BoardFrame() { System.out.println("BoardFrame init"); setTitle("MCTS五子棋"); setSize(800,800); Container containerPane = getContentPane(); JPanel BoardInfo = new JPanel(); BoardInfo.setLayout(new FlowLayout(FlowLayout.CENTER)); BoardInfo.setSize(800, 100); JLabel rowsLabel = new JLabel("棋盘维度"); rowsText = new JTextField(5); rowsText.setText("19"); refreshBoard = new JButton("修改棋盘维度"); restartBoard = new JButton("重新开始"); JLabel nowChessColorLabel = new JLabel("当前轮到:"); nowChessColor = new JLabel("黑棋"); BoardInfo.add(rowsLabel); BoardInfo.add(rowsText); BoardInfo.add(refreshBoard); BoardInfo.add(restartBoard); BoardInfo.add(nowChessColorLabel); BoardInfo.add(nowChessColor); BoardInfo.setBorder(new TitledBorder("棋盘信息栏")); containerPane.add(BoardInfo, BorderLayout.NORTH); boardPanel = new BoardPanel(); containerPane.add(boardPanel); AddActionListener(); } private void AddActionListener() { refreshBoard.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { boardPanel.refreshChessBoard(); } }); restartBoard.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { boardPanel.restart(); } }); } } class Chessman{ private int color; //WHITE=1;BLACK=0 private boolean placed; public Chessman(int color, boolean placed) { this.color = color; this.placed = placed; } public boolean getPlaced() { return placed; } public void setPlaced(boolean placed) { this.placed = placed; } public int getColor() { return color; } } //检查当前结点是否为终端结点 public int checkTerminal(int[][] chessmans) { int totalExistChessman = 0; for(int row_i=0; row_i<rows; row_i++) { for(int col_j=0; col_j<rows; col_j++) { if(chessmans[row_i][col_j] != NOCHESS) { System.out.print("("+col_j+","+row_i+") "); int firstColor = chessmans[row_i][col_j]; int[] distance = new int[4]; //向右判断(包含向左) if(col_j+WINLENGTH-1<rows) { while(firstColor == chessmans[row_i][col_j+distance[0]]) { distance[0]++; } } //向下判断(包含向上) if(row_i+WINLENGTH-1 < rows) { while(firstColor == chessmans[row_i+distance[1]][col_j]) { distance[1]++; } } //向右上判断(包含左下) if(row_i-WINLENGTH+1 >= 0 && col_j+WINLENGTH-1 < rows) { while(firstColor == chessmans[row_i-distance[2]][col_j+distance[2]]) { distance[2]++; } } //向右下判断(包含左上) if(row_i+WINLENGTH-1 < rows && col_j+WINLENGTH-1 < rows) { while(firstColor == chessmans[row_i+distance[3]][col_j+distance[3]]) { distance[3]++; } } for(int distance_num=0; distance_num<4;distance_num++) { if(WINLENGTH == distance[distance_num]) { if(firstColor == BLACK) { return BLACKWIN; }else { return WHITEWIN; } } } totalExistChessman++; } } } if(rows*rows == totalExistChessman) { return ALLFILLED; } return NOTWIN; } }
MCTS_01.java
package bwjiang; import bwjiang.DrawChessBoard.Chessman; public class MCTS_01 { public static final int BUDGET = 1000000; public static final int BLACKWIN = 1; public static final int WHITEWIN = 2; public static final int NOTWIN = 0; public static final int ALLFILLED = -1; public static final int BLACK = 1; public static final int NOCHESS = 0; public static final int WHITE = -1; public static final int ERROR = -999; public static final int WINLENGTH = 5; public static final int UCTCP = 2; public int rows; Node rootNode; public MCTS_01(int[][] chessmans) { rows = chessmans.length; rootNode = new Node(null, chessmans); // System.out.println(); // for(int col_i=0; col_i<row; col_i++) { // for(int row_j=0; row_j<row; row_j++) { // System.out.print(" "+probabilityMap[col_i][row_j]+"-"+accessNumberMap[col_i][row_j]); // } // System.out.println(); // } while (rootNode.accessNumber<BUDGET) { Node node = treePolicy(rootNode); int reward = defaultPolicy(node); backup(node, reward); } //return bestChild(rootNode); } class Node { public int[][] chessmans;//当前结点棋盘状态 public int[][] rewardMap;//当前棋盘状态结点,落子的概率图 public int[][] accessNumberMap;//当前棋盘状态结点,访问数量 public int accessNumber;//当前节点总访问数 public int reward; //下一节点集合 public Node[][] nextNode; public Node lastNode; //本节点的最后一个棋子 public int[] chessYX; //下一节点颜色 int nextColor; public Node(Node lastNode, int[][] chessmans) { this.chessmans = chessmans; this.lastNode = lastNode; this.nextColor = nextChessColor(chessmans); nextNode = new Node[rows][rows]; accessNumber = 0; reward = 0; rewardMap = new int[rows][rows]; accessNumberMap = new int[rows][rows]; chessYX = new int[2]; } public void setChessYX(int Y, int X) { this.chessYX[0] = Y; this.chessYX[1] = X; } } public int[][] getProbabilityMap() { return rootNode.rewardMap; } public Node treePolicy(Node node) { while(node != null) { for(int col_i=0; col_i<rows; col_i++) { for(int row_j=0; row_j<rows; row_j++) { //System.out.println("["+col_i+","+row_j+"]"); if(null == node.nextNode[col_i][row_j] && NOCHESS == node.chessmans[col_i][row_j]) { //System.out.println("hi"); return expand(node, col_i, row_j); } } } node = bestChild(node); } return node; } //扩张下一个为扩展的动作 public Node expand(Node node, int Y, int X) { int[][] tempChessmans = new int[rows][rows]; for(int col_i=0; col_i<rows; col_i++) { for(int row_j=0; row_j<rows; row_j++) { tempChessmans[col_i][row_j] = node.chessmans[col_i][row_j]; //System.out.print(tempChessmans[col_i][row_j]); } //System.out.println(); } int tempNextColor = node.nextColor; tempChessmans[Y][X] = tempNextColor; node.nextNode[Y][X] = new Node(node, tempChessmans); node.nextNode[Y][X].setChessYX(Y, X); return node.nextNode[Y][X]; } //UCT部分精髓 public Node bestChild(Node node) { double maxUCT = 0; int maxUCTcol_i = -1; int maxUCTrow_j = -1; for(int col_i=0; col_i<rows; col_i++) { for(int row_j=0; row_j<rows; row_j++) { if(NOCHESS == node.chessmans[col_i][row_j]) { double tempMaxUCT = node.rewardMap[col_i][row_j]/node.accessNumberMap[col_i][row_j] + UCTCP * Math.sqrt(2*Math.log(node.accessNumber)/node.accessNumberMap[col_i][row_j]); if(tempMaxUCT > maxUCT) { maxUCT = tempMaxUCT; maxUCTcol_i = col_i; maxUCTrow_j = row_j; } } } } if(maxUCTcol_i == -1 || maxUCTrow_j == -1) { System.out.println("error:没找到极大值UCT"); for(int col_i=0; col_i<rows; col_i++) { for(int row_j=0; row_j<rows; row_j++) { if(NOCHESS == node.chessmans[col_i][row_j]) { double tempMaxUCT = node.rewardMap[col_i][row_j]/node.accessNumberMap[col_i][row_j] + UCTCP * Math.sqrt(2*Math.log(node.accessNumber)/node.accessNumberMap[col_i][row_j]); System.out.println("error:"+tempMaxUCT+"="+node.rewardMap[col_i][row_j]+"/"+node.accessNumberMap[col_i][row_j] +"+"+ UCTCP +"*" + "Math.sqrt(2*"+Math.log(node.accessNumber)+"/"+node.accessNumberMap[col_i][row_j]+")"); } } } } return node.nextNode[maxUCTcol_i][maxUCTrow_j]; } //默认策略 public int defaultPolicy(Node node) { //System.out.println("默认"); ///这里不知道是否需要把chessmans复制一下 int nextColor = node.nextColor; int[][] tempChessmans = new int[rows][rows]; for(int col_i=0; col_i<rows; col_i++) { for(int row_j=0; row_j<rows; row_j++) { tempChessmans[col_i][row_j] = node.chessmans[col_i][row_j]; } } while(NOTWIN == checkTerminal(tempChessmans)) { int[] nextYX = new int[2]; nextYX = randomNext(tempChessmans); tempChessmans[nextYX[0]][nextYX[1]] = nextColor; if(nextColor == BLACK) { nextColor = WHITE; }else{ nextColor = BLACK; } } return stateReward(tempChessmans, node.nextColor); } //对于该棋盘状态的奖励 public int stateReward(int[][] chessmans, int nextColor) { final int win = 2; final int loss = 0; final int flat = 1; if(checkTerminal(chessmans) == BLACKWIN) { if(nextColor == BLACK) { return win; }else { return loss; } }else if(checkTerminal(chessmans) == WHITEWIN){ if(nextColor == WHITE) { return win; }else { return loss; } }else { return flat; } } //奖励回传 public void backup(Node lastNode, int reward) { //me为1,him为-1 int meOrHim = 1; //最后一个结点不需要记录其子节点的访问 lastNode.accessNumber++; lastNode.reward +=meOrHim*reward; int[] chessYX = lastNode.chessYX; lastNode = lastNode.lastNode; while(lastNode != null) { //最后一个结点之前的所有节点都要记录其子节点的访问 lastNode.accessNumberMap[chessYX[0]][chessYX[1]]++; lastNode.accessNumber++; lastNode.rewardMap[chessYX[0]][chessYX[1]] += meOrHim*reward; if(1 == meOrHim) { meOrHim = -1; }else { meOrHim = 1; } lastNode.reward += meOrHim*reward; //System.out.println(meOrHim); chessYX = lastNode.chessYX; lastNode = lastNode.lastNode; } } //返回当前棋盘的下一个棋子的颜色 public int nextChessColor(int[][] chessmans) { int whiteNumber = 0; int blackNumber = 0; // for(int col_i=0; col_i<rows; col_i++) { // for(int row_j=0; row_j<rows; row_j++) { // // System.out.print(chessmans[col_i][row_j]); // } // System.out.println(); // } // System.out.println(); for(int col_i=0; col_i<rows; col_i++) { for(int row_j=0; row_j<rows; row_j++) { if(chessmans[col_i][row_j] == BLACK) { //System.out.print("@"); blackNumber++; } if(chessmans[col_i][row_j] == WHITE) { //System.out.print("o"); whiteNumber++; } } } if(blackNumber-whiteNumber == 1) { //System.out.print("@"); return WHITE; }else if(blackNumber-whiteNumber == 0) { //System.out.print("o"); //System.out.println("白黑一样,下一个黑"); return BLACK; }else { System.out.println("error:棋盘棋子不同颜色数量差值不为1或0,为 黑"+blackNumber+" 白"+whiteNumber); return ERROR; } } //按照均等概率随机返回一个当前棋盘的落子位置的xy坐标 public int[] randomNext(int[][] chessmans) { int[] nextYX = new int[2]; int totalNumber = 0; for(int col_i=0; col_i<rows; col_i++) { for(int row_j=0; row_j<rows; row_j++) { if(chessmans[col_i][row_j] != NOCHESS) { totalNumber++; } } } int remainNumber = rows*rows - totalNumber; //System.out.print("随机生成新棋子:"+remainNumber+" = "+rows+"*"+rows+"-"+totalNumber); int randomNumber = (int)(Math.random()*remainNumber); //System.out.print(randomNumber); int nowNumber = 0; for(int col_i=0; col_i<rows; col_i++) { for(int row_j=0; row_j<rows; row_j++) { if(chessmans[col_i][row_j] == NOCHESS) { if(nowNumber == randomNumber) { //System.out.println("插入新棋子"+"("+col_i+","+row_j+")"); nextYX[0] = col_i; nextYX[1] = row_j; return nextYX; } nowNumber++; } } } System.out.println("error:随机数超出总剩余棋子数"); return nextYX; } //检查当前结点是否为终端结点 public int checkTerminal(int[][] chessmans) { int totalExistChessman = 0; for(int col_i=0; col_i<rows; col_i++) { for(int row_j=0; row_j<rows; row_j++) { if(chessmans[col_i][row_j] != NOCHESS) { //System.out.print("("+row_j+","+col_i+") "); int firstColor = chessmans[col_i][row_j]; int[] distance = new int[4]; //向右判断(包含向左) if(row_j+WINLENGTH-1<rows) { while(distance[0]<5 && firstColor == chessmans[col_i][row_j+distance[0]]) { distance[0]++; } } //向下判断(包含向上) if(col_i+WINLENGTH-1 < rows) { while(distance[1]<5 && firstColor == chessmans[col_i+distance[1]][row_j]) { distance[1]++; } } //向右上判断(包含左下) if(col_i-WINLENGTH+1 >= 0 && row_j+WINLENGTH-1 < rows) { while(distance[2]<5 && firstColor == chessmans[col_i-distance[2]][row_j+distance[2]]) { distance[2]++; } } //向右下判断(包含左上) if(col_i+WINLENGTH-1 < rows && row_j+WINLENGTH-1 < rows) { while(distance[3]<5 && firstColor == chessmans[col_i+distance[3]][row_j+distance[3]]) { distance[3]++; } } for(int distance_num=0; distance_num<4;distance_num++) { if(WINLENGTH == distance[distance_num]) { if(firstColor == BLACK) { return BLACKWIN; }else { return WHITEWIN; } } } totalExistChessman++; } } } if(rows*rows == totalExistChessman) { return ALLFILLED; } return NOTWIN; } }