From 7b506195cc8b2250ec8a4729711f207178785c2f Mon Sep 17 00:00:00 2001 From: Noa Aarts Date: Fri, 5 Dec 2025 15:48:20 +0100 Subject: [PATCH] may the flake do it's job --- .envrc | 1 - .gitignore | 3 ++ blokus.py | 59 +++++++++++++++++++++- flake.lock | 122 +++++++++++++++++++++++++++------------------ flake.nix | 129 +++++++++++++++++++++++++++++++++++++++++------- game/src/lib.rs | 4 +- 6 files changed, 248 insertions(+), 70 deletions(-) diff --git a/.envrc b/.envrc index 34e6b2a..3550a30 100644 --- a/.envrc +++ b/.envrc @@ -1,2 +1 @@ use flake -source .venv/bin/activate diff --git a/.gitignore b/.gitignore index d2302b9..8c41c48 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,6 @@ wheels/ # Virtual environments .venv .direnv + +result +target diff --git a/blokus.py b/blokus.py index 3890f9f..64b8189 100755 --- a/blokus.py +++ b/blokus.py @@ -2,9 +2,14 @@ import random import game +import numpy as np +import torch + +############### +# Utilities # +############### + BOARD_SIZE = 14 - - tiles = game.game_tiles() @@ -32,12 +37,62 @@ def print_game_state(game_state: tuple[game.Board, list[int], list[int]]): print(f"Player 2 tiles left: {p2tiles}") +################### +# Game state init # +################### + game_state = ( game.Board(), [i for i in range(21)], [i for i in range(21)], ) +################### +# RL Utils # +################### + + +class Saver: + def __init__(self, results_path, experiment_seed): + self.stats_file = {"train": {}, "test": {}} + self.exp_seed = experiment_seed + self.rpath = results_path + + def get_new_episode(self, mode, episode_no): + if mode == "train": + self.stats_file[mode][episode_no] = { + "loss": [], + "actions": [], + "errors": [], + "errors_noiseless": [], + "done_threshold": 0, + "bond_distance": 0, + "nfev": [], + "opt_ang": [], + "time": [], + "save_circ": [], + "reward": [], + } + elif mode == "test": + self.stats_file[mode][episode_no] = { + "actions": [], + "errors": [], + "errors_noiseless": [], + "done_threshold": 0, + "bond_distance": 0, + "nfev": [], + "opt_ang": [], + "time": [], + } + + def save_file(self): + np.save(f"{self.rpath}/summary_{self.exp_seed}.npy", self.stats_file) + + def validate_stats(self, episode, mode): + assert len(self.stats_file[mode][episode]["actions"]) == len( + self.stats_file[mode][episode]["errors"] + ) + playing = True player = 1 diff --git a/flake.lock b/flake.lock index 779a10e..c49d635 100644 --- a/flake.lock +++ b/flake.lock @@ -1,21 +1,38 @@ { "nodes": { - "fix-py": { - "inputs": { - "flake-utils": "flake-utils", - "nixpkgs": "nixpkgs" - }, + "crane": { "locked": { - "lastModified": 1752329544, - "narHash": "sha256-LtZyywexRDB5FFmEU6A1e5tMIb0MX8/xxpjCHonUzYE=", - "owner": "GuillaumeDesforges", - "repo": "fix-python", - "rev": "248e2ea9620faee9b8a2ae10e12320de2a819fe9", + "lastModified": 1764903584, + "narHash": "sha256-RSkJtNtx0SEaQiYqsoFoRynwfZLo2OZ9z6rUq1DJR6g=", + "owner": "ipetkov", + "repo": "crane", + "rev": "2b3a5a88d852575758e1eb6ac9ee677fcd633fc1", "type": "github" }, "original": { - "owner": "GuillaumeDesforges", - "repo": "fix-python", + "owner": "ipetkov", + "repo": "crane", + "type": "github" + } + }, + "fenix": { + "inputs": { + "nixpkgs": [ + "nixpkgs" + ], + "rust-analyzer-src": "rust-analyzer-src" + }, + "locked": { + "lastModified": 1764915802, + "narHash": "sha256-eHTucU43sRCpvvTt5eey9htcWipS7ZN3B7ts6MiXLxo=", + "owner": "nix-community", + "repo": "fenix", + "rev": "a83a78fd3587d9f3388f0b459ad9c2bbd6d1b6d8", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "fenix", "type": "github" } }, @@ -24,43 +41,30 @@ "systems": "systems" }, "locked": { - "lastModified": 1689068808, - "narHash": "sha256-6ixXo3wt24N/melDWjq70UuHQLxGV8jZvooRanIHXw0=", + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", "owner": "numtide", "repo": "flake-utils", - "rev": "919d646de7be200f3bf08cb76ae1f09402b6f9b4", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", "type": "github" }, "original": { - "id": "flake-utils", - "type": "indirect" + "owner": "numtide", + "repo": "flake-utils", + "type": "github" } }, "nixpkgs": { "locked": { - "lastModified": 1682786779, - "narHash": "sha256-m7QFzPS/CE8hbkbIVK4UStihAQMtczr0vSpOgETOM1g=", + "lastModified": 1764886837, + "narHash": "sha256-Sup1QVtp30cSa1FJLVkwTHEBV75XT+lWPvgHL0I7s1s=", "owner": "nixos", "repo": "nixpkgs", - "rev": "08e4dc3a907a6dfec8bb3bbf1540d8abbffea22b", + "rev": "bd88d6c13ab85cc842b93d53f68d6d40412e5a18", "type": "github" }, "original": { - "id": "nixpkgs", - "type": "indirect" - } - }, - "nixpkgs_2": { - "locked": { - "lastModified": 1764167966, - "narHash": "sha256-nXv6xb7cq+XpjBYIjWEGTLCqQetxJu6zvVlrqHMsCOA=", - "owner": "NixOS", - "repo": "nixpkgs", - "rev": "5c46f3bd98147c8d82366df95bbef2cab3a967ea", - "type": "github" - }, - "original": { - "owner": "NixOS", + "owner": "nixos", "ref": "nixpkgs-unstable", "repo": "nixpkgs", "type": "github" @@ -68,27 +72,51 @@ }, "root": { "inputs": { - "fix-py": "fix-py", - "nixpkgs": "nixpkgs_2", - "systems": "systems_2" + "crane": "crane", + "fenix": "fenix", + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs", + "rust-overlay": "rust-overlay" } }, - "systems": { + "rust-analyzer-src": { + "flake": false, "locked": { - "lastModified": 1681028828, - "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", - "owner": "nix-systems", - "repo": "default", - "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "lastModified": 1764778537, + "narHash": "sha256-SNL+Fj1ZWiBqCrHJT1S9vMZujrWxCOmf3zkT66XSnhE=", + "owner": "rust-lang", + "repo": "rust-analyzer", + "rev": "633cff25206d5108043d87617a43c9d04aa42c88", "type": "github" }, "original": { - "owner": "nix-systems", - "repo": "default", + "owner": "rust-lang", + "ref": "nightly", + "repo": "rust-analyzer", "type": "github" } }, - "systems_2": { + "rust-overlay": { + "inputs": { + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1764902447, + "narHash": "sha256-wNqkDBj+tjK619sTHPEA7uhjr7DHHEY8OsFou31dxy0=", + "owner": "oxalica", + "repo": "rust-overlay", + "rev": "d914a744a83098eeb28125d2848ad383b209223f", + "type": "github" + }, + "original": { + "owner": "oxalica", + "repo": "rust-overlay", + "type": "github" + } + }, + "systems": { "locked": { "lastModified": 1681028828, "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", diff --git a/flake.nix b/flake.nix index d8fba7b..bfc5079 100644 --- a/flake.nix +++ b/flake.nix @@ -1,27 +1,118 @@ { + # Build Pyo3 package inputs = { - nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; - systems.url = "github:nix-systems/default"; - fix-py.url = "github:GuillaumeDesforges/fix-python"; + nixpkgs.url = "github:nixos/nixpkgs/nixpkgs-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + rust-overlay = { + url = "github:oxalica/rust-overlay"; + inputs.nixpkgs.follows = "nixpkgs"; + }; + crane.url = "github:ipetkov/crane"; + fenix = { + url = "github:nix-community/fenix"; + inputs.nixpkgs.follows = "nixpkgs"; + }; }; outputs = - { nixpkgs, fix-py, ... }: - let - eachSystem = - f: - nixpkgs.lib.genAttrs nixpkgs.lib.systems.flakeExposed (system: f nixpkgs.legacyPackages.${system}); - in - { - devShells = eachSystem (pkgs: { - default = pkgs.mkShell { - buildInputs = [ - pkgs.cargo - pkgs.rustc - pkgs.stdenv.cc.cc.lib - fix-py.packages.${pkgs.system}.default + inputs: + inputs.flake-utils.lib.eachDefaultSystem ( + system: + let + pkgs = import inputs.nixpkgs { + inherit system; + overlays = [ inputs.rust-overlay.overlays.default ]; + }; + lib = pkgs.lib; + + python_version = pkgs.python313; + wheel_tail = "cp313-cp313-linux_x86_64"; # Change if python_version changes + + # Get a custom rust toolchain + craneLib = (inputs.crane.mkLib pkgs).overrideToolchain ( + p: inputs.fenix.packages.${system}.complete.toolchain + ); + + project_name = (craneLib.crateNameFromCargoToml { cargoToml = ./game/Cargo.toml; }).pname; + project_version = (craneLib.crateNameFromCargoToml { cargoToml = ./game/Cargo.toml; }).version; + + crate_cfg = { + src = + let + fs = lib.fileset; + in + fs.toSource { + root = ./game; + fileset = fs.unions [ + ./game/src/lib.rs + ./game/Cargo.lock + ./game/Cargo.toml + ./game/pyproject.toml + ]; + }; + nativeBuildInputs = [ python_version ]; + # doCheck = true; + # buildInputs = []; + }; + + crate_artifacts = craneLib.buildDepsOnly ( + crate_cfg + // { + pname = "${project_name}-artifacts"; + version = project_version; + } + ); + + # Build the library, then re-use the target dir to generate the wheel file with maturin + crate_wheel = + (craneLib.buildPackage ( + crate_cfg + // { + pname = project_name; + version = project_version; + cargoArtifacts = crate_artifacts; + } + )).overrideAttrs + (old: { + nativeBuildInputs = old.nativeBuildInputs ++ [ pkgs.maturin ]; + buildPhase = old.buildPhase + '' + maturin build --offline --target-dir ./target + ''; + installPhase = old.installPhase + '' + cp target/wheels/${project_name}-${project_version}-${wheel_tail}.whl $out/ + ''; + }); + in + rec { + packages = { + default = crate_wheel; # The wheel itself + + # A python version with the library installed + pythonpkg = python_version.withPackages (ps: [ + (lib.python_package ps) + ]); + }; + devShells.default = craneLib.devShell { + packages = [ + (pkgs.python3.withPackages (ppkgs: [ + ppkgs.torch + (lib.python_package ppkgs) + ])) ]; }; - }); - }; + lib = { + # To use in other builds with the "withPackages" call + python_package = + ps: + ps.buildPythonPackage rec { + pname = project_name; + format = "wheel"; + version = project_version; + src = "${crate_wheel}/${project_name}-${project_version}-${wheel_tail}.whl"; + doCheck = false; + pythonImportsCheck = [ project_name ]; + }; + }; + } + ); } diff --git a/game/src/lib.rs b/game/src/lib.rs index c6298bf..18440d0 100644 --- a/game/src/lib.rs +++ b/game/src/lib.rs @@ -1,8 +1,10 @@ +#[allow(dead_code)] use pyo3::prelude::*; #[pymodule] mod game { - use pyo3::{exceptions::PyIndexError, prelude::*, types::PySequence}; + use pyo3::exceptions::PyIndexError; + use pyo3::prelude::*; use std::{ collections::HashSet, fmt::{Display, Formatter},