"dumb" RL
This commit is contained in:
parent
d73dba80cd
commit
44e30869f8
1 changed files with 307 additions and 35 deletions
342
blokus.py
342
blokus.py
|
|
@ -1,12 +1,20 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
BOARD_SIZE = 14
|
BOARD_SIZE = 14
|
||||||
|
|
||||||
|
# =======================
|
||||||
|
# Game setup and rules
|
||||||
|
# =======================
|
||||||
|
|
||||||
|
|
||||||
def make_board():
|
def make_board():
|
||||||
a = np.array([[0 for i in range(BOARD_SIZE)] for j in range(BOARD_SIZE)])
|
a = np.array([[0 for _ in range(BOARD_SIZE)] for _ in range(BOARD_SIZE)])
|
||||||
a[4, 4] = -1
|
a[4, 4] = -1
|
||||||
a[9, 9] = -1
|
a[9, 9] = -1
|
||||||
return a
|
return a
|
||||||
|
|
@ -38,14 +46,16 @@ tiles = [
|
||||||
|
|
||||||
|
|
||||||
def get_permutations(which_tiles: list[int]):
|
def get_permutations(which_tiles: list[int]):
|
||||||
|
"""
|
||||||
|
For each tile index in which_tiles, generate all unique rotations/flips.
|
||||||
|
Returns a list of (tile_index, oriented_tile).
|
||||||
|
"""
|
||||||
permutations = []
|
permutations = []
|
||||||
|
for tidx in which_tiles:
|
||||||
for i, tile in enumerate(tiles):
|
tile = tiles[tidx]
|
||||||
if i not in which_tiles:
|
|
||||||
continue
|
|
||||||
|
|
||||||
rots = [np.rot90(tile, k) for k in range(4)]
|
rots = [np.rot90(tile, k) for k in range(4)]
|
||||||
flips = [np.flip(r, axis=1) for r in rots] # flip horizontally
|
flips = [np.flip(r, axis=1) for r in rots] # horizontal flips
|
||||||
all_orients = rots + flips # 8 orientations
|
all_orients = rots + flips # 8 orientations
|
||||||
|
|
||||||
seen = set()
|
seen = set()
|
||||||
|
|
@ -53,12 +63,12 @@ def get_permutations(which_tiles: list[int]):
|
||||||
key = (t.shape, t.tobytes())
|
key = (t.shape, t.tobytes())
|
||||||
if key not in seen:
|
if key not in seen:
|
||||||
seen.add(key)
|
seen.add(key)
|
||||||
permutations.append((i, t))
|
permutations.append((tidx, t))
|
||||||
|
|
||||||
return permutations
|
return permutations
|
||||||
|
|
||||||
|
|
||||||
def can_place(board, tile, player):
|
def can_place(board: np.ndarray, tile: np.ndarray, player: int):
|
||||||
placements = []
|
placements = []
|
||||||
has_minus_one = False
|
has_minus_one = False
|
||||||
for x in range(BOARD_SIZE):
|
for x in range(BOARD_SIZE):
|
||||||
|
|
@ -102,35 +112,37 @@ def can_place(board, tile, player):
|
||||||
if (
|
if (
|
||||||
x + i + 1 < BOARD_SIZE
|
x + i + 1 < BOARD_SIZE
|
||||||
and y + j + 1 < BOARD_SIZE
|
and y + j + 1 < BOARD_SIZE
|
||||||
and board[x + i + 1][y + j + 1] == player
|
and board[x + i + 1, y + j + 1] == player
|
||||||
):
|
):
|
||||||
final.append((x, y))
|
final.append((x, y))
|
||||||
break
|
break
|
||||||
if (
|
if (
|
||||||
x + i + 1 < BOARD_SIZE
|
x + i + 1 < BOARD_SIZE
|
||||||
and y + j - 1 >= 0
|
and y + j - 1 >= 0
|
||||||
and board[x + i + 1][y + j - 1] == player
|
and board[x + i + 1, y + j - 1] == player
|
||||||
):
|
):
|
||||||
final.append((x, y))
|
final.append((x, y))
|
||||||
break
|
break
|
||||||
if (
|
if (
|
||||||
x + i - 1 >= 0
|
x + i - 1 >= 0
|
||||||
and y + j + 1 < BOARD_SIZE
|
and y + j + 1 < BOARD_SIZE
|
||||||
and board[x + i - 1][y + j + 1] == player
|
and board[x + i - 1, y + j + 1] == player
|
||||||
):
|
):
|
||||||
final.append((x, y))
|
final.append((x, y))
|
||||||
break
|
break
|
||||||
if (
|
if (
|
||||||
x + i - 1 >= 0
|
x + i - 1 >= 0
|
||||||
and y + j - 1 >= 0
|
and y + j - 1 >= 0
|
||||||
and board[x + i - 1][y + j - 1] == player
|
and board[x + i - 1, y + j - 1] == player
|
||||||
):
|
):
|
||||||
final.append((x, y))
|
final.append((x, y))
|
||||||
break
|
break
|
||||||
return final
|
return final
|
||||||
|
|
||||||
|
|
||||||
def do_placement(tidx, tile, placement, game_state, player):
|
def do_placement(
|
||||||
|
tidx: int, tile: np.ndarray, placement: tuple[int, int], game_state, player: int
|
||||||
|
):
|
||||||
(x, y) = placement
|
(x, y) = placement
|
||||||
with np.nditer(tile, flags=["multi_index"]) as it:
|
with np.nditer(tile, flags=["multi_index"]) as it:
|
||||||
for v in it:
|
for v in it:
|
||||||
|
|
@ -156,35 +168,295 @@ def print_game_state(game_state):
|
||||||
print("")
|
print("")
|
||||||
print(f"Player 1 tiles left: {p1tiles}")
|
print(f"Player 1 tiles left: {p1tiles}")
|
||||||
print(f"Player 2 tiles left: {p2tiles}")
|
print(f"Player 2 tiles left: {p2tiles}")
|
||||||
|
print("")
|
||||||
|
|
||||||
|
|
||||||
game_state = (
|
def reset_game():
|
||||||
make_board(),
|
board = make_board()
|
||||||
[i for i in range(21)],
|
p1tiles = [i for i in range(21)]
|
||||||
[i for i in range(21)],
|
p2tiles = [i for i in range(21)]
|
||||||
)
|
return [board, p1tiles, p2tiles] # list so it's mutable in-place
|
||||||
|
|
||||||
|
|
||||||
playing = True
|
# =======================
|
||||||
player = 1
|
# RL: encoding & policy
|
||||||
while playing:
|
# =======================
|
||||||
|
|
||||||
|
|
||||||
|
def encode_board(board: np.ndarray, player: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Channels:
|
||||||
|
0: current player's stones
|
||||||
|
1: opponent's stones
|
||||||
|
2: starting squares (-1)
|
||||||
|
"""
|
||||||
|
me = (board == player).astype(np.float32)
|
||||||
|
opp = ((board > 0) & (board != player)).astype(np.float32)
|
||||||
|
start = (board == -1).astype(np.float32)
|
||||||
|
state = np.stack([me, opp, start], axis=0) # (3, 14, 14)
|
||||||
|
return torch.from_numpy(state)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_move(
|
||||||
|
tidx: int, tile: np.ndarray, placement: tuple[int, int]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
x, y = placement
|
||||||
|
area = int(tile.sum())
|
||||||
|
return torch.tensor([tidx, x, y, area], dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
class PolicyNet(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
nn.Conv2d(3, 32, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
conv_out_dim = 64 * BOARD_SIZE * BOARD_SIZE # 64 * 14 * 14
|
||||||
|
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(conv_out_dim + 4, 256),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(256, 1), # scalar logit
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, board_tensor: torch.Tensor, move_features: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
board_tensor: (3, 14, 14)
|
||||||
|
move_features: (N, 4)
|
||||||
|
returns: logits (N,)
|
||||||
|
"""
|
||||||
|
x = self.conv(board_tensor.unsqueeze(0)) # (1, 64, 14, 14)
|
||||||
|
x = x.view(1, -1) # (1, conv_out_dim)
|
||||||
|
x = x.repeat(move_features.size(0), 1) # (N, conv_out_dim)
|
||||||
|
|
||||||
|
combined = torch.cat([x, move_features], dim=1) # (N, conv_out_dim + 4)
|
||||||
|
logits = self.fc(combined).squeeze(-1) # (N,)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
# =======================
|
||||||
|
# RL: move generation & action selection
|
||||||
|
# =======================
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_moves(game_state, player: int):
|
||||||
|
board, p1tiles, p2tiles = game_state
|
||||||
|
available_tiles = p1tiles if player == 1 else p2tiles
|
||||||
|
|
||||||
moves = []
|
moves = []
|
||||||
for tidx, tile in get_permutations(game_state[player]):
|
for tidx, tile in get_permutations(available_tiles):
|
||||||
for placement in can_place(game_state[0], tile, player):
|
for placement in can_place(board, tile, player):
|
||||||
moves.append((tidx, tile, placement))
|
moves.append((tidx, tile, placement))
|
||||||
|
|
||||||
print_game_state(game_state)
|
return moves
|
||||||
print(f"player {player} has {len(moves)} options")
|
|
||||||
|
|
||||||
|
def select_action(policy: PolicyNet, game_state, player: int, device="cpu"):
|
||||||
|
board, _, _ = game_state
|
||||||
|
moves = get_all_moves(game_state, player)
|
||||||
|
|
||||||
if len(moves) == 0:
|
if len(moves) == 0:
|
||||||
print(f"No moves left, player {player} lost")
|
return None, None # no legal moves
|
||||||
playing = False
|
|
||||||
continue
|
|
||||||
|
|
||||||
(tidx, tile, placement) = random.choice(moves)
|
board_tensor = encode_board(board, player).to(device)
|
||||||
do_placement(tidx, tile, placement, game_state, player)
|
|
||||||
|
|
||||||
if player == 1:
|
move_feats = torch.stack(
|
||||||
player = 2
|
[encode_move(tidx, tile, placement) for (tidx, tile, placement) in moves], dim=0
|
||||||
elif player == 2:
|
).to(device)
|
||||||
player = 1
|
|
||||||
|
logits = policy(board_tensor, move_feats) # (N,)
|
||||||
|
probs = F.softmax(logits, dim=0)
|
||||||
|
dist = torch.distributions.Categorical(probs)
|
||||||
|
idx = dist.sample()
|
||||||
|
log_prob = dist.log_prob(idx)
|
||||||
|
|
||||||
|
chosen_move = moves[idx.item()]
|
||||||
|
return chosen_move, log_prob
|
||||||
|
|
||||||
|
|
||||||
|
# =======================
|
||||||
|
# RL: self-play episode
|
||||||
|
# =======================
|
||||||
|
|
||||||
|
|
||||||
|
def play_episode(policy1: PolicyNet, policy2: PolicyNet, optim1, optim2, device="cpu"):
|
||||||
|
policy1.train()
|
||||||
|
policy2.train()
|
||||||
|
game_state = reset_game()
|
||||||
|
player = 1
|
||||||
|
|
||||||
|
log_probs1 = []
|
||||||
|
log_probs2 = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if player == 1:
|
||||||
|
move, log_prob = select_action(policy1, game_state, player, device)
|
||||||
|
else:
|
||||||
|
move, log_prob = select_action(policy2, game_state, player, device)
|
||||||
|
|
||||||
|
# No move → this player loses
|
||||||
|
if move is None:
|
||||||
|
loser = player
|
||||||
|
winner = 2 if player == 1 else 1
|
||||||
|
break
|
||||||
|
|
||||||
|
tidx, tile, placement = move
|
||||||
|
|
||||||
|
if player == 1:
|
||||||
|
log_probs1.append(log_prob)
|
||||||
|
else:
|
||||||
|
log_probs2.append(log_prob)
|
||||||
|
|
||||||
|
do_placement(tidx, tile, placement, game_state, player)
|
||||||
|
|
||||||
|
player = 2 if player == 1 else 1
|
||||||
|
|
||||||
|
print_game_state(game_state)
|
||||||
|
print(f"Player {winner} is the winner")
|
||||||
|
# Rewards: +1 for win, -1 for loss (from each player's perspective)
|
||||||
|
r1 = 1.0 if winner == 1 else -1.0
|
||||||
|
r2 = -r1
|
||||||
|
|
||||||
|
if log_probs1:
|
||||||
|
loss1 = -torch.stack(log_probs1).sum() * r1
|
||||||
|
optim1.zero_grad()
|
||||||
|
loss1.backward()
|
||||||
|
optim1.step()
|
||||||
|
|
||||||
|
if log_probs2:
|
||||||
|
loss2 = -torch.stack(log_probs2).sum() * r2
|
||||||
|
optim2.zero_grad()
|
||||||
|
loss2.backward()
|
||||||
|
optim2.step()
|
||||||
|
|
||||||
|
return r1 # from Player 1's perspective
|
||||||
|
|
||||||
|
|
||||||
|
# =======================
|
||||||
|
# Evaluation: watch them play
|
||||||
|
# =======================
|
||||||
|
|
||||||
|
|
||||||
|
def play_game(policy1: PolicyNet, policy2: PolicyNet, device="cpu"):
|
||||||
|
policy1.eval()
|
||||||
|
policy2.eval()
|
||||||
|
game_state = reset_game()
|
||||||
|
player = 1
|
||||||
|
while True:
|
||||||
|
print_game_state(game_state)
|
||||||
|
|
||||||
|
if player == 1:
|
||||||
|
move, _ = select_action(policy1, game_state, player, device)
|
||||||
|
else:
|
||||||
|
move, _ = select_action(policy2, game_state, player, device)
|
||||||
|
|
||||||
|
if move is None:
|
||||||
|
print(f"No moves left, player {player} lost")
|
||||||
|
break
|
||||||
|
|
||||||
|
tidx, tile, placement = move
|
||||||
|
do_placement(tidx, tile, placement, game_state, player)
|
||||||
|
|
||||||
|
player = 2 if player == 1 else 1
|
||||||
|
|
||||||
|
def load_policy(path, device="cpu"):
|
||||||
|
policy = PolicyNet().to(device)
|
||||||
|
policy.load_state_dict(torch.load(path, map_location=device))
|
||||||
|
policy.eval()
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def human_vs_ai(ai_policy: PolicyNet, device="cpu"):
|
||||||
|
ai_policy.eval()
|
||||||
|
game_state = reset_game()
|
||||||
|
player = 1 # AI goes first
|
||||||
|
|
||||||
|
while True:
|
||||||
|
print_game_state(game_state)
|
||||||
|
|
||||||
|
# Who moves?
|
||||||
|
if player == 1:
|
||||||
|
print("AI thinking...")
|
||||||
|
move, _ = select_action(ai_policy, game_state, player, device)
|
||||||
|
if move is None:
|
||||||
|
print("AI has no moves — AI loses!")
|
||||||
|
break
|
||||||
|
tidx, tile, placement = move
|
||||||
|
print(f"AI plays tile {tidx} at {placement}\n")
|
||||||
|
else:
|
||||||
|
# human turn
|
||||||
|
moves = get_all_moves(game_state, player)
|
||||||
|
if not moves:
|
||||||
|
print("You have no moves — you lose!")
|
||||||
|
break
|
||||||
|
|
||||||
|
print("Your legal moves:")
|
||||||
|
for i, (tidx, tile, placement) in enumerate(moves):
|
||||||
|
print(f"{i}: tile {tidx} at {placement}")
|
||||||
|
|
||||||
|
choice = int(input("Choose move number: "))
|
||||||
|
tidx, tile, placement = moves[choice]
|
||||||
|
|
||||||
|
# Apply move
|
||||||
|
do_placement(tidx, tile, placement, game_state, player)
|
||||||
|
|
||||||
|
# Switch players
|
||||||
|
player = 2 if player == 1 else 1
|
||||||
|
|
||||||
|
|
||||||
|
# =======================
|
||||||
|
# Main training loop
|
||||||
|
# =======================
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
policy1 = PolicyNet().to(device)
|
||||||
|
policy2 = PolicyNet().to(device)
|
||||||
|
|
||||||
|
optim1 = optim.Adam(policy1.parameters(), lr=1e-3)
|
||||||
|
optim2 = optim.Adam(policy2.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
best_avg_reward = -999
|
||||||
|
reward_history = []
|
||||||
|
|
||||||
|
num_episodes = 2000
|
||||||
|
for episode in range(1, num_episodes + 1):
|
||||||
|
reward = play_episode(policy1, policy2, optim1, optim2, device=device)
|
||||||
|
reward_history.append(reward)
|
||||||
|
|
||||||
|
# compute moving average every 50 episodes
|
||||||
|
if len(reward_history) >= 50:
|
||||||
|
avg = sum(reward_history[-50:]) / 50
|
||||||
|
|
||||||
|
# If policy1 improved, save it
|
||||||
|
if avg > best_avg_reward:
|
||||||
|
best_avg_reward = avg
|
||||||
|
torch.save(policy1.state_dict(), "best_policy1.pth")
|
||||||
|
print(f"Saved best policy1 at episode {episode} (avg reward={avg:.3f})")
|
||||||
|
|
||||||
|
if episode % 100 == 0:
|
||||||
|
print(f"Episode {episode}, last reward={reward}")
|
||||||
|
|
||||||
|
print("Training complete.")
|
||||||
|
print("1 = Watch AI vs AI")
|
||||||
|
print("2 = Play against AI")
|
||||||
|
print("3 = Quit")
|
||||||
|
|
||||||
|
choice = input("Select: ")
|
||||||
|
|
||||||
|
if choice == "1":
|
||||||
|
play_game(policy1, policy2, device)
|
||||||
|
elif choice == "2":
|
||||||
|
best_ai = load_policy("best_policy1.pth", device)
|
||||||
|
human_vs_ai(best_ai, device)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue