blokus/blokus.py

1032 lines
30 KiB
Python
Executable file

#!/usr/bin/env python
import random
import sys
import os
import math
from collections import defaultdict
import game
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import trange
import matplotlib.pyplot as plt
####################
# Global constants #
####################
BOARD_SIZE = 14
tiles = game.game_tiles()
TILE_COUNT = len(tiles)
# We need a max orientation count to encode orientation if desired
# If you want to use orientations later, this is handy.
ORIENT_MAX = max(len(t.permutations()) for t in tiles)
###################
# Utility helpers #
###################
def print_game_state(game_state: tuple[game.Board, list[int], list[int]]):
(board, p1tiles, p2tiles) = game_state
barr = []
for y in range(BOARD_SIZE):
row = []
for x in range(BOARD_SIZE):
row.append(board[(x, y)])
barr.append(row)
print(f" {'--' * BOARD_SIZE} ")
for row in barr:
print(
f"|{
''.join(
[
' '
if x == 0
else '\033[93m██\033[00m'
if x == 1
else '\033[94m██\033[00m'
if x == 2
else '██'
for x in row
]
)
}|"
)
print(f" {'--' * BOARD_SIZE} ")
print(f"\033[93mPlayer 1\033[00m tiles left: {p1tiles}")
print(f"\033[94mPlayer 2\033[00m tiles left: {p2tiles}")
def plot_losses(loss_history, out_path="loss_curve.png"):
if not loss_history:
print("No losses to plot.")
return
plt.figure()
plt.plot(range(1, len(loss_history) + 1), loss_history)
plt.xlabel("Training iteration")
plt.ylabel("Loss")
plt.title("BlokuZero training loss")
plt.tight_layout()
plt.savefig(out_path)
plt.close()
print(f"Saved loss plot to {out_path}")
###################
# Game state init #
###################
def initial_game_state():
return (
game.Board(),
[i for i in range(21)], # tiles for player 1
[i for i in range(21)], # tiles for player 2
)
############
# Encoding #
############
def encode_board(board: game.Board) -> torch.Tensor:
"""
Encode board as (3, H, W) float tensor:
channel 0: player 1 stones
channel 1: player 2 stones
channel 2: value 3 ("S" or special)
"""
arr = torch.zeros((3, BOARD_SIZE, BOARD_SIZE), dtype=torch.float32)
for y in range(BOARD_SIZE):
for x in range(BOARD_SIZE):
v = board[(x, y)]
if v == 1:
arr[0, y, x] = 1.0
elif v == 2:
arr[1, y, x] = 1.0
elif v == 3:
arr[2, y, x] = 1.0
return arr
def encode_tiles(p1tiles, p2tiles) -> torch.Tensor:
"""
Encode which tiles each player still has.
21 tiles per player -> 42-dim vector:
indices 0..20: tile i available for P1?
indices 21..41: tile i available for P2?
"""
v = torch.zeros(42, dtype=torch.float32)
for t in p1tiles:
v[t] = 1.0
offset = 21
for t in p2tiles:
v[offset + t] = 1.0
return v
def encode_state(game_state, player: int) -> torch.Tensor:
"""
Flattened state encoding for the NN:
- board one-hot: 3 * 14 * 14 = 588
- tiles: 42
- current player bit: 1
Total dim = 631.
"""
board, p1tiles, p2tiles = game_state
board_tensor = encode_board(board).flatten() # (588,)
tiles_tensor = encode_tiles(p1tiles, p2tiles) # (42,)
player_bit = torch.tensor([1.0 if player == 1 else 0.0], dtype=torch.float32)
return torch.cat([board_tensor, tiles_tensor, player_bit], dim=0) # (631,)
def encode_move(tile_idx: int, placement: tuple[int, int]) -> torch.Tensor:
"""
Encode a move (ignoring orientation for now):
- tile index one-hot (21)
- normalized x, y
Total dim = 23.
"""
x, y = placement
tile_vec = torch.zeros(21, dtype=torch.float32)
tile_vec[tile_idx] = 1.0
pos_vec = torch.tensor(
[x / (BOARD_SIZE - 1), y / (BOARD_SIZE - 1)],
dtype=torch.float32,
)
return torch.cat([tile_vec, pos_vec], dim=0) # (23,)
STATE_DIM = 588 + 42 + 1 # 631
MOVE_DIM = 23
##############################
# AlphaZero-style neural net #
##############################
class BlokuZeroNet(nn.Module):
"""
AlphaZero-style network for this Blokus-like game.
Given a state vector (flattened board + tiles + player),
it produces:
- value v in [-1,1] (estimate of who will win)
- a policy over moves *relative to a given move set*, by
combining a state embedding with move features.
"""
def __init__(self, state_emb_dim=256, move_emb_dim=64):
super().__init__()
# State encoder: simple MLP on the 1D state vector
self.state_mlp = nn.Sequential(
nn.Linear(STATE_DIM, 256),
nn.ReLU(),
nn.Linear(256, state_emb_dim),
nn.ReLU(),
)
# Value head: from state embedding -> scalar, tanh in [-1,1]
self.value_head = nn.Sequential(
nn.Linear(state_emb_dim, 64),
nn.ReLU(),
nn.Linear(64, 1),
nn.Tanh(),
)
# Move encoder: encode move features
self.move_mlp = nn.Sequential(
nn.Linear(MOVE_DIM, 64),
nn.ReLU(),
nn.Linear(64, move_emb_dim),
nn.ReLU(),
)
# Policy head: combine state_emb + move_emb -> logit per move
self.policy_head = nn.Sequential(
nn.Linear(state_emb_dim + move_emb_dim, 128),
nn.ReLU(),
nn.Linear(128, 1), # scalar logit per move
)
def forward_state(self, state_vec: torch.Tensor):
"""
state_vec: (STATE_DIM,)
returns: state_emb (state_emb_dim,), value scalar
"""
x = self.state_mlp(state_vec)
v = self.value_head(x).squeeze(-1)
return x, v
def forward_policy(self, state_emb: torch.Tensor, move_feats: torch.Tensor):
"""
state_emb: (state_emb_dim,)
move_feats: (N, MOVE_DIM)
returns: policy_probs: (N,), logits: (N,)
"""
move_emb = self.move_mlp(move_feats) # (N, move_emb_dim)
state_expanded = state_emb.unsqueeze(0).expand(move_emb.size(0), -1)
h = torch.cat(
[state_expanded, move_emb], dim=1
) # (N, state_emb_dim+move_emb_dim)
logits = self.policy_head(h).squeeze(-1) # (N,)
probs = F.softmax(logits, dim=0)
return probs, logits
def net_predict(net: BlokuZeroNet, game_state, player: int, moves):
"""
Evaluate net on a position from the perspective of 'player'.
Returns:
priors: tensor (N,) over given moves (N=len(moves)), or None if no moves
value: float in [-1,1] from player's POV
"""
if not moves:
return None, 0.0
state_vec = encode_state(game_state, player)
move_feats = []
for tidx, perm, placement in moves:
mf = encode_move(tidx, placement)
move_feats.append(mf)
state_vec = state_vec
move_feats = torch.stack(move_feats, dim=0)
with torch.no_grad():
state_emb, v = net.forward_state(state_vec)
priors, logits = net.forward_policy(state_emb, move_feats)
return priors, float(v.item())
##################
# Move generation #
##################
def get_legal_moves(game_state, player: int):
"""
Returns list of (tile_idx, perm, placement)
"""
board, p1tiles, p2tiles = game_state
gp = game.Player.P1 if player == 1 else game.Player.P2
tiles_left = p1tiles if player == 1 else p2tiles
moves = []
for tile_idx in tiles_left:
tile = tiles[tile_idx]
perms = tile.permutations()
for perm in perms:
plcs = board.tile_placements(perm, gp)
for plc in plcs:
moves.append((tile_idx, perm, plc))
return moves
###############
# MCTS (PUCT) #
###############
class MCTSNode:
def __init__(self, game_state, player: int, parent=None, move_from_parent=None):
self.game_state = game_state # (board, p1tiles, p2tiles)
self.player = player # player to move at this node (1 or 2)
self.parent = parent
self.move_from_parent = (
move_from_parent # index of move from parent to this node
)
self.children = {} # move_index -> child MCTSNode
self.moves = None # list of legal moves at this node
self.P = None # tensor of priors over moves (N,)
self.N = defaultdict(int) # visit count per move_index
self.W = defaultdict(float) # total value per move_index
self.Q = defaultdict(float) # mean value per move_index
self.is_expanded = False
self.is_terminal = False
self.value = 0.0
def expand(self, net: BlokuZeroNet):
"""
Expand this node: generate legal moves, get priors & value from net.
"""
if self.is_expanded:
return
moves = get_legal_moves(self.game_state, self.player)
self.moves = moves
if not moves:
# terminal: current player has no moves -> they lose
self.is_terminal = True
self.P = None
self.value = 0.0
self.is_expanded = True
return
priors, v = net_predict(net, self.game_state, self.player, moves)
self.P = priors
self.value = v # value from this node's player POV
self.is_expanded = True
def select_child(self, c_puct=1.5):
"""
Select move index according to PUCT:
a* = argmax_a (Q(s,a) + c_puct * P(s,a) * sqrt(sum_b N(s,b)) / (1 + N(s,a)))
"""
total_N = sum(self.N[i] for i in range(len(self.moves))) + 1e-8
best_score = -1e9
best_i = None
for i in range(len(self.moves)):
Q = self.Q[i]
P = float(self.P[i])
N_sa = self.N[i]
U = c_puct * P * math.sqrt(total_N) / (1 + N_sa)
score = Q + U
if score > best_score:
best_score = score
best_i = i
return best_i
def simulate_child(self, move_index: int):
"""
Given a move index in self.moves, return or create the child node.
"""
if move_index in self.children:
return self.children[move_index]
tidx, perm, placement = self.moves[move_index]
board, p1tiles, p2tiles = self.game_state
gp = game.Player.P1 if self.player == 1 else game.Player.P2
# Simulate placing the tile on a copy of the board
next_board = board.sim_place(perm, placement, gp) # does not modify original
next_p1tiles = list(p1tiles)
next_p2tiles = list(p2tiles)
if self.player == 1:
next_p1tiles.remove(tidx)
next_player = 2
else:
next_p2tiles.remove(tidx)
next_player = 1
next_state = (next_board, next_p1tiles, next_p2tiles)
child = MCTSNode(
next_state, next_player, parent=self, move_from_parent=move_index
)
self.children[move_index] = child
return child
def mcts_search(
root_state,
root_player: int,
net: BlokuZeroNet,
num_simulations: int = 50,
c_puct: float = 1.5,
temperature: float = 1.0,
):
"""
Run MCTS from (root_state, root_player).
Returns:
pi: tensor of size (num_moves,), MCTS policy over root moves
chosen_move: (tile_idx, perm, placement)
root_moves: list of all legal moves at root
"""
root = MCTSNode(root_state, root_player)
root.expand(net)
if root.is_terminal or root.P is None or not root.moves:
# No moves at root
return None, None, []
# Run simulations
for _ in range(num_simulations):
node = root
path = [node]
# 1. Selection: descend tree until leaf
while node.is_expanded and not node.is_terminal:
move_i = node.select_child(c_puct)
node = node.simulate_child(move_i)
path.append(node)
# 2. Expansion / evaluation of leaf
node.expand(net)
# 3. Compute leaf value from root player's perspective
if node.is_terminal:
# terminal: node.player has no moves -> node.player loses
if root_player == node.player:
v_root = -1.0
else:
v_root = +1.0
else:
# non-terminal: node.value is from node.player's POV
v_node = node.value
v_root = v_node if node.player == root_player else -v_node
# 4. Backup: update Q, N along the path
value = v_root
for i in range(len(path) - 1):
parent = path[i]
child = path[i + 1]
mv_i = child.move_from_parent
parent.N[mv_i] += 1
parent.W[mv_i] += value
parent.Q[mv_i] = parent.W[mv_i] / parent.N[mv_i]
value = -value # flip sign at each ply
# Visits at root
visits = torch.tensor(
[root.N[i] for i in range(len(root.moves))],
dtype=torch.float32,
)
if temperature == 0:
# deterministic: all mass on argmax visit
best = torch.argmax(visits).item()
pi = torch.zeros_like(visits)
pi[best] = 1.0
else:
pi = visits ** (1.0 / temperature)
if pi.sum() > 0:
pi = pi / pi.sum()
else:
pi = torch.ones_like(pi) / pi.numel()
# Choose move to actually play (sample from pi)
move_index = torch.multinomial(pi, num_samples=1).item()
chosen_move = root.moves[move_index]
return pi, chosen_move, root.moves
########################
# Self-play + training #
########################
def self_play_game(
net: BlokuZeroNet,
num_simulations: int = 50,
temperature: float = 1.0,
watch: bool = False,
random_move_prob: float = 0.0,
):
"""
Play one self-play game using MCTS + net.
Returns:
examples: list of (state_tensor, moves, pi, z)
- state_tensor: encoding of state at root of each move
- moves: list of legal moves at that root
- pi: tensor of MCTS policy over moves
- z: final outcome from that state's player POV (+1 or -1)
"""
game_state = initial_game_state()
board, p1tiles, p2tiles = game_state
player = 1
trajectory = [] # list of dicts: {"state": state_vec, "moves": moves, "pi": pi, "player": player}
while True:
moves = get_legal_moves(game_state, player)
if not moves:
# current player has no moves -> they lose
winner = 2 if player == 1 else 1
break
state_vec = encode_state(game_state, player)
if random.random() < random_move_prob:
print("PERFORMING RANDOM MOVE!!")
move_list = moves
num_m = len(move_list)
# uniform policy over legal moves
pi = torch.full((num_m,), 1.0 / num_m, dtype=torch.float32)
idx = random.randrange(num_m)
chosen_move = move_list[idx]
else:
# usual AlphaZero MCTS move
pi, chosen_move, move_list = mcts_search(
game_state,
player,
net,
num_simulations=num_simulations,
c_puct=1.5,
temperature=temperature,
)
if pi is None or chosen_move is None:
winner = 2 if player == 1 else 1
break
trajectory.append(
{
"state": state_vec,
"moves": move_list,
"pi": pi,
"player": player,
}
)
# Apply chosen move to real game
tidx, tile, placement = chosen_move
gp = game.Player.P1 if player == 1 else game.Player.P2
board.place(tile, placement, gp)
if player == 1:
p1tiles.remove(tidx)
else:
p2tiles.remove(tidx)
game_state = (board, p1tiles, p2tiles)
if watch:
print_game_state(game_state)
player = 2 if player == 1 else 1
# Convert trajectory into training examples
examples = []
for step in trajectory:
state_vec = step["state"]
moves = step["moves"]
pi = step["pi"]
p = step["player"]
z = 1.0 if p == winner else -1.0
examples.append((state_vec, moves, pi, z))
return examples
def alpha_zero_train_step(net: BlokuZeroNet, optimizer, batch):
"""
One gradient update on a batch of (state_vec, moves, pi_target, z_target).
"""
if not batch:
return 0.0
total_value_loss = 0.0
total_policy_loss = 0.0
for state_vec, moves, pi_target, z in batch:
state_vec = state_vec
z_t = torch.tensor(z, dtype=torch.float32)
# Encode moves
move_feats = []
for tidx, perm, placement in moves:
move_feats.append(encode_move(tidx, placement))
move_feats = torch.stack(move_feats, dim=0)
state_emb, v_pred = net.forward_state(state_vec)
p_pred, logits = net.forward_policy(state_emb, move_feats)
# Value loss
value_loss = (v_pred - z_t) ** 2
# Policy loss: cross-entropy between pi_target and p_pred
pi_target = pi_target.to(p_pred.dtype)
policy_loss = -torch.sum(pi_target * torch.log(p_pred + 1e-8))
total_value_loss = total_value_loss + value_loss
total_policy_loss = total_policy_loss + policy_loss
loss = (total_value_loss + total_policy_loss) / len(batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return float(loss.item())
def train_alpha_zero(
net: BlokuZeroNet,
num_iterations: int = 50,
games_per_iter: int = 5,
num_simulations: int = 50,
batch_size: int = 32,
lr: float = 1e-3,
save_path: str = "az_trained_agent.pt",
watch_selfplay: bool = False,
):
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
loss_history = []
ckpt_path = save_path + ".ckpt"
start_iter = 1
# Try to resume from checkpoint
if os.path.exists(ckpt_path):
ckpt = torch.load(ckpt_path, map_location="cpu")
try:
net.load_state_dict(ckpt["model_state"])
optimizer.load_state_dict(ckpt["optimizer_state"])
start_iter = ckpt["iteration"] + 1
loss_history = ckpt.get("loss_history", [])
print(f"Resuming training from iteration {start_iter} (found checkpoint).")
if start_iter > num_iterations:
print("Checkpoint exceeds requested num_iterations; nothing to train.")
plot_losses(loss_history, out_path="loss_curve.png")
torch.save(net.state_dict(), save_path)
return
except Exception as e:
print("Checkpoint incompatible, starting fresh.")
print("Reason:", e)
pbar = trange(
start_iter, num_iterations + 1, desc="AZ Training", dynamic_ncols=True
)
# simple schedule for random moves in early training
warmup_iters = 50 # how many iterations to use randomness
max_random_prob = 0.5 # random move probability at iteration 1
for it in pbar:
replay_buffer = []
# linearly decay random_move_prob from max_random_prob -> 0 over warmup_iters
if it <= warmup_iters:
random_move_prob = max_random_prob * (1.0 - (it - 1) / warmup_iters)
else:
random_move_prob = 0.0
# 1. Self-play games to generate fresh data
for g in range(games_per_iter):
examples = self_play_game(
net,
num_simulations=num_simulations,
temperature=1.0, # can also anneal later if you like
watch=watch_selfplay,
random_move_prob=random_move_prob,
)
replay_buffer.extend(examples)
if not replay_buffer:
print("No data generated this iteration, something is wrong.")
continue
random.shuffle(replay_buffer)
# 2. Train on this batch of data
batch_losses = []
for i in range(0, len(replay_buffer), batch_size):
batch = replay_buffer[i : i + batch_size]
loss = alpha_zero_train_step(net, optimizer, batch)
batch_losses.append(loss)
mean_loss = sum(batch_losses) / len(batch_losses)
loss_history.append(mean_loss)
pbar.set_postfix(loss=mean_loss, data=len(replay_buffer))
# 3. Save checkpoint
if it % 5 == 0 or it == num_iterations:
torch.save(
{
"iteration": it,
"model_state": net.state_dict(),
"optimizer_state": optimizer.state_dict(),
"loss_history": loss_history,
},
ckpt_path,
)
plot_losses(loss_history, out_path="loss_curve_ckpt.png")
torch.save(net.state_dict(), save_path)
print(f"\nTraining finished. Model saved to {save_path}")
plot_losses(loss_history, out_path="loss_curve.png")
def load_net_from_checkpoint(path: str) -> BlokuZeroNet:
"""
Load an BlokuZeroNet from either:
- a plain state_dict file (saved with torch.save(net.state_dict(), ...)), or
- a training checkpoint (dict with "model_state" key).
"""
net = BlokuZeroNet()
obj = torch.load(path, map_location="cpu")
# If it's a training checkpoint dict, extract model_state
if isinstance(obj, dict) and "model_state" in obj:
state_dict = obj["model_state"]
else:
# Assume it's already a state_dict
state_dict = obj
net.load_state_dict(state_dict)
net.eval()
return net
def battle(
checkpoint_a: str,
checkpoint_b: str,
num_games: int = 20,
watch: bool = False,
num_simulations: int = 50,
):
"""
Load two AlphaZero checkpoints and have them battle for num_games.
Alternates which net is player 1 for fairness.
Prints a small win-loss matrix at the end.
"""
print(f"Loading net A from: {checkpoint_a}")
print(f"Loading net B from: {checkpoint_b}")
net_a = load_net_from_checkpoint(checkpoint_a)
net_b = load_net_from_checkpoint(checkpoint_b)
# Matrix counters
# Rows = starting player (A as P1, B as P1)
# Cols = winner (A, B, Draw)
stats = {
"A_P1": {"A": 0, "B": 0, "D": 0},
"B_P1": {"A": 0, "B": 0, "D": 0},
}
for g in range(num_games):
if g % 2 == 0:
# Even games: A as P1, B as P2
start_label = "A_P1"
winner = play_game_between_nets(
net_a,
net_b,
watch=watch,
num_simulations=num_simulations,
)
if winner == 1:
stats[start_label]["A"] += 1
elif winner == 2:
stats[start_label]["B"] += 1
else:
stats[start_label]["D"] += 1
else:
# Odd games: B as P1, A as P2
start_label = "B_P1"
winner = play_game_between_nets(
net_b,
net_a,
watch=watch,
num_simulations=num_simulations,
)
if winner == 1:
stats[start_label]["B"] += 1 # player 1 is B
elif winner == 2:
stats[start_label]["A"] += 1 # player 2 is A
else:
stats[start_label]["D"] += 1
print(f"Game {g + 1}/{num_games} finished: winner = {winner}")
# Aggregate totals
total_a_wins = stats["A_P1"]["A"] + stats["B_P1"]["A"]
total_b_wins = stats["A_P1"]["B"] + stats["B_P1"]["B"]
total_draws = stats["A_P1"]["D"] + stats["B_P1"]["D"]
print("\n=== Battle results ===")
print(f"Total games: {num_games}")
print(f"Model A wins: {total_a_wins}")
print(f"Model B wins: {total_b_wins}")
print(f"Draws: {total_draws}")
print("\nWin-loss matrix (rows = starting player, cols = winner):")
print(" A_win B_win Draw")
print(
f"Start A (P1): {stats['A_P1']['A']:5d} {stats['A_P1']['B']:5d} {stats['A_P1']['D']:5d}"
)
print(
f"Start B (P1): {stats['B_P1']['A']:5d} {stats['B_P1']['B']:5d} {stats['B_P1']['D']:5d}"
)
###################
# Play vs the AI #
###################
def az_choose_move(
net: BlokuZeroNet, game_state, player: int, num_simulations: int = 100
):
"""
Use MCTS with the trained net to choose a move for actual play.
"""
moves = get_legal_moves(game_state, player)
if not moves:
return None
# Temperature 0 for deterministic play
pi, chosen_move, root_moves = mcts_search(
game_state,
player,
net,
num_simulations=num_simulations,
c_puct=1.5,
temperature=0.0,
)
if chosen_move is None:
return None
return chosen_move
def play_game_between_nets(
net_p1: BlokuZeroNet,
net_p2: BlokuZeroNet,
watch: bool = False,
max_turns: int = 500,
num_simulations: int = 50,
) -> int:
"""
Play one game between two AlphaZero nets using MCTS for both.
Returns:
1 if player 1 (net_p1) wins
2 if player 2 (net_p2) wins
0 if draw (max_turns reached)
"""
game_state = initial_game_state()
board, p1tiles, p2tiles = game_state
player = 1
turns = 0
while True:
turns += 1
if turns > max_turns:
# treat as draw
return 0
# Choose which net is playing this turn
net = net_p1 if player == 1 else net_p2
move = az_choose_move(net, game_state, player, num_simulations=num_simulations)
if move is None:
# current player cannot move -> they lose
if player == 1:
return 2
else:
return 1
tidx, tile, placement = move
gp = game.Player.P1 if player == 1 else game.Player.P2
board.place(tile, placement, gp)
if player == 1:
p1tiles.remove(tidx)
else:
p2tiles.remove(tidx)
game_state = (board, p1tiles, p2tiles)
if watch:
print_game_state(game_state)
player = 2 if player == 1 else 1
def play_vs_ai(net: BlokuZeroNet, human_is: int = 1, num_simulations: int = 100):
"""
Let a human play against the AlphaZero-style agent.
human_is: 1 or 2
"""
game_state = initial_game_state()
board, p1tiles, p2tiles = game_state
player = 1
while True:
print_game_state(game_state)
if player == human_is:
# Human move
moves = get_legal_moves(game_state, player)
if not moves:
print(f"No moves left for player {player}")
print("You lost 😢")
break
print(f"Player {player}, you have {len(moves)} possible moves.")
for i, (tidx, perm, plc) in enumerate(moves[:50]):
print(f"[{i}] tile {tidx} at {plc}")
if len(moves) > 50:
print(f"... and {len(moves) - 50} more moves not listed")
while True:
try:
choice = int(input("Enter move index: "))
if 0 <= choice < len(moves):
move = moves[choice]
break
else:
print("Invalid index, try again.")
except ValueError:
print("Please enter an integer.")
else:
# AI move
print(f"AI (player {player}) is thinking...")
move = az_choose_move(
net, game_state, player, num_simulations=num_simulations
)
if move is None:
print(f"No moves left for AI player {player}")
print("AI lost, you win! 🎉")
break
tidx, tile, placement = move
gp = game.Player.P1 if player == 1 else game.Player.P2
print(f"player {player} places tile {tidx} at {placement}\n{tile}")
board.place(tile, placement, gp)
if player == 1:
p1tiles.remove(tidx)
else:
p2tiles.remove(tidx)
game_state = (board, p1tiles, p2tiles)
player = 2 if player == 1 else 1
############
# main #
############
def main():
# Battle mode: --battle ckptA ckptB [--games N] [--watch]
if "--battle" in sys.argv:
idx = sys.argv.index("--battle")
try:
ckpt_a = sys.argv[idx + 1]
ckpt_b = sys.argv[idx + 2]
except IndexError:
print("Usage: blokus.py --battle ckptA ckptB [--games N] [--watch]")
return
num_games = 20
if "--games" in sys.argv:
gidx = sys.argv.index("--games")
try:
num_games = int(sys.argv[gidx + 1])
except (IndexError, ValueError):
print("Invalid or missing value for --games, using default 20.")
watch = "--watch" in sys.argv
battle(ckpt_a, ckpt_b, num_games=num_games, watch=watch)
return
net = BlokuZeroNet()
if "--play" in sys.argv:
model_path = "az_trained_agent.pt"
if os.path.exists(model_path):
print(f"Loading model from {model_path}")
state = torch.load(model_path, map_location="cpu")
try:
net.load_state_dict(state)
net.eval()
except Exception as e:
print("Saved model incompatible with current net, playing untrained.")
print("Reason:", e)
else:
print(
"Warning: az_trained_agent.pt not found. Playing with an untrained model."
)
# By default, human is player 1; change to 2 if you want
play_vs_ai(net, human_is=1, num_simulations=100)
else:
# AlphaZero-style training by self-play
train_alpha_zero(
net,
num_iterations=250, # number of training iterations
games_per_iter=5, # self-play games per iteration
num_simulations=50, # MCTS simulations per move during self-play
batch_size=32,
lr=1e-3,
save_path="az_trained_agent.pt",
watch_selfplay="--watch" in sys.argv,
)
if __name__ == "__main__":
main()