add tile permutations to ai knowledge
This commit is contained in:
parent
7483e6060a
commit
e2d5ca264b
1 changed files with 28 additions and 10 deletions
38
blokus.py
38
blokus.py
|
|
@ -98,24 +98,42 @@ def encode_move(tile_idx: int, placement: tuple[int, int]) -> torch.Tensor:
|
|||
return torch.cat([tile_vec, pos_vec], dim=0) # 23-dim
|
||||
|
||||
|
||||
def encode_state_and_move(game_state, player: int, tile_idx: int, placement):
|
||||
def encode_state_and_move(
|
||||
game_state, player: int, tile_idx: int, placement: tuple[int, int], perm: game.Tile
|
||||
):
|
||||
board, p1tiles, p2tiles = game_state
|
||||
board_tensor = encode_board(board).flatten() # 3*14*14 = 588
|
||||
tiles_tensor = encode_tiles(p1tiles, p2tiles) # 42
|
||||
move_tensor = encode_move(tile_idx, placement) # 23
|
||||
|
||||
# Encode "current player" as a bit
|
||||
# Encode board BEFORE the move
|
||||
board_before = encode_board(board).flatten()
|
||||
|
||||
# Encode board AFTER the move using sim_place
|
||||
gp = game.Player.P1 if player == 1 else game.Player.P2
|
||||
board_after_sim = board.sim_place(
|
||||
perm, placement, gp
|
||||
) # <--- uses your new function
|
||||
board_after = encode_board(board_after_sim).flatten()
|
||||
|
||||
tiles_tensor = encode_tiles(p1tiles, p2tiles)
|
||||
move_tensor = encode_move(tile_idx, placement) # still tile+position encoding
|
||||
player_tensor = torch.tensor([1.0 if player == 1 else 0.0], dtype=torch.float32)
|
||||
|
||||
return torch.cat([board_tensor, tiles_tensor, move_tensor, player_tensor], dim=0)
|
||||
# Total size = 588 + 42 + 23 + 1 = 654
|
||||
return torch.cat(
|
||||
[
|
||||
board_before, # 588
|
||||
board_after, # 588
|
||||
tiles_tensor, # 42
|
||||
move_tensor, # 23
|
||||
player_tensor, # 1
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
||||
###########
|
||||
# Model #
|
||||
###########
|
||||
|
||||
FEATURE_SIZE = 654 # from above
|
||||
FEATURE_SIZE = 1242 # from above
|
||||
|
||||
|
||||
class MoveScorer(nn.Module):
|
||||
|
|
@ -222,7 +240,7 @@ class MLAgent(Agent):
|
|||
# Build feature batch
|
||||
features = []
|
||||
for tidx, perm, placement in moves:
|
||||
feat = encode_state_and_move(game_state, player, tidx, placement)
|
||||
feat = encode_state_and_move(game_state, player, tidx, placement, perm)
|
||||
features.append(feat)
|
||||
X = torch.stack(features, dim=0)
|
||||
|
||||
|
|
@ -256,7 +274,7 @@ def select_move_and_logprob(model: MoveScorer, game_state, player: int):
|
|||
|
||||
features = []
|
||||
for tidx, perm, placement in moves:
|
||||
feat = encode_state_and_move(game_state, player, tidx, placement)
|
||||
feat = encode_state_and_move(game_state, player, tidx, placement, perm)
|
||||
features.append(feat)
|
||||
X = torch.stack(features, dim=0) # (num_moves, FEATURE_SIZE)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue