Data Structures#
The QArray is the backbone of the Qwix library. It serves as the fundamental
data structure that encapsulates quantized data (integers) alongside its
quantization parameters (scale and zero-point).
By wrapping these elements together, QArray allows quantized tensors to be
passed around and manipulated seamlessly, ensuring that the quantization
context is preserved throughout the computation graph.
- class qwix.QArray(qvalue: Array, scale: Array, zero_point: Array | None = None, qtype: str | type[Any] | dtype | SupportsDType = None)[source]#
A quantized array implementation with subchannel support.
- The following conditions hold:
qvalue.shape == original.shape
len(scale.shape) == len(original.shape)
len(scale.shape) == len(zero_point.shape)
To enable subchannel quantization, scale and zero_point can be “generic broadcasted” to original.shape, which means
all(o % s == 0 for o, s in zip(original.shape, scale.shape))
- original ≈ (qvalue - zero_point) * generic_broadcast(
scale, original.shape)
- qvalue#
The quantized value.
- Type:
jax.Array
- scale#
The scale used to quantize the value.
- Type:
jax.Array
- zero_point#
The quantization value that represents the exact floating-point value 0, or None if in symmetric quantization.
- Type:
jax.Array | None
- qtype#
The logical type of the qvalue, which could be different from the dtype used for storage in qvalue. If None, the qvalue’s dtype will be used as the logical type.
- Type:
str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType
- property T#
- astype(dtype: str | type[Any] | dtype | SupportsDType) QArray[source]#
Cast the dequant type to the given dtype.
- property dtype#
- property mT#
- property ndim#
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- property scale_tile_shape: tuple[int, ...]#
Returns the tile shape for the scale values.
- property shape#
- property zero_point_tile_shape: tuple[int, ...] | None#
Returns the tile shape for the zero point values.
- qwix.quantize(array: Array, qtype: str | type[Any] | dtype | SupportsDType, *, channelwise_axes: Collection[int] = (), tiled_axes: Mapping[int, int | float] | None = None, calibration_method: str = 'absmax', scale_dtype: str | type[Any] | dtype | SupportsDType | None = None) QArray[source]#
Quantize a Jax Array into QArray using a dynamic range.
This function exists as a public API to avoid constructing a HowToQuantize.
- Parameters:
array – The array to quantize.
qtype – The logical type of the quantized value, e.g. jnp.int8, jnp.int4, jnp.float8_e4m3fn, “nf4”, etc.
channelwise_axes – Channelwise axes have individual scales. This has the same effect as setting their tile sizes to 1 in tiled_axes.
tiled_axes – Tiled axes have blockwise scales, aka subchannel quantization. The value is a mapping from the tiled axis to the tile size. If the tile size is a float, it will be interpreted as “1 / tile_count” and the actual tile size will be round(axis_size * tile_size).
calibration_method – The calibration method to use. The format is “<method>[,<args>]”, e.g. “absmax” or “fixed,-10,10”.
scale_dtype – The dtype of the scale. If not given, the dtype will be the same as the array’s dtype. Note that the scale’s dtype decides the dequantized array’s dtype.
- Returns:
The quantized array.
- qwix.dequantize(array: QArray) Array[source]#
Dequantizes an array. The reverse of |quantize|.
- Parameters:
array – The quantized array to dequantize.
- Returns:
The dequantized array, whose dtype is the same as the scale’s dtype.