Source code for qwix._src.core.conv_general

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Quantized jax.lax.conv_general_dilated."""
# pylint: disable=line-too-long

from collections.abc import Sequence
from typing import Any
import jax
from jax import numpy as jnp
from qwix._src.core import numerics
from qwix._src.core import qarray


def get_how_to_quantize(
    *,
    dimension_numbers: jax.lax.ConvDimensionNumbers,
    for_lhs: bool,
    **kwargs: Any,
) -> qarray.HowToQuantize:
  """Gets how to quantize from conv's dimension_numbers.

  Use channelwise for batch dimension and out feature dimension.

  Args:
    dimension_numbers: The conv's dimension_numbers.
    for_lhs: Whether to quantize lhs or rhs.
    **kwargs: Additional keyword arguments to HowToQuantize.

  Returns:
    How to quantize lhs or rhs.
  """
  if for_lhs:
    channelwise_axes = [dimension_numbers.lhs_spec[0]]
  else:
    channelwise_axes = [dimension_numbers.rhs_spec[0]]
  return qarray.HowToQuantize(
      channelwise_axes=channelwise_axes,
      tiled_axes={},
      **kwargs,
  )


def get_transpose(
    dimension_numbers: jax.lax.ConvDimensionNumbers, for_lhs: bool
) -> list[int | None]:
  """Returns the transpose list for the given dimension_numbers."""
  transpose = [None] * len(dimension_numbers.out_spec)
  if for_lhs:
    # Only batch dimension can be channelwise thus transposed.
    transpose[dimension_numbers.out_spec[0]] = dimension_numbers.lhs_spec[0]
  else:
    # Only out feature dimension can be channelwise thus transposed.
    transpose[dimension_numbers.out_spec[1]] = dimension_numbers.rhs_spec[0]
  return transpose


def _slow_conv_general_dilated(
    lhs: qarray.MaybeQArray,
    rhs: qarray.MaybeQArray,
    window_strides: Sequence[int],
    padding: str | Sequence[tuple[int, int]],
    lhs_dilation: Sequence[int] | None = None,
    rhs_dilation: Sequence[int] | None = None,
    dimension_numbers: jax.lax.ConvGeneralDilatedDimensionNumbers = None,
    feature_group_count: int = 1,
    batch_group_count: int = 1,
    precision: jax.lax.PrecisionLike = None,
    preferred_element_type: jax.typing.DTypeLike | None = None,
    out_sharding=None,
) -> jax.Array:
  """Dequantizes first then computes in floating-point types."""
  if isinstance(lhs, qarray.QArray):
    lhs = qarray.dequantize(lhs)
  if isinstance(rhs, qarray.QArray):
    rhs = qarray.dequantize(rhs)
  return jax.lax.conv_general_dilated(
      lhs,
      rhs,
      window_strides,
      padding,
      lhs_dilation,
      rhs_dilation,
      dimension_numbers,
      feature_group_count,
      batch_group_count,
      precision,
      preferred_element_type,
      out_sharding,
  )


def _fast_conv_general_dilated(
    lhs: qarray.MaybeQArray,
    rhs: qarray.MaybeQArray,
    window_strides: Sequence[int],
    padding: str | Sequence[tuple[int, int]],
    lhs_dilation: Sequence[int] | None = None,
    rhs_dilation: Sequence[int] | None = None,
    dimension_numbers: jax.lax.ConvGeneralDilatedDimensionNumbers = None,
    feature_group_count: int = 1,
    batch_group_count: int = 1,
    preferred_element_type: jax.typing.DTypeLike | None = None,
    out_sharding=None,
) -> jax.Array:
  """Quantized jax.lax.conv_general_dilated. Both sides must be QArrays."""
  dimension_numbers = jax.lax.conv_dimension_numbers(
      lhs.shape, rhs.shape, dimension_numbers
  )
  preferred_element_type, result_type = qarray.get_accumulator_and_result_type(
      lhs, rhs, preferred_element_type=preferred_element_type
  )

  if isinstance(lhs, qarray.QArray):
    lhs_value = lhs.qvalue
    lhs_scale = lhs.scale
    lhs_zero_point = lhs.zero_point
    if qarray.get_tiled_axes(lhs):
      raise ValueError('subchannel is not supported for conv_general_dilated.')
  else:
    lhs_value = lhs
    lhs_scale = None
    lhs_zero_point = None

  if isinstance(rhs, qarray.QArray):
    rhs_value = rhs.qvalue
    rhs_scale = rhs.scale
    rhs_zero_point = rhs.zero_point
    if qarray.get_tiled_axes(rhs):
      raise ValueError('subchannel is not supported for conv_general_dilated.')
  else:
    rhs_value = rhs
    rhs_scale = None
    rhs_zero_point = None

  if rhs_zero_point is not None:
    raise ValueError('Asymmetric quantization for rhs is not supported.')

  res = jax.lax.conv_general_dilated(
      lhs_value,
      rhs_value,
      window_strides=window_strides,
      padding=padding,
      lhs_dilation=lhs_dilation,
      rhs_dilation=rhs_dilation,
      dimension_numbers=dimension_numbers,
      feature_group_count=feature_group_count,
      batch_group_count=batch_group_count,
      preferred_element_type=preferred_element_type,
      out_sharding=out_sharding,
  )
  if lhs_zero_point is not None:
    # TODO(zhuyunx): This value can be constant folded in SRQ scenarios.
    res -= jax.lax.conv_general_dilated(
        jnp.broadcast_to(lhs_zero_point, lhs_value.shape),
        rhs_value,
        window_strides=window_strides,
        padding=padding,
        lhs_dilation=lhs_dilation,
        rhs_dilation=rhs_dilation,
        dimension_numbers=dimension_numbers,
        feature_group_count=feature_group_count,
        batch_group_count=batch_group_count,
        preferred_element_type=preferred_element_type,
    )

  if lhs_scale is not None:
    transpose = get_transpose(dimension_numbers, for_lhs=True)
    lhs_scale = qarray.transpose_array(lhs_scale, transpose)
    res = qarray.call_with_generic_broadcast(jnp.multiply, res, lhs_scale)

  if rhs_scale is not None:
    transpose = get_transpose(dimension_numbers, for_lhs=False)
    rhs_scale = qarray.transpose_array(rhs_scale, transpose)
    res = qarray.call_with_generic_broadcast(jnp.multiply, res, rhs_scale)

  return res.astype(result_type)


[docs] def conv_general_dilated( lhs: qarray.MaybeQArray, rhs: qarray.MaybeQArray, window_strides: Sequence[int], padding: str | Sequence[tuple[int, int]], lhs_dilation: Sequence[int] | None = None, rhs_dilation: Sequence[int] | None = None, dimension_numbers: jax.lax.ConvGeneralDilatedDimensionNumbers = None, feature_group_count: int = 1, batch_group_count: int = 1, precision: jax.lax.PrecisionLike = None, preferred_element_type: jax.typing.DTypeLike | None = None, out_sharding=None, ) -> jax.Array: """Computes a general dilated convolution with support for ``QArray`` inputs. This function serves as a drop-in replacement for `jax.lax.conv_general_dilated <https://docs.jax.dev/en/latest/_autosummary/jax.lax.conv_general_dilated.html>`_. It automatically dispatches to a quantized implementation if inputs are compatible ``QArray``s. Otherwise, it dequantizes inputs and falls back to the standard floating-point JAX implementation. Args: lhs: The left-hand side, either a jax.Array or QArray. rhs: The right-hand side, either a jax.Array or QArray. window_strides: A sequence of integers specifying the stride of the convolution window. padding: The padding algorithm (e.g., 'SAME', 'VALID') or explicit padding amounts. lhs_dilation: Dilation factors for the input (lhs) spatial dimensions. rhs_dilation: Dilation factors for the kernel (rhs) spatial dimensions. dimension_numbers: A structure specifying the dimension layout for lhs, rhs, and the output. feature_group_count: The number of feature groups for grouped convolution. batch_group_count: The number of batch groups. precision: The numerical precision configuration for the computation. preferred_element_type: The target data type for accumulation. out_sharding: Optional sharding spec for the output array. Returns: An Array containing the convolution result. """ use_fast_path = True for x in (lhs, rhs): if isinstance(x, jax.Array) and numerics.should_quantize(x.dtype): use_fast_path = False break if use_fast_path: return _fast_conv_general_dilated( lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, preferred_element_type, out_sharding, ) else: return _slow_conv_general_dilated( lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, precision, preferred_element_type, out_sharding, )