Core Operations#

This module provides the low-level mathematical operations for Qwix.

These functions are designed to be quantized equivalents of standard JAX operations. For example, qwix.dot_general mirrors jax.lax.dot_general, but is specifically engineered to accept QArray inputs. It performs the underlying quantized arithmetic (such as integer matrix multiplication) and correctly handles the propagation of scales and zero-points.

qwix.conv_general_dilated(lhs: Array | QArray, rhs: Array | QArray, window_strides: Sequence[int], padding: str | Sequence[tuple[int, int]], lhs_dilation: Sequence[int] | None = None, rhs_dilation: Sequence[int] | None = None, dimension_numbers: tuple[str, str, str] | ConvDimensionNumbers | None = None, feature_group_count: int = 1, batch_group_count: int = 1, precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, out_sharding=None) Array[source]#

Dispatches to fast or slow conv_general_dilated depending on the inputs.

qwix.dot(a: ~jax.Array | ~qwix._src.core.qarray.QArray, b: ~jax.Array | ~qwix._src.core.qarray.QArray, precision: None | str | ~jax._src.lax.lax.Precision | tuple[str, str] | tuple[~jax._src.lax.lax.Precision, ~jax._src.lax.lax.Precision] | ~jax._src.lax.lax.DotAlgorithm | ~jax._src.lax.lax.DotAlgorithmPreset = None, preferred_element_type: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType | None = None, out_sharding=None, *, _qwix_dot_general: ~typing.Callable[[...], ~jax.Array] = <function dot_general>)[source]#

jnp.dot with QArray support.

qwix.dot_general(lhs: Array | QArray, rhs: Array | QArray, dimension_numbers: tuple[tuple[Sequence[int], Sequence[int]], tuple[Sequence[int], Sequence[int]]], precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, **kwargs) Array[source]#

Quantized jax.lax.dot_general.

Parameters:
  • lhs – The left-hand side, either a jax.Array or QArray.

  • rhs – The right-hand side, either a jax.Array or QArray.

  • dimension_numbers – The dimension numbers passed to dot_general.

  • precision – The precision for jax.lax.dot_general.

  • preferred_element_type – The preferred element type for jax.lax.dot_general.

  • **kwargs – Additional keyword arguments to dot_general.

Returns:

a floating-point jax.Array.

qwix.einsum(*args, _qwix_dot_general=<function dot_general>, preferred_element_type: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType | None = None, **kwargs) Array[source]#

Quantized einsum that can take QArrays and returns floating-point jax.Array.

Parameters:
  • *args – Arguments to einsum.

  • _qwix_dot_general – The dot_general function to use.

  • preferred_element_type – The preferred element type for jax.lax.dot_general.

  • **kwargs – Keyword arguments to einsum.

Returns:

The result of the einsum, a floating-point jax.Array.

qwix.ragged_dot(lhs: Array | QArray, rhs: Array | QArray, group_sizes: Array, precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, group_offset: Array | None = None) Array[source]#

Quantized jax.lax.ragged_dot.

qwix.ragged_dot_general(lhs: Array | QArray, rhs: Array | QArray, group_sizes: Array, dimension_numbers: RaggedDotDimensionNumbers, precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, group_offset: Array | None = None) Array[source]#

Quantized jax.lax.ragged_dot_general.