#!/usr/bin/env python3 import random import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim BOARD_SIZE = 14 # ======================= # Game setup and rules # ======================= def make_board(): a = np.array([[0 for _ in range(BOARD_SIZE)] for _ in range(BOARD_SIZE)]) a[4, 4] = -1 a[9, 9] = -1 return a tiles = [ np.array([[1]]), np.array([[1], [1]]), np.array([[1], [1], [1]]), np.array([[1, 0], [1, 1]]), np.array([[1], [1], [1], [1]]), np.array([[1, 0], [1, 0], [1, 1]]), np.array([[1, 0], [1, 1], [1, 0]]), np.array([[1, 1], [1, 1]]), np.array([[1, 1, 0], [0, 1, 1]]), np.array([[1], [1], [1], [1], [1]]), np.array([[1, 0], [1, 0], [1, 0], [1, 1]]), np.array([[1, 0], [1, 0], [1, 1], [0, 1]]), np.array([[1, 0], [1, 1], [1, 1]]), np.array([[1, 1], [1, 0], [1, 1]]), np.array([[1, 0], [1, 1], [1, 0], [1, 0]]), np.array([[0, 1, 0], [0, 1, 0], [1, 1, 1]]), np.array([[1, 0, 0], [1, 0, 0], [1, 1, 1]]), np.array([[1, 1, 0], [0, 1, 1], [0, 0, 1]]), np.array([[1, 0, 0], [1, 1, 1], [0, 0, 1]]), np.array([[1, 0, 0], [1, 1, 1], [0, 1, 0]]), np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]), ] def clone_state(game_state): board, p1tiles, p2tiles = game_state return [board.copy(), p1tiles.copy(), p2tiles.copy()] def get_permutations(which_tiles: list[int]): """ For each tile index in which_tiles, generate all unique rotations/flips. Returns a list of (tile_index, oriented_tile). """ permutations = [] for tidx in which_tiles: tile = tiles[tidx] rots = [np.rot90(tile, k) for k in range(4)] flips = [np.flip(r, axis=1) for r in rots] # horizontal flips all_orients = rots + flips # 8 orientations seen = set() for t in all_orients: key = (t.shape, t.tobytes()) if key not in seen: seen.add(key) permutations.append((tidx, t)) return permutations def can_place(board: np.ndarray, tile: np.ndarray, player: int): placements = [] has_minus_one = False for x in range(BOARD_SIZE): for y in range(BOARD_SIZE): if board[x, y] == -1: has_minus_one = True with np.nditer(tile, flags=["multi_index"]) as it: for v in it: if v == 1: (i, j) = it.multi_index if x + i >= BOARD_SIZE: break if y + j >= BOARD_SIZE: break if board[x + i][y + j] > 0: break if x + i - 1 >= 0 and board[x + i - 1][y + j] == player: break if y + j - 1 >= 0 and board[x + i][y + j - 1] == player: break if x + i + 1 < BOARD_SIZE and board[x + i + 1][y + j] == player: break if y + j + 1 < BOARD_SIZE and board[x + i][y + j + 1] == player: break else: placements.append((x, y)) final = [] if has_minus_one: for x, y in placements: with np.nditer(tile, flags=["multi_index"]) as it: for v in it: (i, j) = it.multi_index if v == 1 and board[x + i, y + j] == -1: final.append((x, y)) break else: for x, y in placements: with np.nditer(tile, flags=["multi_index"]) as it: for v in it: (i, j) = it.multi_index if ( x + i + 1 < BOARD_SIZE and y + j + 1 < BOARD_SIZE and board[x + i + 1, y + j + 1] == player ): final.append((x, y)) break if ( x + i + 1 < BOARD_SIZE and y + j - 1 >= 0 and board[x + i + 1, y + j - 1] == player ): final.append((x, y)) break if ( x + i - 1 >= 0 and y + j + 1 < BOARD_SIZE and board[x + i - 1, y + j + 1] == player ): final.append((x, y)) break if ( x + i - 1 >= 0 and y + j - 1 >= 0 and board[x + i - 1, y + j - 1] == player ): final.append((x, y)) break return final def do_placement( tidx: int, tile: np.ndarray, placement: tuple[int, int], game_state, player: int ): (x, y) = placement board, p1tiles, p2tiles = game_state with np.nditer(tile, flags=["multi_index"]) as it: for v in it: (i, j) = it.multi_index if v == 1: board[x + i, y + j] = player if player == 1: p1tiles.remove(tidx) else: p2tiles.remove(tidx) def print_game_state(game_state): (board, p1tiles, p2tiles) = game_state for row in board: print( "".join( [ " " if x == 0 else "X" if x == 1 else "O" if x == 2 else "S" for x in row ] ) ) print("") print(f"Player 1 tiles left: {p1tiles}") print(f"Player 2 tiles left: {p2tiles}") print("") def reset_game(): board = make_board() p1tiles = [i for i in range(21)] p2tiles = [i for i in range(21)] return [board, p1tiles, p2tiles] # list so it is mutable def get_all_moves(game_state, player: int): board, p1tiles, p2tiles = game_state available_tiles = p1tiles if player == 1 else p2tiles moves = [] for tidx, tile in get_permutations(available_tiles): for placement in can_place(board, tile, player): moves.append((tidx, tile, placement)) return moves # ======================= # AlphaZero-style network # ======================= def encode_board(board: np.ndarray, player: int) -> torch.Tensor: """ Channels: 0: current player's stones 1: opponent's stones 2: starting squares (-1) """ me = (board == player).astype(np.float32) opp = ((board > 0) & (board != player)).astype(np.float32) start = (board == -1).astype(np.float32) state = np.stack([me, opp, start], axis=0) # (3, 14, 14) return torch.from_numpy(state) def encode_move( tidx: int, tile: np.ndarray, placement: tuple[int, int] ) -> torch.Tensor: x, y = placement area = int(tile.sum()) return torch.tensor([tidx, x, y, area], dtype=torch.float32) class PolicyValueNet(nn.Module): def __init__(self): super().__init__() self.conv = nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), ) conv_out_dim = 64 * BOARD_SIZE * BOARD_SIZE # 64 * 14 * 14 # Value head (board only) self.value_head = nn.Sequential( nn.Linear(conv_out_dim, 128), nn.ReLU(), nn.Linear(128, 1), nn.Tanh(), # value in [-1, 1] ) # Policy head (board + move features) self.policy_head = nn.Sequential( nn.Linear(conv_out_dim + 4, 256), nn.ReLU(), nn.Linear(256, 1), # logit per move ) def forward(self, board_tensor: torch.Tensor, move_features: torch.Tensor): """ board_tensor: (3, 14, 14) move_features: (N, 4) returns: (policy_logits: (N,), value: scalar) """ x = self.conv(board_tensor.unsqueeze(0)) # (1, 64, 14, 14) x = x.view(1, -1) # (1, conv_out_dim) board_embed = x # value head value = self.value_head(board_embed).squeeze(0).squeeze(-1) # scalar # policy head x_rep = board_embed.repeat(move_features.size(0), 1) # (N, conv_out_dim) combined = torch.cat([x_rep, move_features], dim=1) # (N, conv_out_dim + 4) logits = self.policy_head(combined).squeeze(-1) # (N,) return logits, value # ======================= # MCTS (AlphaZero-style) # ======================= class MCTSNode: def __init__(self, state, player: int): self.state = state self.player = player self.is_expanded = False self.is_terminal = False self.moves: Optional[list[tuple[int, np.ndarray, tuple[int, int]]]] = None self.priors: Optional[np.ndarray] = None self.Nsa: Optional[np.ndarray] = None self.Wsa: Optional[np.ndarray] = None self.Qsa: Optional[np.ndarray] = None self.children: dict[int, "MCTSNode"] = {} def expand(self, net: PolicyValueNet, device="cpu") -> float: """ Returns value v from perspective of self.player. """ moves = get_all_moves(self.state, self.player) if len(moves) == 0: # No moves: this player loses self.is_terminal = True self.is_expanded = True return -1.0 self.moves = moves board, _, _ = self.state board_tensor = encode_board(board, self.player).to(device) move_feats = torch.stack( [encode_move(tidx, tile, placement) for (tidx, tile, placement) in moves], dim=0, ).to(device) with torch.no_grad(): logits, value = net(board_tensor, move_feats) probs = F.softmax(logits, dim=0).cpu().numpy() v = float(value.item()) self.priors = probs n = len(moves) self.Nsa = np.zeros(n, dtype=np.float32) self.Wsa = np.zeros(n, dtype=np.float32) self.Qsa = np.zeros(n, dtype=np.float32) self.is_expanded = True return v def select_action(self, c_puct: float = 1.5) -> int: """ Select action index using PUCT formula. """ Ns = np.sum(self.Nsa) + 1e-8 u = c_puct * self.priors * np.sqrt(Ns) / (1.0 + self.Nsa) scores = self.Qsa + u return int(np.argmax(scores)) def mcts_search( net: PolicyValueNet, root_state, root_player: int, n_simulations: int, device="cpu", c_puct: float = 1.5, ): root = MCTSNode(clone_state(root_state), root_player) for _ in range(n_simulations): node = root path: list[tuple[MCTSNode, int]] = [] # Traverse while True: if not node.is_expanded: v = node.expand(net, device) break if node.is_terminal: # Value from this player's perspective is -1 (no moves) v = -1.0 break a = node.select_action(c_puct) path.append((node, a)) if a in node.children: node = node.children[a] else: # create child child_state = clone_state(node.state) tidx, tile, placement = node.moves[a] do_placement(tidx, tile, placement, child_state, node.player) next_player = 2 if node.player == 1 else 1 child = MCTSNode(child_state, next_player) node.children[a] = child node = child # next loop iteration will expand it # Backpropagate value v (from leaf player's perspective) val = v # Going back up the tree, the perspective alternates each move for parent, action_index in reversed(path): val = -val # switch to parent's perspective parent.Nsa[action_index] += 1.0 parent.Wsa[action_index] += val parent.Qsa[action_index] = ( parent.Wsa[action_index] / parent.Nsa[action_index] ) # After all simulations, derive policy target from root visit counts if not root.is_expanded or root.is_terminal or root.moves is None: return None, None, None visits = root.Nsa pi = visits / np.sum(visits) # Sample action from pi (exploration); you can use argmax for greedy play action_index = int(np.random.choice(len(root.moves), p=pi)) return root.moves, pi, action_index # ======================= # Self-play + training # ======================= def self_play_game(net: PolicyValueNet, n_simulations: int, device="cpu"): """ Plays one self-play game using MCTS + shared network. Returns a list of training examples: each entry: (board_snapshot, player, moves, pi, z) """ game_state = reset_game() player = 1 history = [] # list of dicts: board, player, moves, pi, z (filled later) while True: moves = get_all_moves(game_state, player) if len(moves) == 0: winner = 2 if player == 1 else 1 break # Run MCTS from current state mcts_moves, pi, a_idx = mcts_search( net, game_state, player, n_simulations, device ) if mcts_moves is None: winner = 2 if player == 1 else 1 break # Save training position (copy board only; moves are references) board_snapshot = game_state[0].copy() history.append( { "board": board_snapshot, "player": player, "moves": mcts_moves, "pi": pi, "z": None, # fill after game } ) # Play chosen move tidx, tile, placement = mcts_moves[a_idx] do_placement(tidx, tile, placement, game_state, player) player = 2 if player == 1 else 1 # Game finished, assign outcomes for entry in history: entry["z"] = 1.0 if entry["player"] == winner else -1.0 return history, winner def train_on_history(net: PolicyValueNet, optimizer, history, device="cpu"): """ Single gradient step over all positions from one self-play game. """ net.train() optimizer.zero_grad() total_loss = 0.0 for entry in history: board = entry["board"] player = entry["player"] moves = entry["moves"] pi = entry["pi"] z = entry["z"] board_tensor = encode_board(board, player).to(device) move_feats = torch.stack( [encode_move(tidx, tile, placement) for (tidx, tile, placement) in moves], dim=0, ).to(device) target_pi = torch.from_numpy(pi).to(device) target_z = torch.tensor(z, dtype=torch.float32, device=device) logits, value = net(board_tensor, move_feats) log_probs = F.log_softmax(logits, dim=0) policy_loss = -(target_pi * log_probs).sum() value_loss = F.mse_loss(value, target_z) loss = policy_loss + value_loss total_loss += loss if len(history) > 0: total_loss = total_loss / len(history) total_loss.backward() optimizer.step() return float(total_loss.item()) # ======================= # Simple evaluation game # ======================= def play_game_with_mcts(net: PolicyValueNet, n_simulations: int, device="cpu"): """ Watch two MCTS+net players (same weights) play against each other. """ net.eval() game_state = reset_game() player = 1 while True: print_game_state(game_state) moves = get_all_moves(game_state, player) if not moves: print(f"No moves left, player {player} loses.") break mcts_moves, pi, a_idx = mcts_search( net, game_state, player, n_simulations, device ) if mcts_moves is None: print(f"No moves left (MCTS), player {player} loses.") break tidx, tile, placement = mcts_moves[a_idx] print(f"Player {player} plays tile {tidx} at {placement}") do_placement(tidx, tile, placement, game_state, player) player = 2 if player == 1 else 1 # ======================= # Main training loop # ======================= def main(): device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") net = PolicyValueNet().to(device) optimizer = optim.Adam(net.parameters(), lr=1e-3) num_games = 200 # increase a lot for real training n_simulations = 50 # MCTS sims per move (increase if it's too weak) for g in range(1, num_games + 1): history, winner = self_play_game(net, n_simulations, device) loss = train_on_history(net, optimizer, history, device) print( f"Game {g}/{num_games}, winner: Player {winner}, loss: {loss:.4f}, positions: {len(history)}" ) # occasionally watch a game if g % 50 == 0: print("Watching a game with current network:") play_game_with_mcts(net, n_simulations=30, device=device) # Save final network torch.save(net.state_dict(), "alphazero_blokus_net.pth") print("Saved network to alphazero_blokus_net.pth") if __name__ == "__main__": main()