diff --git a/blokus.py b/blokus.py index ba7d533..569184b 100755 --- a/blokus.py +++ b/blokus.py @@ -2,6 +2,9 @@ import random import sys import os +import math +from collections import defaultdict + import game import torch @@ -10,39 +13,49 @@ import torch.nn.functional as F from tqdm.auto import trange import matplotlib.pyplot as plt -############### -# Utilities # -############### +#################### +# Global constants # +#################### BOARD_SIZE = 14 tiles = game.game_tiles() +TILE_COUNT = len(tiles) + +# We need a max orientation count to encode orientation if desired +# If you want to use orientations later, this is handy. +ORIENT_MAX = max(len(t.permutations()) for t in tiles) + +################### +# Utility helpers # +################### def print_game_state(game_state: tuple[game.Board, list[int], list[int]]): (board, p1tiles, p2tiles) = game_state barr = [] - for i in range(BOARD_SIZE): - barr.append([]) - for j in range(BOARD_SIZE): - barr[i].append(board[(j, i)]) + for y in range(BOARD_SIZE): + row = [] + for x in range(BOARD_SIZE): + row.append(board[(x, y)]) + barr.append(row) - print(f" {'-' * BOARD_SIZE} ") + print(f" {'--' * BOARD_SIZE} ") for row in barr: print( f"|{ ''.join( [ - ' ' if x == 0 else 'X' if x == 1 else 'O' if x == 2 else 'S' + ' ' if x == 0 else '\033[93m██\033[00m' if x == 1 else '\033[94m██\033[00m' if x == 2 else '██' for x in row ] ) }|" ) - print(f" {'-' * BOARD_SIZE} ") + print(f" {'--' * BOARD_SIZE} ") - print(f"Player 1 tiles left: {p1tiles}") - print(f"Player 2 tiles left: {p2tiles}") + print(f"\033[93mPlayer 1\033[00m tiles left: {p1tiles}") + print(f"\033[94mPlayer 2\033[00m tiles left: {p2tiles}") def plot_losses(loss_history, out_path="loss_curve.png"): @@ -52,9 +65,9 @@ def plot_losses(loss_history, out_path="loss_curve.png"): plt.figure() plt.plot(range(1, len(loss_history) + 1), loss_history) - plt.xlabel("Episode") + plt.xlabel("Training iteration") plt.ylabel("Loss") - plt.title("Training loss over episodes") + plt.title("AlphaZero training loss") plt.tight_layout() plt.savefig(out_path) plt.close() @@ -69,8 +82,8 @@ def plot_losses(loss_history, out_path="loss_curve.png"): def initial_game_state(): return ( game.Board(), - [i for i in range(21)], - [i for i in range(21)], + [i for i in range(21)], # tiles for player 1 + [i for i in range(21)], # tiles for player 2 ) @@ -80,7 +93,12 @@ def initial_game_state(): def encode_board(board: game.Board) -> torch.Tensor: - # board[(x, y)] returns 0,1,2,... according to your print function + """ + Encode board as (3, H, W) float tensor: + channel 0: player 1 stones + channel 1: player 2 stones + channel 2: value 3 ("S" or special) + """ arr = torch.zeros((3, BOARD_SIZE, BOARD_SIZE), dtype=torch.float32) for y in range(BOARD_SIZE): for x in range(BOARD_SIZE): @@ -89,13 +107,18 @@ def encode_board(board: game.Board) -> torch.Tensor: arr[0, y, x] = 1.0 elif v == 2: arr[1, y, x] = 1.0 - elif v == 3: # if "S" or something else + elif v == 3: arr[2, y, x] = 1.0 return arr def encode_tiles(p1tiles, p2tiles) -> torch.Tensor: - # 21 tiles total, so 42-dim vector + """ + Encode which tiles each player still has. + 21 tiles per player -> 42-dim vector: + indices 0..20: tile i available for P1? + indices 21..41: tile i available for P2? + """ v = torch.zeros(42, dtype=torch.float32) for t in p1tiles: v[t] = 1.0 @@ -105,74 +128,153 @@ def encode_tiles(p1tiles, p2tiles) -> torch.Tensor: return v +def encode_state(game_state, player: int) -> torch.Tensor: + """ + Flattened state encoding for the NN: + - board one-hot: 3 * 14 * 14 = 588 + - tiles: 42 + - current player bit: 1 + Total dim = 631. + """ + board, p1tiles, p2tiles = game_state + board_tensor = encode_board(board).flatten() # (588,) + tiles_tensor = encode_tiles(p1tiles, p2tiles) # (42,) + player_bit = torch.tensor([1.0 if player == 1 else 0.0], dtype=torch.float32) + return torch.cat([board_tensor, tiles_tensor, player_bit], dim=0) # (631,) + + def encode_move(tile_idx: int, placement: tuple[int, int]) -> torch.Tensor: - (x, y) = placement + """ + Encode a move (ignoring orientation for now): + - tile index one-hot (21) + - normalized x, y + Total dim = 23. + """ + x, y = placement tile_vec = torch.zeros(21, dtype=torch.float32) tile_vec[tile_idx] = 1.0 pos_vec = torch.tensor( - [x / (BOARD_SIZE - 1), y / (BOARD_SIZE - 1)], dtype=torch.float32 + [x / (BOARD_SIZE - 1), y / (BOARD_SIZE - 1)], + dtype=torch.float32, ) - return torch.cat([tile_vec, pos_vec], dim=0) # 23-dim + return torch.cat([tile_vec, pos_vec], dim=0) # (23,) -def encode_state_and_move( - game_state, player: int, tile_idx: int, placement: tuple[int, int], perm: game.Tile -): - board, p1tiles, p2tiles = game_state - - # Encode board BEFORE the move - board_before = encode_board(board).flatten() - - # Encode board AFTER the move using sim_place - gp = game.Player.P1 if player == 1 else game.Player.P2 - board_after_sim = board.sim_place( - perm, placement, gp - ) # <--- uses your new function - board_after = encode_board(board_after_sim).flatten() - - tiles_tensor = encode_tiles(p1tiles, p2tiles) - move_tensor = encode_move(tile_idx, placement) # still tile+position encoding - player_tensor = torch.tensor([1.0 if player == 1 else 0.0], dtype=torch.float32) - - return torch.cat( - [ - board_before, # 588 - board_after, # 588 - tiles_tensor, # 42 - move_tensor, # 23 - player_tensor, # 1 - ], - dim=0, - ) +STATE_DIM = 588 + 42 + 1 # 631 +MOVE_DIM = 23 -########### -# Model # -########### - -FEATURE_SIZE = 1242 # from above +############################## +# AlphaZero-style neural net # +############################## -class MoveScorer(nn.Module): - def __init__(self): +class AlphaZeroNet(nn.Module): + """ + AlphaZero-style network for this Blokus-like game. + + Given a state vector (flattened board + tiles + player), + it produces: + - value v in [-1,1] (estimate of who will win) + - a policy over moves *relative to a given move set*, by + combining a state embedding with move features. + """ + + def __init__(self, state_emb_dim=256, move_emb_dim=64): super().__init__() - self.fc1 = nn.Linear(FEATURE_SIZE, 256) - self.fc2 = nn.Linear(256, 128) - self.fc_out = nn.Linear(128, 1) # scalar score - def forward(self, x): - # x: (batch_size, FEATURE_SIZE) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - return self.fc_out(x) # (batch_size, 1) + # State encoder: simple MLP on the 1D state vector + self.state_mlp = nn.Sequential( + nn.Linear(STATE_DIM, 256), + nn.ReLU(), + nn.Linear(256, state_emb_dim), + nn.ReLU(), + ) + + # Value head: from state embedding -> scalar, tanh in [-1,1] + self.value_head = nn.Sequential( + nn.Linear(state_emb_dim, 64), + nn.ReLU(), + nn.Linear(64, 1), + nn.Tanh(), + ) + + # Move encoder: encode move features + self.move_mlp = nn.Sequential( + nn.Linear(MOVE_DIM, 64), + nn.ReLU(), + nn.Linear(64, move_emb_dim), + nn.ReLU(), + ) + + # Policy head: combine state_emb + move_emb -> logit per move + self.policy_head = nn.Sequential( + nn.Linear(state_emb_dim + move_emb_dim, 128), + nn.ReLU(), + nn.Linear(128, 1), # scalar logit per move + ) + + def forward_state(self, state_vec: torch.Tensor): + """ + state_vec: (STATE_DIM,) + returns: state_emb (state_emb_dim,), value scalar + """ + x = self.state_mlp(state_vec) + v = self.value_head(x).squeeze(-1) + return x, v + + def forward_policy(self, state_emb: torch.Tensor, move_feats: torch.Tensor): + """ + state_emb: (state_emb_dim,) + move_feats: (N, MOVE_DIM) + returns: policy_probs: (N,), logits: (N,) + """ + move_emb = self.move_mlp(move_feats) # (N, move_emb_dim) + state_expanded = state_emb.unsqueeze(0).expand(move_emb.size(0), -1) + h = torch.cat( + [state_expanded, move_emb], dim=1 + ) # (N, state_emb_dim+move_emb_dim) + logits = self.policy_head(h).squeeze(-1) # (N,) + probs = F.softmax(logits, dim=0) + return probs, logits + + +def net_predict(net: AlphaZeroNet, game_state, player: int, moves): + """ + Evaluate net on a position from the perspective of 'player'. + + Returns: + priors: tensor (N,) over given moves (N=len(moves)), or None if no moves + value: float in [-1,1] from player's POV + """ + if not moves: + return None, 0.0 + + state_vec = encode_state(game_state, player) + move_feats = [] + for tidx, perm, placement in moves: + mf = encode_move(tidx, placement) + move_feats.append(mf) + + state_vec = state_vec + move_feats = torch.stack(move_feats, dim=0) + + with torch.no_grad(): + state_emb, v = net.forward_state(state_vec) + priors, logits = net.forward_policy(state_emb, move_feats) + + return priors, float(v.item()) ################## -# Move generation +# Move generation # ################## def get_legal_moves(game_state, player: int): + """ + Returns list of (tile_idx, perm, placement) + """ board, p1tiles, p2tiles = game_state gp = game.Player.P1 if player == 1 else game.Player.P2 @@ -184,259 +286,406 @@ def get_legal_moves(game_state, player: int): perms = tile.permutations() for perm in perms: plcs = board.tile_placements(perm, gp) - moves.extend((tile_idx, perm, plc) for plc in plcs) + for plc in plcs: + moves.append((tile_idx, perm, plc)) return moves -########### -# Agents # -########### +############### +# MCTS (PUCT) # +############### -class Agent: - def choose_move(self, game_state, player: int): - """Return (tile_idx, perm, placement) or None if no moves.""" - raise NotImplementedError +class MCTSNode: + def __init__(self, game_state, player: int, parent=None, move_from_parent=None): + self.game_state = game_state # (board, p1tiles, p2tiles) + self.player = player # player to move at this node (1 or 2) + self.parent = parent + self.move_from_parent = ( + move_from_parent # index of move from parent to this node + ) + self.children = {} # move_index -> child MCTSNode + self.moves = None # list of legal moves at this node + self.P = None # tensor of priors over moves (N,) + self.N = defaultdict(int) # visit count per move_index + self.W = defaultdict(float) # total value per move_index + self.Q = defaultdict(float) # mean value per move_index + + self.is_expanded = False + self.is_terminal = False + self.value = 0.0 + + def expand(self, net: AlphaZeroNet): + """ + Expand this node: generate legal moves, get priors & value from net. + """ + if self.is_expanded: + return + + moves = get_legal_moves(self.game_state, self.player) + self.moves = moves -class RandomAgent(Agent): - def choose_move(self, game_state, player: int): - moves = get_legal_moves(game_state, player) if not moves: - return None - return random.choice(moves) + # terminal: current player has no moves -> they lose + self.is_terminal = True + self.P = None + self.value = 0.0 + self.is_expanded = True + return + priors, v = net_predict(net, self.game_state, self.player, moves) + self.P = priors + self.value = v # value from this node's player POV + self.is_expanded = True -class HumanAgent(Agent): - def choose_move(self, game_state, player: int): - moves = get_legal_moves(game_state, player) - if not moves: - print(f"No moves left for player {player}") - return None + def select_child(self, c_puct=1.5): + """ + Select move index according to PUCT: + a* = argmax_a (Q(s,a) + c_puct * P(s,a) * sqrt(sum_b N(s,b)) / (1 + N(s,a))) + """ + total_N = sum(self.N[i] for i in range(len(self.moves))) + 1e-8 + best_score = -1e9 + best_i = None + for i in range(len(self.moves)): + Q = self.Q[i] + P = float(self.P[i]) + N_sa = self.N[i] + U = c_puct * P * math.sqrt(total_N) / (1 + N_sa) + score = Q + U + if score > best_score: + best_score = score + best_i = i + return best_i - print_game_state(game_state) - print(f"Player {player}, you have {len(moves)} possible moves.") + def simulate_child(self, move_index: int): + """ + Given a move index in self.moves, return or create the child node. + """ + if move_index in self.children: + return self.children[move_index] - # Show a *subset* or all moves - for i, (tidx, perm, plc) in enumerate(moves): - if i < 50: # don't spam too hard; tweak as needed - print(f"[{i}] tile {tidx} at {plc}") - else: - break - if len(moves) > 50: - print(f"... and {len(moves) - 50} more moves not listed") + tidx, perm, placement = self.moves[move_index] - while True: - try: - choice = int(input("Enter move index: ")) - if 0 <= choice < len(moves): - return moves[choice] - else: - print("Invalid index, try again.") - except ValueError: - print("Please enter an integer.") + board, p1tiles, p2tiles = self.game_state + gp = game.Player.P1 if self.player == 1 else game.Player.P2 + # Simulate placing the tile on a copy of the board + next_board = board.sim_place(perm, placement, gp) # does not modify original -class MLAgent(Agent): - def __init__( - self, model: MoveScorer, deterministic: bool = True, epsilon: float = 0.0 - ): - self.model = model - self.deterministic = deterministic - self.epsilon = epsilon - - def choose_move(self, game_state, player: int): - moves = get_legal_moves(game_state, player) - if not moves: - return None - - # Optional epsilon-greedy: use 0 for “serious play” - if self.epsilon > 0.0 and random.random() < self.epsilon: - return random.choice(moves) - - # Build feature batch - features = [] - for tidx, perm, placement in moves: - feat = encode_state_and_move(game_state, player, tidx, placement, perm) - features.append(feat) - X = torch.stack(features, dim=0) - - self.model.eval() - with torch.no_grad(): - scores = self.model(X).squeeze(-1) # (num_moves,) - - if self.deterministic: - best_idx = torch.argmax(scores).item() - return moves[best_idx] + next_p1tiles = list(p1tiles) + next_p2tiles = list(p2tiles) + if self.player == 1: + next_p1tiles.remove(tidx) + next_player = 2 else: - # Sample from softmax for more variety - probs = torch.softmax(scores, dim=0) - idx = torch.multinomial(probs, num_samples=1).item() - return moves[idx] + next_p2tiles.remove(tidx) + next_player = 1 + + next_state = (next_board, next_p1tiles, next_p2tiles) + child = MCTSNode( + next_state, next_player, parent=self, move_from_parent=move_index + ) + self.children[move_index] = child + return child -###################### -# Training utilities # -###################### - - -def select_move_and_logprob(model: MoveScorer, game_state, player: int): +def mcts_search( + root_state, + root_player: int, + net: AlphaZeroNet, + num_simulations: int = 50, + c_puct: float = 1.5, + temperature: float = 1.0, +): """ - For training: sample a move from softmax over scores - and return (move, log_prob). If no moves, returns (None, None). - """ - moves = get_legal_moves(game_state, player) - if not moves: - return None, None + Run MCTS from (root_state, root_player). - features = [] - for tidx, perm, placement in moves: - feat = encode_state_and_move(game_state, player, tidx, placement, perm) - features.append(feat) - X = torch.stack(features, dim=0) # (num_moves, FEATURE_SIZE) - - scores = model(X).squeeze(-1) # (num_moves,) - probs = F.softmax(scores, dim=0) - - dist = torch.distributions.Categorical(probs) - idx = dist.sample() - log_prob = dist.log_prob(idx) - - move = moves[idx.item()] - return move, log_prob - - -def play_self_play_game(model: MoveScorer, max_turns: int = 500, watch: bool = False): - """ - Self-play game with the same model as both players. Returns: - log_probs_p1, log_probs_p2, reward_p1, reward_p2 - where rewards are +1/-1 for win/loss. + pi: tensor of size (num_moves,), MCTS policy over root moves + chosen_move: (tile_idx, perm, placement) + root_moves: list of all legal moves at root + """ + root = MCTSNode(root_state, root_player) + root.expand(net) + + if root.is_terminal or root.P is None or not root.moves: + # No moves at root + return None, None, [] + + # Run simulations + for _ in range(num_simulations): + node = root + path = [node] + + # 1. Selection: descend tree until leaf + while node.is_expanded and not node.is_terminal: + move_i = node.select_child(c_puct) + node = node.simulate_child(move_i) + path.append(node) + + # 2. Expansion / evaluation of leaf + node.expand(net) + + # 3. Compute leaf value from root player's perspective + if node.is_terminal: + # terminal: node.player has no moves -> node.player loses + if root_player == node.player: + v_root = -1.0 + else: + v_root = +1.0 + else: + # non-terminal: node.value is from node.player's POV + v_node = node.value + v_root = v_node if node.player == root_player else -v_node + + # 4. Backup: update Q, N along the path + value = v_root + for i in range(len(path) - 1): + parent = path[i] + child = path[i + 1] + mv_i = child.move_from_parent + parent.N[mv_i] += 1 + parent.W[mv_i] += value + parent.Q[mv_i] = parent.W[mv_i] / parent.N[mv_i] + value = -value # flip sign at each ply + + # Visits at root + visits = torch.tensor( + [root.N[i] for i in range(len(root.moves))], + dtype=torch.float32, + ) + + if temperature == 0: + # deterministic: all mass on argmax visit + best = torch.argmax(visits).item() + pi = torch.zeros_like(visits) + pi[best] = 1.0 + else: + pi = visits ** (1.0 / temperature) + if pi.sum() > 0: + pi = pi / pi.sum() + else: + pi = torch.ones_like(pi) / pi.numel() + + # Choose move to actually play (sample from pi) + move_index = torch.multinomial(pi, num_samples=1).item() + chosen_move = root.moves[move_index] + + return pi, chosen_move, root.moves + + +######################## +# Self-play + training # +######################## + + +def self_play_game( + net: AlphaZeroNet, + num_simulations: int = 50, + temperature: float = 1.0, + watch: bool = False, +): + """ + Play one self-play game using MCTS + net. + + Returns: + examples: list of (state_tensor, moves, pi, z) + - state_tensor: encoding of state at root of each move + - moves: list of legal moves at that root + - pi: tensor of MCTS policy over moves + - z: final outcome from that state's player POV (+1 or -1) """ game_state = initial_game_state() board, p1tiles, p2tiles = game_state - - log_probs = {1: [], 2: []} player = 1 - turns = 0 + + trajectory = [] # list of dicts: {"state": state_vec, "moves": moves, "pi": pi, "player": player} while True: - turns += 1 - if turns > max_turns: - # Safety: declare a draw - reward_p1 = 0.0 - reward_p2 = 0.0 - return log_probs[1], log_probs[2], reward_p1, reward_p2 + moves = get_legal_moves(game_state, player) + if not moves: + # current player has no moves -> they lose + winner = 2 if player == 1 else 1 + break - move, log_prob = select_move_and_logprob(model, game_state, player) + state_vec = encode_state(game_state, player) + pi, chosen_move, move_list = mcts_search( + game_state, + player, + net, + num_simulations=num_simulations, + c_puct=1.5, + temperature=temperature, + ) - if move is None: - # Current player cannot move -> they lose - if player == 1: - reward_p1 = -1.0 - reward_p2 = +1.0 - else: - reward_p1 = +1.0 - reward_p2 = -1.0 - return log_probs[1], log_probs[2], reward_p1, reward_p2 + if pi is None or chosen_move is None: + winner = 2 if player == 1 else 1 + break - tidx, tile, placement = move + trajectory.append( + { + "state": state_vec, + "moves": move_list, + "pi": pi, + "player": player, + } + ) + + # Apply chosen move to real game + tidx, tile, placement = chosen_move gp = game.Player.P1 if player == 1 else game.Player.P2 - # Apply move board.place(tile, placement, gp) if player == 1: p1tiles.remove(tidx) else: p2tiles.remove(tidx) - # Update game_state tuple game_state = (board, p1tiles, p2tiles) if watch: print_game_state(game_state) - # Store log_prob - log_probs[player].append(log_prob) - - # Switch player player = 2 if player == 1 else 1 + # Convert trajectory into training examples + examples = [] + for step in trajectory: + state_vec = step["state"] + moves = step["moves"] + pi = step["pi"] + p = step["player"] + z = 1.0 if p == winner else -1.0 + examples.append((state_vec, moves, pi, z)) -def train( - model: MoveScorer, - num_episodes: int = 1000, + return examples + + +def alpha_zero_train_step(net: AlphaZeroNet, optimizer, batch): + """ + One gradient update on a batch of (state_vec, moves, pi_target, z_target). + """ + if not batch: + return 0.0 + + total_value_loss = 0.0 + total_policy_loss = 0.0 + + for state_vec, moves, pi_target, z in batch: + state_vec = state_vec + z_t = torch.tensor(z, dtype=torch.float32) + + # Encode moves + move_feats = [] + for tidx, perm, placement in moves: + move_feats.append(encode_move(tidx, placement)) + move_feats = torch.stack(move_feats, dim=0) + + state_emb, v_pred = net.forward_state(state_vec) + p_pred, logits = net.forward_policy(state_emb, move_feats) + + # Value loss + value_loss = (v_pred - z_t) ** 2 + + # Policy loss: cross-entropy between pi_target and p_pred + pi_target = pi_target.to(p_pred.dtype) + policy_loss = -torch.sum(pi_target * torch.log(p_pred + 1e-8)) + + total_value_loss = total_value_loss + value_loss + total_policy_loss = total_policy_loss + policy_loss + + loss = (total_value_loss + total_policy_loss) / len(batch) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + return float(loss.item()) + + +def train_alpha_zero( + net: AlphaZeroNet, + num_iterations: int = 50, + games_per_iter: int = 5, + num_simulations: int = 50, + batch_size: int = 32, lr: float = 1e-3, - save_path: str = "trained_agent.pt", - watch: bool = False, + save_path: str = "az_trained_agent.pt", + watch_selfplay: bool = False, ): - optimizer = torch.optim.Adam(model.parameters(), lr=lr) + optimizer = torch.optim.Adam(net.parameters(), lr=lr) - # We'll keep a history of losses for plotting loss_history = [] - - # Checkpoint path (partial training state) ckpt_path = save_path + ".ckpt" + start_iter = 1 - start_episode = 1 - - # Try to resume from checkpoint if it exists + # Try to resume from checkpoint if os.path.exists(ckpt_path): ckpt = torch.load(ckpt_path, map_location="cpu") - model.load_state_dict(ckpt["model_state"]) - optimizer.load_state_dict(ckpt["optimizer_state"]) - start_episode = ckpt["episode"] + 1 - loss_history = ckpt.get("loss_history", []) - print(f"Resuming training from episode {start_episode} (found checkpoint).") + try: + net.load_state_dict(ckpt["model_state"]) + optimizer.load_state_dict(ckpt["optimizer_state"]) + start_iter = ckpt["iteration"] + 1 + loss_history = ckpt.get("loss_history", []) + print(f"Resuming training from iteration {start_iter} (found checkpoint).") + if start_iter > num_iterations: + print("Checkpoint exceeds requested num_iterations; nothing to train.") + plot_losses(loss_history, out_path="loss_curve.png") + torch.save(net.state_dict(), save_path) + return + except Exception as e: + print("Checkpoint incompatible, starting fresh.") + print("Reason:", e) - # If we've already passed num_episodes, just plot and exit - if start_episode > num_episodes: - print( - "Checkpoint episode exceeds requested num_episodes; nothing to train." + pbar = trange( + start_iter, num_iterations + 1, desc="AZ Training", dynamic_ncols=True + ) + + for it in pbar: + replay_buffer = [] + + # 1. Self-play games to generate fresh data + for g in range(games_per_iter): + examples = self_play_game( + net, + num_simulations=num_simulations, + temperature=1.0, # can anneal later + watch=watch_selfplay, ) - plot_losses(loss_history, out_path="loss_curve.png") - torch.save(model.state_dict(), save_path) - return + replay_buffer.extend(examples) - pbar = trange(start_episode, num_episodes + 1, desc="Training", dynamic_ncols=True) + if not replay_buffer: + print("No data generated this iteration, something is wrong.") + continue - for episode in pbar: - log_probs_p1, log_probs_p2, r1, r2 = play_self_play_game(model, watch=watch) + random.shuffle(replay_buffer) - loss = torch.tensor(0.0) - if log_probs_p1: - loss = loss - r1 * torch.stack(log_probs_p1).sum() - if log_probs_p2: - loss = loss - r2 * torch.stack(log_probs_p2).sum() + # 2. Train on this batch of data + batch_losses = [] + for i in range(0, len(replay_buffer), batch_size): + batch = replay_buffer[i : i + batch_size] + loss = alpha_zero_train_step(net, optimizer, batch) + batch_losses.append(loss) - optimizer.zero_grad() - loss.backward() - optimizer.step() + mean_loss = sum(batch_losses) / len(batch_losses) + loss_history.append(mean_loss) + pbar.set_postfix(loss=mean_loss, data=len(replay_buffer)) - loss_value = float(loss.item()) - loss_history.append(loss_value) - - # Update progress bar with most recent stats - pbar.set_postfix( - loss=loss_value, - ) - - # Save checkpoint every N episodes (and at the very end) - if episode % 50 == 0 or episode == num_episodes: + # 3. Save checkpoint + if it % 5 == 0 or it == num_iterations: torch.save( { - "episode": episode, - "model_state": model.state_dict(), + "iteration": it, + "model_state": net.state_dict(), "optimizer_state": optimizer.state_dict(), "loss_history": loss_history, }, ckpt_path, ) + plot_losses(loss_history, out_path="loss_curve_ckpt.png") - # Final model save - torch.save(model.state_dict(), save_path) + torch.save(net.state_dict(), save_path) print(f"\nTraining finished. Model saved to {save_path}") - - # Save final loss plot plot_losses(loss_history, out_path="loss_curve.png") @@ -445,38 +694,79 @@ def train( ################### -def play_vs_ai(model: MoveScorer, human_is: int = 1): +def az_choose_move( + net: AlphaZeroNet, game_state, player: int, num_simulations: int = 100 +): """ - Let a human play against the trained model. + Use MCTS with the trained net to choose a move for actual play. + """ + moves = get_legal_moves(game_state, player) + if not moves: + return None + + # Temperature 0 for deterministic play + pi, chosen_move, root_moves = mcts_search( + game_state, + player, + net, + num_simulations=num_simulations, + c_puct=1.5, + temperature=0.0, + ) + if chosen_move is None: + return None + return chosen_move + + +def play_vs_ai(net: AlphaZeroNet, human_is: int = 1, num_simulations: int = 100): + """ + Let a human play against the AlphaZero-style agent. human_is: 1 or 2 """ game_state = initial_game_state() board, p1tiles, p2tiles = game_state - human = HumanAgent() - ai = MLAgent(model, deterministic=True, epsilon=0.0) - - agents = { - human_is: human, - 1 if human_is == 2 else 2: ai, - } - player = 1 while True: - agent = agents[player] - move = agent.choose_move(game_state, player) + print_game_state(game_state) - if move is None: - print(f"No moves left, player {player} lost") - if player == human_is: + if player == human_is: + # Human move + moves = get_legal_moves(game_state, player) + if not moves: + print(f"No moves left for player {player}") print("You lost 😢") - else: - print("You won! 🎉") - break + break + + print(f"Player {player}, you have {len(moves)} possible moves.") + for i, (tidx, perm, plc) in enumerate(moves[:50]): + print(f"[{i}] tile {tidx} at {plc}") + if len(moves) > 50: + print(f"... and {len(moves) - 50} more moves not listed") + + while True: + try: + choice = int(input("Enter move index: ")) + if 0 <= choice < len(moves): + move = moves[choice] + break + else: + print("Invalid index, try again.") + except ValueError: + print("Please enter an integer.") + else: + # AI move + print(f"AI (player {player}) is thinking...") + move = az_choose_move( + net, game_state, player, num_simulations=num_simulations + ) + if move is None: + print(f"No moves left for AI player {player}") + print("AI lost, you win! 🎉") + break tidx, tile, placement = move gp = game.Player.P1 if player == 1 else game.Player.P2 - print(f"player {player} places tile {tidx} at {placement}\n{tile}") board.place(tile, placement, gp) @@ -486,8 +776,6 @@ def play_vs_ai(model: MoveScorer, human_is: int = 1): p2tiles.remove(tidx) game_state = (board, p1tiles, p2tiles) - print_game_state(game_state) - player = 2 if player == 1 else 1 @@ -497,37 +785,38 @@ def play_vs_ai(model: MoveScorer, human_is: int = 1): def main(): - model = MoveScorer() - - if torch.cuda.is_available(): - print("using CUDA") - torch.device("cuda:0") - else: - print("Not using CUDA") + net = AlphaZeroNet() if "--play" in sys.argv: - # Try to load trained weights if they exist - model_path = "trained_agent.pt" + model_path = "az_trained_agent.pt" if os.path.exists(model_path): print(f"Loading model from {model_path}") state = torch.load(model_path, map_location="cpu") - model.load_state_dict(state) - model.eval() + try: + net.load_state_dict(state) + net.eval() + except Exception as e: + print("Saved model incompatible with current net, playing untrained.") + print("Reason:", e) else: print( - "Warning: trained_agent.pt not found. Playing with an untrained model." + "Warning: az_trained_agent.pt not found. Playing with an untrained model." ) # By default, human is player 1; change to 2 if you want - play_vs_ai(model, human_is=1) + play_vs_ai(net, human_is=1, num_simulations=100) + else: - # Train by self-play - train( - model, - num_episodes=1000, + # AlphaZero-style training by self-play + train_alpha_zero( + net, + num_iterations=250, # number of training iterations + games_per_iter=5, # self-play games per iteration + num_simulations=50, # MCTS simulations per move during self-play + batch_size=32, lr=1e-3, - save_path="trained_agent.pt", - watch="--watch" in sys.argv, + save_path="az_trained_agent.pt", + watch_selfplay="--watch" in sys.argv, ) diff --git a/pyrightconfig.json b/pyrightconfig.json deleted file mode 100644 index 71f286f..0000000 --- a/pyrightconfig.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "venv": ".venv", - "venvPath": "./" -}