blokus/blokus.py

476 lines
12 KiB
Python
Executable file

#!/usr/bin/env python
import random
import sys
import os
import game
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import trange
###############
# Utilities #
###############
BOARD_SIZE = 14
tiles = game.game_tiles()
def print_game_state(game_state: tuple[game.Board, list[int], list[int]]):
(board, p1tiles, p2tiles) = game_state
barr = []
for i in range(BOARD_SIZE):
barr.append([])
for j in range(BOARD_SIZE):
barr[i].append(board[(j, i)])
print(f" {'-' * BOARD_SIZE} ")
for row in barr:
print(
f"|{
''.join(
[
' ' if x == 0 else 'X' if x == 1 else 'O' if x == 2 else 'S'
for x in row
]
)
}|"
)
print(f" {'-' * BOARD_SIZE} ")
print(f"Player 1 tiles left: {p1tiles}")
print(f"Player 2 tiles left: {p2tiles}")
###################
# Game state init #
###################
def initial_game_state():
return (
game.Board(),
[i for i in range(21)],
[i for i in range(21)],
)
############
# Encoding #
############
def encode_board(board: game.Board) -> torch.Tensor:
# board[(x, y)] returns 0,1,2,... according to your print function
arr = torch.zeros((3, BOARD_SIZE, BOARD_SIZE), dtype=torch.float32)
for y in range(BOARD_SIZE):
for x in range(BOARD_SIZE):
v = board[(x, y)]
if v == 1:
arr[0, y, x] = 1.0
elif v == 2:
arr[1, y, x] = 1.0
elif v == 3: # if "S" or something else
arr[2, y, x] = 1.0
return arr
def encode_tiles(p1tiles, p2tiles) -> torch.Tensor:
# 21 tiles total, so 42-dim vector
v = torch.zeros(42, dtype=torch.float32)
for t in p1tiles:
v[t] = 1.0
offset = 21
for t in p2tiles:
v[offset + t] = 1.0
return v
def encode_move(tile_idx: int, placement: tuple[int, int]) -> torch.Tensor:
(x, y) = placement
tile_vec = torch.zeros(21, dtype=torch.float32)
tile_vec[tile_idx] = 1.0
pos_vec = torch.tensor(
[x / (BOARD_SIZE - 1), y / (BOARD_SIZE - 1)], dtype=torch.float32
)
return torch.cat([tile_vec, pos_vec], dim=0) # 23-dim
def encode_state_and_move(
game_state, player: int, tile_idx: int, placement: tuple[int, int], perm: game.Tile
):
board, p1tiles, p2tiles = game_state
# Encode board BEFORE the move
board_before = encode_board(board).flatten()
# Encode board AFTER the move using sim_place
gp = game.Player.P1 if player == 1 else game.Player.P2
board_after_sim = board.sim_place(
perm, placement, gp
) # <--- uses your new function
board_after = encode_board(board_after_sim).flatten()
tiles_tensor = encode_tiles(p1tiles, p2tiles)
move_tensor = encode_move(tile_idx, placement) # still tile+position encoding
player_tensor = torch.tensor([1.0 if player == 1 else 0.0], dtype=torch.float32)
return torch.cat(
[
board_before, # 588
board_after, # 588
tiles_tensor, # 42
move_tensor, # 23
player_tensor, # 1
],
dim=0,
)
###########
# Model #
###########
FEATURE_SIZE = 1242 # from above
class MoveScorer(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(FEATURE_SIZE, 256)
self.fc2 = nn.Linear(256, 128)
self.fc_out = nn.Linear(128, 1) # scalar score
def forward(self, x):
# x: (batch_size, FEATURE_SIZE)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc_out(x) # (batch_size, 1)
##################
# Move generation
##################
def get_legal_moves(game_state, player: int):
board, p1tiles, p2tiles = game_state
gp = game.Player.P1 if player == 1 else game.Player.P2
tiles_left = p1tiles if player == 1 else p2tiles
moves = []
for tile_idx in tiles_left:
tile = tiles[tile_idx]
perms = tile.permutations()
for perm in perms:
plcs = board.tile_placements(perm, gp)
moves.extend((tile_idx, perm, plc) for plc in plcs)
return moves
###########
# Agents #
###########
class Agent:
def choose_move(self, game_state, player: int):
"""Return (tile_idx, perm, placement) or None if no moves."""
raise NotImplementedError
class RandomAgent(Agent):
def choose_move(self, game_state, player: int):
moves = get_legal_moves(game_state, player)
if not moves:
return None
return random.choice(moves)
class HumanAgent(Agent):
def choose_move(self, game_state, player: int):
moves = get_legal_moves(game_state, player)
if not moves:
print(f"No moves left for player {player}")
return None
print_game_state(game_state)
print(f"Player {player}, you have {len(moves)} possible moves.")
# Show a *subset* or all moves
for i, (tidx, perm, plc) in enumerate(moves):
if i < 50: # don't spam too hard; tweak as needed
print(f"[{i}] tile {tidx} at {plc}")
else:
break
if len(moves) > 50:
print(f"... and {len(moves) - 50} more moves not listed")
while True:
try:
choice = int(input("Enter move index: "))
if 0 <= choice < len(moves):
return moves[choice]
else:
print("Invalid index, try again.")
except ValueError:
print("Please enter an integer.")
class MLAgent(Agent):
def __init__(
self, model: MoveScorer, deterministic: bool = True, epsilon: float = 0.0
):
self.model = model
self.deterministic = deterministic
self.epsilon = epsilon
def choose_move(self, game_state, player: int):
moves = get_legal_moves(game_state, player)
if not moves:
return None
# Optional epsilon-greedy: use 0 for “serious play”
if self.epsilon > 0.0 and random.random() < self.epsilon:
return random.choice(moves)
# Build feature batch
features = []
for tidx, perm, placement in moves:
feat = encode_state_and_move(game_state, player, tidx, placement, perm)
features.append(feat)
X = torch.stack(features, dim=0)
self.model.eval()
with torch.no_grad():
scores = self.model(X).squeeze(-1) # (num_moves,)
if self.deterministic:
best_idx = torch.argmax(scores).item()
return moves[best_idx]
else:
# Sample from softmax for more variety
probs = torch.softmax(scores, dim=0)
idx = torch.multinomial(probs, num_samples=1).item()
return moves[idx]
######################
# Training utilities #
######################
def select_move_and_logprob(model: MoveScorer, game_state, player: int):
"""
For training: sample a move from softmax over scores
and return (move, log_prob). If no moves, returns (None, None).
"""
moves = get_legal_moves(game_state, player)
if not moves:
return None, None
features = []
for tidx, perm, placement in moves:
feat = encode_state_and_move(game_state, player, tidx, placement, perm)
features.append(feat)
X = torch.stack(features, dim=0) # (num_moves, FEATURE_SIZE)
scores = model(X).squeeze(-1) # (num_moves,)
probs = F.softmax(scores, dim=0)
dist = torch.distributions.Categorical(probs)
idx = dist.sample()
log_prob = dist.log_prob(idx)
move = moves[idx.item()]
return move, log_prob
def play_self_play_game(model: MoveScorer, max_turns: int = 500, watch: bool = False):
"""
Self-play game with the same model as both players.
Returns:
log_probs_p1, log_probs_p2, reward_p1, reward_p2
where rewards are +1/-1 for win/loss.
"""
game_state = initial_game_state()
board, p1tiles, p2tiles = game_state
log_probs = {1: [], 2: []}
player = 1
turns = 0
while True:
turns += 1
if turns > max_turns:
# Safety: declare a draw
reward_p1 = 0.0
reward_p2 = 0.0
return log_probs[1], log_probs[2], reward_p1, reward_p2
move, log_prob = select_move_and_logprob(model, game_state, player)
if move is None:
# Current player cannot move -> they lose
if player == 1:
reward_p1 = -1.0
reward_p2 = +1.0
else:
reward_p1 = +1.0
reward_p2 = -1.0
return log_probs[1], log_probs[2], reward_p1, reward_p2
tidx, tile, placement = move
gp = game.Player.P1 if player == 1 else game.Player.P2
# Apply move
board.place(tile, placement, gp)
if player == 1:
p1tiles.remove(tidx)
else:
p2tiles.remove(tidx)
# Update game_state tuple
game_state = (board, p1tiles, p2tiles)
if watch:
print_game_state(game_state)
# Store log_prob
log_probs[player].append(log_prob)
# Switch player
player = 2 if player == 1 else 1
def train(
model: MoveScorer,
num_episodes: int = 1000,
lr: float = 1e-3,
save_path: str = "trained_agent.pt",
watch: bool = False,
):
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
pbar = trange(1, num_episodes + 1, desc="Training", dynamic_ncols=True)
for episode in pbar:
log_probs_p1, log_probs_p2, r1, r2 = play_self_play_game(model, watch=watch)
loss = torch.tensor(0.0)
if log_probs_p1:
loss = loss - r1 * torch.stack(log_probs_p1).sum()
if log_probs_p2:
loss = loss - r2 * torch.stack(log_probs_p2).sum()
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Update progress bar with most recent stats
pbar.set_postfix(
episode=episode,
loss=float(loss.item()),
r1=float(r1),
r2=float(r2),
)
torch.save(model.state_dict(), save_path)
print(f"\nTraining finished. Model saved to {save_path}")
###################
# Play vs the AI #
###################
def play_vs_ai(model: MoveScorer, human_is: int = 1):
"""
Let a human play against the trained model.
human_is: 1 or 2
"""
game_state = initial_game_state()
board, p1tiles, p2tiles = game_state
human = HumanAgent()
ai = MLAgent(model, deterministic=True, epsilon=0.0)
agents = {
human_is: human,
1 if human_is == 2 else 2: ai,
}
player = 1
while True:
agent = agents[player]
move = agent.choose_move(game_state, player)
if move is None:
print(f"No moves left, player {player} lost")
if player == human_is:
print("You lost 😢")
else:
print("You won! 🎉")
break
tidx, tile, placement = move
gp = game.Player.P1 if player == 1 else game.Player.P2
print(f"player {player} places tile {tidx} at {placement}\n{tile}")
board.place(tile, placement, gp)
if player == 1:
p1tiles.remove(tidx)
else:
p2tiles.remove(tidx)
game_state = (board, p1tiles, p2tiles)
print_game_state(game_state)
player = 2 if player == 1 else 1
############
# main #
############
def main():
model = MoveScorer()
if torch.cuda.is_available():
print("using CUDA")
torch.device("cuda:0")
else:
print("Not using CUDA")
if "--play" in sys.argv:
# Try to load trained weights if they exist
model_path = "trained_agent.pt"
if os.path.exists(model_path):
print(f"Loading model from {model_path}")
state = torch.load(model_path, map_location="cpu")
model.load_state_dict(state)
model.eval()
else:
print(
"Warning: trained_agent.pt not found. Playing with an untrained model."
)
# By default, human is player 1; change to 2 if you want
play_vs_ai(model, human_is=1)
else:
# Train by self-play
train(
model,
num_episodes=1000,
lr=1e-3,
save_path="trained_agent.pt",
watch="--watch" in sys.argv,
)
if __name__ == "__main__":
main()