add plotting and checkpoints
This commit is contained in:
parent
6997ca3551
commit
65515d6d90
2 changed files with 65 additions and 2 deletions
66
blokus.py
66
blokus.py
|
|
@ -8,6 +8,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from tqdm.auto import trange
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
###############
|
||||
# Utilities #
|
||||
|
|
@ -44,6 +45,22 @@ def print_game_state(game_state: tuple[game.Board, list[int], list[int]]):
|
|||
print(f"Player 2 tiles left: {p2tiles}")
|
||||
|
||||
|
||||
def plot_losses(loss_history, out_path="loss_curve.png"):
|
||||
if not loss_history:
|
||||
print("No losses to plot.")
|
||||
return
|
||||
|
||||
plt.figure()
|
||||
plt.plot(range(1, len(loss_history) + 1), loss_history)
|
||||
plt.xlabel("Episode")
|
||||
plt.ylabel("Loss")
|
||||
plt.title("Training loss over episodes")
|
||||
plt.tight_layout()
|
||||
plt.savefig(out_path)
|
||||
plt.close()
|
||||
print(f"Saved loss plot to {out_path}")
|
||||
|
||||
|
||||
###################
|
||||
# Game state init #
|
||||
###################
|
||||
|
|
@ -354,7 +371,33 @@ def train(
|
|||
):
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
||||
|
||||
pbar = trange(1, num_episodes + 1, desc="Training", dynamic_ncols=True)
|
||||
# We'll keep a history of losses for plotting
|
||||
loss_history = []
|
||||
|
||||
# Checkpoint path (partial training state)
|
||||
ckpt_path = save_path + ".ckpt"
|
||||
|
||||
start_episode = 1
|
||||
|
||||
# Try to resume from checkpoint if it exists
|
||||
if os.path.exists(ckpt_path):
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||
model.load_state_dict(ckpt["model_state"])
|
||||
optimizer.load_state_dict(ckpt["optimizer_state"])
|
||||
start_episode = ckpt["episode"] + 1
|
||||
loss_history = ckpt.get("loss_history", [])
|
||||
print(f"Resuming training from episode {start_episode} (found checkpoint).")
|
||||
|
||||
# If we've already passed num_episodes, just plot and exit
|
||||
if start_episode > num_episodes:
|
||||
print(
|
||||
"Checkpoint episode exceeds requested num_episodes; nothing to train."
|
||||
)
|
||||
plot_losses(loss_history, out_path="loss_curve.png")
|
||||
torch.save(model.state_dict(), save_path)
|
||||
return
|
||||
|
||||
pbar = trange(start_episode, num_episodes + 1, desc="Training", dynamic_ncols=True)
|
||||
|
||||
for episode in pbar:
|
||||
log_probs_p1, log_probs_p2, r1, r2 = play_self_play_game(model, watch=watch)
|
||||
|
|
@ -369,17 +412,36 @@ def train(
|
|||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
loss_value = float(loss.item())
|
||||
loss_history.append(loss_value)
|
||||
|
||||
# Update progress bar with most recent stats
|
||||
pbar.set_postfix(
|
||||
episode=episode,
|
||||
loss=float(loss.item()),
|
||||
loss=loss_value,
|
||||
r1=float(r1),
|
||||
r2=float(r2),
|
||||
)
|
||||
|
||||
# Save checkpoint every N episodes (and at the very end)
|
||||
if episode % 50 == 0 or episode == num_episodes:
|
||||
torch.save(
|
||||
{
|
||||
"episode": episode,
|
||||
"model_state": model.state_dict(),
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
"loss_history": loss_history,
|
||||
},
|
||||
ckpt_path,
|
||||
)
|
||||
|
||||
# Final model save
|
||||
torch.save(model.state_dict(), save_path)
|
||||
print(f"\nTraining finished. Model saved to {save_path}")
|
||||
|
||||
# Save final loss plot
|
||||
plot_losses(loss_history, out_path="loss_curve.png")
|
||||
|
||||
|
||||
###################
|
||||
# Play vs the AI #
|
||||
|
|
|
|||
|
|
@ -101,6 +101,7 @@
|
|||
(pkgs.python3.withPackages (ppkgs: [
|
||||
ppkgs.torch
|
||||
ppkgs.tqdm
|
||||
ppkgs.matplotlib
|
||||
(lib.python_package ppkgs)
|
||||
]))
|
||||
];
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue