may the flake do it's job
This commit is contained in:
parent
7e1aaf7eee
commit
7b506195cc
6 changed files with 248 additions and 70 deletions
1
.envrc
1
.envrc
|
|
@ -1,2 +1 @@
|
|||
use flake
|
||||
source .venv/bin/activate
|
||||
|
|
|
|||
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -9,3 +9,6 @@ wheels/
|
|||
# Virtual environments
|
||||
.venv
|
||||
.direnv
|
||||
|
||||
result
|
||||
target
|
||||
|
|
|
|||
59
blokus.py
59
blokus.py
|
|
@ -2,9 +2,14 @@
|
|||
import random
|
||||
import game
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
###############
|
||||
# Utilities #
|
||||
###############
|
||||
|
||||
BOARD_SIZE = 14
|
||||
|
||||
|
||||
tiles = game.game_tiles()
|
||||
|
||||
|
||||
|
|
@ -32,12 +37,62 @@ def print_game_state(game_state: tuple[game.Board, list[int], list[int]]):
|
|||
print(f"Player 2 tiles left: {p2tiles}")
|
||||
|
||||
|
||||
###################
|
||||
# Game state init #
|
||||
###################
|
||||
|
||||
game_state = (
|
||||
game.Board(),
|
||||
[i for i in range(21)],
|
||||
[i for i in range(21)],
|
||||
)
|
||||
|
||||
###################
|
||||
# RL Utils #
|
||||
###################
|
||||
|
||||
|
||||
class Saver:
|
||||
def __init__(self, results_path, experiment_seed):
|
||||
self.stats_file = {"train": {}, "test": {}}
|
||||
self.exp_seed = experiment_seed
|
||||
self.rpath = results_path
|
||||
|
||||
def get_new_episode(self, mode, episode_no):
|
||||
if mode == "train":
|
||||
self.stats_file[mode][episode_no] = {
|
||||
"loss": [],
|
||||
"actions": [],
|
||||
"errors": [],
|
||||
"errors_noiseless": [],
|
||||
"done_threshold": 0,
|
||||
"bond_distance": 0,
|
||||
"nfev": [],
|
||||
"opt_ang": [],
|
||||
"time": [],
|
||||
"save_circ": [],
|
||||
"reward": [],
|
||||
}
|
||||
elif mode == "test":
|
||||
self.stats_file[mode][episode_no] = {
|
||||
"actions": [],
|
||||
"errors": [],
|
||||
"errors_noiseless": [],
|
||||
"done_threshold": 0,
|
||||
"bond_distance": 0,
|
||||
"nfev": [],
|
||||
"opt_ang": [],
|
||||
"time": [],
|
||||
}
|
||||
|
||||
def save_file(self):
|
||||
np.save(f"{self.rpath}/summary_{self.exp_seed}.npy", self.stats_file)
|
||||
|
||||
def validate_stats(self, episode, mode):
|
||||
assert len(self.stats_file[mode][episode]["actions"]) == len(
|
||||
self.stats_file[mode][episode]["errors"]
|
||||
)
|
||||
|
||||
|
||||
playing = True
|
||||
player = 1
|
||||
|
|
|
|||
122
flake.lock
generated
122
flake.lock
generated
|
|
@ -1,21 +1,38 @@
|
|||
{
|
||||
"nodes": {
|
||||
"fix-py": {
|
||||
"inputs": {
|
||||
"flake-utils": "flake-utils",
|
||||
"nixpkgs": "nixpkgs"
|
||||
},
|
||||
"crane": {
|
||||
"locked": {
|
||||
"lastModified": 1752329544,
|
||||
"narHash": "sha256-LtZyywexRDB5FFmEU6A1e5tMIb0MX8/xxpjCHonUzYE=",
|
||||
"owner": "GuillaumeDesforges",
|
||||
"repo": "fix-python",
|
||||
"rev": "248e2ea9620faee9b8a2ae10e12320de2a819fe9",
|
||||
"lastModified": 1764903584,
|
||||
"narHash": "sha256-RSkJtNtx0SEaQiYqsoFoRynwfZLo2OZ9z6rUq1DJR6g=",
|
||||
"owner": "ipetkov",
|
||||
"repo": "crane",
|
||||
"rev": "2b3a5a88d852575758e1eb6ac9ee677fcd633fc1",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "GuillaumeDesforges",
|
||||
"repo": "fix-python",
|
||||
"owner": "ipetkov",
|
||||
"repo": "crane",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"fenix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"rust-analyzer-src": "rust-analyzer-src"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1764915802,
|
||||
"narHash": "sha256-eHTucU43sRCpvvTt5eey9htcWipS7ZN3B7ts6MiXLxo=",
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"rev": "a83a78fd3587d9f3388f0b459ad9c2bbd6d1b6d8",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
|
|
@ -24,43 +41,30 @@
|
|||
"systems": "systems"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1689068808,
|
||||
"narHash": "sha256-6ixXo3wt24N/melDWjq70UuHQLxGV8jZvooRanIHXw0=",
|
||||
"lastModified": 1731533236,
|
||||
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "919d646de7be200f3bf08cb76ae1f09402b6f9b4",
|
||||
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"id": "flake-utils",
|
||||
"type": "indirect"
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1682786779,
|
||||
"narHash": "sha256-m7QFzPS/CE8hbkbIVK4UStihAQMtczr0vSpOgETOM1g=",
|
||||
"lastModified": 1764886837,
|
||||
"narHash": "sha256-Sup1QVtp30cSa1FJLVkwTHEBV75XT+lWPvgHL0I7s1s=",
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "08e4dc3a907a6dfec8bb3bbf1540d8abbffea22b",
|
||||
"rev": "bd88d6c13ab85cc842b93d53f68d6d40412e5a18",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"id": "nixpkgs",
|
||||
"type": "indirect"
|
||||
}
|
||||
},
|
||||
"nixpkgs_2": {
|
||||
"locked": {
|
||||
"lastModified": 1764167966,
|
||||
"narHash": "sha256-nXv6xb7cq+XpjBYIjWEGTLCqQetxJu6zvVlrqHMsCOA=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "5c46f3bd98147c8d82366df95bbef2cab3a967ea",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"owner": "nixos",
|
||||
"ref": "nixpkgs-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
|
|
@ -68,27 +72,51 @@
|
|||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"fix-py": "fix-py",
|
||||
"nixpkgs": "nixpkgs_2",
|
||||
"systems": "systems_2"
|
||||
"crane": "crane",
|
||||
"fenix": "fenix",
|
||||
"flake-utils": "flake-utils",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"rust-overlay": "rust-overlay"
|
||||
}
|
||||
},
|
||||
"systems": {
|
||||
"rust-analyzer-src": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"lastModified": 1764778537,
|
||||
"narHash": "sha256-SNL+Fj1ZWiBqCrHJT1S9vMZujrWxCOmf3zkT66XSnhE=",
|
||||
"owner": "rust-lang",
|
||||
"repo": "rust-analyzer",
|
||||
"rev": "633cff25206d5108043d87617a43c9d04aa42c88",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"owner": "rust-lang",
|
||||
"ref": "nightly",
|
||||
"repo": "rust-analyzer",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"systems_2": {
|
||||
"rust-overlay": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1764902447,
|
||||
"narHash": "sha256-wNqkDBj+tjK619sTHPEA7uhjr7DHHEY8OsFou31dxy0=",
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"rev": "d914a744a83098eeb28125d2848ad383b209223f",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"systems": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
|
|
|
|||
129
flake.nix
129
flake.nix
|
|
@ -1,27 +1,118 @@
|
|||
{
|
||||
# Build Pyo3 package
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable";
|
||||
systems.url = "github:nix-systems/default";
|
||||
fix-py.url = "github:GuillaumeDesforges/fix-python";
|
||||
nixpkgs.url = "github:nixos/nixpkgs/nixpkgs-unstable";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
rust-overlay = {
|
||||
url = "github:oxalica/rust-overlay";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
crane.url = "github:ipetkov/crane";
|
||||
fenix = {
|
||||
url = "github:nix-community/fenix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
};
|
||||
|
||||
outputs =
|
||||
{ nixpkgs, fix-py, ... }:
|
||||
let
|
||||
eachSystem =
|
||||
f:
|
||||
nixpkgs.lib.genAttrs nixpkgs.lib.systems.flakeExposed (system: f nixpkgs.legacyPackages.${system});
|
||||
in
|
||||
{
|
||||
devShells = eachSystem (pkgs: {
|
||||
default = pkgs.mkShell {
|
||||
buildInputs = [
|
||||
pkgs.cargo
|
||||
pkgs.rustc
|
||||
pkgs.stdenv.cc.cc.lib
|
||||
fix-py.packages.${pkgs.system}.default
|
||||
inputs:
|
||||
inputs.flake-utils.lib.eachDefaultSystem (
|
||||
system:
|
||||
let
|
||||
pkgs = import inputs.nixpkgs {
|
||||
inherit system;
|
||||
overlays = [ inputs.rust-overlay.overlays.default ];
|
||||
};
|
||||
lib = pkgs.lib;
|
||||
|
||||
python_version = pkgs.python313;
|
||||
wheel_tail = "cp313-cp313-linux_x86_64"; # Change if python_version changes
|
||||
|
||||
# Get a custom rust toolchain
|
||||
craneLib = (inputs.crane.mkLib pkgs).overrideToolchain (
|
||||
p: inputs.fenix.packages.${system}.complete.toolchain
|
||||
);
|
||||
|
||||
project_name = (craneLib.crateNameFromCargoToml { cargoToml = ./game/Cargo.toml; }).pname;
|
||||
project_version = (craneLib.crateNameFromCargoToml { cargoToml = ./game/Cargo.toml; }).version;
|
||||
|
||||
crate_cfg = {
|
||||
src =
|
||||
let
|
||||
fs = lib.fileset;
|
||||
in
|
||||
fs.toSource {
|
||||
root = ./game;
|
||||
fileset = fs.unions [
|
||||
./game/src/lib.rs
|
||||
./game/Cargo.lock
|
||||
./game/Cargo.toml
|
||||
./game/pyproject.toml
|
||||
];
|
||||
};
|
||||
nativeBuildInputs = [ python_version ];
|
||||
# doCheck = true;
|
||||
# buildInputs = [];
|
||||
};
|
||||
|
||||
crate_artifacts = craneLib.buildDepsOnly (
|
||||
crate_cfg
|
||||
// {
|
||||
pname = "${project_name}-artifacts";
|
||||
version = project_version;
|
||||
}
|
||||
);
|
||||
|
||||
# Build the library, then re-use the target dir to generate the wheel file with maturin
|
||||
crate_wheel =
|
||||
(craneLib.buildPackage (
|
||||
crate_cfg
|
||||
// {
|
||||
pname = project_name;
|
||||
version = project_version;
|
||||
cargoArtifacts = crate_artifacts;
|
||||
}
|
||||
)).overrideAttrs
|
||||
(old: {
|
||||
nativeBuildInputs = old.nativeBuildInputs ++ [ pkgs.maturin ];
|
||||
buildPhase = old.buildPhase + ''
|
||||
maturin build --offline --target-dir ./target
|
||||
'';
|
||||
installPhase = old.installPhase + ''
|
||||
cp target/wheels/${project_name}-${project_version}-${wheel_tail}.whl $out/
|
||||
'';
|
||||
});
|
||||
in
|
||||
rec {
|
||||
packages = {
|
||||
default = crate_wheel; # The wheel itself
|
||||
|
||||
# A python version with the library installed
|
||||
pythonpkg = python_version.withPackages (ps: [
|
||||
(lib.python_package ps)
|
||||
]);
|
||||
};
|
||||
devShells.default = craneLib.devShell {
|
||||
packages = [
|
||||
(pkgs.python3.withPackages (ppkgs: [
|
||||
ppkgs.torch
|
||||
(lib.python_package ppkgs)
|
||||
]))
|
||||
];
|
||||
};
|
||||
});
|
||||
};
|
||||
lib = {
|
||||
# To use in other builds with the "withPackages" call
|
||||
python_package =
|
||||
ps:
|
||||
ps.buildPythonPackage rec {
|
||||
pname = project_name;
|
||||
format = "wheel";
|
||||
version = project_version;
|
||||
src = "${crate_wheel}/${project_name}-${project_version}-${wheel_tail}.whl";
|
||||
doCheck = false;
|
||||
pythonImportsCheck = [ project_name ];
|
||||
};
|
||||
};
|
||||
}
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
#[allow(dead_code)]
|
||||
use pyo3::prelude::*;
|
||||
|
||||
#[pymodule]
|
||||
mod game {
|
||||
use pyo3::{exceptions::PyIndexError, prelude::*, types::PySequence};
|
||||
use pyo3::exceptions::PyIndexError;
|
||||
use pyo3::prelude::*;
|
||||
use std::{
|
||||
collections::HashSet,
|
||||
fmt::{Display, Formatter},
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue