Core Operations#
This module provides the low-level mathematical operations for Qwix.
These functions are designed to be quantized equivalents of standard JAX
operations. For example, qwix.dot_general mirrors jax.lax.dot_general,
but is specifically engineered to accept QArray inputs. It performs the
underlying quantized arithmetic (such as integer matrix multiplication) and
correctly handles the propagation of scales and zero-points.
- qwix.conv_general_dilated(lhs: Array | QArray, rhs: Array | QArray, window_strides: Sequence[int], padding: str | Sequence[tuple[int, int]], lhs_dilation: Sequence[int] | None = None, rhs_dilation: Sequence[int] | None = None, dimension_numbers: tuple[str, str, str] | ConvDimensionNumbers | None = None, feature_group_count: int = 1, batch_group_count: int = 1, precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, out_sharding=None) Array[source]#
Dispatches to fast or slow conv_general_dilated depending on the inputs.
- qwix.dot(a: ~jax.Array | ~qwix._src.core.qarray.QArray, b: ~jax.Array | ~qwix._src.core.qarray.QArray, precision: None | str | ~jax._src.lax.lax.Precision | tuple[str, str] | tuple[~jax._src.lax.lax.Precision, ~jax._src.lax.lax.Precision] | ~jax._src.lax.lax.DotAlgorithm | ~jax._src.lax.lax.DotAlgorithmPreset = None, preferred_element_type: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType | None = None, out_sharding=None, *, _qwix_dot_general: ~typing.Callable[[...], ~jax.Array] = <function dot_general>)[source]#
jnp.dot with QArray support.
- qwix.dot_general(lhs: Array | QArray, rhs: Array | QArray, dimension_numbers: tuple[tuple[Sequence[int], Sequence[int]], tuple[Sequence[int], Sequence[int]]], precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, **kwargs) Array[source]#
Quantized jax.lax.dot_general.
- Parameters:
lhs – The left-hand side, either a jax.Array or QArray.
rhs – The right-hand side, either a jax.Array or QArray.
dimension_numbers – The dimension numbers passed to dot_general.
precision – The precision for jax.lax.dot_general.
preferred_element_type – The preferred element type for jax.lax.dot_general.
**kwargs – Additional keyword arguments to dot_general.
- Returns:
a floating-point jax.Array.
- qwix.einsum(*args, _qwix_dot_general=<function dot_general>, preferred_element_type: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType | None = None, **kwargs) Array[source]#
Quantized einsum that can take QArrays and returns floating-point jax.Array.
- Parameters:
*args – Arguments to einsum.
_qwix_dot_general – The dot_general function to use.
preferred_element_type – The preferred element type for jax.lax.dot_general.
**kwargs – Keyword arguments to einsum.
- Returns:
The result of the einsum, a floating-point jax.Array.
- qwix.ragged_dot(lhs: Array | QArray, rhs: Array | QArray, group_sizes: Array, precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, group_offset: Array | None = None) Array[source]#
Quantized jax.lax.ragged_dot.
- qwix.ragged_dot_general(lhs: Array | QArray, rhs: Array | QArray, group_sizes: Array, dimension_numbers: RaggedDotDimensionNumbers, precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, group_offset: Array | None = None) Array[source]#
Quantized jax.lax.ragged_dot_general.