may the flake do it's job
This commit is contained in:
parent
7e1aaf7eee
commit
7b506195cc
6 changed files with 248 additions and 70 deletions
1
.envrc
1
.envrc
|
|
@ -1,2 +1 @@
|
||||||
use flake
|
use flake
|
||||||
source .venv/bin/activate
|
|
||||||
|
|
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -9,3 +9,6 @@ wheels/
|
||||||
# Virtual environments
|
# Virtual environments
|
||||||
.venv
|
.venv
|
||||||
.direnv
|
.direnv
|
||||||
|
|
||||||
|
result
|
||||||
|
target
|
||||||
|
|
|
||||||
59
blokus.py
59
blokus.py
|
|
@ -2,9 +2,14 @@
|
||||||
import random
|
import random
|
||||||
import game
|
import game
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
###############
|
||||||
|
# Utilities #
|
||||||
|
###############
|
||||||
|
|
||||||
BOARD_SIZE = 14
|
BOARD_SIZE = 14
|
||||||
|
|
||||||
|
|
||||||
tiles = game.game_tiles()
|
tiles = game.game_tiles()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -32,12 +37,62 @@ def print_game_state(game_state: tuple[game.Board, list[int], list[int]]):
|
||||||
print(f"Player 2 tiles left: {p2tiles}")
|
print(f"Player 2 tiles left: {p2tiles}")
|
||||||
|
|
||||||
|
|
||||||
|
###################
|
||||||
|
# Game state init #
|
||||||
|
###################
|
||||||
|
|
||||||
game_state = (
|
game_state = (
|
||||||
game.Board(),
|
game.Board(),
|
||||||
[i for i in range(21)],
|
[i for i in range(21)],
|
||||||
[i for i in range(21)],
|
[i for i in range(21)],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
###################
|
||||||
|
# RL Utils #
|
||||||
|
###################
|
||||||
|
|
||||||
|
|
||||||
|
class Saver:
|
||||||
|
def __init__(self, results_path, experiment_seed):
|
||||||
|
self.stats_file = {"train": {}, "test": {}}
|
||||||
|
self.exp_seed = experiment_seed
|
||||||
|
self.rpath = results_path
|
||||||
|
|
||||||
|
def get_new_episode(self, mode, episode_no):
|
||||||
|
if mode == "train":
|
||||||
|
self.stats_file[mode][episode_no] = {
|
||||||
|
"loss": [],
|
||||||
|
"actions": [],
|
||||||
|
"errors": [],
|
||||||
|
"errors_noiseless": [],
|
||||||
|
"done_threshold": 0,
|
||||||
|
"bond_distance": 0,
|
||||||
|
"nfev": [],
|
||||||
|
"opt_ang": [],
|
||||||
|
"time": [],
|
||||||
|
"save_circ": [],
|
||||||
|
"reward": [],
|
||||||
|
}
|
||||||
|
elif mode == "test":
|
||||||
|
self.stats_file[mode][episode_no] = {
|
||||||
|
"actions": [],
|
||||||
|
"errors": [],
|
||||||
|
"errors_noiseless": [],
|
||||||
|
"done_threshold": 0,
|
||||||
|
"bond_distance": 0,
|
||||||
|
"nfev": [],
|
||||||
|
"opt_ang": [],
|
||||||
|
"time": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
def save_file(self):
|
||||||
|
np.save(f"{self.rpath}/summary_{self.exp_seed}.npy", self.stats_file)
|
||||||
|
|
||||||
|
def validate_stats(self, episode, mode):
|
||||||
|
assert len(self.stats_file[mode][episode]["actions"]) == len(
|
||||||
|
self.stats_file[mode][episode]["errors"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
playing = True
|
playing = True
|
||||||
player = 1
|
player = 1
|
||||||
|
|
|
||||||
122
flake.lock
generated
122
flake.lock
generated
|
|
@ -1,21 +1,38 @@
|
||||||
{
|
{
|
||||||
"nodes": {
|
"nodes": {
|
||||||
"fix-py": {
|
"crane": {
|
||||||
"inputs": {
|
|
||||||
"flake-utils": "flake-utils",
|
|
||||||
"nixpkgs": "nixpkgs"
|
|
||||||
},
|
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1752329544,
|
"lastModified": 1764903584,
|
||||||
"narHash": "sha256-LtZyywexRDB5FFmEU6A1e5tMIb0MX8/xxpjCHonUzYE=",
|
"narHash": "sha256-RSkJtNtx0SEaQiYqsoFoRynwfZLo2OZ9z6rUq1DJR6g=",
|
||||||
"owner": "GuillaumeDesforges",
|
"owner": "ipetkov",
|
||||||
"repo": "fix-python",
|
"repo": "crane",
|
||||||
"rev": "248e2ea9620faee9b8a2ae10e12320de2a819fe9",
|
"rev": "2b3a5a88d852575758e1eb6ac9ee677fcd633fc1",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"owner": "GuillaumeDesforges",
|
"owner": "ipetkov",
|
||||||
"repo": "fix-python",
|
"repo": "crane",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fenix": {
|
||||||
|
"inputs": {
|
||||||
|
"nixpkgs": [
|
||||||
|
"nixpkgs"
|
||||||
|
],
|
||||||
|
"rust-analyzer-src": "rust-analyzer-src"
|
||||||
|
},
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1764915802,
|
||||||
|
"narHash": "sha256-eHTucU43sRCpvvTt5eey9htcWipS7ZN3B7ts6MiXLxo=",
|
||||||
|
"owner": "nix-community",
|
||||||
|
"repo": "fenix",
|
||||||
|
"rev": "a83a78fd3587d9f3388f0b459ad9c2bbd6d1b6d8",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "nix-community",
|
||||||
|
"repo": "fenix",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
@ -24,43 +41,30 @@
|
||||||
"systems": "systems"
|
"systems": "systems"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1689068808,
|
"lastModified": 1731533236,
|
||||||
"narHash": "sha256-6ixXo3wt24N/melDWjq70UuHQLxGV8jZvooRanIHXw0=",
|
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
||||||
"owner": "numtide",
|
"owner": "numtide",
|
||||||
"repo": "flake-utils",
|
"repo": "flake-utils",
|
||||||
"rev": "919d646de7be200f3bf08cb76ae1f09402b6f9b4",
|
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"id": "flake-utils",
|
"owner": "numtide",
|
||||||
"type": "indirect"
|
"repo": "flake-utils",
|
||||||
|
"type": "github"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nixpkgs": {
|
"nixpkgs": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1682786779,
|
"lastModified": 1764886837,
|
||||||
"narHash": "sha256-m7QFzPS/CE8hbkbIVK4UStihAQMtczr0vSpOgETOM1g=",
|
"narHash": "sha256-Sup1QVtp30cSa1FJLVkwTHEBV75XT+lWPvgHL0I7s1s=",
|
||||||
"owner": "nixos",
|
"owner": "nixos",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "08e4dc3a907a6dfec8bb3bbf1540d8abbffea22b",
|
"rev": "bd88d6c13ab85cc842b93d53f68d6d40412e5a18",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"id": "nixpkgs",
|
"owner": "nixos",
|
||||||
"type": "indirect"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nixpkgs_2": {
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1764167966,
|
|
||||||
"narHash": "sha256-nXv6xb7cq+XpjBYIjWEGTLCqQetxJu6zvVlrqHMsCOA=",
|
|
||||||
"owner": "NixOS",
|
|
||||||
"repo": "nixpkgs",
|
|
||||||
"rev": "5c46f3bd98147c8d82366df95bbef2cab3a967ea",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "NixOS",
|
|
||||||
"ref": "nixpkgs-unstable",
|
"ref": "nixpkgs-unstable",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
|
|
@ -68,27 +72,51 @@
|
||||||
},
|
},
|
||||||
"root": {
|
"root": {
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"fix-py": "fix-py",
|
"crane": "crane",
|
||||||
"nixpkgs": "nixpkgs_2",
|
"fenix": "fenix",
|
||||||
"systems": "systems_2"
|
"flake-utils": "flake-utils",
|
||||||
|
"nixpkgs": "nixpkgs",
|
||||||
|
"rust-overlay": "rust-overlay"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"systems": {
|
"rust-analyzer-src": {
|
||||||
|
"flake": false,
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1681028828,
|
"lastModified": 1764778537,
|
||||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
"narHash": "sha256-SNL+Fj1ZWiBqCrHJT1S9vMZujrWxCOmf3zkT66XSnhE=",
|
||||||
"owner": "nix-systems",
|
"owner": "rust-lang",
|
||||||
"repo": "default",
|
"repo": "rust-analyzer",
|
||||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
"rev": "633cff25206d5108043d87617a43c9d04aa42c88",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"owner": "nix-systems",
|
"owner": "rust-lang",
|
||||||
"repo": "default",
|
"ref": "nightly",
|
||||||
|
"repo": "rust-analyzer",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"systems_2": {
|
"rust-overlay": {
|
||||||
|
"inputs": {
|
||||||
|
"nixpkgs": [
|
||||||
|
"nixpkgs"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1764902447,
|
||||||
|
"narHash": "sha256-wNqkDBj+tjK619sTHPEA7uhjr7DHHEY8OsFou31dxy0=",
|
||||||
|
"owner": "oxalica",
|
||||||
|
"repo": "rust-overlay",
|
||||||
|
"rev": "d914a744a83098eeb28125d2848ad383b209223f",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "oxalica",
|
||||||
|
"repo": "rust-overlay",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"systems": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1681028828,
|
"lastModified": 1681028828,
|
||||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||||
|
|
|
||||||
129
flake.nix
129
flake.nix
|
|
@ -1,27 +1,118 @@
|
||||||
{
|
{
|
||||||
|
# Build Pyo3 package
|
||||||
inputs = {
|
inputs = {
|
||||||
nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable";
|
nixpkgs.url = "github:nixos/nixpkgs/nixpkgs-unstable";
|
||||||
systems.url = "github:nix-systems/default";
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
fix-py.url = "github:GuillaumeDesforges/fix-python";
|
rust-overlay = {
|
||||||
|
url = "github:oxalica/rust-overlay";
|
||||||
|
inputs.nixpkgs.follows = "nixpkgs";
|
||||||
|
};
|
||||||
|
crane.url = "github:ipetkov/crane";
|
||||||
|
fenix = {
|
||||||
|
url = "github:nix-community/fenix";
|
||||||
|
inputs.nixpkgs.follows = "nixpkgs";
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
outputs =
|
outputs =
|
||||||
{ nixpkgs, fix-py, ... }:
|
inputs:
|
||||||
let
|
inputs.flake-utils.lib.eachDefaultSystem (
|
||||||
eachSystem =
|
system:
|
||||||
f:
|
let
|
||||||
nixpkgs.lib.genAttrs nixpkgs.lib.systems.flakeExposed (system: f nixpkgs.legacyPackages.${system});
|
pkgs = import inputs.nixpkgs {
|
||||||
in
|
inherit system;
|
||||||
{
|
overlays = [ inputs.rust-overlay.overlays.default ];
|
||||||
devShells = eachSystem (pkgs: {
|
};
|
||||||
default = pkgs.mkShell {
|
lib = pkgs.lib;
|
||||||
buildInputs = [
|
|
||||||
pkgs.cargo
|
python_version = pkgs.python313;
|
||||||
pkgs.rustc
|
wheel_tail = "cp313-cp313-linux_x86_64"; # Change if python_version changes
|
||||||
pkgs.stdenv.cc.cc.lib
|
|
||||||
fix-py.packages.${pkgs.system}.default
|
# Get a custom rust toolchain
|
||||||
|
craneLib = (inputs.crane.mkLib pkgs).overrideToolchain (
|
||||||
|
p: inputs.fenix.packages.${system}.complete.toolchain
|
||||||
|
);
|
||||||
|
|
||||||
|
project_name = (craneLib.crateNameFromCargoToml { cargoToml = ./game/Cargo.toml; }).pname;
|
||||||
|
project_version = (craneLib.crateNameFromCargoToml { cargoToml = ./game/Cargo.toml; }).version;
|
||||||
|
|
||||||
|
crate_cfg = {
|
||||||
|
src =
|
||||||
|
let
|
||||||
|
fs = lib.fileset;
|
||||||
|
in
|
||||||
|
fs.toSource {
|
||||||
|
root = ./game;
|
||||||
|
fileset = fs.unions [
|
||||||
|
./game/src/lib.rs
|
||||||
|
./game/Cargo.lock
|
||||||
|
./game/Cargo.toml
|
||||||
|
./game/pyproject.toml
|
||||||
|
];
|
||||||
|
};
|
||||||
|
nativeBuildInputs = [ python_version ];
|
||||||
|
# doCheck = true;
|
||||||
|
# buildInputs = [];
|
||||||
|
};
|
||||||
|
|
||||||
|
crate_artifacts = craneLib.buildDepsOnly (
|
||||||
|
crate_cfg
|
||||||
|
// {
|
||||||
|
pname = "${project_name}-artifacts";
|
||||||
|
version = project_version;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
# Build the library, then re-use the target dir to generate the wheel file with maturin
|
||||||
|
crate_wheel =
|
||||||
|
(craneLib.buildPackage (
|
||||||
|
crate_cfg
|
||||||
|
// {
|
||||||
|
pname = project_name;
|
||||||
|
version = project_version;
|
||||||
|
cargoArtifacts = crate_artifacts;
|
||||||
|
}
|
||||||
|
)).overrideAttrs
|
||||||
|
(old: {
|
||||||
|
nativeBuildInputs = old.nativeBuildInputs ++ [ pkgs.maturin ];
|
||||||
|
buildPhase = old.buildPhase + ''
|
||||||
|
maturin build --offline --target-dir ./target
|
||||||
|
'';
|
||||||
|
installPhase = old.installPhase + ''
|
||||||
|
cp target/wheels/${project_name}-${project_version}-${wheel_tail}.whl $out/
|
||||||
|
'';
|
||||||
|
});
|
||||||
|
in
|
||||||
|
rec {
|
||||||
|
packages = {
|
||||||
|
default = crate_wheel; # The wheel itself
|
||||||
|
|
||||||
|
# A python version with the library installed
|
||||||
|
pythonpkg = python_version.withPackages (ps: [
|
||||||
|
(lib.python_package ps)
|
||||||
|
]);
|
||||||
|
};
|
||||||
|
devShells.default = craneLib.devShell {
|
||||||
|
packages = [
|
||||||
|
(pkgs.python3.withPackages (ppkgs: [
|
||||||
|
ppkgs.torch
|
||||||
|
(lib.python_package ppkgs)
|
||||||
|
]))
|
||||||
];
|
];
|
||||||
};
|
};
|
||||||
});
|
lib = {
|
||||||
};
|
# To use in other builds with the "withPackages" call
|
||||||
|
python_package =
|
||||||
|
ps:
|
||||||
|
ps.buildPythonPackage rec {
|
||||||
|
pname = project_name;
|
||||||
|
format = "wheel";
|
||||||
|
version = project_version;
|
||||||
|
src = "${crate_wheel}/${project_name}-${project_version}-${wheel_tail}.whl";
|
||||||
|
doCheck = false;
|
||||||
|
pythonImportsCheck = [ project_name ];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
}
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
|
#[allow(dead_code)]
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
|
|
||||||
#[pymodule]
|
#[pymodule]
|
||||||
mod game {
|
mod game {
|
||||||
use pyo3::{exceptions::PyIndexError, prelude::*, types::PySequence};
|
use pyo3::exceptions::PyIndexError;
|
||||||
|
use pyo3::prelude::*;
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashSet,
|
collections::HashSet,
|
||||||
fmt::{Display, Formatter},
|
fmt::{Display, Formatter},
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue