Skip to content

Commit

Permalink
Fix left behind code for adding top-k binary mask method.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 604856770
  • Loading branch information
shivaniag authored and pax authors committed Feb 7, 2024
1 parent de5fd7d commit a3ff279
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
2 changes: 1 addition & 1 deletion praxis/layers/quantization/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def setup(self) -> None:
weight_params=pc,
scale_shape=[3] + hd_shape,
)
self.create_aux_variables('w', pc)
self.create_aux_variables('w', pc, scale_shape=[3] + hd_shape)
if self.use_bias:
# Combined bias weight for q, k, v projections.
pc_bias = WeightHParams(
Expand Down
12 changes: 10 additions & 2 deletions praxis/layers/quantization/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,11 @@ def setup(self) -> None:
weight_params=wp_a,
scale_shape=[self.rank],
)
self.create_aux_variables('w_a', wp_a)
self.create_aux_variables(
'w_a',
wp_a,
scale_shape=[self.rank],
)
wp_b = WeightHParams(
shape=shape_b,
mesh_shape=self.mesh_shape,
Expand All @@ -158,7 +162,11 @@ def setup(self) -> None:
weight_params=wp_b,
scale_shape=[self.output_dims],
)
self.create_aux_variables('w_b', wp_b)
self.create_aux_variables(
'w_b',
wp_b,
scale_shape=[self.output_dims],
)

else:
block_size = self._sub_channel_block_size()
Expand Down
3 changes: 3 additions & 0 deletions praxis/layers/quantization/sparsity/sparsifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ def _create_masks_variables(
)
if self.sparsity.topk_estimator_type:
# create learnable mask parameters for top-k methods
assert (
scale_shape is not None
), 'scale_shape is required for top-k methods.'
sparsity_mask_hp = copy.deepcopy(weight_hp)
self.set_up_weights(
weight_name='w_mask',
Expand Down

0 comments on commit a3ff279

Please sign in to comment.