From 2c93b4763053363d13d6c9ac1fac202f6c54927c Mon Sep 17 00:00:00 2001 From: Noa Aarts Date: Fri, 5 Dec 2025 20:18:20 +0100 Subject: [PATCH] add RL to the game (I hope) --- .envrc | 2 +- blokus.py | 462 +++++++++++++++++++++++++++++++++++++++++++++--------- flake.nix | 11 +- 3 files changed, 397 insertions(+), 78 deletions(-) diff --git a/.envrc b/.envrc index 3550a30..7a3598f 100644 --- a/.envrc +++ b/.envrc @@ -1 +1 @@ -use flake +use flake . -L diff --git a/blokus.py b/blokus.py index 64b8189..3083cb9 100755 --- a/blokus.py +++ b/blokus.py @@ -1,9 +1,13 @@ #!/usr/bin/env python import random +import sys +import os import game -import numpy as np import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm.auto import trange ############### # Utilities # @@ -41,88 +45,402 @@ def print_game_state(game_state: tuple[game.Board, list[int], list[int]]): # Game state init # ################### -game_state = ( - game.Board(), - [i for i in range(21)], - [i for i in range(21)], -) -################### -# RL Utils # -################### +def initial_game_state(): + return ( + game.Board(), + [i for i in range(21)], + [i for i in range(21)], + ) -class Saver: - def __init__(self, results_path, experiment_seed): - self.stats_file = {"train": {}, "test": {}} - self.exp_seed = experiment_seed - self.rpath = results_path - - def get_new_episode(self, mode, episode_no): - if mode == "train": - self.stats_file[mode][episode_no] = { - "loss": [], - "actions": [], - "errors": [], - "errors_noiseless": [], - "done_threshold": 0, - "bond_distance": 0, - "nfev": [], - "opt_ang": [], - "time": [], - "save_circ": [], - "reward": [], - } - elif mode == "test": - self.stats_file[mode][episode_no] = { - "actions": [], - "errors": [], - "errors_noiseless": [], - "done_threshold": 0, - "bond_distance": 0, - "nfev": [], - "opt_ang": [], - "time": [], - } - - def save_file(self): - np.save(f"{self.rpath}/summary_{self.exp_seed}.npy", self.stats_file) - - def validate_stats(self, episode, mode): - assert len(self.stats_file[mode][episode]["actions"]) == len( - self.stats_file[mode][episode]["errors"] - ) +############ +# Encoding # +############ -playing = True -player = 1 -while playing: - moves = [] - assert player == 1 or player == 2 +def encode_board(board: game.Board) -> torch.Tensor: + # board[(x, y)] returns 0,1,2,... according to your print function + arr = torch.zeros((3, BOARD_SIZE, BOARD_SIZE), dtype=torch.float32) + for y in range(BOARD_SIZE): + for x in range(BOARD_SIZE): + v = board[(x, y)] + if v == 1: + arr[0, y, x] = 1.0 + elif v == 2: + arr[1, y, x] = 1.0 + elif v == 3: # if "S" or something else + arr[2, y, x] = 1.0 + return arr + + +def encode_tiles(p1tiles, p2tiles) -> torch.Tensor: + # 21 tiles total, so 42-dim vector + v = torch.zeros(42, dtype=torch.float32) + for t in p1tiles: + v[t] = 1.0 + offset = 21 + for t in p2tiles: + v[offset + t] = 1.0 + return v + + +def encode_move(tile_idx: int, placement: tuple[int, int]) -> torch.Tensor: + (x, y) = placement + tile_vec = torch.zeros(21, dtype=torch.float32) + tile_vec[tile_idx] = 1.0 + pos_vec = torch.tensor( + [x / (BOARD_SIZE - 1), y / (BOARD_SIZE - 1)], dtype=torch.float32 + ) + return torch.cat([tile_vec, pos_vec], dim=0) # 23-dim + + +def encode_state_and_move(game_state, player: int, tile_idx: int, placement): + board, p1tiles, p2tiles = game_state + board_tensor = encode_board(board).flatten() # 3*14*14 = 588 + tiles_tensor = encode_tiles(p1tiles, p2tiles) # 42 + move_tensor = encode_move(tile_idx, placement) # 23 + + # Encode "current player" as a bit + player_tensor = torch.tensor([1.0 if player == 1 else 0.0], dtype=torch.float32) + + return torch.cat([board_tensor, tiles_tensor, move_tensor, player_tensor], dim=0) + # Total size = 588 + 42 + 23 + 1 = 654 + + +########### +# Model # +########### + +FEATURE_SIZE = 654 # from above + + +class MoveScorer(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(FEATURE_SIZE, 256) + self.fc2 = nn.Linear(256, 128) + self.fc_out = nn.Linear(128, 1) # scalar score + + def forward(self, x): + # x: (batch_size, FEATURE_SIZE) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc_out(x) # (batch_size, 1) + + +################## +# Move generation +################## + + +def get_legal_moves(game_state, player: int): + board, p1tiles, p2tiles = game_state gp = game.Player.P1 if player == 1 else game.Player.P2 - for tile_idx in game_state[player]: + + tiles_left = p1tiles if player == 1 else p2tiles + + moves = [] + for tile_idx in tiles_left: tile = tiles[tile_idx] perms = tile.permutations() for perm in perms: - plcs = game_state[0].tile_placements(perm, gp) + plcs = board.tile_placements(perm, gp) moves.extend((tile_idx, perm, plc) for plc in plcs) + return moves - print(f"player {player} has {len(moves)} options") - if len(moves) == 0: - print(f"No moves left, player {player} lost") - playing = False - continue +########### +# Agents # +########### - (tidx, tile, placement) = random.choice(moves) - print( - f"player {player} is placing the following tile with index {tidx} at {placement}\n{tile}" - ) - game_state[0].place(tile, placement, gp) - game_state[player].remove(tidx) - print_game_state(game_state) - if player == 1: - player = 2 - elif player == 2: - player = 1 +class Agent: + def choose_move(self, game_state, player: int): + """Return (tile_idx, perm, placement) or None if no moves.""" + raise NotImplementedError + + +class RandomAgent(Agent): + def choose_move(self, game_state, player: int): + moves = get_legal_moves(game_state, player) + if not moves: + return None + return random.choice(moves) + + +class HumanAgent(Agent): + def choose_move(self, game_state, player: int): + moves = get_legal_moves(game_state, player) + if not moves: + print(f"No moves left for player {player}") + return None + + print_game_state(game_state) + print(f"Player {player}, you have {len(moves)} possible moves.") + + # Show a *subset* or all moves + for i, (tidx, perm, plc) in enumerate(moves): + if i < 50: # don't spam too hard; tweak as needed + print(f"[{i}] tile {tidx} at {plc}") + else: + break + if len(moves) > 50: + print(f"... and {len(moves) - 50} more moves not listed") + + while True: + try: + choice = int(input("Enter move index: ")) + if 0 <= choice < len(moves): + return moves[choice] + else: + print("Invalid index, try again.") + except ValueError: + print("Please enter an integer.") + + +class MLAgent(Agent): + def __init__( + self, model: MoveScorer, deterministic: bool = True, epsilon: float = 0.0 + ): + self.model = model + self.deterministic = deterministic + self.epsilon = epsilon + + def choose_move(self, game_state, player: int): + moves = get_legal_moves(game_state, player) + if not moves: + return None + + # Optional epsilon-greedy: use 0 for “serious play” + if self.epsilon > 0.0 and random.random() < self.epsilon: + return random.choice(moves) + + # Build feature batch + features = [] + for tidx, perm, placement in moves: + feat = encode_state_and_move(game_state, player, tidx, placement) + features.append(feat) + X = torch.stack(features, dim=0) + + self.model.eval() + with torch.no_grad(): + scores = self.model(X).squeeze(-1) # (num_moves,) + + if self.deterministic: + best_idx = torch.argmax(scores).item() + return moves[best_idx] + else: + # Sample from softmax for more variety + probs = torch.softmax(scores, dim=0) + idx = torch.multinomial(probs, num_samples=1).item() + return moves[idx] + + +###################### +# Training utilities # +###################### + + +def select_move_and_logprob(model: MoveScorer, game_state, player: int): + """ + For training: sample a move from softmax over scores + and return (move, log_prob). If no moves, returns (None, None). + """ + moves = get_legal_moves(game_state, player) + if not moves: + return None, None + + features = [] + for tidx, perm, placement in moves: + feat = encode_state_and_move(game_state, player, tidx, placement) + features.append(feat) + X = torch.stack(features, dim=0) # (num_moves, FEATURE_SIZE) + + scores = model(X).squeeze(-1) # (num_moves,) + probs = F.softmax(scores, dim=0) + + dist = torch.distributions.Categorical(probs) + idx = dist.sample() + log_prob = dist.log_prob(idx) + + move = moves[idx.item()] + return move, log_prob + + +def play_self_play_game(model: MoveScorer, max_turns: int = 500): + """ + Self-play game with the same model as both players. + Returns: + log_probs_p1, log_probs_p2, reward_p1, reward_p2 + where rewards are +1/-1 for win/loss. + """ + game_state = initial_game_state() + board, p1tiles, p2tiles = game_state + + log_probs = {1: [], 2: []} + player = 1 + turns = 0 + + while True: + turns += 1 + if turns > max_turns: + # Safety: declare a draw + reward_p1 = 0.0 + reward_p2 = 0.0 + return log_probs[1], log_probs[2], reward_p1, reward_p2 + + move, log_prob = select_move_and_logprob(model, game_state, player) + + if move is None: + # Current player cannot move -> they lose + if player == 1: + reward_p1 = -1.0 + reward_p2 = +1.0 + else: + reward_p1 = +1.0 + reward_p2 = -1.0 + return log_probs[1], log_probs[2], reward_p1, reward_p2 + + tidx, tile, placement = move + gp = game.Player.P1 if player == 1 else game.Player.P2 + + # Apply move + board.place(tile, placement, gp) + if player == 1: + p1tiles.remove(tidx) + else: + p2tiles.remove(tidx) + + # Update game_state tuple + game_state = (board, p1tiles, p2tiles) + + # Store log_prob + log_probs[player].append(log_prob) + + # Switch player + player = 2 if player == 1 else 1 + + +def train( + model: MoveScorer, + num_episodes: int = 1000, + lr: float = 1e-3, + save_path: str = "trained_agent.pt", +): + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + + pbar = trange(1, num_episodes + 1, desc="Training", dynamic_ncols=True) + + for episode in pbar: + log_probs_p1, log_probs_p2, r1, r2 = play_self_play_game(model) + + loss = torch.tensor(0.0) + if log_probs_p1: + loss = loss - r1 * torch.stack(log_probs_p1).sum() + if log_probs_p2: + loss = loss - r2 * torch.stack(log_probs_p2).sum() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Update progress bar with most recent stats + pbar.set_postfix( + episode=episode, + loss=float(loss.item()), + r1=float(r1), + r2=float(r2), + ) + + torch.save(model.state_dict(), save_path) + print(f"\nTraining finished. Model saved to {save_path}") + + +################### +# Play vs the AI # +################### + + +def play_vs_ai(model: MoveScorer, human_is: int = 1): + """ + Let a human play against the trained model. + human_is: 1 or 2 + """ + game_state = initial_game_state() + board, p1tiles, p2tiles = game_state + + human = HumanAgent() + ai = MLAgent(model, deterministic=True, epsilon=0.0) + + agents = { + human_is: human, + 1 if human_is == 2 else 2: ai, + } + + player = 1 + while True: + agent = agents[player] + move = agent.choose_move(game_state, player) + + if move is None: + print(f"No moves left, player {player} lost") + if player == human_is: + print("You lost 😢") + else: + print("You won! 🎉") + break + + tidx, tile, placement = move + gp = game.Player.P1 if player == 1 else game.Player.P2 + + print(f"player {player} places tile {tidx} at {placement}\n{tile}") + + board.place(tile, placement, gp) + if player == 1: + p1tiles.remove(tidx) + else: + p2tiles.remove(tidx) + + game_state = (board, p1tiles, p2tiles) + print_game_state(game_state) + + player = 2 if player == 1 else 1 + + +############ +# main # +############ + + +def main(): + model = MoveScorer() + + if torch.cuda.is_available(): + print("using CUDA") + torch.device("cuda:0") + else: + print("Not using CUDA") + + if "--play" in sys.argv: + # Try to load trained weights if they exist + model_path = "trained_agent.pt" + if os.path.exists(model_path): + print(f"Loading model from {model_path}") + state = torch.load(model_path, map_location="cpu") + model.load_state_dict(state) + model.eval() + else: + print( + "Warning: trained_agent.pt not found. Playing with an untrained model." + ) + + # By default, human is player 1; change to 2 if you want + play_vs_ai(model, human_is=1) + else: + # Train by self-play + train(model, num_episodes=1000, lr=1e-3, save_path="trained_agent.pt") + + +if __name__ == "__main__": + main() diff --git a/flake.nix b/flake.nix index 3d4cc05..6cbb1b9 100644 --- a/flake.nix +++ b/flake.nix @@ -22,10 +22,10 @@ pkgs = import inputs.nixpkgs { inherit system; overlays = [ inputs.rust-overlay.overlays.default ]; - config = { - allowUnfree = true; - cudaSupport = true; - }; + config = { + allowUnfree = true; + cudaSupport = true; + }; }; lib = pkgs.lib; @@ -100,6 +100,7 @@ packages = [ (pkgs.python3.withPackages (ppkgs: [ ppkgs.torch + ppkgs.tqdm (lib.python_package ppkgs) ])) ]; @@ -111,7 +112,7 @@ # To use in other builds with the "withPackages" call python_package = ps: - ps.buildPythonPackage rec { + ps.buildPythonPackage { pname = project_name; format = "wheel"; version = project_version;