blokus/blokus.py

560 lines
17 KiB
Python
Executable file

#!/usr/bin/env python3
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
BOARD_SIZE = 14
# =======================
# Game setup and rules
# =======================
def make_board():
a = np.array([[0 for _ in range(BOARD_SIZE)] for _ in range(BOARD_SIZE)])
a[4, 4] = -1
a[9, 9] = -1
return a
tiles = [
np.array([[1]]),
np.array([[1], [1]]),
np.array([[1], [1], [1]]),
np.array([[1, 0], [1, 1]]),
np.array([[1], [1], [1], [1]]),
np.array([[1, 0], [1, 0], [1, 1]]),
np.array([[1, 0], [1, 1], [1, 0]]),
np.array([[1, 1], [1, 1]]),
np.array([[1, 1, 0], [0, 1, 1]]),
np.array([[1], [1], [1], [1], [1]]),
np.array([[1, 0], [1, 0], [1, 0], [1, 1]]),
np.array([[1, 0], [1, 0], [1, 1], [0, 1]]),
np.array([[1, 0], [1, 1], [1, 1]]),
np.array([[1, 1], [1, 0], [1, 1]]),
np.array([[1, 0], [1, 1], [1, 0], [1, 0]]),
np.array([[0, 1, 0], [0, 1, 0], [1, 1, 1]]),
np.array([[1, 0, 0], [1, 0, 0], [1, 1, 1]]),
np.array([[1, 1, 0], [0, 1, 1], [0, 0, 1]]),
np.array([[1, 0, 0], [1, 1, 1], [0, 0, 1]]),
np.array([[1, 0, 0], [1, 1, 1], [0, 1, 0]]),
np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]),
]
def clone_state(game_state):
board, p1tiles, p2tiles = game_state
return [board.copy(), p1tiles.copy(), p2tiles.copy()]
def get_permutations(which_tiles: list[int]):
"""
For each tile index in which_tiles, generate all unique rotations/flips.
Returns a list of (tile_index, oriented_tile).
"""
permutations = []
for tidx in which_tiles:
tile = tiles[tidx]
rots = [np.rot90(tile, k) for k in range(4)]
flips = [np.flip(r, axis=1) for r in rots] # horizontal flips
all_orients = rots + flips # 8 orientations
seen = set()
for t in all_orients:
key = (t.shape, t.tobytes())
if key not in seen:
seen.add(key)
permutations.append((tidx, t))
return permutations
def can_place(board: np.ndarray, tile: np.ndarray, player: int):
placements = []
has_minus_one = False
for x in range(BOARD_SIZE):
for y in range(BOARD_SIZE):
if board[x, y] == -1:
has_minus_one = True
with np.nditer(tile, flags=["multi_index"]) as it:
for v in it:
if v == 1:
(i, j) = it.multi_index
if x + i >= BOARD_SIZE:
break
if y + j >= BOARD_SIZE:
break
if board[x + i][y + j] > 0:
break
if x + i - 1 >= 0 and board[x + i - 1][y + j] == player:
break
if y + j - 1 >= 0 and board[x + i][y + j - 1] == player:
break
if x + i + 1 < BOARD_SIZE and board[x + i + 1][y + j] == player:
break
if y + j + 1 < BOARD_SIZE and board[x + i][y + j + 1] == player:
break
else:
placements.append((x, y))
final = []
if has_minus_one:
for x, y in placements:
with np.nditer(tile, flags=["multi_index"]) as it:
for v in it:
(i, j) = it.multi_index
if v == 1 and board[x + i, y + j] == -1:
final.append((x, y))
break
else:
for x, y in placements:
with np.nditer(tile, flags=["multi_index"]) as it:
for v in it:
(i, j) = it.multi_index
if (
x + i + 1 < BOARD_SIZE
and y + j + 1 < BOARD_SIZE
and board[x + i + 1, y + j + 1] == player
):
final.append((x, y))
break
if (
x + i + 1 < BOARD_SIZE
and y + j - 1 >= 0
and board[x + i + 1, y + j - 1] == player
):
final.append((x, y))
break
if (
x + i - 1 >= 0
and y + j + 1 < BOARD_SIZE
and board[x + i - 1, y + j + 1] == player
):
final.append((x, y))
break
if (
x + i - 1 >= 0
and y + j - 1 >= 0
and board[x + i - 1, y + j - 1] == player
):
final.append((x, y))
break
return final
def do_placement(
tidx: int, tile: np.ndarray, placement: tuple[int, int], game_state, player: int
):
(x, y) = placement
board, p1tiles, p2tiles = game_state
with np.nditer(tile, flags=["multi_index"]) as it:
for v in it:
(i, j) = it.multi_index
if v == 1:
board[x + i, y + j] = player
if player == 1:
p1tiles.remove(tidx)
else:
p2tiles.remove(tidx)
def print_game_state(game_state):
(board, p1tiles, p2tiles) = game_state
for row in board:
print(
"".join(
[
" " if x == 0 else "X" if x == 1 else "O" if x == 2 else "S"
for x in row
]
)
)
print("")
print(f"Player 1 tiles left: {p1tiles}")
print(f"Player 2 tiles left: {p2tiles}")
print("")
def reset_game():
board = make_board()
p1tiles = [i for i in range(21)]
p2tiles = [i for i in range(21)]
return [board, p1tiles, p2tiles] # list so it is mutable
def get_all_moves(game_state, player: int):
board, p1tiles, p2tiles = game_state
available_tiles = p1tiles if player == 1 else p2tiles
moves = []
for tidx, tile in get_permutations(available_tiles):
for placement in can_place(board, tile, player):
moves.append((tidx, tile, placement))
return moves
# =======================
# AlphaZero-style network
# =======================
def encode_board(board: np.ndarray, player: int) -> torch.Tensor:
"""
Channels:
0: current player's stones
1: opponent's stones
2: starting squares (-1)
"""
me = (board == player).astype(np.float32)
opp = ((board > 0) & (board != player)).astype(np.float32)
start = (board == -1).astype(np.float32)
state = np.stack([me, opp, start], axis=0) # (3, 14, 14)
return torch.from_numpy(state)
def encode_move(
tidx: int, tile: np.ndarray, placement: tuple[int, int]
) -> torch.Tensor:
x, y = placement
area = int(tile.sum())
return torch.tensor([tidx, x, y, area], dtype=torch.float32)
class PolicyValueNet(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
)
conv_out_dim = 64 * BOARD_SIZE * BOARD_SIZE # 64 * 14 * 14
# Value head (board only)
self.value_head = nn.Sequential(
nn.Linear(conv_out_dim, 128),
nn.ReLU(),
nn.Linear(128, 1),
nn.Tanh(), # value in [-1, 1]
)
# Policy head (board + move features)
self.policy_head = nn.Sequential(
nn.Linear(conv_out_dim + 4, 256),
nn.ReLU(),
nn.Linear(256, 1), # logit per move
)
def forward(self, board_tensor: torch.Tensor, move_features: torch.Tensor):
"""
board_tensor: (3, 14, 14)
move_features: (N, 4)
returns: (policy_logits: (N,), value: scalar)
"""
x = self.conv(board_tensor.unsqueeze(0)) # (1, 64, 14, 14)
x = x.view(1, -1) # (1, conv_out_dim)
board_embed = x
# value head
value = self.value_head(board_embed).squeeze(0).squeeze(-1) # scalar
# policy head
x_rep = board_embed.repeat(move_features.size(0), 1) # (N, conv_out_dim)
combined = torch.cat([x_rep, move_features], dim=1) # (N, conv_out_dim + 4)
logits = self.policy_head(combined).squeeze(-1) # (N,)
return logits, value
# =======================
# MCTS (AlphaZero-style)
# =======================
class MCTSNode:
def __init__(self, state, player: int):
self.state = state
self.player = player
self.is_expanded = False
self.is_terminal = False
self.moves: Optional[list[tuple[int, np.ndarray, tuple[int, int]]]] = None
self.priors: Optional[np.ndarray] = None
self.Nsa: Optional[np.ndarray] = None
self.Wsa: Optional[np.ndarray] = None
self.Qsa: Optional[np.ndarray] = None
self.children: dict[int, "MCTSNode"] = {}
def expand(self, net: PolicyValueNet, device="cpu") -> float:
"""
Returns value v from perspective of self.player.
"""
moves = get_all_moves(self.state, self.player)
if len(moves) == 0:
# No moves: this player loses
self.is_terminal = True
self.is_expanded = True
return -1.0
self.moves = moves
board, _, _ = self.state
board_tensor = encode_board(board, self.player).to(device)
move_feats = torch.stack(
[encode_move(tidx, tile, placement) for (tidx, tile, placement) in moves],
dim=0,
).to(device)
with torch.no_grad():
logits, value = net(board_tensor, move_feats)
probs = F.softmax(logits, dim=0).cpu().numpy()
v = float(value.item())
self.priors = probs
n = len(moves)
self.Nsa = np.zeros(n, dtype=np.float32)
self.Wsa = np.zeros(n, dtype=np.float32)
self.Qsa = np.zeros(n, dtype=np.float32)
self.is_expanded = True
return v
def select_action(self, c_puct: float = 1.5) -> int:
"""
Select action index using PUCT formula.
"""
Ns = np.sum(self.Nsa) + 1e-8
u = c_puct * self.priors * np.sqrt(Ns) / (1.0 + self.Nsa)
scores = self.Qsa + u
return int(np.argmax(scores))
def mcts_search(
net: PolicyValueNet,
root_state,
root_player: int,
n_simulations: int,
device="cpu",
c_puct: float = 1.5,
):
root = MCTSNode(clone_state(root_state), root_player)
for _ in range(n_simulations):
node = root
path: list[tuple[MCTSNode, int]] = []
# Traverse
while True:
if not node.is_expanded:
v = node.expand(net, device)
break
if node.is_terminal:
# Value from this player's perspective is -1 (no moves)
v = -1.0
break
a = node.select_action(c_puct)
path.append((node, a))
if a in node.children:
node = node.children[a]
else:
# create child
child_state = clone_state(node.state)
tidx, tile, placement = node.moves[a]
do_placement(tidx, tile, placement, child_state, node.player)
next_player = 2 if node.player == 1 else 1
child = MCTSNode(child_state, next_player)
node.children[a] = child
node = child
# next loop iteration will expand it
# Backpropagate value v (from leaf player's perspective)
val = v
# Going back up the tree, the perspective alternates each move
for parent, action_index in reversed(path):
val = -val # switch to parent's perspective
parent.Nsa[action_index] += 1.0
parent.Wsa[action_index] += val
parent.Qsa[action_index] = (
parent.Wsa[action_index] / parent.Nsa[action_index]
)
# After all simulations, derive policy target from root visit counts
if not root.is_expanded or root.is_terminal or root.moves is None:
return None, None, None
visits = root.Nsa
pi = visits / np.sum(visits)
# Sample action from pi (exploration); you can use argmax for greedy play
action_index = int(np.random.choice(len(root.moves), p=pi))
return root.moves, pi, action_index
# =======================
# Self-play + training
# =======================
def self_play_game(net: PolicyValueNet, n_simulations: int, device="cpu"):
"""
Plays one self-play game using MCTS + shared network.
Returns a list of training examples:
each entry: (board_snapshot, player, moves, pi, z)
"""
game_state = reset_game()
player = 1
history = [] # list of dicts: board, player, moves, pi, z (filled later)
while True:
moves = get_all_moves(game_state, player)
if len(moves) == 0:
winner = 2 if player == 1 else 1
break
# Run MCTS from current state
mcts_moves, pi, a_idx = mcts_search(
net, game_state, player, n_simulations, device
)
if mcts_moves is None:
winner = 2 if player == 1 else 1
break
# Save training position (copy board only; moves are references)
board_snapshot = game_state[0].copy()
history.append(
{
"board": board_snapshot,
"player": player,
"moves": mcts_moves,
"pi": pi,
"z": None, # fill after game
}
)
# Play chosen move
tidx, tile, placement = mcts_moves[a_idx]
do_placement(tidx, tile, placement, game_state, player)
player = 2 if player == 1 else 1
# Game finished, assign outcomes
for entry in history:
entry["z"] = 1.0 if entry["player"] == winner else -1.0
return history, winner
def train_on_history(net: PolicyValueNet, optimizer, history, device="cpu"):
"""
Single gradient step over all positions from one self-play game.
"""
net.train()
optimizer.zero_grad()
total_loss = 0.0
for entry in history:
board = entry["board"]
player = entry["player"]
moves = entry["moves"]
pi = entry["pi"]
z = entry["z"]
board_tensor = encode_board(board, player).to(device)
move_feats = torch.stack(
[encode_move(tidx, tile, placement) for (tidx, tile, placement) in moves],
dim=0,
).to(device)
target_pi = torch.from_numpy(pi).to(device)
target_z = torch.tensor(z, dtype=torch.float32, device=device)
logits, value = net(board_tensor, move_feats)
log_probs = F.log_softmax(logits, dim=0)
policy_loss = -(target_pi * log_probs).sum()
value_loss = F.mse_loss(value, target_z)
loss = policy_loss + value_loss
total_loss += loss
if len(history) > 0:
total_loss = total_loss / len(history)
total_loss.backward()
optimizer.step()
return float(total_loss.item())
# =======================
# Simple evaluation game
# =======================
def play_game_with_mcts(net: PolicyValueNet, n_simulations: int, device="cpu"):
"""
Watch two MCTS+net players (same weights) play against each other.
"""
net.eval()
game_state = reset_game()
player = 1
while True:
print_game_state(game_state)
moves = get_all_moves(game_state, player)
if not moves:
print(f"No moves left, player {player} loses.")
break
mcts_moves, pi, a_idx = mcts_search(
net, game_state, player, n_simulations, device
)
if mcts_moves is None:
print(f"No moves left (MCTS), player {player} loses.")
break
tidx, tile, placement = mcts_moves[a_idx]
print(f"Player {player} plays tile {tidx} at {placement}")
do_placement(tidx, tile, placement, game_state, player)
player = 2 if player == 1 else 1
# =======================
# Main training loop
# =======================
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
net = PolicyValueNet().to(device)
optimizer = optim.Adam(net.parameters(), lr=1e-3)
num_games = 200 # increase a lot for real training
n_simulations = 50 # MCTS sims per move (increase if it's too weak)
for g in range(1, num_games + 1):
history, winner = self_play_game(net, n_simulations, device)
loss = train_on_history(net, optimizer, history, device)
print(
f"Game {g}/{num_games}, winner: Player {winner}, loss: {loss:.4f}, positions: {len(history)}"
)
# occasionally watch a game
if g % 50 == 0:
print("Watching a game with current network:")
play_game_with_mcts(net, n_simulations=30, device=device)
# Save final network
torch.save(net.state_dict(), "alphazero_blokus_net.pth")
print("Saved network to alphazero_blokus_net.pth")
if __name__ == "__main__":
main()