Skip to content

Commit

Permalink
Rename create_aux_variables to create_sparsity_variables for readabil…
Browse files Browse the repository at this point in the history
…ity.

PiperOrigin-RevId: 605141468
  • Loading branch information
shivaniag authored and pax authors committed Feb 8, 2024
1 parent 775727f commit bfb6127
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
4 changes: 2 additions & 2 deletions praxis/layers/quantization/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def setup(self) -> None:
weight_params=pc,
scale_shape=scale_shape,
)
self.create_aux_variables('w', pc, scale_shape=scale_shape)
self.create_sparsity_variables('w', pc, scale_shape=scale_shape)

if self.use_bias:
if self.is_output_projection:
Expand Down Expand Up @@ -335,7 +335,7 @@ def setup(self) -> None:
weight_params=pc,
scale_shape=[3] + hd_shape,
)
self.create_aux_variables('w', pc, scale_shape=[3] + hd_shape)
self.create_sparsity_variables('w', pc, scale_shape=[3] + hd_shape)
if self.use_bias:
# Combined bias weight for q, k, v projections.
pc_bias = WeightHParams(
Expand Down
6 changes: 3 additions & 3 deletions praxis/layers/quantization/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def setup(self) -> None:
weight_params=wp_a,
scale_shape=[self.rank],
)
self.create_aux_variables(
self.create_sparsity_variables(
'w_a',
wp_a,
scale_shape=[self.rank],
Expand All @@ -162,7 +162,7 @@ def setup(self) -> None:
weight_params=wp_b,
scale_shape=[self.output_dims],
)
self.create_aux_variables(
self.create_sparsity_variables(
'w_b',
wp_b,
scale_shape=[self.output_dims],
Expand All @@ -185,7 +185,7 @@ def setup(self) -> None:
weight_params=weight_hparams,
scale_hparams=scale_hparams,
)
self.create_aux_variables(
self.create_sparsity_variables(
'w',
weight_hparams,
scale_shape=[self.output_dims],
Expand Down
4 changes: 2 additions & 2 deletions praxis/layers/quantization/sparsity/sparsifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def sr_ste(
Algorithm description: https://arxiv.org/abs/2102.04010
The last arguement is forced to be static to simplify
The last argument is forced to be static to simplify
the implementation.
Args:
Expand Down Expand Up @@ -129,7 +129,7 @@ class SparsityBaseLayer(base_layer.BaseLayer):

sparsity: Optional[SparsityHParams] = None

def create_aux_variables(
def create_sparsity_variables(
self,
name: str,
weight_hparams: WeightHParams,
Expand Down
4 changes: 2 additions & 2 deletions praxis/layers/quantization/sparsity/sparsifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def setup(self):
name = 'w'
self.create_variable(name, weight_hp)
self.create_child('einsum', self.einsum_tpl.clone())
self.create_aux_variables(name, weight_hp)
self.create_sparsity_variables(name, weight_hp)

def __call__(self, inputs):
w = self.sparsifiy(
Expand Down Expand Up @@ -98,7 +98,7 @@ def setUp(self):
),
),
)
def test_create_aux_variables(self, mode_name, mode):
def test_create_sparsity_variables(self, mode_name, mode):
sparsity_p = pax_fiddle.Config(
SparsityHParams,
sparsity_type=SparsityType.STRUCTURED_NM,
Expand Down

0 comments on commit bfb6127

Please sign in to comment.