From 806fa895c8ea2c3a584a5631dea6784fae66afeb Mon Sep 17 00:00:00 2001 From: DongHyun Choi Date: Thu, 6 Jun 2024 09:41:11 -0700 Subject: [PATCH] [EXPERIMENTAL] Remove leading 1s from the scale. PiperOrigin-RevId: 640928025 --- aqt/jax/v2/aqt_dot_general.py | 43 +++++++++++++++++++----- aqt/jax/v2/aqt_tensor.py | 10 ++++++ aqt/jax/v2/flax/aqt_flax.py | 63 +++++++++++++++++++++++++++++------ aqt/jax/v2/utils.py | 16 +++++++++ 4 files changed, 113 insertions(+), 19 deletions(-) diff --git a/aqt/jax/v2/aqt_dot_general.py b/aqt/jax/v2/aqt_dot_general.py index c7e16ff0..0290f655 100644 --- a/aqt/jax/v2/aqt_dot_general.py +++ b/aqt/jax/v2/aqt_dot_general.py @@ -382,8 +382,12 @@ def __call__( ) if self.rhs.use_fwd_quant: assert fwd_quantized, msg + rhs_scale = rhs.qx.scale[0] # pytype: disable=attribute-error + ndim_diff = rhs.qx.qvalue.ndim - rhs_scale.ndim # pytype: disable=attribute-error + if ndim_diff > 0: + rhs_scale = jnp.expand_dims(rhs_scale, axis=list(range(ndim_diff))) scale_t = transpose.rhs_scale_transpose_for_lhs_input( - rhs.qx.scale[0], dimension_numbers, lhs.shape # pytype: disable=attribute-error + rhs_scale, dimension_numbers, lhs.shape ) # Cast rhs scales to lhs dtype when multiplying with lhs. This is to @@ -483,6 +487,23 @@ def _postprocess_qtensor( rhs_incomplete_qt = rhs_qt rhs_qt = None + # Optimization: Remove leading 1s of rhs scales if the rhs dequant mode is + # OUTPUT. + # In this case, the rhs scales will be transposed to + # (rhs_ba, [1] * lhs_ra, rhs_ra). + # if rhs_ba is empty (which is mostly the case), then removing the leading + # 1s will not give any overhead. + # we need a validity checker for this! + # 1. rhs. 2. All ras are at the final dimension. 3. No ba + # (This could be softened). 4. OUTPUT. + # If all the above conditions are met, then we can just remove the ras, + # and we do not need to transpose. + # We do NOT need calibration axes - we only need the dimension numbers! + + if self.rhs.dequant_mode == DequantMode.OUTPUT: + if utils.is_reducable(rhs.ndim, dimension_numbers): + rhs_incomplete_qt = rhs_incomplete_qt.remove_leading_ones_from_scale() + lhs_quantized, rhs_quantized = self.dg_quantizer.calculate_qvalue( lhs, lhs_incomplete_qt, rhs, rhs_incomplete_qt ) @@ -622,14 +643,18 @@ def _maybe_dequant( out.scale.extend(extend_scale) if cfg.rhs.dequant_mode == DequantMode.OUTPUT: - extend_scale = _get_scale_t( - qt=rhs_qt, - transpose_fn=transpose.rhs_scale_transpose_to_output, - dimension_numbers=dimension_numbers, - lhs_shape=lhs_qin.shape, - rhs_shape=rhs_qin.shape, - ) - out.scale.extend(extend_scale) + if utils.is_reducable(rhs_qin.ndim, dimension_numbers): + # No need to transpose. + out.scale.extend(rhs_qt.scale) + else: + extend_scale = _get_scale_t( + qt=rhs_qt, + transpose_fn=transpose.rhs_scale_transpose_to_output, + dimension_numbers=dimension_numbers, + lhs_shape=lhs_qin.shape, + rhs_shape=rhs_qin.shape, + ) + out.scale.extend(extend_scale) return out diff --git a/aqt/jax/v2/aqt_tensor.py b/aqt/jax/v2/aqt_tensor.py index 294a1370..0c6fd498 100644 --- a/aqt/jax/v2/aqt_tensor.py +++ b/aqt/jax/v2/aqt_tensor.py @@ -127,6 +127,16 @@ def qvalue_astype(self, dtype) -> Self: assert self.is_full(), _MSG_NO_QVALUE return self.replace(qvalue=self.qvalue.astype(dtype)) # pytype: disable=attribute-error + def remove_leading_ones_from_scale(self) -> Self: + """Utility function to remove leading 1s from the scale.""" + if self.scale is None: + return self + + ret = self.replace( # pytype: disable=attribute-error + scale=[utils.remove_leading_ones(s) for s in self.scale] + ) + return ret + def __getitem__(self, idx: jax_typing.ArrayLike) -> Self: """Returns the indexed subtensor on the first axis.""" assert self.scale_t is None, 'scale_t is not supported in __getitem__' diff --git a/aqt/jax/v2/flax/aqt_flax.py b/aqt/jax/v2/flax/aqt_flax.py index ec6ad24d..73c529e5 100644 --- a/aqt/jax/v2/flax/aqt_flax.py +++ b/aqt/jax/v2/flax/aqt_flax.py @@ -141,6 +141,7 @@ def _maybe_recover_scale_from_scale_t( is_rhs: bool, lhs_shape: Sequence[int], rhs_shape: Sequence[int], + dequant_mode ) -> aqt_tensor.QTensor: """Recovers scale from scale_t if necessary.""" if qt is None or qt.scale is not None or qt.scale_t is None: @@ -148,6 +149,11 @@ def _maybe_recover_scale_from_scale_t( transpose_fn = transpose.lhs_recover_scale_from_scale_t if is_rhs: + if dequant_mode == aqt_dot_general.DequantMode.OUTPUT: + if utils.is_reducable(len(rhs_shape), dimension_numbers): + return qt.replace( + scale=[scale_t for scale_t in qt.scale_t], scale_t=None + ) transpose_fn = transpose.rhs_recover_scale_from_scale_t return qt.replace( @@ -165,6 +171,7 @@ def _populate_scale_t( is_rhs: bool, lhs_shape: Sequence[int], rhs_shape: Sequence[int], + dequant_mode ) -> aqt_tensor.QTensor: """Populates scale_t from scale.""" if qt.scale is None: @@ -172,6 +179,9 @@ def _populate_scale_t( transpose_fn = transpose.lhs_scale_transpose_to_output if is_rhs: + if dequant_mode == aqt_dot_general.DequantMode.OUTPUT: + if utils.is_reducable(len(rhs_shape), dimension_numbers): + return qt.replace(scale_t=[scale for scale in qt.scale]) transpose_fn = transpose.rhs_scale_transpose_to_output return qt.replace( @@ -234,7 +244,7 @@ def make_aqt_dg( self, lhs_shape, rhs_shape, - dimension_numbers: tuple[Iterable[int], Iterable[int]], + dimension_numbers ): if self.cfg is None: return jax.lax.dot_general @@ -251,11 +261,24 @@ def make_aqt_dg( ) assert lhs_scale is not None lhs_scale_shape = lhs_scale.shape - rhs_scale = transpose.rhs_scale_transpose_to_output( - jnp.zeros(rhs_scale_shape), dimension_numbers, lhs_shape, rhs_shape - ) - assert rhs_scale is not None - rhs_scale_shape = rhs_scale.shape + + if cfg.fwd.rhs.dequant_mode == aqt_dot_general.DequantMode.OUTPUT: + if utils.is_reducable(len(rhs_shape), dimension_numbers): + rhs_scale_shape = utils.remove_leading_ones( + jnp.zeros(rhs_scale_shape) + ).shape + else: + rhs_scale = transpose.rhs_scale_transpose_to_output( + jnp.zeros(rhs_scale_shape), dimension_numbers, lhs_shape, rhs_shape + ) + assert rhs_scale is not None + rhs_scale_shape = rhs_scale.shape + else: + rhs_scale = transpose.rhs_scale_transpose_to_output( + jnp.zeros(rhs_scale_shape), dimension_numbers, lhs_shape, rhs_shape + ) + assert rhs_scale is not None + rhs_scale_shape = rhs_scale.shape rhs_qm = self.rhs_quant_mode lhs_qm = self.lhs_quant_mode @@ -390,12 +413,22 @@ def ret_dg( # Recover scale from scale_t, if necessary. # The quantized tensor loaded from the legacy freezer has only scale_t. lhs_qt = _maybe_recover_scale_from_scale_t( - lhs_qt, dimension_numbers, False, lhs_shape, rhs_shape + lhs_qt, dimension_numbers, False, lhs_shape, rhs_shape, + cfg.fwd.lhs.dequant_mode ) rhs_qt = _maybe_recover_scale_from_scale_t( - rhs_qt, dimension_numbers, True, lhs_shape, rhs_shape + rhs_qt, dimension_numbers, True, lhs_shape, rhs_shape, + cfg.fwd.rhs.dequant_mode ) + # Optimize rhs scale. + if ( + cfg.fwd.rhs.dequant_mode == aqt_dot_general.DequantMode.OUTPUT + and rhs_qt + ): + if utils.is_reducable(rhs.ndim, dimension_numbers): + rhs_qt = rhs_qt.remove_leading_ones_from_scale() + cfg.apply_custom_vjp_on_jax = False out, (out_lhs_qt, out_rhs_qt) = aqt_flax_dg_core.dg_core_flax_lifted( lhs, rhs, lhs_qt, rhs_qt, dimension_numbers, self, cfg @@ -436,11 +469,21 @@ def ret_dg( # We need to populate the stored QTensor with scale_t. if cfg.fwd.lhs.calibration_mode == calib_contracting_axis: out_lhs_qt = _populate_scale_t( - out_lhs_qt, dimension_numbers, False, lhs_shape, rhs_shape + out_lhs_qt, + dimension_numbers, + False, + lhs_shape, + rhs_shape, + cfg.fwd.lhs.dequant_mode, ) if cfg.fwd.rhs.calibration_mode == calib_contracting_axis: out_rhs_qt = _populate_scale_t( - out_rhs_qt, dimension_numbers, True, lhs_shape, rhs_shape + out_rhs_qt, + dimension_numbers, + True, + lhs_shape, + rhs_shape, + cfg.fwd.rhs.dequant_mode, ) if self.lhs_apply_quant_mode: diff --git a/aqt/jax/v2/utils.py b/aqt/jax/v2/utils.py index 6757a8d4..ea58abdf 100644 --- a/aqt/jax/v2/utils.py +++ b/aqt/jax/v2/utils.py @@ -143,6 +143,22 @@ def get_remaining_axes( return ret +def is_reducable(rhs_ndim, dimension_numbers): + + (_, rhs_ca), (_, rhs_ba) = dimension_numbers + rhs_ra = get_remaining_axes(rhs_ndim, rhs_ca, rhs_ba) + return rhs_ra and not rhs_ba and min(rhs_ra) == len(rhs_ca) + + +def remove_leading_ones(array: jnp.ndarray): + squeeze_axes = [] + for i, dim in enumerate(array.shape): + if dim != 1: + break + squeeze_axes.append(i) + return jnp.squeeze(array, axis=tuple(squeeze_axes)) + + @flax_slots_dataclass class Context: key: jax.Array | None = dynamic_field()