Skip to content

Commit

Permalink
Add calibration.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 608187171
  • Loading branch information
jianlijianli authored and pax authors committed Feb 19, 2024
1 parent f684999 commit 1dd2bb0
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 0 deletions.
40 changes: 40 additions & 0 deletions praxis/layers/quantization/linears_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,46 @@ def test_linear_step_count_in_train(self):
updated_vars[SUMMARIES]['step_count_scalar'], np.array([1])
)

def test_linear_calibration(self):
p_q = pax_fiddle.Config(
qlinears.Linear,
name='_linear_q',
quantization=QuantizationParams(
mode=QuantizationMode.CALIB,
quantization_type=QuantizationType.FR,
),
input_dims=3,
output_dims=3,
)
linear_q = instantiate(p_q)
inputs = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=jnp.float32)

context_params = base_layer.JaxContext.HParams(do_eval=False)
with base_layer.JaxContext.new_context(hparams=context_params):
prng_key = jax.random.PRNGKey(seed=123)
initial_vars_q = linear_q.init(prng_key, inputs)
self.assertArraysEqual(
initial_vars_q[NON_TRAINABLE]['framestat'],
jnp.array([0], dtype=jnp.bfloat16),
)
_, updated_vars = linear_q.apply(
initial_vars_q, inputs, mutable=[PARAMS, NON_TRAINABLE, SUMMARIES]
)
self.assertArraysEqual(
updated_vars[NON_TRAINABLE]['framestat'],
jnp.array([9], dtype=jnp.bfloat16),
)

# No grad.
def loss(params, inputs):
return jnp.sum(linear_q.apply(params, inputs)[0])

grad = jax.grad(loss)(initial_vars_q, inputs)
self.assertArraysEqual(
grad['params']['w'],
jnp.zeros_like(initial_vars_q['params']['w'], dtype=jnp.float32),
)

def test_int4_weight_init(self):
p = pax_fiddle.Config(
qlinears.Linear,
Expand Down
5 changes: 5 additions & 0 deletions praxis/layers/quantization/quantization_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ class QuantizationType(str, enum.Enum):
AQT: Accurate Quantized Training, which is one flavor of QAT.
FQ: Fake Quantization, which is one flavor of QAT.
FQ_VN: Use variational noise to emulate quantization noise.
FR: Fr quantization.
"""
PTQ = 'ptq'
AQT = 'aqt'
FQ = 'fq'
FQ_VN = 'fq_vn'
FR = 'fr'
# Internal quantization type.


Expand All @@ -49,11 +51,14 @@ class QuantizationMode(str, enum.Enum):
INFERENCE. This mode is referenced only by `ServableModelParams` for
serving.
INFERENCE indicates that the model is in inference mode.
QT indicates the model will train with quantization.
CALIB inidates that the model is going to be calibrated.
"""
TRAINING = 'training'
MATERIALIZE = 'materialize'
INFERENCE = 'inference'
QT = 'qt'
CALIB = 'calib'


@dataclasses.dataclass
Expand Down
25 changes: 25 additions & 0 deletions praxis/layers/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def set_up_weights(
)

if self.quantization.mode == QuantizationMode.INFERENCE:
if self.quantization.quantization_type == QuantizationType.FR:
raise NotImplementedError(
'FRAME Quantization is not supported yet for inference.'
)
if (
precision == 4
and self.quantization.weight_params.use_int4_packed_weights
Expand Down Expand Up @@ -136,6 +140,15 @@ def set_up_weights(
dtype=dtype,
use_symmetric=self.quantization.weight_params.use_symmetric,
)
elif self.quantization.mode == QuantizationMode.CALIB:
assert self.quantization.quantization_type == QuantizationType.FR
stats = base_layer.WeightHParams(
shape=[1],
init=base_layer.WeightInit.Constant(0),
dtype=jnp.bfloat16,
)
self.create_variable('framestat', stats, trainable=False)
self.create_variable(weight_name, weight_params)
else:
self.create_variable(weight_name, weight_params)

Expand Down Expand Up @@ -215,6 +228,11 @@ def quantized_einsum(
self.add_summary('step_count', step_count)

if self.quantization.mode == QuantizationMode.INFERENCE:
if self.quantization.quantization_type == QuantizationType.FR:
# It takes extra parameters during infernece.
raise NotImplementedError(
'FR Quantization is not supported yet for inference.'
)
# PTQ, QAT has the same inference graph, only difference is on activation.
# No matter which quantization type is used, the weight and scale
# dimensions are the same for all types.
Expand Down Expand Up @@ -270,6 +288,13 @@ def quantized_einsum(
elif self.quantization.mode == QuantizationMode.QT:
key = self.next_prng_key()
return operations.custom_einsum(x, w, key)
elif self.quantization.mode == QuantizationMode.CALIB:
stat = self.get_var('framestat')
new_stat = jnp.maximum(jnp.max(x), stat)
self.update_var('framestat', new_stat.astype(jnp.bfloat16))
x = jax.lax.stop_gradient(x)
w = jax.lax.stop_gradient(w)
return jnp.einsum(eqn, x, w)
else:
assert not swap_xw, 'Swapping xw is only supported in inference mode.'
if reshape:
Expand Down

0 comments on commit 1dd2bb0

Please sign in to comment.