Skip to content

Commit

Permalink
paged attention
Browse files Browse the repository at this point in the history
  • Loading branch information
patemotter committed Jan 31, 2025
1 parent 449cbcb commit c0e8573
Show file tree
Hide file tree
Showing 12 changed files with 1,759 additions and 48 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ celerybeat.pid

# Environments
.env
.venv
.venv*
env/
venv/
venv*/
ENV/
env.bak/
venv.bak/
Expand Down
102 changes: 102 additions & 0 deletions MaxText/benchmarks/bench_paged_attention_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import argparse
import functools
import time

import jax
import jax.numpy as jnp
from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention

BLOCK_SIZE = 16
MAX_NUM_BLOCKS_PER_SEQ = 512


@functools.partial(jax.jit, static_argnums=(6, 7))
def paged_attn(
q: jax.Array, # [batch, 1, num_heads, head_size]
k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
sm_scale: float,
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
context_lens: jax.Array, # [batch]
block_size: int,
pages_per_compute_block: int,
) -> jax.Array: # [batch, 1, num_heads, head_size]
q = q.squeeze(1)
q = q * sm_scale

head_size = q.shape[-1]
num_slots = k_cache.shape[-2]
k_cache = k_cache.reshape(-1, num_slots // block_size, block_size, head_size)
v_cache = v_cache.reshape(-1, num_slots // block_size, block_size, head_size)

output = paged_attention(
q,
k_cache,
v_cache,
context_lens,
block_tables,
pages_per_compute_block=pages_per_compute_block,
)
return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2])


def benchmark_paged_attn(
batch_size: int,
num_heads: int,
num_kv_heads: int,
head_size: int,
context_len: int,
num_blocks: int,
block_size: int,
pages_per_compute_block: int,
):
rng_key = jax.random.PRNGKey(0)
query = jax.random.normal(rng_key, (batch_size, 1, num_heads, head_size), dtype=jnp.bfloat16)
k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
sm_scale = head_size**-0.5
block_tables = jax.random.randint(rng_key, (batch_size, MAX_NUM_BLOCKS_PER_SEQ), 0, num_blocks, dtype=jnp.int32)
context_lens = jnp.array([context_len] * batch_size, dtype=jnp.int32)

# For JIT compilation.
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, block_size, pages_per_compute_block)
output.block_until_ready()

start = time.time()
for _ in range(100):
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, block_size, pages_per_compute_block)
output.block_until_ready()
end = time.time()

print(f"Time taken: {(end - start) * 10000:.2f} us")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-heads", type=int, default=64)
parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-size", type=int, default=128)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--context-len", type=int, default=1024)
parser.add_argument("--num-blocks", type=int, default=2048)
parser.add_argument("--block-size", type=int, default=16)
args = parser.parse_args()
print(args)

for block_size in [16, 32, 64, 128]:
for pages_per_compute_block in [1, 2, 4, 8, 16, 32, 64, 128]:
if pages_per_compute_block > MAX_NUM_BLOCKS_PER_SEQ:
continue
if block_size * pages_per_compute_block > 1024:
continue
print(f"block_size {block_size}, pages_per_compute_block: {pages_per_compute_block}")
benchmark_paged_attn(
args.batch_size,
args.num_heads,
args.num_kv_heads,
args.head_size,
args.context_len,
args.num_blocks,
block_size,
pages_per_compute_block,
)
9 changes: 9 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ logical_axis_rules: [
['cache_kv', []],
['cache_sequence', []],
['exp', 'expert'],
['paged_kv_heads', []],
['num_pages', ['tensor']],
['tokens_per_page', []],
['paged_kv_head_dim_size', []],
]
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_sequence', 'expert', 'autoregressive']]
Expand Down Expand Up @@ -542,3 +546,8 @@ sa_use_fused_bwd_kernel: False
sa_q_layout: "HEAD_DIM_MINOR"
sa_k_layout: "HEAD_DIM_MINOR"
sa_v_layout: "HEAD_DIM_MINOR"

# Paged Attention
num_pages: 64
tokens_per_page: 32
pages_per_compute_block: 8
4 changes: 3 additions & 1 deletion MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ def main(argv: Sequence[str]) -> None:

# Split RNG before calling prefill
rng, rng_prefill = jax.random.split(rng)
prefill_result, first_token = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill)
slot = 0
prefill_result, first_token = engine.prefill(
params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill, slot=slot
)

rng, rng_init_decode = jax.random.split(rng)
decode_state = engine.init_decode_state(rng_init_decode)
Expand Down
Loading

0 comments on commit c0e8573

Please sign in to comment.