Compare commits
No commits in common. "2b2c81e66ea01f2b838d36bf34d7b50d57713652" and "7b506195cc8b2250ec8a4729711f207178785c2f" have entirely different histories.
2b2c81e66e
...
7b506195cc
5 changed files with 84 additions and 510 deletions
2
.envrc
2
.envrc
|
|
@ -1 +1 @@
|
||||||
use flake . -L --impure
|
use flake
|
||||||
|
|
|
||||||
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -12,7 +12,3 @@ wheels/
|
||||||
|
|
||||||
result
|
result
|
||||||
target
|
target
|
||||||
|
|
||||||
*.ckpt
|
|
||||||
trained_agent.pt
|
|
||||||
loss_curve.png
|
|
||||||
|
|
|
||||||
551
blokus.py
551
blokus.py
|
|
@ -1,14 +1,9 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
import random
|
import random
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import game
|
import game
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from tqdm.auto import trange
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
###############
|
###############
|
||||||
# Utilities #
|
# Utilities #
|
||||||
|
|
@ -27,509 +22,107 @@ def print_game_state(game_state: tuple[game.Board, list[int], list[int]]):
|
||||||
for j in range(BOARD_SIZE):
|
for j in range(BOARD_SIZE):
|
||||||
barr[i].append(board[(j, i)])
|
barr[i].append(board[(j, i)])
|
||||||
|
|
||||||
print(f" {'-' * BOARD_SIZE} ")
|
|
||||||
for row in barr:
|
for row in barr:
|
||||||
print(
|
print(
|
||||||
f"|{
|
"".join(
|
||||||
''.join(
|
|
||||||
[
|
[
|
||||||
' ' if x == 0 else 'X' if x == 1 else 'O' if x == 2 else 'S'
|
" " if x == 0 else "X" if x == 1 else "O" if x == 2 else "S"
|
||||||
for x in row
|
for x in row
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
}|"
|
|
||||||
)
|
)
|
||||||
print(f" {'-' * BOARD_SIZE} ")
|
|
||||||
|
|
||||||
|
print("")
|
||||||
print(f"Player 1 tiles left: {p1tiles}")
|
print(f"Player 1 tiles left: {p1tiles}")
|
||||||
print(f"Player 2 tiles left: {p2tiles}")
|
print(f"Player 2 tiles left: {p2tiles}")
|
||||||
|
|
||||||
|
|
||||||
def plot_losses(loss_history, out_path="loss_curve.png"):
|
|
||||||
if not loss_history:
|
|
||||||
print("No losses to plot.")
|
|
||||||
return
|
|
||||||
|
|
||||||
plt.figure()
|
|
||||||
plt.plot(range(1, len(loss_history) + 1), loss_history)
|
|
||||||
plt.xlabel("Episode")
|
|
||||||
plt.ylabel("Loss")
|
|
||||||
plt.title("Training loss over episodes")
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.savefig(out_path)
|
|
||||||
plt.close()
|
|
||||||
print(f"Saved loss plot to {out_path}")
|
|
||||||
|
|
||||||
|
|
||||||
###################
|
###################
|
||||||
# Game state init #
|
# Game state init #
|
||||||
###################
|
###################
|
||||||
|
|
||||||
|
game_state = (
|
||||||
def initial_game_state():
|
|
||||||
return (
|
|
||||||
game.Board(),
|
game.Board(),
|
||||||
[i for i in range(21)],
|
[i for i in range(21)],
|
||||||
[i for i in range(21)],
|
[i for i in range(21)],
|
||||||
|
)
|
||||||
|
|
||||||
|
###################
|
||||||
|
# RL Utils #
|
||||||
|
###################
|
||||||
|
|
||||||
|
|
||||||
|
class Saver:
|
||||||
|
def __init__(self, results_path, experiment_seed):
|
||||||
|
self.stats_file = {"train": {}, "test": {}}
|
||||||
|
self.exp_seed = experiment_seed
|
||||||
|
self.rpath = results_path
|
||||||
|
|
||||||
|
def get_new_episode(self, mode, episode_no):
|
||||||
|
if mode == "train":
|
||||||
|
self.stats_file[mode][episode_no] = {
|
||||||
|
"loss": [],
|
||||||
|
"actions": [],
|
||||||
|
"errors": [],
|
||||||
|
"errors_noiseless": [],
|
||||||
|
"done_threshold": 0,
|
||||||
|
"bond_distance": 0,
|
||||||
|
"nfev": [],
|
||||||
|
"opt_ang": [],
|
||||||
|
"time": [],
|
||||||
|
"save_circ": [],
|
||||||
|
"reward": [],
|
||||||
|
}
|
||||||
|
elif mode == "test":
|
||||||
|
self.stats_file[mode][episode_no] = {
|
||||||
|
"actions": [],
|
||||||
|
"errors": [],
|
||||||
|
"errors_noiseless": [],
|
||||||
|
"done_threshold": 0,
|
||||||
|
"bond_distance": 0,
|
||||||
|
"nfev": [],
|
||||||
|
"opt_ang": [],
|
||||||
|
"time": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
def save_file(self):
|
||||||
|
np.save(f"{self.rpath}/summary_{self.exp_seed}.npy", self.stats_file)
|
||||||
|
|
||||||
|
def validate_stats(self, episode, mode):
|
||||||
|
assert len(self.stats_file[mode][episode]["actions"]) == len(
|
||||||
|
self.stats_file[mode][episode]["errors"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
############
|
playing = True
|
||||||
# Encoding #
|
player = 1
|
||||||
############
|
while playing:
|
||||||
|
|
||||||
|
|
||||||
def encode_board(board: game.Board) -> torch.Tensor:
|
|
||||||
# board[(x, y)] returns 0,1,2,... according to your print function
|
|
||||||
arr = torch.zeros((3, BOARD_SIZE, BOARD_SIZE), dtype=torch.float32)
|
|
||||||
for y in range(BOARD_SIZE):
|
|
||||||
for x in range(BOARD_SIZE):
|
|
||||||
v = board[(x, y)]
|
|
||||||
if v == 1:
|
|
||||||
arr[0, y, x] = 1.0
|
|
||||||
elif v == 2:
|
|
||||||
arr[1, y, x] = 1.0
|
|
||||||
elif v == 3: # if "S" or something else
|
|
||||||
arr[2, y, x] = 1.0
|
|
||||||
return arr
|
|
||||||
|
|
||||||
|
|
||||||
def encode_tiles(p1tiles, p2tiles) -> torch.Tensor:
|
|
||||||
# 21 tiles total, so 42-dim vector
|
|
||||||
v = torch.zeros(42, dtype=torch.float32)
|
|
||||||
for t in p1tiles:
|
|
||||||
v[t] = 1.0
|
|
||||||
offset = 21
|
|
||||||
for t in p2tiles:
|
|
||||||
v[offset + t] = 1.0
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
def encode_move(tile_idx: int, placement: tuple[int, int]) -> torch.Tensor:
|
|
||||||
(x, y) = placement
|
|
||||||
tile_vec = torch.zeros(21, dtype=torch.float32)
|
|
||||||
tile_vec[tile_idx] = 1.0
|
|
||||||
pos_vec = torch.tensor(
|
|
||||||
[x / (BOARD_SIZE - 1), y / (BOARD_SIZE - 1)], dtype=torch.float32
|
|
||||||
)
|
|
||||||
return torch.cat([tile_vec, pos_vec], dim=0) # 23-dim
|
|
||||||
|
|
||||||
|
|
||||||
def encode_state_and_move(
|
|
||||||
game_state, player: int, tile_idx: int, placement: tuple[int, int], perm: game.Tile
|
|
||||||
):
|
|
||||||
board, p1tiles, p2tiles = game_state
|
|
||||||
|
|
||||||
# Encode board BEFORE the move
|
|
||||||
board_before = encode_board(board).flatten()
|
|
||||||
|
|
||||||
# Encode board AFTER the move using sim_place
|
|
||||||
gp = game.Player.P1 if player == 1 else game.Player.P2
|
|
||||||
board_after_sim = board.sim_place(
|
|
||||||
perm, placement, gp
|
|
||||||
) # <--- uses your new function
|
|
||||||
board_after = encode_board(board_after_sim).flatten()
|
|
||||||
|
|
||||||
tiles_tensor = encode_tiles(p1tiles, p2tiles)
|
|
||||||
move_tensor = encode_move(tile_idx, placement) # still tile+position encoding
|
|
||||||
player_tensor = torch.tensor([1.0 if player == 1 else 0.0], dtype=torch.float32)
|
|
||||||
|
|
||||||
return torch.cat(
|
|
||||||
[
|
|
||||||
board_before, # 588
|
|
||||||
board_after, # 588
|
|
||||||
tiles_tensor, # 42
|
|
||||||
move_tensor, # 23
|
|
||||||
player_tensor, # 1
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
###########
|
|
||||||
# Model #
|
|
||||||
###########
|
|
||||||
|
|
||||||
FEATURE_SIZE = 1242 # from above
|
|
||||||
|
|
||||||
|
|
||||||
class MoveScorer(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.fc1 = nn.Linear(FEATURE_SIZE, 256)
|
|
||||||
self.fc2 = nn.Linear(256, 128)
|
|
||||||
self.fc_out = nn.Linear(128, 1) # scalar score
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# x: (batch_size, FEATURE_SIZE)
|
|
||||||
x = F.relu(self.fc1(x))
|
|
||||||
x = F.relu(self.fc2(x))
|
|
||||||
return self.fc_out(x) # (batch_size, 1)
|
|
||||||
|
|
||||||
|
|
||||||
##################
|
|
||||||
# Move generation
|
|
||||||
##################
|
|
||||||
|
|
||||||
|
|
||||||
def get_legal_moves(game_state, player: int):
|
|
||||||
board, p1tiles, p2tiles = game_state
|
|
||||||
gp = game.Player.P1 if player == 1 else game.Player.P2
|
|
||||||
|
|
||||||
tiles_left = p1tiles if player == 1 else p2tiles
|
|
||||||
|
|
||||||
moves = []
|
moves = []
|
||||||
for tile_idx in tiles_left:
|
assert player == 1 or player == 2
|
||||||
|
gp = game.Player.P1 if player == 1 else game.Player.P2
|
||||||
|
for tile_idx in game_state[player]:
|
||||||
tile = tiles[tile_idx]
|
tile = tiles[tile_idx]
|
||||||
perms = tile.permutations()
|
perms = tile.permutations()
|
||||||
for perm in perms:
|
for perm in perms:
|
||||||
plcs = board.tile_placements(perm, gp)
|
plcs = game_state[0].tile_placements(perm, gp)
|
||||||
moves.extend((tile_idx, perm, plc) for plc in plcs)
|
moves.extend((tile_idx, perm, plc) for plc in plcs)
|
||||||
return moves
|
|
||||||
|
|
||||||
|
print(f"player {player} has {len(moves)} options")
|
||||||
|
|
||||||
###########
|
if len(moves) == 0:
|
||||||
# Agents #
|
|
||||||
###########
|
|
||||||
|
|
||||||
|
|
||||||
class Agent:
|
|
||||||
def choose_move(self, game_state, player: int):
|
|
||||||
"""Return (tile_idx, perm, placement) or None if no moves."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class RandomAgent(Agent):
|
|
||||||
def choose_move(self, game_state, player: int):
|
|
||||||
moves = get_legal_moves(game_state, player)
|
|
||||||
if not moves:
|
|
||||||
return None
|
|
||||||
return random.choice(moves)
|
|
||||||
|
|
||||||
|
|
||||||
class HumanAgent(Agent):
|
|
||||||
def choose_move(self, game_state, player: int):
|
|
||||||
moves = get_legal_moves(game_state, player)
|
|
||||||
if not moves:
|
|
||||||
print(f"No moves left for player {player}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
print_game_state(game_state)
|
|
||||||
print(f"Player {player}, you have {len(moves)} possible moves.")
|
|
||||||
|
|
||||||
# Show a *subset* or all moves
|
|
||||||
for i, (tidx, perm, plc) in enumerate(moves):
|
|
||||||
if i < 50: # don't spam too hard; tweak as needed
|
|
||||||
print(f"[{i}] tile {tidx} at {plc}")
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
if len(moves) > 50:
|
|
||||||
print(f"... and {len(moves) - 50} more moves not listed")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
choice = int(input("Enter move index: "))
|
|
||||||
if 0 <= choice < len(moves):
|
|
||||||
return moves[choice]
|
|
||||||
else:
|
|
||||||
print("Invalid index, try again.")
|
|
||||||
except ValueError:
|
|
||||||
print("Please enter an integer.")
|
|
||||||
|
|
||||||
|
|
||||||
class MLAgent(Agent):
|
|
||||||
def __init__(
|
|
||||||
self, model: MoveScorer, deterministic: bool = True, epsilon: float = 0.0
|
|
||||||
):
|
|
||||||
self.model = model
|
|
||||||
self.deterministic = deterministic
|
|
||||||
self.epsilon = epsilon
|
|
||||||
|
|
||||||
def choose_move(self, game_state, player: int):
|
|
||||||
moves = get_legal_moves(game_state, player)
|
|
||||||
if not moves:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Optional epsilon-greedy: use 0 for “serious play”
|
|
||||||
if self.epsilon > 0.0 and random.random() < self.epsilon:
|
|
||||||
return random.choice(moves)
|
|
||||||
|
|
||||||
# Build feature batch
|
|
||||||
features = []
|
|
||||||
for tidx, perm, placement in moves:
|
|
||||||
feat = encode_state_and_move(game_state, player, tidx, placement, perm)
|
|
||||||
features.append(feat)
|
|
||||||
X = torch.stack(features, dim=0)
|
|
||||||
|
|
||||||
self.model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
scores = self.model(X).squeeze(-1) # (num_moves,)
|
|
||||||
|
|
||||||
if self.deterministic:
|
|
||||||
best_idx = torch.argmax(scores).item()
|
|
||||||
return moves[best_idx]
|
|
||||||
else:
|
|
||||||
# Sample from softmax for more variety
|
|
||||||
probs = torch.softmax(scores, dim=0)
|
|
||||||
idx = torch.multinomial(probs, num_samples=1).item()
|
|
||||||
return moves[idx]
|
|
||||||
|
|
||||||
|
|
||||||
######################
|
|
||||||
# Training utilities #
|
|
||||||
######################
|
|
||||||
|
|
||||||
|
|
||||||
def select_move_and_logprob(model: MoveScorer, game_state, player: int):
|
|
||||||
"""
|
|
||||||
For training: sample a move from softmax over scores
|
|
||||||
and return (move, log_prob). If no moves, returns (None, None).
|
|
||||||
"""
|
|
||||||
moves = get_legal_moves(game_state, player)
|
|
||||||
if not moves:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
features = []
|
|
||||||
for tidx, perm, placement in moves:
|
|
||||||
feat = encode_state_and_move(game_state, player, tidx, placement, perm)
|
|
||||||
features.append(feat)
|
|
||||||
X = torch.stack(features, dim=0) # (num_moves, FEATURE_SIZE)
|
|
||||||
|
|
||||||
scores = model(X).squeeze(-1) # (num_moves,)
|
|
||||||
probs = F.softmax(scores, dim=0)
|
|
||||||
|
|
||||||
dist = torch.distributions.Categorical(probs)
|
|
||||||
idx = dist.sample()
|
|
||||||
log_prob = dist.log_prob(idx)
|
|
||||||
|
|
||||||
move = moves[idx.item()]
|
|
||||||
return move, log_prob
|
|
||||||
|
|
||||||
|
|
||||||
def play_self_play_game(model: MoveScorer, max_turns: int = 500, watch: bool = False):
|
|
||||||
"""
|
|
||||||
Self-play game with the same model as both players.
|
|
||||||
Returns:
|
|
||||||
log_probs_p1, log_probs_p2, reward_p1, reward_p2
|
|
||||||
where rewards are +1/-1 for win/loss.
|
|
||||||
"""
|
|
||||||
game_state = initial_game_state()
|
|
||||||
board, p1tiles, p2tiles = game_state
|
|
||||||
|
|
||||||
log_probs = {1: [], 2: []}
|
|
||||||
player = 1
|
|
||||||
turns = 0
|
|
||||||
|
|
||||||
while True:
|
|
||||||
turns += 1
|
|
||||||
if turns > max_turns:
|
|
||||||
# Safety: declare a draw
|
|
||||||
reward_p1 = 0.0
|
|
||||||
reward_p2 = 0.0
|
|
||||||
return log_probs[1], log_probs[2], reward_p1, reward_p2
|
|
||||||
|
|
||||||
move, log_prob = select_move_and_logprob(model, game_state, player)
|
|
||||||
|
|
||||||
if move is None:
|
|
||||||
# Current player cannot move -> they lose
|
|
||||||
if player == 1:
|
|
||||||
reward_p1 = -1.0
|
|
||||||
reward_p2 = +1.0
|
|
||||||
else:
|
|
||||||
reward_p1 = +1.0
|
|
||||||
reward_p2 = -1.0
|
|
||||||
return log_probs[1], log_probs[2], reward_p1, reward_p2
|
|
||||||
|
|
||||||
tidx, tile, placement = move
|
|
||||||
gp = game.Player.P1 if player == 1 else game.Player.P2
|
|
||||||
|
|
||||||
# Apply move
|
|
||||||
board.place(tile, placement, gp)
|
|
||||||
if player == 1:
|
|
||||||
p1tiles.remove(tidx)
|
|
||||||
else:
|
|
||||||
p2tiles.remove(tidx)
|
|
||||||
|
|
||||||
# Update game_state tuple
|
|
||||||
game_state = (board, p1tiles, p2tiles)
|
|
||||||
if watch:
|
|
||||||
print_game_state(game_state)
|
|
||||||
|
|
||||||
# Store log_prob
|
|
||||||
log_probs[player].append(log_prob)
|
|
||||||
|
|
||||||
# Switch player
|
|
||||||
player = 2 if player == 1 else 1
|
|
||||||
|
|
||||||
|
|
||||||
def train(
|
|
||||||
model: MoveScorer,
|
|
||||||
num_episodes: int = 1000,
|
|
||||||
lr: float = 1e-3,
|
|
||||||
save_path: str = "trained_agent.pt",
|
|
||||||
watch: bool = False,
|
|
||||||
):
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
||||||
|
|
||||||
# We'll keep a history of losses for plotting
|
|
||||||
loss_history = []
|
|
||||||
|
|
||||||
# Checkpoint path (partial training state)
|
|
||||||
ckpt_path = save_path + ".ckpt"
|
|
||||||
|
|
||||||
start_episode = 1
|
|
||||||
|
|
||||||
# Try to resume from checkpoint if it exists
|
|
||||||
if os.path.exists(ckpt_path):
|
|
||||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
|
||||||
model.load_state_dict(ckpt["model_state"])
|
|
||||||
optimizer.load_state_dict(ckpt["optimizer_state"])
|
|
||||||
start_episode = ckpt["episode"] + 1
|
|
||||||
loss_history = ckpt.get("loss_history", [])
|
|
||||||
print(f"Resuming training from episode {start_episode} (found checkpoint).")
|
|
||||||
|
|
||||||
# If we've already passed num_episodes, just plot and exit
|
|
||||||
if start_episode > num_episodes:
|
|
||||||
print(
|
|
||||||
"Checkpoint episode exceeds requested num_episodes; nothing to train."
|
|
||||||
)
|
|
||||||
plot_losses(loss_history, out_path="loss_curve.png")
|
|
||||||
torch.save(model.state_dict(), save_path)
|
|
||||||
return
|
|
||||||
|
|
||||||
pbar = trange(start_episode, num_episodes + 1, desc="Training", dynamic_ncols=True)
|
|
||||||
|
|
||||||
for episode in pbar:
|
|
||||||
log_probs_p1, log_probs_p2, r1, r2 = play_self_play_game(model, watch=watch)
|
|
||||||
|
|
||||||
loss = torch.tensor(0.0)
|
|
||||||
if log_probs_p1:
|
|
||||||
loss = loss - r1 * torch.stack(log_probs_p1).sum()
|
|
||||||
if log_probs_p2:
|
|
||||||
loss = loss - r2 * torch.stack(log_probs_p2).sum()
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
loss_value = float(loss.item())
|
|
||||||
loss_history.append(loss_value)
|
|
||||||
|
|
||||||
# Update progress bar with most recent stats
|
|
||||||
pbar.set_postfix(
|
|
||||||
loss=loss_value,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save checkpoint every N episodes (and at the very end)
|
|
||||||
if episode % 50 == 0 or episode == num_episodes:
|
|
||||||
torch.save(
|
|
||||||
{
|
|
||||||
"episode": episode,
|
|
||||||
"model_state": model.state_dict(),
|
|
||||||
"optimizer_state": optimizer.state_dict(),
|
|
||||||
"loss_history": loss_history,
|
|
||||||
},
|
|
||||||
ckpt_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Final model save
|
|
||||||
torch.save(model.state_dict(), save_path)
|
|
||||||
print(f"\nTraining finished. Model saved to {save_path}")
|
|
||||||
|
|
||||||
# Save final loss plot
|
|
||||||
plot_losses(loss_history, out_path="loss_curve.png")
|
|
||||||
|
|
||||||
|
|
||||||
###################
|
|
||||||
# Play vs the AI #
|
|
||||||
###################
|
|
||||||
|
|
||||||
|
|
||||||
def play_vs_ai(model: MoveScorer, human_is: int = 1):
|
|
||||||
"""
|
|
||||||
Let a human play against the trained model.
|
|
||||||
human_is: 1 or 2
|
|
||||||
"""
|
|
||||||
game_state = initial_game_state()
|
|
||||||
board, p1tiles, p2tiles = game_state
|
|
||||||
|
|
||||||
human = HumanAgent()
|
|
||||||
ai = MLAgent(model, deterministic=True, epsilon=0.0)
|
|
||||||
|
|
||||||
agents = {
|
|
||||||
human_is: human,
|
|
||||||
1 if human_is == 2 else 2: ai,
|
|
||||||
}
|
|
||||||
|
|
||||||
player = 1
|
|
||||||
while True:
|
|
||||||
agent = agents[player]
|
|
||||||
move = agent.choose_move(game_state, player)
|
|
||||||
|
|
||||||
if move is None:
|
|
||||||
print(f"No moves left, player {player} lost")
|
print(f"No moves left, player {player} lost")
|
||||||
if player == human_is:
|
playing = False
|
||||||
print("You lost 😢")
|
continue
|
||||||
else:
|
|
||||||
print("You won! 🎉")
|
|
||||||
break
|
|
||||||
|
|
||||||
tidx, tile, placement = move
|
(tidx, tile, placement) = random.choice(moves)
|
||||||
gp = game.Player.P1 if player == 1 else game.Player.P2
|
print(
|
||||||
|
f"player {player} is placing the following tile with index {tidx} at {placement}\n{tile}"
|
||||||
print(f"player {player} places tile {tidx} at {placement}\n{tile}")
|
)
|
||||||
|
game_state[0].place(tile, placement, gp)
|
||||||
board.place(tile, placement, gp)
|
game_state[player].remove(tidx)
|
||||||
if player == 1:
|
|
||||||
p1tiles.remove(tidx)
|
|
||||||
else:
|
|
||||||
p2tiles.remove(tidx)
|
|
||||||
|
|
||||||
game_state = (board, p1tiles, p2tiles)
|
|
||||||
print_game_state(game_state)
|
print_game_state(game_state)
|
||||||
|
|
||||||
player = 2 if player == 1 else 1
|
if player == 1:
|
||||||
|
player = 2
|
||||||
|
elif player == 2:
|
||||||
############
|
player = 1
|
||||||
# main #
|
|
||||||
############
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
model = MoveScorer()
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
print("using CUDA")
|
|
||||||
torch.device("cuda:0")
|
|
||||||
else:
|
|
||||||
print("Not using CUDA")
|
|
||||||
|
|
||||||
if "--play" in sys.argv:
|
|
||||||
# Try to load trained weights if they exist
|
|
||||||
model_path = "trained_agent.pt"
|
|
||||||
if os.path.exists(model_path):
|
|
||||||
print(f"Loading model from {model_path}")
|
|
||||||
state = torch.load(model_path, map_location="cpu")
|
|
||||||
model.load_state_dict(state)
|
|
||||||
model.eval()
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
"Warning: trained_agent.pt not found. Playing with an untrained model."
|
|
||||||
)
|
|
||||||
|
|
||||||
# By default, human is player 1; change to 2 if you want
|
|
||||||
play_vs_ai(model, human_is=1)
|
|
||||||
else:
|
|
||||||
# Train by self-play
|
|
||||||
train(
|
|
||||||
model,
|
|
||||||
num_episodes=1000,
|
|
||||||
lr=1e-3,
|
|
||||||
save_path="trained_agent.pt",
|
|
||||||
watch="--watch" in sys.argv,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
|
||||||
11
flake.nix
11
flake.nix
|
|
@ -22,10 +22,6 @@
|
||||||
pkgs = import inputs.nixpkgs {
|
pkgs = import inputs.nixpkgs {
|
||||||
inherit system;
|
inherit system;
|
||||||
overlays = [ inputs.rust-overlay.overlays.default ];
|
overlays = [ inputs.rust-overlay.overlays.default ];
|
||||||
config = {
|
|
||||||
allowUnfree = true;
|
|
||||||
cudaSupport = true;
|
|
||||||
};
|
|
||||||
};
|
};
|
||||||
lib = pkgs.lib;
|
lib = pkgs.lib;
|
||||||
|
|
||||||
|
|
@ -100,20 +96,15 @@
|
||||||
packages = [
|
packages = [
|
||||||
(pkgs.python3.withPackages (ppkgs: [
|
(pkgs.python3.withPackages (ppkgs: [
|
||||||
ppkgs.torch
|
ppkgs.torch
|
||||||
ppkgs.tqdm
|
|
||||||
ppkgs.matplotlib
|
|
||||||
(lib.python_package ppkgs)
|
(lib.python_package ppkgs)
|
||||||
]))
|
]))
|
||||||
];
|
];
|
||||||
shellHook = ''
|
|
||||||
export CUDA_PATH=${pkgs.cudatoolkit}
|
|
||||||
'';
|
|
||||||
};
|
};
|
||||||
lib = {
|
lib = {
|
||||||
# To use in other builds with the "withPackages" call
|
# To use in other builds with the "withPackages" call
|
||||||
python_package =
|
python_package =
|
||||||
ps:
|
ps:
|
||||||
ps.buildPythonPackage {
|
ps.buildPythonPackage rec {
|
||||||
pname = project_name;
|
pname = project_name;
|
||||||
format = "wheel";
|
format = "wheel";
|
||||||
version = project_version;
|
version = project_version;
|
||||||
|
|
|
||||||
|
|
@ -230,12 +230,6 @@ mod game {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sim_place(&self, tile: Tile, pos: (usize, usize), player: Player) -> Self {
|
|
||||||
let mut other = self.clone();
|
|
||||||
other.place(tile, pos, player);
|
|
||||||
other
|
|
||||||
}
|
|
||||||
|
|
||||||
fn place(&mut self, tile: Tile, pos: (usize, usize), player: Player) {
|
fn place(&mut self, tile: Tile, pos: (usize, usize), player: Player) {
|
||||||
let (x, y) = pos;
|
let (x, y) = pos;
|
||||||
for &(i, j) in tile.parts.iter() {
|
for &(i, j) in tile.parts.iter() {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue