From 3386de9894e9cc7b8ecf1e4be988571cee29c911 Mon Sep 17 00:00:00 2001 From: Noa Aarts Date: Sun, 14 Dec 2025 10:51:02 +0100 Subject: [PATCH] streams --- .gitignore | 2 ++ main.py | 21 ++++++++++++ pyproject.toml | 2 +- qas-flow/__init__.py | 0 qas_flow/__init__.py | 4 +++ qas_flow/funcs.py | 76 ++++++++++++++++++++++++++++++++++++++++++++ qas_flow/stream.py | 39 +++++++++++++++++++++++ 7 files changed, 143 insertions(+), 1 deletion(-) create mode 100755 main.py delete mode 100644 qas-flow/__init__.py create mode 100644 qas_flow/__init__.py create mode 100644 qas_flow/funcs.py create mode 100644 qas_flow/stream.py diff --git a/.gitignore b/.gitignore index 92b2793..113ed19 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ .direnv + +__pycache__ diff --git a/main.py b/main.py new file mode 100755 index 0000000..519ee54 --- /dev/null +++ b/main.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + + +from math import log2 +from qas_flow import Stream + + +def gen(): + n = 0 + while True: + yield n + n += 1 + + +def mapfunc(ix): + i, x = ix + return str(x) * int(log2(i + 1)) + + +strm = Stream(gen()) +print(strm.enumerate().map(mapfunc).skip(20).take(4).enumerate().collect()) diff --git a/pyproject.toml b/pyproject.toml index ec63fd9..361dc3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools"] build-backend = "setuptools.build_meta" [project] -name = "qas-flow" +name = "qas_flow" version = "0.0.1" dependencies = [ "numpy", diff --git a/qas-flow/__init__.py b/qas-flow/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/qas_flow/__init__.py b/qas_flow/__init__.py new file mode 100644 index 0000000..6cf5144 --- /dev/null +++ b/qas_flow/__init__.py @@ -0,0 +1,4 @@ +from .stream import Stream +from .funcs import map, filter, take, skip, batch, enumerate, collect + +__all__ = ["Stream", "map", "filter", "take", "skip", "batch", "enumerate", "collect"] diff --git a/qas_flow/funcs.py b/qas_flow/funcs.py new file mode 100644 index 0000000..63a645f --- /dev/null +++ b/qas_flow/funcs.py @@ -0,0 +1,76 @@ +from typing import Callable +from .stream import Stream, T, U + + +@Stream.extension() +def map(stream: Stream[T], fn: Callable[[T], U]) -> Stream[U]: + def gen(): + for x in stream: + yield fn(x) + + return Stream(gen()) + + +@Stream.extension() +def filter(stream: Stream[T], pred: Callable[[T], bool]) -> Stream[T]: + def gen(): + for x in stream: + if pred(x): + yield x + + return Stream(gen()) + + +@Stream.extension() +def take(stream: Stream[T], n: int) -> Stream[T]: + def gen(): + c = 0 + for x in stream: + if c < n: + c += 1 + yield x + else: + return + + return Stream(gen()) + + +@Stream.extension() +def skip(stream: Stream[T], n: int) -> Stream[T]: + def gen(): + c = 0 + for x in stream: + c += 1 + if c > n: + yield x + + return Stream(gen()) + + +@Stream.extension() +def batch(stream: Stream[T], n: int) -> Stream[list[T]]: + def gen(): + ls: list[T] = [] + for x in stream: + ls.append(x) + if len(ls) == n: + yield ls + ls = [] + + return Stream(gen()) + + +@Stream.extension() +def enumerate(stream: Stream[T]) -> Stream[tuple[int, T]]: + def gen(): + idx = 0 + for x in stream: + yield (idx, x) + idx += 1 + + return Stream(gen()) + + +@Stream.extension() +def collect(stream: Stream[T]) -> list[T]: + return [v for v in stream] diff --git a/qas_flow/stream.py b/qas_flow/stream.py new file mode 100644 index 0000000..c725897 --- /dev/null +++ b/qas_flow/stream.py @@ -0,0 +1,39 @@ +from collections.abc import Iterator +from typing import Any, Callable, Generic, TypeVar, final + + +T = TypeVar("T") +U = TypeVar("U") + +Op = Callable[["Stream[T]"], "Stream[U]"] + + +@final +class Stream(Generic[T]): + _extensions: dict[str, Callable[..., Any]] = {} + + def __init__(self, it: Iterator[T]) -> None: + self._it = it + + def __iter__(self) -> Iterator[T]: + return self._it + + @classmethod + def extension(cls, name: str | None = None): + """Register a function as Stream.(...). First arg will be the stream.""" + + def deco(fn: Callable[..., Any]): + cls._extensions[name or fn.__name__] = fn + return fn + + return deco + + def __getattr__(self, attr: str): + fn = self._extensions.get(attr) + if fn is None: + raise AttributeError(attr) + + def bound(*args, **kwargs): + return fn(self, *args, **kwargs) + + return bound