560 lines
17 KiB
Python
Executable file
560 lines
17 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
import random
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
|
|
BOARD_SIZE = 14
|
|
|
|
# =======================
|
|
# Game setup and rules
|
|
# =======================
|
|
|
|
|
|
def make_board():
|
|
a = np.array([[0 for _ in range(BOARD_SIZE)] for _ in range(BOARD_SIZE)])
|
|
a[4, 4] = -1
|
|
a[9, 9] = -1
|
|
return a
|
|
|
|
|
|
tiles = [
|
|
np.array([[1]]),
|
|
np.array([[1], [1]]),
|
|
np.array([[1], [1], [1]]),
|
|
np.array([[1, 0], [1, 1]]),
|
|
np.array([[1], [1], [1], [1]]),
|
|
np.array([[1, 0], [1, 0], [1, 1]]),
|
|
np.array([[1, 0], [1, 1], [1, 0]]),
|
|
np.array([[1, 1], [1, 1]]),
|
|
np.array([[1, 1, 0], [0, 1, 1]]),
|
|
np.array([[1], [1], [1], [1], [1]]),
|
|
np.array([[1, 0], [1, 0], [1, 0], [1, 1]]),
|
|
np.array([[1, 0], [1, 0], [1, 1], [0, 1]]),
|
|
np.array([[1, 0], [1, 1], [1, 1]]),
|
|
np.array([[1, 1], [1, 0], [1, 1]]),
|
|
np.array([[1, 0], [1, 1], [1, 0], [1, 0]]),
|
|
np.array([[0, 1, 0], [0, 1, 0], [1, 1, 1]]),
|
|
np.array([[1, 0, 0], [1, 0, 0], [1, 1, 1]]),
|
|
np.array([[1, 1, 0], [0, 1, 1], [0, 0, 1]]),
|
|
np.array([[1, 0, 0], [1, 1, 1], [0, 0, 1]]),
|
|
np.array([[1, 0, 0], [1, 1, 1], [0, 1, 0]]),
|
|
np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]),
|
|
]
|
|
|
|
|
|
def clone_state(game_state):
|
|
board, p1tiles, p2tiles = game_state
|
|
return [board.copy(), p1tiles.copy(), p2tiles.copy()]
|
|
|
|
|
|
def get_permutations(which_tiles: list[int]):
|
|
"""
|
|
For each tile index in which_tiles, generate all unique rotations/flips.
|
|
Returns a list of (tile_index, oriented_tile).
|
|
"""
|
|
permutations = []
|
|
for tidx in which_tiles:
|
|
tile = tiles[tidx]
|
|
|
|
rots = [np.rot90(tile, k) for k in range(4)]
|
|
flips = [np.flip(r, axis=1) for r in rots] # horizontal flips
|
|
all_orients = rots + flips # 8 orientations
|
|
|
|
seen = set()
|
|
for t in all_orients:
|
|
key = (t.shape, t.tobytes())
|
|
if key not in seen:
|
|
seen.add(key)
|
|
permutations.append((tidx, t))
|
|
|
|
return permutations
|
|
|
|
|
|
def can_place(board: np.ndarray, tile: np.ndarray, player: int):
|
|
placements = []
|
|
has_minus_one = False
|
|
for x in range(BOARD_SIZE):
|
|
for y in range(BOARD_SIZE):
|
|
if board[x, y] == -1:
|
|
has_minus_one = True
|
|
with np.nditer(tile, flags=["multi_index"]) as it:
|
|
for v in it:
|
|
if v == 1:
|
|
(i, j) = it.multi_index
|
|
if x + i >= BOARD_SIZE:
|
|
break
|
|
if y + j >= BOARD_SIZE:
|
|
break
|
|
if board[x + i][y + j] > 0:
|
|
break
|
|
if x + i - 1 >= 0 and board[x + i - 1][y + j] == player:
|
|
break
|
|
if y + j - 1 >= 0 and board[x + i][y + j - 1] == player:
|
|
break
|
|
if x + i + 1 < BOARD_SIZE and board[x + i + 1][y + j] == player:
|
|
break
|
|
if y + j + 1 < BOARD_SIZE and board[x + i][y + j + 1] == player:
|
|
break
|
|
else:
|
|
placements.append((x, y))
|
|
|
|
final = []
|
|
if has_minus_one:
|
|
for x, y in placements:
|
|
with np.nditer(tile, flags=["multi_index"]) as it:
|
|
for v in it:
|
|
(i, j) = it.multi_index
|
|
if v == 1 and board[x + i, y + j] == -1:
|
|
final.append((x, y))
|
|
break
|
|
else:
|
|
for x, y in placements:
|
|
with np.nditer(tile, flags=["multi_index"]) as it:
|
|
for v in it:
|
|
(i, j) = it.multi_index
|
|
if (
|
|
x + i + 1 < BOARD_SIZE
|
|
and y + j + 1 < BOARD_SIZE
|
|
and board[x + i + 1, y + j + 1] == player
|
|
):
|
|
final.append((x, y))
|
|
break
|
|
if (
|
|
x + i + 1 < BOARD_SIZE
|
|
and y + j - 1 >= 0
|
|
and board[x + i + 1, y + j - 1] == player
|
|
):
|
|
final.append((x, y))
|
|
break
|
|
if (
|
|
x + i - 1 >= 0
|
|
and y + j + 1 < BOARD_SIZE
|
|
and board[x + i - 1, y + j + 1] == player
|
|
):
|
|
final.append((x, y))
|
|
break
|
|
if (
|
|
x + i - 1 >= 0
|
|
and y + j - 1 >= 0
|
|
and board[x + i - 1, y + j - 1] == player
|
|
):
|
|
final.append((x, y))
|
|
break
|
|
return final
|
|
|
|
|
|
def do_placement(
|
|
tidx: int, tile: np.ndarray, placement: tuple[int, int], game_state, player: int
|
|
):
|
|
(x, y) = placement
|
|
board, p1tiles, p2tiles = game_state
|
|
with np.nditer(tile, flags=["multi_index"]) as it:
|
|
for v in it:
|
|
(i, j) = it.multi_index
|
|
if v == 1:
|
|
board[x + i, y + j] = player
|
|
if player == 1:
|
|
p1tiles.remove(tidx)
|
|
else:
|
|
p2tiles.remove(tidx)
|
|
|
|
|
|
def print_game_state(game_state):
|
|
(board, p1tiles, p2tiles) = game_state
|
|
for row in board:
|
|
print(
|
|
"".join(
|
|
[
|
|
" " if x == 0 else "X" if x == 1 else "O" if x == 2 else "S"
|
|
for x in row
|
|
]
|
|
)
|
|
)
|
|
print("")
|
|
print(f"Player 1 tiles left: {p1tiles}")
|
|
print(f"Player 2 tiles left: {p2tiles}")
|
|
print("")
|
|
|
|
|
|
def reset_game():
|
|
board = make_board()
|
|
p1tiles = [i for i in range(21)]
|
|
p2tiles = [i for i in range(21)]
|
|
return [board, p1tiles, p2tiles] # list so it is mutable
|
|
|
|
|
|
def get_all_moves(game_state, player: int):
|
|
board, p1tiles, p2tiles = game_state
|
|
available_tiles = p1tiles if player == 1 else p2tiles
|
|
|
|
moves = []
|
|
for tidx, tile in get_permutations(available_tiles):
|
|
for placement in can_place(board, tile, player):
|
|
moves.append((tidx, tile, placement))
|
|
return moves
|
|
|
|
|
|
# =======================
|
|
# AlphaZero-style network
|
|
# =======================
|
|
|
|
|
|
def encode_board(board: np.ndarray, player: int) -> torch.Tensor:
|
|
"""
|
|
Channels:
|
|
0: current player's stones
|
|
1: opponent's stones
|
|
2: starting squares (-1)
|
|
"""
|
|
me = (board == player).astype(np.float32)
|
|
opp = ((board > 0) & (board != player)).astype(np.float32)
|
|
start = (board == -1).astype(np.float32)
|
|
state = np.stack([me, opp, start], axis=0) # (3, 14, 14)
|
|
return torch.from_numpy(state)
|
|
|
|
|
|
def encode_move(
|
|
tidx: int, tile: np.ndarray, placement: tuple[int, int]
|
|
) -> torch.Tensor:
|
|
x, y = placement
|
|
area = int(tile.sum())
|
|
return torch.tensor([tidx, x, y, area], dtype=torch.float32)
|
|
|
|
|
|
class PolicyValueNet(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(3, 32, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
)
|
|
conv_out_dim = 64 * BOARD_SIZE * BOARD_SIZE # 64 * 14 * 14
|
|
|
|
# Value head (board only)
|
|
self.value_head = nn.Sequential(
|
|
nn.Linear(conv_out_dim, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, 1),
|
|
nn.Tanh(), # value in [-1, 1]
|
|
)
|
|
|
|
# Policy head (board + move features)
|
|
self.policy_head = nn.Sequential(
|
|
nn.Linear(conv_out_dim + 4, 256),
|
|
nn.ReLU(),
|
|
nn.Linear(256, 1), # logit per move
|
|
)
|
|
|
|
def forward(self, board_tensor: torch.Tensor, move_features: torch.Tensor):
|
|
"""
|
|
board_tensor: (3, 14, 14)
|
|
move_features: (N, 4)
|
|
returns: (policy_logits: (N,), value: scalar)
|
|
"""
|
|
x = self.conv(board_tensor.unsqueeze(0)) # (1, 64, 14, 14)
|
|
x = x.view(1, -1) # (1, conv_out_dim)
|
|
board_embed = x
|
|
|
|
# value head
|
|
value = self.value_head(board_embed).squeeze(0).squeeze(-1) # scalar
|
|
|
|
# policy head
|
|
x_rep = board_embed.repeat(move_features.size(0), 1) # (N, conv_out_dim)
|
|
combined = torch.cat([x_rep, move_features], dim=1) # (N, conv_out_dim + 4)
|
|
logits = self.policy_head(combined).squeeze(-1) # (N,)
|
|
|
|
return logits, value
|
|
|
|
|
|
# =======================
|
|
# MCTS (AlphaZero-style)
|
|
# =======================
|
|
|
|
|
|
class MCTSNode:
|
|
def __init__(self, state, player: int):
|
|
self.state = state
|
|
self.player = player
|
|
self.is_expanded = False
|
|
self.is_terminal = False
|
|
|
|
self.moves: Optional[list[tuple[int, np.ndarray, tuple[int, int]]]] = None
|
|
self.priors: Optional[np.ndarray] = None
|
|
self.Nsa: Optional[np.ndarray] = None
|
|
self.Wsa: Optional[np.ndarray] = None
|
|
self.Qsa: Optional[np.ndarray] = None
|
|
|
|
self.children: dict[int, "MCTSNode"] = {}
|
|
|
|
def expand(self, net: PolicyValueNet, device="cpu") -> float:
|
|
"""
|
|
Returns value v from perspective of self.player.
|
|
"""
|
|
moves = get_all_moves(self.state, self.player)
|
|
if len(moves) == 0:
|
|
# No moves: this player loses
|
|
self.is_terminal = True
|
|
self.is_expanded = True
|
|
return -1.0
|
|
|
|
self.moves = moves
|
|
|
|
board, _, _ = self.state
|
|
board_tensor = encode_board(board, self.player).to(device)
|
|
move_feats = torch.stack(
|
|
[encode_move(tidx, tile, placement) for (tidx, tile, placement) in moves],
|
|
dim=0,
|
|
).to(device)
|
|
|
|
with torch.no_grad():
|
|
logits, value = net(board_tensor, move_feats)
|
|
probs = F.softmax(logits, dim=0).cpu().numpy()
|
|
v = float(value.item())
|
|
|
|
self.priors = probs
|
|
n = len(moves)
|
|
self.Nsa = np.zeros(n, dtype=np.float32)
|
|
self.Wsa = np.zeros(n, dtype=np.float32)
|
|
self.Qsa = np.zeros(n, dtype=np.float32)
|
|
self.is_expanded = True
|
|
return v
|
|
|
|
def select_action(self, c_puct: float = 1.5) -> int:
|
|
"""
|
|
Select action index using PUCT formula.
|
|
"""
|
|
Ns = np.sum(self.Nsa) + 1e-8
|
|
u = c_puct * self.priors * np.sqrt(Ns) / (1.0 + self.Nsa)
|
|
scores = self.Qsa + u
|
|
return int(np.argmax(scores))
|
|
|
|
|
|
def mcts_search(
|
|
net: PolicyValueNet,
|
|
root_state,
|
|
root_player: int,
|
|
n_simulations: int,
|
|
device="cpu",
|
|
c_puct: float = 1.5,
|
|
):
|
|
root = MCTSNode(clone_state(root_state), root_player)
|
|
|
|
for _ in range(n_simulations):
|
|
node = root
|
|
path: list[tuple[MCTSNode, int]] = []
|
|
|
|
# Traverse
|
|
while True:
|
|
if not node.is_expanded:
|
|
v = node.expand(net, device)
|
|
break
|
|
if node.is_terminal:
|
|
# Value from this player's perspective is -1 (no moves)
|
|
v = -1.0
|
|
break
|
|
|
|
a = node.select_action(c_puct)
|
|
path.append((node, a))
|
|
|
|
if a in node.children:
|
|
node = node.children[a]
|
|
else:
|
|
# create child
|
|
child_state = clone_state(node.state)
|
|
tidx, tile, placement = node.moves[a]
|
|
do_placement(tidx, tile, placement, child_state, node.player)
|
|
next_player = 2 if node.player == 1 else 1
|
|
child = MCTSNode(child_state, next_player)
|
|
node.children[a] = child
|
|
node = child
|
|
# next loop iteration will expand it
|
|
# Backpropagate value v (from leaf player's perspective)
|
|
val = v
|
|
# Going back up the tree, the perspective alternates each move
|
|
for parent, action_index in reversed(path):
|
|
val = -val # switch to parent's perspective
|
|
parent.Nsa[action_index] += 1.0
|
|
parent.Wsa[action_index] += val
|
|
parent.Qsa[action_index] = (
|
|
parent.Wsa[action_index] / parent.Nsa[action_index]
|
|
)
|
|
|
|
# After all simulations, derive policy target from root visit counts
|
|
if not root.is_expanded or root.is_terminal or root.moves is None:
|
|
return None, None, None
|
|
|
|
visits = root.Nsa
|
|
pi = visits / np.sum(visits)
|
|
# Sample action from pi (exploration); you can use argmax for greedy play
|
|
action_index = int(np.random.choice(len(root.moves), p=pi))
|
|
return root.moves, pi, action_index
|
|
|
|
|
|
# =======================
|
|
# Self-play + training
|
|
# =======================
|
|
|
|
|
|
def self_play_game(net: PolicyValueNet, n_simulations: int, device="cpu"):
|
|
"""
|
|
Plays one self-play game using MCTS + shared network.
|
|
Returns a list of training examples:
|
|
each entry: (board_snapshot, player, moves, pi, z)
|
|
"""
|
|
game_state = reset_game()
|
|
player = 1
|
|
history = [] # list of dicts: board, player, moves, pi, z (filled later)
|
|
|
|
while True:
|
|
moves = get_all_moves(game_state, player)
|
|
if len(moves) == 0:
|
|
winner = 2 if player == 1 else 1
|
|
break
|
|
|
|
# Run MCTS from current state
|
|
mcts_moves, pi, a_idx = mcts_search(
|
|
net, game_state, player, n_simulations, device
|
|
)
|
|
if mcts_moves is None:
|
|
winner = 2 if player == 1 else 1
|
|
break
|
|
|
|
# Save training position (copy board only; moves are references)
|
|
board_snapshot = game_state[0].copy()
|
|
history.append(
|
|
{
|
|
"board": board_snapshot,
|
|
"player": player,
|
|
"moves": mcts_moves,
|
|
"pi": pi,
|
|
"z": None, # fill after game
|
|
}
|
|
)
|
|
|
|
# Play chosen move
|
|
tidx, tile, placement = mcts_moves[a_idx]
|
|
do_placement(tidx, tile, placement, game_state, player)
|
|
player = 2 if player == 1 else 1
|
|
|
|
# Game finished, assign outcomes
|
|
for entry in history:
|
|
entry["z"] = 1.0 if entry["player"] == winner else -1.0
|
|
|
|
return history, winner
|
|
|
|
|
|
def train_on_history(net: PolicyValueNet, optimizer, history, device="cpu"):
|
|
"""
|
|
Single gradient step over all positions from one self-play game.
|
|
"""
|
|
net.train()
|
|
optimizer.zero_grad()
|
|
total_loss = 0.0
|
|
|
|
for entry in history:
|
|
board = entry["board"]
|
|
player = entry["player"]
|
|
moves = entry["moves"]
|
|
pi = entry["pi"]
|
|
z = entry["z"]
|
|
|
|
board_tensor = encode_board(board, player).to(device)
|
|
move_feats = torch.stack(
|
|
[encode_move(tidx, tile, placement) for (tidx, tile, placement) in moves],
|
|
dim=0,
|
|
).to(device)
|
|
target_pi = torch.from_numpy(pi).to(device)
|
|
target_z = torch.tensor(z, dtype=torch.float32, device=device)
|
|
|
|
logits, value = net(board_tensor, move_feats)
|
|
log_probs = F.log_softmax(logits, dim=0)
|
|
|
|
policy_loss = -(target_pi * log_probs).sum()
|
|
value_loss = F.mse_loss(value, target_z)
|
|
|
|
loss = policy_loss + value_loss
|
|
total_loss += loss
|
|
|
|
if len(history) > 0:
|
|
total_loss = total_loss / len(history)
|
|
|
|
total_loss.backward()
|
|
optimizer.step()
|
|
return float(total_loss.item())
|
|
|
|
|
|
# =======================
|
|
# Simple evaluation game
|
|
# =======================
|
|
|
|
|
|
def play_game_with_mcts(net: PolicyValueNet, n_simulations: int, device="cpu"):
|
|
"""
|
|
Watch two MCTS+net players (same weights) play against each other.
|
|
"""
|
|
net.eval()
|
|
game_state = reset_game()
|
|
player = 1
|
|
|
|
while True:
|
|
print_game_state(game_state)
|
|
moves = get_all_moves(game_state, player)
|
|
if not moves:
|
|
print(f"No moves left, player {player} loses.")
|
|
break
|
|
|
|
mcts_moves, pi, a_idx = mcts_search(
|
|
net, game_state, player, n_simulations, device
|
|
)
|
|
if mcts_moves is None:
|
|
print(f"No moves left (MCTS), player {player} loses.")
|
|
break
|
|
|
|
tidx, tile, placement = mcts_moves[a_idx]
|
|
print(f"Player {player} plays tile {tidx} at {placement}")
|
|
do_placement(tidx, tile, placement, game_state, player)
|
|
|
|
player = 2 if player == 1 else 1
|
|
|
|
|
|
# =======================
|
|
# Main training loop
|
|
# =======================
|
|
|
|
|
|
def main():
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print(f"Using device: {device}")
|
|
|
|
net = PolicyValueNet().to(device)
|
|
optimizer = optim.Adam(net.parameters(), lr=1e-3)
|
|
|
|
num_games = 200 # increase a lot for real training
|
|
n_simulations = 50 # MCTS sims per move (increase if it's too weak)
|
|
|
|
for g in range(1, num_games + 1):
|
|
history, winner = self_play_game(net, n_simulations, device)
|
|
loss = train_on_history(net, optimizer, history, device)
|
|
|
|
print(
|
|
f"Game {g}/{num_games}, winner: Player {winner}, loss: {loss:.4f}, positions: {len(history)}"
|
|
)
|
|
|
|
# occasionally watch a game
|
|
if g % 50 == 0:
|
|
print("Watching a game with current network:")
|
|
play_game_with_mcts(net, n_simulations=30, device=device)
|
|
|
|
# Save final network
|
|
torch.save(net.state_dict(), "alphazero_blokus_net.pth")
|
|
print("Saved network to alphazero_blokus_net.pth")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|