something like alphazero??

This commit is contained in:
Noa Aarts 2025-11-28 00:51:15 +01:00
parent 44e30869f8
commit f0001cc0df
Signed by: noa
GPG key ID: 1850932741EFF672

456
blokus.py
View file

@ -1,4 +1,5 @@
#!/usr/bin/env python
#!/usr/bin/env python3
import random
import numpy as np
import torch
@ -45,6 +46,11 @@ tiles = [
]
def clone_state(game_state):
board, p1tiles, p2tiles = game_state
return [board.copy(), p1tiles.copy(), p2tiles.copy()]
def get_permutations(which_tiles: list[int]):
"""
For each tile index in which_tiles, generate all unique rotations/flips.
@ -95,6 +101,7 @@ def can_place(board: np.ndarray, tile: np.ndarray, player: int):
break
else:
placements.append((x, y))
final = []
if has_minus_one:
for x, y in placements:
@ -144,17 +151,20 @@ def do_placement(
tidx: int, tile: np.ndarray, placement: tuple[int, int], game_state, player: int
):
(x, y) = placement
board, p1tiles, p2tiles = game_state
with np.nditer(tile, flags=["multi_index"]) as it:
for v in it:
(i, j) = it.multi_index
if v == 1:
game_state[0][x + i, y + j] = player
game_state[player].remove(tidx)
board[x + i, y + j] = player
if player == 1:
p1tiles.remove(tidx)
else:
p2tiles.remove(tidx)
def print_game_state(game_state):
(board, p1tiles, p2tiles) = game_state
for row in board:
print(
"".join(
@ -164,7 +174,6 @@ def print_game_state(game_state):
]
)
)
print("")
print(f"Player 1 tiles left: {p1tiles}")
print(f"Player 2 tiles left: {p2tiles}")
@ -175,11 +184,22 @@ def reset_game():
board = make_board()
p1tiles = [i for i in range(21)]
p2tiles = [i for i in range(21)]
return [board, p1tiles, p2tiles] # list so it's mutable in-place
return [board, p1tiles, p2tiles] # list so it is mutable
def get_all_moves(game_state, player: int):
board, p1tiles, p2tiles = game_state
available_tiles = p1tiles if player == 1 else p2tiles
moves = []
for tidx, tile in get_permutations(available_tiles):
for placement in can_place(board, tile, player):
moves.append((tidx, tile, placement))
return moves
# =======================
# RL: encoding & policy
# AlphaZero-style network
# =======================
@ -205,7 +225,7 @@ def encode_move(
return torch.tensor([tidx, x, y, area], dtype=torch.float32)
class PolicyNet(nn.Module):
class PolicyValueNet(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
@ -216,195 +236,290 @@ class PolicyNet(nn.Module):
)
conv_out_dim = 64 * BOARD_SIZE * BOARD_SIZE # 64 * 14 * 14
self.fc = nn.Sequential(
nn.Linear(conv_out_dim + 4, 256),
# Value head (board only)
self.value_head = nn.Sequential(
nn.Linear(conv_out_dim, 128),
nn.ReLU(),
nn.Linear(256, 1), # scalar logit
nn.Linear(128, 1),
nn.Tanh(), # value in [-1, 1]
)
def forward(
self, board_tensor: torch.Tensor, move_features: torch.Tensor
) -> torch.Tensor:
# Policy head (board + move features)
self.policy_head = nn.Sequential(
nn.Linear(conv_out_dim + 4, 256),
nn.ReLU(),
nn.Linear(256, 1), # logit per move
)
def forward(self, board_tensor: torch.Tensor, move_features: torch.Tensor):
"""
board_tensor: (3, 14, 14)
move_features: (N, 4)
returns: logits (N,)
returns: (policy_logits: (N,), value: scalar)
"""
x = self.conv(board_tensor.unsqueeze(0)) # (1, 64, 14, 14)
x = x.view(1, -1) # (1, conv_out_dim)
x = x.repeat(move_features.size(0), 1) # (N, conv_out_dim)
board_embed = x
combined = torch.cat([x, move_features], dim=1) # (N, conv_out_dim + 4)
logits = self.fc(combined).squeeze(-1) # (N,)
return logits
# value head
value = self.value_head(board_embed).squeeze(0).squeeze(-1) # scalar
# policy head
x_rep = board_embed.repeat(move_features.size(0), 1) # (N, conv_out_dim)
combined = torch.cat([x_rep, move_features], dim=1) # (N, conv_out_dim + 4)
logits = self.policy_head(combined).squeeze(-1) # (N,)
return logits, value
# =======================
# RL: move generation & action selection
# MCTS (AlphaZero-style)
# =======================
def get_all_moves(game_state, player: int):
board, p1tiles, p2tiles = game_state
available_tiles = p1tiles if player == 1 else p2tiles
class MCTSNode:
def __init__(self, state, player: int):
self.state = state
self.player = player
self.is_expanded = False
self.is_terminal = False
moves = []
for tidx, tile in get_permutations(available_tiles):
for placement in can_place(board, tile, player):
moves.append((tidx, tile, placement))
self.moves: Optional[list[tuple[int, np.ndarray, tuple[int, int]]]] = None
self.priors: Optional[np.ndarray] = None
self.Nsa: Optional[np.ndarray] = None
self.Wsa: Optional[np.ndarray] = None
self.Qsa: Optional[np.ndarray] = None
return moves
self.children: dict[int, "MCTSNode"] = {}
def expand(self, net: PolicyValueNet, device="cpu") -> float:
"""
Returns value v from perspective of self.player.
"""
moves = get_all_moves(self.state, self.player)
if len(moves) == 0:
# No moves: this player loses
self.is_terminal = True
self.is_expanded = True
return -1.0
self.moves = moves
board, _, _ = self.state
board_tensor = encode_board(board, self.player).to(device)
move_feats = torch.stack(
[encode_move(tidx, tile, placement) for (tidx, tile, placement) in moves],
dim=0,
).to(device)
with torch.no_grad():
logits, value = net(board_tensor, move_feats)
probs = F.softmax(logits, dim=0).cpu().numpy()
v = float(value.item())
self.priors = probs
n = len(moves)
self.Nsa = np.zeros(n, dtype=np.float32)
self.Wsa = np.zeros(n, dtype=np.float32)
self.Qsa = np.zeros(n, dtype=np.float32)
self.is_expanded = True
return v
def select_action(self, c_puct: float = 1.5) -> int:
"""
Select action index using PUCT formula.
"""
Ns = np.sum(self.Nsa) + 1e-8
u = c_puct * self.priors * np.sqrt(Ns) / (1.0 + self.Nsa)
scores = self.Qsa + u
return int(np.argmax(scores))
def select_action(policy: PolicyNet, game_state, player: int, device="cpu"):
board, _, _ = game_state
moves = get_all_moves(game_state, player)
def mcts_search(
net: PolicyValueNet,
root_state,
root_player: int,
n_simulations: int,
device="cpu",
c_puct: float = 1.5,
):
root = MCTSNode(clone_state(root_state), root_player)
if len(moves) == 0:
return None, None # no legal moves
for _ in range(n_simulations):
node = root
path: list[tuple[MCTSNode, int]] = []
board_tensor = encode_board(board, player).to(device)
# Traverse
while True:
if not node.is_expanded:
v = node.expand(net, device)
break
if node.is_terminal:
# Value from this player's perspective is -1 (no moves)
v = -1.0
break
move_feats = torch.stack(
[encode_move(tidx, tile, placement) for (tidx, tile, placement) in moves], dim=0
).to(device)
a = node.select_action(c_puct)
path.append((node, a))
logits = policy(board_tensor, move_feats) # (N,)
probs = F.softmax(logits, dim=0)
dist = torch.distributions.Categorical(probs)
idx = dist.sample()
log_prob = dist.log_prob(idx)
if a in node.children:
node = node.children[a]
else:
# create child
child_state = clone_state(node.state)
tidx, tile, placement = node.moves[a]
do_placement(tidx, tile, placement, child_state, node.player)
next_player = 2 if node.player == 1 else 1
child = MCTSNode(child_state, next_player)
node.children[a] = child
node = child
# next loop iteration will expand it
# Backpropagate value v (from leaf player's perspective)
val = v
# Going back up the tree, the perspective alternates each move
for parent, action_index in reversed(path):
val = -val # switch to parent's perspective
parent.Nsa[action_index] += 1.0
parent.Wsa[action_index] += val
parent.Qsa[action_index] = (
parent.Wsa[action_index] / parent.Nsa[action_index]
)
chosen_move = moves[idx.item()]
return chosen_move, log_prob
# After all simulations, derive policy target from root visit counts
if not root.is_expanded or root.is_terminal or root.moves is None:
return None, None, None
visits = root.Nsa
pi = visits / np.sum(visits)
# Sample action from pi (exploration); you can use argmax for greedy play
action_index = int(np.random.choice(len(root.moves), p=pi))
return root.moves, pi, action_index
# =======================
# RL: self-play episode
# Self-play + training
# =======================
def play_episode(policy1: PolicyNet, policy2: PolicyNet, optim1, optim2, device="cpu"):
policy1.train()
policy2.train()
def self_play_game(net: PolicyValueNet, n_simulations: int, device="cpu"):
"""
Plays one self-play game using MCTS + shared network.
Returns a list of training examples:
each entry: (board_snapshot, player, moves, pi, z)
"""
game_state = reset_game()
player = 1
log_probs1 = []
log_probs2 = []
history = [] # list of dicts: board, player, moves, pi, z (filled later)
while True:
if player == 1:
move, log_prob = select_action(policy1, game_state, player, device)
else:
move, log_prob = select_action(policy2, game_state, player, device)
# No move → this player loses
if move is None:
loser = player
moves = get_all_moves(game_state, player)
if len(moves) == 0:
winner = 2 if player == 1 else 1
break
tidx, tile, placement = move
if player == 1:
log_probs1.append(log_prob)
else:
log_probs2.append(log_prob)
do_placement(tidx, tile, placement, game_state, player)
player = 2 if player == 1 else 1
print_game_state(game_state)
print(f"Player {winner} is the winner")
# Rewards: +1 for win, -1 for loss (from each player's perspective)
r1 = 1.0 if winner == 1 else -1.0
r2 = -r1
if log_probs1:
loss1 = -torch.stack(log_probs1).sum() * r1
optim1.zero_grad()
loss1.backward()
optim1.step()
if log_probs2:
loss2 = -torch.stack(log_probs2).sum() * r2
optim2.zero_grad()
loss2.backward()
optim2.step()
return r1 # from Player 1's perspective
# =======================
# Evaluation: watch them play
# =======================
def play_game(policy1: PolicyNet, policy2: PolicyNet, device="cpu"):
policy1.eval()
policy2.eval()
game_state = reset_game()
player = 1
while True:
print_game_state(game_state)
if player == 1:
move, _ = select_action(policy1, game_state, player, device)
else:
move, _ = select_action(policy2, game_state, player, device)
if move is None:
print(f"No moves left, player {player} lost")
# Run MCTS from current state
mcts_moves, pi, a_idx = mcts_search(
net, game_state, player, n_simulations, device
)
if mcts_moves is None:
winner = 2 if player == 1 else 1
break
tidx, tile, placement = move
do_placement(tidx, tile, placement, game_state, player)
# Save training position (copy board only; moves are references)
board_snapshot = game_state[0].copy()
history.append(
{
"board": board_snapshot,
"player": player,
"moves": mcts_moves,
"pi": pi,
"z": None, # fill after game
}
)
# Play chosen move
tidx, tile, placement = mcts_moves[a_idx]
do_placement(tidx, tile, placement, game_state, player)
player = 2 if player == 1 else 1
def load_policy(path, device="cpu"):
policy = PolicyNet().to(device)
policy.load_state_dict(torch.load(path, map_location=device))
policy.eval()
return policy
# Game finished, assign outcomes
for entry in history:
entry["z"] = 1.0 if entry["player"] == winner else -1.0
def human_vs_ai(ai_policy: PolicyNet, device="cpu"):
ai_policy.eval()
return history, winner
def train_on_history(net: PolicyValueNet, optimizer, history, device="cpu"):
"""
Single gradient step over all positions from one self-play game.
"""
net.train()
optimizer.zero_grad()
total_loss = 0.0
for entry in history:
board = entry["board"]
player = entry["player"]
moves = entry["moves"]
pi = entry["pi"]
z = entry["z"]
board_tensor = encode_board(board, player).to(device)
move_feats = torch.stack(
[encode_move(tidx, tile, placement) for (tidx, tile, placement) in moves],
dim=0,
).to(device)
target_pi = torch.from_numpy(pi).to(device)
target_z = torch.tensor(z, dtype=torch.float32, device=device)
logits, value = net(board_tensor, move_feats)
log_probs = F.log_softmax(logits, dim=0)
policy_loss = -(target_pi * log_probs).sum()
value_loss = F.mse_loss(value, target_z)
loss = policy_loss + value_loss
total_loss += loss
if len(history) > 0:
total_loss = total_loss / len(history)
total_loss.backward()
optimizer.step()
return float(total_loss.item())
# =======================
# Simple evaluation game
# =======================
def play_game_with_mcts(net: PolicyValueNet, n_simulations: int, device="cpu"):
"""
Watch two MCTS+net players (same weights) play against each other.
"""
net.eval()
game_state = reset_game()
player = 1 # AI goes first
player = 1
while True:
print_game_state(game_state)
moves = get_all_moves(game_state, player)
if not moves:
print(f"No moves left, player {player} loses.")
break
# Who moves?
if player == 1:
print("AI thinking...")
move, _ = select_action(ai_policy, game_state, player, device)
if move is None:
print("AI has no moves — AI loses!")
break
tidx, tile, placement = move
print(f"AI plays tile {tidx} at {placement}\n")
else:
# human turn
moves = get_all_moves(game_state, player)
if not moves:
print("You have no moves — you lose!")
break
mcts_moves, pi, a_idx = mcts_search(
net, game_state, player, n_simulations, device
)
if mcts_moves is None:
print(f"No moves left (MCTS), player {player} loses.")
break
print("Your legal moves:")
for i, (tidx, tile, placement) in enumerate(moves):
print(f"{i}: tile {tidx} at {placement}")
choice = int(input("Choose move number: "))
tidx, tile, placement = moves[choice]
# Apply move
tidx, tile, placement = mcts_moves[a_idx]
print(f"Player {player} plays tile {tidx} at {placement}")
do_placement(tidx, tile, placement, game_state, player)
# Switch players
player = 2 if player == 1 else 1
@ -417,45 +532,28 @@ def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
policy1 = PolicyNet().to(device)
policy2 = PolicyNet().to(device)
net = PolicyValueNet().to(device)
optimizer = optim.Adam(net.parameters(), lr=1e-3)
optim1 = optim.Adam(policy1.parameters(), lr=1e-3)
optim2 = optim.Adam(policy2.parameters(), lr=1e-3)
num_games = 200 # increase a lot for real training
n_simulations = 50 # MCTS sims per move (increase if it's too weak)
best_avg_reward = -999
reward_history = []
for g in range(1, num_games + 1):
history, winner = self_play_game(net, n_simulations, device)
loss = train_on_history(net, optimizer, history, device)
num_episodes = 2000
for episode in range(1, num_episodes + 1):
reward = play_episode(policy1, policy2, optim1, optim2, device=device)
reward_history.append(reward)
print(
f"Game {g}/{num_games}, winner: Player {winner}, loss: {loss:.4f}, positions: {len(history)}"
)
# compute moving average every 50 episodes
if len(reward_history) >= 50:
avg = sum(reward_history[-50:]) / 50
# occasionally watch a game
if g % 50 == 0:
print("Watching a game with current network:")
play_game_with_mcts(net, n_simulations=30, device=device)
# If policy1 improved, save it
if avg > best_avg_reward:
best_avg_reward = avg
torch.save(policy1.state_dict(), "best_policy1.pth")
print(f"Saved best policy1 at episode {episode} (avg reward={avg:.3f})")
if episode % 100 == 0:
print(f"Episode {episode}, last reward={reward}")
print("Training complete.")
print("1 = Watch AI vs AI")
print("2 = Play against AI")
print("3 = Quit")
choice = input("Select: ")
if choice == "1":
play_game(policy1, policy2, device)
elif choice == "2":
best_ai = load_policy("best_policy1.pth", device)
human_vs_ai(best_ai, device)
# Save final network
torch.save(net.state_dict(), "alphazero_blokus_net.pth")
print("Saved network to alphazero_blokus_net.pth")
if __name__ == "__main__":