Source code for qwix._src.core.ragged_dot

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Quantized jax.lax.ragged_dot and jax.lax.ragged_dot_general."""
# pylint: disable=line-too-long

from collections.abc import Collection, Sequence
import jax
from jax import numpy as jnp
from qwix._src.core import numerics
from qwix._src.core import qarray


# RaggedDotDimensionNumbers that specify the simple case (i.e., qwix.ragged_dot)
_BASIC_RAGGED_DOT_DIMENSION_NUMBERS = jax.lax.RaggedDotDimensionNumbers(
    dot_dimension_numbers=(((1,), (1,)), ((), ())),
    lhs_ragged_dimensions=[0],
    rhs_group_dimensions=[0],
)


def _apply_group_channelwise_scale(
    rhs_scale: jax.Array,
    lhs_shape: tuple[int, ...],
    group_sizes: jax.Array,
    dimension_numbers: jax.lax.RaggedDotDimensionNumbers,
    precision: jax.lax.PrecisionLike,
    group_offset: jax.Array | None,
) -> jax.Array:
  """Expands the group dimension of rhs_scale using a gather-like op."""
  (lhs_ca, _), _ = dimension_numbers.dot_dimension_numbers

  # Create a `jnp.ones` tensor with the same rank and layout as lhs_val,
  # but with contracting dimensions set to size 1.
  ones_shape = list(lhs_shape)
  for contracting_axis in lhs_ca:
    ones_shape[contracting_axis] = 1
  lhs_ones = jnp.ones(tuple(ones_shape), rhs_scale.dtype)

  return jax.lax.ragged_dot_general(
      lhs_ones,
      rhs_scale,
      group_sizes,
      dimension_numbers,
      precision=precision,
      group_offset=group_offset,
  )


def _apply_tiling(
    contracting_axes: Sequence[int],
    batch_axes: Sequence[int],
    tiled_axes: Collection[int],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
  """Apply tiling to dimension numbers.

  Each tiled contracting axis is split into two axes, the first being the new
  batch axis, and the second being the new contracting axis.

  Args:
    contracting_axes: The original contracting axes.
    batch_axes: The original batch axes.
    tiled_axes: The tiled axes. Must be a subset of contracting_axes.

  Returns:
    A tuple of (new_ca, new_ba, sum_axes).
  """
  new_ca = [a + sum(t <= a for t in tiled_axes) for a in contracting_axes]
  new_ba = [a + sum(t < a for t in tiled_axes) for a in batch_axes]
  # We choose to insert the tile_count axes to the end of the batch axes.
  # Alternatively, we could insert them to the beginning or to the middle,
  # as long as lhs and rhs use the same order.
  new_ba += [
      a + sum(t < a for t in tiled_axes)
      for a in contracting_axes
      if a in tiled_axes
  ]
  sum_axes = range(len(batch_axes), len(new_ba))
  return tuple(new_ca), tuple(new_ba), tuple(sum_axes)


def _ragged_get_scale_transpose(
    dimension_numbers: jax.lax.RaggedDotDimensionNumbers,
    ndims: tuple[int, int],
) -> tuple[list[int | None], list[int | None]]:
  """Calculates the transpose permutation for lhs_scale and rhs_scale."""
  (lhs_ca, rhs_ca), (lhs_ba, rhs_ba) = dimension_numbers.dot_dimension_numbers
  lhs_ragged_dims = dimension_numbers.lhs_ragged_dimensions
  rhs_group_dims = dimension_numbers.rhs_group_dimensions

  lhs_remaining_dims = sorted(
      set(range(ndims[0])) - set(lhs_ca) - set(lhs_ba) - set(lhs_ragged_dims)
  )
  rhs_remaining_dims = sorted(
      set(range(ndims[1])) - set(rhs_ca) - set(rhs_ba) - set(rhs_group_dims)
  )

  lhs_scale_transpose = (
      list(lhs_ba)
      + list(lhs_ragged_dims)
      + list(lhs_remaining_dims)
      + [None] * len(rhs_remaining_dims)
  )
  rhs_scale_transpose = (
      list(rhs_ba)
      + [None] * (len(lhs_ragged_dims) + len(lhs_remaining_dims))
      + list(rhs_remaining_dims)
  )

  return lhs_scale_transpose, rhs_scale_transpose


def _fast_ragged_dot_general(
    lhs: qarray.MaybeQArray,
    rhs: qarray.MaybeQArray,
    group_sizes: jax.Array,
    dimension_numbers: jax.lax.RaggedDotDimensionNumbers,
    precision: jax.lax.PrecisionLike = None,
    preferred_element_type: jax.typing.DTypeLike | None = None,
    group_offset: jax.Array | None = None,
):
  """Quantized ragged_dot_general with a fast path."""
  if isinstance(lhs, qarray.QArray):
    lhs_val = lhs.qvalue
    lhs_scale = lhs.scale
    lhs_tiled_axes = qarray.get_tiled_axes(lhs)
  else:
    lhs_val = lhs
    lhs_scale = None
    lhs_tiled_axes = {}
  if isinstance(rhs, qarray.QArray):
    rhs_val = rhs.qvalue
    rhs_scale = rhs.scale
    rhs_tiled_axes = qarray.get_tiled_axes(rhs)
  else:
    rhs_val = rhs
    rhs_scale = None
    rhs_tiled_axes = {}

  (lhs_ca, rhs_ca), (lhs_ba, rhs_ba) = dimension_numbers.dot_dimension_numbers

  # Figure out the tiled axes to use for the dot_general. For greater
  # flexibility, we allow a non-tiled axis to be contracted with a tiled axis.
  # However, if both axes are tiled, their tile sizes must be the same.
  lhs_tiled_ca = {}
  rhs_tiled_ca = {}
  for l, r in zip(lhs_ca, rhs_ca):
    lhs_tile_size = lhs_tiled_axes.get(l)
    rhs_tile_size = rhs_tiled_axes.get(r)
    if lhs_tile_size and rhs_tile_size and lhs_tile_size != rhs_tile_size:
      raise ValueError(
          'Contracting axes must be tiled with the same tile size.'
          f' {lhs_tiled_axes=} {rhs_tiled_axes=} {dimension_numbers=}'
      )
    if lhs_tile_size or rhs_tile_size:
      lhs_tiled_ca[l] = lhs_tile_size or rhs_tile_size
      rhs_tiled_ca[r] = lhs_tile_size or rhs_tile_size

  # Split lhs/rhs_value for tiled axes.
  lhs_val = qarray.split_axis(lhs_val, lhs_tiled_ca)
  rhs_val = qarray.split_axis(rhs_val, rhs_tiled_ca)

  lhs_ca, lhs_ba, sum_axes = _apply_tiling(lhs_ca, lhs_ba, lhs_tiled_ca)
  rhs_ca, rhs_ba, _ = _apply_tiling(rhs_ca, rhs_ba, rhs_tiled_ca)
  dot_dimension_numbers = (lhs_ca, rhs_ca), (lhs_ba, rhs_ba)
  dimension_numbers = jax.lax.RaggedDotDimensionNumbers(
      dot_dimension_numbers=dot_dimension_numbers,
      lhs_ragged_dimensions=dimension_numbers.lhs_ragged_dimensions,
      rhs_group_dimensions=dimension_numbers.rhs_group_dimensions,
  )

  preferred_element_type, result_type = qarray.get_accumulator_and_result_type(
      lhs, rhs, preferred_element_type=preferred_element_type
  )

  out = jax.lax.ragged_dot_general(
      lhs_val,
      rhs_val,
      group_sizes,
      dimension_numbers,
      precision=precision,
      preferred_element_type=preferred_element_type,
      group_offset=group_offset,
  )

  lhs_scale_transpose, rhs_scale_transpose = _ragged_get_scale_transpose(
      dimension_numbers, (len(lhs_val.shape), len(rhs_val.shape))
  )
  if lhs_scale is not None:
    lhs_scale = qarray.split_axis(lhs_scale, {a: 1 for a in lhs_tiled_ca})
    lhs_scale = qarray.transpose_array(lhs_scale, lhs_scale_transpose)
    out = qarray.call_with_generic_broadcast(jnp.multiply, out, lhs_scale)
  if rhs_scale is not None:
    rhs_scale = qarray.split_axis(rhs_scale, {a: 1 for a in rhs_tiled_ca})
    # Check if the scale has a group dimension that needs special handling.
    if (
        dimension_numbers.rhs_group_dimensions
        and rhs_scale.shape[dimension_numbers.rhs_group_dimensions[0]] > 1
    ):
      rhs_scale = _apply_group_channelwise_scale(
          rhs_scale,
          lhs_val.shape,
          group_sizes,
          dimension_numbers,
          precision,
          group_offset,
      )
    else:
      rhs_scale = qarray.transpose_array(rhs_scale, rhs_scale_transpose)
    out = qarray.call_with_generic_broadcast(jnp.multiply, out, rhs_scale)

  if sum_axes:
    # [tile_count1, tile_count2, ..., M, N] -> [M, N]
    out = jnp.sum(out, axis=sum_axes)

  return out.astype(result_type)


def _slow_ragged_dot_general(
    lhs: qarray.MaybeQArray,
    rhs: qarray.MaybeQArray,
    group_sizes: jax.Array,
    dimension_numbers: jax.lax.RaggedDotDimensionNumbers,
    **kwargs,
):
  """A ragged_dot_general which dequantizes first."""
  lhs = qarray.dequantize(lhs) if isinstance(lhs, qarray.QArray) else lhs
  rhs = qarray.dequantize(rhs) if isinstance(rhs, qarray.QArray) else rhs
  return jax.lax.ragged_dot_general(
      lhs, rhs, group_sizes, dimension_numbers, **kwargs
  )


[docs] def ragged_dot_general( lhs: qarray.MaybeQArray, rhs: qarray.MaybeQArray, group_sizes: jax.Array, dimension_numbers: jax.lax.RaggedDotDimensionNumbers, precision: jax.lax.PrecisionLike = None, preferred_element_type: jax.typing.DTypeLike | None = None, group_offset: jax.Array | None = None, ) -> jax.Array: """Computes a general ragged dot product with support for ``QArray`` inputs. This function serves as a drop-in replacement for `jax.lax.ragged_dot_general <https://docs.jax.dev/en/latest/_autosummary/jax.lax.ragged_dot_general.html>`_. It automatically dispatches to a quantized implementation if inputs are compatible ``QArray``s. Otherwise, it dequantizes inputs and falls back to the standard floating-point JAX implementation. Args: lhs: The left-hand side, either a jax.Array or QArray. rhs: The right-hand side, either a jax.Array or QArray. group_sizes: An array of integers specifying the size of each group in the ragged dimension. dimension_numbers: A ``jax.lax.RaggedDotDimensionNumbers`` struct specifying the contracting, batch, and ragged dimensions. precision: The numerical precision configuration for the computation. preferred_element_type: The target data type for accumulation. group_offset: Optional starting offset for the groups. Returns: An Array containing the result of the ragged dot product. """ use_fast_path = True for operand in (lhs, rhs): if isinstance(operand, qarray.QArray): if operand.zero_point is not None: use_fast_path = False break else: if numerics.should_quantize(operand.dtype): # Always dequantize on inputs if any of the operands is in bf16/fp32, # because XLA is able to fuse the dequantize and the matmul. The slow # path is usually not slower than the fast path, since both use fp # matmul, and will be significantly faster when subchannel or zero_point # is used. use_fast_path = False break if use_fast_path: return _fast_ragged_dot_general( lhs, rhs, group_sizes, dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, group_offset=group_offset, ) else: return _slow_ragged_dot_general( lhs, rhs, group_sizes, dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, group_offset=group_offset, )
[docs] def ragged_dot( lhs: qarray.MaybeQArray, rhs: qarray.MaybeQArray, group_sizes: jax.Array, precision: jax.lax.PrecisionLike = None, preferred_element_type: jax.typing.DTypeLike | None = None, group_offset: jax.Array | None = None, ) -> jax.Array: """Computes a ragged dot product with support for ``QArray`` inputs. This function serves as a drop-in replacement for `jax.lax.ragged_dot <https://docs.jax.dev/en/latest/_autosummary/jax.lax.ragged_dot.html>`_. It is a convenience wrapper around ``ragged_dot_general`` with standard matrix multiplication dimension numbers. Args: lhs: The left-hand side, either a jax.Array or QArray. rhs: The right-hand side, either a jax.Array or QArray. group_sizes: An array of integers specifying the size of each group in the ragged dimension. precision: The numerical precision configuration for the computation. preferred_element_type: The target data type for accumulation. group_offset: Optional starting offset for the groups. Returns: An Array containing the result of the ragged dot product. """ return ragged_dot_general( lhs, rhs, group_sizes, dimension_numbers=_BASIC_RAGGED_DOT_DIMENSION_NUMBERS, precision=precision, preferred_element_type=preferred_element_type, group_offset=group_offset, )