From e405a8749e65024e48dd78c0d1c43f3bd9220b32 Mon Sep 17 00:00:00 2001 From: Noa Aarts Date: Mon, 8 Dec 2025 15:38:10 +0100 Subject: [PATCH] add battles, random move injection and rename to BlokuZero --- blokus.py | 248 +++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 225 insertions(+), 23 deletions(-) diff --git a/blokus.py b/blokus.py index 975c252..183bfe1 100755 --- a/blokus.py +++ b/blokus.py @@ -73,7 +73,7 @@ def plot_losses(loss_history, out_path="loss_curve.png"): plt.plot(range(1, len(loss_history) + 1), loss_history) plt.xlabel("Training iteration") plt.ylabel("Loss") - plt.title("AlphaZero training loss") + plt.title("BlokuZero training loss") plt.tight_layout() plt.savefig(out_path) plt.close() @@ -175,7 +175,7 @@ MOVE_DIM = 23 ############################## -class AlphaZeroNet(nn.Module): +class BlokuZeroNet(nn.Module): """ AlphaZero-style network for this Blokus-like game. @@ -245,7 +245,7 @@ class AlphaZeroNet(nn.Module): return probs, logits -def net_predict(net: AlphaZeroNet, game_state, player: int, moves): +def net_predict(net: BlokuZeroNet, game_state, player: int, moves): """ Evaluate net on a position from the perspective of 'player'. @@ -322,7 +322,7 @@ class MCTSNode: self.is_terminal = False self.value = 0.0 - def expand(self, net: AlphaZeroNet): + def expand(self, net: BlokuZeroNet): """ Expand this node: generate legal moves, get priors & value from net. """ @@ -399,7 +399,7 @@ class MCTSNode: def mcts_search( root_state, root_player: int, - net: AlphaZeroNet, + net: BlokuZeroNet, num_simulations: int = 50, c_puct: float = 1.5, temperature: float = 1.0, @@ -487,10 +487,11 @@ def mcts_search( def self_play_game( - net: AlphaZeroNet, + net: BlokuZeroNet, num_simulations: int = 50, temperature: float = 1.0, watch: bool = False, + random_move_prob: float = 0.0, ): """ Play one self-play game using MCTS + net. @@ -516,18 +517,29 @@ def self_play_game( break state_vec = encode_state(game_state, player) - pi, chosen_move, move_list = mcts_search( - game_state, - player, - net, - num_simulations=num_simulations, - c_puct=1.5, - temperature=temperature, - ) - if pi is None or chosen_move is None: - winner = 2 if player == 1 else 1 - break + if random.random() < random_move_prob: + print("PERFORMING RANDOM MOVE!!") + move_list = moves + num_m = len(move_list) + # uniform policy over legal moves + pi = torch.full((num_m,), 1.0 / num_m, dtype=torch.float32) + idx = random.randrange(num_m) + chosen_move = move_list[idx] + else: + # usual AlphaZero MCTS move + pi, chosen_move, move_list = mcts_search( + game_state, + player, + net, + num_simulations=num_simulations, + c_puct=1.5, + temperature=temperature, + ) + + if pi is None or chosen_move is None: + winner = 2 if player == 1 else 1 + break trajectory.append( { @@ -567,7 +579,7 @@ def self_play_game( return examples -def alpha_zero_train_step(net: AlphaZeroNet, optimizer, batch): +def alpha_zero_train_step(net: BlokuZeroNet, optimizer, batch): """ One gradient update on a batch of (state_vec, moves, pi_target, z_target). """ @@ -610,7 +622,7 @@ def alpha_zero_train_step(net: AlphaZeroNet, optimizer, batch): def train_alpha_zero( - net: AlphaZeroNet, + net: BlokuZeroNet, num_iterations: int = 50, games_per_iter: int = 5, num_simulations: int = 50, @@ -647,16 +659,27 @@ def train_alpha_zero( start_iter, num_iterations + 1, desc="AZ Training", dynamic_ncols=True ) + # simple schedule for random moves in early training + warmup_iters = 50 # how many iterations to use randomness + max_random_prob = 0.5 # random move probability at iteration 1 + for it in pbar: replay_buffer = [] + # linearly decay random_move_prob from max_random_prob -> 0 over warmup_iters + if it <= warmup_iters: + random_move_prob = max_random_prob * (1.0 - (it - 1) / warmup_iters) + else: + random_move_prob = 0.0 + # 1. Self-play games to generate fresh data for g in range(games_per_iter): examples = self_play_game( net, num_simulations=num_simulations, - temperature=1.0, # can anneal later + temperature=1.0, # can also anneal later if you like watch=watch_selfplay, + random_move_prob=random_move_prob, ) replay_buffer.extend(examples) @@ -695,13 +718,115 @@ def train_alpha_zero( plot_losses(loss_history, out_path="loss_curve.png") +def load_net_from_checkpoint(path: str) -> BlokuZeroNet: + """ + Load an BlokuZeroNet from either: + - a plain state_dict file (saved with torch.save(net.state_dict(), ...)), or + - a training checkpoint (dict with "model_state" key). + """ + net = BlokuZeroNet() + obj = torch.load(path, map_location="cpu") + + # If it's a training checkpoint dict, extract model_state + if isinstance(obj, dict) and "model_state" in obj: + state_dict = obj["model_state"] + else: + # Assume it's already a state_dict + state_dict = obj + + net.load_state_dict(state_dict) + net.eval() + return net + + +def battle( + checkpoint_a: str, + checkpoint_b: str, + num_games: int = 20, + watch: bool = False, + num_simulations: int = 50, +): + """ + Load two AlphaZero checkpoints and have them battle for num_games. + Alternates which net is player 1 for fairness. + Prints a small win-loss matrix at the end. + """ + print(f"Loading net A from: {checkpoint_a}") + print(f"Loading net B from: {checkpoint_b}") + + net_a = load_net_from_checkpoint(checkpoint_a) + net_b = load_net_from_checkpoint(checkpoint_b) + + # Matrix counters + # Rows = starting player (A as P1, B as P1) + # Cols = winner (A, B, Draw) + stats = { + "A_P1": {"A": 0, "B": 0, "D": 0}, + "B_P1": {"A": 0, "B": 0, "D": 0}, + } + + for g in range(num_games): + if g % 2 == 0: + # Even games: A as P1, B as P2 + start_label = "A_P1" + winner = play_game_between_nets( + net_a, + net_b, + watch=watch, + num_simulations=num_simulations, + ) + if winner == 1: + stats[start_label]["A"] += 1 + elif winner == 2: + stats[start_label]["B"] += 1 + else: + stats[start_label]["D"] += 1 + else: + # Odd games: B as P1, A as P2 + start_label = "B_P1" + winner = play_game_between_nets( + net_b, + net_a, + watch=watch, + num_simulations=num_simulations, + ) + if winner == 1: + stats[start_label]["B"] += 1 # player 1 is B + elif winner == 2: + stats[start_label]["A"] += 1 # player 2 is A + else: + stats[start_label]["D"] += 1 + + print(f"Game {g + 1}/{num_games} finished: winner = {winner}") + + # Aggregate totals + total_a_wins = stats["A_P1"]["A"] + stats["B_P1"]["A"] + total_b_wins = stats["A_P1"]["B"] + stats["B_P1"]["B"] + total_draws = stats["A_P1"]["D"] + stats["B_P1"]["D"] + + print("\n=== Battle results ===") + print(f"Total games: {num_games}") + print(f"Model A wins: {total_a_wins}") + print(f"Model B wins: {total_b_wins}") + print(f"Draws: {total_draws}") + + print("\nWin-loss matrix (rows = starting player, cols = winner):") + print(" A_win B_win Draw") + print( + f"Start A (P1): {stats['A_P1']['A']:5d} {stats['A_P1']['B']:5d} {stats['A_P1']['D']:5d}" + ) + print( + f"Start B (P1): {stats['B_P1']['A']:5d} {stats['B_P1']['B']:5d} {stats['B_P1']['D']:5d}" + ) + + ################### # Play vs the AI # ################### def az_choose_move( - net: AlphaZeroNet, game_state, player: int, num_simulations: int = 100 + net: BlokuZeroNet, game_state, player: int, num_simulations: int = 100 ): """ Use MCTS with the trained net to choose a move for actual play. @@ -724,7 +849,61 @@ def az_choose_move( return chosen_move -def play_vs_ai(net: AlphaZeroNet, human_is: int = 1, num_simulations: int = 100): +def play_game_between_nets( + net_p1: BlokuZeroNet, + net_p2: BlokuZeroNet, + watch: bool = False, + max_turns: int = 500, + num_simulations: int = 50, +) -> int: + """ + Play one game between two AlphaZero nets using MCTS for both. + Returns: + 1 if player 1 (net_p1) wins + 2 if player 2 (net_p2) wins + 0 if draw (max_turns reached) + """ + game_state = initial_game_state() + board, p1tiles, p2tiles = game_state + + player = 1 + turns = 0 + + while True: + turns += 1 + if turns > max_turns: + # treat as draw + return 0 + + # Choose which net is playing this turn + net = net_p1 if player == 1 else net_p2 + move = az_choose_move(net, game_state, player, num_simulations=num_simulations) + + if move is None: + # current player cannot move -> they lose + if player == 1: + return 2 + else: + return 1 + + tidx, tile, placement = move + gp = game.Player.P1 if player == 1 else game.Player.P2 + + board.place(tile, placement, gp) + if player == 1: + p1tiles.remove(tidx) + else: + p2tiles.remove(tidx) + + game_state = (board, p1tiles, p2tiles) + + if watch: + print_game_state(game_state) + + player = 2 if player == 1 else 1 + + +def play_vs_ai(net: BlokuZeroNet, human_is: int = 1, num_simulations: int = 100): """ Let a human play against the AlphaZero-style agent. human_is: 1 or 2 @@ -791,7 +970,30 @@ def play_vs_ai(net: AlphaZeroNet, human_is: int = 1, num_simulations: int = 100) def main(): - net = AlphaZeroNet() + # Battle mode: --battle ckptA ckptB [--games N] [--watch] + if "--battle" in sys.argv: + idx = sys.argv.index("--battle") + try: + ckpt_a = sys.argv[idx + 1] + ckpt_b = sys.argv[idx + 2] + except IndexError: + print("Usage: blokus.py --battle ckptA ckptB [--games N] [--watch]") + return + + num_games = 20 + if "--games" in sys.argv: + gidx = sys.argv.index("--games") + try: + num_games = int(sys.argv[gidx + 1]) + except (IndexError, ValueError): + print("Invalid or missing value for --games, using default 20.") + + watch = "--watch" in sys.argv + + battle(ckpt_a, ckpt_b, num_games=num_games, watch=watch) + return + + net = BlokuZeroNet() if "--play" in sys.argv: model_path = "az_trained_agent.pt"