diff --git a/.gitignore b/.gitignore index 938243bb7..eb392d276 100644 --- a/.gitignore +++ b/.gitignore @@ -107,9 +107,9 @@ celerybeat.pid # Environments .env -.venv +.venv* env/ -venv/ +venv*/ ENV/ env.bak/ venv.bak/ diff --git a/MaxText/benchmarks/bench_paged_attention_kernel.py b/MaxText/benchmarks/bench_paged_attention_kernel.py new file mode 100644 index 000000000..0be365f73 --- /dev/null +++ b/MaxText/benchmarks/bench_paged_attention_kernel.py @@ -0,0 +1,102 @@ +import argparse +import functools +import time + +import jax +import jax.numpy as jnp +from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention + +BLOCK_SIZE = 16 +MAX_NUM_BLOCKS_PER_SEQ = 512 + + +@functools.partial(jax.jit, static_argnums=(6, 7)) +def paged_attn( + q: jax.Array, # [batch, 1, num_heads, head_size] + k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size] + v_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size] + sm_scale: float, + block_tables: jax.Array, # [batch, max_num_blocks_per_batch] + context_lens: jax.Array, # [batch] + block_size: int, + pages_per_compute_block: int, +) -> jax.Array: # [batch, 1, num_heads, head_size] + q = q.squeeze(1) + q = q * sm_scale + + head_size = q.shape[-1] + num_slots = k_cache.shape[-2] + k_cache = k_cache.reshape(-1, num_slots // block_size, block_size, head_size) + v_cache = v_cache.reshape(-1, num_slots // block_size, block_size, head_size) + + output = paged_attention( + q, + k_cache, + v_cache, + context_lens, + block_tables, + pages_per_compute_block=pages_per_compute_block, + ) + return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2]) + + +def benchmark_paged_attn( + batch_size: int, + num_heads: int, + num_kv_heads: int, + head_size: int, + context_len: int, + num_blocks: int, + block_size: int, + pages_per_compute_block: int, +): + rng_key = jax.random.PRNGKey(0) + query = jax.random.normal(rng_key, (batch_size, 1, num_heads, head_size), dtype=jnp.bfloat16) + k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16) + v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16) + sm_scale = head_size**-0.5 + block_tables = jax.random.randint(rng_key, (batch_size, MAX_NUM_BLOCKS_PER_SEQ), 0, num_blocks, dtype=jnp.int32) + context_lens = jnp.array([context_len] * batch_size, dtype=jnp.int32) + + # For JIT compilation. + output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, block_size, pages_per_compute_block) + output.block_until_ready() + + start = time.time() + for _ in range(100): + output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, block_size, pages_per_compute_block) + output.block_until_ready() + end = time.time() + + print(f"Time taken: {(end - start) * 10000:.2f} us") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num-heads", type=int, default=64) + parser.add_argument("--num-kv-heads", type=int, default=8) + parser.add_argument("--head-size", type=int, default=128) + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument("--context-len", type=int, default=1024) + parser.add_argument("--num-blocks", type=int, default=2048) + parser.add_argument("--block-size", type=int, default=16) + args = parser.parse_args() + print(args) + + for block_size in [16, 32, 64, 128]: + for pages_per_compute_block in [1, 2, 4, 8, 16, 32, 64, 128]: + if pages_per_compute_block > MAX_NUM_BLOCKS_PER_SEQ: + continue + if block_size * pages_per_compute_block > 1024: + continue + print(f"block_size {block_size}, pages_per_compute_block: {pages_per_compute_block}") + benchmark_paged_attn( + args.batch_size, + args.num_heads, + args.num_kv_heads, + args.head_size, + args.context_len, + args.num_blocks, + block_size, + pages_per_compute_block, + ) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 6ec504b99..aa6bdce15 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -268,6 +268,10 @@ logical_axis_rules: [ ['cache_kv', []], ['cache_sequence', []], ['exp', 'expert'], + ['paged_kv_heads', []], + ['num_pages', ['tensor']], + ['tokens_per_page', []], + ['paged_kv_head_dim_size', []], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_sequence', 'expert', 'autoregressive']] @@ -542,3 +546,8 @@ sa_use_fused_bwd_kernel: False sa_q_layout: "HEAD_DIM_MINOR" sa_k_layout: "HEAD_DIM_MINOR" sa_v_layout: "HEAD_DIM_MINOR" + +# Paged Attention +num_pages: 64 +tokens_per_page: 32 +pages_per_compute_block: 8 \ No newline at end of file diff --git a/MaxText/decode.py b/MaxText/decode.py index ef2f2fc79..8b65af0f9 100644 --- a/MaxText/decode.py +++ b/MaxText/decode.py @@ -49,8 +49,10 @@ def main(argv: Sequence[str]) -> None: # Split RNG before calling prefill rng, rng_prefill = jax.random.split(rng) - prefill_result, first_token = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill) slot = 0 + prefill_result, first_token = engine.prefill( + params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill, slot=slot + ) rng, rng_init_decode = jax.random.split(rng) decode_state = engine.init_decode_state(rng_init_decode) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index f5990e7d9..26e516ba3 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -26,10 +26,13 @@ from jax.experimental import shard_map from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask +from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention +from jax.sharding import PartitionSpec as P import jax.numpy as jnp import common_types from kernels.ragged_attention import ragged_gqa from kernels.ragged_attention import ragged_mha +import page_managers from layers import embeddings from layers import initializers from layers import linears @@ -131,6 +134,257 @@ def apply_mask_to_logits(logits: Array, mask: Array): return jnp.where((mask >= DEFAULT_MASK_VALUE * 0.5), logits, DEFAULT_MASK_VALUE) +class PagedAttentionOp(nn.Module): + mesh: Mesh + num_pages: int + tokens_per_page: int + max_pages_per_slot: int + max_pages_per_prefill: int + pages_per_compute_block: int + + num_kv_heads: int + kv_head_dim_size: int + dtype: DType = jnp.float32 + attn_logits_soft_cap: float | None = None + + query_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) + kv_pages_axis_names: AxisNames = ("paged_kv_heads", "num_pages", "tokens_per_page", "paged_kv_head_dim_size") + + def init_or_get_kv_pages(self, model_mode: str): + """Get paged attention op.""" + # Get existing variables if they exist + if self.has_variable("cache", "key_pages"): + key_pages_var = self.variable("cache", "key_pages") + value_pages_var = self.variable("cache", "value_pages") + + # For AR mode, if shape doesn't match, reinitialize values but not variables + if model_mode != common_types.MODEL_MODE_PREFILL and key_pages_var.value.shape[1] != self.num_pages: + kv_pages_shape = (self.num_kv_heads, self.num_pages, self.tokens_per_page, self.kv_head_dim_size) + key_pages_var.value = jnp.zeros(kv_pages_shape, dtype=self.dtype) + value_pages_var.value = jnp.zeros(kv_pages_shape, dtype=self.dtype) + else: + # Initial creation - choose size based on mode + num_pages = self.max_pages_per_prefill if model_mode == common_types.MODEL_MODE_PREFILL else self.num_pages + kv_pages_shape = (self.num_kv_heads, num_pages, self.tokens_per_page, self.kv_head_dim_size) + + key_pages_var = self.variable( + "cache", + "key_pages", + nn.with_logical_partitioning(jnp.zeros, self.kv_pages_axis_names), + kv_pages_shape, + self.dtype, + ) + value_pages_var = self.variable( + "cache", + "value_pages", + nn.with_logical_partitioning(jnp.zeros, self.kv_pages_axis_names), + kv_pages_shape, + self.dtype, + ) + + # Apply logical constraints + key_pages_var.value = nn.with_logical_constraint(key_pages_var.value, self.kv_pages_axis_names) + value_pages_var.value = nn.with_logical_constraint(value_pages_var.value, self.kv_pages_axis_names) + return key_pages_var, value_pages_var + + def paged_dot_product_attention_with_max_and_sum(self, query, key, value): + b, t, n, d = query.shape + _, s, n_kv, _ = key.shape + + query = jnp.reshape(query, (b, t, n_kv, n // n_kv, d)) + + attn_weights = jnp.einsum("btkgd,bskd->bkgts", query, key) + + causal_mask = jnp.triu(jnp.ones((t, s)), k=1) + causal_mask = jnp.reshape(causal_mask, (1, 1, 1, t, s)) + masked_weights = jnp.where(causal_mask, jnp.full_like(attn_weights, -1e10), attn_weights) + + local_max = jnp.max(masked_weights, axis=-1, keepdims=True) + local_exps = jnp.exp(masked_weights - local_max) + local_sums = jnp.sum(local_exps, axis=-1, keepdims=True) + + attn = jnp.einsum("bkgts,bskd->btkgd", local_exps, value) + attn = jnp.reshape(attn, (b, t, n, d)) + + local_max = jnp.moveaxis(local_max, -2, 1) + local_max = jnp.reshape(local_max, (b, t, n, 1)) + + local_sums = jnp.moveaxis(local_sums, -2, 1) + local_sums = jnp.reshape(local_sums, (b, t, n, 1)) + + return attn, local_max, local_sums + + def paged_attention( + self, + query: Array, + key_pages_var: nn.Variable, + value_pages_var: nn.Variable, + page_state: page_managers.PageState, + ) -> Array: + """Apply Paged Attention. + + Annotations: + b: batch_size + s: sequence length + n: query heads + + k: kv_heads + x: num_pages + p: tokens_per_page + + d: kv_head_dim_size + """ + bsnd = nn.logical_to_mesh_axes(self.query_axis_names) + kxpd = nn.logical_to_mesh_axes(self.kv_pages_axis_names) + batch_q, seqlen_q, num_heads_q, head_dim = query.shape + num_heads_kv, num_pages, tokens_per_page, head_dim = key_pages_var.value.shape + + no_shard = P(None, None, None, None) + + @functools.partial( + shard_map, + mesh=self.mesh, + in_specs=( + no_shard, + no_shard, + no_shard, + P(None), + P(None, None), + None, + ), + out_specs=no_shard, + check_rep=False, + ) + def wrap_paged_attention(q, k_pages, v_pages, lengths, page_indices, pages_per_compute_block): + q = jnp.squeeze(q, axis=1) + result = paged_attention( + q=q, + k_pages=k_pages, + v_pages=v_pages, + lengths=lengths, + page_indices=page_indices, + pages_per_compute_block=pages_per_compute_block, + ) + return jnp.expand_dims(result, axis=1) + + return wrap_paged_attention( + query, + key_pages_var.value, + value_pages_var.value, + page_state.sequence_lengths, + page_state.page_map, + self.pages_per_compute_block, + ) + + @nn.compact + def __call__( + self, + query: Array, + key: Array, + value: Array, + decoder_segment_ids: Array, + model_mode: str, + page_state: page_managers.PageState, + ) -> Array: + """Apply paged attention mechanism. + + Returns: + tuple: (output, exponentials_max, exponentials_sum) where the latter two + are None for autoregressive mode (handled by paged_attention kernel) + """ + key_pages_var, value_pages_var = self.init_or_get_kv_pages(model_mode) + self.update(key_pages_var, value_pages_var, key, value, model_mode, page_state) + + if model_mode == common_types.MODEL_MODE_PREFILL: + return self.paged_dot_product_attention_with_max_and_sum(query, key, value) + elif model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: + ar_output = self.paged_attention(query, key_pages_var, value_pages_var, page_state) + return ar_output, None, None + + def update( + self, + key_pages_var: nn.Variable, + value_pages_var: nn.Variable, + key: Array, + value: Array, + model_mode: str, + page_state: Optional[page_managers.PageState] = None, + ) -> None: + """Update KV Pages.""" + if model_mode == common_types.MODEL_MODE_PREFILL: + self.update_prefill_step_pages(key_pages_var, value_pages_var, key, value) + elif model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: + self.update_decode_step_pages(key_pages_var, value_pages_var, key, value, page_state) + + def update_prefill_step_pages( + self, + key_pages_var: nn.Variable, + value_pages_var: nn.Variable, + key: Array, + value: Array, + ) -> None: + """Update pages for prefill step.""" + + assert ( + key.shape == value.shape + ), f"prefill_step key/value should have the same shape, but getting {key.shape=} and {value.shape=} instead" + b, t, n_kv, d = key.shape + assert t % self.tokens_per_page == 0 + assert ( + key_pages_var.value.shape == value_pages_var.value.shape + ), f"prefill_step key/value_pages_var should have the same shape, but getting {key_pages_var.shape=} and {value_pages_var.shape=} instead" + + v_n_kv, v_n_p, v_p, v_d = key_pages_var.value.shape + assert v_n_kv == n_kv, f"{v_n_kv=} {n_kv=}" + assert v_p == self.tokens_per_page, f"{v_p=} {self.tokens_per_page=}" + assert v_d == d, f"{v_d=} {d=}" + assert v_n_p == self.max_pages_per_prefill, f"{v_n_p=} {self.max_pages_per_prefill=}" + + # Handle both init (b>1) and runtime (b=1) cases + if b == 1: + key = jnp.squeeze(key) + value = jnp.squeeze(value) + else: + key = key[0] + value = value[0] + + key = jnp.transpose(key, axes=(1, 0, 2)) + value = jnp.transpose(value, axes=(1, 0, 2)) + + key = jnp.reshape(key, shape=(n_kv, t // self.tokens_per_page, self.tokens_per_page, d)) + value = jnp.reshape(value, shape=(n_kv, t // self.tokens_per_page, self.tokens_per_page, d)) + + key_pages_var.value = nn.with_logical_constraint(key, self.kv_pages_axis_names) + value_pages_var.value = nn.with_logical_constraint(value, self.kv_pages_axis_names) + + def update_decode_step_pages(self, key_pages_var, value_pages_var, key, value, page_state): + key_pages = key_pages_var.value + value_pages = value_pages_var.value + + batch_size, seq_len, kv_heads, head_dim = key.shape + kv_heads, num_pages, tokens_per_page, head_dim = key_pages.shape + + new_key = key.reshape(batch_size, kv_heads, head_dim)[:, :, :] + new_key = jnp.transpose(new_key, (1, 0, 2)) # [n_kv, b, d] + new_value = value.reshape(batch_size, kv_heads, head_dim)[:, :, :] + new_value = jnp.transpose(new_value, (1, 0, 2)) # n_kv, b, d + + broadcast_pages = jnp.tile(page_state.current_page, (kv_heads, 1)) # [n_kv, b] + broadcast_pos = jnp.tile(page_state.current_page_position, (kv_heads, 1)) # [n_kv, b] + kv_indices = jnp.arange(kv_heads)[:, None] # [n_kv, 1] + kv_indices = jnp.tile(kv_indices, (1, batch_size)) # [n_kv, b] + + key_pages_updated = key_pages.at[kv_indices, broadcast_pages, broadcast_pos].set(new_key) + value_pages_updated = value_pages.at[kv_indices, broadcast_pages, broadcast_pos].set(new_value) + + key_pages_updated = nn.with_logical_constraint(key_pages_updated, self.kv_pages_axis_names) + value_pages_updated = nn.with_logical_constraint(value_pages_updated, self.kv_pages_axis_names) + + key_pages_var.value = key_pages_updated + value_pages_var.value = value_pages_updated + return key_pages_var, value_pages_var + + class AttentionOp(nn.Module): config: Config mesh: Mesh @@ -231,6 +485,7 @@ def apply_attention( self.attention_kernel == "dot_product" or (self.attention_kernel == "autoselected" and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE) or (self.attention_kernel == "autoselected" and length < 128) + or (self.attention_kernel == "paged") ): return self.apply_attention_dot(query, key, value, decoder_segment_ids, model_mode) elif self.attention_kernel == "flash" or self.attention_kernel == "autoselected": @@ -1021,7 +1276,7 @@ def normalize_attention(self, local_outs, local_maxes, local_sums): return attn_out @nn.compact - def __call__(self, query, key, value, decoder_segment_ids, model_mode): + def __call__(self, query, key, value, decoder_segment_ids, model_mode, page_state=None): prefill_kv_cache, ar_kv_cache = self.kv_cache( key, value, decoder_segment_ids, model_mode, use_ragged_attention=self.use_ragged_attention ) @@ -1244,6 +1499,7 @@ def __call__( *, model_mode: str = common_types.MODEL_MODE_TRAIN, deterministic: bool = False, + page_state: Optional[page_managers.PageState] = None, ): """Applies Attention on the input data. @@ -1295,36 +1551,55 @@ def __call__( value = checkpoint_name(value, "value_proj") assert not self.config.quantize_kvcache or self.kv_quant - attention_op = AttentionOp( - config=self.config, - mesh=self.mesh, - attention_kernel=self.attention_kernel, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.max_prefill_predict_length, - float32_qk_product=self.float32_qk_product, - float32_logits=self.float32_logits, - quant=self.quant, - kv_quant=self.kv_quant, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - dropout_rate=self.dropout_rate, - dtype=self.dtype, - prefill_cache_axis_order=self.prefill_cache_axis_order, - ar_cache_axis_order=self.ar_cache_axis_order, - compute_axis_order=self.compute_axis_order, - reshape_q=self.reshape_q, - attention_type=self.attention_type, - attn_logits_soft_cap=self.attn_logits_soft_cap, - sliding_window_size=self.sliding_window_size, - use_ragged_attention=self.use_ragged_attention, - ragged_block_size=self.ragged_block_size, - ) - out = attention_op(query, key, value, decoder_segment_ids, model_mode) + if self.attention_kernel == "paged" and model_mode != common_types.MODEL_MODE_TRAIN: + attention_op = PagedAttentionOp( + mesh=self.mesh, + num_pages=self.config.num_pages, + tokens_per_page=self.config.tokens_per_page, + max_pages_per_slot=self.config.max_target_length // self.config.tokens_per_page, + max_pages_per_prefill=self.config.max_prefill_predict_length // self.config.tokens_per_page, + pages_per_compute_block=self.config.pages_per_compute_block, + num_kv_heads=self.num_kv_heads, + kv_head_dim_size=self.head_dim, + dtype=self.dtype, + attn_logits_soft_cap=self.attn_logits_soft_cap, + ) + else: + attention_op = AttentionOp( + config=self.config, + mesh=self.mesh, + attention_kernel=self.attention_kernel, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, + float32_qk_product=self.float32_qk_product, + float32_logits=self.float32_logits, + quant=self.quant, + kv_quant=self.kv_quant, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + dropout_rate=self.dropout_rate, + dtype=self.dtype, + prefill_cache_axis_order=self.prefill_cache_axis_order, + ar_cache_axis_order=self.ar_cache_axis_order, + compute_axis_order=self.compute_axis_order, + reshape_q=self.reshape_q, + attention_type=self.attention_type, + attn_logits_soft_cap=self.attn_logits_soft_cap, + sliding_window_size=self.sliding_window_size, + use_ragged_attention=self.use_ragged_attention, + ragged_block_size=self.ragged_block_size, + ) + + attention_output = attention_op(query, key, value, decoder_segment_ids, model_mode, page_state=page_state) - out = nn.with_logical_constraint(out, self.out_axis_names) + if self.config.attention == "paged" and model_mode != common_types.MODEL_MODE_TRAIN: + unnormalized_out, _, exp_sum = attention_output + out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out + else: + out = attention_output - # apply output projection, output dim is set to the input dim. + out = nn.with_logical_constraint(out, self.out_axis_names) out = self.out_projection(inputs_q.shape[-1], out) out = checkpoint_name(out, "out_proj") return out diff --git a/MaxText/layers/llama2.py b/MaxText/layers/llama2.py index 9b198c594..ce1f8b8a4 100644 --- a/MaxText/layers/llama2.py +++ b/MaxText/layers/llama2.py @@ -31,6 +31,7 @@ from layers import quantizations import common_types +import page_managers from typing import Optional Array = common_types.Array @@ -74,6 +75,7 @@ def __call__( decoder_positions, deterministic, model_mode, + page_state: Optional[page_managers.PageState] = None, ): cfg = self.config mesh = self.mesh @@ -122,6 +124,7 @@ def __call__( decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, + page_state=page_state, ) attention_lnx = nn.with_logical_constraint( diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 4c2046c1f..7cb274fa6 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -25,6 +25,7 @@ import jax.numpy as jnp from jax.ad_checkpoint import checkpoint_name import common_types +import page_managers from layers import attentions from layers import embeddings from layers import linears @@ -274,6 +275,7 @@ def __call__( decoder_segment_ids=None, deterministic=False, model_mode=common_types.MODEL_MODE_TRAIN, + page_state: Optional[page_managers.PageState] = None, ): cfg = self.config mesh = self.mesh @@ -399,6 +401,7 @@ def __call__( decoder_positions, deterministic, model_mode, + page_state=page_state, ) y = self.get_norm_layer()( @@ -453,6 +456,10 @@ def setup(self): cfg = self.config mesh = self.mesh + if self.config.attention == "paged": + assert self.config.max_target_length % self.config.tokens_per_page == 0 + assert self.config.max_prefill_predict_length % self.config.tokens_per_page == 0 + self.shared_embedding = Embed( num_embeddings=cfg.vocab_size, features=cfg.emb_dim, @@ -465,6 +472,16 @@ def setup(self): self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding, mesh=mesh, quant=self.quant) + if self.config.attention == "paged": + self.page_manager = page_managers.PageManager( + num_pages=self.config.num_pages, + tokens_per_page=self.config.tokens_per_page, + slots=int(self.config.per_device_batch_size * jax.device_count()), + max_target_length=self.config.max_target_length, + max_prefill_predict_length=self.config.max_prefill_predict_length, + max_pages_per_slot=self.config.max_target_length // self.config.tokens_per_page, + ) + def __call__( self, decoder_input_tokens, @@ -472,6 +489,8 @@ def __call__( decoder_segment_ids=None, enable_dropout=True, model_mode=common_types.MODEL_MODE_TRAIN, + slot: Optional[int] = None, + true_length: Optional[int] = None, ): """Applies Transformer decoder-branch on encoded-input and target.""" @@ -481,11 +500,23 @@ def __call__( f" which is always {common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR}." ) + page_state = None + if self.config.attention == "paged": + if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: + page_state = self.page_manager(model_mode) + elif model_mode == common_types.MODEL_MODE_PREFILL: + page_state = self.page_manager( + model_mode=model_mode, + slot=slot, + true_length=true_length, + ) + logits = self.decoder( decoder_input_tokens=decoder_input_tokens, decoder_positions=decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=not enable_dropout, model_mode=model_mode, + page_state=page_state, ) return logits diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 04504f78e..042840cbc 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -954,11 +954,8 @@ def get_abstract_state(model, tx, config, rng, mesh, is_training=True): def get_prefill_kv_cache_annotations(model, config, rng, mesh): """Get a shaped abstraction of the state (including optimizer)""" - def init_kv_cache(model, config): - input_shape = ( - config.global_batch_size_to_load, - config.max_prefill_predict_length, - ) + def init_prefill_kv_cache(model, config): + input_shape = (config.global_batch_size_to_load, config.max_prefill_predict_length) model_vars = model.init( {"params": rng, "dropout": rng, "aqt": rng}, @@ -969,8 +966,8 @@ def init_kv_cache(model, config): return model_vars["cache"] with nn_partitioning.axis_rules(config.logical_axis_rules): - init_kv_cache_partial = functools.partial(init_kv_cache, model, config) - abstract_state = jax.eval_shape(init_kv_cache_partial) + init_prefill_kv_cache_partial = functools.partial(init_prefill_kv_cache, model, config) + abstract_state = jax.eval_shape(init_prefill_kv_cache_partial) state_logical_annotations = nn.get_partition_spec(abstract_state) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations) @@ -980,7 +977,7 @@ def init_kv_cache(model, config): def get_kv_cache_annotations(model, config, rng, mesh): """Get a shaped abstraction of the state (including optimizer)""" - def init_kv_cache(model, config): + def init_ar_kv_cache(model, config): input_shape = ( config.global_batch_size_to_load, 1, @@ -995,8 +992,8 @@ def init_kv_cache(model, config): return model_vars["cache"] with nn_partitioning.axis_rules(config.logical_axis_rules): - init_kv_cache_partial = functools.partial(init_kv_cache, model, config) - abstract_state = jax.eval_shape(init_kv_cache_partial) + init_ar_kv_cache_partial = functools.partial(init_ar_kv_cache, model, config) + abstract_state = jax.eval_shape(init_ar_kv_cache_partial) state_logical_annotations = nn.get_partition_spec(abstract_state) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations) @@ -1082,3 +1079,26 @@ def print_system_information(): max_logging.log(f"System Information: Jax Version: {jax.__version__}") max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") max_logging.log(f"System Information: Jax Backend: {jax.lib.xla_bridge.get_backend().platform_version}") + + +def debug_array(array, array_name): + """Debug array sizing and sharding across chips.""" + print(f"\t{array_name}:") + if isinstance(array, flax.linen.spmd.LogicallyPartitioned): + array = array.value + single_shard = array.addressable_shards[0] + n_shards = len(array.addressable_shards) + total_size_across_n_shards = single_shard.data.size * n_shards + total_nbytes_across_n_shards = single_shard.data.nbytes * n_shards + print(f"\t\tdtype: {array.dtype}") + print(f"\t\tshape: {array.shape}") + print(f"\t\tsharding spec: {array.sharding.spec}") + print(f"\t\tdevice local layouer: {array.layout.device_local_layout}") + print(f"\t\tsize (across n shards): {array.size} ({total_size_across_n_shards})") + print(f"\t\tbytes (across n shards): {array.nbytes} ({total_nbytes_across_n_shards})") + + +def debug_qtensor(qtensor, qtensor_name): + """Debug qtensor sizing and sharding across chips.""" + debug_array(qtensor.qvalue, f"{qtensor_name} qvalue") + debug_array(qtensor.scale[0], f"{qtensor_name} scale") diff --git a/MaxText/maxengine.py b/MaxText/maxengine.py index c0c5835e1..bde3484f6 100644 --- a/MaxText/maxengine.py +++ b/MaxText/maxengine.py @@ -216,6 +216,7 @@ def prefill( true_length: int, sampler: Optional[Callable[[Any], Any]] = None, # pylint: disable=unused-argument rng: Optional[PRNGKeyType] = None, + slot: int = 0, ) -> Tuple[Prefix, engine_api.ResultTokens]: """Computes a kv-cache for a new generate request. @@ -244,6 +245,7 @@ def prefill( sequence_indicator = jnp.expand_dims(one_d_output, 0) rng, new_rng = jax.random.split(rng) + with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): flat_logits, new_vars = self.model.apply( params, @@ -254,6 +256,8 @@ def prefill( model_mode=common_types.MODEL_MODE_PREFILL, rngs={"params": new_rng}, mutable=["cache"], + slot=slot, + true_length=true_length, ) next_pos = jnp.full((1, 1), true_length, dtype=jnp.int32) @@ -534,12 +538,44 @@ def copy(path, partial_cache, full_cache, annotations): else: raise ValueError(f"We don't have a strategy for inserting {path_key}") - inserted_cache = jax.tree_util.tree_map_with_path( - copy, - unboxed_prefix["cache"], - decode_state["cache"], - self.kv_cache_annotations_named, - ) + if self.config.attention == "paged": + + def copy_paged(path, prefix_cache, decode_state_cache): + if path[-2].key == "page_manager": + return prefix_cache + path_key = path[-1].key + if path_key in ["key_pages", "value_pages"]: + + def _update_pages(prefix_page_idx, state): + decode_state_pages, prefix_pages, page_map = state + prefix_page = jax.lax.dynamic_index_in_dim(prefix_pages, prefix_page_idx, axis=1) + decode_state_pages = jax.lax.dynamic_update_slice_in_dim( + decode_state_pages, prefix_page, page_map[prefix_page_idx], axis=1 + ) + return decode_state_pages, prefix_pages, page_map + + decode_state_cache, _, _ = jax.lax.fori_loop( + 0, + prefix["cache"]["page_manager"]["num_pages_used"].value[slot], + _update_pages, + (decode_state_cache, prefix_cache, prefix["cache"]["page_manager"]["page_map"].value[slot]), + ) + return decode_state_cache + else: + raise ValueError(f"We don't have a strategy for inserting {path_key} for paged attention.") + + inserted_cache = jax.tree_util.tree_map_with_path( + copy_paged, + unboxed_prefix["cache"], + decode_state["cache"], + ) + else: + inserted_cache = jax.tree_util.tree_map_with_path( + copy, + unboxed_prefix["cache"], + decode_state["cache"], + self.kv_cache_annotations_named, + ) inserted_logits = jax.lax.dynamic_update_index_in_dim(decode_state["logits"], unboxed_prefix["logits"], slot, 0) inserted_next_pos = jax.lax.dynamic_update_index_in_dim(decode_state["next_pos"], unboxed_prefix["next_pos"], slot, 0) inserted_generated_tokens = jax.lax.dynamic_update_index_in_dim( diff --git a/MaxText/page_managers.py b/MaxText/page_managers.py new file mode 100644 index 000000000..c2ece636f --- /dev/null +++ b/MaxText/page_managers.py @@ -0,0 +1,407 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Page Managers for implementing paged attention in MaxText. + +This module provides the PageManager class and associated PageState dataclass for +managing the paged attention mechanism. The paging system allows efficient handling +of variable-length sequences by dividing the attention context into fixed-size pages, +similar to virtual memory systems. +""" + +from typing import Optional, Tuple + +from flax import linen as nn +from flax import struct +import jax +import jax.numpy as jnp + +import common_types + +Array = common_types.Array +DType = common_types.DType +AxisNames = common_types.AxisNames + +# pylint: disable=too-many-positional-arguments + + +@struct.dataclass +class PageState: + """Represents the current state of the paging system. + + Attributes: + page_status: Array indicating whether each page is in use (1) or free (0) + page_map: Array mapping slots to their assigned pages + sequence_lengths: Array containing the current length of each sequence + num_pages_used: Array tracking how many pages each slot is using + current_page: Array indicating the current active page for each slot + current_page_position: Array tracking position within current pages + """ + + page_status: Array + page_map: Array + sequence_lengths: Array + num_pages_used: Array + current_page: Array + current_page_position: Array + + +class PageManager(nn.Module): + """Manages paged attention mechanism for efficient sequence processing. + + The PageManager implements a virtual memory-like system for attention, where the + attention context is divided into fixed-size pages. This allows efficient handling + of variable-length sequences and helps manage memory usage during inference. + + Attributes: + num_pages: Total number of available pages in the system + tokens_per_page: Number of tokens that can be stored in each page + slots: Number of sequence slots available for parallel processing + max_target_length: Maximum length of target sequences + max_prefill_predict_length: Maximum length for prefill prediction + max_pages_per_slot: Maximum number of pages that can be assigned to a slot + """ + + num_pages: int + tokens_per_page: int + slots: int + max_target_length: int + max_prefill_predict_length: int + max_pages_per_slot: int + + def init_or_get_vars(self): + """Initializes or retrieves the state variables for the paging system. + + Returns: + Tuple of nn.Variable objects representing: + - page_status: Status of each page (free/used) + - page_map: Mapping between slots and their assigned pages + - sequence_lengths: Length of sequence in each slot + - num_pages_used: Number of pages used by each slot + - current_page: Current active page for each slot + - current_page_position: Position within current pages + """ + page_status_var = self.variable( + "cache", "page_status", nn.with_logical_partitioning(jnp.zeros, ("num_pages",)), (self.num_pages,), jnp.int32 + ) + page_map_var = self.variable( + "cache", + "page_map", + nn.with_logical_partitioning(jnp.zeros, ("slots", "max_pages_per_slot")), + (self.slots, self.max_pages_per_slot), + jnp.int32, + ) + sequence_lengths_var = self.variable( + "cache", "sequence_lengths", nn.with_logical_partitioning(jnp.zeros, ("slots",)), (self.slots,), jnp.int32 + ) + num_pages_used_var = self.variable( + "cache", "num_pages_used", nn.with_logical_partitioning(jnp.zeros, ("slots",)), (self.slots,), jnp.int32 + ) + current_page_var = self.variable( + "cache", "current_page", nn.with_logical_partitioning(jnp.zeros, ("slots",)), (self.slots,), jnp.int32 + ) + current_page_position_var = self.variable( + "cache", "current_page_position", nn.with_logical_partitioning(jnp.zeros, ("slots",)), (self.slots,), jnp.int32 + ) + + return ( + page_status_var, + page_map_var, + sequence_lengths_var, + num_pages_used_var, + current_page_var, + current_page_position_var, + ) + + def release_slot_pages( + self, + slot: int, + page_status_var: nn.Variable, + page_map_var: nn.Variable, + sequence_lengths_var: nn.Variable, + num_pages_used_var: nn.Variable, + current_page_var: nn.Variable, + current_page_position_var: nn.Variable, + ) -> Tuple: + """Releases all pages assigned to a specific slot. + + This method frees up all pages currently assigned to the given slot, + resetting their status and updating the page mapping accordingly. + + Args: + slot: Integer identifying the slot to be released + page_status_var: Variable tracking page usage status + page_map_var: Variable mapping slots to pages + sequence_lengths_var: Variable tracking sequence lengths + num_pages_used_var: Variable tracking page usage counts + current_page_var: Variable tracking current active pages + current_page_position_var: Variable tracking positions in current pages + + Returns: + Tuple of updated variables after releasing the slot's pages + """ + page_status = page_status_var.value + page_map = page_map_var.value + sequence_lengths = sequence_lengths_var.value + num_pages_used = num_pages_used_var.value + current_page = current_page_var.value + current_page_position = current_page_position_var.value + + def _release_page(i, state): + page_map, page_status = state + page_idx = page_map[slot][i] + page_status = page_status.at[page_idx].set(0) + page_map = page_map.at[slot, i].set(0) + return page_map, page_status + + page_map, page_status = jax.lax.fori_loop(0, num_pages_used[slot], _release_page, (page_map, page_status)) + + sequence_lengths = sequence_lengths.at[slot].set(0) + num_pages_used = num_pages_used.at[slot].set(0) + current_page = current_page.at[slot].set(0) + current_page_position = current_page_position.at[slot].set(0) + + page_status_var.value = page_status + page_map_var.value = page_map + sequence_lengths_var.value = sequence_lengths + num_pages_used_var.value = num_pages_used + current_page_var.value = current_page + current_page_position_var.value = current_page_position + + return ( + page_status_var, + page_map_var, + sequence_lengths_var, + num_pages_used_var, + current_page_var, + current_page_position_var, + ) + + def reserve_prefix_slot_pages( + self, + slot: int, + true_length: int, + page_status_var: nn.Variable, + page_map_var: nn.Variable, + sequence_lengths_var: nn.Variable, + num_pages_used_var: nn.Variable, + current_page_var: nn.Variable, + current_page_position_var: nn.Variable, + ) -> Tuple: + """Reserves pages for a prefix sequence in the specified slot. + + This method allocates the necessary pages for a prefix sequence of given length, + first releasing any existing pages assigned to the slot. + + Args: + slot: Integer identifying the target slot + true_length: Actual length of the prefix sequence + page_status_var: Variable tracking page usage status + page_map_var: Variable mapping slots to pages + sequence_lengths_var: Variable tracking sequence lengths + num_pages_used_var: Variable tracking page usage counts + current_page_var: Variable tracking current active pages + current_page_position_var: Variable tracking positions in current pages + + Returns: + Tuple of updated variables after reserving pages for the prefix + """ + ( + page_status_var, + page_map_var, + sequence_lengths_var, + num_pages_used_var, + current_page_var, + current_page_position_var, + ) = self.release_slot_pages( + slot, + page_status_var, + page_map_var, + sequence_lengths_var, + num_pages_used_var, + current_page_var, + current_page_position_var, + ) + + page_status = page_status_var.value + page_map = page_map_var.value + sequence_lengths = sequence_lengths_var.value + num_pages_used = num_pages_used_var.value + current_page = current_page_var.value + current_page_position = current_page_position_var.value + + prefill_slot_num_pages = jnp.ceil(true_length / self.tokens_per_page).astype(jnp.int32) + prefill_slot_page_slice_idx = jnp.where(true_length == 0, 0, (true_length - 1) % self.tokens_per_page) + + def _reserve_page(i, state): + slot, page_map, page_status, current_page = state + page_idx = jnp.where((page_status[1:] == 0), size=1)[0][0] + 1 + page_status = page_status.at[page_idx].set(1) + page_map = page_map.at[slot, i].set(page_idx) + current_page = current_page.at[slot].set(page_idx) + return slot, page_map, page_status, current_page + + _, page_map, page_status, current_page = jax.lax.fori_loop( + 0, prefill_slot_num_pages, _reserve_page, (slot, page_map, page_status, current_page) + ) + + sequence_lengths = sequence_lengths.at[slot].set(true_length) + num_pages_used = num_pages_used.at[slot].set(prefill_slot_num_pages) + current_page_position = current_page_position.at[slot].set(prefill_slot_page_slice_idx) + + page_status_var.value = page_status + page_map_var.value = page_map + sequence_lengths_var.value = sequence_lengths + num_pages_used_var.value = num_pages_used + current_page_var.value = current_page + current_page_position_var.value = current_page_position + + return ( + page_status_var, + page_map_var, + sequence_lengths_var, + num_pages_used_var, + current_page_var, + current_page_position_var, + ) + + def reserve_decode_step_pages( + self, + page_status_var: nn.Variable, + page_map_var: nn.Variable, + sequence_lengths_var: nn.Variable, + num_pages_used_var: nn.Variable, + current_page_var: nn.Variable, + current_page_position_var: nn.Variable, + ) -> Tuple: + """Reserves additional pages needed for a decoding step. + + This method allocates new pages as needed when sequences grow during + autoregressive decoding, ensuring each active slot has sufficient pages + for its sequence. + + Args: + page_status_var: Variable tracking page usage status + page_map_var: Variable mapping slots to pages + sequence_lengths_var: Variable tracking sequence lengths + num_pages_used_var: Variable tracking page usage counts + current_page_var: Variable tracking current active pages + current_page_position_var: Variable tracking positions in current pages + + Returns: + Tuple of updated variables after reserving pages for the decode step + """ + page_status = page_status_var.value + page_map = page_map_var.value + sequence_lengths = sequence_lengths_var.value + num_pages_used = num_pages_used_var.value + current_page = current_page_var.value + current_page_position = current_page_position_var.value + + sequence_lengths_step = jnp.logical_and(jnp.ones(sequence_lengths.shape, dtype=jnp.int32), sequence_lengths).astype( + jnp.int32 + ) + + sequence_lengths += sequence_lengths_step + + current_num_pages_used = num_pages_used + num_pages_used = jnp.ceil(sequence_lengths / self.tokens_per_page).astype(jnp.int32) + + current_page_position = jnp.where(sequence_lengths == 0, 0, (sequence_lengths - 1) % self.tokens_per_page) + seq_new_page = num_pages_used - current_num_pages_used + + updating_slots = jnp.where((seq_new_page > 0), size=self.slots)[0] + + def _reserve_page(i, state): + page_map, page_status, current_page, updating_slots = state + slot = jax.lax.dynamic_index_in_dim(updating_slots, i, axis=0, keepdims=False) + page_idx = jnp.where((page_status[1:] == 0), size=1)[0][0] + 1 + page_status = page_status.at[page_idx].set(1) + page_map = page_map.at[slot, num_pages_used[slot] - 1].set(page_idx) + current_page = current_page.at[slot].set(page_idx) + return page_map, page_status, current_page, updating_slots + + page_map, page_status, current_page, _ = jax.lax.fori_loop( + 0, jnp.count_nonzero(seq_new_page), _reserve_page, (page_map, page_status, current_page, updating_slots) + ) + + page_status_var.value = page_status + page_map_var.value = page_map + sequence_lengths_var.value = sequence_lengths + num_pages_used_var.value = num_pages_used + current_page_var.value = current_page + current_page_position_var.value = current_page_position + + return ( + page_status_var, + page_map_var, + sequence_lengths_var, + num_pages_used_var, + current_page_var, + current_page_position_var, + ) + + @nn.compact + def __call__( + self, model_mode: Optional[str] = None, slot: Optional[int] = None, true_length: Optional[int] = None + ) -> PageState: + + ( + page_status_var, + page_map_var, + sequence_lengths_var, + num_pages_used_var, + current_page_var, + current_page_position_var, + ) = self.init_or_get_vars() + + if model_mode == common_types.MODEL_MODE_PREFILL and self.is_mutable_collection("params"): + return PageState( + page_status_var.value, + page_map_var.value, + sequence_lengths_var.value, + num_pages_used_var.value, + current_page_var.value, + current_page_position_var.value, + ) + if model_mode == common_types.MODEL_MODE_PREFILL: + assert slot is not None and true_length is not None, f"but get {slot=} and {true_length=} instead" + self.reserve_prefix_slot_pages( + slot, + true_length, + page_status_var, + page_map_var, + sequence_lengths_var, + num_pages_used_var, + current_page_var, + current_page_position_var, + ) + elif model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: + self.reserve_decode_step_pages( + page_status_var, + page_map_var, + sequence_lengths_var, + num_pages_used_var, + current_page_var, + current_page_position_var, + ) + + return PageState( + page_status_var.value, + page_map_var.value, + sequence_lengths_var.value, + num_pages_used_var.value, + current_page_var.value, + current_page_position_var.value, + ) diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 67ca89cbe..b1b0b776e 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -68,7 +68,7 @@ def validate_kv_quant_axis(s: str, quantize_kvcache: bool) -> None: def validate_attention_kernel(s: str) -> None: - valid_attention_kernels = ("autoselected", "dot_product", "flash", "cudnn_flash_te") + valid_attention_kernels = ("autoselected", "dot_product", "flash", "cudnn_flash_te", "paged") if s not in valid_attention_kernels: # currently supported attention raise ValueError("Invalid attention kernel was passed. Valid options ", valid_attention_kernels) diff --git a/MaxText/tests/paged_attention_test.py b/MaxText/tests/paged_attention_test.py new file mode 100644 index 000000000..5c9b3dd1b --- /dev/null +++ b/MaxText/tests/paged_attention_test.py @@ -0,0 +1,826 @@ +import unittest +import pytest +import jax +import numpy as np +import jax.numpy as jnp +from flax.core import freeze +from flax import linen as nn +from layers.attentions import PagedAttentionOp, Attention +from page_managers import PageManager, PageState +import common_types + + +def reference_attention(query, key, value): + attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key) + attn_weights = jax.nn.softmax(attn_weights, axis=-1) + + return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value) + + +class PagedAttentionTest(unittest.TestCase): + + def setUp(self): + self.cfg = { + "num_query_heads": 8, + "num_kv_heads": 8, + "head_dim": 128, + "max_prefill_predict_length": 512, + "max_target_length": 1024, + "num_pages": 64, + "tokens_per_page": 32, + "pages_per_compute_block": 16, + "dtype": jnp.float32, + } + self.rng = jax.random.PRNGKey(42) + devices = jax.devices() + if len(devices) > 1: + self.mesh = jax.sharding.Mesh(devices, axis_names=("data",)) + else: + # Fallback for single-device testing + self.mesh = jax.sharding.Mesh(devices, axis_names=()) + self.attention_op = PagedAttentionOp( + mesh=self.mesh, + num_pages=self.cfg["num_pages"], + tokens_per_page=self.cfg["tokens_per_page"], + max_pages_per_slot=self.cfg["max_target_length"] // self.cfg["tokens_per_page"], + max_pages_per_prefill=self.cfg["max_prefill_predict_length"] // self.cfg["tokens_per_page"], + pages_per_compute_block=self.cfg["pages_per_compute_block"], + num_kv_heads=self.cfg["num_kv_heads"], + kv_head_dim_size=self.cfg["head_dim"], + dtype=self.cfg["dtype"], + ) + + @pytest.mark.tpu_only + def test_paged_attention_output_shape(self): + attention_op = PagedAttentionOp( + mesh=self.mesh, + num_pages=self.cfg["num_pages"], + tokens_per_page=self.cfg["tokens_per_page"], + max_pages_per_slot=self.cfg["max_target_length"] // self.cfg["tokens_per_page"], + max_pages_per_prefill=self.cfg["max_prefill_predict_length"] // self.cfg["tokens_per_page"], + pages_per_compute_block=self.cfg["pages_per_compute_block"], + num_kv_heads=self.cfg["num_kv_heads"], + kv_head_dim_size=self.cfg["head_dim"], + dtype=self.cfg["dtype"], + ) + + query = jnp.ones((1, self.cfg["max_prefill_predict_length"], self.cfg["num_query_heads"], self.cfg["head_dim"])) + key = jnp.ones((1, self.cfg["max_prefill_predict_length"], self.cfg["num_kv_heads"], self.cfg["head_dim"])) + value = jnp.ones((1, self.cfg["max_prefill_predict_length"], self.cfg["num_kv_heads"], self.cfg["head_dim"])) + + page_state = PageState( + page_status=jnp.zeros(self.cfg["num_pages"], dtype=jnp.int32), + page_map=jnp.zeros((1, self.cfg["num_pages"]), dtype=jnp.int32), + sequence_lengths=jnp.zeros(1, dtype=jnp.int32), + num_pages_used=jnp.zeros(1, dtype=jnp.int32), + current_page=jnp.zeros(1, dtype=jnp.int32), + current_page_position=jnp.zeros(1, dtype=jnp.int32), + ) + + variables = attention_op.init(self.rng, query, key, value, None, common_types.MODEL_MODE_PREFILL, page_state) + output_tuple, mutated_variables = attention_op.apply( + variables, query, key, value, None, common_types.MODEL_MODE_PREFILL, page_state, mutable=["cache"] + ) + + output, _, _ = output_tuple # Unpack the tuple + self.assertEqual( + output.shape, (1, self.cfg["max_prefill_predict_length"], self.cfg["num_query_heads"], self.cfg["head_dim"]) + ) + + @pytest.mark.tpu_only + def test_paged_dot_product_attention_with_max_and_sum(self): + query = jnp.ones((1, self.cfg["max_prefill_predict_length"], self.cfg["num_query_heads"], self.cfg["head_dim"])) + key = jnp.ones((1, self.cfg["max_prefill_predict_length"], self.cfg["num_kv_heads"], self.cfg["head_dim"])) + value = jnp.ones((1, self.cfg["max_prefill_predict_length"], self.cfg["num_kv_heads"], self.cfg["head_dim"])) + + output, max_vals, sum_vals = self.attention_op.paged_dot_product_attention_with_max_and_sum(query, key, value) + self.assertEqual( + output.shape, (1, self.cfg["max_prefill_predict_length"], self.cfg["num_query_heads"], self.cfg["head_dim"]) + ) + self.assertEqual(max_vals.shape[-1], 1) + self.assertEqual(sum_vals.shape[-1], 1) + + @pytest.mark.tpu_only + def test_update_prefill_step_pages(self): + key = jnp.ones((1, self.cfg["max_prefill_predict_length"], self.cfg["num_kv_heads"], self.cfg["head_dim"])) + value = jnp.ones((1, self.cfg["max_prefill_predict_length"], self.cfg["num_kv_heads"], self.cfg["head_dim"])) + page_state = PageState( + page_status=jnp.zeros(self.cfg["num_pages"], dtype=jnp.int32), + page_map=jnp.zeros((1, self.cfg["num_pages"]), dtype=jnp.int32), + sequence_lengths=jnp.zeros(1, dtype=jnp.int32), + num_pages_used=jnp.zeros(1, dtype=jnp.int32), + current_page=jnp.zeros(1, dtype=jnp.int32), + current_page_position=jnp.zeros(1, dtype=jnp.int32), + ) + variables = self.attention_op.init( + self.rng, + jnp.ones((1, 1, self.cfg["num_query_heads"], self.cfg["head_dim"])), + key, + value, + None, + common_types.MODEL_MODE_PREFILL, + page_state, + ) + + # Use apply() to update the variable's value + _, mutated_variables = self.attention_op.apply( + variables, + query=jnp.ones((1, 1, self.cfg["num_query_heads"], self.cfg["head_dim"])), + key=key, # Provide the key + value=value, # Provide the value + decoder_segment_ids=None, # Provide a None or a suitable Array for decoder_segment_ids + model_mode=common_types.MODEL_MODE_PREFILL, # Provide the model mode + page_state=page_state, # Provide the page_state + mutable=["cache"], + ) + + # Access the updated values from mutated_variables + updated_key_pages_var = mutated_variables["cache"]["key_pages"] + updated_value_pages_var = mutated_variables["cache"]["value_pages"] + + # Assertions using the updated variables + self.assertEqual( + updated_key_pages_var.value.shape, + ( + self.cfg["num_kv_heads"], + self.cfg["max_prefill_predict_length"] // self.cfg["tokens_per_page"], + self.cfg["tokens_per_page"], + self.cfg["head_dim"], + ), + ) + self.assertEqual( + updated_value_pages_var.value.shape, + ( + self.cfg["num_kv_heads"], + self.cfg["max_prefill_predict_length"] // self.cfg["tokens_per_page"], + self.cfg["tokens_per_page"], + self.cfg["head_dim"], + ), + ) + + @pytest.mark.tpu_only + def test_update_decode_step_pages(self): + """Test cache update during autoregressive generation.""" + batch_size = 1 + # Create distinctive key/value patterns + rng1, rng2 = jax.random.split(self.rng) + key = jax.random.normal(rng1, (batch_size, 1, self.cfg["num_kv_heads"], self.cfg["head_dim"])) + value = jax.random.normal(rng2, (batch_size, 1, self.cfg["num_kv_heads"], self.cfg["head_dim"])) + + # Initialize page state at specific position + test_page = 2 + test_position = 3 + page_state = PageState( + page_status=jnp.zeros(self.cfg["num_pages"], dtype=jnp.int32), + page_map=jnp.zeros((batch_size, self.cfg["num_pages"]), dtype=jnp.int32), + sequence_lengths=jnp.ones(batch_size, dtype=jnp.int32), + num_pages_used=jnp.zeros(batch_size, dtype=jnp.int32), + current_page=jnp.array([test_page], dtype=jnp.int32), + current_page_position=jnp.array([test_position], dtype=jnp.int32), + ) + + # Initialize attention op and run update + variables = self.attention_op.init( + self.rng, + query=key, # Use key as query for initialization + key=key, + value=value, + decoder_segment_ids=None, + model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE, + page_state=page_state, + ) + + _, mutated_variables = self.attention_op.apply( + variables, + query=key, + key=key, + value=value, + decoder_segment_ids=None, + model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE, + page_state=page_state, + mutable=["cache"], + ) + + # Extract updated cache + updated_key_pages = mutated_variables["cache"]["key_pages"] + updated_value_pages = mutated_variables["cache"]["value_pages"] + + # Verify shapes + self.assertEqual( + updated_key_pages.value.shape, + (self.cfg["num_kv_heads"], self.cfg["num_pages"], self.cfg["tokens_per_page"], self.cfg["head_dim"]), + ) + + # Instead of trying to extract logical axes from the value, + # verify against the attention op's configured axis names + self.assertEqual( + self.attention_op.kv_pages_axis_names, ("paged_kv_heads", "num_pages", "tokens_per_page", "paged_kv_head_dim_size") + ) + + # Verify key placement + zeros = jnp.zeros_like(updated_key_pages.value) + expected_key_pages = zeros.at[:, test_page, test_position, :].set(jnp.squeeze(key)) + np.testing.assert_allclose(updated_key_pages.value, expected_key_pages, rtol=1e-5, atol=1e-5) + + # Verify surrounding positions are unchanged + np.testing.assert_allclose( + updated_key_pages.value[:, test_page, test_position + 1 :, :], + zeros[:, test_page, test_position + 1 :, :], + rtol=1e-5, + atol=1e-5, + ) + + @pytest.mark.tpu_only + def test_prefill_attention(self): + batch_size, seq_len, num_heads, head_dim = 1, self.cfg["max_prefill_predict_length"], 8, 128 + query = jax.random.normal(self.rng, (batch_size, seq_len, num_heads, head_dim)) + key = jax.random.normal(self.rng, (batch_size, seq_len, num_heads, head_dim)) + value = jax.random.normal(self.rng, (batch_size, seq_len, num_heads, head_dim)) + + page_state = PageState( + page_status=jnp.zeros(self.cfg["num_pages"], dtype=jnp.int32), + page_map=jnp.zeros((1, self.cfg["num_pages"]), dtype=jnp.int32), + sequence_lengths=jnp.array([seq_len], dtype=jnp.int32), + num_pages_used=jnp.zeros(1, dtype=jnp.int32), + current_page=jnp.zeros(1, dtype=jnp.int32), + current_page_position=jnp.zeros(1, dtype=jnp.int32), + ) + + variables = self.attention_op.init(self.rng, query, key, value, None, common_types.MODEL_MODE_PREFILL, page_state) + output_tuple, _ = self.attention_op.apply( + variables, query, key, value, None, common_types.MODEL_MODE_PREFILL, page_state, mutable=["cache"] + ) + paged_output, max_vals, sum_vals = output_tuple + + # Normalize the output using the returned max and sum values + paged_output = paged_output / (sum_vals + 1e-9) # Add epsilon for numerical stability + reference_output = reference_attention(query, key, value) + + np.testing.assert_allclose(paged_output, reference_output, rtol=1e-5, atol=1e-5) + + @pytest.mark.tpu_only + def test_autoregressive_attention(self): + batch_size, seq_len, num_heads, head_dim = 1, 1, 8, 128 + query = jax.random.normal(self.rng, (batch_size, 1, num_heads, head_dim)) + key = jax.random.normal(self.rng, (batch_size, seq_len, num_heads, head_dim)) + value = jax.random.normal(self.rng, (batch_size, seq_len, num_heads, head_dim)) + + page_state = PageState( + page_status=jnp.zeros(self.cfg["num_pages"], dtype=jnp.int32), + page_map=jnp.zeros((1, self.cfg["num_pages"]), dtype=jnp.int32), + sequence_lengths=jnp.array([seq_len], dtype=jnp.int32), + num_pages_used=jnp.zeros(1, dtype=jnp.int32), + current_page=jnp.zeros(1, dtype=jnp.int32), + current_page_position=jnp.zeros(1, dtype=jnp.int32), + ) + + variables = self.attention_op.init(self.rng, query, key, value, None, common_types.MODEL_MODE_AUTOREGRESSIVE, page_state) + output_tuple, _ = self.attention_op.apply( + variables, query, key, value, None, common_types.MODEL_MODE_AUTOREGRESSIVE, page_state, mutable=["cache"] + ) + + # In autoregressive mode, normalization is handled internally + paged_output, _, _ = output_tuple + reference_output = reference_attention(query, key, value) + np.testing.assert_allclose(paged_output, reference_output, rtol=1e-2, atol=1e-2) + + @pytest.mark.tpu_only + def test_basic_prefill(self): + """Test just the prefill operation without any AR steps.""" + batch_size = 1 # Prefill requires batch_size=1 + seq_len = self.cfg["max_prefill_predict_length"] + num_heads = 8 + head_dim = 128 + + # Create input sequence + prefill_tokens = jax.random.normal(self.rng, (batch_size, seq_len, num_heads, head_dim)) + + # Initialize page state + page_state = PageState( + page_status=jnp.zeros(self.cfg["num_pages"], dtype=jnp.int32), + page_map=jnp.zeros((batch_size, self.cfg["num_pages"]), dtype=jnp.int32), + sequence_lengths=jnp.array([seq_len], dtype=jnp.int32), + num_pages_used=jnp.zeros(batch_size, dtype=jnp.int32), + current_page=jnp.zeros(batch_size, dtype=jnp.int32), + current_page_position=jnp.zeros(batch_size, dtype=jnp.int32), + ) + + # Initialize attention op + variables = self.attention_op.init( + self.rng, prefill_tokens, prefill_tokens, prefill_tokens, None, common_types.MODEL_MODE_PREFILL, page_state + ) + + output_tuple, _ = self.attention_op.apply( + variables, + prefill_tokens, + prefill_tokens, + prefill_tokens, + None, + common_types.MODEL_MODE_PREFILL, + page_state, + mutable=["cache"], + ) + + paged_prefill_output, max_vals, sum_vals = output_tuple + paged_prefill_output = paged_prefill_output / (sum_vals + 1e-9) + + # Compare with reference implementation + reference_output = reference_attention(prefill_tokens, prefill_tokens, prefill_tokens) + np.testing.assert_allclose( + paged_prefill_output, reference_output, rtol=1e-2, atol=1e-2, err_msg="Prefill outputs don't match reference" + ) + + @pytest.mark.tpu_only + def test_prefill_then_single_ar(self): + """Test basic prefill followed by single AR step matches reference impl.""" + batch_size = 1 + prefill_len = self.cfg["max_prefill_predict_length"] + num_heads = 8 + head_dim = 128 + + # Create input sequence + rng1, rng2 = jax.random.split(self.rng) + prefill_tokens = jax.random.normal(rng1, (batch_size, prefill_len, num_heads, head_dim)) + + # Initialize page state + page_state = PageState( + page_status=jnp.zeros(self.cfg["num_pages"], dtype=jnp.int32), + page_map=jnp.zeros((batch_size, self.cfg["num_pages"]), dtype=jnp.int32), + sequence_lengths=jnp.array([prefill_len], dtype=jnp.int32), + num_pages_used=jnp.zeros(batch_size, dtype=jnp.int32), + current_page=jnp.zeros(batch_size, dtype=jnp.int32), + current_page_position=jnp.zeros(batch_size, dtype=jnp.int32), + ) + + # Initialize attention ops + variables = self.attention_op.init( + self.rng, prefill_tokens, prefill_tokens, prefill_tokens, None, common_types.MODEL_MODE_PREFILL, page_state + ) + + output_tuple, mutated_vars = self.attention_op.apply( + variables, + prefill_tokens, + prefill_tokens, + prefill_tokens, + None, + common_types.MODEL_MODE_PREFILL, + page_state, + mutable=["cache"], + ) + + prefill_output, max_vals, sum_vals = output_tuple + prefill_output = prefill_output / (sum_vals + 1e-9) + + # Use updated variables for AR step + variables = mutated_vars + ar_token = jax.random.normal(rng2, (batch_size, 1, num_heads, head_dim)) + + ar_output_tuple, _ = self.attention_op.apply( + variables, ar_token, ar_token, ar_token, None, common_types.MODEL_MODE_AUTOREGRESSIVE, page_state, mutable=["cache"] + ) + + ar_output, _, _ = ar_output_tuple + + # Compare with reference implementation including prefill context + full_sequence = jnp.concatenate([prefill_tokens, ar_token], axis=1) + reference_ar_output = reference_attention(ar_token, full_sequence, full_sequence) + + np.testing.assert_allclose(ar_output, reference_ar_output, rtol=1e-2, atol=1e-2, err_msg="AR outputs don't match") + + @pytest.mark.tpu_only + def test_basic_ar(self): + """Test just the autoregressive operation with a single step.""" + batch_size = 1 + num_heads = 8 + head_dim = 128 + + # Create separate random values for query/key/value + rng1, rng2, rng3 = jax.random.split(self.rng, 3) + query = jax.random.normal(rng1, (batch_size, 1, num_heads, head_dim)) + key = jax.random.normal(rng2, (batch_size, 1, num_heads, head_dim)) + value = jax.random.normal(rng3, (batch_size, 1, num_heads, head_dim)) + + # Initialize page state with sequence length of 1 + page_state = PageState( + page_status=jnp.zeros(self.cfg["num_pages"], dtype=jnp.int32), + page_map=jnp.zeros((batch_size, self.cfg["num_pages"]), dtype=jnp.int32), + sequence_lengths=jnp.ones(batch_size, dtype=jnp.int32), # Start with length 1 + num_pages_used=jnp.zeros(batch_size, dtype=jnp.int32), + current_page=jnp.zeros(batch_size, dtype=jnp.int32), + current_page_position=jnp.zeros(batch_size, dtype=jnp.int32), + ) + + # Initialize and apply attention + variables = self.attention_op.init(self.rng, query, key, value, None, common_types.MODEL_MODE_AUTOREGRESSIVE, page_state) + + output_tuple, _ = self.attention_op.apply( + variables, query, key, value, None, common_types.MODEL_MODE_AUTOREGRESSIVE, page_state, mutable=["cache"] + ) + + # AR mode returns (output, None, None) + ar_output, _, _ = output_tuple + + # Compare with reference implementation + reference_output = reference_attention(query, key, value) + np.testing.assert_allclose(ar_output, reference_output, rtol=1e-2, atol=1e-2, err_msg="AR outputs don't match reference") + + @pytest.mark.tpu_only + def test_paged_attention_single_token_batch(self): + """Test attention with batch_size=1, seq_len=1 - smallest possible input.""" + batch_size = 1 + seq_len = self.cfg["tokens_per_page"] * 16 + query = jax.random.normal(self.rng, (batch_size, seq_len, self.cfg["num_query_heads"], self.cfg["head_dim"])) + key = jax.random.normal(self.rng, (batch_size, seq_len, self.cfg["num_kv_heads"], self.cfg["head_dim"])) + value = jax.random.normal(self.rng, (batch_size, seq_len, self.cfg["num_kv_heads"], self.cfg["head_dim"])) + + page_state = PageState( + page_status=jnp.zeros(self.cfg["num_pages"], dtype=jnp.int32), + page_map=jnp.zeros((batch_size, self.cfg["num_pages"]), dtype=jnp.int32), + sequence_lengths=jnp.array([seq_len], dtype=jnp.int32), + num_pages_used=jnp.zeros(batch_size, dtype=jnp.int32), + current_page=jnp.zeros(batch_size, dtype=jnp.int32), + current_page_position=jnp.zeros(batch_size, dtype=jnp.int32), + ) + + variables = self.attention_op.init(self.rng, query, key, value, None, common_types.MODEL_MODE_PREFILL, page_state) + + output_tuple, _ = self.attention_op.apply( + variables, query, key, value, None, common_types.MODEL_MODE_PREFILL, page_state, mutable=["cache"] + ) + + paged_output, max_vals, sum_vals = output_tuple + # Normalize using returned values + paged_output = paged_output / (sum_vals + 1e-9) + + reference_output = reference_attention(query, key, value) + np.testing.assert_allclose( + paged_output, reference_output, rtol=1e-5, atol=1e-5, err_msg="Single token attention outputs don't match reference" + ) + + @pytest.mark.tpu_only + def test_attention_pattern_consistency(self): + """Test attention pattern maintains consistency across prefill and autoregressive steps.""" + batch_size = 1 + seq_len = self.cfg["max_prefill_predict_length"] + + query = jax.random.normal(self.rng, (batch_size, seq_len, self.cfg["num_query_heads"], self.cfg["head_dim"])) + key = jax.random.normal(self.rng, (batch_size, seq_len, self.cfg["num_kv_heads"], self.cfg["head_dim"])) + value = jax.random.normal(self.rng, (batch_size, seq_len, self.cfg["num_kv_heads"], self.cfg["head_dim"])) + + page_state = PageState( + page_status=jnp.zeros(self.cfg["num_pages"], dtype=jnp.int32), + page_map=jnp.zeros((batch_size, self.cfg["num_pages"]), dtype=jnp.int32), + sequence_lengths=jnp.array([seq_len], dtype=jnp.int32), + num_pages_used=jnp.zeros(batch_size, dtype=jnp.int32), + current_page=jnp.zeros(batch_size, dtype=jnp.int32), + current_page_position=jnp.zeros(batch_size, dtype=jnp.int32), + ) + + # Run prefill + variables = self.attention_op.init(self.rng, query, key, value, None, common_types.MODEL_MODE_PREFILL, page_state) + output_tuple, mutated_vars = self.attention_op.apply( + variables, query, key, value, None, common_types.MODEL_MODE_PREFILL, page_state, mutable=["cache"] + ) + + prefill_output, _, _ = output_tuple + reference_output = reference_attention(query, key, value) + np.testing.assert_allclose(prefill_output, reference_output, rtol=1e-5, atol=1e-5) + + # Test single autoregressive step + ar_query = jax.random.normal(self.rng, (batch_size, 1, self.cfg["num_query_heads"], self.cfg["head_dim"])) + ar_key = jax.random.normal(self.rng, (batch_size, 1, self.cfg["num_kv_heads"], self.cfg["head_dim"])) + ar_value = jax.random.normal(self.rng, (batch_size, 1, self.cfg["num_kv_heads"], self.cfg["head_dim"])) + + ar_output_tuple, _ = self.attention_op.apply( + mutated_vars, ar_query, ar_key, ar_value, None, common_types.MODEL_MODE_AUTOREGRESSIVE, page_state, mutable=["cache"] + ) + + ar_output, _, _ = ar_output_tuple + + # Compare against reference + full_key = jnp.concatenate([key, ar_key], axis=1) + full_value = jnp.concatenate([value, ar_value], axis=1) + ar_reference = reference_attention(ar_query, full_key, full_value) + assert ar_output.shape == ar_reference.shape + np.testing.assert_allclose(ar_output, ar_reference, rtol=1e-2, atol=1e-2) + + @pytest.mark.tpu_only + def test_sequential_page_updates(self): + """Test multiple sequential page updates to verify cache consistency.""" + batch_size = 1 + seq_len = 1 + num_heads = 8 + head_dim = 128 + + # Create initial key/value + rng1, rng2 = jax.random.split(self.rng) + key = jax.random.normal(rng1, (batch_size, seq_len, num_heads, head_dim)) + value = jax.random.normal(rng2, (batch_size, seq_len, num_heads, head_dim)) + + # Initialize page state for first position + page_state = PageState( + page_status=jnp.zeros(self.cfg["num_pages"], dtype=jnp.int32), + page_map=jnp.zeros((batch_size, self.cfg["num_pages"]), dtype=jnp.int32), + sequence_lengths=jnp.ones(batch_size, dtype=jnp.int32), + num_pages_used=jnp.zeros(batch_size, dtype=jnp.int32), + current_page=jnp.array([0], dtype=jnp.int32), + current_page_position=jnp.array([0], dtype=jnp.int32), + ) + + # Initialize attention op + variables = self.attention_op.init( + self.rng, key, key, value, None, common_types.MODEL_MODE_AUTOREGRESSIVE, page_state # Use as query too + ) + + # Perform multiple sequential updates + num_updates = 3 + expected_values = [] + + for i in range(num_updates): + # Generate new key/value + rng1, rng2 = jax.random.split(rng1) + new_key = jax.random.normal(rng1, (batch_size, seq_len, num_heads, head_dim)) + new_value = jax.random.normal(rng2, (batch_size, seq_len, num_heads, head_dim)) + expected_values.append((new_key, new_value)) + + # Update cache + _, variables = self.attention_op.apply( + variables, new_key, new_key, new_value, None, common_types.MODEL_MODE_AUTOREGRESSIVE, page_state, mutable=["cache"] + ) + + # Update page state + page_state = PageState( + page_status=page_state.page_status, + page_map=page_state.page_map, + sequence_lengths=page_state.sequence_lengths + 1, + num_pages_used=page_state.num_pages_used, + current_page=page_state.current_page, + current_page_position=jnp.array([i + 1], dtype=jnp.int32), + ) + + # Verify cache contents + cache = variables["cache"] + key_pages = cache["key_pages"] + value_pages = cache["value_pages"] + + # Check each position + for i, (expected_key, expected_value) in enumerate(expected_values): + for head in range(num_heads): + np.testing.assert_allclose( + key_pages.value[head, 0, i], + expected_key[0, 0, head], + rtol=1e-5, + atol=1e-5, + err_msg=f"Mismatch in key cache at position {i}, head {head}", + ) + np.testing.assert_allclose( + value_pages.value[head, 0, i], + expected_value[0, 0, head], + rtol=1e-5, + atol=1e-5, + err_msg=f"Mismatch in value cache at position {i}, head {head}", + ) + + @pytest.mark.tpu_only + def test_page_boundary_conditions(self): + """Test attention computation across page boundaries.""" + batch_size = 1 + seq_len = self.cfg["tokens_per_page"] * 2 # Two pages exactly + num_heads = 8 + head_dim = 128 + + # Create attention op with exactly 2 pages for prefill + attention_op = PagedAttentionOp( + mesh=self.mesh, + num_pages=self.cfg["num_pages"], + tokens_per_page=self.cfg["tokens_per_page"], + max_pages_per_slot=self.cfg["max_target_length"] // self.cfg["tokens_per_page"], + max_pages_per_prefill=2, # Override to exactly what we need + pages_per_compute_block=self.cfg["pages_per_compute_block"], + num_kv_heads=self.cfg["num_kv_heads"], + kv_head_dim_size=self.cfg["head_dim"], + dtype=self.cfg["dtype"], + ) + + # Create distinct patterns for each page + rng1, rng2 = jax.random.split(self.rng) + query_page1 = jax.random.normal(rng1, (batch_size, self.cfg["tokens_per_page"], num_heads, head_dim)) + query_page2 = jax.random.normal(rng2, (batch_size, self.cfg["tokens_per_page"], num_heads, head_dim)) + query = jnp.concatenate([query_page1, query_page2], axis=1) + + key = query # Use same patterns for key and value for simplicity + value = query + + page_state = PageState( + page_status=jnp.zeros(self.cfg["num_pages"], dtype=jnp.int32), + page_map=jnp.zeros((batch_size, self.cfg["num_pages"]), dtype=jnp.int32), + sequence_lengths=jnp.array([seq_len], dtype=jnp.int32), + num_pages_used=jnp.zeros(batch_size, dtype=jnp.int32), + current_page=jnp.zeros(batch_size, dtype=jnp.int32), + current_page_position=jnp.zeros(batch_size, dtype=jnp.int32), + ) + + variables = attention_op.init(self.rng, query, key, value, None, common_types.MODEL_MODE_PREFILL, page_state) + + output_tuple, _ = attention_op.apply( + variables, query, key, value, None, common_types.MODEL_MODE_PREFILL, page_state, mutable=["cache"] + ) + + output, max_vals, sum_vals = output_tuple + output = output / (sum_vals + 1e-9) + reference_output = reference_attention(query, key, value) + + # Test boundary attention patterns + + # 1. Check last token of first page + boundary_idx = self.cfg["tokens_per_page"] + np.testing.assert_allclose( + output[:, boundary_idx - 1 : boundary_idx], + reference_output[:, boundary_idx - 1 : boundary_idx], + rtol=1e-5, + atol=1e-5, + err_msg="Last token of first page doesn't match reference", + ) + + # 2. Check first token of second page + np.testing.assert_allclose( + output[:, boundary_idx : boundary_idx + 1], + reference_output[:, boundary_idx : boundary_idx + 1], + rtol=1e-5, + atol=1e-5, + err_msg="First token of second page doesn't match reference", + ) + + # 3. Check boundary transition + window_size = 4 # Check 2 tokens on each side of boundary + boundary_window = slice(boundary_idx - window_size // 2, boundary_idx + window_size // 2) + np.testing.assert_allclose( + output[:, boundary_window], + reference_output[:, boundary_window], + rtol=1e-5, + atol=1e-5, + err_msg="Attention pattern at page boundary doesn't match reference", + ) + + # 4. Verify no discontinuities at boundary + attention_diff = jnp.abs(output[:, boundary_idx] - output[:, boundary_idx - 1]) + self.assertTrue(jnp.all(attention_diff < 1e3), "Detected unexpected discontinuity at page boundary") + + # 5. Verify overall output + np.testing.assert_allclose( + output, reference_output, rtol=1e-5, atol=1e-5, err_msg="Complete attention output doesn't match reference" + ) + + @pytest.mark.tpu_only + def test_page_reuse(self): + """Test page reuse after releasing pages.""" + batch_size = 1 + seq_len = 1 + num_heads = 8 + head_dim = 128 + + # Initialize with one sequence + key1 = jax.random.normal(self.rng, (batch_size, seq_len, num_heads, head_dim)) + value1 = jax.random.normal(self.rng, (batch_size, seq_len, num_heads, head_dim)) + + # Initialize page state for first sequence + page_state = PageState( + page_status=jnp.zeros(self.cfg["num_pages"], dtype=jnp.int32), + page_map=jnp.zeros((batch_size, self.cfg["num_pages"]), dtype=jnp.int32), + sequence_lengths=jnp.ones(batch_size, dtype=jnp.int32), + num_pages_used=jnp.zeros(batch_size, dtype=jnp.int32), + current_page=jnp.zeros(batch_size, dtype=jnp.int32), + current_page_position=jnp.zeros(batch_size, dtype=jnp.int32), + ) + + variables = self.attention_op.init( + self.rng, key1, key1, value1, None, common_types.MODEL_MODE_AUTOREGRESSIVE, page_state + ) + + # Store first sequence + _, mutated_vars = self.attention_op.apply( + variables, key1, key1, value1, None, common_types.MODEL_MODE_AUTOREGRESSIVE, page_state, mutable=["cache"] + ) + + # Create new sequence with different values + key2 = jax.random.normal(self.rng, (batch_size, seq_len, num_heads, head_dim)) + value2 = jax.random.normal(self.rng, (batch_size, seq_len, num_heads, head_dim)) + + # Reset page state (simulating page release) + page_state = PageState( + page_status=jnp.zeros(self.cfg["num_pages"], dtype=jnp.int32), + page_map=jnp.zeros((batch_size, self.cfg["num_pages"]), dtype=jnp.int32), + sequence_lengths=jnp.ones(batch_size, dtype=jnp.int32), + num_pages_used=jnp.zeros(batch_size, dtype=jnp.int32), + current_page=jnp.zeros(batch_size, dtype=jnp.int32), + current_page_position=jnp.zeros(batch_size, dtype=jnp.int32), + ) + + # Store second sequence in same location + output_tuple, final_vars = self.attention_op.apply( + mutated_vars, key2, key2, value2, None, common_types.MODEL_MODE_AUTOREGRESSIVE, page_state, mutable=["cache"] + ) + + output, _, _ = output_tuple + reference_output = reference_attention(key2, key2, value2) + + # Verify second sequence is stored correctly + np.testing.assert_allclose( + output, reference_output, rtol=1e-2, atol=1e-2, err_msg="Page reuse produced incorrect attention output" + ) + + @pytest.mark.tpu_only + def test_multi_head_consistency(self): + """Test consistency across different attention heads.""" + batch_size = 1 + seq_len = self.cfg["max_prefill_predict_length"] + num_heads = self.cfg["num_query_heads"] + head_dim = self.cfg["head_dim"] + + # Create input where each head gets different patterns + query = jnp.stack( + [jax.random.normal(self.rng, (batch_size, seq_len, head_dim)) * (i + 1) for i in range(num_heads)], axis=2 + ) + + key = jnp.stack( + [jax.random.normal(self.rng, (batch_size, seq_len, head_dim)) * (i + 1) for i in range(self.cfg["num_kv_heads"])], + axis=2, + ) + + value = jnp.stack( + [jax.random.normal(self.rng, (batch_size, seq_len, head_dim)) * (i + 1) for i in range(self.cfg["num_kv_heads"])], + axis=2, + ) + + page_state = PageState( + page_status=jnp.zeros(self.cfg["num_pages"], dtype=jnp.int32), + page_map=jnp.zeros((batch_size, self.cfg["num_pages"]), dtype=jnp.int32), + sequence_lengths=jnp.array([seq_len], dtype=jnp.int32), + num_pages_used=jnp.zeros(batch_size, dtype=jnp.int32), + current_page=jnp.zeros(batch_size, dtype=jnp.int32), + current_page_position=jnp.zeros(batch_size, dtype=jnp.int32), + ) + + variables = self.attention_op.init(self.rng, query, key, value, None, common_types.MODEL_MODE_PREFILL, page_state) + + output_tuple, _ = self.attention_op.apply( + variables, query, key, value, None, common_types.MODEL_MODE_PREFILL, page_state, mutable=["cache"] + ) + + output, max_vals, sum_vals = output_tuple + output = output / (sum_vals + 1e-9) + reference_output = reference_attention(query, key, value) + + # Check each head separately + for head in range(num_heads): + np.testing.assert_allclose( + output[:, :, head, :], + reference_output[:, :, head, :], + rtol=1e-5, + atol=1e-5, + err_msg=f"Head {head} attention output doesn't match reference", + ) + + @pytest.mark.tpu_only + def test_long_sequence_stability(self): + """Test numerical stability with long sequences.""" + batch_size = 1 + seq_len = self.cfg["max_prefill_predict_length"] + num_heads = 8 + head_dim = 128 + + # Create sequence with large magnitude differences + query = jax.random.normal(self.rng, (batch_size, seq_len, num_heads, head_dim)) * 10 + key = jax.random.normal(self.rng, (batch_size, seq_len, num_heads, head_dim)) * 10 + value = jax.random.normal(self.rng, (batch_size, seq_len, num_heads, head_dim)) + + page_state = PageState( + page_status=jnp.zeros(self.cfg["num_pages"], dtype=jnp.int32), + page_map=jnp.zeros((batch_size, self.cfg["num_pages"]), dtype=jnp.int32), + sequence_lengths=jnp.array([seq_len], dtype=jnp.int32), + num_pages_used=jnp.zeros(batch_size, dtype=jnp.int32), + current_page=jnp.zeros(batch_size, dtype=jnp.int32), + current_page_position=jnp.zeros(batch_size, dtype=jnp.int32), + ) + + variables = self.attention_op.init(self.rng, query, key, value, None, common_types.MODEL_MODE_PREFILL, page_state) + + output_tuple, _ = self.attention_op.apply( + variables, query, key, value, None, common_types.MODEL_MODE_PREFILL, page_state, mutable=["cache"] + ) + + output, max_vals, sum_vals = output_tuple + output = output / (sum_vals + 1e-9) + reference_output = reference_attention(query, key, value) + + # Check numerical stability + np.testing.assert_allclose( + output, reference_output, rtol=1e-5, atol=1e-5, err_msg="Long sequence attention is numerically unstable" + ) + + # Verify that max values aren't too large (check for overflow) + self.assertTrue(jnp.all(jnp.abs(max_vals) < 1e5), "Attention weights may be experiencing numerical overflow") + + +if __name__ == "__main__": + unittest.main()