Skip to content

Commit

Permalink
Add Note to specify the correct ordering for weights and activations …
Browse files Browse the repository at this point in the history
…to be compatible with hardware capabilities.

PiperOrigin-RevId: 604416628
  • Loading branch information
shivaniag authored and pax authors committed Feb 5, 2024
1 parent f7506e3 commit 51e5cf9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
9 changes: 7 additions & 2 deletions praxis/layers/quantization/sparsity/sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ def get_sparsity_mask(
order: Apply pruning using this index order. Supported values are `C`, `R`.
`C` and `R` indicate column-wise and row-wise masking, respectively.
Default is `R` indicating to applying N:M sparsity across rows of the
input matrix.
input matrix. Default is `C` indicating to applying N:M sparsity across
columns of the input matrix. The choice may intersect with hardware
capabilities. For a weight tensor `C` corresponds to the reduction
dimension, and `R' for activations.
Returns:
A mask that indicates the pruning locations (`0`: no pruning, `1`: pruned).
Expand Down Expand Up @@ -201,7 +204,9 @@ def prune_inputs_n_m(
order: Apply pruning using this index order. Supported values are `C`, `R`.
`C` and `R` indicate column-wise and row-wise masking, respectively.
Default is `R` indicating to applying N:M sparsity across rows of the
input matrix.
input matrix. The choice may intersect with hardware capabilities. For a
weight tensor `C` corresponds to the reduction dimension, and `R' for
activations.
Returns:
An array with the same shape as inputs pruned with N:M strategy.
Expand Down
11 changes: 9 additions & 2 deletions praxis/layers/quantization/sparsity/sparsity_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ def __post_init__(self):
)


# NOTE: Pay attention to which dimension, and type of tensor being sparsified.
# Some hardware may support sparsity along the reduction dimension alone. This
# would translate to `C' i.e., column-wise pruning for weights, and `R' i.e.,
# row-wise pruning for activations. Enforcing the correct order may be required
# to suitably target hardware capabilities.
@enum.unique
class SparsityOrder(str, enum.Enum):
"""The different index order to apply pruning.
Expand Down Expand Up @@ -159,8 +164,10 @@ class SparsityHParams:
sparsity
order: Apply pruning using this index order. Supported values are `C`, `R`.
`C` and `R` indicate column-wise and row-wise masking, respectively.
Default is `R` indicating to applying N:M sparsity across rows of the
input matrix.
Default is `C` indicating to applying N:M sparsity across columns of the
input matrix. The choice may intersect with hardware capabilities, that
support sparsity only along a reduction dimension. For a weight tensor `C`
corresponds to the reduction dimension, and `R' for activations.
track_sad_metric: Should we track sparse architecture divergence metric?
topk_estimator_type: Sets the type of top-k mask learning.
"""
Expand Down

0 comments on commit 51e5cf9

Please sign in to comment.