diff --git a/blokus.py b/blokus.py index 8a14e31..5eab2e0 100755 --- a/blokus.py +++ b/blokus.py @@ -1,4 +1,5 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 +import random import numpy as np import torch @@ -45,6 +46,11 @@ tiles = [ ] +def clone_state(game_state): + board, p1tiles, p2tiles = game_state + return [board.copy(), p1tiles.copy(), p2tiles.copy()] + + def get_permutations(which_tiles: list[int]): """ For each tile index in which_tiles, generate all unique rotations/flips. @@ -95,6 +101,7 @@ def can_place(board: np.ndarray, tile: np.ndarray, player: int): break else: placements.append((x, y)) + final = [] if has_minus_one: for x, y in placements: @@ -144,17 +151,20 @@ def do_placement( tidx: int, tile: np.ndarray, placement: tuple[int, int], game_state, player: int ): (x, y) = placement + board, p1tiles, p2tiles = game_state with np.nditer(tile, flags=["multi_index"]) as it: for v in it: (i, j) = it.multi_index if v == 1: - game_state[0][x + i, y + j] = player - game_state[player].remove(tidx) + board[x + i, y + j] = player + if player == 1: + p1tiles.remove(tidx) + else: + p2tiles.remove(tidx) def print_game_state(game_state): (board, p1tiles, p2tiles) = game_state - for row in board: print( "".join( @@ -164,7 +174,6 @@ def print_game_state(game_state): ] ) ) - print("") print(f"Player 1 tiles left: {p1tiles}") print(f"Player 2 tiles left: {p2tiles}") @@ -175,11 +184,22 @@ def reset_game(): board = make_board() p1tiles = [i for i in range(21)] p2tiles = [i for i in range(21)] - return [board, p1tiles, p2tiles] # list so it's mutable in-place + return [board, p1tiles, p2tiles] # list so it is mutable + + +def get_all_moves(game_state, player: int): + board, p1tiles, p2tiles = game_state + available_tiles = p1tiles if player == 1 else p2tiles + + moves = [] + for tidx, tile in get_permutations(available_tiles): + for placement in can_place(board, tile, player): + moves.append((tidx, tile, placement)) + return moves # ======================= -# RL: encoding & policy +# AlphaZero-style network # ======================= @@ -205,7 +225,7 @@ def encode_move( return torch.tensor([tidx, x, y, area], dtype=torch.float32) -class PolicyNet(nn.Module): +class PolicyValueNet(nn.Module): def __init__(self): super().__init__() self.conv = nn.Sequential( @@ -216,195 +236,290 @@ class PolicyNet(nn.Module): ) conv_out_dim = 64 * BOARD_SIZE * BOARD_SIZE # 64 * 14 * 14 - self.fc = nn.Sequential( - nn.Linear(conv_out_dim + 4, 256), + # Value head (board only) + self.value_head = nn.Sequential( + nn.Linear(conv_out_dim, 128), nn.ReLU(), - nn.Linear(256, 1), # scalar logit + nn.Linear(128, 1), + nn.Tanh(), # value in [-1, 1] ) - def forward( - self, board_tensor: torch.Tensor, move_features: torch.Tensor - ) -> torch.Tensor: + # Policy head (board + move features) + self.policy_head = nn.Sequential( + nn.Linear(conv_out_dim + 4, 256), + nn.ReLU(), + nn.Linear(256, 1), # logit per move + ) + + def forward(self, board_tensor: torch.Tensor, move_features: torch.Tensor): """ board_tensor: (3, 14, 14) move_features: (N, 4) - returns: logits (N,) + returns: (policy_logits: (N,), value: scalar) """ x = self.conv(board_tensor.unsqueeze(0)) # (1, 64, 14, 14) x = x.view(1, -1) # (1, conv_out_dim) - x = x.repeat(move_features.size(0), 1) # (N, conv_out_dim) + board_embed = x - combined = torch.cat([x, move_features], dim=1) # (N, conv_out_dim + 4) - logits = self.fc(combined).squeeze(-1) # (N,) - return logits + # value head + value = self.value_head(board_embed).squeeze(0).squeeze(-1) # scalar + + # policy head + x_rep = board_embed.repeat(move_features.size(0), 1) # (N, conv_out_dim) + combined = torch.cat([x_rep, move_features], dim=1) # (N, conv_out_dim + 4) + logits = self.policy_head(combined).squeeze(-1) # (N,) + + return logits, value # ======================= -# RL: move generation & action selection +# MCTS (AlphaZero-style) # ======================= -def get_all_moves(game_state, player: int): - board, p1tiles, p2tiles = game_state - available_tiles = p1tiles if player == 1 else p2tiles +class MCTSNode: + def __init__(self, state, player: int): + self.state = state + self.player = player + self.is_expanded = False + self.is_terminal = False - moves = [] - for tidx, tile in get_permutations(available_tiles): - for placement in can_place(board, tile, player): - moves.append((tidx, tile, placement)) + self.moves: Optional[list[tuple[int, np.ndarray, tuple[int, int]]]] = None + self.priors: Optional[np.ndarray] = None + self.Nsa: Optional[np.ndarray] = None + self.Wsa: Optional[np.ndarray] = None + self.Qsa: Optional[np.ndarray] = None - return moves + self.children: dict[int, "MCTSNode"] = {} + + def expand(self, net: PolicyValueNet, device="cpu") -> float: + """ + Returns value v from perspective of self.player. + """ + moves = get_all_moves(self.state, self.player) + if len(moves) == 0: + # No moves: this player loses + self.is_terminal = True + self.is_expanded = True + return -1.0 + + self.moves = moves + + board, _, _ = self.state + board_tensor = encode_board(board, self.player).to(device) + move_feats = torch.stack( + [encode_move(tidx, tile, placement) for (tidx, tile, placement) in moves], + dim=0, + ).to(device) + + with torch.no_grad(): + logits, value = net(board_tensor, move_feats) + probs = F.softmax(logits, dim=0).cpu().numpy() + v = float(value.item()) + + self.priors = probs + n = len(moves) + self.Nsa = np.zeros(n, dtype=np.float32) + self.Wsa = np.zeros(n, dtype=np.float32) + self.Qsa = np.zeros(n, dtype=np.float32) + self.is_expanded = True + return v + + def select_action(self, c_puct: float = 1.5) -> int: + """ + Select action index using PUCT formula. + """ + Ns = np.sum(self.Nsa) + 1e-8 + u = c_puct * self.priors * np.sqrt(Ns) / (1.0 + self.Nsa) + scores = self.Qsa + u + return int(np.argmax(scores)) -def select_action(policy: PolicyNet, game_state, player: int, device="cpu"): - board, _, _ = game_state - moves = get_all_moves(game_state, player) +def mcts_search( + net: PolicyValueNet, + root_state, + root_player: int, + n_simulations: int, + device="cpu", + c_puct: float = 1.5, +): + root = MCTSNode(clone_state(root_state), root_player) - if len(moves) == 0: - return None, None # no legal moves + for _ in range(n_simulations): + node = root + path: list[tuple[MCTSNode, int]] = [] - board_tensor = encode_board(board, player).to(device) + # Traverse + while True: + if not node.is_expanded: + v = node.expand(net, device) + break + if node.is_terminal: + # Value from this player's perspective is -1 (no moves) + v = -1.0 + break - move_feats = torch.stack( - [encode_move(tidx, tile, placement) for (tidx, tile, placement) in moves], dim=0 - ).to(device) + a = node.select_action(c_puct) + path.append((node, a)) - logits = policy(board_tensor, move_feats) # (N,) - probs = F.softmax(logits, dim=0) - dist = torch.distributions.Categorical(probs) - idx = dist.sample() - log_prob = dist.log_prob(idx) + if a in node.children: + node = node.children[a] + else: + # create child + child_state = clone_state(node.state) + tidx, tile, placement = node.moves[a] + do_placement(tidx, tile, placement, child_state, node.player) + next_player = 2 if node.player == 1 else 1 + child = MCTSNode(child_state, next_player) + node.children[a] = child + node = child + # next loop iteration will expand it + # Backpropagate value v (from leaf player's perspective) + val = v + # Going back up the tree, the perspective alternates each move + for parent, action_index in reversed(path): + val = -val # switch to parent's perspective + parent.Nsa[action_index] += 1.0 + parent.Wsa[action_index] += val + parent.Qsa[action_index] = ( + parent.Wsa[action_index] / parent.Nsa[action_index] + ) - chosen_move = moves[idx.item()] - return chosen_move, log_prob + # After all simulations, derive policy target from root visit counts + if not root.is_expanded or root.is_terminal or root.moves is None: + return None, None, None + + visits = root.Nsa + pi = visits / np.sum(visits) + # Sample action from pi (exploration); you can use argmax for greedy play + action_index = int(np.random.choice(len(root.moves), p=pi)) + return root.moves, pi, action_index # ======================= -# RL: self-play episode +# Self-play + training # ======================= -def play_episode(policy1: PolicyNet, policy2: PolicyNet, optim1, optim2, device="cpu"): - policy1.train() - policy2.train() +def self_play_game(net: PolicyValueNet, n_simulations: int, device="cpu"): + """ + Plays one self-play game using MCTS + shared network. + Returns a list of training examples: + each entry: (board_snapshot, player, moves, pi, z) + """ game_state = reset_game() player = 1 - - log_probs1 = [] - log_probs2 = [] + history = [] # list of dicts: board, player, moves, pi, z (filled later) while True: - if player == 1: - move, log_prob = select_action(policy1, game_state, player, device) - else: - move, log_prob = select_action(policy2, game_state, player, device) - - # No move → this player loses - if move is None: - loser = player + moves = get_all_moves(game_state, player) + if len(moves) == 0: winner = 2 if player == 1 else 1 break - tidx, tile, placement = move - - if player == 1: - log_probs1.append(log_prob) - else: - log_probs2.append(log_prob) - - do_placement(tidx, tile, placement, game_state, player) - - player = 2 if player == 1 else 1 - - print_game_state(game_state) - print(f"Player {winner} is the winner") - # Rewards: +1 for win, -1 for loss (from each player's perspective) - r1 = 1.0 if winner == 1 else -1.0 - r2 = -r1 - - if log_probs1: - loss1 = -torch.stack(log_probs1).sum() * r1 - optim1.zero_grad() - loss1.backward() - optim1.step() - - if log_probs2: - loss2 = -torch.stack(log_probs2).sum() * r2 - optim2.zero_grad() - loss2.backward() - optim2.step() - - return r1 # from Player 1's perspective - - -# ======================= -# Evaluation: watch them play -# ======================= - - -def play_game(policy1: PolicyNet, policy2: PolicyNet, device="cpu"): - policy1.eval() - policy2.eval() - game_state = reset_game() - player = 1 - while True: - print_game_state(game_state) - - if player == 1: - move, _ = select_action(policy1, game_state, player, device) - else: - move, _ = select_action(policy2, game_state, player, device) - - if move is None: - print(f"No moves left, player {player} lost") + # Run MCTS from current state + mcts_moves, pi, a_idx = mcts_search( + net, game_state, player, n_simulations, device + ) + if mcts_moves is None: + winner = 2 if player == 1 else 1 break - tidx, tile, placement = move - do_placement(tidx, tile, placement, game_state, player) + # Save training position (copy board only; moves are references) + board_snapshot = game_state[0].copy() + history.append( + { + "board": board_snapshot, + "player": player, + "moves": mcts_moves, + "pi": pi, + "z": None, # fill after game + } + ) + # Play chosen move + tidx, tile, placement = mcts_moves[a_idx] + do_placement(tidx, tile, placement, game_state, player) player = 2 if player == 1 else 1 -def load_policy(path, device="cpu"): - policy = PolicyNet().to(device) - policy.load_state_dict(torch.load(path, map_location=device)) - policy.eval() - return policy + # Game finished, assign outcomes + for entry in history: + entry["z"] = 1.0 if entry["player"] == winner else -1.0 -def human_vs_ai(ai_policy: PolicyNet, device="cpu"): - ai_policy.eval() + return history, winner + + +def train_on_history(net: PolicyValueNet, optimizer, history, device="cpu"): + """ + Single gradient step over all positions from one self-play game. + """ + net.train() + optimizer.zero_grad() + total_loss = 0.0 + + for entry in history: + board = entry["board"] + player = entry["player"] + moves = entry["moves"] + pi = entry["pi"] + z = entry["z"] + + board_tensor = encode_board(board, player).to(device) + move_feats = torch.stack( + [encode_move(tidx, tile, placement) for (tidx, tile, placement) in moves], + dim=0, + ).to(device) + target_pi = torch.from_numpy(pi).to(device) + target_z = torch.tensor(z, dtype=torch.float32, device=device) + + logits, value = net(board_tensor, move_feats) + log_probs = F.log_softmax(logits, dim=0) + + policy_loss = -(target_pi * log_probs).sum() + value_loss = F.mse_loss(value, target_z) + + loss = policy_loss + value_loss + total_loss += loss + + if len(history) > 0: + total_loss = total_loss / len(history) + + total_loss.backward() + optimizer.step() + return float(total_loss.item()) + + +# ======================= +# Simple evaluation game +# ======================= + + +def play_game_with_mcts(net: PolicyValueNet, n_simulations: int, device="cpu"): + """ + Watch two MCTS+net players (same weights) play against each other. + """ + net.eval() game_state = reset_game() - player = 1 # AI goes first + player = 1 while True: print_game_state(game_state) + moves = get_all_moves(game_state, player) + if not moves: + print(f"No moves left, player {player} loses.") + break - # Who moves? - if player == 1: - print("AI thinking...") - move, _ = select_action(ai_policy, game_state, player, device) - if move is None: - print("AI has no moves — AI loses!") - break - tidx, tile, placement = move - print(f"AI plays tile {tidx} at {placement}\n") - else: - # human turn - moves = get_all_moves(game_state, player) - if not moves: - print("You have no moves — you lose!") - break + mcts_moves, pi, a_idx = mcts_search( + net, game_state, player, n_simulations, device + ) + if mcts_moves is None: + print(f"No moves left (MCTS), player {player} loses.") + break - print("Your legal moves:") - for i, (tidx, tile, placement) in enumerate(moves): - print(f"{i}: tile {tidx} at {placement}") - - choice = int(input("Choose move number: ")) - tidx, tile, placement = moves[choice] - - # Apply move + tidx, tile, placement = mcts_moves[a_idx] + print(f"Player {player} plays tile {tidx} at {placement}") do_placement(tidx, tile, placement, game_state, player) - # Switch players player = 2 if player == 1 else 1 @@ -417,45 +532,28 @@ def main(): device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") - policy1 = PolicyNet().to(device) - policy2 = PolicyNet().to(device) + net = PolicyValueNet().to(device) + optimizer = optim.Adam(net.parameters(), lr=1e-3) - optim1 = optim.Adam(policy1.parameters(), lr=1e-3) - optim2 = optim.Adam(policy2.parameters(), lr=1e-3) + num_games = 200 # increase a lot for real training + n_simulations = 50 # MCTS sims per move (increase if it's too weak) - best_avg_reward = -999 - reward_history = [] + for g in range(1, num_games + 1): + history, winner = self_play_game(net, n_simulations, device) + loss = train_on_history(net, optimizer, history, device) - num_episodes = 2000 - for episode in range(1, num_episodes + 1): - reward = play_episode(policy1, policy2, optim1, optim2, device=device) - reward_history.append(reward) + print( + f"Game {g}/{num_games}, winner: Player {winner}, loss: {loss:.4f}, positions: {len(history)}" + ) - # compute moving average every 50 episodes - if len(reward_history) >= 50: - avg = sum(reward_history[-50:]) / 50 + # occasionally watch a game + if g % 50 == 0: + print("Watching a game with current network:") + play_game_with_mcts(net, n_simulations=30, device=device) - # If policy1 improved, save it - if avg > best_avg_reward: - best_avg_reward = avg - torch.save(policy1.state_dict(), "best_policy1.pth") - print(f"Saved best policy1 at episode {episode} (avg reward={avg:.3f})") - - if episode % 100 == 0: - print(f"Episode {episode}, last reward={reward}") - - print("Training complete.") - print("1 = Watch AI vs AI") - print("2 = Play against AI") - print("3 = Quit") - - choice = input("Select: ") - - if choice == "1": - play_game(policy1, policy2, device) - elif choice == "2": - best_ai = load_policy("best_policy1.pth", device) - human_vs_ai(best_ai, device) + # Save final network + torch.save(net.state_dict(), "alphazero_blokus_net.pth") + print("Saved network to alphazero_blokus_net.pth") if __name__ == "__main__":