蒙特卡洛树搜索算法(实现 TicTacToe 机-机对战)

                     

贡献者: int256

预备知识 蒙特卡洛树搜索算法(理论)

   本文使用了 python 3.7 中提供的特性(future 中的 annotations)与 numpy 库。使用版本 pytohn 3.12.2 与 numpy 1.26.4 可以确保无问题运行。全部代码可以在笔者的 Github 仓库查看: https://github.com/tripleInt/MCTS-TTT/,或 Pastebin https://paste.ubuntu.com/p/NQtdDSx3sY/

   前面已经讨论过蒙特卡洛树搜索算法的理论,下面通过讲解例题进行实战练习,这利于我们更深入地理解这算法。首先回顾例题:

例 1 

   使用蒙特卡洛树搜索算法实现一个机器 vs. 机器的 Tic-Tac-Toe 井字棋对战游戏。Tic-Tac-Toe 有以下规则:

  1. 井字棋棋盘是 $n \times n$ 的正方形网格棋盘,例如下面是一个有一些棋子的 $7 \times 7$ 的棋盘:
    图
    图 1:棋盘例
  2. 游戏的下法是有 $ \boldsymbol\times $、$\bigcirc$ 两方,每次都可以在没有棋子的正方形内部落子。输赢定义为:最先有连续 $m$ 个我方棋子出现的一方获胜。
  3. 连续的定义是:横向、纵向,或两个 $45^\circ$ 对角线方向。

   你需要使得以下内容是可以自定义的:

  1. $n$,即棋盘大小可以自定义。
  2. $m$,即输赢的(棋子连续数)条件可以自定义。
  3. 先下棋的一方可以自定义。也就是谁第一步下棋可以自定义。

   这是一个工程问题,不用考虑时间限制。你需要输出机器对战的中间过程、最终的结果棋盘与谁赢得了这场对战。

   在实现蒙特卡洛树搜索这个算法前,我们需要先做一些准备,定义好游戏的各种内容类。

1. 预备工作——游戏规则相关

   因为要处理连续问题,我们可以使用 “求和” 的方式检查,故可以取特殊值 $1$ 与 $-1$ 表示两种棋子 $\bigcirc$ 和 $ \boldsymbol\times $。定义棋子类表示棋子:

class Chess:
    def __init__(self, name: str, val: int) -> None:
        self.name: str = name # X/O
        self.val: int = val

    def __repr__(self) -> str:
        return f"Chess Object({str(self.name)}, {str(self.val)})"

   其中 __init__ 方法相当于是构造函数,__repr__ 方法提供了一个将类转化为 str 的方式。具体原理是利用了 python 的魔法方法。

   然后对于每一步操作,我们可以考虑为是落点与棋子类型的组合。故定义一个操作类 Move

# 操作
class Move:
    def __init__(self, x: int, y: int, chess: Chess) -> None:
        self.x: int = x
        self.y: int = y
        self.chess: Chess = chess

    def __repr__(self) -> str:
        return (
            "Move Object[("
            + str(self.x)
            + ", "
            + str(self.y)
            + ") , "
            + str(self.chess)
            + "]"
        )
然后就可以定义以下常量便于我们在后面使用:
STATUS = {0: " ", 1: "O", -1: "X"}
X = Chess(STATUS[-1], -1)
O = Chess(STATUS[1], 1)
其中 STATUS 常量存储了后面棋盘中每个位置的数字代表这个位置的状态的情况,XO 分别对应 $ \boldsymbol\times $、$\bigcirc$ 两种棋子。

2. 搜索的状态——棋盘

   显然对于这个问题来说,搜索的状态应该是当前棋盘的情况。我们定义一个表示状态的类 State 并实现一些方法帮助我们在后面进行搜索。 首先考虑其构造函数,需要记录的信息,显然有当前棋盘的情况(使用 numpy 提供的 np.array 来表示)、下一步应当哪方下棋。我们额外开辟一个属性用来记录需要多少连续棋子可以赢得这场游戏。故可以写出 __init__ 方法:

class State:
    def __init__(self, nxtMove,
                checkerboardStat: np.array,
                winNeed: int = -1) -> None:
        """
        Args:
                nxtMove: 接下来该谁下棋了
                checkerboardStat (2 D 网格棋盘):
                    棋盘状态
                winNeed (int, optional):
                    连续多少个棋子可获得胜利. Defaults to -1.
        """
        if len(checkerboardStat.shape) != 2:
            raise Exception(
                "checkerboardStat must be 2D array")

        if (checkerboardStat.shape[0] !=
            checkerboardStat.shape[1]):
            raise Exception(
                "checkerboardStat must be square")

        self.checkerboard: np.array = checkerboardStat
        if winNeed == -1:
            winNeed = self.checkerboard.shape[0]
        self.winNeed = winNeed
        self.nxtMove: Chess = nxtMove
在声明属性的时候尽量使用 “属性名: 类型=值” 的方法,这有助于我们后续实现代码(这样一般的编辑器可以为我们更好地提示代码补全)。

   我们经常会需要获取棋盘的形状(即大小),故再定义一个属性用来表示棋盘大小。这里使用装饰器 @property,这类似于定义了一个属性的 get

@property
def checkerboardSize(self):
    return self.checkerboard.shape[0]

   接下来考虑到我们需要判断当前游戏局势(判断游戏是否结束),故类似的定义一个 @property 装饰的 result 表示当前局势与 isOver 表示是否结束(已经不能继续下棋):

@property
def result(self):
    """判断游戏结果

    Returns:
            Chess | 0 | None: 若返回 0 代表游戏平局
                        返回 None 表示游戏还未结束,
                        否则返回 X/O(Chess 对象) 表示赢家。
    """
    # 横纵连续
    for i in range(self.checkerboardSize - self.winNeed + 1):
        xSum = np.sum(self.checkerboard[i : i + self.winNeed, :], axis=0)
        ySum = np.sum(self.checkerboard[:, i : i + self.winNeed], axis=1)

        if xSum.min() == -self.winNeed or ySum.min() == -self.winNeed:
            return X
        if xSum.max() == self.winNeed or ySum.max() == self.winNeed:
            return O

    # 对角线连续
    for i in range(self.checkerboardSize - self.winNeed + 1):
        for j in range(self.checkerboardSize - self.winNeed + 1):
            subCheckerboard = self.checkerboard[
                i : i + self.winNeed, j : j + self.winNeed
            ]
            # 两条斜向对角线
            diag1Sum, diag2Sum = (
                subCheckerboard.trace(),
                np.fliplr(subCheckerboard).trace(),
            )

            if diag1Sum == -self.winNeed or diag2Sum == -self.winNeed:
                return X
            if diag1Sum == self.winNeed or diag2Sum == self.winNeed:
                return O

    if np.all(self.checkerboard != 0):
        # 平局
        return 0

    # 游戏还未结束
    return None

@property
def isOver(self) -> bool:
    """游戏是否结束

    Returns:
            bool: 游戏是否结束
    """
    return self.result is not None
这里使用了一个小技巧,对角线对应矩阵的迹。同时使用 np.minnp.max 来帮助我们通过求和解决判断是否有棋子连续到足够个数的一方。

走法

   接下来需要考虑下棋操作。首先考虑搜索需要用到当前的所有可能的走法,故这里编写一个方法来实现这个功能:

def getMoves(self) -> List[Move]:
    """获取所有可能的走法

    Returns:
            List[Move]
    """
    return [
        Move(d[0], d[1], self.nxtMove)
        for d in list(zip(*np.where(self.checkerboard == 0)))
    ]

   然后考虑需要判断某种走法对当前的棋盘来说是否合法,由位置(是否在棋盘内、该位置是否有棋子)和这种走法下棋的一方是否是State 记录下来的将要下棋的一方共同决定:

def couldMove(self, move: Move) -> bool:
    """判断走法是否合法

    Args:
            move (Move)

    Returns:
            bool: 是否合法
    """
    if move.chess != self.nxtMove:
        # 下棋的一方不对
        return False

    if not (
        0 <= move.x < self.checkerboardSize and
        0 <= move.y < self.checkerboardSize
    ):
        # 位置不合法
        return False

    # 这位置还没有棋子
    return self.checkerboard[move.x, move.y] == 0

   以及需要根据走法获取一个在当前棋盘进行该走法后的下一状态,这也可以编写一个方法来实现,这里需要注意是返回一个新的 State 对象,用到了 python 3.7 的 future 库中的特性。同时要分清什么时候使用 `self.checkerboard`,什么时候是更新返回的 newCheckerboard

# def doMove(self, move):
def doMove(self, move: Move) -> State:  # Python 3.7 need(PEP 563)
    if not self.couldMove(move):
        raise Exception("Move must be legel")

    newCheckerboard = self.checkerboard.copy()
    newCheckerboard[move.x, move.y] = move.chess.val

    if self.nxtMove == X:
        nxtMove = O
    elif self.nxtMove == O:
        nxtMove = X

    # return type(self)(nxtMove, newCheckerboard, self.winNeed)

    # Python 3.7 need(PEP 563)
    return State(nxtMove, newCheckerboard, self.winNeed)

可视化/输出

   最后实现一个方法来输出当前棋盘状态(这是题目要求的):

def show(self, outputFn: function = print) -> None:
    """显示当前棋盘状态

    Args:
            outputFn (function, optional):
                输出的函数. Defaults to print.
    """

    board = np.copy(self.checkerboard)

    def strLines(r):
        return (" "
                + " | ".join(map(
                    lambda x: STATUS.get(int(x), " "), r))
                + " ")

    for r in board[:-1]:
        outputFn(strLines(r))
        outputFn("-" * (len(r) * 4 - 1))

    outputFn(strLines(board[-1]))
    outputFn()
对于不同的电脑环境、不同的终端、不同的字体可能每个字符的长度不同,需要适当调整这里的第 $19$ 行。

3. 蒙特卡洛树搜索——结点

   接下来考虑蒙特卡洛树的每个结点。将其定义为一个类 MCTSNode

# 结点
class MCTSNode:
    def __init__(self, stat, fa=None):
        """
        Args:
                stat (State): 结点对应状态
                fa (MCTSNode, optional): 父结点. Defaults to None.
        """
        self.stat: State = stat
        self.fa: MCTSNode = fa

        # 子结点列表
        self.sons: List[MCTSNode] = []

        self._visits = 0  # 已访问过结点
        self._results = {}
        self._notTried = None
    
    @property
    def isFullyExpanded(self):
        return len(self.notTried) == 0
            
    @property
    def notTried(self):
        if self._notTried is None:
            self._notTried = self.stat.getMoves()

            # 通过打乱实现“随机”
            shuffle(self._notTried)
        return self._notTried
    
    @property
    def isEnd(self) -> bool:
        """是否是终端结点(叶子结点)
        """
        return self.stat.isOver
其中 shuffle 可以直接使用 random 库提供的 random.shuffle

UCB

   接下来是实现蒙特卡洛树搜索算法中的通过 UCB 选择子节点扩展:

def bestSon(self, c=1.5):
    return self.sons[
        np.argmax(
            [
                (nod.q / nod.n) + c* np.sqrt((2 * np.log(self.n)) / nod.n)
                for nod in self.sons
            ]
        )
    ]

@property
def q(self):
    v = self.fa.stat.nxtMove.val
    return (
        self._results.get(v, 0)
        - self._results.get(-1 * v, 0)
    )

@property
def n(self):
    return self._visits
其中 q 就是 UCB 公式式 1 中的 $N_r$。这里不进行 $+1$ 与 $\times 2$ 的修正(实际效果与修正后是一样的)。

   需要注意这里的 q 应当使用 dict.get 的方式,因为有可能还没有更新过,否则就要使用 defaultdict 来定义 self._results

各种操作

   下面实现蒙特卡洛搜索需要的各种操作(Expand、Simulation 对应 Rollout,Back Propagate)。

def expand(self) -> MCTSNode:
    stat = self.stat.doMove(self.notTried.pop())

    son: MCTSNode = MCTSNode(stat, self)
    self.sons.append(son)
    return son

def rollout(self):
    stat: State = self.stat
    while not stat.isOver:
        stat = stat.doMove(np.random.choice(stat.getMoves()))
    return stat.result

def backpropagate(self, result):
    self._visits += 1
    self._results[result] = self._results.get(result, 0) + 1
    if self.fa is not None:
        self.fa.backpropagate(result)
这里反向传播使用了尾递归的方式,也可以使用 while 循环的方法。

   同样需要注意这里 backpropagate 也使用的是 dict.get 方法。

4. 蒙特卡洛树搜索

   最后是实现搜索树。对于一棵树我们往往只需要记录根节点就获得了所有信息。

class MCTS(object):
    def __init__(self, rootNod: MCTSNode):
        """蒙特卡洛树

        Args:
                rootNod (MCTSNode): 根结点
        """
        self.rootNod: MCTSNode = rootNod

   接下来是实现蒙特卡洛树搜索中的操作。

选择扩展的结点

   可以直接根据前面的 MCTSNode.bestSon() 方法获得到要选择的结点,一直 while 循环到叶子结点就可以了:

def chooseNod(self) -> MCTSNode:
    """选择要扩展的结点

    Returns:
            MCTSNode
    """

    cur = self.rootNod
    
    # 递归到叶子结点并返回
    while not cur.isEnd:
        if (not cur.isFullyExpanded):
            return cur.expand()
        else:
            cur = cur.bestSon()
    return cur

搜索当前结点,找最好的操作

   根据 UCB 算法进行搜索扩展找最好的操作。

def bstAction(self, simulationTimes: int = None, duration: float = None):
    """根据 UCB 算法进行搜索扩展,找到最佳操作

    Args:
            simulationTimes (int, optional): 
                为找到最佳操作已经模拟的次数. Defaults to None.
            duration (float, optional): 
                算法搜索的时间(秒). Defaults to None.
    """

    if simulationTimes is None:
        if duration is None:
            raise Exception("duration must be set")
        endTime: float = time.time() + duration
        while time.time() <= endTime:
            nod = self.chooseNod()
            nod.backpropagate(nod.rollout())
    else:
        for _ in range(simulationTimes):
            nod = self.chooseNod()
            nod.backpropagate(nod.rollout())

    # 展开
    return self.rootNod.bestSon(c=0.0)
这里使用时间限制搜索扩展。

5. 实现程序主体

def main():
    # 棋盘
    board_size = 7
    checkerboardStat = np.zeros((board_size, board_size), dtype=int)
    # 游戏
    game = State(X, checkerboardStat, 4)
    # 蒙特卡洛树搜索
    while game.result is None:
        game.show()
        mcts = MCTS(MCTSNode(game))
        bstNod = mcts.bstAction(simulationTimes=2)
        game = bstNod.stat

    result = game.result
    if type(result) == Chess:
        print("Game Over! Winner is:" + STATUS.get(result.val, "Unknown"))
    else:
        print("Game Over! Tie!")

    print("End At: ")
    game.show()


if __name__ == "__main__":
    main()
每次选择最佳走法进行更新即可。这里限制搜索扩展时间为 $2 \,\mathrm{s} $。

   全部代码整合后的可以在笔者的 Github 仓库查看: https://github.com/tripleInt/MCTS-TTT/,或 Pastebin:https://paste.ubuntu.com/p/NQtdDSx3sY/

                     

© 小时科技 保留一切权利