RMSNormImpl

Root Mean Square Layer Normalization (Zhang & Sennrich, 2019).

\[\text{RMSNorm}(x, \gamma) = \frac{x}{\sqrt{E[x^2] + \varepsilon}} \cdot \gamma\]

RMSNormImpl is the true composite op — the class the converter externalizes as rms_norm. It takes both the input x and the scale γ as explicit forward arguments so that, when externalized, the scale appears as a graph input on the composite op boundary rather than being baked in as a constant from a sibling parameter. Hold the scale as an nn.Parameter on your enclosing module and pass it through.

Constructor

RMSNormImpl(eps=1e-5)

Parameter

Type

Default

Description

eps

float

1e-5

Epsilon for numerical stability.

The reduction axis is fixed to the last dimension (axes = -1).

Forward

def forward(self, input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor

Normalizes input over its last dimension and multiplies by scale. The caller owns the scale tensor — typically an nn.Parameter on the enclosing module, with shape (dim,) for the standard case or (n_heads, 1, dim) for fused Q/K normalization.

ExternalizeSpec

from coreai_torch.composite_ops import RMSNormImpl

ExternalizeSpec(
    target_class=RMSNormImpl,
    composite_op_name="rms_norm",
    composite_attrs=["axes", "eps"],
)

Data types

Tensor

Allowed types

input, scale, output

fp32, fp16, bf16

RMSNorm: convenience wrapper

from coreai_torch.composite_ops import RMSNorm

RMSNorm(dim, eps=1e-5, n_heads=None)

RMSNorm is a thin nn.Module wrapper around RMSNormImpl that owns the learnable scale parameter so callers don’t have to wire one up themselves. Its forward(x) applies the normalization with the internally-held weight.

Parameter

Type

Default

Description

dim

int

Size of the last dimension. Determines the shape of the learnable scale.

eps

float

1e-5

Epsilon for numerical stability.

n_heads

int | None

None

If set, scale shape is (n_heads, 1, dim) for fused Q/K normalization. If None, scale shape is (dim,).

The wrapper itself is not the externalization target — it composes RMSNormImpl internally, so target_class=RMSNormImpl still produces the rms_norm composite op.