QAS-Flow/qas_flow/funcs.py
2025-12-14 10:51:02 +01:00

76 lines
1.5 KiB
Python

from typing import Callable
from .stream import Stream, T, U
@Stream.extension()
def map(stream: Stream[T], fn: Callable[[T], U]) -> Stream[U]:
def gen():
for x in stream:
yield fn(x)
return Stream(gen())
@Stream.extension()
def filter(stream: Stream[T], pred: Callable[[T], bool]) -> Stream[T]:
def gen():
for x in stream:
if pred(x):
yield x
return Stream(gen())
@Stream.extension()
def take(stream: Stream[T], n: int) -> Stream[T]:
def gen():
c = 0
for x in stream:
if c < n:
c += 1
yield x
else:
return
return Stream(gen())
@Stream.extension()
def skip(stream: Stream[T], n: int) -> Stream[T]:
def gen():
c = 0
for x in stream:
c += 1
if c > n:
yield x
return Stream(gen())
@Stream.extension()
def batch(stream: Stream[T], n: int) -> Stream[list[T]]:
def gen():
ls: list[T] = []
for x in stream:
ls.append(x)
if len(ls) == n:
yield ls
ls = []
return Stream(gen())
@Stream.extension()
def enumerate(stream: Stream[T]) -> Stream[tuple[int, T]]:
def gen():
idx = 0
for x in stream:
yield (idx, x)
idx += 1
return Stream(gen())
@Stream.extension()
def collect(stream: Stream[T]) -> list[T]:
return [v for v in stream]