Standard Providers#
This section details the built-in quantization providers available in Qwix.
Quantized Training (QT)#
The QtProvider implements Quantized Training (QT). This provider performs
quantization on both the forward and backward passes during training.
- class qwix.QtProvider(rules: Sequence[QuantizationRule])[source]#
Quantization provider for Quantized Training (QT).
- conv_general_dilated(lhs: Array, rhs: Array, window_strides: Sequence[int], padding: str | Sequence[tuple[int, int]], lhs_dilation: Sequence[int] | None = None, rhs_dilation: Sequence[int] | None = None, dimension_numbers: tuple[str, str, str] | ConvDimensionNumbers | None = None, feature_group_count: int = 1, batch_group_count: int = 1, precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, out_sharding=None) Array[source]#
QT conv_general_dilated.
- dot_general(lhs: Array, rhs: Array, dimension_numbers: tuple[tuple[Sequence[int], Sequence[int]], tuple[Sequence[int], Sequence[int]]], precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, *, out_sharding=None) Array[source]#
QT dot_general.
- einsum(einsum_str: str, *operands: ~jax.Array, precision: None | str | ~jax._src.lax.lax.Precision | tuple[str, str] | tuple[~jax._src.lax.lax.Precision, ~jax._src.lax.lax.Precision] | ~jax._src.lax.lax.DotAlgorithm | ~jax._src.lax.lax.DotAlgorithmPreset = None, preferred_element_type: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType | None = None, _dot_general: ~typing.Callable[[...], ~jax.Array] = <function dot_general>, out_sharding=None) Array[source]#
QT einsum.
- ragged_dot(lhs: Array, rhs: Array, group_sizes: Array, precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, group_offset: Array | None = None) Array[source]#
QT ragged_dot.
- class qwix.QtRule(*, module_path: str = '.*', op_names: Collection[str] = (), weight_qtype: str | type[Any] | dtype | SupportsDType | None = None, act_qtype: str | type[Any] | dtype | SupportsDType | None = None, tile_size: int | float | None = None, act_static_scale: bool | None = None, weight_calibration_method: str = 'absmax', act_calibration_method: str | None = None, act_batch_axes: Collection[int] = (0,), bwd_qtype: str | type[Any] | dtype | SupportsDType | None = None, bwd_calibration_method: str = 'absmax', bwd_weight_grad_tile_size: int | float | None = None, disable_channelwise_axes: bool = False, bwd_stochastic_rounding: str | None = None, channelwise_noise_axes: Sequence[int] = (0,), additional_qt_config: Mapping[str, Any] | None = None)[source]#
QuantizationRule with all settings specific to Quantized Training (QT).
- additional_qt_config: Mapping[str, Any] | None = None#
- bwd_calibration_method: str = 'absmax'#
- bwd_qtype: str | type[Any] | dtype | SupportsDType | None = None#
- bwd_stochastic_rounding: str | None = None#
- bwd_weight_grad_tile_size: int | float | None = None#
- channelwise_noise_axes: Sequence[int] = (0,)#
- disable_channelwise_axes: bool = False#
Post-Training Quantization (PTQ)#
These APIs handle Post-Training Quantization, which quantizes a pre-trained model without requiring a full retraining loop.
- class qwix.PtqProvider(rules: ~typing.Sequence[~qwix._src.qconfig.QuantizationRule], *, _qarray_module=<module 'qwix._src.core.qarray' from '/home/docs/checkouts/readthedocs.org/user_builds/qwix/checkouts/latest/qwix/_src/core/qarray.py'>, _dot_general_fn=<function dot_general>, _einsum_fn=<function einsum>, _conv_general_dilated_fn=<function conv_general_dilated>)[source]#
Quantization provider for PTQ.
In PTQ mode, weights needs to be pre-quantized. However, Qwix doesn’t know about how to quantize them until the actual ops get called. To solve this, we still initialize the original weights when the model is initialized, but we replace them with the quantized weights when the ops are called.
It should be invisible to users in Flax linen because module.init will call both the setup() and __call__() methods.
If memory usage is a concern, wrapping module.init with jit or eval_shape should avoid materializing the original weights.
NNX can use the same trick so we don’t need to intercept nnx.Param.
This approach allows the original weights to be supplied during apply, and will actually quantize them correctly. This can be an alternative to quantize_params if partial param quantization is not needed.
- conv_general_dilated(lhs: Array, rhs: Array | WithAux[QArray], window_strides: Sequence[int], padding: str | Sequence[tuple[int, int]], lhs_dilation: Sequence[int] | None = None, rhs_dilation: Sequence[int] | None = None, dimension_numbers: tuple[str, str, str] | ConvDimensionNumbers | None = None, feature_group_count: int = 1, batch_group_count: int = 1, precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, out_sharding=None) Array[source]#
- dot(a: Array, b: Array | WithAux[QArray], precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, out_sharding=None)[source]#
Intercepts jax.numpy.dot.
- dot_general(lhs: Array, rhs: Array | WithAux[QArray], dimension_numbers: tuple[tuple[Sequence[int], Sequence[int]], tuple[Sequence[int], Sequence[int]]], precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, *, out_sharding: NamedSharding | None = None) Array[source]#
- einsum(einsum_str: str, *operands: ~jax.Array, precision: None | str | ~jax._src.lax.lax.Precision | tuple[str, str] | tuple[~jax._src.lax.lax.Precision, ~jax._src.lax.lax.Precision] | ~jax._src.lax.lax.DotAlgorithm | ~jax._src.lax.lax.DotAlgorithmPreset = None, preferred_element_type: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType | None = None, _dot_general: ~typing.Callable[[...], ~jax.Array] = <function dot_general>, out_sharding=None) Array[source]#
- qwix.quantize_params(params: ~typing.Any, abstract_quantized_params: ~typing.Any, quant_stats: ~typing.Any = FrozenDict({}), *, allow_extra_params: bool = False, _qarray_module=<module 'qwix._src.core.qarray' from '/home/docs/checkouts/readthedocs.org/user_builds/qwix/checkouts/latest/qwix/_src/core/qarray.py'>) Any[source]#
Quantize the param tree for PTQ.
This function quantizes the param tree (weights) for PTQ. It doesn’t need to run the model and is useful when the original params are too large to fit in the HBM.
- Parameters:
params – The floating-point param tree to quantize, which is usually generated by the original or QAT model. The tree doesn’t need to be complete and can be a subtree of the whole param tree. In NN, the tree needs to be unboxed, i.e. nn.unbox(). In NNX, the tree needs to be a pure dict, i.e. nnx.to_pure_dict().
abstract_quantized_params – The param tree generated by the PTQ model, which can be abstract with jax.ShapeDtypeStruct as leaves instead of jax.Array. This includes the information of how to quantize each param. In NN, the tree may contain AxisMetadata. In NNX, this should be the PTQ model itself, possibly abstract.
quant_stats – The quantization statistics, which needs to be a pure dict of unboxed values. This is only used in SRQ.
allow_extra_params – If True, allow the params to contain extra parameters that are not present in the abstract_quantized_params, e.g., params for loss computation that are not needed in PTQ.
_qarray_module – The qarray module to use. Useful for extending.
- Returns:
The quantized param tree, which has the same structure as the input params but with quantized leaves.
On-Device ML (ODML)#
These APIs are specialized for converting JAX models to run on edge devices via LiteRT (formerly TensorFlow Lite). They handle specific constraints required by mobile hardware.
- class qwix.OdmlConversionProvider(rules: Sequence[QuantizationRule], params, quant_stats, **kwargs)[source]#
Quantization provider for ODML conversion.
This mode is similar to OdmlQatProvider, but all fake_quant ops are annotated by composites and the scales are computed statically in numpy.
Supported modes:
Weight-only quantization.
Static-range quantization.
Usage:
# The params can be from QAT or the FP model. params = ... # If using static-range quantization, quant_stats are needed and can be # obtained by either 1) QAT or 2) calibrating. quant_stats = ... # Apply OdmlConversionProvider to the model. conversion_model = qwix.quantize_model( fp_model, qwix.OdmlConversionProvider(rules, params, quant_stats) ) # Convert and get the ODML model, which is an ai_edge_jax.model.TfLiteModel. odml_model = ai_edge_jax.convert( conversion_model.apply, {'params': params}, (inputs,) ) # The odml_model can be exported or directly run. odml_model.export('/tmp/odml_model.tflite') odml_model(inputs)
- class qwix.OdmlQatProvider(rules: Sequence[QuantizationRule], *, disable_per_channel_weights: bool = False, fixed_range_for_inputs: tuple[float, float] | None = None, fixed_range_for_outputs: tuple[float, float] | None = None, strict: bool = True)[source]#
QAT provider for ODML.
Compared with the regular QAT provider, this provider
Quantizes all ops more than just conv, einsum, and dot_general.
Quantizes output activations via a delayed fake_quant.
Supports limited per-channel quantization for weights.
Doesn’t support subchannel quantization.
- nn_param(module: Module, name: str, init_fn: Callable[[...], Any], *init_args, unbox: bool = True, **init_kwargs) Array | AxisMetadata[Array][source]#
Intercepts nn.Module.param to associate weight_name aux_data.
LoRA Quantization#
These APIs combine Low-Rank Adaptation (LoRA) with quantization, allowing for memory-efficient fine-tuning of large models.
- class qwix.LoraProvider(rules=None, **kwargs)[source]#
Provider for (Q)LoRA.
LoraProvider inherits from PtqProvider, because the base model is frozen during LoRA training.
- conv_general_dilated(lhs: Array, rhs: Array | WithAux[QArray], window_strides: Sequence[int], padding: str | Sequence[tuple[int, int]], lhs_dilation: Sequence[int] | None = None, rhs_dilation: Sequence[int] | None = None, dimension_numbers: tuple[str, str, str] | ConvDimensionNumbers | None = None, feature_group_count: int = 1, batch_group_count: int = 1, precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, out_sharding=None) Array[source]#
LoRA conv_general_dilated.
- dot_general(lhs: Array, rhs: Array | WithAux[QArray], dimension_numbers: tuple[tuple[Sequence[int], Sequence[int]], tuple[Sequence[int], Sequence[int]]], precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, out_sharding: NamedSharding | None = None) Array[source]#
LoRA dot_general.
- class qwix.LoraRule(*, module_path: str = '.*', op_names: ~collections.abc.Collection[str] = (), weight_qtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType | None = None, act_qtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType | None = None, tile_size: int | float | None = None, act_static_scale: bool | None = None, weight_calibration_method: str = 'absmax', act_calibration_method: str | None = None, act_batch_axes: ~collections.abc.Collection[int] = (0,), rank: int, alpha: float, dropout: float = 0.0, lora_a_initializer: ~typing.Callable[[...], ~jax.Array] = <function variance_scaling.<locals>.init>, lora_b_initializer: ~typing.Callable[[...], ~jax.Array] = <function zeros>)[source]#
LoRA rules that match and configure the LoRA behavior.
- alpha: float#
- dropout: float = 0.0#
- lora_a_initializer(shape: Sequence[int | Any], dtype: Any | None = None, out_sharding: NamedSharding | PartitionSpec | None = None) Array#
- lora_b_initializer(shape: Sequence[int | Any], dtype: Any | None = None, out_sharding: NamedSharding | PartitionSpec | None = None) Array#
An initializer that returns a constant array full of zeros.
The
keyargument is ignored.>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- rank: int#
- qwix.apply_lora_to_model(model: ModelType, provider: QuantizationProvider, *model_inputs: Any, methods: Collection[str] = ('__call__',), **model_inputs_kwargs: Any) ModelType[source]#
Applies LoRA to a model.