Skip to content

Commit

Permalink
add scalar quantizer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 687071008
  • Loading branch information
bignamehyp authored and pax authors committed Oct 17, 2024
1 parent 619fc34 commit 572b571
Showing 1 changed file with 93 additions and 0 deletions.
93 changes: 93 additions & 0 deletions praxis/layers/video/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Quantizer for vqvae."""

from collections.abc import Sequence
import math

import jax
Expand Down Expand Up @@ -135,6 +136,98 @@ def _entropy_loss(
return loss


class ScalarQuantizer(base_layer.BaseLayer):
"""Scalar quantizer (SQ)."""

embedding_dim: int = 8
dim_widths: Sequence[int] = (8, 8, 4, 4, 4, 4, 4, 4)
bound_method: str = 'sine' # 'sine' or 'tanh'
eps: float = 1e-3 # relatively large due to use of float16

def _bound_embedding(self, z: JTensor) -> JTensor:
"""Applies bounding function (e.g., tanh, sine)."""
dim_widths = jnp.asarray(self.dim_widths, np.int32)
k = dim_widths * (1 - self.eps) / 2
offset = jnp.where(dim_widths % 2 == 1, 0.0, 0.5)
if self.bound_method == 'sine':
return jnp.sin(z) * k - offset
elif self.bound_method == 'tanh':
return jnp.tanh(z) * k - offset
else:
raise ValueError(f'Bound method {self.bound_method} not supported')

def _quantize_embedding(self, z: JTensor) -> JTensor:
"""Returns quantized values like cdx.ops.ste_round(z)."""
# use cdx.ops.ste_round(z) for training
z_q = jnp.round(z)
z_q = z + jax.lax.stop_gradient(z_q - z)
return z_q

def _get_indices(self, z: JTensor) -> tuple[JTensor, JTensor]:
"""Returns indices (SQ or implied codebook) from quantized values."""
dim_widths = jnp.asarray(self.dim_widths, np.int32)
left = -(dim_widths // 2)
right = dim_widths + left - 1
# In {left, ..., right}
clipped = jnp.clip(z, left, right)
# In {0, ..., left + right}
zeroed = clipped - left
if len(self.dim_widths) != zeroed.shape[-1]:
raise ValueError('Sum of dim_widths and codebook sizes do not match.')
if zeroed.dtype != jnp.int32:
zeroed = jnp.round(zeroed).astype(jnp.int32)
zeroed = jnp.clip(zeroed, 0, dim_widths - 1)
indices = jnp.zeros(zeroed.shape[:-1], dtype=jnp.int32)
for i, n in enumerate(self.dim_widths):
indices *= n
indices += zeroed[..., i]
return indices, clipped

def __call__(self, inputs: JTensor) -> tuple[JTensor, NestedMap]:
if len(self.dim_widths) != inputs.shape[-1]:
raise ValueError(
'Number of dim widths must match channels: '
f'{len(self.dim_widths)} vs {inputs.shape}'
)

dim_widths = jnp.asarray(self.dim_widths, np.int32)
unquantized = self._bound_embedding(inputs)
# In {left, ..., right}.
quantized = self._quantize_embedding(unquantized)
indices, clipped = self._get_indices(quantized)

if self.do_eval:
quantized = clipped
middle = (dim_widths - dim_widths // 2 * 2 - 1) / 2
# This might not be integers.
quantized_centered = quantized - middle
unquantized_centered = unquantized - middle

result_dict = NestedMap(
raw=unquantized_centered,
quantized=quantized_centered,
encoding_indices=indices,
quantizer_loss=0.0,
)

return quantized_centered, result_dict

def decode_ids(self, ids: JTensor) -> JTensor:
if ids.dtype != jnp.int32:
ids = jnp.round(ids).astype(jnp.int32)
vals = []
dim_widths = jnp.asarray(self.dim_widths, np.int32)
for dim_width in dim_widths[::-1]:
vals.append(jnp.mod(ids, dim_width))
ids = ids // dim_width
vals = jnp.stack(vals[::-1], axis=-1)
vals -= dim_widths // 2 # tokens are [0..N), vals are [-N/2..N/2)

middle = (dim_widths - dim_widths // 2 * 2 - 1) / 2
vals -= middle
return vals


class LookupFreeQuantizer(base_layer.BaseLayer):
"""Lookup free quantizer with fixed value sets."""

Expand Down

0 comments on commit 572b571

Please sign in to comment.