Post-Training Quantization (PTQ)#

Note

NOTE: This is for PTQ on XLA devices (CPU/GPU/TPU). ODML models deployed through the LiteRT converter should use ODML modes.

Post-training quantization optimizes serving performance on XLA devices. It’s achieved by quantizing weights ahead of time and computing with quantized types. When static-range quantization is enabled, PTQ also pre-calculates the scales so that the cost of activation quantization is minimal.

PTQ can be used alone, or used together with QT to recover some quality.

PTQ with Qwix#

PTQ in Qwix is implemented by PtqProvider and can be applied to model with quantize_model.

fp_model = SomeLinenModel()
ptq_model = qwix.quantize_model(fp_model, qwix.PtqProvider(rules))

A more common practice is to use eval_shape instead of JIT above to obtain an abstract PTQ model, and use quantize_params below to obtain the quantized weights, as demonstrated below.

Weight quantization#

Besides quantizing the model, PTQ also requires weights to be quantized ahead of time. This can be achieved by the quantize_params function.

# Floating-point params, usually loaded from checkpoints.
fp_params = ...

# Initialize abstract quantized params, which serve as a template so that the
# quantize_params function knows how to quantize each weight.
abs_ptq_variables = jax.eval_shape(ptq_model.init, jax.random.key(0), model_input)

ptq_params = qwix.quantize_params(fp_params, abs_ptq_variables['params'])

# ptq_params contains the quantized weights and can be consumed by ptq_model.
quantized_model_output = ptq_model.apply({'params': ptq_params}, model_input)

The intermediate ptq_params can be saved to disk, creating a quantized checkpoint. This practice is commonly known as offline quantization. Qwix recommends online quantization whenever possible because

  • Eliminating the offline quantization step improves the development velocity, and reduces the maintenance cost of multiple checkpoints.

  • The structure of ptq_params is the implementation detail of Qwix, which is subject to change, creating incompatibility of quantized checkpoints.

When using online quantization, the fp_params may be too large to fit in the HBM of the serving topology. To solve this, quantize_params also takes a subtree of fp_params. For example, we could load the checkpoints layer by layer and quantize each layer immediately, which is known as pipelined checkpoint loading and quantization.

Alternative way to quantize weights#

For smaller models where HBM limit is not a concern, weight quantization can be achieved by feeding the unquantized weights to the PTQ models themselves. The PTQ models will quantize them correctly and replace the original weights. This can be convenient especially for NNX models.

# Assume fp_variables contains the correct unquantized weights.
_, ptq_variables = ptq_model.apply(fp_variables, model_input, mutable=True)
# ptq_variables contains the quantized weights now.

This could look tricky and non-obvious for most users. Thus it’s recommended to always use quantized_params for Linen models.

Static-range quantization#

In SRQ, the PTQ model contains extra static scales that needs to be calculated from the quant_stats collected during QT. In this case, additional arguments need to be provided to quantize_params.

model = SomeLinenModel(...)
rules = [
    qconfig.QuantizationRule(
        weight_qtype="int8",
        act_qtype="int8",
        act_static_scale=True,
    ),
]

qt_model = qwix.quantize_model(model, qwix.QtProvider(rules))
qt_variables = qt_model.init(jax.random.key(0), model_input)
# qt_variables contains "params" and "quant_stats".

ptq_model = qwix.quantize_model(model, qwix.PtqProvider(rules))
abs_ptq_variables = jax.eval_shape(ptq_model.init, jax.random.key(0), model_input)

ptq_params = qwix.quantize_params(
    qt_variables['params'],
    abs_ptq_variables['params'],
    qt_variables['quant_stats'],
)