add --watch system option
This commit is contained in:
parent
e2d5ca264b
commit
12b0c31239
1 changed files with 12 additions and 3 deletions
15
blokus.py
15
blokus.py
|
|
@ -289,7 +289,7 @@ def select_move_and_logprob(model: MoveScorer, game_state, player: int):
|
||||||
return move, log_prob
|
return move, log_prob
|
||||||
|
|
||||||
|
|
||||||
def play_self_play_game(model: MoveScorer, max_turns: int = 500):
|
def play_self_play_game(model: MoveScorer, max_turns: int = 500, watch: bool = False):
|
||||||
"""
|
"""
|
||||||
Self-play game with the same model as both players.
|
Self-play game with the same model as both players.
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -335,6 +335,8 @@ def play_self_play_game(model: MoveScorer, max_turns: int = 500):
|
||||||
|
|
||||||
# Update game_state tuple
|
# Update game_state tuple
|
||||||
game_state = (board, p1tiles, p2tiles)
|
game_state = (board, p1tiles, p2tiles)
|
||||||
|
if watch:
|
||||||
|
print_game_state(game_state)
|
||||||
|
|
||||||
# Store log_prob
|
# Store log_prob
|
||||||
log_probs[player].append(log_prob)
|
log_probs[player].append(log_prob)
|
||||||
|
|
@ -348,13 +350,14 @@ def train(
|
||||||
num_episodes: int = 1000,
|
num_episodes: int = 1000,
|
||||||
lr: float = 1e-3,
|
lr: float = 1e-3,
|
||||||
save_path: str = "trained_agent.pt",
|
save_path: str = "trained_agent.pt",
|
||||||
|
watch: bool = False,
|
||||||
):
|
):
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
||||||
|
|
||||||
pbar = trange(1, num_episodes + 1, desc="Training", dynamic_ncols=True)
|
pbar = trange(1, num_episodes + 1, desc="Training", dynamic_ncols=True)
|
||||||
|
|
||||||
for episode in pbar:
|
for episode in pbar:
|
||||||
log_probs_p1, log_probs_p2, r1, r2 = play_self_play_game(model)
|
log_probs_p1, log_probs_p2, r1, r2 = play_self_play_game(model, watch=watch)
|
||||||
|
|
||||||
loss = torch.tensor(0.0)
|
loss = torch.tensor(0.0)
|
||||||
if log_probs_p1:
|
if log_probs_p1:
|
||||||
|
|
@ -460,7 +463,13 @@ def main():
|
||||||
play_vs_ai(model, human_is=1)
|
play_vs_ai(model, human_is=1)
|
||||||
else:
|
else:
|
||||||
# Train by self-play
|
# Train by self-play
|
||||||
train(model, num_episodes=1000, lr=1e-3, save_path="trained_agent.pt")
|
train(
|
||||||
|
model,
|
||||||
|
num_episodes=1000,
|
||||||
|
lr=1e-3,
|
||||||
|
save_path="trained_agent.pt",
|
||||||
|
watch="--watch" in sys.argv,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue