Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make float32_qk_product and float32_logits apply during inference #1225

Merged
merged 3 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ logits_via_embedding: False
normalize_embedding_logits: True # whether to normlize pre-softmax logits if logits_via_embedding is true
logits_dot_in_fp32: False # whether to use fp32 in logits_dense or shared_embedding dot product for stability
cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher precision is generally beneficial, but it can vary slightly.
float32_qk_product: False # in dot_product attention, whether to cast to fp32 the inputs to qk product
float32_logits: False # in dot_product attention, whether to cast to fp32 the inputs to softmax

# mixture of experts (moe)
num_experts: 1
Expand Down
4 changes: 2 additions & 2 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def apply_attention_dot(
"""Apply Attention."""
validate_compute_axis_order(self.compute_axis_order)
# Casting qk_product and softmaxt computation for float32 for model stability.
if model_mode == common_types.MODEL_MODE_TRAIN and self.float32_qk_product:
if self.float32_qk_product:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you have to set precision as well for float32 to actually take effect https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision - we have this option in maxtext

matmul_precision: "default"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point which I hadn't considered. I suppose there are two moving parts here: are the flops performed in bf16 or fp32, and are the results accumulated in bf16 or fp32. I think these are generally controlled with the precision and preferred_element_type arguments, respectively.

It appears that on default precision the flops happen in bf16 even when one or more of the inputs are in fp32. However, the accumulation can still happen in fp32, and that seems to have been enough to solve our particular problem. In particular, the compiler seems to recognize that even though the python says to upcast to fp32, it can elide that because it's going to do the computation. However, it still outputs fp32.

This is the qk product with float32_qk_product=False
Screenshot 2025-01-31 at 4 53 15 PM

And this is with float32_qk_product=True (note the output type is now f32)
Screenshot 2025-01-31 at 3 14 11 PM

I'm not 100% confident in my interpretation of those graphs, but this would explain why it takes longer even without changing the precision parameter.

Separately, it looks like matmul_precision consistently gets routed into DenseGeneral usages, but not into the raw einsums used in qk_product and wv_product. When I change matmul_precision in the config it does not affect the runtime of those operations, but if I add it explicitly to the einsums then the wv_product does take longer, which makes sense. Is that something we should fix just by adding those arguments to the einsums?

if isinstance(key, KVTensor):
key = key.dequant()
query = query.astype(jnp.float32)
Expand All @@ -491,7 +491,7 @@ def apply_attention_dot(
attn_weights = attn_weights * self.attn_logits_soft_cap

# Casting softmaxt computation for float32 for model stability.
if model_mode == common_types.MODEL_MODE_TRAIN and self.float32_logits:
if self.float32_logits:
attn_weights = attn_weights.astype(jnp.float32)
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)
if attn_mask is not None:
Expand Down
4 changes: 2 additions & 2 deletions MaxText/layers/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def __call__(
weight_dtype=cfg.weight_dtype,
dropout_rate=cfg.dropout_rate,
name="self_attention",
float32_qk_product=True,
float32_logits=True,
float32_qk_product=cfg.float32_qk_product,
float32_logits=cfg.float32_logits,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(cfg),
use_ragged_attention=cfg.use_ragged_attention,
Expand Down
4 changes: 2 additions & 2 deletions MaxText/layers/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def __call__(
weight_dtype=cfg.weight_dtype,
dropout_rate=cfg.dropout_rate,
name="self_attention_local",
float32_qk_product=True,
float32_logits=True,
float32_qk_product=cfg.float32_qk_product,
float32_logits=cfg.float32_logits,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(cfg),
attention_type=attentions.AttentionType.LOCAL_SLIDING,
Expand Down
2 changes: 2 additions & 0 deletions MaxText/layers/gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ def __call__(
mesh=mesh,
dropout_rate=cfg.dropout_rate,
name="self_attention",
float32_qk_product=cfg.float32_qk_product,
float32_logits=cfg.float32_logits,
fused_qkv=cfg.fused_qkv,
use_bias=True,
quant=self.quant,
Expand Down
2 changes: 2 additions & 0 deletions MaxText/layers/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def __call__(
weight_dtype=cfg.weight_dtype,
dropout_rate=cfg.dropout_rate,
name="self_attention",
float32_qk_product=cfg.float32_qk_product,
float32_logits=cfg.float32_logits,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(cfg),
prefill_cache_axis_order=tuple([int(i) for i in cfg.prefill_cache_axis_order.split(",")]),
Expand Down
2 changes: 2 additions & 0 deletions MaxText/layers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def __call__(
weight_dtype=cfg.weight_dtype,
dropout_rate=cfg.dropout_rate,
name="self_attention",
float32_qk_product=cfg.float32_qk_product,
float32_logits=cfg.float32_logits,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(cfg),
)
Expand Down
2 changes: 2 additions & 0 deletions MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def __call__(
weight_dtype=cfg.weight_dtype,
dropout_rate=cfg.dropout_rate,
name="self_attention",
float32_qk_product=cfg.float32_qk_product,
float32_logits=cfg.float32_logits,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(cfg),
prefill_cache_axis_order=tuple([int(i) for i in cfg.prefill_cache_axis_order.split(",")]),
Expand Down