Skip to content

Commit

Permalink
Unblocking subchannel quantization with activation quantization.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 606725799
  • Loading branch information
The praxis Authors committed Feb 13, 2024
1 parent be97eb5 commit 433c0b0
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions praxis/layers/quantization/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,16 @@ def __call__(self, inputs: JTensor) -> JTensor:
inputs_shape, block_size, len(inputs_shape) - 1
),
)
q_einsum_params['eqn'] = 'scz,...sc->...sz'
q_einsum_params['scale_eqn'] = '...sz,sz->...z'
q_einsum_params['zp_eqn'] = '...sc,sz->...z'
q_einsum_params['swap_xw'] = True
if self.quantization.act_params is not None:
q_einsum_params['eqn'] = '...sc,scz->...sz'
q_einsum_params['scale_eqn'] = '...sz,sz->...z'
q_einsum_params['zp_eqn'] = '...sc,sz->...z'
q_einsum_params['swap_xw'] = False
else:
q_einsum_params['eqn'] = 'scz,...sc->...sz'
q_einsum_params['scale_eqn'] = '...sz,sz->...z'
q_einsum_params['zp_eqn'] = '...sc,sz->...z'
q_einsum_params['swap_xw'] = True
if len(w.shape) == 2:
q_einsum_params['reshape'] = self._get_sub_channel_shape(
list(w.shape), block_size, 0
Expand Down

0 comments on commit 433c0b0

Please sign in to comment.