RoPE¶
Rotary Positional Embedding (Su et al., 2021) — encodes both absolute position and relative distance between tokens by rotating pairs of elements in the last dimension. The rotation angle is derived from the token’s absolute position; either the embedding is split in half or its alternate elements are paired (interleaved=True) before the rotation matrix is applied.
Three ways to drive the rotation, in priority order:
Pass precomputed
cosandsindirectly.Pass
position_ids(and optionallyfreqs) — the op constructscos/sininternally.Pass nothing extra — the op derives
position_idsfromoffset+scaleandfreqsfrombase.
Use position_ids when your model computes position indices externally (custom sequence packing or variable-length inputs). Use cos/sin when you have precomputed frequency tensors. Use offset for KV-cache decoding steps where only a single new token position is needed.
Constructor¶
RoPE(scale=1.0, base=1e4, dims=None, interleaved=False)
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
|
|
Frequency scaling factor applied to positions. |
|
|
|
Base for the geometric frequency sequence. |
|
|
|
Number of dimensions to rotate. If |
|
|
|
If |
Forward¶
def forward(
self,
input: torch.Tensor,
cos: torch.Tensor | None = None,
sin: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
freqs: torch.Tensor | None = None,
offset: torch.Tensor | None = None,
) -> torch.Tensor
Argument |
Description |
|---|---|
|
Tensor of rank ≥ 3, shape |
|
Precomputed cosines / sines, broadcastable to |
|
Position indices, broadcastable to |
|
Custom angular frequencies of shape |
|
Starting position for the sequence. Tensor of shape |
When cos/sin are not provided, position ids are constructed as position_ids = (offset + arange(seq_len)) * scale, and frequencies are freqs[i] = 1 / base ** (i / (embed/2)).
Optional input resolution order¶
If
cosandsinare both provided, use them directly.Else, build
cos/sinfromposition_idsandfreqs:position_ids: use the argument if provided; otherwise construct fromoffsetandscale.freqs: use the argument if provided; otherwise construct frombase.
Input names variants¶
Arguments provided |
|
|---|---|
|
|
|
|
|
|
|
|
|
|
ExternalizeSpec¶
ExternalizeSpec(
target_class=RoPE,
composite_op_name="rope",
composite_attrs=["scale", "base", "dims", "interleaved"],
)
Partial rotation (dims)¶
When dims is set to a positive even integer smaller than embed, only the first dims features are rotated; the rest pass through unchanged:
y_partial = rope(input[..., :dims])
output = torch.cat([y_partial, input[..., dims:]], dim=-1)
When dims is None or dims >= embed, the full last dimension is rotated.
Data types¶
Tensor |
Allowed dtypes |
|---|---|
|
|
|
integer |
input, cos, and sin dtypes must be promotable; the output dtype is the promoted type.