#!/usr/bin/env python import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim BOARD_SIZE = 14 # ======================= # Game setup and rules # ======================= def make_board(): a = np.array([[0 for _ in range(BOARD_SIZE)] for _ in range(BOARD_SIZE)]) a[4, 4] = -1 a[9, 9] = -1 return a tiles = [ np.array([[1]]), np.array([[1], [1]]), np.array([[1], [1], [1]]), np.array([[1, 0], [1, 1]]), np.array([[1], [1], [1], [1]]), np.array([[1, 0], [1, 0], [1, 1]]), np.array([[1, 0], [1, 1], [1, 0]]), np.array([[1, 1], [1, 1]]), np.array([[1, 1, 0], [0, 1, 1]]), np.array([[1], [1], [1], [1], [1]]), np.array([[1, 0], [1, 0], [1, 0], [1, 1]]), np.array([[1, 0], [1, 0], [1, 1], [0, 1]]), np.array([[1, 0], [1, 1], [1, 1]]), np.array([[1, 1], [1, 0], [1, 1]]), np.array([[1, 0], [1, 1], [1, 0], [1, 0]]), np.array([[0, 1, 0], [0, 1, 0], [1, 1, 1]]), np.array([[1, 0, 0], [1, 0, 0], [1, 1, 1]]), np.array([[1, 1, 0], [0, 1, 1], [0, 0, 1]]), np.array([[1, 0, 0], [1, 1, 1], [0, 0, 1]]), np.array([[1, 0, 0], [1, 1, 1], [0, 1, 0]]), np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]), ] def get_permutations(which_tiles: list[int]): """ For each tile index in which_tiles, generate all unique rotations/flips. Returns a list of (tile_index, oriented_tile). """ permutations = [] for tidx in which_tiles: tile = tiles[tidx] rots = [np.rot90(tile, k) for k in range(4)] flips = [np.flip(r, axis=1) for r in rots] # horizontal flips all_orients = rots + flips # 8 orientations seen = set() for t in all_orients: key = (t.shape, t.tobytes()) if key not in seen: seen.add(key) permutations.append((tidx, t)) return permutations def can_place(board: np.ndarray, tile: np.ndarray, player: int): placements = [] has_minus_one = False for x in range(BOARD_SIZE): for y in range(BOARD_SIZE): if board[x, y] == -1: has_minus_one = True with np.nditer(tile, flags=["multi_index"]) as it: for v in it: if v == 1: (i, j) = it.multi_index if x + i >= BOARD_SIZE: break if y + j >= BOARD_SIZE: break if board[x + i][y + j] > 0: break if x + i - 1 >= 0 and board[x + i - 1][y + j] == player: break if y + j - 1 >= 0 and board[x + i][y + j - 1] == player: break if x + i + 1 < BOARD_SIZE and board[x + i + 1][y + j] == player: break if y + j + 1 < BOARD_SIZE and board[x + i][y + j + 1] == player: break else: placements.append((x, y)) final = [] if has_minus_one: for x, y in placements: with np.nditer(tile, flags=["multi_index"]) as it: for v in it: (i, j) = it.multi_index if v == 1 and board[x + i, y + j] == -1: final.append((x, y)) break else: for x, y in placements: with np.nditer(tile, flags=["multi_index"]) as it: for v in it: (i, j) = it.multi_index if ( x + i + 1 < BOARD_SIZE and y + j + 1 < BOARD_SIZE and board[x + i + 1, y + j + 1] == player ): final.append((x, y)) break if ( x + i + 1 < BOARD_SIZE and y + j - 1 >= 0 and board[x + i + 1, y + j - 1] == player ): final.append((x, y)) break if ( x + i - 1 >= 0 and y + j + 1 < BOARD_SIZE and board[x + i - 1, y + j + 1] == player ): final.append((x, y)) break if ( x + i - 1 >= 0 and y + j - 1 >= 0 and board[x + i - 1, y + j - 1] == player ): final.append((x, y)) break return final def do_placement( tidx: int, tile: np.ndarray, placement: tuple[int, int], game_state, player: int ): (x, y) = placement with np.nditer(tile, flags=["multi_index"]) as it: for v in it: (i, j) = it.multi_index if v == 1: game_state[0][x + i, y + j] = player game_state[player].remove(tidx) def print_game_state(game_state): (board, p1tiles, p2tiles) = game_state for row in board: print( "".join( [ " " if x == 0 else "X" if x == 1 else "O" if x == 2 else "S" for x in row ] ) ) print("") print(f"Player 1 tiles left: {p1tiles}") print(f"Player 2 tiles left: {p2tiles}") print("") def reset_game(): board = make_board() p1tiles = [i for i in range(21)] p2tiles = [i for i in range(21)] return [board, p1tiles, p2tiles] # list so it's mutable in-place # ======================= # RL: encoding & policy # ======================= def encode_board(board: np.ndarray, player: int) -> torch.Tensor: """ Channels: 0: current player's stones 1: opponent's stones 2: starting squares (-1) """ me = (board == player).astype(np.float32) opp = ((board > 0) & (board != player)).astype(np.float32) start = (board == -1).astype(np.float32) state = np.stack([me, opp, start], axis=0) # (3, 14, 14) return torch.from_numpy(state) def encode_move( tidx: int, tile: np.ndarray, placement: tuple[int, int] ) -> torch.Tensor: x, y = placement area = int(tile.sum()) return torch.tensor([tidx, x, y, area], dtype=torch.float32) class PolicyNet(nn.Module): def __init__(self): super().__init__() self.conv = nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), ) conv_out_dim = 64 * BOARD_SIZE * BOARD_SIZE # 64 * 14 * 14 self.fc = nn.Sequential( nn.Linear(conv_out_dim + 4, 256), nn.ReLU(), nn.Linear(256, 1), # scalar logit ) def forward( self, board_tensor: torch.Tensor, move_features: torch.Tensor ) -> torch.Tensor: """ board_tensor: (3, 14, 14) move_features: (N, 4) returns: logits (N,) """ x = self.conv(board_tensor.unsqueeze(0)) # (1, 64, 14, 14) x = x.view(1, -1) # (1, conv_out_dim) x = x.repeat(move_features.size(0), 1) # (N, conv_out_dim) combined = torch.cat([x, move_features], dim=1) # (N, conv_out_dim + 4) logits = self.fc(combined).squeeze(-1) # (N,) return logits # ======================= # RL: move generation & action selection # ======================= def get_all_moves(game_state, player: int): board, p1tiles, p2tiles = game_state available_tiles = p1tiles if player == 1 else p2tiles moves = [] for tidx, tile in get_permutations(available_tiles): for placement in can_place(board, tile, player): moves.append((tidx, tile, placement)) return moves def select_action(policy: PolicyNet, game_state, player: int, device="cpu"): board, _, _ = game_state moves = get_all_moves(game_state, player) if len(moves) == 0: return None, None # no legal moves board_tensor = encode_board(board, player).to(device) move_feats = torch.stack( [encode_move(tidx, tile, placement) for (tidx, tile, placement) in moves], dim=0 ).to(device) logits = policy(board_tensor, move_feats) # (N,) probs = F.softmax(logits, dim=0) dist = torch.distributions.Categorical(probs) idx = dist.sample() log_prob = dist.log_prob(idx) chosen_move = moves[idx.item()] return chosen_move, log_prob # ======================= # RL: self-play episode # ======================= def play_episode(policy1: PolicyNet, policy2: PolicyNet, optim1, optim2, device="cpu"): policy1.train() policy2.train() game_state = reset_game() player = 1 log_probs1 = [] log_probs2 = [] while True: if player == 1: move, log_prob = select_action(policy1, game_state, player, device) else: move, log_prob = select_action(policy2, game_state, player, device) # No move → this player loses if move is None: loser = player winner = 2 if player == 1 else 1 break tidx, tile, placement = move if player == 1: log_probs1.append(log_prob) else: log_probs2.append(log_prob) do_placement(tidx, tile, placement, game_state, player) player = 2 if player == 1 else 1 print_game_state(game_state) print(f"Player {winner} is the winner") # Rewards: +1 for win, -1 for loss (from each player's perspective) r1 = 1.0 if winner == 1 else -1.0 r2 = -r1 if log_probs1: loss1 = -torch.stack(log_probs1).sum() * r1 optim1.zero_grad() loss1.backward() optim1.step() if log_probs2: loss2 = -torch.stack(log_probs2).sum() * r2 optim2.zero_grad() loss2.backward() optim2.step() return r1 # from Player 1's perspective # ======================= # Evaluation: watch them play # ======================= def play_game(policy1: PolicyNet, policy2: PolicyNet, device="cpu"): policy1.eval() policy2.eval() game_state = reset_game() player = 1 while True: print_game_state(game_state) if player == 1: move, _ = select_action(policy1, game_state, player, device) else: move, _ = select_action(policy2, game_state, player, device) if move is None: print(f"No moves left, player {player} lost") break tidx, tile, placement = move do_placement(tidx, tile, placement, game_state, player) player = 2 if player == 1 else 1 def load_policy(path, device="cpu"): policy = PolicyNet().to(device) policy.load_state_dict(torch.load(path, map_location=device)) policy.eval() return policy def human_vs_ai(ai_policy: PolicyNet, device="cpu"): ai_policy.eval() game_state = reset_game() player = 1 # AI goes first while True: print_game_state(game_state) # Who moves? if player == 1: print("AI thinking...") move, _ = select_action(ai_policy, game_state, player, device) if move is None: print("AI has no moves — AI loses!") break tidx, tile, placement = move print(f"AI plays tile {tidx} at {placement}\n") else: # human turn moves = get_all_moves(game_state, player) if not moves: print("You have no moves — you lose!") break print("Your legal moves:") for i, (tidx, tile, placement) in enumerate(moves): print(f"{i}: tile {tidx} at {placement}") choice = int(input("Choose move number: ")) tidx, tile, placement = moves[choice] # Apply move do_placement(tidx, tile, placement, game_state, player) # Switch players player = 2 if player == 1 else 1 # ======================= # Main training loop # ======================= def main(): device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") policy1 = PolicyNet().to(device) policy2 = PolicyNet().to(device) optim1 = optim.Adam(policy1.parameters(), lr=1e-3) optim2 = optim.Adam(policy2.parameters(), lr=1e-3) best_avg_reward = -999 reward_history = [] num_episodes = 2000 for episode in range(1, num_episodes + 1): reward = play_episode(policy1, policy2, optim1, optim2, device=device) reward_history.append(reward) # compute moving average every 50 episodes if len(reward_history) >= 50: avg = sum(reward_history[-50:]) / 50 # If policy1 improved, save it if avg > best_avg_reward: best_avg_reward = avg torch.save(policy1.state_dict(), "best_policy1.pth") print(f"Saved best policy1 at episode {episode} (avg reward={avg:.3f})") if episode % 100 == 0: print(f"Episode {episode}, last reward={reward}") print("Training complete.") print("1 = Watch AI vs AI") print("2 = Play against AI") print("3 = Quit") choice = input("Select: ") if choice == "1": play_game(policy1, policy2, device) elif choice == "2": best_ai = load_policy("best_policy1.pth", device) human_vs_ai(best_ai, device) if __name__ == "__main__": main()