add --watch system option

This commit is contained in:
Noa Aarts 2025-12-05 20:46:52 +01:00
parent e2d5ca264b
commit 12b0c31239
Signed by: noa
GPG key ID: 1850932741EFF672

View file

@ -289,7 +289,7 @@ def select_move_and_logprob(model: MoveScorer, game_state, player: int):
return move, log_prob
def play_self_play_game(model: MoveScorer, max_turns: int = 500):
def play_self_play_game(model: MoveScorer, max_turns: int = 500, watch: bool = False):
"""
Self-play game with the same model as both players.
Returns:
@ -335,6 +335,8 @@ def play_self_play_game(model: MoveScorer, max_turns: int = 500):
# Update game_state tuple
game_state = (board, p1tiles, p2tiles)
if watch:
print_game_state(game_state)
# Store log_prob
log_probs[player].append(log_prob)
@ -348,13 +350,14 @@ def train(
num_episodes: int = 1000,
lr: float = 1e-3,
save_path: str = "trained_agent.pt",
watch: bool = False,
):
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
pbar = trange(1, num_episodes + 1, desc="Training", dynamic_ncols=True)
for episode in pbar:
log_probs_p1, log_probs_p2, r1, r2 = play_self_play_game(model)
log_probs_p1, log_probs_p2, r1, r2 = play_self_play_game(model, watch=watch)
loss = torch.tensor(0.0)
if log_probs_p1:
@ -460,7 +463,13 @@ def main():
play_vs_ai(model, human_is=1)
else:
# Train by self-play
train(model, num_episodes=1000, lr=1e-3, save_path="trained_agent.pt")
train(
model,
num_episodes=1000,
lr=1e-3,
save_path="trained_agent.pt",
watch="--watch" in sys.argv,
)
if __name__ == "__main__":