1032 lines
30 KiB
Python
Executable file
1032 lines
30 KiB
Python
Executable file
#!/usr/bin/env python
|
|
import random
|
|
import sys
|
|
import os
|
|
import math
|
|
from collections import defaultdict
|
|
|
|
import game
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from tqdm.auto import trange
|
|
import matplotlib.pyplot as plt
|
|
|
|
####################
|
|
# Global constants #
|
|
####################
|
|
|
|
BOARD_SIZE = 14
|
|
tiles = game.game_tiles()
|
|
TILE_COUNT = len(tiles)
|
|
|
|
# We need a max orientation count to encode orientation if desired
|
|
# If you want to use orientations later, this is handy.
|
|
ORIENT_MAX = max(len(t.permutations()) for t in tiles)
|
|
|
|
###################
|
|
# Utility helpers #
|
|
###################
|
|
|
|
|
|
def print_game_state(game_state: tuple[game.Board, list[int], list[int]]):
|
|
(board, p1tiles, p2tiles) = game_state
|
|
|
|
barr = []
|
|
for y in range(BOARD_SIZE):
|
|
row = []
|
|
for x in range(BOARD_SIZE):
|
|
row.append(board[(x, y)])
|
|
barr.append(row)
|
|
|
|
print(f" {'--' * BOARD_SIZE} ")
|
|
for row in barr:
|
|
print(
|
|
f"|{
|
|
''.join(
|
|
[
|
|
' '
|
|
if x == 0
|
|
else '\033[93m██\033[00m'
|
|
if x == 1
|
|
else '\033[94m██\033[00m'
|
|
if x == 2
|
|
else '██'
|
|
for x in row
|
|
]
|
|
)
|
|
}|"
|
|
)
|
|
print(f" {'--' * BOARD_SIZE} ")
|
|
|
|
print(f"\033[93mPlayer 1\033[00m tiles left: {p1tiles}")
|
|
print(f"\033[94mPlayer 2\033[00m tiles left: {p2tiles}")
|
|
|
|
|
|
def plot_losses(loss_history, out_path="loss_curve.png"):
|
|
if not loss_history:
|
|
print("No losses to plot.")
|
|
return
|
|
|
|
plt.figure()
|
|
plt.plot(range(1, len(loss_history) + 1), loss_history)
|
|
plt.xlabel("Training iteration")
|
|
plt.ylabel("Loss")
|
|
plt.title("BlokuZero training loss")
|
|
plt.tight_layout()
|
|
plt.savefig(out_path)
|
|
plt.close()
|
|
print(f"Saved loss plot to {out_path}")
|
|
|
|
|
|
###################
|
|
# Game state init #
|
|
###################
|
|
|
|
|
|
def initial_game_state():
|
|
return (
|
|
game.Board(),
|
|
[i for i in range(21)], # tiles for player 1
|
|
[i for i in range(21)], # tiles for player 2
|
|
)
|
|
|
|
|
|
############
|
|
# Encoding #
|
|
############
|
|
|
|
|
|
def encode_board(board: game.Board) -> torch.Tensor:
|
|
"""
|
|
Encode board as (3, H, W) float tensor:
|
|
channel 0: player 1 stones
|
|
channel 1: player 2 stones
|
|
channel 2: value 3 ("S" or special)
|
|
"""
|
|
arr = torch.zeros((3, BOARD_SIZE, BOARD_SIZE), dtype=torch.float32)
|
|
for y in range(BOARD_SIZE):
|
|
for x in range(BOARD_SIZE):
|
|
v = board[(x, y)]
|
|
if v == 1:
|
|
arr[0, y, x] = 1.0
|
|
elif v == 2:
|
|
arr[1, y, x] = 1.0
|
|
elif v == 3:
|
|
arr[2, y, x] = 1.0
|
|
return arr
|
|
|
|
|
|
def encode_tiles(p1tiles, p2tiles) -> torch.Tensor:
|
|
"""
|
|
Encode which tiles each player still has.
|
|
21 tiles per player -> 42-dim vector:
|
|
indices 0..20: tile i available for P1?
|
|
indices 21..41: tile i available for P2?
|
|
"""
|
|
v = torch.zeros(42, dtype=torch.float32)
|
|
for t in p1tiles:
|
|
v[t] = 1.0
|
|
offset = 21
|
|
for t in p2tiles:
|
|
v[offset + t] = 1.0
|
|
return v
|
|
|
|
|
|
def encode_state(game_state, player: int) -> torch.Tensor:
|
|
"""
|
|
Flattened state encoding for the NN:
|
|
- board one-hot: 3 * 14 * 14 = 588
|
|
- tiles: 42
|
|
- current player bit: 1
|
|
Total dim = 631.
|
|
"""
|
|
board, p1tiles, p2tiles = game_state
|
|
board_tensor = encode_board(board).flatten() # (588,)
|
|
tiles_tensor = encode_tiles(p1tiles, p2tiles) # (42,)
|
|
player_bit = torch.tensor([1.0 if player == 1 else 0.0], dtype=torch.float32)
|
|
return torch.cat([board_tensor, tiles_tensor, player_bit], dim=0) # (631,)
|
|
|
|
|
|
def encode_move(tile_idx: int, placement: tuple[int, int]) -> torch.Tensor:
|
|
"""
|
|
Encode a move (ignoring orientation for now):
|
|
- tile index one-hot (21)
|
|
- normalized x, y
|
|
Total dim = 23.
|
|
"""
|
|
x, y = placement
|
|
tile_vec = torch.zeros(21, dtype=torch.float32)
|
|
tile_vec[tile_idx] = 1.0
|
|
pos_vec = torch.tensor(
|
|
[x / (BOARD_SIZE - 1), y / (BOARD_SIZE - 1)],
|
|
dtype=torch.float32,
|
|
)
|
|
return torch.cat([tile_vec, pos_vec], dim=0) # (23,)
|
|
|
|
|
|
STATE_DIM = 588 + 42 + 1 # 631
|
|
MOVE_DIM = 23
|
|
|
|
|
|
##############################
|
|
# AlphaZero-style neural net #
|
|
##############################
|
|
|
|
|
|
class BlokuZeroNet(nn.Module):
|
|
"""
|
|
AlphaZero-style network for this Blokus-like game.
|
|
|
|
Given a state vector (flattened board + tiles + player),
|
|
it produces:
|
|
- value v in [-1,1] (estimate of who will win)
|
|
- a policy over moves *relative to a given move set*, by
|
|
combining a state embedding with move features.
|
|
"""
|
|
|
|
def __init__(self, state_emb_dim=256, move_emb_dim=64):
|
|
super().__init__()
|
|
|
|
# State encoder: simple MLP on the 1D state vector
|
|
self.state_mlp = nn.Sequential(
|
|
nn.Linear(STATE_DIM, 256),
|
|
nn.ReLU(),
|
|
nn.Linear(256, state_emb_dim),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
# Value head: from state embedding -> scalar, tanh in [-1,1]
|
|
self.value_head = nn.Sequential(
|
|
nn.Linear(state_emb_dim, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 1),
|
|
nn.Tanh(),
|
|
)
|
|
|
|
# Move encoder: encode move features
|
|
self.move_mlp = nn.Sequential(
|
|
nn.Linear(MOVE_DIM, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, move_emb_dim),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
# Policy head: combine state_emb + move_emb -> logit per move
|
|
self.policy_head = nn.Sequential(
|
|
nn.Linear(state_emb_dim + move_emb_dim, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, 1), # scalar logit per move
|
|
)
|
|
|
|
def forward_state(self, state_vec: torch.Tensor):
|
|
"""
|
|
state_vec: (STATE_DIM,)
|
|
returns: state_emb (state_emb_dim,), value scalar
|
|
"""
|
|
x = self.state_mlp(state_vec)
|
|
v = self.value_head(x).squeeze(-1)
|
|
return x, v
|
|
|
|
def forward_policy(self, state_emb: torch.Tensor, move_feats: torch.Tensor):
|
|
"""
|
|
state_emb: (state_emb_dim,)
|
|
move_feats: (N, MOVE_DIM)
|
|
returns: policy_probs: (N,), logits: (N,)
|
|
"""
|
|
move_emb = self.move_mlp(move_feats) # (N, move_emb_dim)
|
|
state_expanded = state_emb.unsqueeze(0).expand(move_emb.size(0), -1)
|
|
h = torch.cat(
|
|
[state_expanded, move_emb], dim=1
|
|
) # (N, state_emb_dim+move_emb_dim)
|
|
logits = self.policy_head(h).squeeze(-1) # (N,)
|
|
probs = F.softmax(logits, dim=0)
|
|
return probs, logits
|
|
|
|
|
|
def net_predict(net: BlokuZeroNet, game_state, player: int, moves):
|
|
"""
|
|
Evaluate net on a position from the perspective of 'player'.
|
|
|
|
Returns:
|
|
priors: tensor (N,) over given moves (N=len(moves)), or None if no moves
|
|
value: float in [-1,1] from player's POV
|
|
"""
|
|
if not moves:
|
|
return None, 0.0
|
|
|
|
state_vec = encode_state(game_state, player)
|
|
move_feats = []
|
|
for tidx, perm, placement in moves:
|
|
mf = encode_move(tidx, placement)
|
|
move_feats.append(mf)
|
|
|
|
state_vec = state_vec
|
|
move_feats = torch.stack(move_feats, dim=0)
|
|
|
|
with torch.no_grad():
|
|
state_emb, v = net.forward_state(state_vec)
|
|
priors, logits = net.forward_policy(state_emb, move_feats)
|
|
|
|
return priors, float(v.item())
|
|
|
|
|
|
##################
|
|
# Move generation #
|
|
##################
|
|
|
|
|
|
def get_legal_moves(game_state, player: int):
|
|
"""
|
|
Returns list of (tile_idx, perm, placement)
|
|
"""
|
|
board, p1tiles, p2tiles = game_state
|
|
gp = game.Player.P1 if player == 1 else game.Player.P2
|
|
|
|
tiles_left = p1tiles if player == 1 else p2tiles
|
|
|
|
moves = []
|
|
for tile_idx in tiles_left:
|
|
tile = tiles[tile_idx]
|
|
perms = tile.permutations()
|
|
for perm in perms:
|
|
plcs = board.tile_placements(perm, gp)
|
|
for plc in plcs:
|
|
moves.append((tile_idx, perm, plc))
|
|
return moves
|
|
|
|
|
|
###############
|
|
# MCTS (PUCT) #
|
|
###############
|
|
|
|
|
|
class MCTSNode:
|
|
def __init__(self, game_state, player: int, parent=None, move_from_parent=None):
|
|
self.game_state = game_state # (board, p1tiles, p2tiles)
|
|
self.player = player # player to move at this node (1 or 2)
|
|
self.parent = parent
|
|
self.move_from_parent = (
|
|
move_from_parent # index of move from parent to this node
|
|
)
|
|
|
|
self.children = {} # move_index -> child MCTSNode
|
|
self.moves = None # list of legal moves at this node
|
|
self.P = None # tensor of priors over moves (N,)
|
|
self.N = defaultdict(int) # visit count per move_index
|
|
self.W = defaultdict(float) # total value per move_index
|
|
self.Q = defaultdict(float) # mean value per move_index
|
|
|
|
self.is_expanded = False
|
|
self.is_terminal = False
|
|
self.value = 0.0
|
|
|
|
def expand(self, net: BlokuZeroNet):
|
|
"""
|
|
Expand this node: generate legal moves, get priors & value from net.
|
|
"""
|
|
if self.is_expanded:
|
|
return
|
|
|
|
moves = get_legal_moves(self.game_state, self.player)
|
|
self.moves = moves
|
|
|
|
if not moves:
|
|
# terminal: current player has no moves -> they lose
|
|
self.is_terminal = True
|
|
self.P = None
|
|
self.value = 0.0
|
|
self.is_expanded = True
|
|
return
|
|
|
|
priors, v = net_predict(net, self.game_state, self.player, moves)
|
|
self.P = priors
|
|
self.value = v # value from this node's player POV
|
|
self.is_expanded = True
|
|
|
|
def select_child(self, c_puct=1.5):
|
|
"""
|
|
Select move index according to PUCT:
|
|
a* = argmax_a (Q(s,a) + c_puct * P(s,a) * sqrt(sum_b N(s,b)) / (1 + N(s,a)))
|
|
"""
|
|
total_N = sum(self.N[i] for i in range(len(self.moves))) + 1e-8
|
|
best_score = -1e9
|
|
best_i = None
|
|
for i in range(len(self.moves)):
|
|
Q = self.Q[i]
|
|
P = float(self.P[i])
|
|
N_sa = self.N[i]
|
|
U = c_puct * P * math.sqrt(total_N) / (1 + N_sa)
|
|
score = Q + U
|
|
if score > best_score:
|
|
best_score = score
|
|
best_i = i
|
|
return best_i
|
|
|
|
def simulate_child(self, move_index: int):
|
|
"""
|
|
Given a move index in self.moves, return or create the child node.
|
|
"""
|
|
if move_index in self.children:
|
|
return self.children[move_index]
|
|
|
|
tidx, perm, placement = self.moves[move_index]
|
|
|
|
board, p1tiles, p2tiles = self.game_state
|
|
gp = game.Player.P1 if self.player == 1 else game.Player.P2
|
|
|
|
# Simulate placing the tile on a copy of the board
|
|
next_board = board.sim_place(perm, placement, gp) # does not modify original
|
|
|
|
next_p1tiles = list(p1tiles)
|
|
next_p2tiles = list(p2tiles)
|
|
if self.player == 1:
|
|
next_p1tiles.remove(tidx)
|
|
next_player = 2
|
|
else:
|
|
next_p2tiles.remove(tidx)
|
|
next_player = 1
|
|
|
|
next_state = (next_board, next_p1tiles, next_p2tiles)
|
|
child = MCTSNode(
|
|
next_state, next_player, parent=self, move_from_parent=move_index
|
|
)
|
|
self.children[move_index] = child
|
|
return child
|
|
|
|
|
|
def mcts_search(
|
|
root_state,
|
|
root_player: int,
|
|
net: BlokuZeroNet,
|
|
num_simulations: int = 50,
|
|
c_puct: float = 1.5,
|
|
temperature: float = 1.0,
|
|
):
|
|
"""
|
|
Run MCTS from (root_state, root_player).
|
|
|
|
Returns:
|
|
pi: tensor of size (num_moves,), MCTS policy over root moves
|
|
chosen_move: (tile_idx, perm, placement)
|
|
root_moves: list of all legal moves at root
|
|
"""
|
|
root = MCTSNode(root_state, root_player)
|
|
root.expand(net)
|
|
|
|
if root.is_terminal or root.P is None or not root.moves:
|
|
# No moves at root
|
|
return None, None, []
|
|
|
|
# Run simulations
|
|
for _ in range(num_simulations):
|
|
node = root
|
|
path = [node]
|
|
|
|
# 1. Selection: descend tree until leaf
|
|
while node.is_expanded and not node.is_terminal:
|
|
move_i = node.select_child(c_puct)
|
|
node = node.simulate_child(move_i)
|
|
path.append(node)
|
|
|
|
# 2. Expansion / evaluation of leaf
|
|
node.expand(net)
|
|
|
|
# 3. Compute leaf value from root player's perspective
|
|
if node.is_terminal:
|
|
# terminal: node.player has no moves -> node.player loses
|
|
if root_player == node.player:
|
|
v_root = -1.0
|
|
else:
|
|
v_root = +1.0
|
|
else:
|
|
# non-terminal: node.value is from node.player's POV
|
|
v_node = node.value
|
|
v_root = v_node if node.player == root_player else -v_node
|
|
|
|
# 4. Backup: update Q, N along the path
|
|
value = v_root
|
|
for i in range(len(path) - 1):
|
|
parent = path[i]
|
|
child = path[i + 1]
|
|
mv_i = child.move_from_parent
|
|
parent.N[mv_i] += 1
|
|
parent.W[mv_i] += value
|
|
parent.Q[mv_i] = parent.W[mv_i] / parent.N[mv_i]
|
|
value = -value # flip sign at each ply
|
|
|
|
# Visits at root
|
|
visits = torch.tensor(
|
|
[root.N[i] for i in range(len(root.moves))],
|
|
dtype=torch.float32,
|
|
)
|
|
|
|
if temperature == 0:
|
|
# deterministic: all mass on argmax visit
|
|
best = torch.argmax(visits).item()
|
|
pi = torch.zeros_like(visits)
|
|
pi[best] = 1.0
|
|
else:
|
|
pi = visits ** (1.0 / temperature)
|
|
if pi.sum() > 0:
|
|
pi = pi / pi.sum()
|
|
else:
|
|
pi = torch.ones_like(pi) / pi.numel()
|
|
|
|
# Choose move to actually play (sample from pi)
|
|
move_index = torch.multinomial(pi, num_samples=1).item()
|
|
chosen_move = root.moves[move_index]
|
|
|
|
return pi, chosen_move, root.moves
|
|
|
|
|
|
########################
|
|
# Self-play + training #
|
|
########################
|
|
|
|
|
|
def self_play_game(
|
|
net: BlokuZeroNet,
|
|
num_simulations: int = 50,
|
|
temperature: float = 1.0,
|
|
watch: bool = False,
|
|
random_move_prob: float = 0.0,
|
|
):
|
|
"""
|
|
Play one self-play game using MCTS + net.
|
|
|
|
Returns:
|
|
examples: list of (state_tensor, moves, pi, z)
|
|
- state_tensor: encoding of state at root of each move
|
|
- moves: list of legal moves at that root
|
|
- pi: tensor of MCTS policy over moves
|
|
- z: final outcome from that state's player POV (+1 or -1)
|
|
"""
|
|
game_state = initial_game_state()
|
|
board, p1tiles, p2tiles = game_state
|
|
player = 1
|
|
|
|
trajectory = [] # list of dicts: {"state": state_vec, "moves": moves, "pi": pi, "player": player}
|
|
|
|
while True:
|
|
moves = get_legal_moves(game_state, player)
|
|
if not moves:
|
|
# current player has no moves -> they lose
|
|
winner = 2 if player == 1 else 1
|
|
break
|
|
|
|
state_vec = encode_state(game_state, player)
|
|
|
|
if random.random() < random_move_prob:
|
|
print("PERFORMING RANDOM MOVE!!")
|
|
move_list = moves
|
|
num_m = len(move_list)
|
|
# uniform policy over legal moves
|
|
pi = torch.full((num_m,), 1.0 / num_m, dtype=torch.float32)
|
|
idx = random.randrange(num_m)
|
|
chosen_move = move_list[idx]
|
|
else:
|
|
# usual AlphaZero MCTS move
|
|
pi, chosen_move, move_list = mcts_search(
|
|
game_state,
|
|
player,
|
|
net,
|
|
num_simulations=num_simulations,
|
|
c_puct=1.5,
|
|
temperature=temperature,
|
|
)
|
|
|
|
if pi is None or chosen_move is None:
|
|
winner = 2 if player == 1 else 1
|
|
break
|
|
|
|
trajectory.append(
|
|
{
|
|
"state": state_vec,
|
|
"moves": move_list,
|
|
"pi": pi,
|
|
"player": player,
|
|
}
|
|
)
|
|
|
|
# Apply chosen move to real game
|
|
tidx, tile, placement = chosen_move
|
|
gp = game.Player.P1 if player == 1 else game.Player.P2
|
|
|
|
board.place(tile, placement, gp)
|
|
if player == 1:
|
|
p1tiles.remove(tidx)
|
|
else:
|
|
p2tiles.remove(tidx)
|
|
|
|
game_state = (board, p1tiles, p2tiles)
|
|
if watch:
|
|
print_game_state(game_state)
|
|
|
|
player = 2 if player == 1 else 1
|
|
|
|
# Convert trajectory into training examples
|
|
examples = []
|
|
for step in trajectory:
|
|
state_vec = step["state"]
|
|
moves = step["moves"]
|
|
pi = step["pi"]
|
|
p = step["player"]
|
|
z = 1.0 if p == winner else -1.0
|
|
examples.append((state_vec, moves, pi, z))
|
|
|
|
return examples
|
|
|
|
|
|
def alpha_zero_train_step(net: BlokuZeroNet, optimizer, batch):
|
|
"""
|
|
One gradient update on a batch of (state_vec, moves, pi_target, z_target).
|
|
"""
|
|
if not batch:
|
|
return 0.0
|
|
|
|
total_value_loss = 0.0
|
|
total_policy_loss = 0.0
|
|
|
|
for state_vec, moves, pi_target, z in batch:
|
|
state_vec = state_vec
|
|
z_t = torch.tensor(z, dtype=torch.float32)
|
|
|
|
# Encode moves
|
|
move_feats = []
|
|
for tidx, perm, placement in moves:
|
|
move_feats.append(encode_move(tidx, placement))
|
|
move_feats = torch.stack(move_feats, dim=0)
|
|
|
|
state_emb, v_pred = net.forward_state(state_vec)
|
|
p_pred, logits = net.forward_policy(state_emb, move_feats)
|
|
|
|
# Value loss
|
|
value_loss = (v_pred - z_t) ** 2
|
|
|
|
# Policy loss: cross-entropy between pi_target and p_pred
|
|
pi_target = pi_target.to(p_pred.dtype)
|
|
policy_loss = -torch.sum(pi_target * torch.log(p_pred + 1e-8))
|
|
|
|
total_value_loss = total_value_loss + value_loss
|
|
total_policy_loss = total_policy_loss + policy_loss
|
|
|
|
loss = (total_value_loss + total_policy_loss) / len(batch)
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
return float(loss.item())
|
|
|
|
|
|
def train_alpha_zero(
|
|
net: BlokuZeroNet,
|
|
num_iterations: int = 50,
|
|
games_per_iter: int = 5,
|
|
num_simulations: int = 50,
|
|
batch_size: int = 32,
|
|
lr: float = 1e-3,
|
|
save_path: str = "az_trained_agent.pt",
|
|
watch_selfplay: bool = False,
|
|
):
|
|
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
|
|
|
|
loss_history = []
|
|
ckpt_path = save_path + ".ckpt"
|
|
start_iter = 1
|
|
|
|
# Try to resume from checkpoint
|
|
if os.path.exists(ckpt_path):
|
|
ckpt = torch.load(ckpt_path, map_location="cpu")
|
|
try:
|
|
net.load_state_dict(ckpt["model_state"])
|
|
optimizer.load_state_dict(ckpt["optimizer_state"])
|
|
start_iter = ckpt["iteration"] + 1
|
|
loss_history = ckpt.get("loss_history", [])
|
|
print(f"Resuming training from iteration {start_iter} (found checkpoint).")
|
|
if start_iter > num_iterations:
|
|
print("Checkpoint exceeds requested num_iterations; nothing to train.")
|
|
plot_losses(loss_history, out_path="loss_curve.png")
|
|
torch.save(net.state_dict(), save_path)
|
|
return
|
|
except Exception as e:
|
|
print("Checkpoint incompatible, starting fresh.")
|
|
print("Reason:", e)
|
|
|
|
pbar = trange(
|
|
start_iter, num_iterations + 1, desc="AZ Training", dynamic_ncols=True
|
|
)
|
|
|
|
# simple schedule for random moves in early training
|
|
warmup_iters = 50 # how many iterations to use randomness
|
|
max_random_prob = 0.5 # random move probability at iteration 1
|
|
|
|
for it in pbar:
|
|
replay_buffer = []
|
|
|
|
# linearly decay random_move_prob from max_random_prob -> 0 over warmup_iters
|
|
if it <= warmup_iters:
|
|
random_move_prob = max_random_prob * (1.0 - (it - 1) / warmup_iters)
|
|
else:
|
|
random_move_prob = 0.0
|
|
|
|
# 1. Self-play games to generate fresh data
|
|
for g in range(games_per_iter):
|
|
examples = self_play_game(
|
|
net,
|
|
num_simulations=num_simulations,
|
|
temperature=1.0, # can also anneal later if you like
|
|
watch=watch_selfplay,
|
|
random_move_prob=random_move_prob,
|
|
)
|
|
replay_buffer.extend(examples)
|
|
|
|
if not replay_buffer:
|
|
print("No data generated this iteration, something is wrong.")
|
|
continue
|
|
|
|
random.shuffle(replay_buffer)
|
|
|
|
# 2. Train on this batch of data
|
|
batch_losses = []
|
|
for i in range(0, len(replay_buffer), batch_size):
|
|
batch = replay_buffer[i : i + batch_size]
|
|
loss = alpha_zero_train_step(net, optimizer, batch)
|
|
batch_losses.append(loss)
|
|
|
|
mean_loss = sum(batch_losses) / len(batch_losses)
|
|
loss_history.append(mean_loss)
|
|
pbar.set_postfix(loss=mean_loss, data=len(replay_buffer))
|
|
|
|
# 3. Save checkpoint
|
|
if it % 5 == 0 or it == num_iterations:
|
|
torch.save(
|
|
{
|
|
"iteration": it,
|
|
"model_state": net.state_dict(),
|
|
"optimizer_state": optimizer.state_dict(),
|
|
"loss_history": loss_history,
|
|
},
|
|
ckpt_path,
|
|
)
|
|
plot_losses(loss_history, out_path="loss_curve_ckpt.png")
|
|
|
|
torch.save(net.state_dict(), save_path)
|
|
print(f"\nTraining finished. Model saved to {save_path}")
|
|
plot_losses(loss_history, out_path="loss_curve.png")
|
|
|
|
|
|
def load_net_from_checkpoint(path: str) -> BlokuZeroNet:
|
|
"""
|
|
Load an BlokuZeroNet from either:
|
|
- a plain state_dict file (saved with torch.save(net.state_dict(), ...)), or
|
|
- a training checkpoint (dict with "model_state" key).
|
|
"""
|
|
net = BlokuZeroNet()
|
|
obj = torch.load(path, map_location="cpu")
|
|
|
|
# If it's a training checkpoint dict, extract model_state
|
|
if isinstance(obj, dict) and "model_state" in obj:
|
|
state_dict = obj["model_state"]
|
|
else:
|
|
# Assume it's already a state_dict
|
|
state_dict = obj
|
|
|
|
net.load_state_dict(state_dict)
|
|
net.eval()
|
|
return net
|
|
|
|
|
|
def battle(
|
|
checkpoint_a: str,
|
|
checkpoint_b: str,
|
|
num_games: int = 20,
|
|
watch: bool = False,
|
|
num_simulations: int = 50,
|
|
):
|
|
"""
|
|
Load two AlphaZero checkpoints and have them battle for num_games.
|
|
Alternates which net is player 1 for fairness.
|
|
Prints a small win-loss matrix at the end.
|
|
"""
|
|
print(f"Loading net A from: {checkpoint_a}")
|
|
print(f"Loading net B from: {checkpoint_b}")
|
|
|
|
net_a = load_net_from_checkpoint(checkpoint_a)
|
|
net_b = load_net_from_checkpoint(checkpoint_b)
|
|
|
|
# Matrix counters
|
|
# Rows = starting player (A as P1, B as P1)
|
|
# Cols = winner (A, B, Draw)
|
|
stats = {
|
|
"A_P1": {"A": 0, "B": 0, "D": 0},
|
|
"B_P1": {"A": 0, "B": 0, "D": 0},
|
|
}
|
|
|
|
for g in range(num_games):
|
|
if g % 2 == 0:
|
|
# Even games: A as P1, B as P2
|
|
start_label = "A_P1"
|
|
winner = play_game_between_nets(
|
|
net_a,
|
|
net_b,
|
|
watch=watch,
|
|
num_simulations=num_simulations,
|
|
)
|
|
if winner == 1:
|
|
stats[start_label]["A"] += 1
|
|
elif winner == 2:
|
|
stats[start_label]["B"] += 1
|
|
else:
|
|
stats[start_label]["D"] += 1
|
|
else:
|
|
# Odd games: B as P1, A as P2
|
|
start_label = "B_P1"
|
|
winner = play_game_between_nets(
|
|
net_b,
|
|
net_a,
|
|
watch=watch,
|
|
num_simulations=num_simulations,
|
|
)
|
|
if winner == 1:
|
|
stats[start_label]["B"] += 1 # player 1 is B
|
|
elif winner == 2:
|
|
stats[start_label]["A"] += 1 # player 2 is A
|
|
else:
|
|
stats[start_label]["D"] += 1
|
|
|
|
print(f"Game {g + 1}/{num_games} finished: winner = {winner}")
|
|
|
|
# Aggregate totals
|
|
total_a_wins = stats["A_P1"]["A"] + stats["B_P1"]["A"]
|
|
total_b_wins = stats["A_P1"]["B"] + stats["B_P1"]["B"]
|
|
total_draws = stats["A_P1"]["D"] + stats["B_P1"]["D"]
|
|
|
|
print("\n=== Battle results ===")
|
|
print(f"Total games: {num_games}")
|
|
print(f"Model A wins: {total_a_wins}")
|
|
print(f"Model B wins: {total_b_wins}")
|
|
print(f"Draws: {total_draws}")
|
|
|
|
print("\nWin-loss matrix (rows = starting player, cols = winner):")
|
|
print(" A_win B_win Draw")
|
|
print(
|
|
f"Start A (P1): {stats['A_P1']['A']:5d} {stats['A_P1']['B']:5d} {stats['A_P1']['D']:5d}"
|
|
)
|
|
print(
|
|
f"Start B (P1): {stats['B_P1']['A']:5d} {stats['B_P1']['B']:5d} {stats['B_P1']['D']:5d}"
|
|
)
|
|
|
|
|
|
###################
|
|
# Play vs the AI #
|
|
###################
|
|
|
|
|
|
def az_choose_move(
|
|
net: BlokuZeroNet, game_state, player: int, num_simulations: int = 100
|
|
):
|
|
"""
|
|
Use MCTS with the trained net to choose a move for actual play.
|
|
"""
|
|
moves = get_legal_moves(game_state, player)
|
|
if not moves:
|
|
return None
|
|
|
|
# Temperature 0 for deterministic play
|
|
pi, chosen_move, root_moves = mcts_search(
|
|
game_state,
|
|
player,
|
|
net,
|
|
num_simulations=num_simulations,
|
|
c_puct=1.5,
|
|
temperature=0.0,
|
|
)
|
|
if chosen_move is None:
|
|
return None
|
|
return chosen_move
|
|
|
|
|
|
def play_game_between_nets(
|
|
net_p1: BlokuZeroNet,
|
|
net_p2: BlokuZeroNet,
|
|
watch: bool = False,
|
|
max_turns: int = 500,
|
|
num_simulations: int = 50,
|
|
) -> int:
|
|
"""
|
|
Play one game between two AlphaZero nets using MCTS for both.
|
|
Returns:
|
|
1 if player 1 (net_p1) wins
|
|
2 if player 2 (net_p2) wins
|
|
0 if draw (max_turns reached)
|
|
"""
|
|
game_state = initial_game_state()
|
|
board, p1tiles, p2tiles = game_state
|
|
|
|
player = 1
|
|
turns = 0
|
|
|
|
while True:
|
|
turns += 1
|
|
if turns > max_turns:
|
|
# treat as draw
|
|
return 0
|
|
|
|
# Choose which net is playing this turn
|
|
net = net_p1 if player == 1 else net_p2
|
|
move = az_choose_move(net, game_state, player, num_simulations=num_simulations)
|
|
|
|
if move is None:
|
|
# current player cannot move -> they lose
|
|
if player == 1:
|
|
return 2
|
|
else:
|
|
return 1
|
|
|
|
tidx, tile, placement = move
|
|
gp = game.Player.P1 if player == 1 else game.Player.P2
|
|
|
|
board.place(tile, placement, gp)
|
|
if player == 1:
|
|
p1tiles.remove(tidx)
|
|
else:
|
|
p2tiles.remove(tidx)
|
|
|
|
game_state = (board, p1tiles, p2tiles)
|
|
|
|
if watch:
|
|
print_game_state(game_state)
|
|
|
|
player = 2 if player == 1 else 1
|
|
|
|
|
|
def play_vs_ai(net: BlokuZeroNet, human_is: int = 1, num_simulations: int = 100):
|
|
"""
|
|
Let a human play against the AlphaZero-style agent.
|
|
human_is: 1 or 2
|
|
"""
|
|
game_state = initial_game_state()
|
|
board, p1tiles, p2tiles = game_state
|
|
|
|
player = 1
|
|
while True:
|
|
print_game_state(game_state)
|
|
|
|
if player == human_is:
|
|
# Human move
|
|
moves = get_legal_moves(game_state, player)
|
|
if not moves:
|
|
print(f"No moves left for player {player}")
|
|
print("You lost 😢")
|
|
break
|
|
|
|
print(f"Player {player}, you have {len(moves)} possible moves.")
|
|
for i, (tidx, perm, plc) in enumerate(moves[:50]):
|
|
print(f"[{i}] tile {tidx} at {plc}")
|
|
if len(moves) > 50:
|
|
print(f"... and {len(moves) - 50} more moves not listed")
|
|
|
|
while True:
|
|
try:
|
|
choice = int(input("Enter move index: "))
|
|
if 0 <= choice < len(moves):
|
|
move = moves[choice]
|
|
break
|
|
else:
|
|
print("Invalid index, try again.")
|
|
except ValueError:
|
|
print("Please enter an integer.")
|
|
else:
|
|
# AI move
|
|
print(f"AI (player {player}) is thinking...")
|
|
move = az_choose_move(
|
|
net, game_state, player, num_simulations=num_simulations
|
|
)
|
|
if move is None:
|
|
print(f"No moves left for AI player {player}")
|
|
print("AI lost, you win! 🎉")
|
|
break
|
|
|
|
tidx, tile, placement = move
|
|
gp = game.Player.P1 if player == 1 else game.Player.P2
|
|
print(f"player {player} places tile {tidx} at {placement}\n{tile}")
|
|
|
|
board.place(tile, placement, gp)
|
|
if player == 1:
|
|
p1tiles.remove(tidx)
|
|
else:
|
|
p2tiles.remove(tidx)
|
|
|
|
game_state = (board, p1tiles, p2tiles)
|
|
player = 2 if player == 1 else 1
|
|
|
|
|
|
############
|
|
# main #
|
|
############
|
|
|
|
|
|
def main():
|
|
# Battle mode: --battle ckptA ckptB [--games N] [--watch]
|
|
if "--battle" in sys.argv:
|
|
idx = sys.argv.index("--battle")
|
|
try:
|
|
ckpt_a = sys.argv[idx + 1]
|
|
ckpt_b = sys.argv[idx + 2]
|
|
except IndexError:
|
|
print("Usage: blokus.py --battle ckptA ckptB [--games N] [--watch]")
|
|
return
|
|
|
|
num_games = 20
|
|
if "--games" in sys.argv:
|
|
gidx = sys.argv.index("--games")
|
|
try:
|
|
num_games = int(sys.argv[gidx + 1])
|
|
except (IndexError, ValueError):
|
|
print("Invalid or missing value for --games, using default 20.")
|
|
|
|
watch = "--watch" in sys.argv
|
|
|
|
battle(ckpt_a, ckpt_b, num_games=num_games, watch=watch)
|
|
return
|
|
|
|
net = BlokuZeroNet()
|
|
|
|
if "--play" in sys.argv:
|
|
model_path = "az_trained_agent.pt"
|
|
if os.path.exists(model_path):
|
|
print(f"Loading model from {model_path}")
|
|
state = torch.load(model_path, map_location="cpu")
|
|
try:
|
|
net.load_state_dict(state)
|
|
net.eval()
|
|
except Exception as e:
|
|
print("Saved model incompatible with current net, playing untrained.")
|
|
print("Reason:", e)
|
|
else:
|
|
print(
|
|
"Warning: az_trained_agent.pt not found. Playing with an untrained model."
|
|
)
|
|
|
|
# By default, human is player 1; change to 2 if you want
|
|
play_vs_ai(net, human_is=1, num_simulations=100)
|
|
|
|
else:
|
|
# AlphaZero-style training by self-play
|
|
train_alpha_zero(
|
|
net,
|
|
num_iterations=250, # number of training iterations
|
|
games_per_iter=5, # self-play games per iteration
|
|
num_simulations=50, # MCTS simulations per move during self-play
|
|
batch_size=32,
|
|
lr=1e-3,
|
|
save_path="az_trained_agent.pt",
|
|
watch_selfplay="--watch" in sys.argv,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|