Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/maxtext/configs/models/deepseek4-284b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ num_experts_per_tok: 6
mlp_activations_limit: 10
shared_experts: 1
routed_score_func: "sqrtsoftplus"
norm_topk_prob: true
routed_bias: true
routed_scaling_factor: 1.5


# --- Attention configuration ---
attention_type: 'compressed'
Expand All @@ -62,3 +66,4 @@ rope_type: "default"
rope_max_timescale: 10000 # Main RoPE theta
compressed_rope_max_timescale: 160000 # Compressed RoPE theta
max_position_embeddings: 1048576
original_max_position_embeddings: 65536
91 changes: 79 additions & 12 deletions src/maxtext/layers/mhc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from jax.sharding import Mesh
from maxtext.common.common_types import Array, Config
from maxtext.common.common_types import HyperConnectionType
from maxtext.layers.initializers import default_bias_init, default_scalar_init, nd_dense_init
from maxtext.layers.initializers import default_bias_init, default_scalar_init, nd_dense_init, variable_to_logically_partitioned
from maxtext.layers import nnx_wrappers
from maxtext.layers.normalizations import RMSNorm


Expand Down Expand Up @@ -61,21 +62,18 @@ def sinkhorn(t, iters=20):
# Use float32 precision for numerical stability during normalization
initial_dtype = t.dtype
t = t.astype(jnp.float32)
eps = 1e-5

# Column-wise normalization (axis=-2) - positive and sum up to 1 across columns
# Equivalent to t = exp(t) / jnp.sum(jnp.exp(t), axis=-2)
t = jax.nn.softmax(t, axis=-2)
t = jax.nn.softmax(t, axis=-1) + eps
t = t / (jnp.sum(t, axis=-2, keepdims=True) + eps)

def body_fun(i, val):
# L1 Normalization: val / sum(val) with clipping of denominator
# Normalize rows (axis -1)
val = val / jnp.clip(jnp.sum(val, axis=-1, keepdims=True), min=1e-12)
# Normalize columns (axis -2)
val = val / jnp.clip(jnp.sum(val, axis=-2, keepdims=True), min=1e-12)
val = val / (jnp.sum(val, axis=-1, keepdims=True) + eps)
val = val / (jnp.sum(val, axis=-2, keepdims=True) + eps)
return val

# Use lax.fori_loop for an efficient, JIT-friendly loop
t = jax.lax.fori_loop(0, iters, body_fun, t)
t = jax.lax.fori_loop(0, iters - 1, body_fun, t)
return t.astype(initial_dtype)


Expand Down Expand Up @@ -224,7 +222,7 @@ def res_mapping(self, x: Array):
output = sinkhorn(intermediate, self.sinkhorn_iterations)
return output

def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int):
def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: float, eps: float = 0.0):
"""Helper function for both pre and post mappings."""
# In MaxText, we match weight precision to activations before Matmul
alpha = jnp.asarray(alpha, self.dtype)
Expand All @@ -233,7 +231,7 @@ def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale
# Apply projection: (b, s, k*d) @ (k*d, k) -> (b, s, k)
h = jnp.einsum("bsm,mk -> bsk", x, alpha, precision=self.matmul_precision)
intermediate = alpha_scale * h + beta[None, None, :]
output = scale * jax.nn.sigmoid(intermediate)
output = scale * jax.nn.sigmoid(intermediate) + eps
return output

def __call__(
Expand Down Expand Up @@ -269,6 +267,7 @@ def __call__(
self.pre_alpha[...],
self.pre_beta[...],
1.0,
eps=1e-5,
)
layer_input = jnp.einsum("bskd,bsk -> bsd", x, pre_mapping, precision=self.matmul_precision)

Expand Down Expand Up @@ -307,3 +306,71 @@ def __call__(
res_mapping = self.res_mapping(norm_x)
res_out = jnp.einsum("bskd,bskm -> bsmd", x, res_mapping, precision=self.matmul_precision)
return res_out + post_out, metadata


class DeepSeek4HyperHead(nnx.Module):
"""Final HC-stream collapse; used by DeepSeek V4 before the shared RMSNorm."""

def __init__(
self,
config: Config,
mesh: Mesh,
rngs: nnx.Rngs,
):
self.config = config
self.mesh = mesh
self.rngs = rngs
self.dtype = config.dtype
self.weight_dtype = config.weight_dtype
self.mhc_expansion_rate = config.mhc_expansion_rate
self.emb_dim = config.emb_dim
self.eps = 1e-6

# Weight matrices
weight_init = nd_dense_init(1.0, "fan_in", "normal")
self.hc_fn = nnx.Param(
weight_init(
rngs.params(),
(self.mhc_expansion_rate * self.emb_dim, self.mhc_expansion_rate),
self.weight_dtype,
in_axis=0,
out_axis=1,
),
out_sharding=("activation_embed", None),
)
self.hc_base = nnx.Param(
default_bias_init(rngs.params(), (self.mhc_expansion_rate,), self.weight_dtype),
out_sharding=(None,),
)
self.hc_scale = nnx.Param(
default_scalar_init(rngs.params(), (1,), self.weight_dtype),
out_sharding=(None,),
)

def __call__(self, x: Array) -> Array:
# x shape: [batch, length, k, d]
b, s, k, d = x.shape
assert k == self.mhc_expansion_rate
assert d == self.emb_dim

flat = jnp.reshape(x, (b, s, k * d))
flat_f32 = flat.astype(jnp.float32)
variance = jnp.mean(jnp.square(flat_f32), axis=-1, keepdims=True)
flat_norm = flat_f32 * jax.lax.rsqrt(variance + self.eps)

hc_fn = jnp.asarray(self.hc_fn[...], jnp.float32)
hc_base = jnp.asarray(self.hc_base[...], jnp.float32)
hc_scale = jnp.asarray(self.hc_scale[...], jnp.float32)

mixes = jnp.einsum("bsm,mk->bsk", flat_norm, hc_fn, precision=jax.lax.Precision(self.config.matmul_precision))
pre = jax.nn.sigmoid(mixes * hc_scale[None, None, :] + hc_base[None, None, :]) + self.eps

x_f32 = x.astype(jnp.float32)
out = jnp.sum(pre[:, :, :, None] * x_f32, axis=2)
return out.astype(self.dtype)


DeepSeek4HyperHeadToLinen = nnx_wrappers.to_linen_class(
DeepSeek4HyperHead,
base_metadata_fn=variable_to_logically_partitioned,
)
15 changes: 8 additions & 7 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,14 +699,15 @@ def get_topk(self, gate_logits, pre_bias_logits, rngs=None, input_ids=None):
else:
top_k_weights, top_k_indices = jax.lax.top_k(gate_logits, self.num_experts_per_tok)

if self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK:
if self.config.decoder_block in (ctypes.DecoderBlockType.DEEPSEEK, ctypes.DecoderBlockType.DEEPSEEK4):
top_k_weights = self.deepseek_scale_weights(top_k_weights)
elif self.config.decoder_block not in (ctypes.DecoderBlockType.LLAMA4, ctypes.DecoderBlockType.GEMMA4):
top_k_weights = jax.nn.softmax(top_k_weights.astype(jnp.float32), axis=-1).astype(self.dtype)
else:
if self.config.decoder_block not in (ctypes.DecoderBlockType.LLAMA4, ctypes.DecoderBlockType.GEMMA4):
top_k_weights = jax.nn.softmax(top_k_weights.astype(jnp.float32), axis=-1).astype(self.dtype)

# Normalization of router weights (e.g. used by Qwen3, Gemma4).
if self.config.norm_topk_prob:
top_k_weights /= top_k_weights.sum(axis=-1, keepdims=True)
# Normalization of router weights (e.g. used by Qwen3, Gemma4).
if self.config.norm_topk_prob:
top_k_weights /= top_k_weights.sum(axis=-1, keepdims=True)

return top_k_weights, top_k_indices

Expand Down Expand Up @@ -793,7 +794,7 @@ def apply_ffn_activation(self, layer_w0, layer_w1):
layer_act = self.activation_fn(layer_w0 * 1.702)
glu = jnp.multiply(layer_w0, layer_act)
intermediate_layer = jnp.multiply(glu, (layer_w1 + 1))
elif self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK and self.config.mlp_activations_limit > 0.0:
elif self.config.decoder_block in (ctypes.DecoderBlockType.DEEPSEEK, ctypes.DecoderBlockType.DEEPSEEK4) and self.config.mlp_activations_limit > 0.0:
# DeepSeek V4 uses bounds to clip the SwiGLU activations
layer_w0 = jnp.clip(layer_w0, min=None, max=self.config.mlp_activations_limit)
layer_w1 = jnp.clip(layer_w1, min=-self.config.mlp_activations_limit, max=self.config.mlp_activations_limit)
Expand Down
Loading
Loading