“””
Continuous batching = iteration-level scheduling + ragged (packed) batching.
Two approaches are compared (both run BATCH_SIZE sequences concurrently, so the
comparison is slot-for-slot fair):
1. Static batching (baseline):
Prompts are processed BATCH_SIZE at a time. Each wave is padded to a
common length and run together until the LONGEST request in that wave
finishes; a hard “batch barrier” then has to clear before the next wave
starts. Short requests sit idle behind the barrier.
2. Continuous batching (production-aligned):
Two ideas combine to keep the GPU busy.
(a) Iteration-level scheduling: the moment a sequence finishes it frees
its slot, and the next queued prompt is admitted on the SAME step –
no waiting for the rest of the batch.
(b) Ragged / packed batching – the part that makes it truly “continuous”:
instead of padding every sequence into a rectangular (B, max_len)
tensor, ALL in-flight tokens are concatenated into a single unpadded
(1, total_tokens) row and run in ONE forward pass. A block-diagonal
causal attention mask stops tokens from attending across sequence
boundaries, so packing is mathematically identical to running each
sequence on its own (verified: greedy output matches per-prompt
generation token-for-token).
Because attention is governed entirely by the mask, a newly admitted
prompt’s multi-token PREFILL rides along in the same forward pass as
every other sequence’s single-token DECODE step. Prefill and decode are
fused: no padding, no separate prefill pass.
KV cache: each sequence keeps its own DynamicCache; every step the caches
are concatenated along the time axis into one packed cache, and the newly
computed KV is scattered back per sequence. (Real engines store the
cache in fixed-size pages – “paged attention” – to avoid this per-step
reassembly, but the attention/masking logic is exactly what you see here.)
“””
import time
import torch
from dataclasses import dataclass, field
from typing import Optional
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
from transformers.cache_utils import DynamicLayer
MODEL_ID = “openai-community/gpt2” # swap for any causal LM
BATCH_SIZE = 3 # max concurrent sequences (slots)
def _device_sync(model) -> None:
“””Block until queued GPU work finishes, so timings are accurate.”””
if model.device.type == “cuda”:
torch.cuda.synchronize()
elif model.device.type == “mps”:
torch.mps.synchronize()
def static_batching(requests: list(tuple(str, int)), tokenizer, model) -> list(str):
“””Baseline. Process requests BATCH_SIZE at a time; each wave runs together
until its LONGEST request finishes, then a batch barrier clears before the
next wave starts.
Downside: short requests in a wave idle until the wave’s longest is done –
and no slot can be refilled until the whole wave clears the barrier.
“””
if not requests:
return ()
tokenizer.padding_side = “left”
results: dict(int, str) = {}
indexed = list(enumerate(requests)) # (req_id, (prompt, cap))
for wave_start in range(0, len(indexed), BATCH_SIZE):
wave = indexed(wave_start: wave_start + BATCH_SIZE)
wave_max = max(cap for _, (_, cap) in wave)
# Show which request occupies each slot in this wave.
for slot, (req_id, (prompt, cap)) in enumerate(wave):
print(f” ++ slot {slot} <- req {req_id} ({cap} tok cap): {prompt!r}”, flush=True)
prompts = (p for _, (p, _) in wave)
inputs = tokenizer(
prompts, return_tensors=”pt”, padding=True, truncation=True
).to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=wave_max, # whole wave decodes to the longest
pad_token_id=tokenizer.eos_token_id,
do_sample=False,
)
width = inputs.input_ids.shape(1)
print(
f” *** batch barrier: all {len(wave)} slots wait for the longest “
f”({wave_max} tokens) ***”,
flush=True,
)
for slot, ((req_id, (prompt, cap)), row) in enumerate(zip(wave, output_ids)):
text = prompt + tokenizer.decode(row(width:width + cap), skip_special_tokens=True)
results(req_id) = text
print(
f” — slot {slot} done req {req_id} ({cap}/{wave_max} tokens): {text(:90)}”,
flush=True,
)
return (results(k) for k in sorted(results))
@dataclass
class Sequence:
“””State for a single in-flight sequence.”””
req_id: int # original request index (for ordering results)
prompt: str
max_new_tokens: int # per-request cap so short requests finish early
# Tokens to feed on the NEXT step: the whole prompt right after admission
# (prefill), then a single token per step (decode).
pending_ids: list(int)
# Per-sequence KV-cache; None until this sequence has run once.
kv_cache: Optional(DynamicCache) = None
kv_len: int = 0 # number of cached tokens (prompt + generated)
tokens_generated: int = 0
output_ids: list(int) = field(default_factory=list)
def _make_cache(layers_kv: list(tuple(torch.Tensor, torch.Tensor))) -> DynamicCache:
“””Build a DynamicCache from explicit per-layer (keys, values) tensors.
We SET the tensors directly instead of calling DynamicLayer.update() (which
would append), because we are assembling caches from scratch each step.
“””
cache = DynamicCache()
for k, v in layers_kv:
layer = DynamicLayer()
layer.lazy_initialization(k, v)
layer.keys = k
layer.values = v
cache.layers.append(layer)
return cache
def _ragged_step(seqs: list(Sequence), model, device, dtype) -> list(int):
“””Run ONE packed forward pass over every active sequence.
All sequences are flattened into a single row (batch dim = 1):
input_ids (1, total_q) – every sequence’s pending tokens
position_ids (1, total_q) – each token’s position in ITS sequence
attention_mask (1, 1, total_q, total_kv + total_q) – block-diagonal causal
past_key_values packed cache (1, H, total_kv, D)
total_q = sum of pending tokens (1 per decoding seq, prompt_len per new seq)
total_kv = sum of already-cached tokens across sequences
Returns the next greedy token for each sequence (same order as “seqs“).
“””
q_lens = (len(s.pending_ids) for s in seqs)
total_q = sum(q_lens)
total_kv = sum(s.kv_len for s in seqs)
# Packed inputs: concatenate every sequence’s pending tokens into one row.
flat_ids = (t for s in seqs for t in s.pending_ids)
input_ids = torch.tensor((flat_ids), dtype=torch.long, device=device)
# Tag every KEY and every QUERY token with (sequence index, position-in-sequence).
# Key space is laid out as ( cached tokens | this step’s new tokens ), matching
# how the model appends new KV to the end of the packed cache.
key_seq, key_pos = (), ()
for si, s in enumerate(seqs): # cached block
for p in range(s.kv_len):
key_seq.append(si)
key_pos.append(p)
q_seq, q_pos = (), ()
for si, s in enumerate(seqs): # new block (also queries)
for j in range(len(s.pending_ids)):
pos = s.kv_len + j
q_seq.append(si)
q_pos.append(pos)
key_seq.append(si)
key_pos.append(pos)
q_seq_t = torch.tensor(q_seq, device=device)
q_pos_t = torch.tensor(q_pos, device=device)
key_seq_t = torch.tensor(key_seq, device=device)
key_pos_t = torch.tensor(key_pos, device=device)
# Each token’s positional embedding uses its own sequence position, not its
# offset in the packed row.
position_ids = q_pos_t.unsqueeze(0) # (1, total_q)
# Block-diagonal causal mask: a query may attend to a key only if they belong
# to the SAME sequence (block-diagonal) and the key is not in the future
# (causal). This is the whole trick – it makes packing equivalent to running
# each sequence separately. 0.0 = attend, large-negative = blocked (additive).
same = q_seq_t(:, None) == key_seq_t(None, 🙂
causal = key_pos_t(None, 🙂 <= q_pos_t(:, None)
allowed = same & causal # (total_q, total_kv + total_q)
attn_mask = torch.zeros(1, 1, total_q, total_kv + total_q, dtype=dtype, device=device)
attn_mask.masked_fill_(~allowed(None, None), torch.finfo(dtype).min)
# Packed KV-cache: concatenate each sequence’s cache along the time axis.
# Freshly admitted sequences (kv_len == 0) contribute nothing here.
cached = (s for s in seqs if s.kv_len > 0)
if cached:
num_layers = len(cached(0).kv_cache.layers)
layers_kv = ()
for l in range(num_layers):
ks = torch.cat((s.kv_cache.layers(l).keys for s in cached), dim=2)
vs = torch.cat((s.kv_cache.layers(l).values for s in cached), dim=2)
layers_kv.append((ks, vs))
past = _make_cache(layers_kv)
else:
past = DynamicCache()
with torch.no_grad():
out = model(
input_ids=input_ids,
attention_mask=attn_mask,
position_ids=position_ids,
past_key_values=past,
use_cache=True,
)
# Greedy next token for each sequence: read the logits at its LAST pending
# token (for a prefilling sequence that is the final prompt token).
logits = out.logits(0) # (total_q, vocab)
offsets, last_idx, off = (), (), 0
for ql in q_lens:
offsets.append(off)
last_idx.append(off + ql – 1)
off += ql
next_tokens = (int(logits(i).argmax()) for i in last_idx)
# Scatter the newly computed KV back to each sequence. The output cache is
# ( old packed block | new packed block ); slice this step’s new block per
# sequence and append it to that sequence’s own cache.
out_kv = out.past_key_values
num_layers = len(out_kv.layers)
for si, s in enumerate(seqs):
o, ql = offsets(si), q_lens(si)
layers_kv = ()
for l in range(num_layers):
k_new = out_kv.layers(l).keys(:, :, total_kv + o: total_kv + o + ql, 🙂
v_new = out_kv.layers(l).values(:, :, total_kv + o: total_kv + o + ql, 🙂
if s.kv_cache is None:
layers_kv.append((k_new, v_new))
else:
layers_kv.append((
torch.cat((s.kv_cache.layers(l).keys, k_new), dim=2),
torch.cat((s.kv_cache.layers(l).values, v_new), dim=2),
))
s.kv_cache = _make_cache(layers_kv)
s.kv_len += ql
return next_tokens
def visualize_ragged_step(seqs: list(Sequence), tokenizer, title: str, slot_ids: list(int)) -> None:
“””Illustrative print of ONE packed step: the concatenated input row and the
block-diagonal causal attention mask.
This mirrors the masking logic in _ragged_step (recomputed here as a boolean
grid purely for display) so you can SEE that sequences are packed together
yet isolated by the mask. Each sequence gets a letter A, B, C, …
# = a query may attend to that key . = blocked
“””
labels = (chr(ord(“A”) + s.req_id) for s in seqs)
q_lens = (len(s.pending_ids) for s in seqs)
total_q = sum(q_lens)
total_kv = sum(s.kv_len for s in seqs)
print(f”\n{‘=’ * 72}\n {title}”)
print(f” total_q={total_q} tokens fed this step | total_kv={total_kv} cached”)
print(f” {len(seqs)} sequences packed into ONE unpadded row of shape (1, {total_q}):\n”)
# The concatenated tokens, grouped per sequence (this is the “ragged” row).
for i, s in enumerate(seqs):
kind = f”PREFILL({q_lens(i)})” if s.kv_len == 0 else f”decode({q_lens(i)})”
toks = ” “.join(repr(tokenizer.decode(
Source link



GIPHY App Key not set. Please check settings