#!/usr/bin/env python import random import sys import os import math from collections import defaultdict import game import torch import torch.nn as nn import torch.nn.functional as F from tqdm.auto import trange import matplotlib.pyplot as plt #################### # Global constants # #################### BOARD_SIZE = 14 tiles = game.game_tiles() TILE_COUNT = len(tiles) # We need a max orientation count to encode orientation if desired # If you want to use orientations later, this is handy. ORIENT_MAX = max(len(t.permutations()) for t in tiles) ################### # Utility helpers # ################### def print_game_state(game_state: tuple[game.Board, list[int], list[int]]): (board, p1tiles, p2tiles) = game_state barr = [] for y in range(BOARD_SIZE): row = [] for x in range(BOARD_SIZE): row.append(board[(x, y)]) barr.append(row) print(f" {'--' * BOARD_SIZE} ") for row in barr: print( f"|{ ''.join( [ ' ' if x == 0 else '\033[93m██\033[00m' if x == 1 else '\033[94m██\033[00m' if x == 2 else '██' for x in row ] ) }|" ) print(f" {'--' * BOARD_SIZE} ") print(f"\033[93mPlayer 1\033[00m tiles left: {p1tiles}") print(f"\033[94mPlayer 2\033[00m tiles left: {p2tiles}") def plot_losses(loss_history, out_path="loss_curve.png"): if not loss_history: print("No losses to plot.") return plt.figure() plt.plot(range(1, len(loss_history) + 1), loss_history) plt.xlabel("Training iteration") plt.ylabel("Loss") plt.title("BlokuZero training loss") plt.tight_layout() plt.savefig(out_path) plt.close() print(f"Saved loss plot to {out_path}") ################### # Game state init # ################### def initial_game_state(): return ( game.Board(), [i for i in range(21)], # tiles for player 1 [i for i in range(21)], # tiles for player 2 ) ############ # Encoding # ############ def encode_board(board: game.Board) -> torch.Tensor: """ Encode board as (3, H, W) float tensor: channel 0: player 1 stones channel 1: player 2 stones channel 2: value 3 ("S" or special) """ arr = torch.zeros((3, BOARD_SIZE, BOARD_SIZE), dtype=torch.float32) for y in range(BOARD_SIZE): for x in range(BOARD_SIZE): v = board[(x, y)] if v == 1: arr[0, y, x] = 1.0 elif v == 2: arr[1, y, x] = 1.0 elif v == 3: arr[2, y, x] = 1.0 return arr def encode_tiles(p1tiles, p2tiles) -> torch.Tensor: """ Encode which tiles each player still has. 21 tiles per player -> 42-dim vector: indices 0..20: tile i available for P1? indices 21..41: tile i available for P2? """ v = torch.zeros(42, dtype=torch.float32) for t in p1tiles: v[t] = 1.0 offset = 21 for t in p2tiles: v[offset + t] = 1.0 return v def encode_state(game_state, player: int) -> torch.Tensor: """ Flattened state encoding for the NN: - board one-hot: 3 * 14 * 14 = 588 - tiles: 42 - current player bit: 1 Total dim = 631. """ board, p1tiles, p2tiles = game_state board_tensor = encode_board(board).flatten() # (588,) tiles_tensor = encode_tiles(p1tiles, p2tiles) # (42,) player_bit = torch.tensor([1.0 if player == 1 else 0.0], dtype=torch.float32) return torch.cat([board_tensor, tiles_tensor, player_bit], dim=0) # (631,) def encode_move(tile_idx: int, placement: tuple[int, int]) -> torch.Tensor: """ Encode a move (ignoring orientation for now): - tile index one-hot (21) - normalized x, y Total dim = 23. """ x, y = placement tile_vec = torch.zeros(21, dtype=torch.float32) tile_vec[tile_idx] = 1.0 pos_vec = torch.tensor( [x / (BOARD_SIZE - 1), y / (BOARD_SIZE - 1)], dtype=torch.float32, ) return torch.cat([tile_vec, pos_vec], dim=0) # (23,) STATE_DIM = 588 + 42 + 1 # 631 MOVE_DIM = 23 ############################## # AlphaZero-style neural net # ############################## class BlokuZeroNet(nn.Module): """ AlphaZero-style network for this Blokus-like game. Given a state vector (flattened board + tiles + player), it produces: - value v in [-1,1] (estimate of who will win) - a policy over moves *relative to a given move set*, by combining a state embedding with move features. """ def __init__(self, state_emb_dim=256, move_emb_dim=64): super().__init__() # State encoder: simple MLP on the 1D state vector self.state_mlp = nn.Sequential( nn.Linear(STATE_DIM, 256), nn.ReLU(), nn.Linear(256, state_emb_dim), nn.ReLU(), ) # Value head: from state embedding -> scalar, tanh in [-1,1] self.value_head = nn.Sequential( nn.Linear(state_emb_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Tanh(), ) # Move encoder: encode move features self.move_mlp = nn.Sequential( nn.Linear(MOVE_DIM, 64), nn.ReLU(), nn.Linear(64, move_emb_dim), nn.ReLU(), ) # Policy head: combine state_emb + move_emb -> logit per move self.policy_head = nn.Sequential( nn.Linear(state_emb_dim + move_emb_dim, 128), nn.ReLU(), nn.Linear(128, 1), # scalar logit per move ) def forward_state(self, state_vec: torch.Tensor): """ state_vec: (STATE_DIM,) returns: state_emb (state_emb_dim,), value scalar """ x = self.state_mlp(state_vec) v = self.value_head(x).squeeze(-1) return x, v def forward_policy(self, state_emb: torch.Tensor, move_feats: torch.Tensor): """ state_emb: (state_emb_dim,) move_feats: (N, MOVE_DIM) returns: policy_probs: (N,), logits: (N,) """ move_emb = self.move_mlp(move_feats) # (N, move_emb_dim) state_expanded = state_emb.unsqueeze(0).expand(move_emb.size(0), -1) h = torch.cat( [state_expanded, move_emb], dim=1 ) # (N, state_emb_dim+move_emb_dim) logits = self.policy_head(h).squeeze(-1) # (N,) probs = F.softmax(logits, dim=0) return probs, logits def net_predict(net: BlokuZeroNet, game_state, player: int, moves): """ Evaluate net on a position from the perspective of 'player'. Returns: priors: tensor (N,) over given moves (N=len(moves)), or None if no moves value: float in [-1,1] from player's POV """ if not moves: return None, 0.0 state_vec = encode_state(game_state, player) move_feats = [] for tidx, perm, placement in moves: mf = encode_move(tidx, placement) move_feats.append(mf) state_vec = state_vec move_feats = torch.stack(move_feats, dim=0) with torch.no_grad(): state_emb, v = net.forward_state(state_vec) priors, logits = net.forward_policy(state_emb, move_feats) return priors, float(v.item()) ################## # Move generation # ################## def get_legal_moves(game_state, player: int): """ Returns list of (tile_idx, perm, placement) """ board, p1tiles, p2tiles = game_state gp = game.Player.P1 if player == 1 else game.Player.P2 tiles_left = p1tiles if player == 1 else p2tiles moves = [] for tile_idx in tiles_left: tile = tiles[tile_idx] perms = tile.permutations() for perm in perms: plcs = board.tile_placements(perm, gp) for plc in plcs: moves.append((tile_idx, perm, plc)) return moves ############### # MCTS (PUCT) # ############### class MCTSNode: def __init__(self, game_state, player: int, parent=None, move_from_parent=None): self.game_state = game_state # (board, p1tiles, p2tiles) self.player = player # player to move at this node (1 or 2) self.parent = parent self.move_from_parent = ( move_from_parent # index of move from parent to this node ) self.children = {} # move_index -> child MCTSNode self.moves = None # list of legal moves at this node self.P = None # tensor of priors over moves (N,) self.N = defaultdict(int) # visit count per move_index self.W = defaultdict(float) # total value per move_index self.Q = defaultdict(float) # mean value per move_index self.is_expanded = False self.is_terminal = False self.value = 0.0 def expand(self, net: BlokuZeroNet): """ Expand this node: generate legal moves, get priors & value from net. """ if self.is_expanded: return moves = get_legal_moves(self.game_state, self.player) self.moves = moves if not moves: # terminal: current player has no moves -> they lose self.is_terminal = True self.P = None self.value = 0.0 self.is_expanded = True return priors, v = net_predict(net, self.game_state, self.player, moves) self.P = priors self.value = v # value from this node's player POV self.is_expanded = True def select_child(self, c_puct=1.5): """ Select move index according to PUCT: a* = argmax_a (Q(s,a) + c_puct * P(s,a) * sqrt(sum_b N(s,b)) / (1 + N(s,a))) """ total_N = sum(self.N[i] for i in range(len(self.moves))) + 1e-8 best_score = -1e9 best_i = None for i in range(len(self.moves)): Q = self.Q[i] P = float(self.P[i]) N_sa = self.N[i] U = c_puct * P * math.sqrt(total_N) / (1 + N_sa) score = Q + U if score > best_score: best_score = score best_i = i return best_i def simulate_child(self, move_index: int): """ Given a move index in self.moves, return or create the child node. """ if move_index in self.children: return self.children[move_index] tidx, perm, placement = self.moves[move_index] board, p1tiles, p2tiles = self.game_state gp = game.Player.P1 if self.player == 1 else game.Player.P2 # Simulate placing the tile on a copy of the board next_board = board.sim_place(perm, placement, gp) # does not modify original next_p1tiles = list(p1tiles) next_p2tiles = list(p2tiles) if self.player == 1: next_p1tiles.remove(tidx) next_player = 2 else: next_p2tiles.remove(tidx) next_player = 1 next_state = (next_board, next_p1tiles, next_p2tiles) child = MCTSNode( next_state, next_player, parent=self, move_from_parent=move_index ) self.children[move_index] = child return child def mcts_search( root_state, root_player: int, net: BlokuZeroNet, num_simulations: int = 50, c_puct: float = 1.5, temperature: float = 1.0, ): """ Run MCTS from (root_state, root_player). Returns: pi: tensor of size (num_moves,), MCTS policy over root moves chosen_move: (tile_idx, perm, placement) root_moves: list of all legal moves at root """ root = MCTSNode(root_state, root_player) root.expand(net) if root.is_terminal or root.P is None or not root.moves: # No moves at root return None, None, [] # Run simulations for _ in range(num_simulations): node = root path = [node] # 1. Selection: descend tree until leaf while node.is_expanded and not node.is_terminal: move_i = node.select_child(c_puct) node = node.simulate_child(move_i) path.append(node) # 2. Expansion / evaluation of leaf node.expand(net) # 3. Compute leaf value from root player's perspective if node.is_terminal: # terminal: node.player has no moves -> node.player loses if root_player == node.player: v_root = -1.0 else: v_root = +1.0 else: # non-terminal: node.value is from node.player's POV v_node = node.value v_root = v_node if node.player == root_player else -v_node # 4. Backup: update Q, N along the path value = v_root for i in range(len(path) - 1): parent = path[i] child = path[i + 1] mv_i = child.move_from_parent parent.N[mv_i] += 1 parent.W[mv_i] += value parent.Q[mv_i] = parent.W[mv_i] / parent.N[mv_i] value = -value # flip sign at each ply # Visits at root visits = torch.tensor( [root.N[i] for i in range(len(root.moves))], dtype=torch.float32, ) if temperature == 0: # deterministic: all mass on argmax visit best = torch.argmax(visits).item() pi = torch.zeros_like(visits) pi[best] = 1.0 else: pi = visits ** (1.0 / temperature) if pi.sum() > 0: pi = pi / pi.sum() else: pi = torch.ones_like(pi) / pi.numel() # Choose move to actually play (sample from pi) move_index = torch.multinomial(pi, num_samples=1).item() chosen_move = root.moves[move_index] return pi, chosen_move, root.moves ######################## # Self-play + training # ######################## def self_play_game( net: BlokuZeroNet, num_simulations: int = 50, temperature: float = 1.0, watch: bool = False, random_move_prob: float = 0.0, ): """ Play one self-play game using MCTS + net. Returns: examples: list of (state_tensor, moves, pi, z) - state_tensor: encoding of state at root of each move - moves: list of legal moves at that root - pi: tensor of MCTS policy over moves - z: final outcome from that state's player POV (+1 or -1) """ game_state = initial_game_state() board, p1tiles, p2tiles = game_state player = 1 trajectory = [] # list of dicts: {"state": state_vec, "moves": moves, "pi": pi, "player": player} while True: moves = get_legal_moves(game_state, player) if not moves: # current player has no moves -> they lose winner = 2 if player == 1 else 1 break state_vec = encode_state(game_state, player) if random.random() < random_move_prob: print("PERFORMING RANDOM MOVE!!") move_list = moves num_m = len(move_list) # uniform policy over legal moves pi = torch.full((num_m,), 1.0 / num_m, dtype=torch.float32) idx = random.randrange(num_m) chosen_move = move_list[idx] else: # usual AlphaZero MCTS move pi, chosen_move, move_list = mcts_search( game_state, player, net, num_simulations=num_simulations, c_puct=1.5, temperature=temperature, ) if pi is None or chosen_move is None: winner = 2 if player == 1 else 1 break trajectory.append( { "state": state_vec, "moves": move_list, "pi": pi, "player": player, } ) # Apply chosen move to real game tidx, tile, placement = chosen_move gp = game.Player.P1 if player == 1 else game.Player.P2 board.place(tile, placement, gp) if player == 1: p1tiles.remove(tidx) else: p2tiles.remove(tidx) game_state = (board, p1tiles, p2tiles) if watch: print_game_state(game_state) player = 2 if player == 1 else 1 # Convert trajectory into training examples examples = [] for step in trajectory: state_vec = step["state"] moves = step["moves"] pi = step["pi"] p = step["player"] z = 1.0 if p == winner else -1.0 examples.append((state_vec, moves, pi, z)) return examples def alpha_zero_train_step(net: BlokuZeroNet, optimizer, batch): """ One gradient update on a batch of (state_vec, moves, pi_target, z_target). """ if not batch: return 0.0 total_value_loss = 0.0 total_policy_loss = 0.0 for state_vec, moves, pi_target, z in batch: state_vec = state_vec z_t = torch.tensor(z, dtype=torch.float32) # Encode moves move_feats = [] for tidx, perm, placement in moves: move_feats.append(encode_move(tidx, placement)) move_feats = torch.stack(move_feats, dim=0) state_emb, v_pred = net.forward_state(state_vec) p_pred, logits = net.forward_policy(state_emb, move_feats) # Value loss value_loss = (v_pred - z_t) ** 2 # Policy loss: cross-entropy between pi_target and p_pred pi_target = pi_target.to(p_pred.dtype) policy_loss = -torch.sum(pi_target * torch.log(p_pred + 1e-8)) total_value_loss = total_value_loss + value_loss total_policy_loss = total_policy_loss + policy_loss loss = (total_value_loss + total_policy_loss) / len(batch) optimizer.zero_grad() loss.backward() optimizer.step() return float(loss.item()) def train_alpha_zero( net: BlokuZeroNet, num_iterations: int = 50, games_per_iter: int = 5, num_simulations: int = 50, batch_size: int = 32, lr: float = 1e-3, save_path: str = "az_trained_agent.pt", watch_selfplay: bool = False, ): optimizer = torch.optim.Adam(net.parameters(), lr=lr) loss_history = [] ckpt_path = save_path + ".ckpt" start_iter = 1 # Try to resume from checkpoint if os.path.exists(ckpt_path): ckpt = torch.load(ckpt_path, map_location="cpu") try: net.load_state_dict(ckpt["model_state"]) optimizer.load_state_dict(ckpt["optimizer_state"]) start_iter = ckpt["iteration"] + 1 loss_history = ckpt.get("loss_history", []) print(f"Resuming training from iteration {start_iter} (found checkpoint).") if start_iter > num_iterations: print("Checkpoint exceeds requested num_iterations; nothing to train.") plot_losses(loss_history, out_path="loss_curve.png") torch.save(net.state_dict(), save_path) return except Exception as e: print("Checkpoint incompatible, starting fresh.") print("Reason:", e) pbar = trange( start_iter, num_iterations + 1, desc="AZ Training", dynamic_ncols=True ) # simple schedule for random moves in early training warmup_iters = 50 # how many iterations to use randomness max_random_prob = 0.5 # random move probability at iteration 1 for it in pbar: replay_buffer = [] # linearly decay random_move_prob from max_random_prob -> 0 over warmup_iters if it <= warmup_iters: random_move_prob = max_random_prob * (1.0 - (it - 1) / warmup_iters) else: random_move_prob = 0.0 # 1. Self-play games to generate fresh data for g in range(games_per_iter): examples = self_play_game( net, num_simulations=num_simulations, temperature=1.0, # can also anneal later if you like watch=watch_selfplay, random_move_prob=random_move_prob, ) replay_buffer.extend(examples) if not replay_buffer: print("No data generated this iteration, something is wrong.") continue random.shuffle(replay_buffer) # 2. Train on this batch of data batch_losses = [] for i in range(0, len(replay_buffer), batch_size): batch = replay_buffer[i : i + batch_size] loss = alpha_zero_train_step(net, optimizer, batch) batch_losses.append(loss) mean_loss = sum(batch_losses) / len(batch_losses) loss_history.append(mean_loss) pbar.set_postfix(loss=mean_loss, data=len(replay_buffer)) # 3. Save checkpoint if it % 5 == 0 or it == num_iterations: torch.save( { "iteration": it, "model_state": net.state_dict(), "optimizer_state": optimizer.state_dict(), "loss_history": loss_history, }, ckpt_path, ) plot_losses(loss_history, out_path="loss_curve_ckpt.png") torch.save(net.state_dict(), save_path) print(f"\nTraining finished. Model saved to {save_path}") plot_losses(loss_history, out_path="loss_curve.png") def load_net_from_checkpoint(path: str) -> BlokuZeroNet: """ Load an BlokuZeroNet from either: - a plain state_dict file (saved with torch.save(net.state_dict(), ...)), or - a training checkpoint (dict with "model_state" key). """ net = BlokuZeroNet() obj = torch.load(path, map_location="cpu") # If it's a training checkpoint dict, extract model_state if isinstance(obj, dict) and "model_state" in obj: state_dict = obj["model_state"] else: # Assume it's already a state_dict state_dict = obj net.load_state_dict(state_dict) net.eval() return net def battle( checkpoint_a: str, checkpoint_b: str, num_games: int = 20, watch: bool = False, num_simulations: int = 50, ): """ Load two AlphaZero checkpoints and have them battle for num_games. Alternates which net is player 1 for fairness. Prints a small win-loss matrix at the end. """ print(f"Loading net A from: {checkpoint_a}") print(f"Loading net B from: {checkpoint_b}") net_a = load_net_from_checkpoint(checkpoint_a) net_b = load_net_from_checkpoint(checkpoint_b) # Matrix counters # Rows = starting player (A as P1, B as P1) # Cols = winner (A, B, Draw) stats = { "A_P1": {"A": 0, "B": 0, "D": 0}, "B_P1": {"A": 0, "B": 0, "D": 0}, } for g in range(num_games): if g % 2 == 0: # Even games: A as P1, B as P2 start_label = "A_P1" winner = play_game_between_nets( net_a, net_b, watch=watch, num_simulations=num_simulations, ) if winner == 1: stats[start_label]["A"] += 1 elif winner == 2: stats[start_label]["B"] += 1 else: stats[start_label]["D"] += 1 else: # Odd games: B as P1, A as P2 start_label = "B_P1" winner = play_game_between_nets( net_b, net_a, watch=watch, num_simulations=num_simulations, ) if winner == 1: stats[start_label]["B"] += 1 # player 1 is B elif winner == 2: stats[start_label]["A"] += 1 # player 2 is A else: stats[start_label]["D"] += 1 print(f"Game {g + 1}/{num_games} finished: winner = {winner}") # Aggregate totals total_a_wins = stats["A_P1"]["A"] + stats["B_P1"]["A"] total_b_wins = stats["A_P1"]["B"] + stats["B_P1"]["B"] total_draws = stats["A_P1"]["D"] + stats["B_P1"]["D"] print("\n=== Battle results ===") print(f"Total games: {num_games}") print(f"Model A wins: {total_a_wins}") print(f"Model B wins: {total_b_wins}") print(f"Draws: {total_draws}") print("\nWin-loss matrix (rows = starting player, cols = winner):") print(" A_win B_win Draw") print( f"Start A (P1): {stats['A_P1']['A']:5d} {stats['A_P1']['B']:5d} {stats['A_P1']['D']:5d}" ) print( f"Start B (P1): {stats['B_P1']['A']:5d} {stats['B_P1']['B']:5d} {stats['B_P1']['D']:5d}" ) ################### # Play vs the AI # ################### def az_choose_move( net: BlokuZeroNet, game_state, player: int, num_simulations: int = 100 ): """ Use MCTS with the trained net to choose a move for actual play. """ moves = get_legal_moves(game_state, player) if not moves: return None # Temperature 0 for deterministic play pi, chosen_move, root_moves = mcts_search( game_state, player, net, num_simulations=num_simulations, c_puct=1.5, temperature=0.0, ) if chosen_move is None: return None return chosen_move def play_game_between_nets( net_p1: BlokuZeroNet, net_p2: BlokuZeroNet, watch: bool = False, max_turns: int = 500, num_simulations: int = 50, ) -> int: """ Play one game between two AlphaZero nets using MCTS for both. Returns: 1 if player 1 (net_p1) wins 2 if player 2 (net_p2) wins 0 if draw (max_turns reached) """ game_state = initial_game_state() board, p1tiles, p2tiles = game_state player = 1 turns = 0 while True: turns += 1 if turns > max_turns: # treat as draw return 0 # Choose which net is playing this turn net = net_p1 if player == 1 else net_p2 move = az_choose_move(net, game_state, player, num_simulations=num_simulations) if move is None: # current player cannot move -> they lose if player == 1: return 2 else: return 1 tidx, tile, placement = move gp = game.Player.P1 if player == 1 else game.Player.P2 board.place(tile, placement, gp) if player == 1: p1tiles.remove(tidx) else: p2tiles.remove(tidx) game_state = (board, p1tiles, p2tiles) if watch: print_game_state(game_state) player = 2 if player == 1 else 1 def play_vs_ai(net: BlokuZeroNet, human_is: int = 1, num_simulations: int = 100): """ Let a human play against the AlphaZero-style agent. human_is: 1 or 2 """ game_state = initial_game_state() board, p1tiles, p2tiles = game_state player = 1 while True: print_game_state(game_state) if player == human_is: # Human move moves = get_legal_moves(game_state, player) if not moves: print(f"No moves left for player {player}") print("You lost 😢") break print(f"Player {player}, you have {len(moves)} possible moves.") for i, (tidx, perm, plc) in enumerate(moves[:50]): print(f"[{i}] tile {tidx} at {plc}") if len(moves) > 50: print(f"... and {len(moves) - 50} more moves not listed") while True: try: choice = int(input("Enter move index: ")) if 0 <= choice < len(moves): move = moves[choice] break else: print("Invalid index, try again.") except ValueError: print("Please enter an integer.") else: # AI move print(f"AI (player {player}) is thinking...") move = az_choose_move( net, game_state, player, num_simulations=num_simulations ) if move is None: print(f"No moves left for AI player {player}") print("AI lost, you win! 🎉") break tidx, tile, placement = move gp = game.Player.P1 if player == 1 else game.Player.P2 print(f"player {player} places tile {tidx} at {placement}\n{tile}") board.place(tile, placement, gp) if player == 1: p1tiles.remove(tidx) else: p2tiles.remove(tidx) game_state = (board, p1tiles, p2tiles) player = 2 if player == 1 else 1 ############ # main # ############ def main(): # Battle mode: --battle ckptA ckptB [--games N] [--watch] if "--battle" in sys.argv: idx = sys.argv.index("--battle") try: ckpt_a = sys.argv[idx + 1] ckpt_b = sys.argv[idx + 2] except IndexError: print("Usage: blokus.py --battle ckptA ckptB [--games N] [--watch]") return num_games = 20 if "--games" in sys.argv: gidx = sys.argv.index("--games") try: num_games = int(sys.argv[gidx + 1]) except (IndexError, ValueError): print("Invalid or missing value for --games, using default 20.") watch = "--watch" in sys.argv battle(ckpt_a, ckpt_b, num_games=num_games, watch=watch) return net = BlokuZeroNet() if "--play" in sys.argv: model_path = "az_trained_agent.pt" if os.path.exists(model_path): print(f"Loading model from {model_path}") state = torch.load(model_path, map_location="cpu") try: net.load_state_dict(state) net.eval() except Exception as e: print("Saved model incompatible with current net, playing untrained.") print("Reason:", e) else: print( "Warning: az_trained_agent.pt not found. Playing with an untrained model." ) # By default, human is player 1; change to 2 if you want play_vs_ai(net, human_is=1, num_simulations=100) else: # AlphaZero-style training by self-play train_alpha_zero( net, num_iterations=250, # number of training iterations games_per_iter=5, # self-play games per iteration num_simulations=50, # MCTS simulations per move during self-play batch_size=32, lr=1e-3, save_path="az_trained_agent.pt", watch_selfplay="--watch" in sys.argv, ) if __name__ == "__main__": main()