535 lines
14 KiB
Python
Executable file
535 lines
14 KiB
Python
Executable file
#!/usr/bin/env python
|
|
import random
|
|
import sys
|
|
import os
|
|
import game
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from tqdm.auto import trange
|
|
import matplotlib.pyplot as plt
|
|
|
|
###############
|
|
# Utilities #
|
|
###############
|
|
|
|
BOARD_SIZE = 14
|
|
tiles = game.game_tiles()
|
|
|
|
|
|
def print_game_state(game_state: tuple[game.Board, list[int], list[int]]):
|
|
(board, p1tiles, p2tiles) = game_state
|
|
|
|
barr = []
|
|
for i in range(BOARD_SIZE):
|
|
barr.append([])
|
|
for j in range(BOARD_SIZE):
|
|
barr[i].append(board[(j, i)])
|
|
|
|
print(f" {'-' * BOARD_SIZE} ")
|
|
for row in barr:
|
|
print(
|
|
f"|{
|
|
''.join(
|
|
[
|
|
' ' if x == 0 else 'X' if x == 1 else 'O' if x == 2 else 'S'
|
|
for x in row
|
|
]
|
|
)
|
|
}|"
|
|
)
|
|
print(f" {'-' * BOARD_SIZE} ")
|
|
|
|
print(f"Player 1 tiles left: {p1tiles}")
|
|
print(f"Player 2 tiles left: {p2tiles}")
|
|
|
|
|
|
def plot_losses(loss_history, out_path="loss_curve.png"):
|
|
if not loss_history:
|
|
print("No losses to plot.")
|
|
return
|
|
|
|
plt.figure()
|
|
plt.plot(range(1, len(loss_history) + 1), loss_history)
|
|
plt.xlabel("Episode")
|
|
plt.ylabel("Loss")
|
|
plt.title("Training loss over episodes")
|
|
plt.tight_layout()
|
|
plt.savefig(out_path)
|
|
plt.close()
|
|
print(f"Saved loss plot to {out_path}")
|
|
|
|
|
|
###################
|
|
# Game state init #
|
|
###################
|
|
|
|
|
|
def initial_game_state():
|
|
return (
|
|
game.Board(),
|
|
[i for i in range(21)],
|
|
[i for i in range(21)],
|
|
)
|
|
|
|
|
|
############
|
|
# Encoding #
|
|
############
|
|
|
|
|
|
def encode_board(board: game.Board) -> torch.Tensor:
|
|
# board[(x, y)] returns 0,1,2,... according to your print function
|
|
arr = torch.zeros((3, BOARD_SIZE, BOARD_SIZE), dtype=torch.float32)
|
|
for y in range(BOARD_SIZE):
|
|
for x in range(BOARD_SIZE):
|
|
v = board[(x, y)]
|
|
if v == 1:
|
|
arr[0, y, x] = 1.0
|
|
elif v == 2:
|
|
arr[1, y, x] = 1.0
|
|
elif v == 3: # if "S" or something else
|
|
arr[2, y, x] = 1.0
|
|
return arr
|
|
|
|
|
|
def encode_tiles(p1tiles, p2tiles) -> torch.Tensor:
|
|
# 21 tiles total, so 42-dim vector
|
|
v = torch.zeros(42, dtype=torch.float32)
|
|
for t in p1tiles:
|
|
v[t] = 1.0
|
|
offset = 21
|
|
for t in p2tiles:
|
|
v[offset + t] = 1.0
|
|
return v
|
|
|
|
|
|
def encode_move(tile_idx: int, placement: tuple[int, int]) -> torch.Tensor:
|
|
(x, y) = placement
|
|
tile_vec = torch.zeros(21, dtype=torch.float32)
|
|
tile_vec[tile_idx] = 1.0
|
|
pos_vec = torch.tensor(
|
|
[x / (BOARD_SIZE - 1), y / (BOARD_SIZE - 1)], dtype=torch.float32
|
|
)
|
|
return torch.cat([tile_vec, pos_vec], dim=0) # 23-dim
|
|
|
|
|
|
def encode_state_and_move(
|
|
game_state, player: int, tile_idx: int, placement: tuple[int, int], perm: game.Tile
|
|
):
|
|
board, p1tiles, p2tiles = game_state
|
|
|
|
# Encode board BEFORE the move
|
|
board_before = encode_board(board).flatten()
|
|
|
|
# Encode board AFTER the move using sim_place
|
|
gp = game.Player.P1 if player == 1 else game.Player.P2
|
|
board_after_sim = board.sim_place(
|
|
perm, placement, gp
|
|
) # <--- uses your new function
|
|
board_after = encode_board(board_after_sim).flatten()
|
|
|
|
tiles_tensor = encode_tiles(p1tiles, p2tiles)
|
|
move_tensor = encode_move(tile_idx, placement) # still tile+position encoding
|
|
player_tensor = torch.tensor([1.0 if player == 1 else 0.0], dtype=torch.float32)
|
|
|
|
return torch.cat(
|
|
[
|
|
board_before, # 588
|
|
board_after, # 588
|
|
tiles_tensor, # 42
|
|
move_tensor, # 23
|
|
player_tensor, # 1
|
|
],
|
|
dim=0,
|
|
)
|
|
|
|
|
|
###########
|
|
# Model #
|
|
###########
|
|
|
|
FEATURE_SIZE = 1242 # from above
|
|
|
|
|
|
class MoveScorer(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fc1 = nn.Linear(FEATURE_SIZE, 256)
|
|
self.fc2 = nn.Linear(256, 128)
|
|
self.fc_out = nn.Linear(128, 1) # scalar score
|
|
|
|
def forward(self, x):
|
|
# x: (batch_size, FEATURE_SIZE)
|
|
x = F.relu(self.fc1(x))
|
|
x = F.relu(self.fc2(x))
|
|
return self.fc_out(x) # (batch_size, 1)
|
|
|
|
|
|
##################
|
|
# Move generation
|
|
##################
|
|
|
|
|
|
def get_legal_moves(game_state, player: int):
|
|
board, p1tiles, p2tiles = game_state
|
|
gp = game.Player.P1 if player == 1 else game.Player.P2
|
|
|
|
tiles_left = p1tiles if player == 1 else p2tiles
|
|
|
|
moves = []
|
|
for tile_idx in tiles_left:
|
|
tile = tiles[tile_idx]
|
|
perms = tile.permutations()
|
|
for perm in perms:
|
|
plcs = board.tile_placements(perm, gp)
|
|
moves.extend((tile_idx, perm, plc) for plc in plcs)
|
|
return moves
|
|
|
|
|
|
###########
|
|
# Agents #
|
|
###########
|
|
|
|
|
|
class Agent:
|
|
def choose_move(self, game_state, player: int):
|
|
"""Return (tile_idx, perm, placement) or None if no moves."""
|
|
raise NotImplementedError
|
|
|
|
|
|
class RandomAgent(Agent):
|
|
def choose_move(self, game_state, player: int):
|
|
moves = get_legal_moves(game_state, player)
|
|
if not moves:
|
|
return None
|
|
return random.choice(moves)
|
|
|
|
|
|
class HumanAgent(Agent):
|
|
def choose_move(self, game_state, player: int):
|
|
moves = get_legal_moves(game_state, player)
|
|
if not moves:
|
|
print(f"No moves left for player {player}")
|
|
return None
|
|
|
|
print_game_state(game_state)
|
|
print(f"Player {player}, you have {len(moves)} possible moves.")
|
|
|
|
# Show a *subset* or all moves
|
|
for i, (tidx, perm, plc) in enumerate(moves):
|
|
if i < 50: # don't spam too hard; tweak as needed
|
|
print(f"[{i}] tile {tidx} at {plc}")
|
|
else:
|
|
break
|
|
if len(moves) > 50:
|
|
print(f"... and {len(moves) - 50} more moves not listed")
|
|
|
|
while True:
|
|
try:
|
|
choice = int(input("Enter move index: "))
|
|
if 0 <= choice < len(moves):
|
|
return moves[choice]
|
|
else:
|
|
print("Invalid index, try again.")
|
|
except ValueError:
|
|
print("Please enter an integer.")
|
|
|
|
|
|
class MLAgent(Agent):
|
|
def __init__(
|
|
self, model: MoveScorer, deterministic: bool = True, epsilon: float = 0.0
|
|
):
|
|
self.model = model
|
|
self.deterministic = deterministic
|
|
self.epsilon = epsilon
|
|
|
|
def choose_move(self, game_state, player: int):
|
|
moves = get_legal_moves(game_state, player)
|
|
if not moves:
|
|
return None
|
|
|
|
# Optional epsilon-greedy: use 0 for “serious play”
|
|
if self.epsilon > 0.0 and random.random() < self.epsilon:
|
|
return random.choice(moves)
|
|
|
|
# Build feature batch
|
|
features = []
|
|
for tidx, perm, placement in moves:
|
|
feat = encode_state_and_move(game_state, player, tidx, placement, perm)
|
|
features.append(feat)
|
|
X = torch.stack(features, dim=0)
|
|
|
|
self.model.eval()
|
|
with torch.no_grad():
|
|
scores = self.model(X).squeeze(-1) # (num_moves,)
|
|
|
|
if self.deterministic:
|
|
best_idx = torch.argmax(scores).item()
|
|
return moves[best_idx]
|
|
else:
|
|
# Sample from softmax for more variety
|
|
probs = torch.softmax(scores, dim=0)
|
|
idx = torch.multinomial(probs, num_samples=1).item()
|
|
return moves[idx]
|
|
|
|
|
|
######################
|
|
# Training utilities #
|
|
######################
|
|
|
|
|
|
def select_move_and_logprob(model: MoveScorer, game_state, player: int):
|
|
"""
|
|
For training: sample a move from softmax over scores
|
|
and return (move, log_prob). If no moves, returns (None, None).
|
|
"""
|
|
moves = get_legal_moves(game_state, player)
|
|
if not moves:
|
|
return None, None
|
|
|
|
features = []
|
|
for tidx, perm, placement in moves:
|
|
feat = encode_state_and_move(game_state, player, tidx, placement, perm)
|
|
features.append(feat)
|
|
X = torch.stack(features, dim=0) # (num_moves, FEATURE_SIZE)
|
|
|
|
scores = model(X).squeeze(-1) # (num_moves,)
|
|
probs = F.softmax(scores, dim=0)
|
|
|
|
dist = torch.distributions.Categorical(probs)
|
|
idx = dist.sample()
|
|
log_prob = dist.log_prob(idx)
|
|
|
|
move = moves[idx.item()]
|
|
return move, log_prob
|
|
|
|
|
|
def play_self_play_game(model: MoveScorer, max_turns: int = 500, watch: bool = False):
|
|
"""
|
|
Self-play game with the same model as both players.
|
|
Returns:
|
|
log_probs_p1, log_probs_p2, reward_p1, reward_p2
|
|
where rewards are +1/-1 for win/loss.
|
|
"""
|
|
game_state = initial_game_state()
|
|
board, p1tiles, p2tiles = game_state
|
|
|
|
log_probs = {1: [], 2: []}
|
|
player = 1
|
|
turns = 0
|
|
|
|
while True:
|
|
turns += 1
|
|
if turns > max_turns:
|
|
# Safety: declare a draw
|
|
reward_p1 = 0.0
|
|
reward_p2 = 0.0
|
|
return log_probs[1], log_probs[2], reward_p1, reward_p2
|
|
|
|
move, log_prob = select_move_and_logprob(model, game_state, player)
|
|
|
|
if move is None:
|
|
# Current player cannot move -> they lose
|
|
if player == 1:
|
|
reward_p1 = -1.0
|
|
reward_p2 = +1.0
|
|
else:
|
|
reward_p1 = +1.0
|
|
reward_p2 = -1.0
|
|
return log_probs[1], log_probs[2], reward_p1, reward_p2
|
|
|
|
tidx, tile, placement = move
|
|
gp = game.Player.P1 if player == 1 else game.Player.P2
|
|
|
|
# Apply move
|
|
board.place(tile, placement, gp)
|
|
if player == 1:
|
|
p1tiles.remove(tidx)
|
|
else:
|
|
p2tiles.remove(tidx)
|
|
|
|
# Update game_state tuple
|
|
game_state = (board, p1tiles, p2tiles)
|
|
if watch:
|
|
print_game_state(game_state)
|
|
|
|
# Store log_prob
|
|
log_probs[player].append(log_prob)
|
|
|
|
# Switch player
|
|
player = 2 if player == 1 else 1
|
|
|
|
|
|
def train(
|
|
model: MoveScorer,
|
|
num_episodes: int = 1000,
|
|
lr: float = 1e-3,
|
|
save_path: str = "trained_agent.pt",
|
|
watch: bool = False,
|
|
):
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
|
|
# We'll keep a history of losses for plotting
|
|
loss_history = []
|
|
|
|
# Checkpoint path (partial training state)
|
|
ckpt_path = save_path + ".ckpt"
|
|
|
|
start_episode = 1
|
|
|
|
# Try to resume from checkpoint if it exists
|
|
if os.path.exists(ckpt_path):
|
|
ckpt = torch.load(ckpt_path, map_location="cpu")
|
|
model.load_state_dict(ckpt["model_state"])
|
|
optimizer.load_state_dict(ckpt["optimizer_state"])
|
|
start_episode = ckpt["episode"] + 1
|
|
loss_history = ckpt.get("loss_history", [])
|
|
print(f"Resuming training from episode {start_episode} (found checkpoint).")
|
|
|
|
# If we've already passed num_episodes, just plot and exit
|
|
if start_episode > num_episodes:
|
|
print(
|
|
"Checkpoint episode exceeds requested num_episodes; nothing to train."
|
|
)
|
|
plot_losses(loss_history, out_path="loss_curve.png")
|
|
torch.save(model.state_dict(), save_path)
|
|
return
|
|
|
|
pbar = trange(start_episode, num_episodes + 1, desc="Training", dynamic_ncols=True)
|
|
|
|
for episode in pbar:
|
|
log_probs_p1, log_probs_p2, r1, r2 = play_self_play_game(model, watch=watch)
|
|
|
|
loss = torch.tensor(0.0)
|
|
if log_probs_p1:
|
|
loss = loss - r1 * torch.stack(log_probs_p1).sum()
|
|
if log_probs_p2:
|
|
loss = loss - r2 * torch.stack(log_probs_p2).sum()
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
loss_value = float(loss.item())
|
|
loss_history.append(loss_value)
|
|
|
|
# Update progress bar with most recent stats
|
|
pbar.set_postfix(
|
|
loss=loss_value,
|
|
)
|
|
|
|
# Save checkpoint every N episodes (and at the very end)
|
|
if episode % 50 == 0 or episode == num_episodes:
|
|
torch.save(
|
|
{
|
|
"episode": episode,
|
|
"model_state": model.state_dict(),
|
|
"optimizer_state": optimizer.state_dict(),
|
|
"loss_history": loss_history,
|
|
},
|
|
ckpt_path,
|
|
)
|
|
|
|
# Final model save
|
|
torch.save(model.state_dict(), save_path)
|
|
print(f"\nTraining finished. Model saved to {save_path}")
|
|
|
|
# Save final loss plot
|
|
plot_losses(loss_history, out_path="loss_curve.png")
|
|
|
|
|
|
###################
|
|
# Play vs the AI #
|
|
###################
|
|
|
|
|
|
def play_vs_ai(model: MoveScorer, human_is: int = 1):
|
|
"""
|
|
Let a human play against the trained model.
|
|
human_is: 1 or 2
|
|
"""
|
|
game_state = initial_game_state()
|
|
board, p1tiles, p2tiles = game_state
|
|
|
|
human = HumanAgent()
|
|
ai = MLAgent(model, deterministic=True, epsilon=0.0)
|
|
|
|
agents = {
|
|
human_is: human,
|
|
1 if human_is == 2 else 2: ai,
|
|
}
|
|
|
|
player = 1
|
|
while True:
|
|
agent = agents[player]
|
|
move = agent.choose_move(game_state, player)
|
|
|
|
if move is None:
|
|
print(f"No moves left, player {player} lost")
|
|
if player == human_is:
|
|
print("You lost 😢")
|
|
else:
|
|
print("You won! 🎉")
|
|
break
|
|
|
|
tidx, tile, placement = move
|
|
gp = game.Player.P1 if player == 1 else game.Player.P2
|
|
|
|
print(f"player {player} places tile {tidx} at {placement}\n{tile}")
|
|
|
|
board.place(tile, placement, gp)
|
|
if player == 1:
|
|
p1tiles.remove(tidx)
|
|
else:
|
|
p2tiles.remove(tidx)
|
|
|
|
game_state = (board, p1tiles, p2tiles)
|
|
print_game_state(game_state)
|
|
|
|
player = 2 if player == 1 else 1
|
|
|
|
|
|
############
|
|
# main #
|
|
############
|
|
|
|
|
|
def main():
|
|
model = MoveScorer()
|
|
|
|
if torch.cuda.is_available():
|
|
print("using CUDA")
|
|
torch.device("cuda:0")
|
|
else:
|
|
print("Not using CUDA")
|
|
|
|
if "--play" in sys.argv:
|
|
# Try to load trained weights if they exist
|
|
model_path = "trained_agent.pt"
|
|
if os.path.exists(model_path):
|
|
print(f"Loading model from {model_path}")
|
|
state = torch.load(model_path, map_location="cpu")
|
|
model.load_state_dict(state)
|
|
model.eval()
|
|
else:
|
|
print(
|
|
"Warning: trained_agent.pt not found. Playing with an untrained model."
|
|
)
|
|
|
|
# By default, human is player 1; change to 2 if you want
|
|
play_vs_ai(model, human_is=1)
|
|
else:
|
|
# Train by self-play
|
|
train(
|
|
model,
|
|
num_episodes=1000,
|
|
lr=1e-3,
|
|
save_path="trained_agent.pt",
|
|
watch="--watch" in sys.argv,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|