Core Operations#

This module provides the low-level mathematical operations for Qwix.

These functions are designed to be quantized equivalents of standard JAX operations. For example, qwix.dot_general mirrors jax.lax.dot_general, but is specifically engineered to accept QArray inputs. It performs the underlying quantized arithmetic (such as integer matrix multiplication) and correctly handles the propagation of scales and zero-points.

qwix.conv_general_dilated(lhs: Array | QArray, rhs: Array | QArray, window_strides: Sequence[int], padding: str | Sequence[tuple[int, int]], lhs_dilation: Sequence[int] | None = None, rhs_dilation: Sequence[int] | None = None, dimension_numbers: tuple[str, str, str] | ConvDimensionNumbers | None = None, feature_group_count: int = 1, batch_group_count: int = 1, precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, out_sharding=None) Array[source]#

Computes a general dilated convolution with support for QArray inputs.

This function serves as a drop-in replacement for jax.lax.conv_general_dilated.

It automatically dispatches to a quantized implementation if inputs are compatible ``QArray``s. Otherwise, it dequantizes inputs and falls back to the standard floating-point JAX implementation.

Parameters:
  • lhs – The left-hand side, either a jax.Array or QArray.

  • rhs – The right-hand side, either a jax.Array or QArray.

  • window_strides – A sequence of integers specifying the stride of the convolution window.

  • padding – The padding algorithm (e.g., ‘SAME’, ‘VALID’) or explicit padding amounts.

  • lhs_dilation – Dilation factors for the input (lhs) spatial dimensions.

  • rhs_dilation – Dilation factors for the kernel (rhs) spatial dimensions.

  • dimension_numbers – A structure specifying the dimension layout for lhs, rhs, and the output.

  • feature_group_count – The number of feature groups for grouped convolution.

  • batch_group_count – The number of batch groups.

  • precision – The numerical precision configuration for the computation.

  • preferred_element_type – The target data type for accumulation.

  • out_sharding – Optional sharding spec for the output array.

Returns:

An Array containing the convolution result.

qwix.dot(a: ~jax.Array | ~qwix._src.core.qarray.QArray, b: ~jax.Array | ~qwix._src.core.qarray.QArray, precision: None | str | ~jax._src.lax.lax.Precision | tuple[str, str] | tuple[~jax._src.lax.lax.Precision, ~jax._src.lax.lax.Precision] | ~jax._src.lax.lax.DotAlgorithm | ~jax._src.lax.lax.DotAlgorithmPreset = None, preferred_element_type: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType | None = None, out_sharding=None, *, _qwix_dot_general: ~typing.Callable[[...], ~jax.Array] = <function dot_general>)[source]#

Computes the dot product with support for QArray inputs.

This function serves as a drop-in replacement for jax.numpy.dot.

It automatically dispatches to a quantized implementation if inputs are compatible ``QArray``s. It also supports subchannel quantization where applicable.

Parameters:
  • a – The left-hand side, either a jax.Array or QArray.

  • b – The right-hand side, either a jax.Array or QArray.

  • precision – The numerical precision configuration for the computation.

  • preferred_element_type – The element type for the accumulation. Unlike standard jax.numpy.dot, this function exposes this argument to allow control over quantized accumulation precision.

  • out_sharding – Optional sharding spec for the output array.

  • _qwix_dot_general – Internal argument for dependency injection of the underlying dot_general implementation. Defaults to qwix.dot_general.

Returns:

The dot product of a and b.

qwix.dot_general(lhs: Array | QArray, rhs: Array | QArray, dimension_numbers: tuple[tuple[Sequence[int], Sequence[int]], tuple[Sequence[int], Sequence[int]]], precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, **kwargs) Array[source]#

Computes a general dot product with support for QArray inputs.

This function serves as a drop-in replacement for jax.lax.dot_general.

Parameters:
  • lhs – The left-hand side, either a jax.Array or QArray.

  • rhs – The right-hand side, either a jax.Array or QArray.

  • dimension_numbers – The dimension numbers passed to dot_general.

  • precision – The precision for jax.lax.dot_general.

  • preferred_element_type – The preferred element type for jax.lax.dot_general.

  • **kwargs – Additional keyword arguments to dot_general.

Returns:

a floating-point jax.Array.

qwix.einsum(*args, _qwix_dot_general=<function dot_general>, preferred_element_type: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType | None = None, **kwargs) Array[source]#

Computes Einstein summation convention with support for QArray inputs.

This function serves as a drop-in replacement for jax.numpy.einsum.

Parameters:
  • *args – Arguments to einsum.

  • _qwix_dot_general – The dot_general function to use.

  • preferred_element_type – The preferred element type for jax.lax.dot_general.

  • **kwargs – Keyword arguments to einsum.

Returns:

The result of the einsum, a floating-point jax.Array.

qwix.ragged_dot(lhs: Array | QArray, rhs: Array | QArray, group_sizes: Array, precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, group_offset: Array | None = None) Array[source]#

Computes a ragged dot product with support for QArray inputs.

This function serves as a drop-in replacement for jax.lax.ragged_dot.

It is a convenience wrapper around ragged_dot_general with standard matrix multiplication dimension numbers.

Parameters:
  • lhs – The left-hand side, either a jax.Array or QArray.

  • rhs – The right-hand side, either a jax.Array or QArray.

  • group_sizes – An array of integers specifying the size of each group in the ragged dimension.

  • precision – The numerical precision configuration for the computation.

  • preferred_element_type – The target data type for accumulation.

  • group_offset – Optional starting offset for the groups.

Returns:

An Array containing the result of the ragged dot product.

qwix.ragged_dot_general(lhs: Array | QArray, rhs: Array | QArray, group_sizes: Array, dimension_numbers: RaggedDotDimensionNumbers, precision: None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset = None, preferred_element_type: str | type[Any] | dtype | SupportsDType | None = None, group_offset: Array | None = None) Array[source]#

Computes a general ragged dot product with support for QArray inputs.

This function serves as a drop-in replacement for jax.lax.ragged_dot_general.

It automatically dispatches to a quantized implementation if inputs are compatible ``QArray``s. Otherwise, it dequantizes inputs and falls back to the standard floating-point JAX implementation.

Parameters:
  • lhs – The left-hand side, either a jax.Array or QArray.

  • rhs – The right-hand side, either a jax.Array or QArray.

  • group_sizes – An array of integers specifying the size of each group in the ragged dimension.

  • dimension_numbers – A jax.lax.RaggedDotDimensionNumbers struct specifying the contracting, batch, and ragged dimensions.

  • precision – The numerical precision configuration for the computation.

  • preferred_element_type – The target data type for accumulation.

  • group_offset – Optional starting offset for the groups.

Returns:

An Array containing the result of the ragged dot product.