batch_norm¶
Inference-time batch normalization using running statistics:
The mean and variance are pre-computed running statistics passed in as inputs; momentum (a training-only construct) is dropped during conversion.
ATen source: aten._native_batch_norm_legit_no_training
Inputs¶
Name |
Shape |
Description |
|---|---|---|
|
|
Supported ranks: 2, 3, 4, 5 — |
|
|
Per-channel scale, applied after normalization |
|
|
Per-channel shift, added after the scale |
|
|
Per-channel running mean |
|
|
Per-channel running variance |
Attributes¶
Name |
Type |
Description |
|---|---|---|
|
|
Numerical-stability epsilon added to the variance |
|
|
Composite op version |
Output¶
Name |
Shape |
Description |
|---|---|---|
|
|
Same shape as |
Data types¶
fp16, fp32, bf16 for all tensor inputs and the output.
PyTorch example¶
import torch
N, C, H, W = 20, 5, 10, 10
input = torch.randn(N, C, H, W)
running_mean = torch.zeros(C)
running_var = torch.ones(C)
output = torch.ops.aten._native_batch_norm_legit_no_training(
input,
weight=torch.ones(C),
bias=torch.zeros(C),
running_mean=running_mean,
running_var=running_var,
momentum=0.1,
eps=1e-5,
)