instance_norm¶
Normalizes each (sample, channel) slice independently across its spatial dims — mean and variance are computed per-sample rather than across the batch.
ATen source: aten.instance_norm (preserved as composite by get_decomp_table())
Inputs¶
Name |
Shape |
Description |
|---|---|---|
|
|
1, 2, or 3 spatial dims (e.g., |
|
|
Per-channel scale applied after normalization |
|
|
Per-channel shift applied after the scale |
Attributes¶
Name |
Type |
Description |
|---|---|---|
|
|
Numerical-stability epsilon |
|
|
Composite op version |
Output¶
Name |
Shape |
Description |
|---|---|---|
|
|
Same shape as |
Data types¶
fp16, fp32.
PyTorch example¶
import torch
N, C, H, W = 2, 6, 10, 10
input = torch.randn(N, C, H, W)
gamma = torch.randn(C)
beta = torch.randn(C)
output = torch.ops.aten.instance_norm.default(
input, gamma, beta,
None, None, # running_mean / running_var unused in inference
True, # use_input_stats
0.1, # momentum (ignored in inference)
1e-5, # eps
True, # cudnn_enabled (ignored)
)