76 lines
1.5 KiB
Python
76 lines
1.5 KiB
Python
from typing import Callable
|
|
from .stream import Stream, T, U
|
|
|
|
|
|
@Stream.extension()
|
|
def map(stream: Stream[T], fn: Callable[[T], U]) -> Stream[U]:
|
|
def gen():
|
|
for x in stream:
|
|
yield fn(x)
|
|
|
|
return Stream(gen())
|
|
|
|
|
|
@Stream.extension()
|
|
def filter(stream: Stream[T], pred: Callable[[T], bool]) -> Stream[T]:
|
|
def gen():
|
|
for x in stream:
|
|
if pred(x):
|
|
yield x
|
|
|
|
return Stream(gen())
|
|
|
|
|
|
@Stream.extension()
|
|
def take(stream: Stream[T], n: int) -> Stream[T]:
|
|
def gen():
|
|
c = 0
|
|
for x in stream:
|
|
if c < n:
|
|
c += 1
|
|
yield x
|
|
else:
|
|
return
|
|
|
|
return Stream(gen())
|
|
|
|
|
|
@Stream.extension()
|
|
def skip(stream: Stream[T], n: int) -> Stream[T]:
|
|
def gen():
|
|
c = 0
|
|
for x in stream:
|
|
c += 1
|
|
if c > n:
|
|
yield x
|
|
|
|
return Stream(gen())
|
|
|
|
|
|
@Stream.extension()
|
|
def batch(stream: Stream[T], n: int) -> Stream[list[T]]:
|
|
def gen():
|
|
ls: list[T] = []
|
|
for x in stream:
|
|
ls.append(x)
|
|
if len(ls) == n:
|
|
yield ls
|
|
ls = []
|
|
|
|
return Stream(gen())
|
|
|
|
|
|
@Stream.extension()
|
|
def enumerate(stream: Stream[T]) -> Stream[tuple[int, T]]:
|
|
def gen():
|
|
idx = 0
|
|
for x in stream:
|
|
yield (idx, x)
|
|
idx += 1
|
|
|
|
return Stream(gen())
|
|
|
|
|
|
@Stream.extension()
|
|
def collect(stream: Stream[T]) -> list[T]:
|
|
return [v for v in stream]
|