add plotting and checkpoints

This commit is contained in:
Noa Aarts 2025-12-05 20:56:12 +01:00
parent 6997ca3551
commit 65515d6d90
Signed by: noa
GPG key ID: 1850932741EFF672
2 changed files with 65 additions and 2 deletions

View file

@ -8,6 +8,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import trange
import matplotlib.pyplot as plt
###############
# Utilities #
@ -44,6 +45,22 @@ def print_game_state(game_state: tuple[game.Board, list[int], list[int]]):
print(f"Player 2 tiles left: {p2tiles}")
def plot_losses(loss_history, out_path="loss_curve.png"):
if not loss_history:
print("No losses to plot.")
return
plt.figure()
plt.plot(range(1, len(loss_history) + 1), loss_history)
plt.xlabel("Episode")
plt.ylabel("Loss")
plt.title("Training loss over episodes")
plt.tight_layout()
plt.savefig(out_path)
plt.close()
print(f"Saved loss plot to {out_path}")
###################
# Game state init #
###################
@ -354,7 +371,33 @@ def train(
):
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
pbar = trange(1, num_episodes + 1, desc="Training", dynamic_ncols=True)
# We'll keep a history of losses for plotting
loss_history = []
# Checkpoint path (partial training state)
ckpt_path = save_path + ".ckpt"
start_episode = 1
# Try to resume from checkpoint if it exists
if os.path.exists(ckpt_path):
ckpt = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(ckpt["model_state"])
optimizer.load_state_dict(ckpt["optimizer_state"])
start_episode = ckpt["episode"] + 1
loss_history = ckpt.get("loss_history", [])
print(f"Resuming training from episode {start_episode} (found checkpoint).")
# If we've already passed num_episodes, just plot and exit
if start_episode > num_episodes:
print(
"Checkpoint episode exceeds requested num_episodes; nothing to train."
)
plot_losses(loss_history, out_path="loss_curve.png")
torch.save(model.state_dict(), save_path)
return
pbar = trange(start_episode, num_episodes + 1, desc="Training", dynamic_ncols=True)
for episode in pbar:
log_probs_p1, log_probs_p2, r1, r2 = play_self_play_game(model, watch=watch)
@ -369,17 +412,36 @@ def train(
loss.backward()
optimizer.step()
loss_value = float(loss.item())
loss_history.append(loss_value)
# Update progress bar with most recent stats
pbar.set_postfix(
episode=episode,
loss=float(loss.item()),
loss=loss_value,
r1=float(r1),
r2=float(r2),
)
# Save checkpoint every N episodes (and at the very end)
if episode % 50 == 0 or episode == num_episodes:
torch.save(
{
"episode": episode,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
"loss_history": loss_history,
},
ckpt_path,
)
# Final model save
torch.save(model.state_dict(), save_path)
print(f"\nTraining finished. Model saved to {save_path}")
# Save final loss plot
plot_losses(loss_history, out_path="loss_curve.png")
###################
# Play vs the AI #

View file

@ -101,6 +101,7 @@
(pkgs.python3.withPackages (ppkgs: [
ppkgs.torch
ppkgs.tqdm
ppkgs.matplotlib
(lib.python_package ppkgs)
]))
];