From 65515d6d90a8a05b9abe12318bcb3a2d41160b36 Mon Sep 17 00:00:00 2001 From: Noa Aarts Date: Fri, 5 Dec 2025 20:56:12 +0100 Subject: [PATCH] add plotting and checkpoints --- blokus.py | 66 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- flake.nix | 1 + 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/blokus.py b/blokus.py index 0703129..53b4b1e 100755 --- a/blokus.py +++ b/blokus.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from tqdm.auto import trange +import matplotlib.pyplot as plt ############### # Utilities # @@ -44,6 +45,22 @@ def print_game_state(game_state: tuple[game.Board, list[int], list[int]]): print(f"Player 2 tiles left: {p2tiles}") +def plot_losses(loss_history, out_path="loss_curve.png"): + if not loss_history: + print("No losses to plot.") + return + + plt.figure() + plt.plot(range(1, len(loss_history) + 1), loss_history) + plt.xlabel("Episode") + plt.ylabel("Loss") + plt.title("Training loss over episodes") + plt.tight_layout() + plt.savefig(out_path) + plt.close() + print(f"Saved loss plot to {out_path}") + + ################### # Game state init # ################### @@ -354,7 +371,33 @@ def train( ): optimizer = torch.optim.Adam(model.parameters(), lr=lr) - pbar = trange(1, num_episodes + 1, desc="Training", dynamic_ncols=True) + # We'll keep a history of losses for plotting + loss_history = [] + + # Checkpoint path (partial training state) + ckpt_path = save_path + ".ckpt" + + start_episode = 1 + + # Try to resume from checkpoint if it exists + if os.path.exists(ckpt_path): + ckpt = torch.load(ckpt_path, map_location="cpu") + model.load_state_dict(ckpt["model_state"]) + optimizer.load_state_dict(ckpt["optimizer_state"]) + start_episode = ckpt["episode"] + 1 + loss_history = ckpt.get("loss_history", []) + print(f"Resuming training from episode {start_episode} (found checkpoint).") + + # If we've already passed num_episodes, just plot and exit + if start_episode > num_episodes: + print( + "Checkpoint episode exceeds requested num_episodes; nothing to train." + ) + plot_losses(loss_history, out_path="loss_curve.png") + torch.save(model.state_dict(), save_path) + return + + pbar = trange(start_episode, num_episodes + 1, desc="Training", dynamic_ncols=True) for episode in pbar: log_probs_p1, log_probs_p2, r1, r2 = play_self_play_game(model, watch=watch) @@ -369,17 +412,36 @@ def train( loss.backward() optimizer.step() + loss_value = float(loss.item()) + loss_history.append(loss_value) + # Update progress bar with most recent stats pbar.set_postfix( episode=episode, - loss=float(loss.item()), + loss=loss_value, r1=float(r1), r2=float(r2), ) + # Save checkpoint every N episodes (and at the very end) + if episode % 50 == 0 or episode == num_episodes: + torch.save( + { + "episode": episode, + "model_state": model.state_dict(), + "optimizer_state": optimizer.state_dict(), + "loss_history": loss_history, + }, + ckpt_path, + ) + + # Final model save torch.save(model.state_dict(), save_path) print(f"\nTraining finished. Model saved to {save_path}") + # Save final loss plot + plot_losses(loss_history, out_path="loss_curve.png") + ################### # Play vs the AI # diff --git a/flake.nix b/flake.nix index 6cbb1b9..386ac5c 100644 --- a/flake.nix +++ b/flake.nix @@ -101,6 +101,7 @@ (pkgs.python3.withPackages (ppkgs: [ ppkgs.torch ppkgs.tqdm + ppkgs.matplotlib (lib.python_package ppkgs) ])) ];