GatedDeltaUpdate¶
Gated Delta Network recurrence — a linear-complexity alternative to softmax attention for sequence modeling. Used in modern efficient attention mechanisms like Delta Networks (Qwen3-Next) and other recurrent-style transformers. Use this op when your model implements a state-space or linear recurrence layer that you want preserved as a single composite op in the lowered IR. The state tensor S is a key-value memory matrix that accumulates over timesteps; initial_state lets you pass a cached state for autoregressive generation or chunked processing.
Dimension variables¶
Symbol |
Meaning |
|---|---|
|
Batch size |
|
Sequence length |
|
Number of attention heads for query and key (note: Q and K have the same head count, unlike SDPA) |
|
Number of attention heads for value |
|
Per-head dim for query / key |
|
Per-head dim for value / output |
Constructor¶
GatedDeltaUpdate(use_qk_l2_norm=True)
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
|
|
Whether to apply L2 normalization to query and key before the delta update. |
Forward¶
def forward(
self,
query: torch.Tensor, # [B, N_kq_heads, S, D_k]
key: torch.Tensor, # [B, N_kq_heads, S, D_k]
value: torch.Tensor, # [B, N_v_heads, S, D_v]
g: torch.Tensor, # [B, N_v_heads, S]
beta: torch.Tensor, # [B, N_v_heads, S]
initial_state: torch.Tensor, # [B, N_v_heads, D_k, D_v]
) -> tuple[torch.Tensor, torch.Tensor]
Argument |
Description |
|---|---|
|
Queries. |
|
Keys. Same head count as |
|
Values. May have a different head count than |
|
Gate / decay factors. Should be negative (the op applies |
|
Update strength factors, typically in |
|
Recurrent state from prior sequence — pass zeros for a fresh sequence. |
Returns (output, final_state):
Return value |
Shape |
Description |
|---|---|---|
|
|
Per-timestep retrieval outputs. Same dtype as input. |
|
|
State matrix after processing all |
ExternalizeSpec¶
ExternalizeSpec(
target_class=GatedDeltaUpdate,
composite_op_name="gated_delta_update",
composite_attrs=["use_qk_l2_norm"],
)
Data types¶
fp32, fp16, bf16 for all tensor inputs and outputs.
Constraints¶
All input tensors must have compatible (promotable) dtypes; the output dtype matches the promoted input dtype.
gshould be negative — the op appliesexpinternally andexp(g)must lie in[0, 1]for the decay to behave correctly.betais typically in[0, 1](often the output of a sigmoid).
L2 normalization flag¶
When use_qk_l2_norm=True (default), the op applies L2 normalization to query and key before the recurrence. Set it to False if your model already L2-normalizes Q/K externally.