Get Started#
This guide will demonstrate how to apply post-training quantization to a simple MLP model.
import jax
from flax import linen as nn
class MLP(nn.Module):
dhidden: int
dout: int
@nn.compact
def __call__(self, x):
x = nn.Dense(self.dhidden, use_bias=False)(x)
x = nn.relu(x)
x = nn.Dense(self.dout, use_bias=False)(x)
return x
model = MLP(64, 16)
model_input = jax.random.uniform(jax.random.key(0), (8, 16))
import jax
from flax import nnx
class MLP(nnx.Module):
def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs)
self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs)
def __call__(self, x):
x = self.linear1(x)
x = nnx.relu(x)
x = self.linear2(x)
return x
model = MLP(16, 64, 16, rngs=nnx.Rngs(0))
model_input = jax.random.uniform(jax.random.key(0), (8, 16))
Since Qwix is able to quantize the whole model implicitly, there’s no need to modify the model code. The above model can also be substituted with any other Linen/NNX models.
Quantization config#
Qwix uses a regex-based configuration system to instruct how to quantize a Jax
model. Configurations are defined as a list of QuantizationRule. Each rule
consists of a key that matches Flax modules, and a set of values that control
quantization behavior.
For example, to quantize the above model in int8 (w8a8), we need to define the rules as below.
import qwix
rules = [
qwix.QuantizationRule(
module_path='.*', # this rule matches all modules.
weight_qtype='int8', # quantizes weights in int8.
act_qtype='int8', # quantizes activations in int8.
)
]
Unlike some other libraries that provides limited number of quantization
recipes, Qwix doesn’t have a list of presets. Instead, different quantization
schemas are achieved by combinations of quantization configs. For a full list of
available options, please check the
QuantizationRule class.
Apply quantization#
With the above code, applying quantization is as simple as one line.
ptq_model = qwix.quantize_model(model, qwix.PtqProvider(rules))
ptq_model = qwix.quantize_model(model, qwix.PtqProvider(rules), model_input)
We could inspect the params to verify that weights are now pre-quantized.
>>> jax.eval_shape(ptq_model.init, jax.random.key(0), model_input)['params']
{
'Dense_0': {
'kernel': WithAux(
array=QArray(
qvalue=ShapeDtypeStruct(shape=(16, 64), dtype=int8),
scale=ShapeDtypeStruct(shape=(1, 64), dtype=float32),
...
),
...
)
},
'Dense_1': {
'kernel': WithAux(
array=QArray(
qvalue=ShapeDtypeStruct(shape=(64, 16), dtype=int8),
scale=ShapeDtypeStruct(shape=(1, 16), dtype=float32),
...
),
...
)
}
}
>>> jax.eval_shape(nnx.to_pure_dict, nnx.state(ptq_model))
{
'linear1': {
'kernel': {
'array': {
'qvalue': ShapeDtypeStruct(shape=(16, 64), dtype=int8),
'scale': ShapeDtypeStruct(shape=(1, 64), dtype=float32)
}
}
},
'linear2': {
'kernel': {
'array': {
'qvalue': ShapeDtypeStruct(shape=(64, 16), dtype=int8),
'scale': ShapeDtypeStruct(shape=(1, 16), dtype=float32)
}
}
}
}
Quantization providers#
You may notice that we initialized a PtqProvider object above and applied it
to the model. PtqProvider implements QuantizationProvider interface, which
is a powerful abstraction that allows different quantization modes being
implemented and consumed in a consistent way.
Qwix ships with the following providers.
It’s also possible to implement your own provider by subclassing existing ones, which is perfect for researchers to explore novel quantization algorithms.