128 lines
3.2 KiB
Python
Executable file
128 lines
3.2 KiB
Python
Executable file
#!/usr/bin/env python
|
|
import random
|
|
import game
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
###############
|
|
# Utilities #
|
|
###############
|
|
|
|
BOARD_SIZE = 14
|
|
tiles = game.game_tiles()
|
|
|
|
|
|
def print_game_state(game_state: tuple[game.Board, list[int], list[int]]):
|
|
(board, p1tiles, p2tiles) = game_state
|
|
|
|
barr = []
|
|
for i in range(BOARD_SIZE):
|
|
barr.append([])
|
|
for j in range(BOARD_SIZE):
|
|
barr[i].append(board[(j, i)])
|
|
|
|
for row in barr:
|
|
print(
|
|
"".join(
|
|
[
|
|
" " if x == 0 else "X" if x == 1 else "O" if x == 2 else "S"
|
|
for x in row
|
|
]
|
|
)
|
|
)
|
|
|
|
print("")
|
|
print(f"Player 1 tiles left: {p1tiles}")
|
|
print(f"Player 2 tiles left: {p2tiles}")
|
|
|
|
|
|
###################
|
|
# Game state init #
|
|
###################
|
|
|
|
game_state = (
|
|
game.Board(),
|
|
[i for i in range(21)],
|
|
[i for i in range(21)],
|
|
)
|
|
|
|
###################
|
|
# RL Utils #
|
|
###################
|
|
|
|
|
|
class Saver:
|
|
def __init__(self, results_path, experiment_seed):
|
|
self.stats_file = {"train": {}, "test": {}}
|
|
self.exp_seed = experiment_seed
|
|
self.rpath = results_path
|
|
|
|
def get_new_episode(self, mode, episode_no):
|
|
if mode == "train":
|
|
self.stats_file[mode][episode_no] = {
|
|
"loss": [],
|
|
"actions": [],
|
|
"errors": [],
|
|
"errors_noiseless": [],
|
|
"done_threshold": 0,
|
|
"bond_distance": 0,
|
|
"nfev": [],
|
|
"opt_ang": [],
|
|
"time": [],
|
|
"save_circ": [],
|
|
"reward": [],
|
|
}
|
|
elif mode == "test":
|
|
self.stats_file[mode][episode_no] = {
|
|
"actions": [],
|
|
"errors": [],
|
|
"errors_noiseless": [],
|
|
"done_threshold": 0,
|
|
"bond_distance": 0,
|
|
"nfev": [],
|
|
"opt_ang": [],
|
|
"time": [],
|
|
}
|
|
|
|
def save_file(self):
|
|
np.save(f"{self.rpath}/summary_{self.exp_seed}.npy", self.stats_file)
|
|
|
|
def validate_stats(self, episode, mode):
|
|
assert len(self.stats_file[mode][episode]["actions"]) == len(
|
|
self.stats_file[mode][episode]["errors"]
|
|
)
|
|
|
|
|
|
playing = True
|
|
player = 1
|
|
while playing:
|
|
moves = []
|
|
assert player == 1 or player == 2
|
|
gp = game.Player.P1 if player == 1 else game.Player.P2
|
|
for tile_idx in game_state[player]:
|
|
tile = tiles[tile_idx]
|
|
perms = tile.permutations()
|
|
for perm in perms:
|
|
plcs = game_state[0].tile_placements(perm, gp)
|
|
moves.extend((tile_idx, perm, plc) for plc in plcs)
|
|
|
|
print(f"player {player} has {len(moves)} options")
|
|
|
|
if len(moves) == 0:
|
|
print(f"No moves left, player {player} lost")
|
|
playing = False
|
|
continue
|
|
|
|
(tidx, tile, placement) = random.choice(moves)
|
|
print(
|
|
f"player {player} is placing the following tile with index {tidx} at {placement}\n{tile}"
|
|
)
|
|
game_state[0].place(tile, placement, gp)
|
|
game_state[player].remove(tidx)
|
|
print_game_state(game_state)
|
|
|
|
if player == 1:
|
|
player = 2
|
|
elif player == 2:
|
|
player = 1
|