diff --git a/praxis/layers/quantization/attentions.py b/praxis/layers/quantization/attentions.py index 9d478906..dbcf711e 100644 --- a/praxis/layers/quantization/attentions.py +++ b/praxis/layers/quantization/attentions.py @@ -335,7 +335,7 @@ def setup(self) -> None: weight_params=pc, scale_shape=[3] + hd_shape, ) - self.create_aux_variables('w', pc) + self.create_aux_variables('w', pc, scale_shape=[3] + hd_shape) if self.use_bias: # Combined bias weight for q, k, v projections. pc_bias = WeightHParams( diff --git a/praxis/layers/quantization/linears.py b/praxis/layers/quantization/linears.py index ac962e25..bfa326e5 100644 --- a/praxis/layers/quantization/linears.py +++ b/praxis/layers/quantization/linears.py @@ -147,7 +147,11 @@ def setup(self) -> None: weight_params=wp_a, scale_shape=[self.rank], ) - self.create_aux_variables('w_a', wp_a) + self.create_aux_variables( + 'w_a', + wp_a, + scale_shape=[self.rank], + ) wp_b = WeightHParams( shape=shape_b, mesh_shape=self.mesh_shape, @@ -158,7 +162,11 @@ def setup(self) -> None: weight_params=wp_b, scale_shape=[self.output_dims], ) - self.create_aux_variables('w_b', wp_b) + self.create_aux_variables( + 'w_b', + wp_b, + scale_shape=[self.output_dims], + ) else: block_size = self._sub_channel_block_size() diff --git a/praxis/layers/quantization/sparsity/sparsifier.py b/praxis/layers/quantization/sparsity/sparsifier.py index e86564ae..7bf09898 100644 --- a/praxis/layers/quantization/sparsity/sparsifier.py +++ b/praxis/layers/quantization/sparsity/sparsifier.py @@ -164,6 +164,9 @@ def _create_masks_variables( ) if self.sparsity.topk_estimator_type: # create learnable mask parameters for top-k methods + assert ( + scale_shape is not None + ), 'scale_shape is required for top-k methods.' sparsity_mask_hp = copy.deepcopy(weight_hp) self.set_up_weights( weight_name='w_mask',