add --watch system option
This commit is contained in:
parent
e2d5ca264b
commit
12b0c31239
1 changed files with 12 additions and 3 deletions
15
blokus.py
15
blokus.py
|
|
@ -289,7 +289,7 @@ def select_move_and_logprob(model: MoveScorer, game_state, player: int):
|
|||
return move, log_prob
|
||||
|
||||
|
||||
def play_self_play_game(model: MoveScorer, max_turns: int = 500):
|
||||
def play_self_play_game(model: MoveScorer, max_turns: int = 500, watch: bool = False):
|
||||
"""
|
||||
Self-play game with the same model as both players.
|
||||
Returns:
|
||||
|
|
@ -335,6 +335,8 @@ def play_self_play_game(model: MoveScorer, max_turns: int = 500):
|
|||
|
||||
# Update game_state tuple
|
||||
game_state = (board, p1tiles, p2tiles)
|
||||
if watch:
|
||||
print_game_state(game_state)
|
||||
|
||||
# Store log_prob
|
||||
log_probs[player].append(log_prob)
|
||||
|
|
@ -348,13 +350,14 @@ def train(
|
|||
num_episodes: int = 1000,
|
||||
lr: float = 1e-3,
|
||||
save_path: str = "trained_agent.pt",
|
||||
watch: bool = False,
|
||||
):
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
||||
|
||||
pbar = trange(1, num_episodes + 1, desc="Training", dynamic_ncols=True)
|
||||
|
||||
for episode in pbar:
|
||||
log_probs_p1, log_probs_p2, r1, r2 = play_self_play_game(model)
|
||||
log_probs_p1, log_probs_p2, r1, r2 = play_self_play_game(model, watch=watch)
|
||||
|
||||
loss = torch.tensor(0.0)
|
||||
if log_probs_p1:
|
||||
|
|
@ -460,7 +463,13 @@ def main():
|
|||
play_vs_ai(model, human_is=1)
|
||||
else:
|
||||
# Train by self-play
|
||||
train(model, num_episodes=1000, lr=1e-3, save_path="trained_agent.pt")
|
||||
train(
|
||||
model,
|
||||
num_episodes=1000,
|
||||
lr=1e-3,
|
||||
save_path="trained_agent.pt",
|
||||
watch="--watch" in sys.argv,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue