add battles, random move injection and rename to BlokuZero
This commit is contained in:
parent
a7f9e68939
commit
e405a8749e
1 changed files with 225 additions and 23 deletions
248
blokus.py
248
blokus.py
|
|
@ -73,7 +73,7 @@ def plot_losses(loss_history, out_path="loss_curve.png"):
|
||||||
plt.plot(range(1, len(loss_history) + 1), loss_history)
|
plt.plot(range(1, len(loss_history) + 1), loss_history)
|
||||||
plt.xlabel("Training iteration")
|
plt.xlabel("Training iteration")
|
||||||
plt.ylabel("Loss")
|
plt.ylabel("Loss")
|
||||||
plt.title("AlphaZero training loss")
|
plt.title("BlokuZero training loss")
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig(out_path)
|
plt.savefig(out_path)
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
@ -175,7 +175,7 @@ MOVE_DIM = 23
|
||||||
##############################
|
##############################
|
||||||
|
|
||||||
|
|
||||||
class AlphaZeroNet(nn.Module):
|
class BlokuZeroNet(nn.Module):
|
||||||
"""
|
"""
|
||||||
AlphaZero-style network for this Blokus-like game.
|
AlphaZero-style network for this Blokus-like game.
|
||||||
|
|
||||||
|
|
@ -245,7 +245,7 @@ class AlphaZeroNet(nn.Module):
|
||||||
return probs, logits
|
return probs, logits
|
||||||
|
|
||||||
|
|
||||||
def net_predict(net: AlphaZeroNet, game_state, player: int, moves):
|
def net_predict(net: BlokuZeroNet, game_state, player: int, moves):
|
||||||
"""
|
"""
|
||||||
Evaluate net on a position from the perspective of 'player'.
|
Evaluate net on a position from the perspective of 'player'.
|
||||||
|
|
||||||
|
|
@ -322,7 +322,7 @@ class MCTSNode:
|
||||||
self.is_terminal = False
|
self.is_terminal = False
|
||||||
self.value = 0.0
|
self.value = 0.0
|
||||||
|
|
||||||
def expand(self, net: AlphaZeroNet):
|
def expand(self, net: BlokuZeroNet):
|
||||||
"""
|
"""
|
||||||
Expand this node: generate legal moves, get priors & value from net.
|
Expand this node: generate legal moves, get priors & value from net.
|
||||||
"""
|
"""
|
||||||
|
|
@ -399,7 +399,7 @@ class MCTSNode:
|
||||||
def mcts_search(
|
def mcts_search(
|
||||||
root_state,
|
root_state,
|
||||||
root_player: int,
|
root_player: int,
|
||||||
net: AlphaZeroNet,
|
net: BlokuZeroNet,
|
||||||
num_simulations: int = 50,
|
num_simulations: int = 50,
|
||||||
c_puct: float = 1.5,
|
c_puct: float = 1.5,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
|
|
@ -487,10 +487,11 @@ def mcts_search(
|
||||||
|
|
||||||
|
|
||||||
def self_play_game(
|
def self_play_game(
|
||||||
net: AlphaZeroNet,
|
net: BlokuZeroNet,
|
||||||
num_simulations: int = 50,
|
num_simulations: int = 50,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
watch: bool = False,
|
watch: bool = False,
|
||||||
|
random_move_prob: float = 0.0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Play one self-play game using MCTS + net.
|
Play one self-play game using MCTS + net.
|
||||||
|
|
@ -516,18 +517,29 @@ def self_play_game(
|
||||||
break
|
break
|
||||||
|
|
||||||
state_vec = encode_state(game_state, player)
|
state_vec = encode_state(game_state, player)
|
||||||
pi, chosen_move, move_list = mcts_search(
|
|
||||||
game_state,
|
|
||||||
player,
|
|
||||||
net,
|
|
||||||
num_simulations=num_simulations,
|
|
||||||
c_puct=1.5,
|
|
||||||
temperature=temperature,
|
|
||||||
)
|
|
||||||
|
|
||||||
if pi is None or chosen_move is None:
|
if random.random() < random_move_prob:
|
||||||
winner = 2 if player == 1 else 1
|
print("PERFORMING RANDOM MOVE!!")
|
||||||
break
|
move_list = moves
|
||||||
|
num_m = len(move_list)
|
||||||
|
# uniform policy over legal moves
|
||||||
|
pi = torch.full((num_m,), 1.0 / num_m, dtype=torch.float32)
|
||||||
|
idx = random.randrange(num_m)
|
||||||
|
chosen_move = move_list[idx]
|
||||||
|
else:
|
||||||
|
# usual AlphaZero MCTS move
|
||||||
|
pi, chosen_move, move_list = mcts_search(
|
||||||
|
game_state,
|
||||||
|
player,
|
||||||
|
net,
|
||||||
|
num_simulations=num_simulations,
|
||||||
|
c_puct=1.5,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
if pi is None or chosen_move is None:
|
||||||
|
winner = 2 if player == 1 else 1
|
||||||
|
break
|
||||||
|
|
||||||
trajectory.append(
|
trajectory.append(
|
||||||
{
|
{
|
||||||
|
|
@ -567,7 +579,7 @@ def self_play_game(
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
def alpha_zero_train_step(net: AlphaZeroNet, optimizer, batch):
|
def alpha_zero_train_step(net: BlokuZeroNet, optimizer, batch):
|
||||||
"""
|
"""
|
||||||
One gradient update on a batch of (state_vec, moves, pi_target, z_target).
|
One gradient update on a batch of (state_vec, moves, pi_target, z_target).
|
||||||
"""
|
"""
|
||||||
|
|
@ -610,7 +622,7 @@ def alpha_zero_train_step(net: AlphaZeroNet, optimizer, batch):
|
||||||
|
|
||||||
|
|
||||||
def train_alpha_zero(
|
def train_alpha_zero(
|
||||||
net: AlphaZeroNet,
|
net: BlokuZeroNet,
|
||||||
num_iterations: int = 50,
|
num_iterations: int = 50,
|
||||||
games_per_iter: int = 5,
|
games_per_iter: int = 5,
|
||||||
num_simulations: int = 50,
|
num_simulations: int = 50,
|
||||||
|
|
@ -647,16 +659,27 @@ def train_alpha_zero(
|
||||||
start_iter, num_iterations + 1, desc="AZ Training", dynamic_ncols=True
|
start_iter, num_iterations + 1, desc="AZ Training", dynamic_ncols=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# simple schedule for random moves in early training
|
||||||
|
warmup_iters = 50 # how many iterations to use randomness
|
||||||
|
max_random_prob = 0.5 # random move probability at iteration 1
|
||||||
|
|
||||||
for it in pbar:
|
for it in pbar:
|
||||||
replay_buffer = []
|
replay_buffer = []
|
||||||
|
|
||||||
|
# linearly decay random_move_prob from max_random_prob -> 0 over warmup_iters
|
||||||
|
if it <= warmup_iters:
|
||||||
|
random_move_prob = max_random_prob * (1.0 - (it - 1) / warmup_iters)
|
||||||
|
else:
|
||||||
|
random_move_prob = 0.0
|
||||||
|
|
||||||
# 1. Self-play games to generate fresh data
|
# 1. Self-play games to generate fresh data
|
||||||
for g in range(games_per_iter):
|
for g in range(games_per_iter):
|
||||||
examples = self_play_game(
|
examples = self_play_game(
|
||||||
net,
|
net,
|
||||||
num_simulations=num_simulations,
|
num_simulations=num_simulations,
|
||||||
temperature=1.0, # can anneal later
|
temperature=1.0, # can also anneal later if you like
|
||||||
watch=watch_selfplay,
|
watch=watch_selfplay,
|
||||||
|
random_move_prob=random_move_prob,
|
||||||
)
|
)
|
||||||
replay_buffer.extend(examples)
|
replay_buffer.extend(examples)
|
||||||
|
|
||||||
|
|
@ -695,13 +718,115 @@ def train_alpha_zero(
|
||||||
plot_losses(loss_history, out_path="loss_curve.png")
|
plot_losses(loss_history, out_path="loss_curve.png")
|
||||||
|
|
||||||
|
|
||||||
|
def load_net_from_checkpoint(path: str) -> BlokuZeroNet:
|
||||||
|
"""
|
||||||
|
Load an BlokuZeroNet from either:
|
||||||
|
- a plain state_dict file (saved with torch.save(net.state_dict(), ...)), or
|
||||||
|
- a training checkpoint (dict with "model_state" key).
|
||||||
|
"""
|
||||||
|
net = BlokuZeroNet()
|
||||||
|
obj = torch.load(path, map_location="cpu")
|
||||||
|
|
||||||
|
# If it's a training checkpoint dict, extract model_state
|
||||||
|
if isinstance(obj, dict) and "model_state" in obj:
|
||||||
|
state_dict = obj["model_state"]
|
||||||
|
else:
|
||||||
|
# Assume it's already a state_dict
|
||||||
|
state_dict = obj
|
||||||
|
|
||||||
|
net.load_state_dict(state_dict)
|
||||||
|
net.eval()
|
||||||
|
return net
|
||||||
|
|
||||||
|
|
||||||
|
def battle(
|
||||||
|
checkpoint_a: str,
|
||||||
|
checkpoint_b: str,
|
||||||
|
num_games: int = 20,
|
||||||
|
watch: bool = False,
|
||||||
|
num_simulations: int = 50,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load two AlphaZero checkpoints and have them battle for num_games.
|
||||||
|
Alternates which net is player 1 for fairness.
|
||||||
|
Prints a small win-loss matrix at the end.
|
||||||
|
"""
|
||||||
|
print(f"Loading net A from: {checkpoint_a}")
|
||||||
|
print(f"Loading net B from: {checkpoint_b}")
|
||||||
|
|
||||||
|
net_a = load_net_from_checkpoint(checkpoint_a)
|
||||||
|
net_b = load_net_from_checkpoint(checkpoint_b)
|
||||||
|
|
||||||
|
# Matrix counters
|
||||||
|
# Rows = starting player (A as P1, B as P1)
|
||||||
|
# Cols = winner (A, B, Draw)
|
||||||
|
stats = {
|
||||||
|
"A_P1": {"A": 0, "B": 0, "D": 0},
|
||||||
|
"B_P1": {"A": 0, "B": 0, "D": 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for g in range(num_games):
|
||||||
|
if g % 2 == 0:
|
||||||
|
# Even games: A as P1, B as P2
|
||||||
|
start_label = "A_P1"
|
||||||
|
winner = play_game_between_nets(
|
||||||
|
net_a,
|
||||||
|
net_b,
|
||||||
|
watch=watch,
|
||||||
|
num_simulations=num_simulations,
|
||||||
|
)
|
||||||
|
if winner == 1:
|
||||||
|
stats[start_label]["A"] += 1
|
||||||
|
elif winner == 2:
|
||||||
|
stats[start_label]["B"] += 1
|
||||||
|
else:
|
||||||
|
stats[start_label]["D"] += 1
|
||||||
|
else:
|
||||||
|
# Odd games: B as P1, A as P2
|
||||||
|
start_label = "B_P1"
|
||||||
|
winner = play_game_between_nets(
|
||||||
|
net_b,
|
||||||
|
net_a,
|
||||||
|
watch=watch,
|
||||||
|
num_simulations=num_simulations,
|
||||||
|
)
|
||||||
|
if winner == 1:
|
||||||
|
stats[start_label]["B"] += 1 # player 1 is B
|
||||||
|
elif winner == 2:
|
||||||
|
stats[start_label]["A"] += 1 # player 2 is A
|
||||||
|
else:
|
||||||
|
stats[start_label]["D"] += 1
|
||||||
|
|
||||||
|
print(f"Game {g + 1}/{num_games} finished: winner = {winner}")
|
||||||
|
|
||||||
|
# Aggregate totals
|
||||||
|
total_a_wins = stats["A_P1"]["A"] + stats["B_P1"]["A"]
|
||||||
|
total_b_wins = stats["A_P1"]["B"] + stats["B_P1"]["B"]
|
||||||
|
total_draws = stats["A_P1"]["D"] + stats["B_P1"]["D"]
|
||||||
|
|
||||||
|
print("\n=== Battle results ===")
|
||||||
|
print(f"Total games: {num_games}")
|
||||||
|
print(f"Model A wins: {total_a_wins}")
|
||||||
|
print(f"Model B wins: {total_b_wins}")
|
||||||
|
print(f"Draws: {total_draws}")
|
||||||
|
|
||||||
|
print("\nWin-loss matrix (rows = starting player, cols = winner):")
|
||||||
|
print(" A_win B_win Draw")
|
||||||
|
print(
|
||||||
|
f"Start A (P1): {stats['A_P1']['A']:5d} {stats['A_P1']['B']:5d} {stats['A_P1']['D']:5d}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Start B (P1): {stats['B_P1']['A']:5d} {stats['B_P1']['B']:5d} {stats['B_P1']['D']:5d}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
###################
|
###################
|
||||||
# Play vs the AI #
|
# Play vs the AI #
|
||||||
###################
|
###################
|
||||||
|
|
||||||
|
|
||||||
def az_choose_move(
|
def az_choose_move(
|
||||||
net: AlphaZeroNet, game_state, player: int, num_simulations: int = 100
|
net: BlokuZeroNet, game_state, player: int, num_simulations: int = 100
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Use MCTS with the trained net to choose a move for actual play.
|
Use MCTS with the trained net to choose a move for actual play.
|
||||||
|
|
@ -724,7 +849,61 @@ def az_choose_move(
|
||||||
return chosen_move
|
return chosen_move
|
||||||
|
|
||||||
|
|
||||||
def play_vs_ai(net: AlphaZeroNet, human_is: int = 1, num_simulations: int = 100):
|
def play_game_between_nets(
|
||||||
|
net_p1: BlokuZeroNet,
|
||||||
|
net_p2: BlokuZeroNet,
|
||||||
|
watch: bool = False,
|
||||||
|
max_turns: int = 500,
|
||||||
|
num_simulations: int = 50,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Play one game between two AlphaZero nets using MCTS for both.
|
||||||
|
Returns:
|
||||||
|
1 if player 1 (net_p1) wins
|
||||||
|
2 if player 2 (net_p2) wins
|
||||||
|
0 if draw (max_turns reached)
|
||||||
|
"""
|
||||||
|
game_state = initial_game_state()
|
||||||
|
board, p1tiles, p2tiles = game_state
|
||||||
|
|
||||||
|
player = 1
|
||||||
|
turns = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
turns += 1
|
||||||
|
if turns > max_turns:
|
||||||
|
# treat as draw
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Choose which net is playing this turn
|
||||||
|
net = net_p1 if player == 1 else net_p2
|
||||||
|
move = az_choose_move(net, game_state, player, num_simulations=num_simulations)
|
||||||
|
|
||||||
|
if move is None:
|
||||||
|
# current player cannot move -> they lose
|
||||||
|
if player == 1:
|
||||||
|
return 2
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
tidx, tile, placement = move
|
||||||
|
gp = game.Player.P1 if player == 1 else game.Player.P2
|
||||||
|
|
||||||
|
board.place(tile, placement, gp)
|
||||||
|
if player == 1:
|
||||||
|
p1tiles.remove(tidx)
|
||||||
|
else:
|
||||||
|
p2tiles.remove(tidx)
|
||||||
|
|
||||||
|
game_state = (board, p1tiles, p2tiles)
|
||||||
|
|
||||||
|
if watch:
|
||||||
|
print_game_state(game_state)
|
||||||
|
|
||||||
|
player = 2 if player == 1 else 1
|
||||||
|
|
||||||
|
|
||||||
|
def play_vs_ai(net: BlokuZeroNet, human_is: int = 1, num_simulations: int = 100):
|
||||||
"""
|
"""
|
||||||
Let a human play against the AlphaZero-style agent.
|
Let a human play against the AlphaZero-style agent.
|
||||||
human_is: 1 or 2
|
human_is: 1 or 2
|
||||||
|
|
@ -791,7 +970,30 @@ def play_vs_ai(net: AlphaZeroNet, human_is: int = 1, num_simulations: int = 100)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
net = AlphaZeroNet()
|
# Battle mode: --battle ckptA ckptB [--games N] [--watch]
|
||||||
|
if "--battle" in sys.argv:
|
||||||
|
idx = sys.argv.index("--battle")
|
||||||
|
try:
|
||||||
|
ckpt_a = sys.argv[idx + 1]
|
||||||
|
ckpt_b = sys.argv[idx + 2]
|
||||||
|
except IndexError:
|
||||||
|
print("Usage: blokus.py --battle ckptA ckptB [--games N] [--watch]")
|
||||||
|
return
|
||||||
|
|
||||||
|
num_games = 20
|
||||||
|
if "--games" in sys.argv:
|
||||||
|
gidx = sys.argv.index("--games")
|
||||||
|
try:
|
||||||
|
num_games = int(sys.argv[gidx + 1])
|
||||||
|
except (IndexError, ValueError):
|
||||||
|
print("Invalid or missing value for --games, using default 20.")
|
||||||
|
|
||||||
|
watch = "--watch" in sys.argv
|
||||||
|
|
||||||
|
battle(ckpt_a, ckpt_b, num_games=num_games, watch=watch)
|
||||||
|
return
|
||||||
|
|
||||||
|
net = BlokuZeroNet()
|
||||||
|
|
||||||
if "--play" in sys.argv:
|
if "--play" in sys.argv:
|
||||||
model_path = "az_trained_agent.pt"
|
model_path = "az_trained_agent.pt"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue