add battles, random move injection and rename to BlokuZero

This commit is contained in:
Noa Aarts 2025-12-08 15:38:10 +01:00
parent a7f9e68939
commit e405a8749e
Signed by: noa
GPG key ID: 1850932741EFF672

226
blokus.py
View file

@ -73,7 +73,7 @@ def plot_losses(loss_history, out_path="loss_curve.png"):
plt.plot(range(1, len(loss_history) + 1), loss_history)
plt.xlabel("Training iteration")
plt.ylabel("Loss")
plt.title("AlphaZero training loss")
plt.title("BlokuZero training loss")
plt.tight_layout()
plt.savefig(out_path)
plt.close()
@ -175,7 +175,7 @@ MOVE_DIM = 23
##############################
class AlphaZeroNet(nn.Module):
class BlokuZeroNet(nn.Module):
"""
AlphaZero-style network for this Blokus-like game.
@ -245,7 +245,7 @@ class AlphaZeroNet(nn.Module):
return probs, logits
def net_predict(net: AlphaZeroNet, game_state, player: int, moves):
def net_predict(net: BlokuZeroNet, game_state, player: int, moves):
"""
Evaluate net on a position from the perspective of 'player'.
@ -322,7 +322,7 @@ class MCTSNode:
self.is_terminal = False
self.value = 0.0
def expand(self, net: AlphaZeroNet):
def expand(self, net: BlokuZeroNet):
"""
Expand this node: generate legal moves, get priors & value from net.
"""
@ -399,7 +399,7 @@ class MCTSNode:
def mcts_search(
root_state,
root_player: int,
net: AlphaZeroNet,
net: BlokuZeroNet,
num_simulations: int = 50,
c_puct: float = 1.5,
temperature: float = 1.0,
@ -487,10 +487,11 @@ def mcts_search(
def self_play_game(
net: AlphaZeroNet,
net: BlokuZeroNet,
num_simulations: int = 50,
temperature: float = 1.0,
watch: bool = False,
random_move_prob: float = 0.0,
):
"""
Play one self-play game using MCTS + net.
@ -516,6 +517,17 @@ def self_play_game(
break
state_vec = encode_state(game_state, player)
if random.random() < random_move_prob:
print("PERFORMING RANDOM MOVE!!")
move_list = moves
num_m = len(move_list)
# uniform policy over legal moves
pi = torch.full((num_m,), 1.0 / num_m, dtype=torch.float32)
idx = random.randrange(num_m)
chosen_move = move_list[idx]
else:
# usual AlphaZero MCTS move
pi, chosen_move, move_list = mcts_search(
game_state,
player,
@ -567,7 +579,7 @@ def self_play_game(
return examples
def alpha_zero_train_step(net: AlphaZeroNet, optimizer, batch):
def alpha_zero_train_step(net: BlokuZeroNet, optimizer, batch):
"""
One gradient update on a batch of (state_vec, moves, pi_target, z_target).
"""
@ -610,7 +622,7 @@ def alpha_zero_train_step(net: AlphaZeroNet, optimizer, batch):
def train_alpha_zero(
net: AlphaZeroNet,
net: BlokuZeroNet,
num_iterations: int = 50,
games_per_iter: int = 5,
num_simulations: int = 50,
@ -647,16 +659,27 @@ def train_alpha_zero(
start_iter, num_iterations + 1, desc="AZ Training", dynamic_ncols=True
)
# simple schedule for random moves in early training
warmup_iters = 50 # how many iterations to use randomness
max_random_prob = 0.5 # random move probability at iteration 1
for it in pbar:
replay_buffer = []
# linearly decay random_move_prob from max_random_prob -> 0 over warmup_iters
if it <= warmup_iters:
random_move_prob = max_random_prob * (1.0 - (it - 1) / warmup_iters)
else:
random_move_prob = 0.0
# 1. Self-play games to generate fresh data
for g in range(games_per_iter):
examples = self_play_game(
net,
num_simulations=num_simulations,
temperature=1.0, # can anneal later
temperature=1.0, # can also anneal later if you like
watch=watch_selfplay,
random_move_prob=random_move_prob,
)
replay_buffer.extend(examples)
@ -695,13 +718,115 @@ def train_alpha_zero(
plot_losses(loss_history, out_path="loss_curve.png")
def load_net_from_checkpoint(path: str) -> BlokuZeroNet:
"""
Load an BlokuZeroNet from either:
- a plain state_dict file (saved with torch.save(net.state_dict(), ...)), or
- a training checkpoint (dict with "model_state" key).
"""
net = BlokuZeroNet()
obj = torch.load(path, map_location="cpu")
# If it's a training checkpoint dict, extract model_state
if isinstance(obj, dict) and "model_state" in obj:
state_dict = obj["model_state"]
else:
# Assume it's already a state_dict
state_dict = obj
net.load_state_dict(state_dict)
net.eval()
return net
def battle(
checkpoint_a: str,
checkpoint_b: str,
num_games: int = 20,
watch: bool = False,
num_simulations: int = 50,
):
"""
Load two AlphaZero checkpoints and have them battle for num_games.
Alternates which net is player 1 for fairness.
Prints a small win-loss matrix at the end.
"""
print(f"Loading net A from: {checkpoint_a}")
print(f"Loading net B from: {checkpoint_b}")
net_a = load_net_from_checkpoint(checkpoint_a)
net_b = load_net_from_checkpoint(checkpoint_b)
# Matrix counters
# Rows = starting player (A as P1, B as P1)
# Cols = winner (A, B, Draw)
stats = {
"A_P1": {"A": 0, "B": 0, "D": 0},
"B_P1": {"A": 0, "B": 0, "D": 0},
}
for g in range(num_games):
if g % 2 == 0:
# Even games: A as P1, B as P2
start_label = "A_P1"
winner = play_game_between_nets(
net_a,
net_b,
watch=watch,
num_simulations=num_simulations,
)
if winner == 1:
stats[start_label]["A"] += 1
elif winner == 2:
stats[start_label]["B"] += 1
else:
stats[start_label]["D"] += 1
else:
# Odd games: B as P1, A as P2
start_label = "B_P1"
winner = play_game_between_nets(
net_b,
net_a,
watch=watch,
num_simulations=num_simulations,
)
if winner == 1:
stats[start_label]["B"] += 1 # player 1 is B
elif winner == 2:
stats[start_label]["A"] += 1 # player 2 is A
else:
stats[start_label]["D"] += 1
print(f"Game {g + 1}/{num_games} finished: winner = {winner}")
# Aggregate totals
total_a_wins = stats["A_P1"]["A"] + stats["B_P1"]["A"]
total_b_wins = stats["A_P1"]["B"] + stats["B_P1"]["B"]
total_draws = stats["A_P1"]["D"] + stats["B_P1"]["D"]
print("\n=== Battle results ===")
print(f"Total games: {num_games}")
print(f"Model A wins: {total_a_wins}")
print(f"Model B wins: {total_b_wins}")
print(f"Draws: {total_draws}")
print("\nWin-loss matrix (rows = starting player, cols = winner):")
print(" A_win B_win Draw")
print(
f"Start A (P1): {stats['A_P1']['A']:5d} {stats['A_P1']['B']:5d} {stats['A_P1']['D']:5d}"
)
print(
f"Start B (P1): {stats['B_P1']['A']:5d} {stats['B_P1']['B']:5d} {stats['B_P1']['D']:5d}"
)
###################
# Play vs the AI #
###################
def az_choose_move(
net: AlphaZeroNet, game_state, player: int, num_simulations: int = 100
net: BlokuZeroNet, game_state, player: int, num_simulations: int = 100
):
"""
Use MCTS with the trained net to choose a move for actual play.
@ -724,7 +849,61 @@ def az_choose_move(
return chosen_move
def play_vs_ai(net: AlphaZeroNet, human_is: int = 1, num_simulations: int = 100):
def play_game_between_nets(
net_p1: BlokuZeroNet,
net_p2: BlokuZeroNet,
watch: bool = False,
max_turns: int = 500,
num_simulations: int = 50,
) -> int:
"""
Play one game between two AlphaZero nets using MCTS for both.
Returns:
1 if player 1 (net_p1) wins
2 if player 2 (net_p2) wins
0 if draw (max_turns reached)
"""
game_state = initial_game_state()
board, p1tiles, p2tiles = game_state
player = 1
turns = 0
while True:
turns += 1
if turns > max_turns:
# treat as draw
return 0
# Choose which net is playing this turn
net = net_p1 if player == 1 else net_p2
move = az_choose_move(net, game_state, player, num_simulations=num_simulations)
if move is None:
# current player cannot move -> they lose
if player == 1:
return 2
else:
return 1
tidx, tile, placement = move
gp = game.Player.P1 if player == 1 else game.Player.P2
board.place(tile, placement, gp)
if player == 1:
p1tiles.remove(tidx)
else:
p2tiles.remove(tidx)
game_state = (board, p1tiles, p2tiles)
if watch:
print_game_state(game_state)
player = 2 if player == 1 else 1
def play_vs_ai(net: BlokuZeroNet, human_is: int = 1, num_simulations: int = 100):
"""
Let a human play against the AlphaZero-style agent.
human_is: 1 or 2
@ -791,7 +970,30 @@ def play_vs_ai(net: AlphaZeroNet, human_is: int = 1, num_simulations: int = 100)
def main():
net = AlphaZeroNet()
# Battle mode: --battle ckptA ckptB [--games N] [--watch]
if "--battle" in sys.argv:
idx = sys.argv.index("--battle")
try:
ckpt_a = sys.argv[idx + 1]
ckpt_b = sys.argv[idx + 2]
except IndexError:
print("Usage: blokus.py --battle ckptA ckptB [--games N] [--watch]")
return
num_games = 20
if "--games" in sys.argv:
gidx = sys.argv.index("--games")
try:
num_games = int(sys.argv[gidx + 1])
except (IndexError, ValueError):
print("Invalid or missing value for --games, using default 20.")
watch = "--watch" in sys.argv
battle(ckpt_a, ckpt_b, num_games=num_games, watch=watch)
return
net = BlokuZeroNet()
if "--play" in sys.argv:
model_path = "az_trained_agent.pt"