Core Operations#
This module provides the low-level mathematical operations for Qwix.
These functions are designed to be quantized equivalents of standard JAX
operations. For example, qwix.dot_general mirrors jax.lax.dot_general,
but is specifically engineered to accept QArray inputs. It performs the
underlying quantized arithmetic (such as integer matrix multiplication) and
correctly handles the propagation of scales and zero-points.
- qwix.conv_general_dilated(lhs: Array | QArray, rhs: Array | QArray, window_strides: Sequence[int], padding: str | Sequence[tuple[int, int]], lhs_dilation: Sequence[int] | None = None, rhs_dilation: Sequence[int] | None = None, dimension_numbers: tuple[str, str, str] | ConvDimensionNumbers | None = None, feature_group_count: int = 1, batch_group_count: int = 1, precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, out_sharding=None) Array[source]#
Computes a general dilated convolution with support for
QArrayinputs.This function serves as a drop-in replacement for jax.lax.conv_general_dilated.
It automatically dispatches to a quantized implementation if inputs are compatible ``QArray``s. Otherwise, it dequantizes inputs and falls back to the standard floating-point JAX implementation.
- Parameters:
lhs – The left-hand side, either a jax.Array or QArray.
rhs – The right-hand side, either a jax.Array or QArray.
window_strides – A sequence of integers specifying the stride of the convolution window.
padding – The padding algorithm (e.g., ‘SAME’, ‘VALID’) or explicit padding amounts.
lhs_dilation – Dilation factors for the input (lhs) spatial dimensions.
rhs_dilation – Dilation factors for the kernel (rhs) spatial dimensions.
dimension_numbers – A structure specifying the dimension layout for lhs, rhs, and the output.
feature_group_count – The number of feature groups for grouped convolution.
batch_group_count – The number of batch groups.
precision – The numerical precision configuration for the computation.
preferred_element_type – The target data type for accumulation.
out_sharding – Optional sharding spec for the output array.
- Returns:
An Array containing the convolution result.
- qwix.dot(a: ~jax.Array | ~qwix._src.core.qarray.QArray, b: ~jax.Array | ~qwix._src.core.qarray.QArray, precision: None | str | ~jax._src.lax.lax.Precision | tuple[str, str] | tuple[~jax._src.lax.lax.Precision, ~jax._src.lax.lax.Precision] | ~jax._src.lax.lax.DotAlgorithm | ~jax._src.lax.lax.DotAlgorithmPreset = None, preferred_element_type: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType | None = None, out_sharding=None, *, _qwix_dot_general: ~typing.Callable[[...], ~jax.Array] = <function dot_general>)[source]#
Computes the dot product with support for
QArrayinputs.This function serves as a drop-in replacement for jax.numpy.dot.
It automatically dispatches to a quantized implementation if inputs are compatible ``QArray``s. It also supports subchannel quantization where applicable.
- Parameters:
a – The left-hand side, either a jax.Array or QArray.
b – The right-hand side, either a jax.Array or QArray.
precision – The numerical precision configuration for the computation.
preferred_element_type – The element type for the accumulation. Unlike standard
jax.numpy.dot, this function exposes this argument to allow control over quantized accumulation precision.out_sharding – Optional sharding spec for the output array.
_qwix_dot_general – Internal argument for dependency injection of the underlying
dot_generalimplementation. Defaults toqwix.dot_general.
- Returns:
The dot product of
aandb.
- qwix.dot_general(lhs: Array | QArray, rhs: Array | QArray, dimension_numbers: tuple[tuple[Sequence[int], Sequence[int]], tuple[Sequence[int], Sequence[int]]], precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, **kwargs) Array[source]#
Computes a general dot product with support for
QArrayinputs.This function serves as a drop-in replacement for jax.lax.dot_general.
- Parameters:
lhs – The left-hand side, either a jax.Array or QArray.
rhs – The right-hand side, either a jax.Array or QArray.
dimension_numbers – The dimension numbers passed to dot_general.
precision – The precision for jax.lax.dot_general.
preferred_element_type – The preferred element type for jax.lax.dot_general.
**kwargs – Additional keyword arguments to dot_general.
- Returns:
a floating-point jax.Array.
- qwix.einsum(*args, _qwix_dot_general=<function dot_general>, preferred_element_type: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType | None = None, **kwargs) Array[source]#
Computes Einstein summation convention with support for
QArrayinputs.This function serves as a drop-in replacement for jax.numpy.einsum.
- Parameters:
*args – Arguments to einsum.
_qwix_dot_general – The dot_general function to use.
preferred_element_type – The preferred element type for jax.lax.dot_general.
**kwargs – Keyword arguments to einsum.
- Returns:
The result of the einsum, a floating-point jax.Array.
- qwix.ragged_dot(lhs: Array | QArray, rhs: Array | QArray, group_sizes: Array, precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, group_offset: Array | None = None) Array[source]#
Computes a ragged dot product with support for
QArrayinputs.This function serves as a drop-in replacement for jax.lax.ragged_dot.
It is a convenience wrapper around
ragged_dot_generalwith standard matrix multiplication dimension numbers.- Parameters:
lhs – The left-hand side, either a jax.Array or QArray.
rhs – The right-hand side, either a jax.Array or QArray.
group_sizes – An array of integers specifying the size of each group in the ragged dimension.
precision – The numerical precision configuration for the computation.
preferred_element_type – The target data type for accumulation.
group_offset – Optional starting offset for the groups.
- Returns:
An Array containing the result of the ragged dot product.
- qwix.ragged_dot_general(lhs: Array | QArray, rhs: Array | QArray, group_sizes: Array, dimension_numbers: RaggedDotDimensionNumbers, precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, group_offset: Array | None = None) Array[source]#
Computes a general ragged dot product with support for
QArrayinputs.This function serves as a drop-in replacement for jax.lax.ragged_dot_general.
It automatically dispatches to a quantized implementation if inputs are compatible ``QArray``s. Otherwise, it dequantizes inputs and falls back to the standard floating-point JAX implementation.
- Parameters:
lhs – The left-hand side, either a jax.Array or QArray.
rhs – The right-hand side, either a jax.Array or QArray.
group_sizes – An array of integers specifying the size of each group in the ragged dimension.
dimension_numbers – A
jax.lax.RaggedDotDimensionNumbersstruct specifying the contracting, batch, and ragged dimensions.precision – The numerical precision configuration for the computation.
preferred_element_type – The target data type for accumulation.
group_offset – Optional starting offset for the groups.
- Returns:
An Array containing the result of the ragged dot product.