diff --git a/blokus.py b/blokus.py index 569184b..ba7d533 100755 --- a/blokus.py +++ b/blokus.py @@ -2,9 +2,6 @@ import random import sys import os -import math -from collections import defaultdict - import game import torch @@ -13,49 +10,39 @@ import torch.nn.functional as F from tqdm.auto import trange import matplotlib.pyplot as plt -#################### -# Global constants # -#################### +############### +# Utilities # +############### BOARD_SIZE = 14 tiles = game.game_tiles() -TILE_COUNT = len(tiles) - -# We need a max orientation count to encode orientation if desired -# If you want to use orientations later, this is handy. -ORIENT_MAX = max(len(t.permutations()) for t in tiles) - -################### -# Utility helpers # -################### def print_game_state(game_state: tuple[game.Board, list[int], list[int]]): (board, p1tiles, p2tiles) = game_state barr = [] - for y in range(BOARD_SIZE): - row = [] - for x in range(BOARD_SIZE): - row.append(board[(x, y)]) - barr.append(row) + for i in range(BOARD_SIZE): + barr.append([]) + for j in range(BOARD_SIZE): + barr[i].append(board[(j, i)]) - print(f" {'--' * BOARD_SIZE} ") + print(f" {'-' * BOARD_SIZE} ") for row in barr: print( f"|{ ''.join( [ - ' ' if x == 0 else '\033[93m██\033[00m' if x == 1 else '\033[94m██\033[00m' if x == 2 else '██' + ' ' if x == 0 else 'X' if x == 1 else 'O' if x == 2 else 'S' for x in row ] ) }|" ) - print(f" {'--' * BOARD_SIZE} ") + print(f" {'-' * BOARD_SIZE} ") - print(f"\033[93mPlayer 1\033[00m tiles left: {p1tiles}") - print(f"\033[94mPlayer 2\033[00m tiles left: {p2tiles}") + print(f"Player 1 tiles left: {p1tiles}") + print(f"Player 2 tiles left: {p2tiles}") def plot_losses(loss_history, out_path="loss_curve.png"): @@ -65,9 +52,9 @@ def plot_losses(loss_history, out_path="loss_curve.png"): plt.figure() plt.plot(range(1, len(loss_history) + 1), loss_history) - plt.xlabel("Training iteration") + plt.xlabel("Episode") plt.ylabel("Loss") - plt.title("AlphaZero training loss") + plt.title("Training loss over episodes") plt.tight_layout() plt.savefig(out_path) plt.close() @@ -82,8 +69,8 @@ def plot_losses(loss_history, out_path="loss_curve.png"): def initial_game_state(): return ( game.Board(), - [i for i in range(21)], # tiles for player 1 - [i for i in range(21)], # tiles for player 2 + [i for i in range(21)], + [i for i in range(21)], ) @@ -93,12 +80,7 @@ def initial_game_state(): def encode_board(board: game.Board) -> torch.Tensor: - """ - Encode board as (3, H, W) float tensor: - channel 0: player 1 stones - channel 1: player 2 stones - channel 2: value 3 ("S" or special) - """ + # board[(x, y)] returns 0,1,2,... according to your print function arr = torch.zeros((3, BOARD_SIZE, BOARD_SIZE), dtype=torch.float32) for y in range(BOARD_SIZE): for x in range(BOARD_SIZE): @@ -107,18 +89,13 @@ def encode_board(board: game.Board) -> torch.Tensor: arr[0, y, x] = 1.0 elif v == 2: arr[1, y, x] = 1.0 - elif v == 3: + elif v == 3: # if "S" or something else arr[2, y, x] = 1.0 return arr def encode_tiles(p1tiles, p2tiles) -> torch.Tensor: - """ - Encode which tiles each player still has. - 21 tiles per player -> 42-dim vector: - indices 0..20: tile i available for P1? - indices 21..41: tile i available for P2? - """ + # 21 tiles total, so 42-dim vector v = torch.zeros(42, dtype=torch.float32) for t in p1tiles: v[t] = 1.0 @@ -128,153 +105,74 @@ def encode_tiles(p1tiles, p2tiles) -> torch.Tensor: return v -def encode_state(game_state, player: int) -> torch.Tensor: - """ - Flattened state encoding for the NN: - - board one-hot: 3 * 14 * 14 = 588 - - tiles: 42 - - current player bit: 1 - Total dim = 631. - """ - board, p1tiles, p2tiles = game_state - board_tensor = encode_board(board).flatten() # (588,) - tiles_tensor = encode_tiles(p1tiles, p2tiles) # (42,) - player_bit = torch.tensor([1.0 if player == 1 else 0.0], dtype=torch.float32) - return torch.cat([board_tensor, tiles_tensor, player_bit], dim=0) # (631,) - - def encode_move(tile_idx: int, placement: tuple[int, int]) -> torch.Tensor: - """ - Encode a move (ignoring orientation for now): - - tile index one-hot (21) - - normalized x, y - Total dim = 23. - """ - x, y = placement + (x, y) = placement tile_vec = torch.zeros(21, dtype=torch.float32) tile_vec[tile_idx] = 1.0 pos_vec = torch.tensor( - [x / (BOARD_SIZE - 1), y / (BOARD_SIZE - 1)], - dtype=torch.float32, + [x / (BOARD_SIZE - 1), y / (BOARD_SIZE - 1)], dtype=torch.float32 ) - return torch.cat([tile_vec, pos_vec], dim=0) # (23,) + return torch.cat([tile_vec, pos_vec], dim=0) # 23-dim -STATE_DIM = 588 + 42 + 1 # 631 -MOVE_DIM = 23 +def encode_state_and_move( + game_state, player: int, tile_idx: int, placement: tuple[int, int], perm: game.Tile +): + board, p1tiles, p2tiles = game_state + + # Encode board BEFORE the move + board_before = encode_board(board).flatten() + + # Encode board AFTER the move using sim_place + gp = game.Player.P1 if player == 1 else game.Player.P2 + board_after_sim = board.sim_place( + perm, placement, gp + ) # <--- uses your new function + board_after = encode_board(board_after_sim).flatten() + + tiles_tensor = encode_tiles(p1tiles, p2tiles) + move_tensor = encode_move(tile_idx, placement) # still tile+position encoding + player_tensor = torch.tensor([1.0 if player == 1 else 0.0], dtype=torch.float32) + + return torch.cat( + [ + board_before, # 588 + board_after, # 588 + tiles_tensor, # 42 + move_tensor, # 23 + player_tensor, # 1 + ], + dim=0, + ) -############################## -# AlphaZero-style neural net # -############################## +########### +# Model # +########### + +FEATURE_SIZE = 1242 # from above -class AlphaZeroNet(nn.Module): - """ - AlphaZero-style network for this Blokus-like game. - - Given a state vector (flattened board + tiles + player), - it produces: - - value v in [-1,1] (estimate of who will win) - - a policy over moves *relative to a given move set*, by - combining a state embedding with move features. - """ - - def __init__(self, state_emb_dim=256, move_emb_dim=64): +class MoveScorer(nn.Module): + def __init__(self): super().__init__() + self.fc1 = nn.Linear(FEATURE_SIZE, 256) + self.fc2 = nn.Linear(256, 128) + self.fc_out = nn.Linear(128, 1) # scalar score - # State encoder: simple MLP on the 1D state vector - self.state_mlp = nn.Sequential( - nn.Linear(STATE_DIM, 256), - nn.ReLU(), - nn.Linear(256, state_emb_dim), - nn.ReLU(), - ) - - # Value head: from state embedding -> scalar, tanh in [-1,1] - self.value_head = nn.Sequential( - nn.Linear(state_emb_dim, 64), - nn.ReLU(), - nn.Linear(64, 1), - nn.Tanh(), - ) - - # Move encoder: encode move features - self.move_mlp = nn.Sequential( - nn.Linear(MOVE_DIM, 64), - nn.ReLU(), - nn.Linear(64, move_emb_dim), - nn.ReLU(), - ) - - # Policy head: combine state_emb + move_emb -> logit per move - self.policy_head = nn.Sequential( - nn.Linear(state_emb_dim + move_emb_dim, 128), - nn.ReLU(), - nn.Linear(128, 1), # scalar logit per move - ) - - def forward_state(self, state_vec: torch.Tensor): - """ - state_vec: (STATE_DIM,) - returns: state_emb (state_emb_dim,), value scalar - """ - x = self.state_mlp(state_vec) - v = self.value_head(x).squeeze(-1) - return x, v - - def forward_policy(self, state_emb: torch.Tensor, move_feats: torch.Tensor): - """ - state_emb: (state_emb_dim,) - move_feats: (N, MOVE_DIM) - returns: policy_probs: (N,), logits: (N,) - """ - move_emb = self.move_mlp(move_feats) # (N, move_emb_dim) - state_expanded = state_emb.unsqueeze(0).expand(move_emb.size(0), -1) - h = torch.cat( - [state_expanded, move_emb], dim=1 - ) # (N, state_emb_dim+move_emb_dim) - logits = self.policy_head(h).squeeze(-1) # (N,) - probs = F.softmax(logits, dim=0) - return probs, logits - - -def net_predict(net: AlphaZeroNet, game_state, player: int, moves): - """ - Evaluate net on a position from the perspective of 'player'. - - Returns: - priors: tensor (N,) over given moves (N=len(moves)), or None if no moves - value: float in [-1,1] from player's POV - """ - if not moves: - return None, 0.0 - - state_vec = encode_state(game_state, player) - move_feats = [] - for tidx, perm, placement in moves: - mf = encode_move(tidx, placement) - move_feats.append(mf) - - state_vec = state_vec - move_feats = torch.stack(move_feats, dim=0) - - with torch.no_grad(): - state_emb, v = net.forward_state(state_vec) - priors, logits = net.forward_policy(state_emb, move_feats) - - return priors, float(v.item()) + def forward(self, x): + # x: (batch_size, FEATURE_SIZE) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc_out(x) # (batch_size, 1) ################## -# Move generation # +# Move generation ################## def get_legal_moves(game_state, player: int): - """ - Returns list of (tile_idx, perm, placement) - """ board, p1tiles, p2tiles = game_state gp = game.Player.P1 if player == 1 else game.Player.P2 @@ -286,406 +184,259 @@ def get_legal_moves(game_state, player: int): perms = tile.permutations() for perm in perms: plcs = board.tile_placements(perm, gp) - for plc in plcs: - moves.append((tile_idx, perm, plc)) + moves.extend((tile_idx, perm, plc) for plc in plcs) return moves -############### -# MCTS (PUCT) # -############### +########### +# Agents # +########### -class MCTSNode: - def __init__(self, game_state, player: int, parent=None, move_from_parent=None): - self.game_state = game_state # (board, p1tiles, p2tiles) - self.player = player # player to move at this node (1 or 2) - self.parent = parent - self.move_from_parent = ( - move_from_parent # index of move from parent to this node - ) +class Agent: + def choose_move(self, game_state, player: int): + """Return (tile_idx, perm, placement) or None if no moves.""" + raise NotImplementedError - self.children = {} # move_index -> child MCTSNode - self.moves = None # list of legal moves at this node - self.P = None # tensor of priors over moves (N,) - self.N = defaultdict(int) # visit count per move_index - self.W = defaultdict(float) # total value per move_index - self.Q = defaultdict(float) # mean value per move_index - - self.is_expanded = False - self.is_terminal = False - self.value = 0.0 - - def expand(self, net: AlphaZeroNet): - """ - Expand this node: generate legal moves, get priors & value from net. - """ - if self.is_expanded: - return - - moves = get_legal_moves(self.game_state, self.player) - self.moves = moves +class RandomAgent(Agent): + def choose_move(self, game_state, player: int): + moves = get_legal_moves(game_state, player) if not moves: - # terminal: current player has no moves -> they lose - self.is_terminal = True - self.P = None - self.value = 0.0 - self.is_expanded = True - return - - priors, v = net_predict(net, self.game_state, self.player, moves) - self.P = priors - self.value = v # value from this node's player POV - self.is_expanded = True - - def select_child(self, c_puct=1.5): - """ - Select move index according to PUCT: - a* = argmax_a (Q(s,a) + c_puct * P(s,a) * sqrt(sum_b N(s,b)) / (1 + N(s,a))) - """ - total_N = sum(self.N[i] for i in range(len(self.moves))) + 1e-8 - best_score = -1e9 - best_i = None - for i in range(len(self.moves)): - Q = self.Q[i] - P = float(self.P[i]) - N_sa = self.N[i] - U = c_puct * P * math.sqrt(total_N) / (1 + N_sa) - score = Q + U - if score > best_score: - best_score = score - best_i = i - return best_i - - def simulate_child(self, move_index: int): - """ - Given a move index in self.moves, return or create the child node. - """ - if move_index in self.children: - return self.children[move_index] - - tidx, perm, placement = self.moves[move_index] - - board, p1tiles, p2tiles = self.game_state - gp = game.Player.P1 if self.player == 1 else game.Player.P2 - - # Simulate placing the tile on a copy of the board - next_board = board.sim_place(perm, placement, gp) # does not modify original - - next_p1tiles = list(p1tiles) - next_p2tiles = list(p2tiles) - if self.player == 1: - next_p1tiles.remove(tidx) - next_player = 2 - else: - next_p2tiles.remove(tidx) - next_player = 1 - - next_state = (next_board, next_p1tiles, next_p2tiles) - child = MCTSNode( - next_state, next_player, parent=self, move_from_parent=move_index - ) - self.children[move_index] = child - return child + return None + return random.choice(moves) -def mcts_search( - root_state, - root_player: int, - net: AlphaZeroNet, - num_simulations: int = 50, - c_puct: float = 1.5, - temperature: float = 1.0, -): - """ - Run MCTS from (root_state, root_player). +class HumanAgent(Agent): + def choose_move(self, game_state, player: int): + moves = get_legal_moves(game_state, player) + if not moves: + print(f"No moves left for player {player}") + return None - Returns: - pi: tensor of size (num_moves,), MCTS policy over root moves - chosen_move: (tile_idx, perm, placement) - root_moves: list of all legal moves at root - """ - root = MCTSNode(root_state, root_player) - root.expand(net) + print_game_state(game_state) + print(f"Player {player}, you have {len(moves)} possible moves.") - if root.is_terminal or root.P is None or not root.moves: - # No moves at root - return None, None, [] - - # Run simulations - for _ in range(num_simulations): - node = root - path = [node] - - # 1. Selection: descend tree until leaf - while node.is_expanded and not node.is_terminal: - move_i = node.select_child(c_puct) - node = node.simulate_child(move_i) - path.append(node) - - # 2. Expansion / evaluation of leaf - node.expand(net) - - # 3. Compute leaf value from root player's perspective - if node.is_terminal: - # terminal: node.player has no moves -> node.player loses - if root_player == node.player: - v_root = -1.0 + # Show a *subset* or all moves + for i, (tidx, perm, plc) in enumerate(moves): + if i < 50: # don't spam too hard; tweak as needed + print(f"[{i}] tile {tidx} at {plc}") else: - v_root = +1.0 + break + if len(moves) > 50: + print(f"... and {len(moves) - 50} more moves not listed") + + while True: + try: + choice = int(input("Enter move index: ")) + if 0 <= choice < len(moves): + return moves[choice] + else: + print("Invalid index, try again.") + except ValueError: + print("Please enter an integer.") + + +class MLAgent(Agent): + def __init__( + self, model: MoveScorer, deterministic: bool = True, epsilon: float = 0.0 + ): + self.model = model + self.deterministic = deterministic + self.epsilon = epsilon + + def choose_move(self, game_state, player: int): + moves = get_legal_moves(game_state, player) + if not moves: + return None + + # Optional epsilon-greedy: use 0 for “serious play” + if self.epsilon > 0.0 and random.random() < self.epsilon: + return random.choice(moves) + + # Build feature batch + features = [] + for tidx, perm, placement in moves: + feat = encode_state_and_move(game_state, player, tidx, placement, perm) + features.append(feat) + X = torch.stack(features, dim=0) + + self.model.eval() + with torch.no_grad(): + scores = self.model(X).squeeze(-1) # (num_moves,) + + if self.deterministic: + best_idx = torch.argmax(scores).item() + return moves[best_idx] else: - # non-terminal: node.value is from node.player's POV - v_node = node.value - v_root = v_node if node.player == root_player else -v_node - - # 4. Backup: update Q, N along the path - value = v_root - for i in range(len(path) - 1): - parent = path[i] - child = path[i + 1] - mv_i = child.move_from_parent - parent.N[mv_i] += 1 - parent.W[mv_i] += value - parent.Q[mv_i] = parent.W[mv_i] / parent.N[mv_i] - value = -value # flip sign at each ply - - # Visits at root - visits = torch.tensor( - [root.N[i] for i in range(len(root.moves))], - dtype=torch.float32, - ) - - if temperature == 0: - # deterministic: all mass on argmax visit - best = torch.argmax(visits).item() - pi = torch.zeros_like(visits) - pi[best] = 1.0 - else: - pi = visits ** (1.0 / temperature) - if pi.sum() > 0: - pi = pi / pi.sum() - else: - pi = torch.ones_like(pi) / pi.numel() - - # Choose move to actually play (sample from pi) - move_index = torch.multinomial(pi, num_samples=1).item() - chosen_move = root.moves[move_index] - - return pi, chosen_move, root.moves + # Sample from softmax for more variety + probs = torch.softmax(scores, dim=0) + idx = torch.multinomial(probs, num_samples=1).item() + return moves[idx] -######################## -# Self-play + training # -######################## +###################### +# Training utilities # +###################### -def self_play_game( - net: AlphaZeroNet, - num_simulations: int = 50, - temperature: float = 1.0, - watch: bool = False, -): +def select_move_and_logprob(model: MoveScorer, game_state, player: int): """ - Play one self-play game using MCTS + net. + For training: sample a move from softmax over scores + and return (move, log_prob). If no moves, returns (None, None). + """ + moves = get_legal_moves(game_state, player) + if not moves: + return None, None + features = [] + for tidx, perm, placement in moves: + feat = encode_state_and_move(game_state, player, tidx, placement, perm) + features.append(feat) + X = torch.stack(features, dim=0) # (num_moves, FEATURE_SIZE) + + scores = model(X).squeeze(-1) # (num_moves,) + probs = F.softmax(scores, dim=0) + + dist = torch.distributions.Categorical(probs) + idx = dist.sample() + log_prob = dist.log_prob(idx) + + move = moves[idx.item()] + return move, log_prob + + +def play_self_play_game(model: MoveScorer, max_turns: int = 500, watch: bool = False): + """ + Self-play game with the same model as both players. Returns: - examples: list of (state_tensor, moves, pi, z) - - state_tensor: encoding of state at root of each move - - moves: list of legal moves at that root - - pi: tensor of MCTS policy over moves - - z: final outcome from that state's player POV (+1 or -1) + log_probs_p1, log_probs_p2, reward_p1, reward_p2 + where rewards are +1/-1 for win/loss. """ game_state = initial_game_state() board, p1tiles, p2tiles = game_state - player = 1 - trajectory = [] # list of dicts: {"state": state_vec, "moves": moves, "pi": pi, "player": player} + log_probs = {1: [], 2: []} + player = 1 + turns = 0 while True: - moves = get_legal_moves(game_state, player) - if not moves: - # current player has no moves -> they lose - winner = 2 if player == 1 else 1 - break + turns += 1 + if turns > max_turns: + # Safety: declare a draw + reward_p1 = 0.0 + reward_p2 = 0.0 + return log_probs[1], log_probs[2], reward_p1, reward_p2 - state_vec = encode_state(game_state, player) - pi, chosen_move, move_list = mcts_search( - game_state, - player, - net, - num_simulations=num_simulations, - c_puct=1.5, - temperature=temperature, - ) + move, log_prob = select_move_and_logprob(model, game_state, player) - if pi is None or chosen_move is None: - winner = 2 if player == 1 else 1 - break + if move is None: + # Current player cannot move -> they lose + if player == 1: + reward_p1 = -1.0 + reward_p2 = +1.0 + else: + reward_p1 = +1.0 + reward_p2 = -1.0 + return log_probs[1], log_probs[2], reward_p1, reward_p2 - trajectory.append( - { - "state": state_vec, - "moves": move_list, - "pi": pi, - "player": player, - } - ) - - # Apply chosen move to real game - tidx, tile, placement = chosen_move + tidx, tile, placement = move gp = game.Player.P1 if player == 1 else game.Player.P2 + # Apply move board.place(tile, placement, gp) if player == 1: p1tiles.remove(tidx) else: p2tiles.remove(tidx) + # Update game_state tuple game_state = (board, p1tiles, p2tiles) if watch: print_game_state(game_state) + # Store log_prob + log_probs[player].append(log_prob) + + # Switch player player = 2 if player == 1 else 1 - # Convert trajectory into training examples - examples = [] - for step in trajectory: - state_vec = step["state"] - moves = step["moves"] - pi = step["pi"] - p = step["player"] - z = 1.0 if p == winner else -1.0 - examples.append((state_vec, moves, pi, z)) - return examples - - -def alpha_zero_train_step(net: AlphaZeroNet, optimizer, batch): - """ - One gradient update on a batch of (state_vec, moves, pi_target, z_target). - """ - if not batch: - return 0.0 - - total_value_loss = 0.0 - total_policy_loss = 0.0 - - for state_vec, moves, pi_target, z in batch: - state_vec = state_vec - z_t = torch.tensor(z, dtype=torch.float32) - - # Encode moves - move_feats = [] - for tidx, perm, placement in moves: - move_feats.append(encode_move(tidx, placement)) - move_feats = torch.stack(move_feats, dim=0) - - state_emb, v_pred = net.forward_state(state_vec) - p_pred, logits = net.forward_policy(state_emb, move_feats) - - # Value loss - value_loss = (v_pred - z_t) ** 2 - - # Policy loss: cross-entropy between pi_target and p_pred - pi_target = pi_target.to(p_pred.dtype) - policy_loss = -torch.sum(pi_target * torch.log(p_pred + 1e-8)) - - total_value_loss = total_value_loss + value_loss - total_policy_loss = total_policy_loss + policy_loss - - loss = (total_value_loss + total_policy_loss) / len(batch) - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - return float(loss.item()) - - -def train_alpha_zero( - net: AlphaZeroNet, - num_iterations: int = 50, - games_per_iter: int = 5, - num_simulations: int = 50, - batch_size: int = 32, +def train( + model: MoveScorer, + num_episodes: int = 1000, lr: float = 1e-3, - save_path: str = "az_trained_agent.pt", - watch_selfplay: bool = False, + save_path: str = "trained_agent.pt", + watch: bool = False, ): - optimizer = torch.optim.Adam(net.parameters(), lr=lr) + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + # We'll keep a history of losses for plotting loss_history = [] - ckpt_path = save_path + ".ckpt" - start_iter = 1 - # Try to resume from checkpoint + # Checkpoint path (partial training state) + ckpt_path = save_path + ".ckpt" + + start_episode = 1 + + # Try to resume from checkpoint if it exists if os.path.exists(ckpt_path): ckpt = torch.load(ckpt_path, map_location="cpu") - try: - net.load_state_dict(ckpt["model_state"]) - optimizer.load_state_dict(ckpt["optimizer_state"]) - start_iter = ckpt["iteration"] + 1 - loss_history = ckpt.get("loss_history", []) - print(f"Resuming training from iteration {start_iter} (found checkpoint).") - if start_iter > num_iterations: - print("Checkpoint exceeds requested num_iterations; nothing to train.") - plot_losses(loss_history, out_path="loss_curve.png") - torch.save(net.state_dict(), save_path) - return - except Exception as e: - print("Checkpoint incompatible, starting fresh.") - print("Reason:", e) + model.load_state_dict(ckpt["model_state"]) + optimizer.load_state_dict(ckpt["optimizer_state"]) + start_episode = ckpt["episode"] + 1 + loss_history = ckpt.get("loss_history", []) + print(f"Resuming training from episode {start_episode} (found checkpoint).") - pbar = trange( - start_iter, num_iterations + 1, desc="AZ Training", dynamic_ncols=True - ) - - for it in pbar: - replay_buffer = [] - - # 1. Self-play games to generate fresh data - for g in range(games_per_iter): - examples = self_play_game( - net, - num_simulations=num_simulations, - temperature=1.0, # can anneal later - watch=watch_selfplay, + # If we've already passed num_episodes, just plot and exit + if start_episode > num_episodes: + print( + "Checkpoint episode exceeds requested num_episodes; nothing to train." ) - replay_buffer.extend(examples) + plot_losses(loss_history, out_path="loss_curve.png") + torch.save(model.state_dict(), save_path) + return - if not replay_buffer: - print("No data generated this iteration, something is wrong.") - continue + pbar = trange(start_episode, num_episodes + 1, desc="Training", dynamic_ncols=True) - random.shuffle(replay_buffer) + for episode in pbar: + log_probs_p1, log_probs_p2, r1, r2 = play_self_play_game(model, watch=watch) - # 2. Train on this batch of data - batch_losses = [] - for i in range(0, len(replay_buffer), batch_size): - batch = replay_buffer[i : i + batch_size] - loss = alpha_zero_train_step(net, optimizer, batch) - batch_losses.append(loss) + loss = torch.tensor(0.0) + if log_probs_p1: + loss = loss - r1 * torch.stack(log_probs_p1).sum() + if log_probs_p2: + loss = loss - r2 * torch.stack(log_probs_p2).sum() - mean_loss = sum(batch_losses) / len(batch_losses) - loss_history.append(mean_loss) - pbar.set_postfix(loss=mean_loss, data=len(replay_buffer)) + optimizer.zero_grad() + loss.backward() + optimizer.step() - # 3. Save checkpoint - if it % 5 == 0 or it == num_iterations: + loss_value = float(loss.item()) + loss_history.append(loss_value) + + # Update progress bar with most recent stats + pbar.set_postfix( + loss=loss_value, + ) + + # Save checkpoint every N episodes (and at the very end) + if episode % 50 == 0 or episode == num_episodes: torch.save( { - "iteration": it, - "model_state": net.state_dict(), + "episode": episode, + "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "loss_history": loss_history, }, ckpt_path, ) - plot_losses(loss_history, out_path="loss_curve_ckpt.png") - torch.save(net.state_dict(), save_path) + # Final model save + torch.save(model.state_dict(), save_path) print(f"\nTraining finished. Model saved to {save_path}") + + # Save final loss plot plot_losses(loss_history, out_path="loss_curve.png") @@ -694,79 +445,38 @@ def train_alpha_zero( ################### -def az_choose_move( - net: AlphaZeroNet, game_state, player: int, num_simulations: int = 100 -): +def play_vs_ai(model: MoveScorer, human_is: int = 1): """ - Use MCTS with the trained net to choose a move for actual play. - """ - moves = get_legal_moves(game_state, player) - if not moves: - return None - - # Temperature 0 for deterministic play - pi, chosen_move, root_moves = mcts_search( - game_state, - player, - net, - num_simulations=num_simulations, - c_puct=1.5, - temperature=0.0, - ) - if chosen_move is None: - return None - return chosen_move - - -def play_vs_ai(net: AlphaZeroNet, human_is: int = 1, num_simulations: int = 100): - """ - Let a human play against the AlphaZero-style agent. + Let a human play against the trained model. human_is: 1 or 2 """ game_state = initial_game_state() board, p1tiles, p2tiles = game_state + human = HumanAgent() + ai = MLAgent(model, deterministic=True, epsilon=0.0) + + agents = { + human_is: human, + 1 if human_is == 2 else 2: ai, + } + player = 1 while True: - print_game_state(game_state) + agent = agents[player] + move = agent.choose_move(game_state, player) - if player == human_is: - # Human move - moves = get_legal_moves(game_state, player) - if not moves: - print(f"No moves left for player {player}") + if move is None: + print(f"No moves left, player {player} lost") + if player == human_is: print("You lost 😢") - break - - print(f"Player {player}, you have {len(moves)} possible moves.") - for i, (tidx, perm, plc) in enumerate(moves[:50]): - print(f"[{i}] tile {tidx} at {plc}") - if len(moves) > 50: - print(f"... and {len(moves) - 50} more moves not listed") - - while True: - try: - choice = int(input("Enter move index: ")) - if 0 <= choice < len(moves): - move = moves[choice] - break - else: - print("Invalid index, try again.") - except ValueError: - print("Please enter an integer.") - else: - # AI move - print(f"AI (player {player}) is thinking...") - move = az_choose_move( - net, game_state, player, num_simulations=num_simulations - ) - if move is None: - print(f"No moves left for AI player {player}") - print("AI lost, you win! 🎉") - break + else: + print("You won! 🎉") + break tidx, tile, placement = move gp = game.Player.P1 if player == 1 else game.Player.P2 + print(f"player {player} places tile {tidx} at {placement}\n{tile}") board.place(tile, placement, gp) @@ -776,6 +486,8 @@ def play_vs_ai(net: AlphaZeroNet, human_is: int = 1, num_simulations: int = 100) p2tiles.remove(tidx) game_state = (board, p1tiles, p2tiles) + print_game_state(game_state) + player = 2 if player == 1 else 1 @@ -785,38 +497,37 @@ def play_vs_ai(net: AlphaZeroNet, human_is: int = 1, num_simulations: int = 100) def main(): - net = AlphaZeroNet() + model = MoveScorer() + + if torch.cuda.is_available(): + print("using CUDA") + torch.device("cuda:0") + else: + print("Not using CUDA") if "--play" in sys.argv: - model_path = "az_trained_agent.pt" + # Try to load trained weights if they exist + model_path = "trained_agent.pt" if os.path.exists(model_path): print(f"Loading model from {model_path}") state = torch.load(model_path, map_location="cpu") - try: - net.load_state_dict(state) - net.eval() - except Exception as e: - print("Saved model incompatible with current net, playing untrained.") - print("Reason:", e) + model.load_state_dict(state) + model.eval() else: print( - "Warning: az_trained_agent.pt not found. Playing with an untrained model." + "Warning: trained_agent.pt not found. Playing with an untrained model." ) # By default, human is player 1; change to 2 if you want - play_vs_ai(net, human_is=1, num_simulations=100) - + play_vs_ai(model, human_is=1) else: - # AlphaZero-style training by self-play - train_alpha_zero( - net, - num_iterations=250, # number of training iterations - games_per_iter=5, # self-play games per iteration - num_simulations=50, # MCTS simulations per move during self-play - batch_size=32, + # Train by self-play + train( + model, + num_episodes=1000, lr=1e-3, - save_path="az_trained_agent.pt", - watch_selfplay="--watch" in sys.argv, + save_path="trained_agent.pt", + watch="--watch" in sys.argv, ) diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 0000000..71f286f --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,4 @@ +{ + "venv": ".venv", + "venvPath": "./" +}