diff --git a/blokus.py b/blokus.py index 30c20ed..8a14e31 100755 --- a/blokus.py +++ b/blokus.py @@ -1,12 +1,20 @@ #!/usr/bin/env python + import numpy as np -import random +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim BOARD_SIZE = 14 +# ======================= +# Game setup and rules +# ======================= + def make_board(): - a = np.array([[0 for i in range(BOARD_SIZE)] for j in range(BOARD_SIZE)]) + a = np.array([[0 for _ in range(BOARD_SIZE)] for _ in range(BOARD_SIZE)]) a[4, 4] = -1 a[9, 9] = -1 return a @@ -38,14 +46,16 @@ tiles = [ def get_permutations(which_tiles: list[int]): + """ + For each tile index in which_tiles, generate all unique rotations/flips. + Returns a list of (tile_index, oriented_tile). + """ permutations = [] - - for i, tile in enumerate(tiles): - if i not in which_tiles: - continue + for tidx in which_tiles: + tile = tiles[tidx] rots = [np.rot90(tile, k) for k in range(4)] - flips = [np.flip(r, axis=1) for r in rots] # flip horizontally + flips = [np.flip(r, axis=1) for r in rots] # horizontal flips all_orients = rots + flips # 8 orientations seen = set() @@ -53,12 +63,12 @@ def get_permutations(which_tiles: list[int]): key = (t.shape, t.tobytes()) if key not in seen: seen.add(key) - permutations.append((i, t)) + permutations.append((tidx, t)) return permutations -def can_place(board, tile, player): +def can_place(board: np.ndarray, tile: np.ndarray, player: int): placements = [] has_minus_one = False for x in range(BOARD_SIZE): @@ -102,35 +112,37 @@ def can_place(board, tile, player): if ( x + i + 1 < BOARD_SIZE and y + j + 1 < BOARD_SIZE - and board[x + i + 1][y + j + 1] == player + and board[x + i + 1, y + j + 1] == player ): final.append((x, y)) break if ( x + i + 1 < BOARD_SIZE and y + j - 1 >= 0 - and board[x + i + 1][y + j - 1] == player + and board[x + i + 1, y + j - 1] == player ): final.append((x, y)) break if ( x + i - 1 >= 0 and y + j + 1 < BOARD_SIZE - and board[x + i - 1][y + j + 1] == player + and board[x + i - 1, y + j + 1] == player ): final.append((x, y)) break if ( x + i - 1 >= 0 and y + j - 1 >= 0 - and board[x + i - 1][y + j - 1] == player + and board[x + i - 1, y + j - 1] == player ): final.append((x, y)) break return final -def do_placement(tidx, tile, placement, game_state, player): +def do_placement( + tidx: int, tile: np.ndarray, placement: tuple[int, int], game_state, player: int +): (x, y) = placement with np.nditer(tile, flags=["multi_index"]) as it: for v in it: @@ -156,35 +168,295 @@ def print_game_state(game_state): print("") print(f"Player 1 tiles left: {p1tiles}") print(f"Player 2 tiles left: {p2tiles}") + print("") -game_state = ( - make_board(), - [i for i in range(21)], - [i for i in range(21)], -) +def reset_game(): + board = make_board() + p1tiles = [i for i in range(21)] + p2tiles = [i for i in range(21)] + return [board, p1tiles, p2tiles] # list so it's mutable in-place -playing = True -player = 1 -while playing: +# ======================= +# RL: encoding & policy +# ======================= + + +def encode_board(board: np.ndarray, player: int) -> torch.Tensor: + """ + Channels: + 0: current player's stones + 1: opponent's stones + 2: starting squares (-1) + """ + me = (board == player).astype(np.float32) + opp = ((board > 0) & (board != player)).astype(np.float32) + start = (board == -1).astype(np.float32) + state = np.stack([me, opp, start], axis=0) # (3, 14, 14) + return torch.from_numpy(state) + + +def encode_move( + tidx: int, tile: np.ndarray, placement: tuple[int, int] +) -> torch.Tensor: + x, y = placement + area = int(tile.sum()) + return torch.tensor([tidx, x, y, area], dtype=torch.float32) + + +class PolicyNet(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(3, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=3, padding=1), + nn.ReLU(), + ) + conv_out_dim = 64 * BOARD_SIZE * BOARD_SIZE # 64 * 14 * 14 + + self.fc = nn.Sequential( + nn.Linear(conv_out_dim + 4, 256), + nn.ReLU(), + nn.Linear(256, 1), # scalar logit + ) + + def forward( + self, board_tensor: torch.Tensor, move_features: torch.Tensor + ) -> torch.Tensor: + """ + board_tensor: (3, 14, 14) + move_features: (N, 4) + returns: logits (N,) + """ + x = self.conv(board_tensor.unsqueeze(0)) # (1, 64, 14, 14) + x = x.view(1, -1) # (1, conv_out_dim) + x = x.repeat(move_features.size(0), 1) # (N, conv_out_dim) + + combined = torch.cat([x, move_features], dim=1) # (N, conv_out_dim + 4) + logits = self.fc(combined).squeeze(-1) # (N,) + return logits + + +# ======================= +# RL: move generation & action selection +# ======================= + + +def get_all_moves(game_state, player: int): + board, p1tiles, p2tiles = game_state + available_tiles = p1tiles if player == 1 else p2tiles + moves = [] - for tidx, tile in get_permutations(game_state[player]): - for placement in can_place(game_state[0], tile, player): + for tidx, tile in get_permutations(available_tiles): + for placement in can_place(board, tile, player): moves.append((tidx, tile, placement)) - print_game_state(game_state) - print(f"player {player} has {len(moves)} options") + return moves + + +def select_action(policy: PolicyNet, game_state, player: int, device="cpu"): + board, _, _ = game_state + moves = get_all_moves(game_state, player) if len(moves) == 0: - print(f"No moves left, player {player} lost") - playing = False - continue + return None, None # no legal moves - (tidx, tile, placement) = random.choice(moves) - do_placement(tidx, tile, placement, game_state, player) + board_tensor = encode_board(board, player).to(device) - if player == 1: - player = 2 - elif player == 2: - player = 1 + move_feats = torch.stack( + [encode_move(tidx, tile, placement) for (tidx, tile, placement) in moves], dim=0 + ).to(device) + + logits = policy(board_tensor, move_feats) # (N,) + probs = F.softmax(logits, dim=0) + dist = torch.distributions.Categorical(probs) + idx = dist.sample() + log_prob = dist.log_prob(idx) + + chosen_move = moves[idx.item()] + return chosen_move, log_prob + + +# ======================= +# RL: self-play episode +# ======================= + + +def play_episode(policy1: PolicyNet, policy2: PolicyNet, optim1, optim2, device="cpu"): + policy1.train() + policy2.train() + game_state = reset_game() + player = 1 + + log_probs1 = [] + log_probs2 = [] + + while True: + if player == 1: + move, log_prob = select_action(policy1, game_state, player, device) + else: + move, log_prob = select_action(policy2, game_state, player, device) + + # No move → this player loses + if move is None: + loser = player + winner = 2 if player == 1 else 1 + break + + tidx, tile, placement = move + + if player == 1: + log_probs1.append(log_prob) + else: + log_probs2.append(log_prob) + + do_placement(tidx, tile, placement, game_state, player) + + player = 2 if player == 1 else 1 + + print_game_state(game_state) + print(f"Player {winner} is the winner") + # Rewards: +1 for win, -1 for loss (from each player's perspective) + r1 = 1.0 if winner == 1 else -1.0 + r2 = -r1 + + if log_probs1: + loss1 = -torch.stack(log_probs1).sum() * r1 + optim1.zero_grad() + loss1.backward() + optim1.step() + + if log_probs2: + loss2 = -torch.stack(log_probs2).sum() * r2 + optim2.zero_grad() + loss2.backward() + optim2.step() + + return r1 # from Player 1's perspective + + +# ======================= +# Evaluation: watch them play +# ======================= + + +def play_game(policy1: PolicyNet, policy2: PolicyNet, device="cpu"): + policy1.eval() + policy2.eval() + game_state = reset_game() + player = 1 + while True: + print_game_state(game_state) + + if player == 1: + move, _ = select_action(policy1, game_state, player, device) + else: + move, _ = select_action(policy2, game_state, player, device) + + if move is None: + print(f"No moves left, player {player} lost") + break + + tidx, tile, placement = move + do_placement(tidx, tile, placement, game_state, player) + + player = 2 if player == 1 else 1 + +def load_policy(path, device="cpu"): + policy = PolicyNet().to(device) + policy.load_state_dict(torch.load(path, map_location=device)) + policy.eval() + return policy + +def human_vs_ai(ai_policy: PolicyNet, device="cpu"): + ai_policy.eval() + game_state = reset_game() + player = 1 # AI goes first + + while True: + print_game_state(game_state) + + # Who moves? + if player == 1: + print("AI thinking...") + move, _ = select_action(ai_policy, game_state, player, device) + if move is None: + print("AI has no moves — AI loses!") + break + tidx, tile, placement = move + print(f"AI plays tile {tidx} at {placement}\n") + else: + # human turn + moves = get_all_moves(game_state, player) + if not moves: + print("You have no moves — you lose!") + break + + print("Your legal moves:") + for i, (tidx, tile, placement) in enumerate(moves): + print(f"{i}: tile {tidx} at {placement}") + + choice = int(input("Choose move number: ")) + tidx, tile, placement = moves[choice] + + # Apply move + do_placement(tidx, tile, placement, game_state, player) + + # Switch players + player = 2 if player == 1 else 1 + + +# ======================= +# Main training loop +# ======================= + + +def main(): + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + + policy1 = PolicyNet().to(device) + policy2 = PolicyNet().to(device) + + optim1 = optim.Adam(policy1.parameters(), lr=1e-3) + optim2 = optim.Adam(policy2.parameters(), lr=1e-3) + + best_avg_reward = -999 + reward_history = [] + + num_episodes = 2000 + for episode in range(1, num_episodes + 1): + reward = play_episode(policy1, policy2, optim1, optim2, device=device) + reward_history.append(reward) + + # compute moving average every 50 episodes + if len(reward_history) >= 50: + avg = sum(reward_history[-50:]) / 50 + + # If policy1 improved, save it + if avg > best_avg_reward: + best_avg_reward = avg + torch.save(policy1.state_dict(), "best_policy1.pth") + print(f"Saved best policy1 at episode {episode} (avg reward={avg:.3f})") + + if episode % 100 == 0: + print(f"Episode {episode}, last reward={reward}") + + print("Training complete.") + print("1 = Watch AI vs AI") + print("2 = Play against AI") + print("3 = Quit") + + choice = input("Select: ") + + if choice == "1": + play_game(policy1, policy2, device) + elif choice == "2": + best_ai = load_policy("best_policy1.pth", device) + human_vs_ai(best_ai, device) + + +if __name__ == "__main__": + main()