add plotting and checkpoints
This commit is contained in:
parent
6997ca3551
commit
65515d6d90
2 changed files with 65 additions and 2 deletions
66
blokus.py
66
blokus.py
|
|
@ -8,6 +8,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from tqdm.auto import trange
|
from tqdm.auto import trange
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
###############
|
###############
|
||||||
# Utilities #
|
# Utilities #
|
||||||
|
|
@ -44,6 +45,22 @@ def print_game_state(game_state: tuple[game.Board, list[int], list[int]]):
|
||||||
print(f"Player 2 tiles left: {p2tiles}")
|
print(f"Player 2 tiles left: {p2tiles}")
|
||||||
|
|
||||||
|
|
||||||
|
def plot_losses(loss_history, out_path="loss_curve.png"):
|
||||||
|
if not loss_history:
|
||||||
|
print("No losses to plot.")
|
||||||
|
return
|
||||||
|
|
||||||
|
plt.figure()
|
||||||
|
plt.plot(range(1, len(loss_history) + 1), loss_history)
|
||||||
|
plt.xlabel("Episode")
|
||||||
|
plt.ylabel("Loss")
|
||||||
|
plt.title("Training loss over episodes")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(out_path)
|
||||||
|
plt.close()
|
||||||
|
print(f"Saved loss plot to {out_path}")
|
||||||
|
|
||||||
|
|
||||||
###################
|
###################
|
||||||
# Game state init #
|
# Game state init #
|
||||||
###################
|
###################
|
||||||
|
|
@ -354,7 +371,33 @@ def train(
|
||||||
):
|
):
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
||||||
|
|
||||||
pbar = trange(1, num_episodes + 1, desc="Training", dynamic_ncols=True)
|
# We'll keep a history of losses for plotting
|
||||||
|
loss_history = []
|
||||||
|
|
||||||
|
# Checkpoint path (partial training state)
|
||||||
|
ckpt_path = save_path + ".ckpt"
|
||||||
|
|
||||||
|
start_episode = 1
|
||||||
|
|
||||||
|
# Try to resume from checkpoint if it exists
|
||||||
|
if os.path.exists(ckpt_path):
|
||||||
|
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||||
|
model.load_state_dict(ckpt["model_state"])
|
||||||
|
optimizer.load_state_dict(ckpt["optimizer_state"])
|
||||||
|
start_episode = ckpt["episode"] + 1
|
||||||
|
loss_history = ckpt.get("loss_history", [])
|
||||||
|
print(f"Resuming training from episode {start_episode} (found checkpoint).")
|
||||||
|
|
||||||
|
# If we've already passed num_episodes, just plot and exit
|
||||||
|
if start_episode > num_episodes:
|
||||||
|
print(
|
||||||
|
"Checkpoint episode exceeds requested num_episodes; nothing to train."
|
||||||
|
)
|
||||||
|
plot_losses(loss_history, out_path="loss_curve.png")
|
||||||
|
torch.save(model.state_dict(), save_path)
|
||||||
|
return
|
||||||
|
|
||||||
|
pbar = trange(start_episode, num_episodes + 1, desc="Training", dynamic_ncols=True)
|
||||||
|
|
||||||
for episode in pbar:
|
for episode in pbar:
|
||||||
log_probs_p1, log_probs_p2, r1, r2 = play_self_play_game(model, watch=watch)
|
log_probs_p1, log_probs_p2, r1, r2 = play_self_play_game(model, watch=watch)
|
||||||
|
|
@ -369,17 +412,36 @@ def train(
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
loss_value = float(loss.item())
|
||||||
|
loss_history.append(loss_value)
|
||||||
|
|
||||||
# Update progress bar with most recent stats
|
# Update progress bar with most recent stats
|
||||||
pbar.set_postfix(
|
pbar.set_postfix(
|
||||||
episode=episode,
|
episode=episode,
|
||||||
loss=float(loss.item()),
|
loss=loss_value,
|
||||||
r1=float(r1),
|
r1=float(r1),
|
||||||
r2=float(r2),
|
r2=float(r2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Save checkpoint every N episodes (and at the very end)
|
||||||
|
if episode % 50 == 0 or episode == num_episodes:
|
||||||
|
torch.save(
|
||||||
|
{
|
||||||
|
"episode": episode,
|
||||||
|
"model_state": model.state_dict(),
|
||||||
|
"optimizer_state": optimizer.state_dict(),
|
||||||
|
"loss_history": loss_history,
|
||||||
|
},
|
||||||
|
ckpt_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Final model save
|
||||||
torch.save(model.state_dict(), save_path)
|
torch.save(model.state_dict(), save_path)
|
||||||
print(f"\nTraining finished. Model saved to {save_path}")
|
print(f"\nTraining finished. Model saved to {save_path}")
|
||||||
|
|
||||||
|
# Save final loss plot
|
||||||
|
plot_losses(loss_history, out_path="loss_curve.png")
|
||||||
|
|
||||||
|
|
||||||
###################
|
###################
|
||||||
# Play vs the AI #
|
# Play vs the AI #
|
||||||
|
|
|
||||||
|
|
@ -101,6 +101,7 @@
|
||||||
(pkgs.python3.withPackages (ppkgs: [
|
(pkgs.python3.withPackages (ppkgs: [
|
||||||
ppkgs.torch
|
ppkgs.torch
|
||||||
ppkgs.tqdm
|
ppkgs.tqdm
|
||||||
|
ppkgs.matplotlib
|
||||||
(lib.python_package ppkgs)
|
(lib.python_package ppkgs)
|
||||||
]))
|
]))
|
||||||
];
|
];
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue