add battles, random move injection and rename to BlokuZero
This commit is contained in:
parent
a7f9e68939
commit
e405a8749e
1 changed files with 225 additions and 23 deletions
248
blokus.py
248
blokus.py
|
|
@ -73,7 +73,7 @@ def plot_losses(loss_history, out_path="loss_curve.png"):
|
|||
plt.plot(range(1, len(loss_history) + 1), loss_history)
|
||||
plt.xlabel("Training iteration")
|
||||
plt.ylabel("Loss")
|
||||
plt.title("AlphaZero training loss")
|
||||
plt.title("BlokuZero training loss")
|
||||
plt.tight_layout()
|
||||
plt.savefig(out_path)
|
||||
plt.close()
|
||||
|
|
@ -175,7 +175,7 @@ MOVE_DIM = 23
|
|||
##############################
|
||||
|
||||
|
||||
class AlphaZeroNet(nn.Module):
|
||||
class BlokuZeroNet(nn.Module):
|
||||
"""
|
||||
AlphaZero-style network for this Blokus-like game.
|
||||
|
||||
|
|
@ -245,7 +245,7 @@ class AlphaZeroNet(nn.Module):
|
|||
return probs, logits
|
||||
|
||||
|
||||
def net_predict(net: AlphaZeroNet, game_state, player: int, moves):
|
||||
def net_predict(net: BlokuZeroNet, game_state, player: int, moves):
|
||||
"""
|
||||
Evaluate net on a position from the perspective of 'player'.
|
||||
|
||||
|
|
@ -322,7 +322,7 @@ class MCTSNode:
|
|||
self.is_terminal = False
|
||||
self.value = 0.0
|
||||
|
||||
def expand(self, net: AlphaZeroNet):
|
||||
def expand(self, net: BlokuZeroNet):
|
||||
"""
|
||||
Expand this node: generate legal moves, get priors & value from net.
|
||||
"""
|
||||
|
|
@ -399,7 +399,7 @@ class MCTSNode:
|
|||
def mcts_search(
|
||||
root_state,
|
||||
root_player: int,
|
||||
net: AlphaZeroNet,
|
||||
net: BlokuZeroNet,
|
||||
num_simulations: int = 50,
|
||||
c_puct: float = 1.5,
|
||||
temperature: float = 1.0,
|
||||
|
|
@ -487,10 +487,11 @@ def mcts_search(
|
|||
|
||||
|
||||
def self_play_game(
|
||||
net: AlphaZeroNet,
|
||||
net: BlokuZeroNet,
|
||||
num_simulations: int = 50,
|
||||
temperature: float = 1.0,
|
||||
watch: bool = False,
|
||||
random_move_prob: float = 0.0,
|
||||
):
|
||||
"""
|
||||
Play one self-play game using MCTS + net.
|
||||
|
|
@ -516,18 +517,29 @@ def self_play_game(
|
|||
break
|
||||
|
||||
state_vec = encode_state(game_state, player)
|
||||
pi, chosen_move, move_list = mcts_search(
|
||||
game_state,
|
||||
player,
|
||||
net,
|
||||
num_simulations=num_simulations,
|
||||
c_puct=1.5,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
if pi is None or chosen_move is None:
|
||||
winner = 2 if player == 1 else 1
|
||||
break
|
||||
if random.random() < random_move_prob:
|
||||
print("PERFORMING RANDOM MOVE!!")
|
||||
move_list = moves
|
||||
num_m = len(move_list)
|
||||
# uniform policy over legal moves
|
||||
pi = torch.full((num_m,), 1.0 / num_m, dtype=torch.float32)
|
||||
idx = random.randrange(num_m)
|
||||
chosen_move = move_list[idx]
|
||||
else:
|
||||
# usual AlphaZero MCTS move
|
||||
pi, chosen_move, move_list = mcts_search(
|
||||
game_state,
|
||||
player,
|
||||
net,
|
||||
num_simulations=num_simulations,
|
||||
c_puct=1.5,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
if pi is None or chosen_move is None:
|
||||
winner = 2 if player == 1 else 1
|
||||
break
|
||||
|
||||
trajectory.append(
|
||||
{
|
||||
|
|
@ -567,7 +579,7 @@ def self_play_game(
|
|||
return examples
|
||||
|
||||
|
||||
def alpha_zero_train_step(net: AlphaZeroNet, optimizer, batch):
|
||||
def alpha_zero_train_step(net: BlokuZeroNet, optimizer, batch):
|
||||
"""
|
||||
One gradient update on a batch of (state_vec, moves, pi_target, z_target).
|
||||
"""
|
||||
|
|
@ -610,7 +622,7 @@ def alpha_zero_train_step(net: AlphaZeroNet, optimizer, batch):
|
|||
|
||||
|
||||
def train_alpha_zero(
|
||||
net: AlphaZeroNet,
|
||||
net: BlokuZeroNet,
|
||||
num_iterations: int = 50,
|
||||
games_per_iter: int = 5,
|
||||
num_simulations: int = 50,
|
||||
|
|
@ -647,16 +659,27 @@ def train_alpha_zero(
|
|||
start_iter, num_iterations + 1, desc="AZ Training", dynamic_ncols=True
|
||||
)
|
||||
|
||||
# simple schedule for random moves in early training
|
||||
warmup_iters = 50 # how many iterations to use randomness
|
||||
max_random_prob = 0.5 # random move probability at iteration 1
|
||||
|
||||
for it in pbar:
|
||||
replay_buffer = []
|
||||
|
||||
# linearly decay random_move_prob from max_random_prob -> 0 over warmup_iters
|
||||
if it <= warmup_iters:
|
||||
random_move_prob = max_random_prob * (1.0 - (it - 1) / warmup_iters)
|
||||
else:
|
||||
random_move_prob = 0.0
|
||||
|
||||
# 1. Self-play games to generate fresh data
|
||||
for g in range(games_per_iter):
|
||||
examples = self_play_game(
|
||||
net,
|
||||
num_simulations=num_simulations,
|
||||
temperature=1.0, # can anneal later
|
||||
temperature=1.0, # can also anneal later if you like
|
||||
watch=watch_selfplay,
|
||||
random_move_prob=random_move_prob,
|
||||
)
|
||||
replay_buffer.extend(examples)
|
||||
|
||||
|
|
@ -695,13 +718,115 @@ def train_alpha_zero(
|
|||
plot_losses(loss_history, out_path="loss_curve.png")
|
||||
|
||||
|
||||
def load_net_from_checkpoint(path: str) -> BlokuZeroNet:
|
||||
"""
|
||||
Load an BlokuZeroNet from either:
|
||||
- a plain state_dict file (saved with torch.save(net.state_dict(), ...)), or
|
||||
- a training checkpoint (dict with "model_state" key).
|
||||
"""
|
||||
net = BlokuZeroNet()
|
||||
obj = torch.load(path, map_location="cpu")
|
||||
|
||||
# If it's a training checkpoint dict, extract model_state
|
||||
if isinstance(obj, dict) and "model_state" in obj:
|
||||
state_dict = obj["model_state"]
|
||||
else:
|
||||
# Assume it's already a state_dict
|
||||
state_dict = obj
|
||||
|
||||
net.load_state_dict(state_dict)
|
||||
net.eval()
|
||||
return net
|
||||
|
||||
|
||||
def battle(
|
||||
checkpoint_a: str,
|
||||
checkpoint_b: str,
|
||||
num_games: int = 20,
|
||||
watch: bool = False,
|
||||
num_simulations: int = 50,
|
||||
):
|
||||
"""
|
||||
Load two AlphaZero checkpoints and have them battle for num_games.
|
||||
Alternates which net is player 1 for fairness.
|
||||
Prints a small win-loss matrix at the end.
|
||||
"""
|
||||
print(f"Loading net A from: {checkpoint_a}")
|
||||
print(f"Loading net B from: {checkpoint_b}")
|
||||
|
||||
net_a = load_net_from_checkpoint(checkpoint_a)
|
||||
net_b = load_net_from_checkpoint(checkpoint_b)
|
||||
|
||||
# Matrix counters
|
||||
# Rows = starting player (A as P1, B as P1)
|
||||
# Cols = winner (A, B, Draw)
|
||||
stats = {
|
||||
"A_P1": {"A": 0, "B": 0, "D": 0},
|
||||
"B_P1": {"A": 0, "B": 0, "D": 0},
|
||||
}
|
||||
|
||||
for g in range(num_games):
|
||||
if g % 2 == 0:
|
||||
# Even games: A as P1, B as P2
|
||||
start_label = "A_P1"
|
||||
winner = play_game_between_nets(
|
||||
net_a,
|
||||
net_b,
|
||||
watch=watch,
|
||||
num_simulations=num_simulations,
|
||||
)
|
||||
if winner == 1:
|
||||
stats[start_label]["A"] += 1
|
||||
elif winner == 2:
|
||||
stats[start_label]["B"] += 1
|
||||
else:
|
||||
stats[start_label]["D"] += 1
|
||||
else:
|
||||
# Odd games: B as P1, A as P2
|
||||
start_label = "B_P1"
|
||||
winner = play_game_between_nets(
|
||||
net_b,
|
||||
net_a,
|
||||
watch=watch,
|
||||
num_simulations=num_simulations,
|
||||
)
|
||||
if winner == 1:
|
||||
stats[start_label]["B"] += 1 # player 1 is B
|
||||
elif winner == 2:
|
||||
stats[start_label]["A"] += 1 # player 2 is A
|
||||
else:
|
||||
stats[start_label]["D"] += 1
|
||||
|
||||
print(f"Game {g + 1}/{num_games} finished: winner = {winner}")
|
||||
|
||||
# Aggregate totals
|
||||
total_a_wins = stats["A_P1"]["A"] + stats["B_P1"]["A"]
|
||||
total_b_wins = stats["A_P1"]["B"] + stats["B_P1"]["B"]
|
||||
total_draws = stats["A_P1"]["D"] + stats["B_P1"]["D"]
|
||||
|
||||
print("\n=== Battle results ===")
|
||||
print(f"Total games: {num_games}")
|
||||
print(f"Model A wins: {total_a_wins}")
|
||||
print(f"Model B wins: {total_b_wins}")
|
||||
print(f"Draws: {total_draws}")
|
||||
|
||||
print("\nWin-loss matrix (rows = starting player, cols = winner):")
|
||||
print(" A_win B_win Draw")
|
||||
print(
|
||||
f"Start A (P1): {stats['A_P1']['A']:5d} {stats['A_P1']['B']:5d} {stats['A_P1']['D']:5d}"
|
||||
)
|
||||
print(
|
||||
f"Start B (P1): {stats['B_P1']['A']:5d} {stats['B_P1']['B']:5d} {stats['B_P1']['D']:5d}"
|
||||
)
|
||||
|
||||
|
||||
###################
|
||||
# Play vs the AI #
|
||||
###################
|
||||
|
||||
|
||||
def az_choose_move(
|
||||
net: AlphaZeroNet, game_state, player: int, num_simulations: int = 100
|
||||
net: BlokuZeroNet, game_state, player: int, num_simulations: int = 100
|
||||
):
|
||||
"""
|
||||
Use MCTS with the trained net to choose a move for actual play.
|
||||
|
|
@ -724,7 +849,61 @@ def az_choose_move(
|
|||
return chosen_move
|
||||
|
||||
|
||||
def play_vs_ai(net: AlphaZeroNet, human_is: int = 1, num_simulations: int = 100):
|
||||
def play_game_between_nets(
|
||||
net_p1: BlokuZeroNet,
|
||||
net_p2: BlokuZeroNet,
|
||||
watch: bool = False,
|
||||
max_turns: int = 500,
|
||||
num_simulations: int = 50,
|
||||
) -> int:
|
||||
"""
|
||||
Play one game between two AlphaZero nets using MCTS for both.
|
||||
Returns:
|
||||
1 if player 1 (net_p1) wins
|
||||
2 if player 2 (net_p2) wins
|
||||
0 if draw (max_turns reached)
|
||||
"""
|
||||
game_state = initial_game_state()
|
||||
board, p1tiles, p2tiles = game_state
|
||||
|
||||
player = 1
|
||||
turns = 0
|
||||
|
||||
while True:
|
||||
turns += 1
|
||||
if turns > max_turns:
|
||||
# treat as draw
|
||||
return 0
|
||||
|
||||
# Choose which net is playing this turn
|
||||
net = net_p1 if player == 1 else net_p2
|
||||
move = az_choose_move(net, game_state, player, num_simulations=num_simulations)
|
||||
|
||||
if move is None:
|
||||
# current player cannot move -> they lose
|
||||
if player == 1:
|
||||
return 2
|
||||
else:
|
||||
return 1
|
||||
|
||||
tidx, tile, placement = move
|
||||
gp = game.Player.P1 if player == 1 else game.Player.P2
|
||||
|
||||
board.place(tile, placement, gp)
|
||||
if player == 1:
|
||||
p1tiles.remove(tidx)
|
||||
else:
|
||||
p2tiles.remove(tidx)
|
||||
|
||||
game_state = (board, p1tiles, p2tiles)
|
||||
|
||||
if watch:
|
||||
print_game_state(game_state)
|
||||
|
||||
player = 2 if player == 1 else 1
|
||||
|
||||
|
||||
def play_vs_ai(net: BlokuZeroNet, human_is: int = 1, num_simulations: int = 100):
|
||||
"""
|
||||
Let a human play against the AlphaZero-style agent.
|
||||
human_is: 1 or 2
|
||||
|
|
@ -791,7 +970,30 @@ def play_vs_ai(net: AlphaZeroNet, human_is: int = 1, num_simulations: int = 100)
|
|||
|
||||
|
||||
def main():
|
||||
net = AlphaZeroNet()
|
||||
# Battle mode: --battle ckptA ckptB [--games N] [--watch]
|
||||
if "--battle" in sys.argv:
|
||||
idx = sys.argv.index("--battle")
|
||||
try:
|
||||
ckpt_a = sys.argv[idx + 1]
|
||||
ckpt_b = sys.argv[idx + 2]
|
||||
except IndexError:
|
||||
print("Usage: blokus.py --battle ckptA ckptB [--games N] [--watch]")
|
||||
return
|
||||
|
||||
num_games = 20
|
||||
if "--games" in sys.argv:
|
||||
gidx = sys.argv.index("--games")
|
||||
try:
|
||||
num_games = int(sys.argv[gidx + 1])
|
||||
except (IndexError, ValueError):
|
||||
print("Invalid or missing value for --games, using default 20.")
|
||||
|
||||
watch = "--watch" in sys.argv
|
||||
|
||||
battle(ckpt_a, ckpt_b, num_games=num_games, watch=watch)
|
||||
return
|
||||
|
||||
net = BlokuZeroNet()
|
||||
|
||||
if "--play" in sys.argv:
|
||||
model_path = "az_trained_agent.pt"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue