Data Structures#

The QArray is the backbone of the Qwix library. It serves as the fundamental data structure that encapsulates quantized data (integers) alongside its quantization parameters (scale and zero-point).

By wrapping these elements together, QArray allows quantized tensors to be passed around and manipulated seamlessly, ensuring that the quantization context is preserved throughout the computation graph.

class qwix.QArray(qvalue: Array, scale: Array, zero_point: Array | None = None, qtype: str | type[Any] | dtype | SupportsDType = None)[source]#

A quantized array implementation with subchannel support.

The following conditions hold:
  • qvalue.shape == original.shape

  • len(scale.shape) == len(original.shape)

  • len(scale.shape) == len(zero_point.shape)

  • To enable subchannel quantization, scale and zero_point can be “generic broadcasted” to original.shape, which means

    all(o % s == 0 for o, s in zip(original.shape, scale.shape))

  • original ≈ (qvalue - zero_point) * generic_broadcast(

    scale, original.shape)

qvalue#

The quantized value.

Type:

jax.Array

scale#

The scale used to quantize the value.

Type:

jax.Array

zero_point#

The quantization value that represents the exact floating-point value 0, or None if in symmetric quantization.

Type:

jax.Array | None

qtype#

The logical type of the qvalue, which could be different from the dtype used for storage in qvalue. If None, the qvalue’s dtype will be used as the logical type.

Type:

str | type[Any] | numpy.dtype | jax._src.typing.SupportsDType

property T#
astype(dtype: str | type[Any] | dtype | SupportsDType) QArray[source]#

Cast the dequant type to the given dtype.

property dtype#
property mT#
property ndim#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

reshape(*new_shape) QArray[source]#
property scale_tile_shape: tuple[int, ...]#

Returns the tile shape for the scale values.

property shape#
swapaxes(axis1: int, axis2: int) QArray[source]#
transpose(*args) QArray[source]#
property zero_point_tile_shape: tuple[int, ...] | None#

Returns the tile shape for the zero point values.

qwix.quantize(array: Array, qtype: str | type[Any] | dtype | SupportsDType, *, channelwise_axes: Collection[int] = (), tiled_axes: Mapping[int, int | float] | None = None, calibration_method: str = 'absmax', scale_dtype: str | type[Any] | dtype | SupportsDType | None = None) QArray[source]#

Quantize a Jax Array into QArray using a dynamic range.

This function exists as a public API to avoid constructing a HowToQuantize.

Parameters:
  • array – The array to quantize.

  • qtype – The logical type of the quantized value, e.g. jnp.int8, jnp.int4, jnp.float8_e4m3fn, “nf4”, etc.

  • channelwise_axes – Channelwise axes have individual scales. This has the same effect as setting their tile sizes to 1 in tiled_axes.

  • tiled_axes – Tiled axes have blockwise scales, aka subchannel quantization. The value is a mapping from the tiled axis to the tile size. If the tile size is a float, it will be interpreted as “1 / tile_count” and the actual tile size will be round(axis_size * tile_size).

  • calibration_method – The calibration method to use. The format is “<method>[,<args>]”, e.g. “absmax” or “fixed,-10,10”.

  • scale_dtype – The dtype of the scale. If not given, the dtype will be the same as the array’s dtype. Note that the scale’s dtype decides the dequantized array’s dtype.

Returns:

The quantized array.

qwix.dequantize(array: QArray) Array[source]#

Dequantizes an array. The reverse of |quantize|.

Parameters:

array – The quantized array to dequantize.

Returns:

The dequantized array, whose dtype is the same as the scale’s dtype.