Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions aqt/jax/v2/aqt_dot_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,12 @@ def __call__(
)
if self.rhs.use_fwd_quant:
assert fwd_quantized, msg
rhs_scale = rhs.qx.scale[0] # pytype: disable=attribute-error
ndim_diff = rhs.qx.qvalue.ndim - rhs_scale.ndim # pytype: disable=attribute-error
if ndim_diff > 0:
rhs_scale = jnp.expand_dims(rhs_scale, axis=list(range(ndim_diff)))
scale_t = transpose.rhs_scale_transpose_for_lhs_input(
rhs.qx.scale[0], dimension_numbers, lhs.shape # pytype: disable=attribute-error
rhs_scale, dimension_numbers, lhs.shape
)

# Cast rhs scales to lhs dtype when multiplying with lhs. This is to
Expand Down Expand Up @@ -483,6 +487,23 @@ def _postprocess_qtensor(
rhs_incomplete_qt = rhs_qt
rhs_qt = None

# Optimization: Remove leading 1s of rhs scales if the rhs dequant mode is
# OUTPUT.
# In this case, the rhs scales will be transposed to
# (rhs_ba, [1] * lhs_ra, rhs_ra).
# if rhs_ba is empty (which is mostly the case), then removing the leading
# 1s will not give any overhead.
# we need a validity checker for this!
# 1. rhs. 2. All ras are at the final dimension. 3. No ba
# (This could be softened). 4. OUTPUT.
# If all the above conditions are met, then we can just remove the ras,
# and we do not need to transpose.
# We do NOT need calibration axes - we only need the dimension numbers!

if self.rhs.dequant_mode == DequantMode.OUTPUT:
if utils.is_reducable(rhs.ndim, dimension_numbers):
rhs_incomplete_qt = rhs_incomplete_qt.remove_leading_ones_from_scale()

lhs_quantized, rhs_quantized = self.dg_quantizer.calculate_qvalue(
lhs, lhs_incomplete_qt, rhs, rhs_incomplete_qt
)
Expand Down Expand Up @@ -622,14 +643,18 @@ def _maybe_dequant(

out.scale.extend(extend_scale)
if cfg.rhs.dequant_mode == DequantMode.OUTPUT:
extend_scale = _get_scale_t(
qt=rhs_qt,
transpose_fn=transpose.rhs_scale_transpose_to_output,
dimension_numbers=dimension_numbers,
lhs_shape=lhs_qin.shape,
rhs_shape=rhs_qin.shape,
)
out.scale.extend(extend_scale)
if utils.is_reducable(rhs_qin.ndim, dimension_numbers):
# No need to transpose.
out.scale.extend(rhs_qt.scale)
else:
extend_scale = _get_scale_t(
qt=rhs_qt,
transpose_fn=transpose.rhs_scale_transpose_to_output,
dimension_numbers=dimension_numbers,
lhs_shape=lhs_qin.shape,
rhs_shape=rhs_qin.shape,
)
out.scale.extend(extend_scale)
return out


Expand Down
10 changes: 10 additions & 0 deletions aqt/jax/v2/aqt_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ def qvalue_astype(self, dtype) -> Self:
assert self.is_full(), _MSG_NO_QVALUE
return self.replace(qvalue=self.qvalue.astype(dtype)) # pytype: disable=attribute-error

def remove_leading_ones_from_scale(self) -> Self:
"""Utility function to remove leading 1s from the scale."""
if self.scale is None:
return self

ret = self.replace( # pytype: disable=attribute-error
scale=[utils.remove_leading_ones(s) for s in self.scale]
)
return ret

def __getitem__(self, idx: jax_typing.ArrayLike) -> Self:
"""Returns the indexed subtensor on the first axis."""
assert self.scale_t is None, 'scale_t is not supported in __getitem__'
Expand Down
63 changes: 53 additions & 10 deletions aqt/jax/v2/flax/aqt_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,19 @@ def _maybe_recover_scale_from_scale_t(
is_rhs: bool,
lhs_shape: Sequence[int],
rhs_shape: Sequence[int],
dequant_mode
) -> aqt_tensor.QTensor:
"""Recovers scale from scale_t if necessary."""
if qt is None or qt.scale is not None or qt.scale_t is None:
return qt

transpose_fn = transpose.lhs_recover_scale_from_scale_t
if is_rhs:
if dequant_mode == aqt_dot_general.DequantMode.OUTPUT:
if utils.is_reducable(len(rhs_shape), dimension_numbers):
return qt.replace(
scale=[scale_t for scale_t in qt.scale_t], scale_t=None
)
transpose_fn = transpose.rhs_recover_scale_from_scale_t

return qt.replace(
Expand All @@ -165,13 +171,17 @@ def _populate_scale_t(
is_rhs: bool,
lhs_shape: Sequence[int],
rhs_shape: Sequence[int],
dequant_mode
) -> aqt_tensor.QTensor:
"""Populates scale_t from scale."""
if qt.scale is None:
return qt

transpose_fn = transpose.lhs_scale_transpose_to_output
if is_rhs:
if dequant_mode == aqt_dot_general.DequantMode.OUTPUT:
if utils.is_reducable(len(rhs_shape), dimension_numbers):
return qt.replace(scale_t=[scale for scale in qt.scale])
transpose_fn = transpose.rhs_scale_transpose_to_output

return qt.replace(
Expand Down Expand Up @@ -234,7 +244,7 @@ def make_aqt_dg(
self,
lhs_shape,
rhs_shape,
dimension_numbers: tuple[Iterable[int], Iterable[int]],
dimension_numbers
):
if self.cfg is None:
return jax.lax.dot_general
Expand All @@ -251,11 +261,24 @@ def make_aqt_dg(
)
assert lhs_scale is not None
lhs_scale_shape = lhs_scale.shape
rhs_scale = transpose.rhs_scale_transpose_to_output(
jnp.zeros(rhs_scale_shape), dimension_numbers, lhs_shape, rhs_shape
)
assert rhs_scale is not None
rhs_scale_shape = rhs_scale.shape

if cfg.fwd.rhs.dequant_mode == aqt_dot_general.DequantMode.OUTPUT:
if utils.is_reducable(len(rhs_shape), dimension_numbers):
rhs_scale_shape = utils.remove_leading_ones(
jnp.zeros(rhs_scale_shape)
).shape
else:
rhs_scale = transpose.rhs_scale_transpose_to_output(
jnp.zeros(rhs_scale_shape), dimension_numbers, lhs_shape, rhs_shape
)
assert rhs_scale is not None
rhs_scale_shape = rhs_scale.shape
else:
rhs_scale = transpose.rhs_scale_transpose_to_output(
jnp.zeros(rhs_scale_shape), dimension_numbers, lhs_shape, rhs_shape
)
assert rhs_scale is not None
rhs_scale_shape = rhs_scale.shape
rhs_qm = self.rhs_quant_mode
lhs_qm = self.lhs_quant_mode

Expand Down Expand Up @@ -390,12 +413,22 @@ def ret_dg(
# Recover scale from scale_t, if necessary.
# The quantized tensor loaded from the legacy freezer has only scale_t.
lhs_qt = _maybe_recover_scale_from_scale_t(
lhs_qt, dimension_numbers, False, lhs_shape, rhs_shape
lhs_qt, dimension_numbers, False, lhs_shape, rhs_shape,
cfg.fwd.lhs.dequant_mode
)
rhs_qt = _maybe_recover_scale_from_scale_t(
rhs_qt, dimension_numbers, True, lhs_shape, rhs_shape
rhs_qt, dimension_numbers, True, lhs_shape, rhs_shape,
cfg.fwd.rhs.dequant_mode
)

# Optimize rhs scale.
if (
cfg.fwd.rhs.dequant_mode == aqt_dot_general.DequantMode.OUTPUT
and rhs_qt
):
if utils.is_reducable(rhs.ndim, dimension_numbers):
rhs_qt = rhs_qt.remove_leading_ones_from_scale()

cfg.apply_custom_vjp_on_jax = False
out, (out_lhs_qt, out_rhs_qt) = aqt_flax_dg_core.dg_core_flax_lifted(
lhs, rhs, lhs_qt, rhs_qt, dimension_numbers, self, cfg
Expand Down Expand Up @@ -436,11 +469,21 @@ def ret_dg(
# We need to populate the stored QTensor with scale_t.
if cfg.fwd.lhs.calibration_mode == calib_contracting_axis:
out_lhs_qt = _populate_scale_t(
out_lhs_qt, dimension_numbers, False, lhs_shape, rhs_shape
out_lhs_qt,
dimension_numbers,
False,
lhs_shape,
rhs_shape,
cfg.fwd.lhs.dequant_mode,
)
if cfg.fwd.rhs.calibration_mode == calib_contracting_axis:
out_rhs_qt = _populate_scale_t(
out_rhs_qt, dimension_numbers, True, lhs_shape, rhs_shape
out_rhs_qt,
dimension_numbers,
True,
lhs_shape,
rhs_shape,
cfg.fwd.rhs.dequant_mode,
)

if self.lhs_apply_quant_mode:
Expand Down
16 changes: 16 additions & 0 deletions aqt/jax/v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,22 @@ def get_remaining_axes(
return ret


def is_reducable(rhs_ndim, dimension_numbers):

(_, rhs_ca), (_, rhs_ba) = dimension_numbers
rhs_ra = get_remaining_axes(rhs_ndim, rhs_ca, rhs_ba)
return rhs_ra and not rhs_ba and min(rhs_ra) == len(rhs_ca)


def remove_leading_ones(array: jnp.ndarray):
squeeze_axes = []
for i, dim in enumerate(array.shape):
if dim != 1:
break
squeeze_axes.append(i)
return jnp.squeeze(array, axis=tuple(squeeze_axes))


@flax_slots_dataclass
class Context:
key: jax.Array | None = dynamic_field()
Expand Down