This commit is contained in:
Noa Aarts 2025-12-14 10:51:02 +01:00
parent e9cf83145f
commit 3386de9894
Signed by: noa
GPG key ID: 1850932741EFF672
7 changed files with 143 additions and 1 deletions

2
.gitignore vendored
View file

@ -1 +1,3 @@
.direnv
__pycache__

21
main.py Executable file
View file

@ -0,0 +1,21 @@
#!/usr/bin/env python
from math import log2
from qas_flow import Stream
def gen():
n = 0
while True:
yield n
n += 1
def mapfunc(ix):
i, x = ix
return str(x) * int(log2(i + 1))
strm = Stream(gen())
print(strm.enumerate().map(mapfunc).skip(20).take(4).enumerate().collect())

View file

@ -3,7 +3,7 @@ requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[project]
name = "qas-flow"
name = "qas_flow"
version = "0.0.1"
dependencies = [
"numpy",

View file

4
qas_flow/__init__.py Normal file
View file

@ -0,0 +1,4 @@
from .stream import Stream
from .funcs import map, filter, take, skip, batch, enumerate, collect
__all__ = ["Stream", "map", "filter", "take", "skip", "batch", "enumerate", "collect"]

76
qas_flow/funcs.py Normal file
View file

@ -0,0 +1,76 @@
from typing import Callable
from .stream import Stream, T, U
@Stream.extension()
def map(stream: Stream[T], fn: Callable[[T], U]) -> Stream[U]:
def gen():
for x in stream:
yield fn(x)
return Stream(gen())
@Stream.extension()
def filter(stream: Stream[T], pred: Callable[[T], bool]) -> Stream[T]:
def gen():
for x in stream:
if pred(x):
yield x
return Stream(gen())
@Stream.extension()
def take(stream: Stream[T], n: int) -> Stream[T]:
def gen():
c = 0
for x in stream:
if c < n:
c += 1
yield x
else:
return
return Stream(gen())
@Stream.extension()
def skip(stream: Stream[T], n: int) -> Stream[T]:
def gen():
c = 0
for x in stream:
c += 1
if c > n:
yield x
return Stream(gen())
@Stream.extension()
def batch(stream: Stream[T], n: int) -> Stream[list[T]]:
def gen():
ls: list[T] = []
for x in stream:
ls.append(x)
if len(ls) == n:
yield ls
ls = []
return Stream(gen())
@Stream.extension()
def enumerate(stream: Stream[T]) -> Stream[tuple[int, T]]:
def gen():
idx = 0
for x in stream:
yield (idx, x)
idx += 1
return Stream(gen())
@Stream.extension()
def collect(stream: Stream[T]) -> list[T]:
return [v for v in stream]

39
qas_flow/stream.py Normal file
View file

@ -0,0 +1,39 @@
from collections.abc import Iterator
from typing import Any, Callable, Generic, TypeVar, final
T = TypeVar("T")
U = TypeVar("U")
Op = Callable[["Stream[T]"], "Stream[U]"]
@final
class Stream(Generic[T]):
_extensions: dict[str, Callable[..., Any]] = {}
def __init__(self, it: Iterator[T]) -> None:
self._it = it
def __iter__(self) -> Iterator[T]:
return self._it
@classmethod
def extension(cls, name: str | None = None):
"""Register a function as Stream.<name>(...). First arg will be the stream."""
def deco(fn: Callable[..., Any]):
cls._extensions[name or fn.__name__] = fn
return fn
return deco
def __getattr__(self, attr: str):
fn = self._extensions.get(attr)
if fn is None:
raise AttributeError(attr)
def bound(*args, **kwargs):
return fn(self, *args, **kwargs)
return bound