贡献者: int256
本文使用了 python 3.7 中提供的特性(future
中的 annotations
)与 numpy
库。使用版本 pytohn 3.12.2 与 numpy 1.26.4 可以确保无问题运行。全部代码可以在笔者的 Github 仓库查看:
https://github.com/tripleInt/MCTS-TTT/,或 Pastebin https://paste.ubuntu.com/p/NQtdDSx3sY/。
前面已经讨论过蒙特卡洛树搜索算法的理论,下面通过讲解例题进行实战练习,这利于我们更深入地理解这算法。首先回顾例题:
在实现蒙特卡洛树搜索这个算法前,我们需要先做一些准备,定义好游戏的各种内容类。
因为要处理连续问题,我们可以使用 “求和” 的方式检查,故可以取特殊值 $1$ 与 $-1$ 表示两种棋子 $\bigcirc$ 和 $ \boldsymbol\times $。定义棋子类表示棋子:
class Chess:
def __init__(self, name: str, val: int) -> None:
self.name: str = name # X/O
self.val: int = val
def __repr__(self) -> str:
return f"Chess Object({str(self.name)}, {str(self.val)})"
其中 __init__
方法相当于是构造函数,__repr__
方法提供了一个将类转化为 str
的方式。具体原理是利用了 python 的魔法方法。
然后对于每一步操作,我们可以考虑为是落点与棋子类型的组合。故定义一个操作类 Move
:
# 操作
class Move:
def __init__(self, x: int, y: int, chess: Chess) -> None:
self.x: int = x
self.y: int = y
self.chess: Chess = chess
def __repr__(self) -> str:
return (
"Move Object[("
+ str(self.x)
+ ", "
+ str(self.y)
+ ") , "
+ str(self.chess)
+ "]"
)
然后就可以定义以下常量便于我们在后面使用:
STATUS = {0: " ", 1: "O", -1: "X"}
X = Chess(STATUS[-1], -1)
O = Chess(STATUS[1], 1)
其中 STATUS
常量存储了后面棋盘中每个位置的数字代表这个位置的状态的情况,X
、O
分别对应 $ \boldsymbol\times $、$\bigcirc$ 两种棋子。
显然对于这个问题来说,搜索的状态应该是当前棋盘的情况。我们定义一个表示状态的类 State
并实现一些方法帮助我们在后面进行搜索。
首先考虑其构造函数,需要记录的信息,显然有当前棋盘的情况(使用 numpy
提供的 np.array
来表示)、下一步应当哪方下棋。我们额外开辟一个属性用来记录需要多少连续棋子可以赢得这场游戏。故可以写出 __init__
方法:
class State:
def __init__(self, nxtMove,
checkerboardStat: np.array,
winNeed: int = -1) -> None:
"""
Args:
nxtMove: 接下来该谁下棋了
checkerboardStat (2 D 网格棋盘):
棋盘状态
winNeed (int, optional):
连续多少个棋子可获得胜利. Defaults to -1.
"""
if len(checkerboardStat.shape) != 2:
raise Exception(
"checkerboardStat must be 2D array")
if (checkerboardStat.shape[0] !=
checkerboardStat.shape[1]):
raise Exception(
"checkerboardStat must be square")
self.checkerboard: np.array = checkerboardStat
if winNeed == -1:
winNeed = self.checkerboard.shape[0]
self.winNeed = winNeed
self.nxtMove: Chess = nxtMove
在声明属性的时候尽量使用 “属性名: 类型=值” 的方法,这有助于我们后续实现代码(这样一般的编辑器可以为我们更好地提示代码补全)。
我们经常会需要获取棋盘的形状(即大小),故再定义一个属性用来表示棋盘大小。这里使用装饰器 @property
,这类似于定义了一个属性的 get
。
@property
def checkerboardSize(self):
return self.checkerboard.shape[0]
接下来考虑到我们需要判断当前游戏局势(判断游戏是否结束),故类似的定义一个 @property
装饰的 result
表示当前局势与 isOver
表示是否结束(已经不能继续下棋):
@property
def result(self):
"""判断游戏结果
Returns:
Chess | 0 | None: 若返回 0 代表游戏平局
返回 None 表示游戏还未结束,
否则返回 X/O(Chess 对象) 表示赢家。
"""
# 横纵连续
for i in range(self.checkerboardSize - self.winNeed + 1):
xSum = np.sum(self.checkerboard[i : i + self.winNeed, :], axis=0)
ySum = np.sum(self.checkerboard[:, i : i + self.winNeed], axis=1)
if xSum.min() == -self.winNeed or ySum.min() == -self.winNeed:
return X
if xSum.max() == self.winNeed or ySum.max() == self.winNeed:
return O
# 对角线连续
for i in range(self.checkerboardSize - self.winNeed + 1):
for j in range(self.checkerboardSize - self.winNeed + 1):
subCheckerboard = self.checkerboard[
i : i + self.winNeed, j : j + self.winNeed
]
# 两条斜向对角线
diag1Sum, diag2Sum = (
subCheckerboard.trace(),
np.fliplr(subCheckerboard).trace(),
)
if diag1Sum == -self.winNeed or diag2Sum == -self.winNeed:
return X
if diag1Sum == self.winNeed or diag2Sum == self.winNeed:
return O
if np.all(self.checkerboard != 0):
# 平局
return 0
# 游戏还未结束
return None
@property
def isOver(self) -> bool:
"""游戏是否结束
Returns:
bool: 游戏是否结束
"""
return self.result is not None
这里使用了一个小技巧,对角线对应矩阵的迹。同时使用 np.min
和 np.max
来帮助我们通过求和解决判断是否有棋子连续到足够个数的一方。
接下来需要考虑下棋操作。首先考虑搜索需要用到当前的所有可能的走法,故这里编写一个方法来实现这个功能:
def getMoves(self) -> List[Move]:
"""获取所有可能的走法
Returns:
List[Move]
"""
return [
Move(d[0], d[1], self.nxtMove)
for d in list(zip(*np.where(self.checkerboard == 0)))
]
然后考虑需要判断某种走法对当前的棋盘来说是否合法,由位置(是否在棋盘内、该位置是否有棋子)和这种走法下棋的一方是否是State 记录下来的将要下棋的一方共同决定:
def couldMove(self, move: Move) -> bool:
"""判断走法是否合法
Args:
move (Move)
Returns:
bool: 是否合法
"""
if move.chess != self.nxtMove:
# 下棋的一方不对
return False
if not (
0 <= move.x < self.checkerboardSize and
0 <= move.y < self.checkerboardSize
):
# 位置不合法
return False
# 这位置还没有棋子
return self.checkerboard[move.x, move.y] == 0
以及需要根据走法获取一个在当前棋盘进行该走法后的下一状态,这也可以编写一个方法来实现,这里需要注意是返回一个新的 State
对象,用到了 python 3.7 的 future 库中的特性。同时要分清什么时候使用 `self.checkerboard`,什么时候是更新返回的 newCheckerboard
:
# def doMove(self, move):
def doMove(self, move: Move) -> State: # Python 3.7 need(PEP 563)
if not self.couldMove(move):
raise Exception("Move must be legel")
newCheckerboard = self.checkerboard.copy()
newCheckerboard[move.x, move.y] = move.chess.val
if self.nxtMove == X:
nxtMove = O
elif self.nxtMove == O:
nxtMove = X
# return type(self)(nxtMove, newCheckerboard, self.winNeed)
# Python 3.7 need(PEP 563)
return State(nxtMove, newCheckerboard, self.winNeed)
最后实现一个方法来输出当前棋盘状态(这是题目要求的):
def show(self, outputFn: function = print) -> None:
"""显示当前棋盘状态
Args:
outputFn (function, optional):
输出的函数. Defaults to print.
"""
board = np.copy(self.checkerboard)
def strLines(r):
return (" "
+ " | ".join(map(
lambda x: STATUS.get(int(x), " "), r))
+ " ")
for r in board[:-1]:
outputFn(strLines(r))
outputFn("-" * (len(r) * 4 - 1))
outputFn(strLines(board[-1]))
outputFn()
对于不同的电脑环境、不同的终端、不同的字体可能每个字符的长度不同,需要适当调整这里的第 $19$ 行。
接下来考虑蒙特卡洛树的每个结点。将其定义为一个类 MCTSNode
:
# 结点
class MCTSNode:
def __init__(self, stat, fa=None):
"""
Args:
stat (State): 结点对应状态
fa (MCTSNode, optional): 父结点. Defaults to None.
"""
self.stat: State = stat
self.fa: MCTSNode = fa
# 子结点列表
self.sons: List[MCTSNode] = []
self._visits = 0 # 已访问过结点
self._results = {}
self._notTried = None
@property
def isFullyExpanded(self):
return len(self.notTried) == 0
@property
def notTried(self):
if self._notTried is None:
self._notTried = self.stat.getMoves()
# 通过打乱实现“随机”
shuffle(self._notTried)
return self._notTried
@property
def isEnd(self) -> bool:
"""是否是终端结点(叶子结点)
"""
return self.stat.isOver
其中 shuffle
可以直接使用 random
库提供的 random.shuffle
。
接下来是实现蒙特卡洛树搜索算法中的通过 UCB 选择子节点扩展:
def bestSon(self, c=1.5):
return self.sons[
np.argmax(
[
(nod.q / nod.n) + c* np.sqrt((2 * np.log(self.n)) / nod.n)
for nod in self.sons
]
)
]
@property
def q(self):
v = self.fa.stat.nxtMove.val
return (
self._results.get(v, 0)
- self._results.get(-1 * v, 0)
)
@property
def n(self):
return self._visits
其中 q
就是 UCB 公式式 1 中的 $N_r$。这里不进行 $+1$ 与 $\times 2$ 的修正(实际效果与修正后是一样的)。
需要注意这里的 q
应当使用 dict.get
的方式,因为有可能还没有更新过,否则就要使用 defaultdict
来定义 self._results
。
下面实现蒙特卡洛搜索需要的各种操作(Expand、Simulation 对应 Rollout,Back Propagate)。
def expand(self) -> MCTSNode:
stat = self.stat.doMove(self.notTried.pop())
son: MCTSNode = MCTSNode(stat, self)
self.sons.append(son)
return son
def rollout(self):
stat: State = self.stat
while not stat.isOver:
stat = stat.doMove(np.random.choice(stat.getMoves()))
return stat.result
def backpropagate(self, result):
self._visits += 1
self._results[result] = self._results.get(result, 0) + 1
if self.fa is not None:
self.fa.backpropagate(result)
这里反向传播使用了尾递归的方式,也可以使用 while 循环的方法。
同样需要注意这里 backpropagate
也使用的是 dict.get
方法。
最后是实现搜索树。对于一棵树我们往往只需要记录根节点就获得了所有信息。
class MCTS(object):
def __init__(self, rootNod: MCTSNode):
"""蒙特卡洛树
Args:
rootNod (MCTSNode): 根结点
"""
self.rootNod: MCTSNode = rootNod
接下来是实现蒙特卡洛树搜索中的操作。
可以直接根据前面的 MCTSNode.bestSon()
方法获得到要选择的结点,一直 while
循环到叶子结点就可以了:
def chooseNod(self) -> MCTSNode:
"""选择要扩展的结点
Returns:
MCTSNode
"""
cur = self.rootNod
# 递归到叶子结点并返回
while not cur.isEnd:
if (not cur.isFullyExpanded):
return cur.expand()
else:
cur = cur.bestSon()
return cur
根据 UCB 算法进行搜索扩展找最好的操作。
def bstAction(self, simulationTimes: int = None, duration: float = None):
"""根据 UCB 算法进行搜索扩展,找到最佳操作
Args:
simulationTimes (int, optional):
为找到最佳操作已经模拟的次数. Defaults to None.
duration (float, optional):
算法搜索的时间(秒). Defaults to None.
"""
if simulationTimes is None:
if duration is None:
raise Exception("duration must be set")
endTime: float = time.time() + duration
while time.time() <= endTime:
nod = self.chooseNod()
nod.backpropagate(nod.rollout())
else:
for _ in range(simulationTimes):
nod = self.chooseNod()
nod.backpropagate(nod.rollout())
# 展开
return self.rootNod.bestSon(c=0.0)
这里使用时间限制搜索扩展。
def main():
# 棋盘
board_size = 7
checkerboardStat = np.zeros((board_size, board_size), dtype=int)
# 游戏
game = State(X, checkerboardStat, 4)
# 蒙特卡洛树搜索
while game.result is None:
game.show()
mcts = MCTS(MCTSNode(game))
bstNod = mcts.bstAction(simulationTimes=2)
game = bstNod.stat
result = game.result
if type(result) == Chess:
print("Game Over! Winner is:" + STATUS.get(result.val, "Unknown"))
else:
print("Game Over! Tie!")
print("End At: ")
game.show()
if __name__ == "__main__":
main()
每次选择最佳走法进行更新即可。这里限制搜索扩展时间为 $2 \,\mathrm{s} $。
全部代码整合后的可以在笔者的 Github 仓库查看: https://github.com/tripleInt/MCTS-TTT/,或 Pastebin:https://paste.ubuntu.com/p/NQtdDSx3sY/。