Source code for qwix._src.providers.qt

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Quantized training (QT) support."""

import dataclasses
import functools
from typing import Any, Callable, Mapping, Sequence

import jax
from jax import numpy as jnp
from qwix._src import averaging
from qwix._src import flax_util
from qwix._src import qconfig
from qwix._src.core import conv_general_qt
from qwix._src.core import dot_general_qt
from qwix._src.core import ragged_dot_qt
from qwix._src.core import stochastic_rounding


[docs] @dataclasses.dataclass(frozen=True, kw_only=True) class QtRule(qconfig.QuantizationRule): """QuantizationRule with all settings specific to Quantized Training (QT).""" # In backward pass, quantize the gradients to the given type. This doesn't # affect the residuals as the residuals will reuse the quantization in the # forward pass. bwd_qtype: jax.typing.DTypeLike | None = None # In backward pass, calibrate the gradients using the given method. bwd_calibration_method: str = 'absmax' # In backward pass, enable subchannel for contraction axes when calculating # the gradient of weights. Note that the tiling is actually applied to the # the incoming gradient and the residual activation rather than any "weight". bwd_weight_grad_tile_size: int | float | None = None # If True, disable channelwise axes for both forward and backward passes. disable_channelwise_axes: bool = False # Use stochastic rounding for the gradients. # Currently supports "uniform" and "low_bit_uniform". bwd_stochastic_rounding: str | None = None # Use channelwise noise for stochastic rounding. By default, it will generate # noise for the 0th dimension and broadcast it over remaining dimensions. channelwise_noise_axes: Sequence[int] = (0,) # Override any fields in DotGeneralQtConfig or ConvGeneralQtConfig. This is # highly experimental and subjects to changes with no backward compatibility # guarantees. additional_qt_config: Mapping[str, Any] | None = None
[docs] class QtProvider(qconfig.QuantizationProvider): """Quantization provider for Quantized Training (QT).""" def _init_rule(self, rule: qconfig.QuantizationRule) -> QtRule: rule = super()._init_rule(rule) if not isinstance(rule, QtRule): rule = QtRule(**dataclasses.asdict(rule)) return rule
[docs] def dot_general( self, lhs: jax.Array, rhs: jax.Array, dimension_numbers: jax.lax.DotDimensionNumbers, precision: jax.lax.PrecisionLike = None, preferred_element_type: jax.typing.DTypeLike | None = None, *, out_sharding=None, ) -> jax.Array: """QT dot_general.""" rule, op_id = self._get_current_rule_and_op_id('dot_general') if rule is None or rule.weight_qtype is None: return jax.lax.dot_general( lhs, rhs, dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, out_sharding=out_sharding, ) config = self._create_dot_general_qt_config(rule, op_id, lhs, rhs) return dot_general_qt.dot_general_qt(lhs, rhs, dimension_numbers, config)
[docs] def einsum( self, einsum_str: str, *operands: jax.Array, precision: jax.lax.PrecisionLike = None, preferred_element_type: jax.typing.DTypeLike | None = None, _dot_general: Callable[..., jax.Array] = jax.lax.dot_general, # pylint: disable=invalid-name out_sharding=None, ) -> jax.Array: """QT einsum.""" rule, op_id = self._get_current_rule_and_op_id('einsum') if rule is None or rule.weight_qtype is None: return jnp.einsum( einsum_str, *operands, precision=precision, preferred_element_type=preferred_element_type, _dot_general=_dot_general, out_sharding=out_sharding, ) if not isinstance(einsum_str, str) or len(operands) != 2: raise ValueError(f'Unsupported einsum format: {einsum_str=} {operands=}') def custom_dot_general( lhs, rhs, dimension_numbers, precision, preferred_element_type, **kwargs, ): # TODO(dangyi): support preferred_element_type. del precision, preferred_element_type, kwargs return dot_general_qt.dot_general_qt( lhs, rhs, dimension_numbers, # lhs and rhs might be flipped by einsum so we cannot use the operands # from the einsum call. self._create_dot_general_qt_config(rule, op_id, lhs, rhs), ) with jax.disable_jit(): return jnp.einsum( einsum_str, *operands, precision=precision, preferred_element_type=preferred_element_type, _dot_general=custom_dot_general, out_sharding=out_sharding, )
[docs] def conv_general_dilated( self, lhs: jax.Array, rhs: jax.Array, window_strides: Sequence[int], padding: str | Sequence[tuple[int, int]], lhs_dilation: Sequence[int] | None = None, rhs_dilation: Sequence[int] | None = None, dimension_numbers: jax.lax.ConvGeneralDilatedDimensionNumbers = None, feature_group_count: int = 1, batch_group_count: int = 1, precision: jax.lax.PrecisionLike = None, preferred_element_type: jax.typing.DTypeLike | None = None, out_sharding=None, ) -> jax.Array: """QT conv_general_dilated.""" rule, op_id = self._get_current_rule_and_op_id('conv_general_dilated') if rule is None or rule.weight_qtype is None: return jax.lax.conv_general_dilated( lhs, rhs, window_strides, padding, lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=precision, preferred_element_type=preferred_element_type, out_sharding=out_sharding, ) if rule.tile_size: raise ValueError('subchannel is not supported for conv_general_dilated.') config = self._create_conv_general_qt_config(rule, op_id, lhs, rhs) return conv_general_qt.conv_general_qt( lhs, rhs, config, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, out_sharding, )
[docs] def ragged_dot( self, lhs: jax.Array, rhs: jax.Array, group_sizes: jax.Array, precision: jax.lax.PrecisionLike = None, preferred_element_type: jax.typing.DTypeLike | None = None, group_offset: jax.Array | None = None, ) -> jax.Array: """QT ragged_dot.""" rule, _ = self._get_current_rule_and_op_id('ragged_dot') if rule is None or rule.weight_qtype is None: return jax.lax.ragged_dot( lhs, rhs, group_sizes, precision=precision, preferred_element_type=preferred_element_type, group_offset=group_offset, ) config = self._create_ragged_dot_qt_config(rule) return ragged_dot_qt.ragged_dot_qt( lhs, rhs, group_sizes, config, precision, preferred_element_type, group_offset, )
[docs] def get_intercept_map(self): """Used for interception.""" return super().get_intercept_map() | { 'jax.lax.conv_general_dilated': self.conv_general_dilated, 'jax.lax.dot_general': self.dot_general, 'jax.numpy.einsum': self.einsum, 'jax.lax.ragged_dot': self.ragged_dot, }
def _collect_quant_stat( self, name: str, batch_axes: tuple[int, ...], calibration: averaging.Calibration, ) -> averaging.Calibration: """Collects the quantization statistics.""" # Calculate the mean over the batch axes. calibration = jax.tree.map( lambda x: x.mean(axis=batch_axes, keepdims=True), calibration ) aggregator = averaging.SimpleMovingAverage() quant_stat = flax_util.get_or_create_variable( 'quant_stats', name, lambda: aggregator.init(calibration) ) if flax_util.should_update_quant_stats(): quant_stat.value = aggregator.update(quant_stat.value, calibration) return aggregator.get_calibration(quant_stat.value, calibration) def _create_conv_general_qt_config( self, rule: qconfig.QuantizationRule, op_id: str, lhs: jax.Array, rhs: jax.Array, ) -> conv_general_qt.ConvGeneralQtConfig: """Creates a ConvGeneralQtConfig for conv_general_dilated.""" assert isinstance(rule, QtRule), '_init_rule should have been called.' # Assume LHS is an activation and RHS is a weight. del lhs lhs_collect_quant_stat = None if rule.act_qtype is not None and rule.act_static_scale: lhs_collect_quant_stat = functools.partial( self._collect_quant_stat, f'{op_id}_lhs', rule.act_batch_axes ) assert flax_util.find_param(rhs) is not None return conv_general_qt.ConvGeneralQtConfig( # fwd configs. lhs_qtype=rule.act_qtype, rhs_qtype=rule.weight_qtype, lhs_calibration_method=rule.act_calibration_method, rhs_calibration_method=rule.weight_calibration_method, lhs_collect_quant_stat=lhs_collect_quant_stat, rhs_collect_quant_stat=None, lhs_disable_channelwise_axes=rule.disable_channelwise_axes, rhs_disable_channelwise_axes=rule.disable_channelwise_axes, # bwd configs. dlhs_grad_qtype=rule.bwd_qtype, drhs_grad_qtype=rule.bwd_qtype, dlhs_grad_calibration_method=rule.bwd_calibration_method, drhs_grad_calibration_method=rule.bwd_calibration_method, dlhs_grad_disable_channelwise_axes=rule.disable_channelwise_axes, drhs_grad_disable_channelwise_axes=rule.disable_channelwise_axes, ) def _create_dot_general_qt_config( self, rule: qconfig.QuantizationRule, op_id: str, lhs: jax.Array, rhs: jax.Array, ) -> dot_general_qt.DotGeneralQtConfig: """Creates a DotGeneralQtConfig for dot_general and einsum.""" assert isinstance(rule, QtRule), '_init_rule should have been called.' # LHS configs based on whether it's a weight or an activation. lhs_qtype = None lhs_calibration_method = None lhs_is_weight = flax_util.find_param(lhs) is not None lhs_collect_quant_stat = None if lhs_is_weight: if rule.weight_qtype is not None: lhs_qtype = rule.weight_qtype lhs_calibration_method = rule.weight_calibration_method elif rule.act_qtype is not None: lhs_qtype = rule.act_qtype lhs_calibration_method = rule.act_calibration_method if rule.act_static_scale: lhs_collect_quant_stat = functools.partial( self._collect_quant_stat, f'{op_id}_lhs', rule.act_batch_axes ) # RHS configs based on whether it's a weight or an activation. rhs_qtype = None rhs_calibration_method = None rhs_is_weight = flax_util.find_param(rhs) is not None rhs_collect_quant_stat = None if rhs_is_weight: assert not lhs_is_weight, 'lhs and rhs cannot be both weights.' if rule.weight_qtype is not None: rhs_qtype = rule.weight_qtype rhs_calibration_method = rule.weight_calibration_method elif rule.act_qtype is not None: rhs_qtype = rule.act_qtype rhs_calibration_method = rule.act_calibration_method if rule.act_static_scale: rhs_collect_quant_stat = functools.partial( self._collect_quant_stat, f'{op_id}_rhs', rule.act_batch_axes ) # bwd config, which is only enabled when bwd_qtype is set. dlhs_tile_size = None drhs_tile_size = None bwd_stochastic_rounding_noise_fn = None if rule.bwd_qtype is not None: if lhs_is_weight: dlhs_tile_size = rule.bwd_weight_grad_tile_size if rhs_is_weight: drhs_tile_size = rule.bwd_weight_grad_tile_size if rule.bwd_stochastic_rounding is not None: bwd_stochastic_rounding_noise_fn = stochastic_rounding.get_noise_fn( method=rule.bwd_stochastic_rounding, key=flax_util.make_rng('stochastic_rounding'), channelwise_noise_axes=rule.channelwise_noise_axes, ) qt_config = dot_general_qt.DotGeneralQtConfig( # fwd configs. lhs_qtype=lhs_qtype, rhs_qtype=rhs_qtype, tile_size=rule.tile_size, lhs_calibration_method=lhs_calibration_method, rhs_calibration_method=rhs_calibration_method, lhs_collect_quant_stat=lhs_collect_quant_stat, rhs_collect_quant_stat=rhs_collect_quant_stat, lhs_disable_channelwise_axes=rule.disable_channelwise_axes, rhs_disable_channelwise_axes=rule.disable_channelwise_axes, # dlhs configs. dlhs_grad_qtype=rule.bwd_qtype, dlhs_grad_calibration_method=rule.bwd_calibration_method, dlhs_tile_size=dlhs_tile_size, dlhs_stochastic_rounding_noise_fn=bwd_stochastic_rounding_noise_fn, dlhs_grad_disable_channelwise_axes=rule.disable_channelwise_axes, # drhs configs. drhs_grad_qtype=rule.bwd_qtype, drhs_grad_calibration_method=rule.bwd_calibration_method, drhs_tile_size=drhs_tile_size, drhs_stochastic_rounding_noise_fn=bwd_stochastic_rounding_noise_fn, drhs_grad_disable_channelwise_axes=rule.disable_channelwise_axes, ) if rule.additional_qt_config: qt_config = dataclasses.replace(qt_config, **rule.additional_qt_config) return qt_config def _create_ragged_dot_qt_config( self, rule: qconfig.QuantizationRule, ) -> ragged_dot_qt.RaggedDotQtConfig: """Creates a RaggedDotQtConfig for ragged_dot.""" assert isinstance(rule, QtRule), '_init_rule should have been called.' # Assume LHS is an activation and RHS is a weight. return ragged_dot_qt.RaggedDotQtConfig( # fwd configs. lhs_qtype=rule.act_qtype, rhs_qtype=rule.weight_qtype, # bwd configs. dlhs_grad_qtype=rule.bwd_qtype, drhs_grad_qtype=rule.bwd_qtype, )