SDPA¶
Scaled Dot-Product Attention with optional causal masking, sliding window, and attention sinks.
Use this class instead of torch.nn.functional.scaled_dot_product_attention when you need the full attention operation preserved as a single composite op in the lowered IR. The is_causal, window_size, and sinks options compose: you can enable causal masking, restrict the attended context with a sliding window, and designate a fixed number of global sink tokens — all in a single externalized op.
Constructor¶
SDPA(scale=None, is_causal=False, window_size=0)
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
|
|
Attention scale factor. If |
|
|
|
Apply lower-right causal mask. |
|
|
|
Sliding window size. |
Forward¶
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
sinks: torch.Tensor | None = None,
) -> torch.Tensor
Input names variants¶
Arguments provided |
|
|---|---|
|
|
|
|
|
|
|
|
ExternalizeSpec¶
ExternalizeSpec(
target_class=SDPA,
composite_op_name="scaled_dot_product_attention",
composite_attrs=["scale", "is_causal", "window_size"],
)
Supported attention schemas¶
SDPA covers Multi-Head, Grouped-Query, and Multi-Query Attention based on the relationship between N_q (query heads) and N_kv (key/value heads):
Schema |
Constraint |
Example |
|---|---|---|
Multi-Head Attention (MHA) |
|
32 query heads, 32 kv heads |
Grouped-Query Attention (GQA) |
|
32 query heads, 8 kv heads |
Multi-Query Attention (MQA) |
|
32 query heads, 1 kv head |
Tensor shapes: query is [B, N_q, T_q, D], key is [B, N_kv, T_kv, D], value is [B, N_kv, T_kv, D_v]. For GQA / MQA, do not pre-tile key / value to match N_q — pass them with their native N_kv and the broadcasting is recorded as part of the composite op.