Generator Basics
A generator function looks like a normal function but contains yield. Calling it returns a generator object - nothing executes until you iterate:
def count_up(start, stop):
n = start
while n <= stop:
yield n # pause here, return n
n += 1 # resume here on next next() call
gen = count_up(1, 3) # creates generator object, runs nothing
print(next(gen)) # 1 - runs until first yield
print(next(gen)) # 2
print(next(gen)) # 3
print(next(gen)) # raises StopIteration
# for loops call next() automatically
for n in count_up(1, 5):
print(n) # 1 2 3 4 5
1 2 3 StopIteration
Generators have state. Each generator object has its own position and local variables:
def fibonacci():
a, b = 0, 1
while True:
yield a
a, b = b, a + b
fib = fibonacci()
first_ten = [next(fib) for _ in range(10)]
print(first_ten)
# Two independent generators with separate state
fib1 = fibonacci()
fib2 = fibonacci()
next(fib1); next(fib1); next(fib1) # advance fib1 three steps
print(next(fib1)) # 2
print(next(fib2)) # 0 - fib2 is independent
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34] 2 0
A generator function can have a return statement. Its value becomes the StopIteration exception value:
def first_match(iterable, predicate):
for item in iterable:
if predicate(item):
yield item
return item # stop after first match, return the value
gen = first_match([1, 4, 7, 3, 9], lambda x: x > 5)
try:
print(next(gen)) # 7
print(next(gen)) # StopIteration(7)
except StopIteration as e:
print(f'Generator returned: {e.value}') # 7
A list comprehension [x*2 for x in range(10_000_000)] allocates ~80 MB. The equivalent generator expression (x*2 for x in range(10_000_000)) uses about 120 bytes regardless of size - it stores only the generator state, not the values.
Generator Expressions
Generator expressions are the lazy counterpart to list comprehensions. Use parentheses instead of square brackets:
import sys
# List comprehension - builds entire list in memory
squares_list = [x**2 for x in range(1000)]
print(sys.getsizeof(squares_list)) # ~8856 bytes
# Generator expression - lazy, O(1) memory
squares_gen = (x**2 for x in range(1000))
print(sys.getsizeof(squares_gen)) # ~104 bytes
# Pass directly to functions that consume iterables
total = sum(x**2 for x in range(1000)) # no extra brackets needed
largest = max(len(s) for s in ['hi', 'hello', 'hey'])
any_neg = any(x < 0 for x in [1, 2, -3, 4])
all_pos = all(x > 0 for x in [1, 2, 3])
# Filter + transform
data = [1, -2, 3, -4, 5]
pos_doubled = list(x * 2 for x in data if x > 0)
print(pos_doubled) # [2, 6, 10]
Generator expressions are single-use. Once exhausted, they cannot be restarted:
gen = (x**2 for x in range(5))
print(list(gen)) # [0, 1, 4, 9, 16]
print(list(gen)) # [] - exhausted!
# If you need to iterate multiple times, use a list
# or define a function that creates a fresh generator each time
def squares(n):
return (x**2 for x in range(n))
print(list(squares(5))) # [0, 1, 4, 9, 16]
print(list(squares(5))) # [0, 1, 4, 9, 16] - fresh generator each call
yield from
yield from iterable delegates to another iterable, yielding each of its values in turn. It is more than a shorthand for a loop - it also wires send() and throw() transparently:
def chain(*iterables):
for it in iterables:
yield from it # equivalent to: for item in it: yield item
result = list(chain([1, 2], [3, 4], [5]))
print(result) # [1, 2, 3, 4, 5]
# Flatten nested lists
def flatten(nested):
for item in nested:
if isinstance(item, list):
yield from flatten(item) # recursive delegation
else:
yield item
nested = [1, [2, 3, [4, 5]], 6, [7, [8, 9]]]
print(list(flatten(nested))) # [1, 2, 3, 4, 5, 6, 7, 8, 9]
The return value of the sub-generator becomes the result of the yield from expression in the delegating generator:
def sub_gen():
yield 1
yield 2
return 'sub done' # becomes result of yield from
def delegating_gen():
result = yield from sub_gen()
print(f'sub returned: {result}')
yield 3
gen = delegating_gen()
print(next(gen)) # 1
print(next(gen)) # 2 (sub_gen is transparently delegated)
print(next(gen)) # prints "sub returned: sub done", then yields 3 -> 3
1 2 sub returned: sub done 3
send(), throw(), close()
Generators support two-way communication. send(value) resumes the generator AND provides a value that becomes the result of the yield expression:
def accumulator():
total = 0
while True:
value = yield total # yield current total, receive new value
if value is None:
return total
total += value
gen = accumulator()
next(gen) # prime the generator (advance to first yield)
gen.send(10) # total = 10
gen.send(20) # total = 30
result = gen.send(5)
print(result) # 35
# Must call next() or send(None) first to prime the generator
# Calling send(non_None) before priming raises TypeError
def logged_gen():
try:
while True:
value = yield
print(f'received: {value}')
except GeneratorExit:
print('generator closing cleanly')
except ValueError as e:
print(f'ValueError injected: {e}')
gen = logged_gen()
next(gen)
gen.send('hello') # received: hello
gen.send('world') # received: world
gen.throw(ValueError, 'bad input') # ValueError injected: bad input
gen.close() # generator closing cleanly (sends GeneratorExit)
received: hello received: world ValueError injected: bad input generator closing cleanly
When a generator is used as a coroutine (with send()), it must be primed with next(gen) before the first send(). A common pattern is a @coroutine decorator that calls next() automatically. In modern Python, prefer async def and await - they build on generators but eliminate the need to prime manually.
Generator Pipelines
Generators compose naturally into lazy pipelines - each stage pulls from the previous one on demand. The whole pipeline uses O(1) memory regardless of input size:
import re
# Each stage is a generator function
def read_lines(path):
with open(path) as f:
yield from f
def strip_comments(lines):
for line in lines:
line = line.split('#')[0].strip()
if line:
yield line
def parse_numbers(lines):
for line in lines:
for word in line.split():
try:
yield float(word)
except ValueError:
pass
def above_threshold(numbers, threshold):
return (n for n in numbers if n > threshold)
# Assemble pipeline - no data flows until we consume
lines = read_lines('data.txt')
stripped = strip_comments(lines)
numbers = parse_numbers(stripped)
filtered = above_threshold(numbers, 100.0)
# Only now does data flow through the pipeline
total = sum(filtered)
print(f'Sum of numbers above 100: {total}')
Pipeline with logging/timing at each stage:
def tee(iterable, label=''):
count = 0
for item in iterable:
count += 1
yield item
print(f'[{label}] yielded {count} items')
def take(iterable, n):
for i, item in enumerate(iterable):
if i >= n:
return
yield item
# Add tee() stages for debugging without breaking pipeline
pipeline = take(
tee(
(x**2 for x in range(100)),
label='squares'
),
n=5
)
print(list(pipeline)) # [0, 1, 4, 9, 16]
# [squares] yielded 5 items (only 5 pulled - lazy!)
Batching items from a pipeline:
from itertools import islice
def batched(iterable, n):
it = iter(iterable)
while True:
batch = list(islice(it, n))
if not batch:
return
yield batch
# Process a million records in batches of 1000
def generate_records():
for i in range(1_000_000):
yield {'id': i, 'value': i * 2}
for batch in batched(generate_records(), 1000):
# batch is a list of 1000 dicts
# process_batch(batch)
pass
Infinite Sequences
Generators can model infinite sequences without storing any values - just pair them with islice or a termination condition:
from itertools import islice
def naturals(start=1):
n = start
while True:
yield n
n += 1
def primes():
yield 2
tested = [2]
for candidate in naturals(3):
if all(candidate % p != 0 for p in tested):
tested.append(candidate)
yield candidate
# Take first 10 primes
first_ten = list(islice(primes(), 10))
print(first_ten) # [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
# Collatz sequence
def collatz(n):
yield n
while n != 1:
n = n // 2 if n % 2 == 0 else 3 * n + 1
yield n
print(list(collatz(6))) # [6, 3, 10, 5, 16, 8, 4, 2, 1]
from collections import deque
def moving_average(iterable, window):
buf = deque(maxlen=window)
for value in iterable:
buf.append(value)
if len(buf) == window:
yield sum(buf) / window
# Moving average over a data stream
data = [1, 3, 5, 7, 9, 11, 13]
avgs = list(moving_average(data, window=3))
print(avgs) # [3.0, 5.0, 7.0, 9.0, 11.0]
# Running statistics on infinite stream
def running_stats(iterable):
n = total = 0
for value in iterable:
n += 1
total += value
mean = total / n
yield n, mean
for count, mean in islice(running_stats(naturals()), 5):
print(f'n={count}, mean={mean:.2f}')
n=1, mean=1.00 n=2, mean=1.50 n=3, mean=2.00 n=4, mean=2.50 n=5, mean=3.00
For common infinite sequences (count, cycle, repeat) and combiners (chain, islice), prefer itertools - they are implemented in C and faster than equivalent generator functions. Write custom generators when the logic is specific to your problem.