diff --git a/contrib/models/Qwen3.6-27B/README.md b/contrib/models/Qwen3.6-27B/README.md new file mode 100644 index 00000000..650a7012 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/README.md @@ -0,0 +1,397 @@ +# Contrib Model: Qwen3.6-27B + +NeuronX Distributed Inference implementation of Qwen3.6-27B, a 27B parameter dense model from Alibaba Cloud with a hybrid DeltaNet + GQA attention architecture. + +## Relationship to Qwen3.5-27B + +Qwen3.6-27B is a **post-training update** of Qwen3.5-27B with improved agentic coding and thinking preservation. The models share **identical architecture** (`qwen3_5` model_type, `Qwen3_5ForConditionalGeneration`) -- only weights differ. This contrib reuses the same NxDI implementation as [Qwen3.5-27B](../Qwen3.5-27B/) (PR #128). Any code updates to Qwen3.5-27B should be propagated to this contrib and vice versa. + +### Config differences from Qwen3.5-27B + +| Field | Value | Impact | +|-------|-------|--------| +| `output_gate_type` | `"swish"` | **Ignored** -- not used by HF transformers or NxDI (gate uses sigmoid) | +| `language_model_only` | `false` | Informational, not used by model code | +| `bos_token_id` | `248044` | New but not architecture-relevant | +| `pad_token_id` | `null` | New at text_config level (already handled) | +| `partial_rotary_factor` | `0.25` | Already in rope_parameters, redundant copy | +| `transformers_version` | `4.57.1` | Updated from `4.57.0.dev0` | + +No architecture changes are required relative to the Qwen3.5-27B hybrid +implementation. This contrib packages the NxDI Qwen3.6-27B model code, +DeltaNet NKI kernels, FP8/vLLM serving helpers, and validation coverage for the +Qwen3.6 weights. + +## Model Family + +| Model | HuggingFace ID | Params | Instance | +|-------|----------------|--------|----------| +| **Qwen3.6-27B** | [`Qwen/Qwen3.6-27B`](https://huggingface.co/Qwen/Qwen3.6-27B) | 27B | trn2.3xlarge (TP=4) | + +**License:** Apache 2.0 + +## Architecture Details + +| Feature | Value | +|---------|-------| +| Layers | 64 (48 DeltaNet + 16 GQA) | +| Layer Pattern | [3 DeltaNet + 1 GQA] x 16 | +| Hidden Size | 5120 | +| GQA Attention | 24 heads, 4 KV heads, head_dim=256 | +| DeltaNet Attention | 48 value heads, 16 key heads, k_dim=v_dim=128 | +| Dense MLP | SwiGLU (gate_proj + up_proj: 5120 -> 17408, down_proj: 17408 -> 5120) | +| Position Encoding | Partial RoPE (25% of head_dim = 64 dims), mRoPE for VL | +| Vocabulary | 248,320 | +| Normalization | RMSNorm with +1 weight convention | +| Activation | SiLU gated MLP | + +### Unique Architecture Features + +- **Hybrid DeltaNet + GQA:** 48 of 64 layers use Gated DeltaNet (linear recurrent attention), 16 layers use standard GQA with KV cache. The pattern repeats every 4 layers: 3 DeltaNet + 1 GQA. +- **DeltaNet Linear Attention:** Uses the delta rule for recurrent state updates with gated decay. Per-step: `state *= exp(g); delta = (v - state^T @ k) * beta; state += outer(k, delta); output = state^T @ q`. Runs as a chunked algorithm for context encoding, per-token recurrence for token generation. +- **Custom NKI Kernels:** Three NKI kernels implement the DeltaNet forward pass on Neuron: a per-token recurrent kernel (TKG), a per-chunk kernel (legacy), and a fused single-kernel chunked forward (CTE). The fused kernel uses a Neumann series for intra-chunk correction with state persistence in SBUF across chunks. +- **GQA Output Gate:** Attention layers use a sigmoid output gate. `q_proj` is 2x sized and interleaved: `[head0_query | head0_gate | head1_query | ...]`. The gate is split during weight conversion and applied after attention. +- **Partial RoPE:** Only 25% of head_dim (64 of 256 dimensions) receives rotary embeddings. The remaining 192 dimensions are identity (no rotation). +- **+1 RMSNorm Convention:** HF weights use `output = norm(x) * (1 + weight)` where weight is initialized to zeros. Converted to standard `output = norm(x) * weight` during loading by adding 1.0 to all RMSNorm weights (except DeltaNet internal norms, which use standard convention). +- **Vision-Language Support:** Optional ViT encoder runs on CPU (HBM fully consumed by 27B text decoder). Vision embeddings are injected via a scatter mask at traced input positions. + +## Test Results + +### Unit Tests (CPU) + +| Test Module | Tests | Status | +|-------------|-------|--------| +| test_config.py | 26 | 26/26 PASS | +| test_weight_conversion.py | 16 | 16/16 PASS | +| test_hybrid_cache_manager.py | 13 | 13/13 PASS | +| test_deltanet_decay.py | 2 | 2/2 PASS | +| **Total** | **57** | **57/57 PASS** | + +Unit tests are architecture-level and do not depend on weights. Coverage includes config parsing, weight conversion, hybrid cache allocation/update behavior, and DeltaNet decay handling. + +### Quality Validation (Qwen3.6-27B, trn2.3xlarge, TP=4, SDK 2.29) + +7/7 text-only quality tests passed with `enable_thinking=False`: + +| Test | Expected | Result | +|------|----------|--------| +| Speed of light | 299,792,458 m/s | PASS | +| 17 * 23 | 391 | PASS | +| 60mph * 2.5h | 150 miles | PASS | +| is_prime function | Correct Python | PASS | +| French translation | Bonjour, comment allez-vous ? | PASS | +| Capital of Japan | Tokyo | PASS | +| sqrt(144) | 12 | PASS | + +## Performance Benchmarks + +### Qwen3.6-27B on trn2.3xlarge (TP=4, LNC=2, SDK 2.29, BF16) + +**TTFT (Time To First Token)** + +| Input Length | P50 (ms) | P95 (ms) | +|-------------|----------|----------| +| 16 tokens | 305.3 | 305.6 | +| 64 tokens | 305.4 | 305.9 | +| 128 tokens | 306.6 | 306.8 | +| 256 tokens | 306.2 | 306.3 | + +**TPOT / Throughput** + +| Output Length | TPOT P50 (ms) | tok/s P50 | E2E P50 (ms) | +|--------------|---------------|-----------|---------------| +| 16 | 54.3 | 18.4 | 1,121 | +| 32 | 54.4 | 18.4 | 1,993 | +| 64 | 54.2 | 18.5 | 3,720 | +| 128 | 54.2 | 18.5 | 4,912 | + +### Comparison with Qwen3.5-27B + +| Metric | Qwen3.5-27B | Qwen3.6-27B | Delta | +|--------|------------|------------|-------| +| TPOT P50 | 53 ms | 54.2 ms | +2.3% | +| Throughput | 18.9 tok/s | 18.5 tok/s | -2.1% | +| TTFT (128 tok) | 576 ms | 306.6 ms | -47% * | + +\* TTFT improvement is due to compilation config differences (256-token bucket vs 128-token bucket), not model differences. Architectural performance is equivalent. + +### Long-Context vLLM Baseline + +A 128K FP8-MLP artifact was validated on trn2.3xlarge (TP=4, LNC=2, SDK 2.29) +with the vLLM Neuron plugin, Qwen chunked prefill, and native vLLM APC enabled. + +| Metric | Result | +|--------|--------| +| Max model length | 131,072 tokens | +| Context encoding bucket | 512 | +| Prefill throughput | 404-428 tok/s from 512 through 64K prompt tokens | +| Decode throughput | 26.3-26.6 tok/s | +| 64K quality | needle retrieval prompts returned all expected codes | +| State reset | repeated short-after-long validation passed after 32K and 64K requests | +| Peak Neuron device memory | ~53.25 GB decimal during the 64K eval | + +TTFT/TPOT details for the same 128K FP8/vLLM artifact: + +| Metric | Result | Notes | +|--------|--------|-------| +| Decode TPOT | ~37.6-38.0 ms/token | Derived from 26.3-26.6 tok/s decode | +| Cold 512-token TTFT | ~1.2-1.3s | Derived from measured prefill throughput plus one decode step | +| Cold 32K-token TTFT | ~76.6-81.1s | Derived from measured prefill throughput plus one decode step | +| Cold 64K-token TTFT | ~153-162s | Derived from measured prefill throughput plus one decode step | +| Warm APC latency, ~10.8K prompt | 1.36-2.38s | Exact-repeat, partial-prefix, and cross-prefix validation runs | +| Cold APC baseline, ~10.8K prompt | 25.17-26.68s | Same prompts with prefix cache disabled or cold | + +Native vLLM prefix caching/APC was also validated with exact greedy output +matches: + +| APC Scenario | Cold | Warm | Speedup | Result | +|--------------|------|------|---------|--------| +| Server exact-repeat, ~10.8K prompt tokens | 26.68s | 1.67s | 16.0x | exact text match | +| Offline exact-repeat | 26.19s | 2.38s | 11.0x | exact token-ID match | +| Offline partial-prefix reuse | 25.52s | 1.70s | 15.0x | exact token-ID match | +| Server cross-prefix reuse | 25.17s | 1.36s | 18.5x | exact text match | + +### Hybrid APC Follow-up Status + +Follow-up work on the `experimental` branch extended the baseline vLLM/APC +path toward Qwen3.6 Hybrid APC, where attention KV prefix reuse is only correct +when the matching DeltaNet recurrent/conv checkpoint is also available. + +What has been implemented and proven in that branch: + +- Scheduler-side safety gating prevents vLLM from reading an attention prefix + unless a matching GDN checkpoint is registered. +- Qwen request prep consumes scheduler-authorized, request-scoped restore keys + instead of relying on prefix length alone. +- The CTE restore path handles suffix-only execution over a restored prefix: + suffix tokens, slot mapping, `computed_context_lens`, `num_queries`, and GDN + restore metadata are kept aligned. +- BF16 single-request backed-prefix validation passes with cold/warm exactness + on the 2K checkpoint-boundary case. The proven shape restores a 256-token GDN + checkpoint, executes a 16-token suffix, and matches cold output. +- The safety fallback also passes: if attention KV has a prefix hit but no GDN + checkpoint exists, prefix reads are disabled and the request recomputes cold. + +Current blocker: + +- True generated-token batch-2 validation needs both `tkg_batch_size=2` and + `ctx_batch_size=2`. +- A batch-2 artifact with `tkg_batch_size=2` but `ctx_batch_size=1` failed in + vLLM-Neuron host-logits sampling because two prefills were packed into one + CTE row, then logits were reordered for two live request ids. +- Single-bucket `ctx_batch_size=2` / `tkg_batch_size=2` BF16 artifacts for CTE + bucket 256 and CTE bucket 512 compiled successfully. +- The combined multi-bucket artifact (`cte_buckets=256,512`, + `prefix_buckets=256,512`) started compiling and the TKG priority HLO passed, + but the smaller Trainium instance became SSH-unresponsive during all-HLO CTE + compilation. This appears to be a Neuron/NxDI compile-capacity or compile + orchestration issue, not a model-correctness failure. + +Expected outcome after the batch-2 artifact or an equivalent prefill-only +proof is available: + +- Batched Hybrid APC can preserve the same correctness rule as the + single-request path: + `usable_prefix_hit = attention_KV_prefix_hit AND matching_GDN_checkpoint_hit`. +- Warm repeated-prefix and partial-prefix requests should avoid replaying the + shared cold prefill while restoring the required GDN state. +- This is the path expected to turn the current exact single-request APC proof + into a measured cold-prefill performance win for batched serving. + +The fused CTE kernel and FP8 path are not the current correctness blockers. +The BF16 per-chunk CTE path is the reference path for Hybrid APC validation: +the fused BF16 CTE artifact has shown NaNs around token 105-106, and FP8 should +be revisited after the BF16 batch-2 serving contract is proven. + +### Key Observations + +- **BF16 TP=4 is HBM-limited:** The pure BF16 path is limited to short contexts on trn2.3xlarge. The validated 128K baseline uses MLP-only FP8 weights plus the hybrid cache manager. +- **DeltaNet enables efficient TKG:** Token generation uses O(1) per-token recurrence instead of O(n) KV cache attention for 48/64 layers. +- **vLLM APC is high leverage:** Repeated-prefix requests avoid replaying long chunked prefill and are the largest observed latency win for chat/RAG-style workloads. +- **Performance equivalent to Qwen3.5-27B:** The BF16 TPOT difference is within measurement noise. Expected since architectures are identical. + +## Usage + +### Text-Only (trn2.3xlarge, TP=4) + +```python +import json +import torch +from transformers import AutoTokenizer, GenerationConfig +from neuronx_distributed_inference.models.config import NeuronConfig, OnDeviceSamplingConfig +from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter + +from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronQwen35ForCausalLM + +model_path = "/path/to/Qwen3.6-27B" +compiled_path = "/scratch/qwen36_traced/" + +neuron_config = NeuronConfig( + tp_degree=4, + batch_size=1, + ctx_batch_size=1, + tkg_batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + logical_nc_config=2, + enable_bucketing=False, + flash_decoding_enabled=False, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + save_sharded_checkpoint=True, +) + +# Read config.json directly (model_type 'qwen3_5' may not be +# registered in all transformers versions) +import os +with open(os.path.join(model_path, "config.json")) as f: + hf_config = json.load(f) +text_config = hf_config.get("text_config", hf_config) +config_dict = dict(text_config) +config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) +config_dict.setdefault("tie_word_embeddings", False) + +config = Qwen35InferenceConfig( + neuron_config=neuron_config, + **config_dict, +) + +model = NeuronQwen35ForCausalLM(model_path, config) +model.compile(compiled_path) + +# Reload from compiled artifacts +model = NeuronQwen35ForCausalLM(compiled_path) +model.load(compiled_path) + +tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="right") +gen_config = GenerationConfig( + do_sample=True, top_k=1, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, +) + +inputs = tokenizer("The capital of France is", return_tensors="pt") +gen_model = HuggingFaceGenerationAdapter(model) +outputs = gen_model.generate( + inputs.input_ids, + generation_config=gen_config, + attention_mask=inputs.attention_mask, + max_new_tokens=50, +) +print(tokenizer.decode(outputs[0], skip_special_tokens=True)) +``` + +### Vision-Language (trn2.3xlarge, TP=4) + +The VL pipeline uses the text decoder on Neuron and the vision encoder on CPU: + +```python +from src.modeling_qwen35_vl import NeuronQwen35VLForCausalLM, Qwen35VLInferenceConfig + +vl_model = NeuronQwen35VLForCausalLM( + model_path="/path/to/Qwen3.6-27B", + config=vl_config, +) +vl_model.compile(compiled_path) +vl_model.load(compiled_path) + +# See test/integration/test_model.py for full VL usage example +``` + +### DeltaNet Kernel Selection + +The DeltaNet forward path can be controlled via environment variables: + +| Env Var | Forward Path | Use Case | +|---------|-------------|----------| +| `USE_NKI_FUSED=1` | Fused chunked NKI kernel | Best CTE performance (default for SDK 2.29) | +| `USE_NKI_CHUNKED=1` | Per-chunk NKI kernel | Legacy, superseded by fused | +| `USE_NKI=1` | Per-token NKI kernel | TKG (always used for token generation) | +| `DELTANET_SEQUENTIAL=1` | Sequential PyTorch | Debugging/reference | +| *(none)* | PyTorch chunked | Default fallback for CTE | + +## Caveats + +1. **BF16 HBM pressure at TP=4:** The pure BF16 model consumes nearly all HBM on trn2.3xlarge. Use the FP8/vLLM path for the validated 128K artifact, or a larger instance for additional batching/headroom. + +2. **SDK 2.29+ required:** The NKI DeltaNet kernels require NKI 0.3.0 (SDK 2.29). No library modifications needed -- runs on stock SDK 2.29 DLAMI (`/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/`). + +3. **No mini model test:** Unlike DeepSeek-V3, a mini model cannot be provided because DeltaNet layers require NKI kernels that only execute on Neuron devices. Integration tests require a trn2 instance with the full 27B weights. + +4. **Vision encoder runs on CPU:** The ViT cannot be placed on Neuron because HBM is fully consumed by the text decoder. This adds ~918ms latency per image. Future optimization: quantize text decoder to free HBM, or use larger instance. + +5. **Compilation time:** The short-context BF16 path compiles in roughly 13 minutes. The validated 128K FP8/vLLM artifact takes longer because it includes long-context cache shapes and presharded checkpoints. + +6. **+1 RMSNorm convention:** Qwen3.5/3.6 uses `output = norm(x) * (1 + weight)` for most RMSNorm layers, but DeltaNet internal norms use standard `output = norm(x) * weight`. The weight conversion handles this automatically, but custom weight loading must be aware of both conventions. + +7. **DeltaNet numerical stability:** DeltaNet kernels rely on normalized Q/K inputs and bounded decay handling. The chunked path includes regression coverage for decay handling; changes to the fused kernel should be validated against the CPU reference and long-context stress prompts. + +8. **Shared codebase with Qwen3.5-27B:** This contrib uses the same `Qwen35*` class names and `modeling_qwen35*.py` filenames as the [Qwen3.5-27B contrib](../Qwen3.5-27B/). This is intentional -- both models share the `qwen3_5` model_type. The code is identical; only the HuggingFace model ID and weights differ. + +## Maximum Sequence Length + +| seq_len | Path | Status | Notes | +|---------|------|--------|-------| +| 128 | BF16 NxDI | **PASS** | BF16 baseline/quality checks | +| 256 | BF16 NxDI | **PASS** | BF16 benchmark bucket | +| 512 | BF16 NxDI | **PASS** | 4 DeltaNet chunks | +| 65,536 | FP8/vLLM | **PASS** | chunked prefill, quality, and state-reset validation | +| 131,072 | FP8/vLLM | **PASS** | compiled and served with 512-token CTE bucket | + +For production long-context serving on trn2.3xlarge, use the FP8/vLLM artifact +and 512-token context encoding bucket. Larger instances are recommended for +larger batches or additional serving headroom. + +## Compatibility Matrix + +| Instance | TP | LNC | Status | Notes | +|----------|-----|-----|--------|-------| +| trn2.3xlarge | 4 | 2 | **PASS** | BF16 short-context and FP8 128K vLLM/APC validated | +| trn2.48xlarge | 4 | 2 | Expected PASS | Untested for this contrib; use the same TP=4 artifact shape when compiling for trn2.3xlarge deployment | +| trn2u.48xlarge | 4 | 2 | Expected PASS | Untested for this contrib; same portability caveat as trn2.48xlarge | + +### SDK Configuration + +| Component | Version | +|-----------|---------| +| NxDI | 0.9.17334 | +| neuronx-cc | 2.24.5133 | +| torch | 2.9.1 | +| transformers | 4.57.6 | +| NKI | 0.3.0 | +| NXDI venv | `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/` | + +## Testing + +### Unit Tests (CPU only, no device needed) + +```bash +cd contrib/models/Qwen3.6-27B/ +# On DLAMI: source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate +pytest test/unit/ -v +``` + +Tests: config parsing (26), weight conversion (16), hybrid cache manager (13), and DeltaNet decay handling (2) = **57 tests**. + +### Integration Tests (needs trn2.3xlarge with 4 NeuronCores) + +```bash +cd contrib/models/Qwen3.6-27B/ + +QWEN35_MODEL_PATH=/mnt/models/Qwen3.6-27B \ +QWEN35_COMPILED_PATH=/mnt/models/qwen36_traced \ +pytest test/integration/test_model.py --capture=tee-sys +``` + +Tests: model loads, generates, coherence, top-token valid, capital test, TTFT, throughput, multi-prompt = **8 tests**. + +Note: The env var is `QWEN35_MODEL_PATH` (not `QWEN36`) because the code uses the `qwen3_5` model_type internally. + +## Example Checkpoints + +- [`Qwen/Qwen3.6-27B`](https://huggingface.co/Qwen/Qwen3.6-27B) (BF16, ~52 GB) + +## Maintainer + +AWS Neuron + +**Last Updated:** 2026-04-23 diff --git a/contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py b/contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py new file mode 100644 index 00000000..fe7f45d5 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py @@ -0,0 +1,383 @@ +#!/usr/bin/env python3 +"""Minimal OpenAI-compatible HTTP server for the Qwen3.6-27B NxDI artifact. + +This intentionally avoids uvicorn/fastapi runtime dependencies so it can run in +the stock Neuron inference venv. It supports non-streaming: + - GET /health + - GET /v1/models + - POST /v1/completions + - POST /v1/chat/completions +""" + +import argparse +import json +import sys +import threading +import time +import traceback +import uuid +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from typing import Any, Dict, List + +import torch + + +def _json_response(handler: BaseHTTPRequestHandler, status: int, payload: Dict[str, Any]): + body = json.dumps(payload, ensure_ascii=False).encode("utf-8") + handler.send_response(status) + handler.send_header("Content-Type", "application/json") + handler.send_header("Content-Length", str(len(body))) + handler.send_header("Access-Control-Allow-Origin", "*") + handler.send_header("Access-Control-Allow-Headers", "authorization,content-type") + handler.send_header("Access-Control-Allow-Methods", "GET,POST,OPTIONS") + handler.end_headers() + handler.wfile.write(body) + + +def _error(handler: BaseHTTPRequestHandler, status: int, message: str): + _json_response( + handler, + status, + {"error": {"message": message, "type": "server_error", "code": status}}, + ) + + +def _first_text_prompt(prompt: Any) -> str: + if isinstance(prompt, str): + return prompt + if isinstance(prompt, list) and prompt: + return str(prompt[0]) + return str(prompt) + + +def _token_scalar(tokens: Any) -> int: + if hasattr(tokens, "detach"): + tokens = tokens.detach().cpu() + if tokens.ndim == 0: + return int(tokens.item()) + return int(tokens.reshape(-1)[0].item()) + + +class QwenOpenAIServer: + def __init__(self, args: argparse.Namespace): + self.args = args + self.model_id = args.model_id + self.lock = threading.Lock() + self._load_model() + + def _load_model(self): + if self.args.contrib_root not in sys.path: + sys.path.insert(0, self.args.contrib_root) + + from transformers import AutoTokenizer, GenerationConfig + from neuronx_distributed_inference.modules.generation.sampling import ( + prepare_sampling_params, + ) + from src.modeling_qwen35 import NeuronQwen35ForCausalLM + + print("Loading tokenizer from", self.args.model_path, flush=True) + self.tokenizer = AutoTokenizer.from_pretrained( + self.args.model_path, + padding_side="right", + ) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + print("Loading NxDI artifact from", self.args.compiled_path, flush=True) + t0 = time.perf_counter() + self.model = NeuronQwen35ForCausalLM(self.args.compiled_path) + self.model.load(self.args.compiled_path) + self.model.reset() + self.prepare_sampling_params = prepare_sampling_params + self.GenerationConfig = GenerationConfig + print(f"Model loaded in {time.perf_counter() - t0:.2f}s", flush=True) + + def _chat_prompt(self, messages: List[Dict[str, Any]], enable_thinking: bool = False) -> str: + try: + return self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=enable_thinking, + ) + except TypeError: + return self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + except Exception: + lines = [] + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + lines.append(f"{role}: {content}") + lines.append("assistant:") + return "\n".join(lines) + + def _generate(self, prompt: str, body: Dict[str, Any]) -> Dict[str, Any]: + max_tokens = int(body.get("max_tokens", body.get("max_completion_tokens", 128)) or 128) + if max_tokens <= 0: + raise ValueError("max_tokens must be positive") + if max_tokens > self.args.max_new_tokens_limit: + raise ValueError( + f"max_tokens={max_tokens} exceeds server limit {self.args.max_new_tokens_limit}" + ) + + input_ids = torch.tensor( + [self.tokenizer(prompt, add_special_tokens=False).input_ids], + dtype=torch.long, + ) + prompt_tokens = int(input_ids.shape[1]) + if prompt_tokens <= 0: + raise ValueError("prompt must contain at least one token") + if prompt_tokens + max_tokens > self.args.seq_len: + raise ValueError( + f"prompt_tokens + max_tokens = {prompt_tokens + max_tokens} exceeds " + f"seq_len={self.args.seq_len}" + ) + + temperature = float(body.get("temperature", 0.0) or 0.0) + top_p = float(body.get("top_p", 1.0) or 1.0) + top_k = int(body.get("top_k", 1) or 1) + # NxDI's traced on-device sampler for this artifact uses do_sample=True. + # OpenAI temperature=0 means greedy, but passing literal 0 into that + # sampler divides logits by zero. top_k=1 with temperature=1 is the + # deterministic greedy path used by the validated HF adapter tests. + sampler_temperature = temperature + if temperature <= 0.0: + sampler_temperature = 1.0 + top_p = 1.0 + top_k = 1 + sampling_params = self.prepare_sampling_params( + batch_size=1, + top_k=[top_k], + top_p=[top_p], + temperature=[sampler_temperature], + ) + seq_ids = torch.tensor([0], dtype=torch.int32) + + with self.lock: + if hasattr(self.model, "reset"): + self.model.reset() + t0 = time.perf_counter() + first_token = None + for start in range(0, prompt_tokens, self.args.chunk_size): + end = min(start + self.args.chunk_size, prompt_tokens) + valid = end - start + chunk_ids = input_ids[:, start:end] + attention_mask = torch.ones((1, valid), dtype=torch.long) + position_ids = torch.arange( + start, + end, + dtype=torch.long, + ).unsqueeze(0) + + with torch.no_grad(): + out = self.model( + input_ids=chunk_ids, + attention_mask=attention_mask, + position_ids=position_ids, + seq_ids=seq_ids, + sampling_params=sampling_params, + return_dict=True, + ) + first_token = _token_scalar(out.tokens) + + if first_token is None: + raise RuntimeError("prefill produced no token") + + new_ids = [] + current_token = first_token + vocab_size = len(self.tokenizer) + raw_eos_id = self.tokenizer.eos_token_id + eos_ids = ( + set(raw_eos_id) + if isinstance(raw_eos_id, (list, tuple, set)) + else {raw_eos_id} + ) + decode_ids = torch.empty((1, 1), dtype=torch.int32) + decode_position_ids = torch.empty((1, 1), dtype=torch.int32) + decode_attention_mask = torch.ones( + (1, prompt_tokens + max_tokens), + dtype=torch.int32, + ) + finish_reason = "length" + with torch.no_grad(): + for step in range(max_tokens): + if current_token in eos_ids: + finish_reason = "stop" + break + if current_token < 0 or current_token >= vocab_size: + raise RuntimeError(f"model generated invalid token id: {current_token}") + new_ids.append(current_token) + if step == max_tokens - 1: + break + + pos_value = prompt_tokens + step + decode_ids[0, 0] = current_token + decode_position_ids[0, 0] = pos_value + active_attention_mask = decode_attention_mask[:, : pos_value + 1] + out = self.model( + input_ids=decode_ids, + attention_mask=active_attention_mask, + position_ids=decode_position_ids, + seq_ids=seq_ids, + sampling_params=sampling_params, + return_dict=True, + ) + current_token = _token_scalar(out.tokens) + elapsed = time.perf_counter() - t0 + + invalid = [tok for tok in new_ids if tok < 0 or tok >= vocab_size] + if invalid: + raise RuntimeError(f"model generated invalid token ids: {invalid[:8]}") + + text = self.tokenizer.decode(new_ids, skip_special_tokens=True) + for stop in body.get("stop") or []: + if isinstance(stop, str) and stop in text: + text = text.split(stop, 1)[0] + + return { + "text": text, + "prompt_tokens": prompt_tokens, + "completion_tokens": len(new_ids), + "elapsed": elapsed, + "tokens": new_ids, + "finish_reason": finish_reason, + } + + +def make_handler(server_state: QwenOpenAIServer): + class Handler(BaseHTTPRequestHandler): + protocol_version = "HTTP/1.1" + + def log_message(self, fmt, *args): + print(f"{self.address_string()} - {fmt % args}", flush=True) + + def do_OPTIONS(self): + _json_response(self, 200, {}) + + def do_GET(self): + if self.path == "/health": + _json_response(self, 200, {"status": "ok", "model": server_state.model_id}) + elif self.path == "/v1/models": + _json_response( + self, + 200, + { + "object": "list", + "data": [ + { + "id": server_state.model_id, + "object": "model", + "created": int(time.time()), + "owned_by": "local", + } + ], + }, + ) + else: + _error(self, 404, f"unknown route: {self.path}") + + def do_POST(self): + try: + length = int(self.headers.get("content-length", "0")) + body = json.loads(self.rfile.read(length).decode("utf-8") or "{}") + if body.get("stream"): + raise ValueError("stream=true is not supported by this minimal server yet") + + if self.path == "/v1/completions": + result = server_state._generate(_first_text_prompt(body.get("prompt", "")), body) + _json_response( + self, + 200, + { + "id": f"cmpl-{uuid.uuid4().hex}", + "object": "text_completion", + "created": int(time.time()), + "model": server_state.model_id, + "choices": [ + { + "index": 0, + "text": result["text"], + "finish_reason": result["finish_reason"], + } + ], + "usage": { + "prompt_tokens": result["prompt_tokens"], + "completion_tokens": result["completion_tokens"], + "total_tokens": result["prompt_tokens"] + + result["completion_tokens"], + }, + "x_latency_seconds": result["elapsed"], + }, + ) + elif self.path == "/v1/chat/completions": + messages = body.get("messages") or [] + if not isinstance(messages, list): + raise ValueError("messages must be a list") + result = server_state._generate( + server_state._chat_prompt( + messages, + enable_thinking=bool(body.get("enable_thinking", False)), + ), + body, + ) + _json_response( + self, + 200, + { + "id": f"chatcmpl-{uuid.uuid4().hex}", + "object": "chat.completion", + "created": int(time.time()), + "model": server_state.model_id, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": result["text"], + }, + "finish_reason": result["finish_reason"], + } + ], + "usage": { + "prompt_tokens": result["prompt_tokens"], + "completion_tokens": result["completion_tokens"], + "total_tokens": result["prompt_tokens"] + + result["completion_tokens"], + }, + "x_latency_seconds": result["elapsed"], + }, + ) + else: + _error(self, 404, f"unknown route: {self.path}") + except Exception as exc: + traceback.print_exc() + _error(self, 500, str(exc)) + + return Handler + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--model-id", default="qwen3.6-27b-neuron") + parser.add_argument("--model-path", required=True) + parser.add_argument("--compiled-path", required=True) + parser.add_argument("--contrib-root", required=True) + parser.add_argument("--seq-len", type=int, default=65536) + parser.add_argument("--chunk-size", type=int, default=512) + parser.add_argument("--max-new-tokens-limit", type=int, default=512) + args = parser.parse_args() + + state = QwenOpenAIServer(args) + httpd = ThreadingHTTPServer((args.host, args.port), make_handler(state)) + print(f"Serving {args.model_id} on http://{args.host}:{args.port}", flush=True) + httpd.serve_forever() + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/src/__init__.py b/contrib/models/Qwen3.6-27B/src/__init__.py new file mode 100644 index 00000000..7e79aa03 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/__init__.py @@ -0,0 +1,41 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from src.modeling_qwen35 import ( + NeuronGatedDeltaNet, + NeuronQwen35Attention, + NeuronQwen35DecoderLayer, + NeuronQwen35ForCausalLM, + NeuronQwen35Model, + Qwen35DecoderModelInstance, + Qwen35InferenceConfig, + Qwen35MLP, + Qwen35ModelWrapper, +) +from src.modeling_qwen35_vision import ( + NeuronQwen35VisionForImageEncoding, + NeuronQwen35VisionModel, +) +from src.modeling_qwen35_vl import ( + NeuronQwen35VLForCausalLM, + Qwen35VLInferenceConfig, +) + +__all__ = [ + # Text decoder + "NeuronGatedDeltaNet", + "NeuronQwen35Attention", + "NeuronQwen35DecoderLayer", + "NeuronQwen35ForCausalLM", + "NeuronQwen35Model", + "Qwen35DecoderModelInstance", + "Qwen35InferenceConfig", + "Qwen35MLP", + "Qwen35ModelWrapper", + # Vision encoder + "NeuronQwen35VisionForImageEncoding", + "NeuronQwen35VisionModel", + # Vision-language + "NeuronQwen35VLForCausalLM", + "Qwen35VLInferenceConfig", +] diff --git a/contrib/models/Qwen3.6-27B/src/modeling_qwen35.py b/contrib/models/Qwen3.6-27B/src/modeling_qwen35.py new file mode 100644 index 00000000..ed3c3f2c --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/modeling_qwen35.py @@ -0,0 +1,3245 @@ +""" +NxDI contrib: Qwen3.5-27B / Qwen3.6-27B (qwen3_5 -- dense model) + +Supports both Qwen3.5-27B and Qwen3.6-27B. These models share identical +architecture (qwen3_5 model_type). Qwen3.6-27B is a post-training update +with improved agentic coding and thinking preservation -- no architecture +changes, only weight differences. + +Hybrid DeltaNet + Standard Attention + Dense MLP architecture. +Adapted from Qwen3.5-35B-A3B (MoE) -- MoE removed, dense MLP added. + +48 of 64 layers use Gated DeltaNet (linear recurrent attention) +16 of 64 layers use standard GQA with KV cache + output gate +All 64 layers use a dense SwiGLU MLP (intermediate_size=17408) + +Architecture details: +- DeltaNet layers: separate in_proj_{qkv, z, a, b}, causal conv1d on QKV, gated delta rule +- Attention layers: q_proj doubled (Q + gate), partial RoPE (25% of head_dim), sigmoid output gate +- Dense MLP: standard SwiGLU (gate_proj, up_proj, down_proj) -- no MoE, no router, no experts +- KV cache: NxDI KVCacheManager for attention layers; DeltaNet layers store recurrent+conv + state as nn.Parameter buffers and return dummy KV tuples + +Config compatibility notes: +- Qwen3.6-27B adds output_gate_type="swish" to text_config. This field is + unused by both HF transformers and this NxDI code (gate uses sigmoid, as + confirmed across transformers v4.57.6, v5.6.0, and GitHub main). Safe to ignore. +""" + +import gc +import math +import logging +import os +import sys +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm + +try: + from neuronxcc.nki._private_kernels.attention import attention_isa_kernel +except ImportError: + from neuronxcc.nki.kernels.attention import attention_isa_kernel + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) +from neuronx_distributed.utils import cpu_mode + +try: + from nki import jit as nki_jit # NKI 0.3.0+ (SDK 2.29) +except ImportError: + from torch_neuronx.xla_impl.ops import nki_jit # NKI 0.2.x (SDK 2.28) +from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeRMSNorm + +from src.nki_kernels.nki_deltanet import deltanet_recurrent_fwd as _deltanet_nki_kernel +from src.nki_kernels.nki_deltanet import ( + deltanet_recurrent_fwd_state as _deltanet_nki_kernel_state, +) +from src.nki_kernels.nki_deltanet_chunked import ( + deltanet_chunk_step as _deltanet_nki_chunk_step, +) +from src.nki_kernels.nki_deltanet_fused import ( + deltanet_fused_chunked_fwd as _deltanet_fused_kernel, +) +from src.nki_kernels.nki_deltanet_fused import ( + _make_lower_mask, + _make_lower_mask_diag, + _make_identity, +) + +from neuronx_distributed_inference.models.config import ( + InferenceConfig, + NeuronConfig, +) +from neuronx_distributed_inference.models.model_wrapper import ( + CONTEXT_ENCODING_MODEL_TAG, + TOKEN_GENERATION_MODEL_TAG, + DecoderModelInstance, + ModelWrapper, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) +from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding +from neuronx_distributed_inference.modules.kvcache.kv_cache_manager import KVCacheManager +from neuronx_distributed_inference.models.layer_boundary_marker import ( + ModuleMarkerEndWrapper, + ModuleMarkerStartWrapper, +) + +logger = logging.getLogger(__name__) + +try: + _flash_fwd_call = nki_jit()(attention_isa_kernel) +except TypeError: + from torch_neuronx.xla_impl.ops import nki_jit as _torch_xla_nki_jit + + _flash_fwd_call = _torch_xla_nki_jit()(attention_isa_kernel) + +# Option B: Direct nkilib flash attention for head_dim > 128 +USE_NKILIB_KERNEL = os.environ.get("USE_NKILIB_KERNEL", "0") == "1" + +_nkilib_flash_attn = None +if USE_NKILIB_KERNEL: + try: + import neuronxcc.nki as _nki + from neuronx_distributed_inference.modules.attention.attention_base import ( + peel_decorations as _peel_decorations, + get_platform_target as _get_platform_target, + ) + from neuronxcc.nki.compiler import ( + skip_middle_end_transformations as _skip_middle_end, + enable_stack_allocator as _enable_stack_allocator, + ) + + import importlib + + _fork_path = "/home/ubuntu/nki-library-fork/nkilib_src" + if os.path.isdir(_fork_path) and _fork_path not in sys.path: + sys.path.insert(0, _fork_path) + _to_remove = [k for k in sys.modules if k.startswith("nkilib")] + for k in _to_remove: + del sys.modules[k] + import nki.language as _stub_nl + import neuronxcc.nki.language as _real_nl + + for _attr in [ + "NKIObject", + "float8_e4m3fn", + "float8_e4m3fn_x4", + "float8_e5m2_x4", + "float4_e2m1fn_x4", + ]: + if not hasattr(_real_nl, _attr) and hasattr(_stub_nl, _attr): + setattr(_real_nl, _attr, getattr(_stub_nl, _attr)) + from nkilib.core.attention.attention_cte import ( + attention_cte as _attention_cte_raw, + _MAX_HEAD_DIM, + ) + + assert _MAX_HEAD_DIM == 256, ( + f"nkilib fork has _MAX_HEAD_DIM={_MAX_HEAD_DIM}, expected 256. " + f"System nkilib may have been loaded instead of fork." + ) + logger.info( + f"Loaded nkilib attention_cte from fork (_MAX_HEAD_DIM={_MAX_HEAD_DIM})" + ) + + _raw_fn = _peel_decorations(_attention_cte_raw) + _platform = _get_platform_target() + _nkilib_flash_attn = _nki.jit( + _raw_fn, + mode="torchxla", + platform_target=_platform, + show_compiler_tb=True, + debug_kernel=True, + ) + _nkilib_flash_attn = _skip_middle_end(_nkilib_flash_attn) + _nkilib_flash_attn = _enable_stack_allocator( + _nkilib_flash_attn, log_level=logging.INFO + ) + logger.info("Option B: nkilib flash attention loaded for head_dim > 128") + except Exception as e: + logger.warning(f"Option B: Failed to load nkilib flash attention: {e}") + import traceback as _tb + + _tb.print_exc() + _nkilib_flash_attn = None + +# Option A: Detect if patch_attn_kernel was imported +NKILIB_PATCH_ACTIVE = False +try: + from importlib import import_module as _import_module + + _attn_mod = _import_module("neuronxcc.nki._pre_prod_kernels.attn_fwd") + if hasattr(_attn_mod, "_original_attention_nki_kernel_adapter"): + NKILIB_PATCH_ACTIVE = True + logger.info("Option A detected: _pre_prod_kernels patched with nkilib kernel") +except Exception: + pass + + +# ============================================================ +# Newton-Raphson Refined RMSNorm +# ============================================================ +USE_NEWTON_RMSNORM = os.environ.get("USE_NEWTON_RMSNORM") == "1" +USE_PYTHON_RMSNORM = os.environ.get("USE_PYTHON_RMSNORM") == "1" + + +class NewtonRMSNorm(nn.Module): + """RMSNorm with Newton-Raphson refined rsqrt for improved numerical accuracy.""" + + def __init__(self, hidden_size=None, eps=1e-6): + super().__init__() + self.weight = None + if hidden_size is not None: + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.hidden_size = hidden_size + self.variance_epsilon = eps + + def forward(self, hidden_states): + original_dtype = hidden_states.dtype + x = hidden_states.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + y = torch.rsqrt(variance + self.variance_epsilon) + y = y * (3.0 - (variance + self.variance_epsilon) * y * y) * 0.5 + result = x * y + if self.weight is not None: + result = result * self.weight.float() + return result.to(original_dtype) + + +def get_rmsnorm_cls(): + if cpu_mode() or USE_PYTHON_RMSNORM: + return Qwen3MoeRMSNorm + return NewtonRMSNorm if USE_NEWTON_RMSNORM else CustomRMSNorm + + +def l2norm(x, dim=-1, eps=1e-6): + return F.normalize(x, p=2, dim=dim, eps=eps) + + +# ============================================================ +# Gated DeltaNet Module (Linear Recurrent Attention) +# ============================================================ + + +class NeuronGatedDeltaNet(nn.Module): + """ + Gated DeltaNet linear attention for Neuron. + + Replaces standard attention for 48 of 64 layers in Qwen3.5/3.6-27B. + Uses a chunk-based linear recurrence instead of KV cache. + + HF weight layout (27B dense -- scaled dimensions): + - in_proj_qkv.weight: (key_dim*2 + value_dim, hidden_size) = (10240, 5120) + - in_proj_z.weight: (value_dim, hidden_size) = (6144, 5120) + - in_proj_a.weight: (num_v_heads, hidden_size) = (48, 5120) + - in_proj_b.weight: (num_v_heads, hidden_size) = (48, 5120) + - conv1d.weight: (conv_dim, 1, conv_kernel_size) = (10240, 1, 4) + - A_log: (num_v_heads,) = (48,) + - dt_bias: (num_v_heads,) = (48,) + - norm.weight: (head_v_dim,) = (128,) + - out_proj.weight: (hidden_size, value_dim) = (5120, 6144) + """ + + def __init__(self, config, layer_idx: int): + super().__init__() + tc = config + + self.hidden_size = tc.hidden_size # 5120 + self.tp_degree = tc.neuron_config.tp_degree + self.global_num_v_heads = tc.linear_num_value_heads # 48 + self.global_num_k_heads = tc.linear_num_key_heads # 16 + self.head_k_dim = tc.linear_key_head_dim # 128 + self.head_v_dim = tc.linear_value_head_dim # 128 + if self.global_num_v_heads % self.tp_degree != 0: + raise ValueError( + f"linear_num_value_heads={self.global_num_v_heads} must be divisible " + f"by tp_degree={self.tp_degree}" + ) + if self.global_num_k_heads % self.tp_degree != 0: + raise ValueError( + f"linear_num_key_heads={self.global_num_k_heads} must be divisible " + f"by tp_degree={self.tp_degree}" + ) + self.num_v_heads = self.global_num_v_heads // self.tp_degree + self.num_k_heads = self.global_num_k_heads // self.tp_degree + self.global_key_dim = self.head_k_dim * self.global_num_k_heads # 2048 + self.global_value_dim = self.head_v_dim * self.global_num_v_heads # 6144 + self.key_dim = self.head_k_dim * self.num_k_heads # 512 at TP=4 + self.value_dim = self.head_v_dim * self.num_v_heads # 1536 at TP=4 + self.conv_kernel_size = tc.linear_conv_kernel_dim # 4 + self.layer_idx = layer_idx + self.rms_norm_eps = tc.rms_norm_eps + self.use_hybrid_cache_manager = getattr(tc, "use_hybrid_cache_manager", False) + self.use_qwen_hybrid_chunked_prefill = getattr( + tc, "use_qwen_hybrid_chunked_prefill", False + ) + self.use_qwen_hybrid_chunked_prefill_nki = getattr( + tc, "use_qwen_hybrid_chunked_prefill_nki", False + ) + + # KV cache dummy shape info + self.head_dim = tc.head_dim # 256 + tp_degree = tc.neuron_config.tp_degree + raw_kv_heads = tc.num_key_value_heads + if raw_kv_heads < tp_degree: + replicated_kv_heads = tp_degree + else: + replicated_kv_heads = raw_kv_heads + self.kv_heads_per_rank = replicated_kv_heads // tp_degree + + # Conv1d on concatenated QKV (NOT Z). Store the depthwise kernel in a + # ColumnParallelLinear parameter container so NxD's checkpoint sharder + # can split it by output channel. Forward still uses it as Conv1d + # weight after unsqueezing the singleton input-channel dimension. + self.global_conv_dim = self.global_key_dim * 2 + self.global_value_dim # 10240 + self.conv_dim = self.key_dim * 2 + self.value_dim # 2560 at TP=4 + self.conv1d_weight = ColumnParallelLinear( + self.conv_kernel_size, + self.global_conv_dim, + bias=False, + gather_output=False, + ) + + # Input/output projections are the large DeltaNet tensors. Shard them + # with tensor parallelism; convert_qwen35_hf_to_neuron_state_dict() + # reorders in_proj_qkv into per-rank [Q_local | K_local | V_local] + # blocks before NxD slices the output dimension. + self.in_proj_qkv = ColumnParallelLinear( + self.hidden_size, + self.global_key_dim * 2 + self.global_value_dim, + bias=False, + gather_output=False, + ) + self.in_proj_z = ColumnParallelLinear( + self.hidden_size, + self.global_value_dim, + bias=False, + gather_output=False, + ) + self.in_proj_b = ColumnParallelLinear( + self.hidden_size, + self.global_num_v_heads, + bias=False, + gather_output=False, + ) + self.in_proj_a = ColumnParallelLinear( + self.hidden_size, + self.global_num_v_heads, + bias=False, + gather_output=False, + ) + + # Same parameter-container pattern for per-value-head decay vectors. + # These are used as vectors in forward but sharded by output dim during + # checkpoint conversion/loading. + self.dt_bias_weight = ColumnParallelLinear( + 1, + self.global_num_v_heads, + bias=False, + gather_output=False, + ) + self.A_log_weight = ColumnParallelLinear( + 1, + self.global_num_v_heads, + bias=False, + gather_output=False, + ) + + # Output norm and projection + self.norm = Qwen3MoeRMSNorm(self.head_v_dim, eps=self.rms_norm_eps) + self.out_proj = RowParallelLinear( + self.global_value_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + ) + + # State buffers for CTE -> TKG carry-over + alloc_batch_size = getattr(config.neuron_config, "max_batch_size", 1) + self._phase_batch_size = getattr(config.neuron_config, "batch_size", 1) + self.recurrent_state_buffer = nn.Parameter( + torch.zeros( + alloc_batch_size, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + dtype=config.neuron_config.torch_dtype, + ), + requires_grad=False, + ) + self.conv_state_buffer = nn.Parameter( + torch.zeros( + alloc_batch_size, + self.conv_dim, + self.conv_kernel_size - 1, + dtype=config.neuron_config.torch_dtype, + ), + requires_grad=False, + ) + + def _conv1d_weight(self): + return self.conv1d_weight.weight.unsqueeze(1) + + def _dt_bias(self): + return self.dt_bias_weight.weight.squeeze(1) + + def _A_log(self): + return self.A_log_weight.weight.squeeze(1) + + def _recurrent_step(self, query, key, value, g, beta, recurrent_state): + """Single-step recurrent update for token generation.""" + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + q_t = query[:, :, 0] + k_t = key[:, :, 0] + v_t = value[:, :, 0] + g_t = g[:, :, 0].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, 0].unsqueeze(-1) + + new_state = recurrent_state * g_t + kv_mem = (new_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + new_state = new_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + output = (new_state * q_t.unsqueeze(-1)).sum(dim=-2) + + return output.unsqueeze(2), new_state + + def _nki_recurrent_forward(self, query, key, value, g, beta): + """Full-sequence recurrent forward using NKI kernel for context encoding.""" + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + BH = B * H + query_flat = query.reshape(BH, S, k_dim).contiguous() + key_flat = key.reshape(BH, S, k_dim).contiguous() + value_flat = value.reshape(BH, S, v_dim).contiguous() + + g_flat = g.reshape(BH, S).unsqueeze(-1).expand(-1, -1, v_dim).contiguous() + beta_flat = beta.reshape(BH, S).unsqueeze(-1).expand(-1, -1, v_dim).contiguous() + + outputs = [] + states = [] + for bh in range(BH): + out_bh, state_bh = _deltanet_nki_kernel_state( + query_flat[bh], + key_flat[bh], + value_flat[bh], + g_flat[bh], + beta_flat[bh], + ) + outputs.append(out_bh) + states.append(state_bh) + + output = torch.stack(outputs, dim=0) + output = output.reshape(B, H, S, v_dim) + + final_state = torch.stack(states, dim=0) + final_state = final_state.reshape(B, H, k_dim, v_dim) + + return output, final_state + + def _nki_chunked_forward( + self, query, key, value, g, beta, output_final_state=False, initial_state=None + ): + """Chunked NKI kernel forward for context encoding (prefill).""" + chunk_size = 128 + + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + pad_size = (chunk_size - S % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = S + pad_size + + num_chunks = total_seq_len // chunk_size + g_reshaped = g.reshape(B, H, num_chunks, chunk_size) + g_cs = g_reshaped.cumsum(dim=-1) + g_last_per_chunk = g_cs[:, :, :, -1:] + g_last_expanded = g_last_per_chunk.expand(-1, -1, -1, chunk_size) + + query_chunks = query.reshape(B, H, num_chunks, chunk_size, k_dim) + key_chunks = key.reshape(B, H, num_chunks, chunk_size, k_dim) + value_chunks = value.reshape(B, H, num_chunks, chunk_size, v_dim) + + beta_chunks = ( + beta.reshape(B, H, num_chunks, chunk_size) + .unsqueeze(-1) + .expand(-1, -1, -1, -1, v_dim) + ) + gc_chunks = g_cs.unsqueeze(-1).expand(-1, -1, -1, -1, v_dim) + gl_chunks = g_last_expanded.unsqueeze(-1).expand(-1, -1, -1, -1, v_dim) + + BH = B * H + query_chunks = query_chunks.reshape( + BH, num_chunks, chunk_size, k_dim + ).contiguous() + key_chunks = key_chunks.reshape(BH, num_chunks, chunk_size, k_dim).contiguous() + value_chunks = value_chunks.reshape( + BH, num_chunks, chunk_size, v_dim + ).contiguous() + beta_chunks = beta_chunks.reshape( + BH, num_chunks, chunk_size, v_dim + ).contiguous() + gc_chunks = gc_chunks.reshape(BH, num_chunks, chunk_size, v_dim).contiguous() + gl_chunks = gl_chunks.reshape(BH, num_chunks, chunk_size, v_dim).contiguous() + + device = query.device + lower_mask = torch.tril( + torch.ones(chunk_size, chunk_size, dtype=torch.float32, device=device), + diagonal=-1, + ) + identity_mat = torch.eye(chunk_size, dtype=torch.float32, device=device) + lower_mask_diag = torch.tril( + torch.ones(chunk_size, chunk_size, dtype=torch.float32, device=device), + diagonal=0, + ) + + initial_state_flat = None + if initial_state is not None: + initial_state_flat = initial_state.reshape(BH, k_dim, v_dim).float().contiguous() + + all_outputs = [] + all_states = [] + for bh in range(BH): + if initial_state_flat is None: + state = torch.zeros(k_dim, v_dim, dtype=torch.float32, device=device) + else: + state = initial_state_flat[bh] + + head_chunks = [] + for c_idx in range(num_chunks): + q_chunk = query_chunks[bh, c_idx].contiguous() + k_chunk = key_chunks[bh, c_idx].contiguous() + v_chunk = value_chunks[bh, c_idx].contiguous() + beta_chunk = beta_chunks[bh, c_idx].contiguous() + gc_chunk = gc_chunks[bh, c_idx].contiguous() + gl_chunk = gl_chunks[bh, c_idx].contiguous() + + out_chunk, state = _deltanet_nki_chunk_step( + q_chunk, + k_chunk, + v_chunk, + beta_chunk, + gc_chunk, + gl_chunk, + state, + lower_mask, + identity_mat, + lower_mask_diag, + ) + head_chunks.append(out_chunk) + + head_output = torch.cat(head_chunks, dim=0) + all_outputs.append(head_output) + all_states.append(state) + + output = torch.stack(all_outputs, dim=0) + output = output.reshape(B, H, total_seq_len, v_dim) + output = output[:, :, :S] + + if output_final_state: + final_state = torch.stack(all_states, dim=0) + last_recurrent_state = final_state.reshape(B, H, k_dim, v_dim) + else: + last_recurrent_state = None + + return output, last_recurrent_state + + def _fused_chunked_forward( + self, query, key, value, g, beta, output_final_state=False + ): + """Fused single-kernel chunked forward for CTE — SSD-style. + + Processes all chunks in a single NKI kernel call per (B,H) pair. + State persists in SBUF across chunks (no HBM round-trips). + Cumsum of g computed in-kernel via tensor_tensor_scan. + + This is the optimized version of _nki_chunked_forward with: + 1. Single kernel call per (B,H) instead of B*H*num_chunks + 2. State in SBUF across all chunks (biggest perf win) + 3. In-kernel cumsum (avoids PyTorch cumsum overhead) + 4. tensor_scalar for broadcasts (no explicit loops) + """ + chunk_size = 128 + + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + # Pad sequence to multiple of chunk_size + pad_size = (chunk_size - S % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = S + pad_size + # Pass raw per-token log-decay. The fused NKI kernel forms decay as + # exp(cumsum(g)_i - cumsum(g)_j), so no pre-kernel clamp is needed. + + BH = B * H + # Flatten to (BH, S, dim) for per-(b,h) kernel calls + query_flat = query.reshape(BH, total_seq_len, k_dim).contiguous() + key_flat = key.reshape(BH, total_seq_len, k_dim).contiguous() + value_flat = value.reshape(BH, total_seq_len, v_dim).contiguous() + + # g and beta: (BH, S) -> (BH, S, 1) for the kernel's (S, 1) input layout + g_flat = g.reshape(BH, total_seq_len).unsqueeze(-1).contiguous() + beta_flat = beta.reshape(BH, total_seq_len).unsqueeze(-1).contiguous() + + # Create constant mask tensors (shared across all B*H calls) + device = query.device + lower_mask = torch.tensor( + _make_lower_mask(), dtype=torch.float32, device=device + ) + identity_mat = torch.tensor( + _make_identity(), dtype=torch.float32, device=device + ) + lower_mask_diag = torch.tensor( + _make_lower_mask_diag(), dtype=torch.float32, device=device + ) + + all_outputs = [] + all_states = [] + for bh in range(BH): + out_bh, state_bh = _deltanet_fused_kernel( + query_flat[bh], # (S, 128) + key_flat[bh], # (S, 128) + value_flat[bh], # (S, 128) + g_flat[bh], # (S, 1) — RAW g, not cumsum + beta_flat[bh], # (S, 1) — sigmoid(b) + lower_mask, # (128, 128) + identity_mat, # (128, 128) + lower_mask_diag, # (128, 128) + ) + all_outputs.append(out_bh) + all_states.append(state_bh) + + output = torch.stack(all_outputs, dim=0) + output = output.reshape(B, H, total_seq_len, v_dim) + output = output[:, :, :S] + + if output_final_state: + final_state = torch.stack(all_states, dim=0) + last_recurrent_state = final_state.reshape(B, H, k_dim, v_dim) + else: + last_recurrent_state = None + + return output, last_recurrent_state + + def _sequential_forward(self, query, key, value, g, beta, output_final_state=False): + """Sequential full-sequence gated delta rule for CTE. + + Uses the same per-step recurrence as _recurrent_step but loops over the + full sequence. Avoids the slice-assignment loop in _chunk_forward that + may compile incorrectly on Neuron/XLA. + """ + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + state = query.new_zeros(B, H, k_dim, v_dim) + all_outputs = [] + for t in range(S): + q_t = query[:, :, t] # (B, H, K) + k_t = key[:, :, t] # (B, H, K) + v_t = value[:, :, t] # (B, H, V) + beta_t = beta[:, :, t].unsqueeze(-1) # (B, H, 1) + g_t = g[:, :, t].exp().unsqueeze(-1).unsqueeze(-1) # (B, H, 1, 1) + + # Gated delta rule + state = state * g_t + kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2) # (B, H, V) + delta = (v_t - kv_mem) * beta_t # (B, H, V) + state = state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) # (B, H, K, V) + + o_t = (state * q_t.unsqueeze(-1)).sum(dim=-2) # (B, H, V) + all_outputs.append(o_t.unsqueeze(2)) + + output = torch.cat(all_outputs, dim=2) # (B, H, S, V) + final_state = state if output_final_state else None + return output, final_state + + def _chunk_forward( + self, query, key, value, g, beta, output_final_state=False, initial_state=None + ): + """Chunk-based forward for context encoding (prefill).""" + chunk_size = 64 + + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + pad_size = (chunk_size - S % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = S + pad_size + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + + num_chunks = total_seq_len // chunk_size + query = query.reshape(B, H, num_chunks, chunk_size, k_dim) + key = key.reshape(B, H, num_chunks, chunk_size, k_dim) + value = value.reshape(B, H, num_chunks, chunk_size, v_dim) + k_beta = k_beta.reshape(B, H, num_chunks, chunk_size, k_dim) + v_beta = v_beta.reshape(B, H, num_chunks, chunk_size, v_dim) + g = g.reshape(B, H, num_chunks, chunk_size) + + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=0, + ) + + g = g.cumsum(dim=-1) + decay_mask = (g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().tril() + + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + + if initial_state is None: + last_recurrent_state = torch.zeros( + B, H, k_dim, v_dim, dtype=query.dtype, device=query.device + ) + else: + last_recurrent_state = initial_state.to(dtype=query.dtype) + core_attn_out = torch.zeros_like(value) + mask2 = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=1, + ) + + for i in range(num_chunks): + q_i = query[:, :, i] + k_i = key[:, :, i] + v_i = value[:, :, i] + + attn_i = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_( + mask2, 0 + ) + + v_prime = k_cumdecay[:, :, i] @ last_recurrent_state + v_new = v_i - v_prime + + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn_i @ v_new + + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + ( + k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None] + ).transpose(-1, -2) + @ v_new + ) + + core_attn_out = core_attn_out.reshape(B, H, -1, v_dim) + core_attn_out = core_attn_out[:, :, :S] + + if not output_final_state: + last_recurrent_state = None + + return core_attn_out, last_recurrent_state + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask=None, + position_ids=None, + past_key_value=None, + **kwargs, + ): + """Forward pass compatible with NxDI decoder layer interface.""" + batch_size, seq_len, _ = hidden_states.shape + + seq_ids = kwargs.get("seq_ids", None) + qwen_chunked_prefill_active = ( + self.use_qwen_hybrid_chunked_prefill + and past_key_value is not None + and seq_len > 1 + ) + is_decode = past_key_value is not None and not qwen_chunked_prefill_active + + # Padding mask for DeltaNet: [B, S, 1] with 1.0 for real tokens, 0.0 for padding. + # Passed from get_model_output where it's computed from input_ids != pad_token_id. + # Embeddings are already zeroed for padding tokens; this mask additionally + # zeros the decay gate so the recurrent state is preserved unchanged + # through padding positions (no spurious decay). + valid_mask_1d = kwargs.get("deltanet_padding_mask", None) # [B, S, 1] or None + hybrid_cache_active = self.use_hybrid_cache_manager + recurrent_state_cache = None + conv_state_cache = None + if hybrid_cache_active and past_key_value is not None: + recurrent_state_cache, conv_state_cache = past_key_value + + # Project inputs + deltanet_fp32 = os.environ.get("DELTANET_FP32") == "1" + if deltanet_fp32 and isinstance(self.in_proj_qkv, nn.Linear): + hs_f32 = hidden_states.float() + qkv = F.linear(hs_f32, self.in_proj_qkv.weight.float()).to( + hidden_states.dtype + ) + z = F.linear(hs_f32, self.in_proj_z.weight.float()).to(hidden_states.dtype) + b = F.linear(hs_f32, self.in_proj_b.weight.float()).to(hidden_states.dtype) + a = F.linear(hs_f32, self.in_proj_a.weight.float()).to(hidden_states.dtype) + else: + qkv = self.in_proj_qkv(hidden_states) + z = self.in_proj_z(hidden_states) + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + + # Split QKV + query = qkv[..., : self.key_dim] + key = qkv[..., self.key_dim : self.key_dim * 2] + value = qkv[..., self.key_dim * 2 :] + + # Causal Conv1d on QKV + mixed = torch.cat([query, key, value], dim=-1) + mixed = mixed.transpose(1, 2) + + if is_decode: + if conv_state_cache is not None: + conv_state = conv_state_cache[:batch_size] + elif seq_ids is not None: + conv_state = torch.index_select(self.conv_state_buffer, 0, seq_ids) + else: + conv_state = self.conv_state_buffer[:batch_size] + conv_input = torch.cat([conv_state, mixed], dim=-1) + + w = self._conv1d_weight().squeeze(1) + conv_out = torch.zeros_like(mixed) + for k in range(4): + conv_out = ( + conv_out + + w[:, k].unsqueeze(0).unsqueeze(-1) * conv_input[:, :, k : k + 1] + ) + mixed_post_conv = F.silu(conv_out) + + new_conv_state = torch.cat([conv_state[:, :, 1:], mixed], dim=-1) + alloc_bs = self.conv_state_buffer.shape[0] + if hybrid_cache_active: + new_conv_state = new_conv_state.to(self.conv_state_buffer.dtype) + elif seq_ids is not None: + # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement + # Add buffer dependency for input_output_alias + new_conv_state = ( + new_conv_state.to(self.conv_state_buffer.dtype) + + self.conv_state_buffer * 0 + ) + elif batch_size < alloc_bs: + pad_size = alloc_bs - batch_size + new_conv_state = torch.cat( + [ + new_conv_state, + self.conv_state_buffer[batch_size:] * 0, + ], + dim=0, + ) + else: + new_conv_state = new_conv_state + self.conv_state_buffer * 0 + else: + if qwen_chunked_prefill_active and conv_state_cache is not None: + conv_state = conv_state_cache[:batch_size] + if position_ids is not None: + reset_mask = (position_ids[:, :1].long() == 0).to( + dtype=conv_state.dtype, device=conv_state.device + ) + conv_state = conv_state * (1.0 - reset_mask[:, None, :]) + conv_input = torch.cat([conv_state, mixed], dim=-1) + w = self._conv1d_weight().squeeze(1) + conv_out = torch.zeros_like(mixed) + for k in range(self.conv_kernel_size): + conv_out = conv_out + w[:, k].unsqueeze(0).unsqueeze(-1) * conv_input[ + :, :, k : k + seq_len + ] + mixed_post_conv = F.silu(conv_out) + if valid_mask_1d is not None: + state_len = self.conv_kernel_size - 1 + num_valid = valid_mask_1d.squeeze(-1).sum(dim=-1, keepdim=True).long() + idx_base = (state_len + num_valid - state_len).clamp(min=0) + offsets = torch.arange(state_len, device=mixed.device).unsqueeze(0) + gather_idx = idx_base + offsets + gather_idx = gather_idx.unsqueeze(1).expand(-1, self.conv_dim, -1) + new_conv_state = torch.gather(conv_input, 2, gather_idx) + else: + new_conv_state = conv_input[:, :, -self.conv_kernel_size + 1 :].contiguous() + else: + mixed_post_conv = F.silu( + F.conv1d( + mixed, + self._conv1d_weight(), + bias=None, + padding=self.conv_kernel_size - 1, + groups=self.conv_dim, + )[:, :, :seq_len] + ) + + if valid_mask_1d is not None: + # valid_mask_1d is [B, S, 1]; count valid tokens per batch + num_valid = ( + valid_mask_1d.squeeze(-1).sum(dim=-1, keepdim=True).long() + ) # [B, 1] + idx_base = num_valid - 3 + idx_base = idx_base.clamp(min=0) + offsets = torch.arange(3, device=mixed.device).unsqueeze(0) + gather_idx = idx_base + offsets # [B, 3] + gather_idx = gather_idx.unsqueeze(1).expand(-1, self.conv_dim, -1) + new_conv_state = torch.gather(mixed, 2, gather_idx) + else: + new_conv_state = mixed[:, :, -3:].contiguous() + + alloc_bs = self.conv_state_buffer.shape[0] + if hybrid_cache_active: + new_conv_state = new_conv_state.to(self.conv_state_buffer.dtype) + elif seq_ids is not None: + # BS=1 optimization: scatter to index 0 = direct replacement + new_conv_state = ( + new_conv_state.to(self.conv_state_buffer.dtype) + + self.conv_state_buffer * 0 + ) + elif batch_size < alloc_bs: + pad_size = alloc_bs - batch_size + new_conv_state = torch.cat( + [ + new_conv_state, + torch.zeros( + pad_size, + self.conv_dim, + self.conv_kernel_size - 1, + dtype=new_conv_state.dtype, + device=new_conv_state.device, + ), + ], + dim=0, + ) + new_conv_state = new_conv_state + self.conv_state_buffer * 0 + else: + new_conv_state = new_conv_state + self.conv_state_buffer * 0 + + mixed_post_conv = mixed_post_conv.transpose(1, 2) + + # Zero out conv1d output for padding positions. + # Conv1d with kernel_size=4 leaks real token info into the first + # few padding positions. Zeroing here ensures Q, K, V are exactly + # zero for all padding positions so the recurrence is unaffected. + if valid_mask_1d is not None: + mixed_post_conv = ( + mixed_post_conv * valid_mask_1d + ) # [B, S, conv_dim] * [B, S, 1] + + query = mixed_post_conv[..., : self.key_dim] + key = mixed_post_conv[..., self.key_dim : self.key_dim * 2] + value = mixed_post_conv[..., self.key_dim * 2 :] + + # Reshape to heads + query = query.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) + key = key.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) + value = value.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + + # Compute gating + beta = b.sigmoid() + g = -self._A_log().float().exp() * F.softplus(a.float() + self._dt_bias()) + + if valid_mask_1d is not None: + # Zero g for padding → alpha=exp(0)=1 → state preserved through padding + # Zero beta for padding → no state update from padding tokens + mask_2d = valid_mask_1d.squeeze(-1).float() # [B, S] + g = g * mask_2d.unsqueeze(-1) + beta = beta * mask_2d.unsqueeze(-1) + + # Expand K heads to match V heads (16 -> 48) using expand+reshape + if self.num_v_heads // self.num_k_heads > 1: + rep = self.num_v_heads // self.num_k_heads # 3 + query = ( + query.unsqueeze(3) + .expand(-1, -1, -1, rep, -1) + .reshape(batch_size, seq_len, self.num_v_heads, self.head_k_dim) + ) + key = ( + key.unsqueeze(3) + .expand(-1, -1, -1, rep, -1) + .reshape(batch_size, seq_len, self.num_v_heads, self.head_k_dim) + ) + + # Transpose to (B, H, S, dim) + query = query.transpose(1, 2).contiguous().float() + key = key.transpose(1, 2).contiguous().float() + value = value.transpose(1, 2).contiguous().float() + g = g.transpose(1, 2).contiguous().float() + beta = beta.transpose(1, 2).contiguous().float() + + if is_decode: + # TKG: single-step recurrent update + if recurrent_state_cache is not None: + recurrent_state = recurrent_state_cache[:batch_size].float() + elif seq_ids is not None: + recurrent_state = torch.index_select( + self.recurrent_state_buffer, 0, seq_ids + ).float() + else: + recurrent_state = self.recurrent_state_buffer[:batch_size].float() + + output, new_state = self._recurrent_step( + query, key, value, g, beta, recurrent_state + ) + new_state_bf16 = new_state.to(self.recurrent_state_buffer.dtype) + alloc_bs = self.recurrent_state_buffer.shape[0] + if hybrid_cache_active: + new_rec_state = new_state_bf16 + elif seq_ids is not None: + # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement + # Add buffer dependency for input_output_alias + new_rec_state = new_state_bf16 + self.recurrent_state_buffer * 0 + elif batch_size < alloc_bs: + new_rec_state = torch.cat( + [ + new_state_bf16, + self.recurrent_state_buffer[batch_size:] * 0, + ], + dim=0, + ) + else: + new_rec_state = new_state_bf16 + self.recurrent_state_buffer * 0 + else: + # CTE: fused NKI kernel by default (PyTorch _chunk_forward can hit + # neuronx-cc codegen ICE NCC_INLA001 with these DeltaNet dimensions). + # Override with env vars for debugging/benchmarking. + use_nki_fused = os.environ.get("USE_NKI_FUSED", "1") != "0" + use_nki_chunked = os.environ.get("USE_NKI_CHUNKED") == "1" + use_nki = os.environ.get("USE_NKI") == "1" + use_sequential = os.environ.get("DELTANET_SEQUENTIAL") == "1" + use_pytorch_chunk = os.environ.get("USE_PYTORCH_CHUNK") == "1" + + if qwen_chunked_prefill_active and recurrent_state_cache is not None: + initial_state = recurrent_state_cache[:batch_size].float() + if position_ids is not None: + reset_mask = (position_ids[:, :1].long() == 0).to( + dtype=initial_state.dtype, device=initial_state.device + ) + initial_state = initial_state * (1.0 - reset_mask[:, :, None, None]) + if self.use_qwen_hybrid_chunked_prefill_nki: + output, final_state = self._nki_chunked_forward( + query, + key, + value, + g, + beta, + output_final_state=True, + initial_state=initial_state, + ) + else: + output, final_state = self._chunk_forward( + query, + key, + value, + g, + beta, + output_final_state=True, + initial_state=initial_state, + ) + elif use_pytorch_chunk: + output, final_state = self._chunk_forward( + query, key, value, g, beta, output_final_state=True + ) + elif use_nki_chunked: + output, final_state = self._nki_chunked_forward( + query, key, value, g, beta, output_final_state=True + ) + elif use_nki: + output, final_state = self._nki_recurrent_forward( + query, key, value, g, beta + ) + elif use_sequential: + output, final_state = self._sequential_forward( + query, key, value, g, beta, output_final_state=True + ) + elif use_nki_fused: + output, final_state = self._fused_chunked_forward( + query, key, value, g, beta, output_final_state=True + ) + else: + output, final_state = self._fused_chunked_forward( + query, key, value, g, beta, output_final_state=True + ) + + if final_state is not None: + final_state_bf16 = final_state.to(self.recurrent_state_buffer.dtype) + alloc_bs = self.recurrent_state_buffer.shape[0] + if hybrid_cache_active: + new_rec_state = final_state_bf16 + elif seq_ids is not None: + # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement + # Add buffer dependency for input_output_alias + new_rec_state = final_state_bf16 + self.recurrent_state_buffer * 0 + elif batch_size < alloc_bs: + new_rec_state = torch.cat( + [ + final_state_bf16, + torch.zeros( + alloc_bs - batch_size, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + dtype=final_state_bf16.dtype, + device=final_state_bf16.device, + ), + ], + dim=0, + ) + new_rec_state = new_rec_state + self.recurrent_state_buffer * 0 + else: + new_rec_state = final_state_bf16 + self.recurrent_state_buffer * 0 + else: + new_rec_state = self.recurrent_state_buffer * 1 + + # Output: norm, gate, project + output = output.to(hidden_states.dtype) + output = output.transpose(1, 2).contiguous() + output = output.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + output = self.norm(output) + z_gate = z.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + output = output * F.silu(z_gate) + output = output.reshape(batch_size, seq_len, self.value_dim) + output = self.out_proj(output) + + if hybrid_cache_active: + return output, (new_rec_state, new_conv_state), new_rec_state, new_conv_state + + # Return dummy KV for KVCacheManager + dummy_k = torch.zeros( + batch_size, + self.kv_heads_per_rank, + seq_len, + self.head_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + dummy_v = torch.zeros_like(dummy_k) + + return output, (dummy_k, dummy_v), new_rec_state, new_conv_state + + +# ============================================================ +# InferenceConfig (Dense -- no MoE) +# ============================================================ + + +class Qwen35InferenceConfig(InferenceConfig): + """Config for Qwen3.5/3.6-27B (dense) with hybrid DeltaNet + Attention.""" + + def __init__(self, *args, **kwargs): + # Set defaults BEFORE super().__init__() because it calls validate_config() + # which checks get_required_attributes(). These can be overridden by + # kwargs or load_config. + + # Layer types for hybrid dispatch: [3 DeltaNet + 1 GQA] repeated. + if "layer_types" not in kwargs and not any( + hasattr(a, "layer_types") for a in args if hasattr(a, "__dict__") + ): + num_layers = kwargs.get("num_hidden_layers", 64) + if num_layers % 4 != 0: + raise ValueError( + f"Qwen3.5 hybrid layer count must be divisible by 4, got {num_layers}" + ) + layer_types = [] + for _ in range(num_layers // 4): + layer_types.extend( + [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + ] + ) + kwargs.setdefault("layer_types", layer_types) + + # DeltaNet-specific config defaults + kwargs.setdefault("linear_num_value_heads", 48) + kwargs.setdefault("linear_num_key_heads", 16) + kwargs.setdefault("linear_key_head_dim", 128) + kwargs.setdefault("linear_value_head_dim", 128) + kwargs.setdefault("linear_conv_kernel_dim", 4) + kwargs.setdefault("use_hybrid_cache_manager", False) + kwargs.setdefault("use_qwen_hybrid_chunked_prefill", False) + kwargs.setdefault("use_qwen_hybrid_chunked_prefill_nki", False) + + super().__init__(*args, **kwargs) + + # Attention output gate + self.attn_output_gate = getattr(self, "attn_output_gate", True) + + # Partial RoPE + self.partial_rotary_factor = getattr(self, "partial_rotary_factor", 0.25) + self.rope_dim = int(self.head_dim * self.partial_rotary_factor) # 64 + + # mRoPE (multimodal RoPE) for VL support + rope_params = getattr(self, "rope_parameters", {}) or {} + self.mrope_section = rope_params.get("mrope_section", [11, 11, 10]) + self.mrope_interleaved = rope_params.get("mrope_interleaved", True) + + # Standard HF config attributes expected by NxDI + if not hasattr(self, "output_attentions"): + self.output_attentions = False + if not hasattr(self, "output_hidden_states"): + self.output_hidden_states = False + + def get_required_attributes(self) -> List[str]: + return [ + "head_dim", + "hidden_act", + "hidden_size", + "intermediate_size", + "max_position_embeddings", + "num_attention_heads", + "num_hidden_layers", + "num_key_value_heads", + "rms_norm_eps", + "rope_theta", + "vocab_size", + # DeltaNet-specific + "linear_num_value_heads", + "linear_num_key_heads", + "linear_key_head_dim", + "linear_value_head_dim", + "linear_conv_kernel_dim", + "layer_types", + ] + + @classmethod + def get_neuron_config_cls(cls): + return NeuronConfig + + +# ============================================================ +# Attention (standard GQA for 16 of 64 layers) +# With output gate: q_proj is 2x sized, split into (query, gate) +# With partial RoPE: only first rope_dim dimensions get rotary +# ============================================================ + + +class Qwen35MRoPEEmbedding(nn.Module): + """Multimodal Rotary Position Embedding (mRoPE) for Qwen3.5. + + Handles 3D position information (temporal, height, width) for VL models. + Position IDs have shape (3, batch_size, seq_len) for T/H/W dimensions. + For text-only (2D position_ids), broadcasts to 3D with identical positions. + """ + + def __init__(self, config): + super().__init__() + self.head_dim = config.head_dim # 256 + self.rope_dim = config.rope_dim # 64 + self.mrope_section = config.mrope_section # [11, 11, 10] + self.mrope_interleaved = getattr(config, "mrope_interleaved", True) + self.rope_theta = config.rope_theta + + # Validate mrope_section sums to rope_dim // 2 = 32 + assert sum(self.mrope_section) == self.rope_dim // 2, ( + f"mrope_section {self.mrope_section} sums to {sum(self.mrope_section)}, " + f"expected {self.rope_dim // 2}" + ) + + def forward(self, x, position_ids_3d): + """Compute cos/sin from 3D position IDs. + + Args: + x: hidden_states (for device/dtype inference) + position_ids_3d: (3, batch_size, seq_len) -- T, H, W positions + + Returns: + cos: (batch_size, seq_len, rope_dim) + sin: (batch_size, seq_len, rope_dim) + """ + device = x.device + dtype = torch.float32 + + if position_ids_3d.ndim == 2: + position_ids_3d = position_ids_3d[None, ...].expand( + 3, position_ids_3d.shape[0], -1 + ) + + inv_freq = 1.0 / ( + self.rope_theta + ** ( + torch.arange(0, self.rope_dim, 2, dtype=dtype, device=device) + / self.rope_dim + ) + ) + inv_freq = inv_freq[None, None, :, None].expand( + 3, position_ids_3d.shape[1], -1, 1 + ) + positions = position_ids_3d[:, :, None, :].float() + freqs = (inv_freq.float() @ positions).transpose(2, 3) + + # Match HF Qwen3.6 mRoPE layout exactly: start from the temporal + # frequencies, then splice H/W frequencies into interleaved positions. + freqs_t = freqs[0] + if self.mrope_interleaved: + for dim, offset in enumerate((1, 2), start=1): + length = self.mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + + emb = torch.cat((freqs_t, freqs_t), dim=-1) + cos = emb.cos().to(dtype=x.dtype) + sin = emb.sin().to(dtype=x.dtype) + + return cos, sin + + +class NeuronQwen35Attention(NeuronAttentionBase): + """Standard GQA attention for Qwen3.5 with output gate and partial RoPE. + + 24 Q heads, 4 KV heads (6:1 GQA), head_dim=256 for 27B dense. + q_proj is doubled (query + gate), split at load time. + Only first rope_dim=64 of head_dim=256 gets rotary encoding. + + Uses NeuronAttentionBase infrastructure for QKV projection, KV cache, + RoPE, and attention computation. Overrides forward() to insert the + sigmoid output gate between attention output and o_proj. + """ + + def __init__(self, config): + # Partial RoPE: create mRoPE embedding with rope_dim (64) + self.rope_dim = config.rope_dim # 64 = head_dim * partial_rotary_factor + + # Create QK norm modules (will be passed to base class) + rms_norm_eps = config.rms_norm_eps + q_ln = get_rmsnorm_cls()(config.head_dim, rms_norm_eps) + k_ln = get_rmsnorm_cls()(config.head_dim, rms_norm_eps) + + # Partial RoPE: use standard RotaryEmbedding. + # For VL with 3D mRoPE positions, cos/sin are pre-computed externally in + # get_model_output() using Qwen35MRoPEEmbedding and passed as cos_cache/sin_cache. + rotary_emb = RotaryEmbedding( + self.rope_dim, # Only 64 dims get rotary embedding + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + super().__init__( + config=config, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=config.head_dim, + rotary_emb=rotary_emb, + rms_norm_eps=rms_norm_eps, + use_qk_norm=False, + q_layernorm=q_ln, + k_layernorm=k_ln, + ) + + # Separate mRoPE module for VL 3D position_ids + self.mrope_emb = Qwen35MRoPEEmbedding(config) + + # Output gate projection: hidden_size -> num_heads * head_dim + # Populated from the second half of q_proj during state dict conversion. + self.output_gate_proj = ColumnParallelLinear( + config.hidden_size, + config.num_attention_heads * config.head_dim, + bias=False, + gather_output=False, + ) + + def apply_rotary_embedding( + self, Q, K, V, position_ids, cos_cache, sin_cache, use_polar_compatible_rope + ): + """Partial RoPE: only apply rotary embedding to first rope_dim dimensions. + + Q shape: (B, H, S, head_dim) where head_dim=256 + cos/sin shape: (B, S, rope_dim) where rope_dim=64 (from RotaryEmbedding(dim=64)) + + Split Q/K along last dim into: + q_rope (first 64 dims) -- apply RoPE + q_pass (remaining 192 dims) -- pass through unchanged + """ + from neuronx_distributed_inference.modules.attention.utils import ( + apply_rotary_pos_emb, + ) + + if self.rotary_emb is not None: + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb(V, position_ids) + + # Split into rope and pass-through portions + Q_orig_dtype = Q.dtype + q_rope = Q[..., : self.rope_dim] # (B, H, S, 64) + q_pass = Q[..., self.rope_dim :] # (B, H, S, 192) + k_rope = K[..., : self.rope_dim] + k_pass = K[..., self.rope_dim :] + + # Apply RoPE only to the rope portion + q_rope, k_rope = apply_rotary_pos_emb(q_rope, k_rope, cos_cache, sin_cache) + + # Concatenate back (ensure bf16 is maintained) + Q = torch.cat([q_rope, q_pass], dim=-1).to(Q_orig_dtype) + K = torch.cat([k_rope, k_pass], dim=-1).to(Q_orig_dtype) + + return Q, K, cos_cache, sin_cache + + def perform_prefill(self, Q, K, V, q_len, bsz, attention_mask=None): + """Prefill path with NKI flash attention for head_dim=256.""" + head_dim = Q.shape[-1] + + # Option B: nkilib flash attention for head_dim > 128 + if _nkilib_flash_attn is not None: + q_contig = Q.contiguous() + k_contig = K.contiguous() + v_contig = V.contiguous() + scale = 1.0 / math.sqrt(head_dim) + result = _nkilib_flash_attn( + q_contig, k_contig, v_contig, scale=scale, use_causal_mask=True + ) + return result, None + + # Option A: kernel patched globally + if NKILIB_PATCH_ACTIVE: + return _flash_fwd_call(Q, K, V, use_causal_mask=True), None + + # Fallback: softmax path (use 3D tensors to avoid compiler ICE with 4D patterns) + if head_dim > 128: + # GQA: expand K/V heads to match Q heads + num_q_heads = Q.shape[1] + num_kv_heads = K.shape[1] + if num_q_heads != num_kv_heads: + kv_rep = num_q_heads // num_kv_heads + K = ( + K.unsqueeze(2) + .expand(-1, -1, kv_rep, -1, -1) + .reshape(bsz, num_q_heads, q_len, head_dim) + ) + V = ( + V.unsqueeze(2) + .expand(-1, -1, kv_rep, -1, -1) + .reshape(bsz, num_q_heads, q_len, head_dim) + ) + # Reshape to 3D (B*H, S, d) to avoid neuronx-cc codegen ICE with 4D + # attention weight tensors (NCC_INLA001: Expected 2D tensor but got 4D AP) + Q_3d = Q.reshape(bsz * num_q_heads, q_len, head_dim) + K_3d = K.reshape(bsz * num_q_heads, q_len, head_dim) + V_3d = V.reshape(bsz * num_q_heads, q_len, head_dim) + attn_weights = torch.bmm(Q_3d, K_3d.transpose(-1, -2)) / math.sqrt(head_dim) + # Build causal mask for 3D: (1, S, S) broadcast over B*H + causal_mask = torch.triu( + torch.full( + (q_len, q_len), + -65504.0, + dtype=attn_weights.dtype, + device=attn_weights.device, + ), + diagonal=1, + ).unsqueeze(0) + attn_weights = attn_weights + causal_mask + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + Q.dtype + ) + attn_output = torch.bmm(attn_weights, V_3d) + # Reshape back to 4D (B, H, S, d) + return attn_output.reshape(bsz, num_q_heads, q_len, head_dim), None + + return _flash_fwd_call(Q, K, V, use_causal_mask=True), None + + def perform_qwen_chunked_prefill(self, Q, K, V, past_key_value, position_ids): + """Exact chunked CTE over the full decode cache. + + The current chunk K/V tensors are scattered into the full cache at + absolute position_ids, then attention for this chunk is computed over + all cache positions up to the chunk end. This keeps full-attention + layers correct when model-local chunked prefill feeds context in + multiple CTE-bucket calls. + """ + k_cache, v_cache = past_key_value + B, q_heads, q_len, head_dim = Q.shape + kv_heads = K.shape[1] + cache_len = k_cache.shape[2] + + pos = position_ids.long() + k_index = pos[:, None, :, None].expand(B, kv_heads, q_len, head_dim) + k_cache = torch.scatter(k_cache, dim=2, index=k_index, src=K.to(k_cache.dtype)) + v_cache = torch.scatter(v_cache, dim=2, index=k_index, src=V.to(v_cache.dtype)) + + if q_heads != kv_heads: + kv_rep = q_heads // kv_heads + K_full = ( + k_cache.unsqueeze(2) + .expand(-1, -1, kv_rep, -1, -1) + .reshape(B, q_heads, cache_len, head_dim) + ) + V_full = ( + v_cache.unsqueeze(2) + .expand(-1, -1, kv_rep, -1, -1) + .reshape(B, q_heads, cache_len, head_dim) + ) + else: + K_full = k_cache + V_full = v_cache + + attn_weights = torch.matmul(Q, K_full.transpose(-1, -2)) / math.sqrt(head_dim) + cache_positions = torch.arange(cache_len, device=position_ids.device).view(1, 1, 1, -1) + causal_mask = cache_positions <= pos[:, None, :, None] + attn_weights = attn_weights.masked_fill(~causal_mask, -65504.0) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(Q.dtype) + return torch.matmul(attn_weights, V_full) + + def forward( + self, + hidden_states, + attention_mask=None, + position_ids=None, + past_key_value=None, + cos_cache=None, + sin_cache=None, + rmsnorm=None, + adapter_ids=None, + active_mask=None, + **kwargs, + ): + """Forward with output gate applied BEFORE o_proj. + + Override NeuronAttentionBase.forward() to insert the sigmoid gate + between the attention output and o_proj, matching the HF reference: + gate = sigmoid(gate_proj(pre_attn_hidden)) + attn_output = attn_output * gate + attn_output = o_proj(attn_output) + """ + bsz, q_len, _ = hidden_states.shape + + # Use standard 2D position_ids for prep_qkv_tensors. + rope_pos_ids = position_ids + + # Compute gate from input hidden states (before QKV projection) + gate = self.output_gate_proj(hidden_states) # (B, S, num_heads * head_dim) + + # Standard QKV prep (projections, QK norm, RoPE) + Q, K, V, cos_cache, sin_cache, _residual = self.prep_qkv_tensors( + rope_pos_ids, + hidden_states, + past_key_value, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + rmsnorm=rmsnorm, + ) + + qwen_chunked_prefill_active = ( + past_key_value is not None + and q_len > 1 + and getattr(self.config, "use_qwen_hybrid_chunked_prefill", False) + ) + + if past_key_value is None: + # Context encoding (prefill) + attn_output, _flash_strategy = self.perform_prefill( + Q, K, V, q_len, bsz, attention_mask + ) + elif qwen_chunked_prefill_active: + attn_output = self.perform_qwen_chunked_prefill( + Q, K, V, past_key_value, position_ids + ) + else: + # Token generation (decode) + tkg_mask = attention_mask + if tkg_mask is not None and tkg_mask.ndim == 2: + tkg_mask = tkg_mask.unsqueeze(1).unsqueeze(2) # (B, S) -> (B, 1, 1, S) + attn_output = self.compute_for_token_gen( + Q, K, V, position_ids, past_key_value, tkg_mask, active_mask + ) + + # attn_output is (B, H, S, head_dim) -- transpose to (B, S, H*head_dim) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + + # Apply sigmoid output gate BEFORE o_proj (matching HF reference) + attn_output = attn_output * torch.sigmoid(gate) + + # Apply o_proj + attn_output = self.get_o_proj()(attn_output, adapter_ids=adapter_ids) + + # Ensure K, V are in model dtype (bf16) for KV cache update + # (prevents mixed-precision dynamic-update-slice in neuronx-cc) + K = K.to(self.torch_dtype) + V = V.to(self.torch_dtype) + past_key_value = (K, V) + return attn_output, past_key_value, cos_cache, sin_cache + + +# ============================================================ +# Dense MLP (replaces MoE) +# ============================================================ + + +class Qwen35MLP(nn.Module): + """Dense SwiGLU MLP for Qwen3.5/3.6-27B. + + gate_proj: hidden_size -> intermediate_size (5120 -> 17408) + up_proj: hidden_size -> intermediate_size (5120 -> 17408) + down_proj: intermediate_size -> hidden_size (17408 -> 5120) + + output = down_proj(silu(gate_proj(x)) * up_proj(x)) + """ + + def __init__(self, config): + super().__init__() + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=False, + gather_output=False, + ) + self.up_proj = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=False, + gather_output=False, + ) + self.down_proj = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=False, + input_is_parallel=True, + ) + + def forward(self, hidden_states): + gate = self.gate_proj(hidden_states) + up = self.up_proj(hidden_states) + hidden_states = F.silu(gate) * up + hidden_states = self.down_proj(hidden_states) + return hidden_states + + +# ============================================================ +# Decoder Layer (hybrid dispatch -- DeltaNet or GQA + Dense MLP) +# ============================================================ + + +class NeuronQwen35DecoderLayer(nn.Module): + """Hybrid decoder layer: dispatches to DeltaNet or standard attention. + Uses dense MLP for all layers (no MoE). + """ + + def __init__(self, config: Qwen35InferenceConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_type = config.layer_types[layer_idx] + self.layer_idx = layer_idx + self.config = config + + # Attention (DeltaNet or standard GQA) + if self.layer_type == "linear_attention": + self.linear_attn = NeuronGatedDeltaNet(config, layer_idx) + else: + self.self_attn = NeuronQwen35Attention(config=config) + + # Dense MLP (all layers) + self.mlp = Qwen35MLP(config) + + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask=None, + position_ids=None, + past_key_value=None, + padding_mask=None, + cos_cache=None, + sin_cache=None, + **kwargs, + ): + residual = hidden_states + + hidden_states = ModuleMarkerStartWrapper()(hidden_states) + hidden_states = self.input_layernorm(hidden_states) + + if self.layer_type == "linear_attention": + # DeltaNet path + attn_out, dummy_kv, new_rec_state, new_conv_state = self.linear_attn( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + **kwargs, + ) + hidden_states = residual + attn_out + present_key_value = dummy_kv + deltanet_states = ( + None + if getattr(self.config, "use_hybrid_cache_manager", False) + else (new_rec_state, new_conv_state) + ) + else: + deltanet_states = None + # Standard attention path + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + cos_cache=cos_cache, + sin_cache=sin_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Dense MLP FFN + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + hidden_states = ModuleMarkerEndWrapper()(hidden_states) + outputs = ( + hidden_states, + present_key_value, + cos_cache, + sin_cache, + None, + deltanet_states, + ) + return outputs + + +# ============================================================ +# Hybrid Cache Manager (opt-in) +# ============================================================ + + +class HybridDeltaNetCacheManager(KVCacheManager): + """Layer-type-aware cache manager for Qwen3.5/Qwen3.6 hybrid dense models.""" + + def __init__(self, config: Qwen35InferenceConfig, num_kv_head, **kwargs): + self.layer_types = list(config.layer_types) + self._validate_hybrid_config(config) + super().__init__(config, num_kv_head=num_kv_head, **kwargs) + + dtype = ( + config.neuron_config.attention_dtype + if config.neuron_config.attention_dtype is not None + else config.neuron_config.torch_dtype + ) + cache_dtype = getattr(self, "cache_dtype", dtype) + max_batch_size = ( + config.neuron_config.kv_cache_batch_size + + config.neuron_config.kv_cache_padding_size + ) + tp_degree = config.neuron_config.tp_degree + if config.linear_num_value_heads % tp_degree != 0: + raise ValueError( + f"linear_num_value_heads={config.linear_num_value_heads} must be divisible " + f"by tp_degree={tp_degree}" + ) + if config.linear_num_key_heads % tp_degree != 0: + raise ValueError( + f"linear_num_key_heads={config.linear_num_key_heads} must be divisible " + f"by tp_degree={tp_degree}" + ) + local_num_value_heads = config.linear_num_value_heads // tp_degree + local_num_key_heads = config.linear_num_key_heads // tp_degree + recurrent_shape = [ + max_batch_size, + local_num_value_heads, + config.linear_key_head_dim, + config.linear_value_head_dim, + ] + conv_dim = ( + 2 * local_num_key_heads * config.linear_key_head_dim + + local_num_value_heads * config.linear_value_head_dim + ) + conv_shape = [ + max_batch_size, + conv_dim, + config.linear_conv_kernel_dim - 1, + ] + + params = [] + for layer_idx, layer_type in enumerate(self.layer_types): + if layer_type == "linear_attention": + params.append( + nn.Parameter(torch.zeros(recurrent_shape, dtype=dtype), requires_grad=False) + ) + params.append( + nn.Parameter(torch.zeros(conv_shape, dtype=dtype), requires_grad=False) + ) + else: + k_shape = self.k_shapes[layer_idx] if hasattr(self, "k_shapes") else self.k_shape + v_shape = self.v_shapes[layer_idx] if hasattr(self, "v_shapes") else self.v_shape + params.append( + nn.Parameter(torch.zeros(k_shape, dtype=cache_dtype), requires_grad=False) + ) + params.append( + nn.Parameter(torch.zeros(v_shape, dtype=cache_dtype), requires_grad=False) + ) + + self.past_key_values = nn.ParameterList(params) + + @staticmethod + def _validate_hybrid_config(config: Qwen35InferenceConfig): + nc = config.neuron_config + unsupported = [] + if nc.is_block_kv_layout: + unsupported.append("block KV layout") + if getattr(nc, "kv_quant_config", None) is not None or getattr(nc, "kv_cache_quant", False): + unsupported.append("KV cache quantization") + if nc.enable_fused_speculation or nc.speculation_length > 0 or nc.is_medusa: + unsupported.append("speculative decoding") + if getattr(nc, "enable_eagle_speculation", False) or getattr(nc, "is_eagle_draft", False): + unsupported.append("EAGLE speculation") + if nc.flash_decoding_enabled: + unsupported.append("flash decoding") + if nc.attention_dp_degree > 1: + unsupported.append("attention data parallelism") + if nc.kv_cache_tiling: + unsupported.append("KV cache tiling") + if nc.padding_side != "right": + unsupported.append("left padding") + if nc.is_continuous_batching: + unsupported.append("continuous batching") + if unsupported: + raise ValueError( + "HybridDeltaNetCacheManager v1 does not support: " + + ", ".join(unsupported) + ) + + def _is_deltanet_layer(self, idx: int) -> bool: + return self.layer_types[idx] == "linear_attention" + + def get_seq_length(self, past_key_values=None): + for idx, layer_type in enumerate(self.layer_types): + if layer_type != "linear_attention": + if past_key_values is None: + _, v_cache = self._fetch_cache(idx) + elif len(past_key_values) == len(self.past_key_values): + v_cache = past_key_values[2 * idx + 1] + else: + v_cache = past_key_values[idx][1] + return v_cache.shape[2] + return 0 + + def get_deltanet_state_by_layer_id(self, idx, kvcache_buffer=None, seq_ids=None): + recurrent_state, conv_state = self._fetch_cache(idx, kvcache_buffer) + if seq_ids is not None: + cache_idx = self.get_cache_update_index_for_seq_ids(seq_ids) + recurrent_state = torch.index_select(recurrent_state, dim=0, index=cache_idx) + conv_state = torch.index_select(conv_state, dim=0, index=cache_idx) + elif self.kv_cache_padding_size > 0: + recurrent_state = recurrent_state[: -self.kv_cache_padding_size] + conv_state = conv_state[: -self.kv_cache_padding_size] + return recurrent_state, conv_state + + def get_cache( + self, + seq_len: int, + skip_slice=False, + kvcache_buffer=None, + seq_ids=None, + windowed_context_encoding_window_idx=-1, + **kwargs, + ): + past_key_values = [] + for idx in range(len(self.past_key_values) // 2): + if self._is_deltanet_layer(idx): + past_key_values.append( + list(self.get_deltanet_state_by_layer_id(idx, kvcache_buffer, seq_ids)) + ) + else: + past_key_values.append( + list( + self.get_kv_by_layer_id( + idx=idx, + skip_slice=skip_slice, + seq_len=seq_len, + kvcache_buffer=kvcache_buffer, + seq_ids=seq_ids, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + ) + ) + return past_key_values + + def update_cache( + self, + is_for_context_encoding: bool, + seq_ids: torch.Tensor, + position_ids: torch.Tensor, + new_key_values: List[torch.Tensor], + seq_len: int, + scatter_index=None, + kv_active_mask=None, + kvcache_buffer=None, + windowed_context_encoding_window_idx: int = -1, + **kwargs, + ): + updated_cache = [] + for idx, kv_per_layer in enumerate(new_key_values): + if self._is_deltanet_layer(idx): + recurrent_state, conv_state = self.update_deltanet_state_by_layer_id( + idx=idx, + seq_ids=seq_ids, + state_per_layer=kv_per_layer, + kvcache_buffer=kvcache_buffer, + ) + elif kwargs.get("qwen_chunked_prefill_update", False): + recurrent_state, conv_state = self.update_qwen_chunked_kv_by_layer_id( + idx=idx, + seq_ids=seq_ids, + position_ids=position_ids, + kv_per_layer=kv_per_layer, + kvcache_buffer=kvcache_buffer, + valid_mask=kwargs.get("qwen_chunked_valid_mask", None), + ) + else: + recurrent_state, conv_state = self.update_kv_by_layer_id( + idx=idx, + is_for_context_encoding=is_for_context_encoding, + seq_ids=seq_ids, + position_ids=position_ids, + kv_per_layer=kv_per_layer, + seq_len=seq_len, + scatter_index=scatter_index, + kv_active_mask=kv_active_mask, + kvcache_buffer=kvcache_buffer, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + updated_cache.append(recurrent_state) + updated_cache.append(conv_state) + return updated_cache + + def update_qwen_chunked_kv_by_layer_id( + self, + idx: int, + seq_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_per_layer: Tuple[torch.Tensor, torch.Tensor], + kvcache_buffer=None, + valid_mask=None, + ): + latest_k, latest_v = kv_per_layer + k_cache, v_cache = self._fetch_cache(idx, kvcache_buffer) + latest_k = latest_k.to(k_cache.dtype) + latest_v = latest_v.to(v_cache.dtype) + + if seq_ids is not None: + cache_idx = self.get_cache_update_index_for_seq_ids(seq_ids) + selected_k = torch.index_select(k_cache, dim=0, index=cache_idx) + selected_v = torch.index_select(v_cache, dim=0, index=cache_idx) + else: + cache_idx = None + selected_k = k_cache[: latest_k.shape[0]] + selected_v = v_cache[: latest_v.shape[0]] + + pos = position_ids.long() + k_index = pos[:, None, :, None].expand_as(latest_k) + v_index = pos[:, None, :, None].expand_as(latest_v) + + if valid_mask is not None: + valid = valid_mask.to(torch.bool)[:, None, :, None] + old_k = torch.gather(selected_k, dim=2, index=k_index) + old_v = torch.gather(selected_v, dim=2, index=v_index) + latest_k = torch.where(valid, latest_k, old_k) + latest_v = torch.where(valid, latest_v, old_v) + + updated_k = torch.scatter(selected_k, dim=2, index=k_index, src=latest_k) + updated_v = torch.scatter(selected_v, dim=2, index=v_index, src=latest_v) + + if cache_idx is not None: + k_row_index = cache_idx.view(-1, 1, 1, 1).expand_as(updated_k) + v_row_index = cache_idx.view(-1, 1, 1, 1).expand_as(updated_v) + k_cache = torch.scatter(k_cache, dim=0, index=k_row_index, src=updated_k) + v_cache = torch.scatter(v_cache, dim=0, index=v_row_index, src=updated_v) + return k_cache, v_cache + + if updated_k.shape[0] == k_cache.shape[0]: + return updated_k + k_cache * 0, updated_v + v_cache * 0 + + pad_rows = k_cache.shape[0] - updated_k.shape[0] + if pad_rows > 0: + updated_k = torch.cat([updated_k, k_cache[updated_k.shape[0] :] * 0], dim=0) + updated_v = torch.cat([updated_v, v_cache[updated_v.shape[0] :] * 0], dim=0) + return updated_k + k_cache * 0, updated_v + v_cache * 0 + + def update_deltanet_state_by_layer_id( + self, + idx: int, + seq_ids: torch.Tensor, + state_per_layer: Tuple[torch.Tensor, torch.Tensor], + kvcache_buffer=None, + ): + latest_recurrent, latest_conv = state_per_layer + recurrent_cache, conv_cache = self._fetch_cache(idx, kvcache_buffer) + latest_recurrent = latest_recurrent.to(recurrent_cache.dtype) + latest_conv = latest_conv.to(conv_cache.dtype) + + if latest_recurrent.shape[0] == recurrent_cache.shape[0] and seq_ids is None: + return ( + latest_recurrent + recurrent_cache * 0, + latest_conv + conv_cache * 0, + ) + + if seq_ids is not None: + cache_idx = self.get_cache_update_index_for_seq_ids(seq_ids) + recurrent_index = cache_idx.view(-1, 1, 1, 1).expand_as(latest_recurrent) + conv_index = cache_idx.view(-1, 1, 1).expand_as(latest_conv) + recurrent_cache = torch.scatter( + input=recurrent_cache, + dim=0, + index=recurrent_index, + src=latest_recurrent, + ) + conv_cache = torch.scatter( + input=conv_cache, + dim=0, + index=conv_index, + src=latest_conv, + ) + return recurrent_cache, conv_cache + + pad_size = recurrent_cache.shape[0] - latest_recurrent.shape[0] + if pad_size > 0: + latest_recurrent = torch.cat( + [latest_recurrent, recurrent_cache[latest_recurrent.shape[0] :] * 0], + dim=0, + ) + latest_conv = torch.cat( + [latest_conv, conv_cache[latest_conv.shape[0] :] * 0], + dim=0, + ) + return latest_recurrent + recurrent_cache * 0, latest_conv + conv_cache * 0 + + +# ============================================================ +# Model +# ============================================================ + + +class NeuronQwen35Model(NeuronBaseModel): + def setup_attr_for_model(self, config: Qwen35InferenceConfig): + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config: Qwen35InferenceConfig): + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + ) + self.layers = nn.ModuleList( + [ + NeuronQwen35DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = get_rmsnorm_cls()(self.hidden_size, eps=config.rms_norm_eps) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=False if self.on_device_sampling else True, + bias=False, + ) + + # mRoPE embedding for VL + self.mrope_emb = Qwen35MRoPEEmbedding(config) + + def init_inference_optimization(self, config: Qwen35InferenceConfig): + super().init_inference_optimization(config) + if getattr(config, "use_hybrid_cache_manager", False): + self.kv_mgr = HybridDeltaNetCacheManager( + config, + num_kv_head=self.num_key_value_heads, + global_rank=self.rank_util, + attention_chunk_size=self.attention_chunk_size, + sliding_window=self.sliding_window, + windowed_context_encoding_size=self.windowed_context_encoding_size, + layer_to_cache_size_mapping=self.layer_to_cache_size_mapping, + ) + + @property + def _deltanet_state_params(self): + """Return DeltaNet state nn.Parameters in alias order.""" + params = [] + for layer in self.layers: + if hasattr(layer, "linear_attn"): + params.append(layer.linear_attn.recurrent_state_buffer) + params.append(layer.linear_attn.conv_state_buffer) + return params + + def encode_vision_to_input(self, inputs_embeds, vision_embeddings, vision_mask): + """Scatter vision embeddings into text input embeddings at image token positions.""" + _, max_positions, embedding_dim = inputs_embeds.shape + h_new = inputs_embeds.clone() + vision_flat = vision_embeddings.view(-1, embedding_dim) + positions_flat = vision_mask.view(-1) + h_new.view(-1, embedding_dim).index_put_( + (positions_flat,), vision_flat, accumulate=False + ) + return h_new + + def get_model_output( + self, + input_ids=None, + seq_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + active_mask=None, + inputs_embeds=None, + prev_hidden=None, + adapter_ids=None, + rotary_position_ids=None, + update_cache=False, + is_for_context_encoding=False, + vision_embeddings=None, + vision_mask=None, + local_attn_mask=None, + windowed_context_encoding_window_idx=-1, + padding_mask=None, + **kwargs, + ): + """Override to collect DeltaNet state tensors from decoder layers.""" + batch_size, seq_length = input_ids.shape[:2] + if self.config.neuron_config.layer_boundary_markers: + input_ids = ModuleMarkerStartWrapper()(input_ids) + + past_key_values_length = 0 + if past_key_values is not None: + if hasattr(self.kv_mgr, "get_seq_length"): + past_key_values_length = self.kv_mgr.get_seq_length(past_key_values) + else: + past_key_values_length = past_key_values[0][1].shape[2] + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # CRITICAL: Zero out embeddings for padding tokens so DeltaNet recurrence + # is not polluted. DeltaNet has no attention mask -- it processes all + # sequence positions through a linear recurrence. Padding tokens have + # real embedding vectors which corrupt the recurrence state. + # The mask is [B, S, 1] float with 1.0 for real tokens, 0.0 for padding. + if ( + is_for_context_encoding + and attention_mask is not None + and attention_mask.ndim == 2 + ): + deltanet_padding_mask = attention_mask.unsqueeze(-1).to( + inputs_embeds.dtype + ) + else: + deltanet_padding_mask = ( + (input_ids != self.padding_idx).unsqueeze(-1).to(inputs_embeds.dtype) + ) + if is_for_context_encoding: + inputs_embeds = inputs_embeds * deltanet_padding_mask + + # Vision embedding injection. Text-only calls still pass dummy vision + # tensors to keep the traced input signature stable; those tensors have + # one dummy entry per text token and must not overwrite text embeddings. + if (vision_embeddings is not None) and (vision_mask is not None): + if vision_embeddings.dtype != self.config.neuron_config.torch_dtype: + vision_embeddings = vision_embeddings.to( + self.config.neuron_config.torch_dtype + ) + has_real_vision_inputs = ( + vision_embeddings.ndim == 3 + and vision_mask.ndim == 3 + and vision_embeddings.shape[1] != seq_length + ) + if is_for_context_encoding and has_real_vision_inputs: + inputs_embeds = self.encode_vision_to_input( + inputs_embeds, vision_embeddings, vision_mask + ) + elif is_for_context_encoding and vision_embeddings.numel() > 0: + inputs_embeds = inputs_embeds + vision_embeddings.sum() * 0 + inputs_embeds = ( + inputs_embeds + vision_mask.sum().to(inputs_embeds.dtype) * 0 + ) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + hidden_states = inputs_embeds + + # Get KV cache for TKG and for model-local chunked CTE. + use_qwen_chunked_prefill = ( + is_for_context_encoding + and getattr(self.config, "use_qwen_hybrid_chunked_prefill", False) + ) + cache_size = ( + self.config.neuron_config.seq_len + if use_qwen_chunked_prefill + else self.n_positions + ) + if (not is_for_context_encoding) or use_qwen_chunked_prefill: + if self.kv_mgr is not None: + past_key_values = self.kv_mgr.get_cache( + seq_ids=seq_ids, + seq_len=cache_size, + is_for_context_encoding=is_for_context_encoding, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + + # Decoder layers + next_decoder_cache = () + deltanet_state_tensors = [] + cos_cache = None + sin_cache = None + + # Convert 2D attention_mask to 4D causal mask for CTE + if ( + attention_mask is not None + and attention_mask.ndim == 2 + and is_for_context_encoding + ): + causal = torch.ones( + (seq_length, seq_length), + dtype=torch.bool, + device=attention_mask.device, + ).tril() + padding_4d = attention_mask[:, None, None, :].to(torch.bool) + attention_mask = (causal[None, None, :, :] & padding_4d).to( + attention_mask.dtype + ) + + # Pre-compute mRoPE cos/sin + if rotary_position_ids is not None and rotary_position_ids.ndim == 3: + cos_cache, sin_cache = self.mrope_emb(inputs_embeds, rotary_position_ids) + + for idx, decoder_layer in enumerate(self.layers): + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + layer_outputs = decoder_layer( + hidden_states, + seq_ids=seq_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + active_mask=active_mask, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + rotary_position_ids=rotary_position_ids, + kv_mgr=self.kv_mgr, + get_kv_per_layer=False, + update_kv_per_layer=False, + idx=idx, + is_for_context_encoding=is_for_context_encoding, + seq_len=cache_size, + residual=None, + local_mask=local_attn_mask, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + padding_mask=padding_mask, + deltanet_padding_mask=deltanet_padding_mask, + qwen_chunked_prefill_update=use_qwen_chunked_prefill, + qwen_chunked_valid_mask=deltanet_padding_mask.squeeze(-1) + if use_qwen_chunked_prefill + else None, + **kwargs, + ) + + hidden_states = layer_outputs[0] + kv = layer_outputs[1] + next_decoder_cache += (kv,) + cos_cache, sin_cache = layer_outputs[2:4] + + # Collect DeltaNet state tensors + deltanet_states = layer_outputs[5] if len(layer_outputs) > 5 else None + if deltanet_states is not None: + deltanet_state_tensors.append(deltanet_states[0]) + deltanet_state_tensors.append(deltanet_states[1]) + + # Update KV cache + if update_cache: + next_decoder_cache = self.kv_mgr.update_cache( + is_for_context_encoding=is_for_context_encoding, + seq_ids=seq_ids, + position_ids=position_ids, + new_key_values=next_decoder_cache, + seq_len=cache_size, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + qwen_chunked_prefill_update=use_qwen_chunked_prefill, + qwen_chunked_valid_mask=deltanet_padding_mask.squeeze(-1) + if use_qwen_chunked_prefill + else None, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + self._deltanet_updated_states = deltanet_state_tensors + + return (hidden_states, next_decoder_cache) + + def forward( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden=None, + adapter_ids=None, + accepted_indices=None, + current_length=None, + medusa_mask=None, + scatter_index=None, + slot_mapping=None, + active_block_table=None, + num_queries=None, + computed_context_lens=None, + tile_q_indices=None, + tile_block_tables=None, + tile_masks=None, + inputs_embeds=None, + kv_cache=None, + active_mask=None, + rotary_position_id=None, + vision_embeddings=None, + vision_mask=None, + ): + """Override base forward to append DeltaNet state tensors to output.""" + prev_hidden = self.set_none_if_empty(prev_hidden) + adapter_ids = self.set_none_if_empty(adapter_ids) + accepted_indices = self.set_none_if_empty(accepted_indices) + current_length = self.set_none_if_empty(current_length) + medusa_mask = self.set_none_if_empty(medusa_mask) + scatter_index = self.set_none_if_empty(scatter_index) + slot_mapping = self.set_none_if_empty(slot_mapping) + active_block_table = self.set_none_if_empty(active_block_table) + num_queries = self.set_none_if_empty(num_queries) + computed_context_lens = self.set_none_if_empty(computed_context_lens) + tile_q_indices = self.set_none_if_empty(tile_q_indices) + tile_block_tables = self.set_none_if_empty(tile_block_tables) + tile_masks = self.set_none_if_empty(tile_masks) + inputs_embeds = self.set_none_if_empty(inputs_embeds) + kv_cache = self.set_none_if_empty(kv_cache) + active_mask = self.set_none_if_empty(active_mask) + rotary_position_id = self.set_none_if_empty(rotary_position_id) + vision_embeddings = self.set_none_if_empty(vision_embeddings) + vision_mask = self.set_none_if_empty(vision_mask) + + is_for_context_encoding = position_ids.shape[-1] != 1 and not ( + hasattr(self.neuron_config, "speculation_length") + and position_ids.shape[-1] == self.neuron_config.speculation_length + ) + + seq_ids = seq_ids.to(torch.int32) + attn_mask = attention_mask + + hidden_states, updated_kv_cache = self.get_model_output( + input_ids=input_ids, + seq_ids=seq_ids, + attention_mask=attn_mask, + position_ids=position_ids, + active_mask=active_mask, + inputs_embeds=inputs_embeds, + adapter_ids=adapter_ids, + rotary_position_ids=rotary_position_id, + update_cache=True, + is_for_context_encoding=is_for_context_encoding, + padding_mask=None, + active_block_table=active_block_table, + scatter_index=slot_mapping + if getattr(self, "is_block_kv_layout", False) + else scatter_index, + vision_embeddings=vision_embeddings, + vision_mask=vision_mask, + ) + + batch_size = input_ids.shape[0] + if not getattr(self, "sliced_hidden", False): + if not is_for_context_encoding: + pass + else: + if getattr(self.config, "use_qwen_hybrid_chunked_prefill", False): + if attention_mask is not None and attention_mask.ndim == 2: + index = ( + attention_mask.to(torch.long).sum(dim=1, keepdim=True) + - 1 + ).clamp(min=0) + else: + index = ( + (input_ids != self.padding_idx) + .sum(dim=1, keepdim=True) + .long() + - 1 + ).clamp(min=0) + else: + index = torch.max(position_ids, dim=1, keepdim=True).indices + index = index.unsqueeze(1).expand(batch_size, 1, self.hidden_size) + hidden_states = torch.gather(hidden_states, dim=1, index=index) + + logits = self.lm_head(hidden_states) + logits = logits.float() + + if hasattr(self.lm_head, "pad_size"): + if self.lm_head.gather_output: + rank_id = torch.tensor(0, device=logits.device, dtype=torch.int32) + world_size = 1 + else: + from neuronx_distributed.parallel_layers import parallel_state + + rank_id = self.rank_util.get_rank() + world_size = torch.distributed.get_world_size( + group=self.lm_head.tensor_parallel_group + ) + from neuronx_distributed_inference.models.model_base import ( + mask_padded_logits, + ) + + logits = mask_padded_logits( + logits, rank_id, world_size, pad_size=self.lm_head.pad_size + ) + + if self.on_device_sampling: + res = self._sample_on_device( + logits, sampling_params, False, is_for_context_encoding + ) + else: + res = logits + + outputs = [res] + if self.neuron_config.output_logits: + outputs += [logits] + outputs += updated_kv_cache + + # Append DeltaNet state tensors (for input_output_aliases) + if ( + not getattr(self.config, "use_hybrid_cache_manager", False) + and hasattr(self, "_deltanet_updated_states") + ): + outputs += self._deltanet_updated_states + + return outputs + + +# ============================================================ +# State Dict Converter (Dense -- no MoE weight handling) +# ============================================================ + + +def convert_qwen35_hf_to_neuron_state_dict(neuron_state_dict, config): + """Convert HF Qwen3.5/3.6-27B weights to NxDI format. + + Weight mappings per layer type: + + DeltaNet layers (linear_attention): + HF: layers.X.linear_attn.{in_proj_qkv, in_proj_z, in_proj_a, in_proj_b, + conv1d, A_log, dt_bias, norm, out_proj} + NxDI: projections keep names; conv1d/A_log/dt_bias are remapped into + ColumnParallelLinear parameter containers so NxD can shard them. + + Full attention layers: + HF: layers.X.self_attn.q_proj.weight: (12288, 5120) -- doubled for gate + NxDI: layers.X.self_attn.Wqkv.weight (fused Q+K+V, gate separated) + layers.X.self_attn.output_gate_proj.weight (gate part) + HF: layers.X.self_attn.{k_proj, v_proj, o_proj, q_norm, k_norm} + NxDI: layers.X.self_attn.{..., q_layernorm, k_layernorm} + + Dense MLP (all layers): + HF: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight + NxDI: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight (same names) + """ + # Add rank_util + neuron_state_dict["rank_util.rank"] = torch.arange( + 0, + config.neuron_config.tp_degree, + dtype=torch.int32, + ) + + def _reorder_deltanet_qkv_for_tp(qkv_weight: torch.Tensor) -> torch.Tensor: + """Pack [Q_all | K_all | V_all] into per-rank Q/K/V blocks. + + ColumnParallelLinear slices the first dimension contiguously. DeltaNet + needs each rank to receive its local query, key, and value heads + together, so the full HF tensor is repacked as: + [rank0 Q | rank0 K | rank0 V | rank1 Q | rank1 K | rank1 V | ...]. + """ + tp_degree = config.neuron_config.tp_degree + num_k_heads = config.linear_num_key_heads + num_v_heads = config.linear_num_value_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim + if num_k_heads % tp_degree != 0: + raise ValueError( + f"linear_num_key_heads={num_k_heads} must be divisible by tp_degree={tp_degree}" + ) + if num_v_heads % tp_degree != 0: + raise ValueError( + f"linear_num_value_heads={num_v_heads} must be divisible by tp_degree={tp_degree}" + ) + + key_dim = num_k_heads * head_k_dim + value_dim = num_v_heads * head_v_dim + q_weight = qkv_weight[:key_dim].reshape(num_k_heads, head_k_dim, -1) + k_weight = qkv_weight[key_dim : 2 * key_dim].reshape(num_k_heads, head_k_dim, -1) + v_weight = qkv_weight[2 * key_dim : 2 * key_dim + value_dim].reshape( + num_v_heads, head_v_dim, -1 + ) + local_k_heads = num_k_heads // tp_degree + local_v_heads = num_v_heads // tp_degree + blocks = [] + for rank in range(tp_degree): + blocks.append( + q_weight[ + rank * local_k_heads : (rank + 1) * local_k_heads + ].reshape(-1, qkv_weight.shape[1]) + ) + blocks.append( + k_weight[ + rank * local_k_heads : (rank + 1) * local_k_heads + ].reshape(-1, qkv_weight.shape[1]) + ) + blocks.append( + v_weight[ + rank * local_v_heads : (rank + 1) * local_v_heads + ].reshape(-1, qkv_weight.shape[1]) + ) + return torch.cat(blocks, dim=0).contiguous() + + def _reorder_deltanet_qkv_channels_for_tp(channel_tensor: torch.Tensor) -> torch.Tensor: + """Repack a first-dimension Q/K/V channel tensor into TP rank blocks.""" + tp_degree = config.neuron_config.tp_degree + num_k_heads = config.linear_num_key_heads + num_v_heads = config.linear_num_value_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim + key_dim = num_k_heads * head_k_dim + value_dim = num_v_heads * head_v_dim + q_tensor = channel_tensor[:key_dim] + k_tensor = channel_tensor[key_dim : 2 * key_dim] + v_tensor = channel_tensor[2 * key_dim : 2 * key_dim + value_dim] + local_key_dim = key_dim // tp_degree + local_value_dim = value_dim // tp_degree + blocks = [] + for rank in range(tp_degree): + blocks.append(q_tensor[rank * local_key_dim : (rank + 1) * local_key_dim]) + blocks.append(k_tensor[rank * local_key_dim : (rank + 1) * local_key_dim]) + blocks.append( + v_tensor[rank * local_value_dim : (rank + 1) * local_value_dim] + ) + return torch.cat(blocks, dim=0).contiguous() + + # CRITICAL: Convert (1+weight) RMSNorm weights to standard RMSNorm weights. + # Qwen3.5 uses RMSNorm with `output = norm(x) * (1 + weight)` where weight + # is initialized to zeros. Standard NxDI RMSNorm uses `output = norm(x) * weight` + # where weight is initialized to ones. To convert: new_weight = old_weight + 1.0 + norm_keys_to_convert = [] + for l in range(config.num_hidden_layers): + norm_keys_to_convert.append(f"layers.{l}.input_layernorm.weight") + norm_keys_to_convert.append(f"layers.{l}.post_attention_layernorm.weight") + if config.layer_types[l] == "full_attention": + norm_keys_to_convert.append(f"layers.{l}.self_attn.q_norm.weight") + norm_keys_to_convert.append(f"layers.{l}.self_attn.k_norm.weight") + norm_keys_to_convert.append("norm.weight") + + for nk in norm_keys_to_convert: + if nk in neuron_state_dict: + old_val = neuron_state_dict[nk] + neuron_state_dict[nk] = old_val.float() + 1.0 + if "layers.0." in nk or nk == "norm.weight": + logger.debug( + f"[NORM FIX] {nk}: mean {old_val.float().mean():.4f} -> {neuron_state_dict[nk].mean():.4f}" + ) + else: + if "layers.0." in nk or nk == "norm.weight": + logger.warning(f"[NORM FIX] key not found: {nk}") + + for l in range(config.num_hidden_layers): + layer_type = config.layer_types[l] + + # === DeltaNet layers === + if layer_type == "linear_attention": + qkv_key = f"layers.{l}.linear_attn.in_proj_qkv.weight" + if qkv_key in neuron_state_dict and config.neuron_config.tp_degree > 1: + neuron_state_dict[qkv_key] = _reorder_deltanet_qkv_for_tp( + neuron_state_dict[qkv_key] + ) + + conv_key = f"layers.{l}.linear_attn.conv1d.weight" + conv_weight_key = f"layers.{l}.linear_attn.conv1d_weight.weight" + if conv_key in neuron_state_dict: + conv_weight = neuron_state_dict.pop(conv_key) + if config.neuron_config.tp_degree > 1: + conv_weight = _reorder_deltanet_qkv_channels_for_tp(conv_weight) + neuron_state_dict[conv_weight_key] = conv_weight.squeeze(1).contiguous() + + for vector_name in ("A_log", "dt_bias"): + vector_key = f"layers.{l}.linear_attn.{vector_name}" + vector_weight_key = f"layers.{l}.linear_attn.{vector_name}_weight.weight" + if vector_key in neuron_state_dict: + neuron_state_dict[vector_weight_key] = ( + neuron_state_dict.pop(vector_key).reshape(-1, 1).contiguous() + ) + + # === Attention layers === + if layer_type == "full_attention": + neuron_state_dict[f"layers.{l}.self_attn.rank_util.rank"] = torch.arange( + 0, + config.neuron_config.tp_degree, + dtype=torch.int32, + ) + + # QK norms: q_norm -> q_layernorm, k_norm -> k_layernorm + q_norm_key = f"layers.{l}.self_attn.q_norm.weight" + k_norm_key = f"layers.{l}.self_attn.k_norm.weight" + if q_norm_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.self_attn.q_layernorm.weight"] = ( + neuron_state_dict.pop(q_norm_key).detach().clone() + ) + if k_norm_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.self_attn.k_layernorm.weight"] = ( + neuron_state_dict.pop(k_norm_key).detach().clone() + ) + + # q_proj is doubled: (12288, 5120) = (num_heads * head_dim * 2, hidden) + # INTERLEAVED: [head0_query(256) | head0_gate(256) | head1_query(256) | ...] + q_proj_key = f"layers.{l}.self_attn.q_proj.weight" + if q_proj_key in neuron_state_dict: + q_proj_w = neuron_state_dict.pop(q_proj_key) + num_heads = config.num_attention_heads # 24 + head_dim = config.head_dim # 256 + q_proj_w = q_proj_w.reshape(num_heads, head_dim * 2, config.hidden_size) + query_w = q_proj_w[:, :head_dim, :] # (24, 256, 5120) + gate_w = q_proj_w[:, head_dim:, :] # (24, 256, 5120) + query_w = query_w.reshape( + num_heads * head_dim, config.hidden_size + ) # (6144, 5120) + gate_w = gate_w.reshape( + num_heads * head_dim, config.hidden_size + ) # (6144, 5120) + + neuron_state_dict[q_proj_key] = query_w + neuron_state_dict[f"layers.{l}.self_attn.output_gate_proj.weight"] = ( + gate_w + ) + + # Fuse QKV + if config.neuron_config.fused_qkv: + q_key = f"layers.{l}.self_attn.q_proj.weight" + k_key = f"layers.{l}.self_attn.k_proj.weight" + v_key = f"layers.{l}.self_attn.v_proj.weight" + if q_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = torch.cat( + [ + neuron_state_dict[q_key], + neuron_state_dict[k_key], + neuron_state_dict[v_key], + ] + ) + del neuron_state_dict[q_key] + del neuron_state_dict[k_key] + del neuron_state_dict[v_key] + + # Dense MLP: no weight conversion needed -- HF and NxDI use same names + # HF: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight + # NxDI: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight + + gc.collect() + + return neuron_state_dict + + +# ============================================================ +# Custom ModelWrapper and DecoderModelInstance for DeltaNet state aliasing +# ============================================================ + + +class Qwen35DecoderModelInstance(DecoderModelInstance): + """Custom DecoderModelInstance that adds DeltaNet state buffers to input_output_aliases.""" + + def get(self, bucket_rank, **kwargs): + """Override to add DeltaNet state aliases after KV cache aliases.""" + module, input_output_aliases = super().get(bucket_rank, **kwargs) + + num_output_from_trace = 1 if not self.neuron_config.output_logits else 2 + + if module.kv_mgr is not None: + num_kv = len(module.kv_mgr.past_key_values) + else: + num_kv = 0 + + state_start_idx = num_output_from_trace + num_kv + + if ( + not getattr(module.config, "use_hybrid_cache_manager", False) + and hasattr(module, "_deltanet_state_params") + ): + for i, param in enumerate(module._deltanet_state_params): + input_output_aliases[param] = state_start_idx + i + + return module, input_output_aliases + + +class Qwen35ModelWrapper(ModelWrapper): + """Custom ModelWrapper for VL support with mRoPE and vision inputs.""" + + def get_model_instance(self): + return Qwen35DecoderModelInstance( + model_cls=self.model_cls, + config=self.config, + **self.model_init_kwargs, + ) + + def input_generator(self): + """Generate inputs including mrope_position_ids, vision_embeddings, and vision_mask.""" + base_inputs = super().input_generator() + extended_inputs = [] + + for bucket_inputs in base_inputs: + input_ids = bucket_inputs[0] + batch_size = input_ids.shape[0] + n_active_tokens = input_ids.shape[1] + + is_cte = n_active_tokens > 1 + + if is_cte: + mrope_position_ids = ( + torch.arange(0, n_active_tokens, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + + vision_embeddings = torch.zeros( + (batch_size, n_active_tokens, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + (batch_size, n_active_tokens, 1), + fill_value=n_active_tokens - 1, + dtype=torch.int32, + ) + else: + mrope_position_ids = torch.zeros((0,), dtype=torch.int32) + vision_embeddings = torch.zeros( + (0,), dtype=self.config.neuron_config.torch_dtype + ) + vision_mask = torch.zeros((0,), dtype=torch.int32) + + padded = list(bucket_inputs) + while len(padded) < 21: + padded.append(torch.zeros((0,), dtype=torch.int32)) + padded.append(mrope_position_ids) # position 21 + padded.append(vision_embeddings) # position 22 + padded.append(vision_mask) # position 23 + + extended_inputs.append(tuple(padded)) + + return extended_inputs + + def pad_inputs(self, *args, pad_type="first_fit"): + """Override to pad mrope_position_ids and vision inputs to bucket size.""" + orig_mrope = args[21] if len(args) >= 22 else None + orig_vis_emb = args[22] if len(args) >= 23 else None + orig_vis_mask = args[23] if len(args) >= 24 else None + + padded_args = super().pad_inputs(*args, pad_type=pad_type) + + if len(padded_args) >= 24 and orig_mrope is not None: + padded_seq_len = padded_args[0].shape[1] + batch_size = padded_args[0].shape[0] + is_cte = padded_seq_len > 1 + + if is_cte: + current_mrope = orig_mrope + current_vis_emb = orig_vis_emb + current_vis_mask = orig_vis_mask + + if ( + current_mrope.ndim == 3 + and current_mrope.shape[-1] != padded_seq_len + ): + orig_len = current_mrope.shape[-1] + pad_size = padded_seq_len - orig_len + last_pos = current_mrope[:, :, -1:] + pad_offsets = torch.arange( + 1, pad_size + 1, dtype=current_mrope.dtype + ) + pad_offsets = ( + pad_offsets.unsqueeze(0).unsqueeze(0).expand(3, batch_size, -1) + ) + mrope_pad = last_pos + pad_offsets + mrope_position_ids = torch.cat([current_mrope, mrope_pad], dim=-1) + elif current_mrope.ndim == 3: + mrope_position_ids = current_mrope + else: + mrope_position_ids = ( + torch.arange(0, padded_seq_len, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + + if ( + current_vis_emb is not None + and current_vis_emb.ndim == 3 + and current_vis_emb.shape[1] < padded_seq_len + ): + pad_emb = torch.zeros( + ( + batch_size, + padded_seq_len - current_vis_emb.shape[1], + current_vis_emb.shape[2], + ), + dtype=current_vis_emb.dtype, + ) + vision_embeddings = torch.cat([current_vis_emb, pad_emb], dim=1) + elif current_vis_emb is not None and current_vis_emb.ndim == 3: + vision_embeddings = current_vis_emb[:, :padded_seq_len] + else: + vision_embeddings = torch.zeros( + (batch_size, padded_seq_len, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + + if ( + current_vis_mask is not None + and current_vis_mask.ndim == 3 + and current_vis_mask.shape[1] < padded_seq_len + ): + pad_mask = torch.full( + (batch_size, padded_seq_len - current_vis_mask.shape[1], 1), + fill_value=padded_seq_len - 1, + dtype=torch.int32, + ) + vision_mask = torch.cat([current_vis_mask, pad_mask], dim=1) + elif current_vis_mask is not None and current_vis_mask.ndim == 3: + vision_mask = current_vis_mask[:, :padded_seq_len] + else: + vision_mask = torch.full( + (batch_size, padded_seq_len, 1), + fill_value=padded_seq_len - 1, + dtype=torch.int32, + ) + + padded_args = ( + *padded_args[:21], + mrope_position_ids, + vision_embeddings, + vision_mask, + ) + + padded_args = list(padded_args) + padded_args[23] = padded_args[23].clamp(max=padded_seq_len - 1) + padded_args = tuple(padded_args) + + return padded_args + + +# ============================================================ +# Top-Level Model +# ============================================================ + + +class NeuronQwen35ForCausalLM(NeuronBaseForCausalLM): + _model_cls = NeuronQwen35Model + + def get_model_wrapper_cls(self): + """Return custom ModelWrapper with DeltaNet state aliasing.""" + return Qwen35ModelWrapper + + @staticmethod + def load_hf_model(model_path, **kwargs): + """Load HF model weights. + + The model is a VL model (Qwen3_5ForConditionalGeneration) but we + only need the text backbone. + """ + from transformers import AutoModelForCausalLM + + kwargs.setdefault("trust_remote_code", True) + return AutoModelForCausalLM.from_pretrained(model_path, **kwargs) + + @classmethod + def get_config_cls(cls): + return Qwen35InferenceConfig + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict, config): + """Strip VL wrapper prefix and convert to NxDI format.""" + new_sd = {} + for k, v in state_dict.items(): + if k.startswith("language_model."): + new_k = k.replace("language_model.", "", 1) + new_sd[new_k] = v + elif k.startswith("model.language_model."): + new_k = k.replace("model.language_model.", "", 1) + new_sd[new_k] = v + elif k.startswith("model.visual") or k.startswith("visual"): + continue # Skip vision encoder + elif k.startswith("model."): + new_sd[k.replace("model.", "", 1)] = v + elif k.startswith("mtp."): + continue # Skip MTP + elif k.startswith("lm_head."): + new_sd[k] = v + else: + new_sd[k] = v + + return convert_qwen35_hf_to_neuron_state_dict(new_sd, config) + + def enable_context_encoding(self): + self.compile_tag = CONTEXT_ENCODING_MODEL_TAG + super().enable_context_encoding() + + def enable_token_generation(self): + self.compile_tag = TOKEN_GENERATION_MODEL_TAG + super().enable_token_generation() + + def _copy_past_key_values(self, outputs): + """Override to also copy DeltaNet state buffers on CPU.""" + super()._copy_past_key_values(outputs) + if getattr(self.config, "use_hybrid_cache_manager", False): + return + + num_output_from_trace = 1 + if ( + self.neuron_config.output_logits + and self.neuron_config.on_device_sampling_config + ): + num_output_from_trace = 2 + + if ( + hasattr(self, "token_generation_model") + and self.token_generation_model is not None + ): + tkg_model = self.token_generation_model.model + cte_model = self.context_encoding_model.model + else: + return + + if tkg_model.kv_mgr is not None: + num_kv = len(tkg_model.kv_mgr.past_key_values) + else: + num_kv = 0 + + state_start = num_output_from_trace + num_kv + + tkg_params = getattr(tkg_model, "_deltanet_state_params", []) + cte_params = getattr(cte_model, "_deltanet_state_params", []) + + if len(tkg_params) > 0 and state_start + len(tkg_params) <= len(outputs): + for i, (tkg_param, cte_param) in enumerate(zip(tkg_params, cte_params)): + new_state = outputs[state_start + i] + tkg_param.data = new_state + cte_param.data = new_state + + def get_required_kwargs(self): + """Return extra kwargs for HF generation loop.""" + return ["llava_args"] + + def _get_model_outputs( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden, + adapter_ids, + medusa_args, + llava_args, + slot_mapping=None, + block_table=None, + full_context_lens=None, + computed_context_lens=None, + tf_args=None, + ): + """Override to pass all 24 positional args explicitly.""" + is_prefill = self._is_prefill(position_ids) or ( + getattr(self.config, "use_qwen_hybrid_chunked_prefill", False) + and input_ids.shape[-1] > 1 + ) + + seq_len = input_ids.shape[1] + batch_size = input_ids.shape[0] + + if llava_args and len(llava_args) >= 2: + vision_embeddings = llava_args[0] + vision_mask = llava_args[1] + if len(llava_args) >= 3: + mrope_position_ids = llava_args[2] + else: + mrope_position_ids = None + elif is_prefill: + vision_embeddings = torch.zeros( + (batch_size, seq_len, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + (batch_size, seq_len, 1), + fill_value=seq_len - 1, + dtype=torch.int32, + ) + mrope_position_ids = None + else: + vision_embeddings = torch.zeros((0,), dtype=torch.float32) + vision_mask = torch.zeros((0,), dtype=torch.int32) + mrope_position_ids = None + + if is_prefill: + if mrope_position_ids is None: + mrope_position_ids = ( + torch.arange(0, seq_len, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + else: + mrope_position_ids = torch.zeros((0,), dtype=torch.int32) + + empties = [torch.empty(0) for _ in range(14)] + + if is_prefill: + ctx_bs = self.context_encoding_model.neuron_config.batch_size + output_logits = [] + + for cb in range(0, batch_size, ctx_bs): + cb_end = min(cb + ctx_bs, batch_size) + actual_chunk = cb_end - cb + + chunk_input_ids = input_ids[cb:cb_end] + chunk_attn_mask = attention_mask[cb:cb_end] + chunk_pos_ids = position_ids[cb:cb_end] + chunk_seq_ids = seq_ids[cb:cb_end] + chunk_sampling = sampling_params[cb:cb_end] + chunk_prev_hidden = ( + prev_hidden[cb:cb_end] + if prev_hidden is not None + and hasattr(prev_hidden, "ndim") + and prev_hidden.ndim > 0 + and prev_hidden.shape[0] > 0 + else prev_hidden + ) + chunk_adapter_ids = ( + adapter_ids[cb:cb_end] + if adapter_ids is not None + and hasattr(adapter_ids, "ndim") + and adapter_ids.ndim > 0 + and adapter_ids.shape[0] > 0 + else adapter_ids + ) + + if mrope_position_ids.ndim == 3: + chunk_mrope = mrope_position_ids[:, cb:cb_end, :] + else: + chunk_mrope = mrope_position_ids + + if vision_embeddings.ndim == 3: + chunk_vis_emb = vision_embeddings[cb:cb_end] + chunk_vis_mask = vision_mask[cb:cb_end] + else: + chunk_vis_emb = vision_embeddings + chunk_vis_mask = vision_mask + + if actual_chunk < ctx_bs: + pad_n = ctx_bs - actual_chunk + chunk_input_ids = torch.cat( + [chunk_input_ids, chunk_input_ids[:1].expand(pad_n, -1)], dim=0 + ) + chunk_attn_mask = torch.cat( + [chunk_attn_mask, chunk_attn_mask[:1].expand(pad_n, -1)], dim=0 + ) + chunk_pos_ids = torch.cat( + [chunk_pos_ids, chunk_pos_ids[:1].expand(pad_n, -1)], dim=0 + ) + pad_seq = torch.arange( + batch_size, batch_size + pad_n, dtype=chunk_seq_ids.dtype + ) + chunk_seq_ids = torch.cat([chunk_seq_ids, pad_seq], dim=0) + chunk_sampling = torch.cat( + [chunk_sampling, chunk_sampling[:1].expand(pad_n, -1)], dim=0 + ) + if ( + chunk_prev_hidden is not None + and hasattr(chunk_prev_hidden, "ndim") + and chunk_prev_hidden.ndim > 0 + and chunk_prev_hidden.shape[0] > 0 + ): + chunk_prev_hidden = torch.cat( + [ + chunk_prev_hidden, + chunk_prev_hidden[:1].expand(pad_n, -1), + ], + dim=0, + ) + if ( + chunk_adapter_ids is not None + and hasattr(chunk_adapter_ids, "ndim") + and chunk_adapter_ids.ndim > 0 + and chunk_adapter_ids.shape[0] > 0 + ): + chunk_adapter_ids = torch.cat( + [ + chunk_adapter_ids, + chunk_adapter_ids[:1].expand(pad_n, -1), + ], + dim=0, + ) + if chunk_mrope.ndim == 3: + chunk_mrope = torch.cat( + [chunk_mrope, chunk_mrope[:, :1, :].expand(-1, pad_n, -1)], + dim=1, + ) + if chunk_vis_emb.ndim == 3: + chunk_vis_emb = torch.cat( + [ + chunk_vis_emb, + torch.zeros( + (pad_n,) + chunk_vis_emb.shape[1:], + dtype=chunk_vis_emb.dtype, + ), + ], + dim=0, + ) + chunk_vis_mask = torch.cat( + [ + chunk_vis_mask, + torch.full( + (pad_n,) + chunk_vis_mask.shape[1:], + fill_value=seq_len - 1, + dtype=chunk_vis_mask.dtype, + ), + ], + dim=0, + ) + + chunk_out = self.context_encoding_model( + chunk_input_ids, + chunk_attn_mask, + chunk_pos_ids, + chunk_seq_ids, + chunk_sampling, + chunk_prev_hidden, + chunk_adapter_ids, + *empties, + chunk_mrope, + chunk_vis_emb, + chunk_vis_mask, + ) + if actual_chunk < ctx_bs: + chunk_out = chunk_out[:actual_chunk] + output_logits.append(chunk_out) + + outputs = ( + torch.cat(output_logits, dim=0) + if len(output_logits) > 1 + else output_logits[0] + ) + self.kv_cache_populated = True + is_run_on_neuron = self.context_encoding_model.is_neuron() + else: + outputs = self.token_generation_model( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden, + adapter_ids, + *empties, + mrope_position_ids, + vision_embeddings, + vision_mask, + ) + is_run_on_neuron = self.token_generation_model.is_neuron() + + return outputs, is_run_on_neuron + + def get_compiler_args(self): + if self.compile_tag == CONTEXT_ENCODING_MODEL_TAG: + optimization_level = "-O1" + else: + optimization_level = "-O1" + + compiler_args = ( + "--enable-saturate-infinity " + "--enable-mixed-precision-accumulation " + f"--model-type transformer {optimization_level} " + "--auto-cast=none " + ) + return compiler_args diff --git a/contrib/models/Qwen3.6-27B/src/modeling_qwen35_vision.py b/contrib/models/Qwen3.6-27B/src/modeling_qwen35_vision.py new file mode 100644 index 00000000..761d7e95 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/modeling_qwen35_vision.py @@ -0,0 +1,819 @@ +""" +Qwen3.5-27B / Qwen3.6-27B (Dense) Vision Encoder for NeuronX Distributed Inference. + +Ports the Qwen3.5/3.6 ViT encoder to run on Neuron. The vision encoder +architecture is identical across Qwen3.5-27B and Qwen3.6-27B (same patch +embed, same rotary, same merger) -- only out_hidden_size changes vs the MoE +variant (5120 vs 2048, read from config). + +The vision encoder runs as a separate compiled model from the text decoder, +compiled and loaded via NeuronBaseForImageToText. +""" + +import logging +import math +import os +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# CRITICAL: Use finite negative value instead of -inf for Neuron attention masks. +# The Neuron compiler's bfloat16 handling of -inf produces NaN that bleeds from +# padding positions into ALL positions through the transformer layers. +# -65504.0 is large enough for softmax masking but avoids NaN overflow. +_MASK_NEG_INF = -65504.0 + +logger = logging.getLogger(__name__) + +# -- NxDI imports (available on Neuron instances) -- +try: + from neuronx_distributed_inference.models.application_base import ( + NeuronApplicationBase, + ) + from neuronx_distributed_inference.models.model_wrapper import ModelWrapper + from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, + ) + from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding + from neuronx_distributed.parallel_layers import layers as nxd_layers +except ImportError: + logger.warning( + "NxDI imports unavailable -- vision module can only be used on Neuron instances" + ) + +# -- HuggingFace imports for patch embed (runs on CPU) -- +try: + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + Qwen3_5MoeVisionPatchEmbed, + Qwen3_5MoeVisionPatchMerger, + Qwen3_5MoeVisionRotaryEmbedding, + ) +except ImportError: + try: + # transformers 4.57+ uses Qwen3VL* class names + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLVisionPatchEmbed as Qwen3_5MoeVisionPatchEmbed, + Qwen3VLVisionPatchMerger as Qwen3_5MoeVisionPatchMerger, + Qwen3VLVisionRotaryEmbedding as Qwen3_5MoeVisionRotaryEmbedding, + ) + except ImportError: + try: + # Older transformers uses Qwen2VL* class names + from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLVisionPatchEmbed as Qwen3_5MoeVisionPatchEmbed, + Qwen2VLVisionPatchMerger as Qwen3_5MoeVisionPatchMerger, + Qwen2VLVisionRotaryEmbedding as Qwen3_5MoeVisionRotaryEmbedding, + ) + except ImportError: + Qwen3_5MoeVisionPatchEmbed = None + Qwen3_5MoeVisionPatchMerger = None + Qwen3_5MoeVisionRotaryEmbedding = None + + +def apply_rotary_pos_emb_vision(q, k, cos, sin): + """Apply rotary position embeddings to vision Q and K tensors. + + Uses rotate_half style (matching HF reference): + q_embed = (q * cos) + (rotate_half(q) * sin) + + Args: + q: (seq_len, num_heads, head_dim) + k: (seq_len, num_heads, head_dim) + cos: (seq_len, head_dim) + sin: (seq_len, head_dim) + """ + cos = cos.unsqueeze(-2) # (seq_len, 1, head_dim) + sin = sin.unsqueeze(-2) + + def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +class NeuronQwen35VisionAttention(nn.Module): + """Vision attention for Qwen3.5 MoE. + + Uses fused QKV linear (no bias in Neuron port for efficiency). + Non-causal attention with block-diagonal mask for variable-length images. + """ + + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.hidden_size // self.num_heads + self.scaling = self.head_dim**-0.5 + + # Fused QKV: (hidden_size -> 3 * hidden_size) with bias + self.qkv = nxd_layers.ColumnParallelLinear( + self.hidden_size, + 3 * self.hidden_size, + bias=True, + gather_output=True, + ) + self.proj = nxd_layers.RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + input_is_parallel=False, + ) + + def forward(self, hidden_states, attention_mask=None, position_embeddings=None): + """ + Args: + hidden_states: (seq_len, hidden_size) + attention_mask: (1, 1, seq_len, seq_len) block-diagonal mask + position_embeddings: (cos, sin) tuple + """ + seq_len = hidden_states.shape[0] + + # QKV projection + qkv = self.qkv(hidden_states) # (seq_len, 3 * hidden_size) + qkv = qkv.reshape(seq_len, 3, self.num_heads, self.head_dim) + qkv = qkv.permute(1, 0, 2, 3) # (3, seq_len, num_heads, head_dim) + q, k, v = qkv.unbind(0) # each (seq_len, num_heads, head_dim) + + # Apply rotary embeddings + if position_embeddings is not None: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + # Reshape for batched attention: (1, num_heads, seq_len, head_dim) + q = q.transpose(0, 1).unsqueeze(0) + k = k.transpose(0, 1).unsqueeze(0) + v = v.transpose(0, 1).unsqueeze(0) + + # Scaled dot-product attention + attn_weights = torch.matmul(q, k.transpose(-1, -2)) * self.scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + + # Reshape back: (seq_len, hidden_size) + attn_output = attn_output.squeeze(0).transpose(0, 1).reshape(seq_len, -1) + + # Output projection + attn_output = self.proj(attn_output) + return attn_output + + +class NeuronQwen35VisionMLP(nn.Module): + """Vision MLP with GELU activation.""" + + def __init__(self, config): + super().__init__() + self.linear_fc1 = nxd_layers.ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + gather_output=True, + ) + self.linear_fc2 = nxd_layers.RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + input_is_parallel=False, + ) + self.act_fn = nn.GELU() + + def forward(self, hidden_states): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_states))) + + +class NeuronQwen35VisionBlock(nn.Module): + """Single vision transformer block: LayerNorm + Attention + LayerNorm + MLP.""" + + def __init__(self, config): + super().__init__() + self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.attn = NeuronQwen35VisionAttention(config) + self.mlp = NeuronQwen35VisionMLP(config) + + def forward(self, hidden_states, attention_mask=None, position_embeddings=None): + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class NeuronQwen35VisionModel(nn.Module): + """Qwen3.5 MoE Vision Encoder for Neuron. + + This is the nn.Module that gets compiled and traced onto Neuron. + Patch embedding, positional embedding, and rotary embedding are computed + on CPU in the ModelWrapper and passed as inputs. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.blocks = nn.ModuleList( + [NeuronQwen35VisionBlock(config) for _ in range(config.depth)] + ) + # Merger: spatial_merge_size^2 * hidden_size -> out_hidden_size + self.merger_norm = nn.LayerNorm(config.hidden_size, eps=1e-6) + merger_hidden = config.hidden_size * (config.spatial_merge_size**2) + self.merger_fc1 = nn.Linear(merger_hidden, merger_hidden) + self.merger_act = nn.GELU() + self.merger_fc2 = nn.Linear(merger_hidden, config.out_hidden_size) + + def forward(self, hidden_states, attention_mask=None, position_embeddings=None): + """ + Args: + hidden_states: (seq_len, hidden_size) -- after patch_embed + pos_embed + attention_mask: (1, 1, seq_len, seq_len) block-diagonal mask + position_embeddings: (cos, sin) tuple for rotary + + Returns: + vision_embeddings: (merged_seq_len, out_hidden_size) + """ + for block in self.blocks: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) + + # Apply merger: norm -> spatial merge -> fc1 -> gelu -> fc2 + hidden_states = self.merger_norm(hidden_states) + merge_size = self.config.spatial_merge_size + merged_hidden = self.config.hidden_size * (merge_size**2) + hidden_states = hidden_states.view(-1, merged_hidden) + hidden_states = self.merger_fc2(self.merger_act(self.merger_fc1(hidden_states))) + + return hidden_states + + +class CPUVisionModel(nn.Module): + """CPU-only vision encoder (pure PyTorch, no Neuron dependencies). + + Used when HBM is insufficient to load the vision encoder on Neuron + alongside the text decoder (e.g., 27B dense model on trn2.3xlarge). + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.blocks = nn.ModuleList( + [self._make_block(config) for _ in range(config.depth)] + ) + self.merger_norm = nn.LayerNorm(config.hidden_size, eps=1e-6) + merger_hidden = config.hidden_size * (config.spatial_merge_size**2) + self.merger_fc1 = nn.Linear(merger_hidden, merger_hidden) + self.merger_act = nn.GELU() + self.merger_fc2 = nn.Linear(merger_hidden, config.out_hidden_size) + + @staticmethod + def _make_block(config): + """Build a single vision block with standard nn.Linear (no TP).""" + block = nn.Module() + block.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + block.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + + # Attention + attn = nn.Module() + attn.hidden_size = config.hidden_size + attn.num_heads = config.num_heads + attn.head_dim = config.hidden_size // config.num_heads + attn.scaling = attn.head_dim**-0.5 + attn.qkv = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=True) + attn.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True) + block.attn = attn + + # MLP + mlp = nn.Module() + mlp.linear_fc1 = nn.Linear( + config.hidden_size, config.intermediate_size, bias=True + ) + mlp.linear_fc2 = nn.Linear( + config.intermediate_size, config.hidden_size, bias=True + ) + mlp.act_fn = nn.GELU() + block.mlp = mlp + + return block + + def _forward_attention(self, attn, hidden_states, attention_mask, cos, sin): + seq_len = hidden_states.shape[0] + qkv = attn.qkv(hidden_states).reshape(seq_len, 3, attn.num_heads, attn.head_dim) + qkv = qkv.permute(1, 0, 2, 3) + q, k, v = qkv.unbind(0) + + if cos is not None and sin is not None: + cos_u = cos.unsqueeze(-2) + sin_u = sin.unsqueeze(-2) + + def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + q = (q * cos_u) + (rotate_half(q) * sin_u) + k = (k * cos_u) + (rotate_half(k) * sin_u) + + q = q.transpose(0, 1).unsqueeze(0) + k = k.transpose(0, 1).unsqueeze(0) + v = v.transpose(0, 1).unsqueeze(0) + + attn_weights = torch.matmul(q, k.transpose(-1, -2)) * attn.scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + out = torch.matmul(attn_weights, v) + out = out.squeeze(0).transpose(0, 1).reshape(seq_len, -1) + return attn.proj(out) + + def forward(self, hidden_states, attention_mask, cos, sin): + for block in self.blocks: + hidden_states = hidden_states + self._forward_attention( + block.attn, block.norm1(hidden_states), attention_mask, cos, sin + ) + hidden_states = hidden_states + block.mlp.linear_fc2( + block.mlp.act_fn(block.mlp.linear_fc1(block.norm2(hidden_states))) + ) + + hidden_states = self.merger_norm(hidden_states) + merge_size = self.config.spatial_merge_size + merged_hidden = self.config.hidden_size * (merge_size**2) + hidden_states = hidden_states.view(-1, merged_hidden) + hidden_states = self.merger_fc2(self.merger_act(self.merger_fc1(hidden_states))) + return hidden_states + + +class NeuronQwen35VisionModelWrapper(ModelWrapper): + """Wraps the vision encoder for NxDI tracing. + + Handles CPU-side operations that cannot be traced: + - Patch embedding (Conv3d) + - Positional embedding (Embedding + bilinear interpolation) + - Rotary position embedding computation + - Vision attention mask construction (block-diagonal) + - Sequence length bucketing and padding/unpadding + + Supports three modes: + 1. NxDI traced model (parallel layers) -- standard NxDI compilation + 2. Pre-compiled standalone model -- loaded from torch_neuronx.trace() output + 3. CPU-only model -- for when HBM is full (e.g., 27B dense on trn2.3xlarge) + """ + + def __init__(self, config, model_cls=None, **kwargs): + if model_cls is not None: + super().__init__(config, model_cls, **kwargs) + else: + # Standalone mode: no NxDI model_cls + nn.Module.__init__(self) + self.vision_config = config + self._compiled_model = None # Set by load_compiled() -- single bucket + self._compiled_buckets = None # Set by load_compiled() -- multi-bucket dict + self._cpu_model = None # Set by load_cpu_model() + + # These HF modules run on CPU, outside the traced graph + if Qwen3_5MoeVisionPatchEmbed is not None: + self.patch_embed = Qwen3_5MoeVisionPatchEmbed(config) + self.pos_embed = nn.Embedding( + config.num_position_embeddings, config.hidden_size + ) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen3_5MoeVisionRotaryEmbedding(head_dim // 2) + else: + logger.warning("HF Qwen3.5 MoE vision classes not available") + + self.vision_seq_len_buckets = kwargs.get( + "vision_seq_len_buckets", [1024, 4096, 16384] + ) + + def load_compiled(self, compiled_model_path): + """Load pre-compiled standalone vision encoder(s). + + Supports two modes: + 1. Single .pt file: Legacy mode, loads one compiled model for one bucket size. + 2. Directory with multiple .pt files: Multi-bucket mode. Files must be named + 'vision_encoder_{bucket_size}.pt' (e.g., 'vision_encoder_256.pt'). + Falls back to single 'vision_encoder.pt' in the directory. + + Args: + compiled_model_path: Path to a .pt file or directory containing bucket .pt files. + """ + import glob as glob_module + + logger.info(f"Loading pre-compiled vision encoder from {compiled_model_path}") + + if os.path.isfile(compiled_model_path): + # Single file mode (legacy) + self._compiled_model = torch.jit.load(compiled_model_path) + self._compiled_buckets = None + logger.info("Vision encoder loaded successfully (single bucket)") + elif os.path.isdir(compiled_model_path): + # Directory mode: look for bucket-specific files + bucket_files = sorted( + glob_module.glob( + os.path.join(compiled_model_path, "vision_encoder_*.pt") + ) + ) + if bucket_files: + self._compiled_buckets = {} + for bf in bucket_files: + # Extract bucket size from filename: vision_encoder_256.pt -> 256 + basename = os.path.basename(bf) + try: + bucket_size = int( + basename.replace("vision_encoder_", "").replace(".pt", "") + ) + self._compiled_buckets[bucket_size] = torch.jit.load(bf) + logger.info(f" Loaded vision bucket {bucket_size} from {bf}") + except ValueError: + logger.warning(f" Skipping unrecognized file: {bf}") + self._compiled_model = None + # Update vision_seq_len_buckets to match compiled buckets + self.vision_seq_len_buckets = sorted(self._compiled_buckets.keys()) + logger.info( + f"Vision encoder loaded with {len(self._compiled_buckets)} buckets: " + f"{self.vision_seq_len_buckets}" + ) + else: + # Fall back to single vision_encoder.pt in directory + single_path = os.path.join(compiled_model_path, "vision_encoder.pt") + if os.path.exists(single_path): + self._compiled_model = torch.jit.load(single_path) + self._compiled_buckets = None + logger.info( + "Vision encoder loaded successfully (single file in dir)" + ) + else: + raise FileNotFoundError( + f"No vision encoder files found in {compiled_model_path}" + ) + else: + raise FileNotFoundError( + f"Vision encoder path not found: {compiled_model_path}" + ) + + def load_vision_weights_from_hf(self, model_path): + """Load patch_embed and pos_embed weights from HF safetensors. + + Args: + model_path: Path to HF model directory + """ + from pathlib import Path + from safetensors import safe_open + + st_files = sorted( + p + for p in Path(model_path).glob("*.safetensors") + if p.suffix == ".safetensors" + ) + loaded = 0 + for sf_path in st_files: + with safe_open(str(sf_path), framework="pt") as f: + for key in f.keys(): + if key == "model.visual.patch_embed.proj.weight": + self.patch_embed.proj.weight.data.copy_(f.get_tensor(key)) + loaded += 1 + elif key == "model.visual.patch_embed.proj.bias": + self.patch_embed.proj.bias.data.copy_(f.get_tensor(key)) + loaded += 1 + elif key == "model.visual.pos_embed.weight": + self.pos_embed.weight.data.copy_(f.get_tensor(key)) + loaded += 1 + logger.info(f"Loaded {loaded} CPU-side vision weight tensors from HF") + + def load_cpu_model(self, model_path): + """Load a CPU-only vision encoder from HF safetensors. + + Use this when HBM is insufficient for the Neuron-compiled vision encoder + (e.g., 27B dense model fills trn2.3xlarge HBM). + + Args: + model_path: Path to HF model directory with safetensors + """ + from pathlib import Path + from safetensors import safe_open + + config = self.vision_config + cpu_model = CPUVisionModel(config) + + # Build key mapping from HF safetensors to CPU model + key_map = {} + for i in range(config.depth): + hf_pre = f"model.visual.blocks.{i}" + loc_pre = f"blocks.{i}" + for suffix in [ + "attn.qkv.weight", + "attn.qkv.bias", + "attn.proj.weight", + "attn.proj.bias", + "mlp.linear_fc1.weight", + "mlp.linear_fc1.bias", + "mlp.linear_fc2.weight", + "mlp.linear_fc2.bias", + "norm1.weight", + "norm1.bias", + "norm2.weight", + "norm2.bias", + ]: + key_map[f"{hf_pre}.{suffix}"] = f"{loc_pre}.{suffix}" + + key_map["model.visual.merger.norm.weight"] = "merger_norm.weight" + key_map["model.visual.merger.norm.bias"] = "merger_norm.bias" + key_map["model.visual.merger.linear_fc1.weight"] = "merger_fc1.weight" + key_map["model.visual.merger.linear_fc1.bias"] = "merger_fc1.bias" + key_map["model.visual.merger.linear_fc2.weight"] = "merger_fc2.weight" + key_map["model.visual.merger.linear_fc2.bias"] = "merger_fc2.bias" + + st_files = sorted(Path(model_path).glob("model*.safetensors")) + loaded = 0 + state_dict = cpu_model.state_dict() + + for sf_path in st_files: + with safe_open(str(sf_path), framework="pt") as f: + for key in f.keys(): + if key in key_map: + local_key = key_map[key] + if local_key in state_dict: + state_dict[local_key].copy_(f.get_tensor(key)) + loaded += 1 + + cpu_model.load_state_dict(state_dict) + cpu_model = cpu_model.to(torch.bfloat16).eval() + self._cpu_model = cpu_model + logger.info( + f"Loaded CPU vision encoder: {loaded} weights, " + f"{sum(p.numel() for p in cpu_model.parameters()) / 1e6:.1f}M params" + ) + + def _get_vision_bucket(self, seq_len): + """Find the smallest bucket that fits the sequence length.""" + for bucket in sorted(self.vision_seq_len_buckets): + if seq_len <= bucket: + return bucket + return self.vision_seq_len_buckets[-1] + + def rot_pos_emb(self, grid_thw): + """Compute rotary positional embeddings for vision tokens. + + Returns: (total_tokens, head_dim) tensor of rotary frequencies. + """ + merge_size = self.vision_config.spatial_merge_size + grid_thw_list = grid_thw.tolist() + + max_hw = max(max(h, w) for _, h, w in grid_thw_list) + freq_table = self.rotary_pos_emb(max_hw) + device = freq_table.device + + total_tokens = sum(t * h * w for t, h, w in grid_thw_list) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw_list: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) + block_cols = torch.arange(merged_w, device=device) + intra_row = torch.arange(merge_size, device=device) + intra_col = torch.arange(merge_size, device=device) + + row_idx = ( + block_rows[:, None, None, None] * merge_size + + intra_row[None, None, :, None] + ) + col_idx = ( + block_cols[None, :, None, None] * merge_size + + intra_col[None, None, None, :] + ) + + row_idx = row_idx.expand( + merged_h, merged_w, merge_size, merge_size + ).reshape(-1) + col_idx = col_idx.expand( + merged_h, merged_w, merge_size, merge_size + ).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + embeddings = freq_table[pos_ids] + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + """Bilinear interpolation of positional embeddings for variable resolution.""" + grid_thw_list = grid_thw.tolist() + grid_ts = [row[0] for row in grid_thw_list] + grid_hs = [row[1] for row in grid_thw_list] + grid_ws = [row[2] for row in grid_thw_list] + device = self.pos_embed.weight.device + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in grid_thw_list: + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) + weight_tensor = torch.tensor( + weight_list, dtype=self.pos_embed.weight.dtype, device=device + ) + pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split( + [h * w for h, w in zip(grid_hs, grid_ws)] + ) + + merge_size = self.vision_config.spatial_merge_size + patch_pos_embeds_permute = [] + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view( + t, h // merge_size, merge_size, w // merge_size, merge_size, -1 + ) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + + return torch.cat(patch_pos_embeds_permute) + + def _build_vision_attention_mask(self, grid_thw, seq_len, dtype): + """Build block-diagonal attention mask for variable-length images. + + Each image gets its own attention block (no cross-image attention). + """ + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + # Build block-diagonal mask + mask = torch.full((seq_len, seq_len), _MASK_NEG_INF, dtype=dtype) + for i in range(len(cu_seqlens) - 1): + start = cu_seqlens[i].item() + end = cu_seqlens[i + 1].item() + mask[start:end, start:end] = 0.0 + + return mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) + + def forward(self, pixel_values, image_grid_thw): + """Run vision encoding (CPU preprocessing + Neuron traced model). + + Args: + pixel_values: Raw pixel values from HF processor + image_grid_thw: (num_images, 3) -- temporal, height, width in patches + + Returns: + vision_embeddings: (total_merged_tokens, out_hidden_size) + """ + # 1. Patch embedding (CPU, Conv3d) + hidden_states = self.patch_embed(pixel_values) + + # 2. Positional embedding (CPU, bilinear interpolation) + pos_embeds = self.fast_pos_embed_interpolate(image_grid_thw) + hidden_states = hidden_states + pos_embeds + + # 3. Rotary position embeddings (CPU) + rotary_pos_emb = self.rot_pos_emb(image_grid_thw) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + # 4. Vision attention mask (block-diagonal) + seq_len = hidden_states.shape[0] + attention_mask = self._build_vision_attention_mask( + image_grid_thw, seq_len, hidden_states.dtype + ) + + # 5. Bucket and pad for Neuron compilation + bucket_len = self._get_vision_bucket(seq_len) + cos, sin = position_embeddings + if seq_len < bucket_len: + pad_len = bucket_len - seq_len + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_len)) + cos = F.pad(cos, (0, 0, 0, pad_len)) + sin = F.pad(sin, (0, 0, 0, pad_len)) + # Extend mask with _MASK_NEG_INF for padded positions (NOT -inf, which causes NaN on Neuron) + mask = torch.full( + (1, 1, bucket_len, bucket_len), _MASK_NEG_INF, dtype=hidden_states.dtype + ) + mask[:, :, :seq_len, :seq_len] = attention_mask + attention_mask = mask + + # 6. Run vision model (Neuron compiled or CPU fallback) + if self._compiled_buckets is not None: + # Multi-bucket mode: select the compiled model for this bucket + if bucket_len not in self._compiled_buckets: + raise RuntimeError( + f"No compiled vision encoder for bucket size {bucket_len}. " + f"Available buckets: {sorted(self._compiled_buckets.keys())}. " + f"Input seq_len={seq_len} requires bucket {bucket_len}." + ) + compiled_model = self._compiled_buckets[bucket_len] + vision_output = compiled_model( + hidden_states.to(torch.bfloat16), + attention_mask.to(torch.bfloat16), + cos.to(torch.bfloat16), + sin.to(torch.bfloat16), + ) + elif self._compiled_model is not None: + # Single compiled model (legacy) + vision_output = self._compiled_model( + hidden_states.to(torch.bfloat16), + attention_mask.to(torch.bfloat16), + cos.to(torch.bfloat16), + sin.to(torch.bfloat16), + ) + elif self._cpu_model is not None: + # CPU-only mode: run vision encoder on CPU (no bucketing/padding needed + # but we pad anyway for consistency with the same merger math) + with torch.no_grad(): + vision_output = self._cpu_model( + hidden_states.to(torch.bfloat16), + attention_mask.to(torch.bfloat16), + cos.to(torch.bfloat16), + sin.to(torch.bfloat16), + ) + else: + # NxDI traced model: takes (hidden_states, attention_mask, position_embeddings) + vision_output = self.model(hidden_states, attention_mask, (cos, sin)) + + # 7. Unpad: only keep valid merged tokens + merge_area = self.vision_config.spatial_merge_size**2 + total_merged_tokens = sum( + t + * (h // self.vision_config.spatial_merge_size) + * (w // self.vision_config.spatial_merge_size) + for t, h, w in image_grid_thw.tolist() + ) + vision_output = vision_output[:total_merged_tokens] + + return vision_output + + +class NeuronQwen35VisionForImageEncoding(NeuronApplicationBase): + """Standalone application class for vision encoding (for testing).""" + + model_cls = NeuronQwen35VisionModel + model_wrapper_cls = NeuronQwen35VisionModelWrapper + + @staticmethod + def prepare_input_args(image_path, processor): + """Prepare vision inputs from an image path. + + Args: + image_path: Path to image file + processor: HF AutoProcessor + + Returns: + pixel_values, image_grid_thw + """ + from PIL import Image + + image = Image.open(image_path).convert("RGB") + inputs = processor(images=image, return_tensors="pt") + return inputs["pixel_values"], inputs["image_grid_thw"] diff --git a/contrib/models/Qwen3.6-27B/src/modeling_qwen35_vl.py b/contrib/models/Qwen3.6-27B/src/modeling_qwen35_vl.py new file mode 100644 index 00000000..e3afbb1b --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/modeling_qwen35_vl.py @@ -0,0 +1,662 @@ +""" +Qwen3.5-27B / Qwen3.6-27B Vision-Language Model Orchestrator for NeuronX Distributed Inference. + +This is the top-level VL model that wires together: +- The vision encoder (modeling_qwen35_vision.py) +- The text decoder (modeling_qwen35.py, dense model with vision injection) + +It handles: +- Multimodal RoPE (mRoPE) with interleaved layout +- Vision embedding injection via scatter_by_index_put +- Separate compilation and loading of vision and text models +- The CTE+TKG generation loop with vision inputs + +Architecture follows the NxDI NeuronBaseForImageToText pattern established +by Qwen3-VL in SDK 2.28, adapted for Qwen3.5/3.6 dense model's unique features: +- No deepstack (Qwen3.5/3.6 does not use intermediate vision feature injection) +- DeltaNet linear attention layers in the text decoder +- Dense SwiGLU MLP layers in the text decoder +- Interleaved mRoPE (THWTHW... layout) instead of Qwen3-VL's section-based layout +""" + +import logging +import os +from typing import Optional + +import torch +import torch.nn.functional as F + +logger = logging.getLogger(__name__) + +# NxDI imports +try: + from neuronx_distributed_inference.models.image_to_text_model_base import ( + ImageToTextInferenceConfig, + NeuronBaseForImageToText, + ) + from neuronx_distributed_inference.models.config import NeuronConfig + + HAS_NXDI_VL = True +except ImportError: + HAS_NXDI_VL = False + logger.warning("NxDI VL base classes not available -- VL model requires SDK 2.28+") + +# Local imports +try: + from src.modeling_qwen35 import ( + NeuronQwen35ForCausalLM, + NeuronQwen35Model, + Qwen35InferenceConfig, + Qwen35ModelWrapper, + ) + from src.modeling_qwen35_vision import ( + NeuronQwen35VisionModel, + NeuronQwen35VisionModelWrapper, + ) +except ImportError: + from modeling_qwen35 import ( + NeuronQwen35ForCausalLM, + NeuronQwen35Model, + Qwen35InferenceConfig, + Qwen35ModelWrapper, + ) + from modeling_qwen35_vision import ( + NeuronQwen35VisionModel, + NeuronQwen35VisionModelWrapper, + ) + + +def get_rope_index( + input_ids, + image_grid_thw=None, + video_grid_thw=None, + attention_mask=None, + image_token_id=248056, + video_token_id=248057, + vision_start_token_id=248053, + spatial_merge_size=2, +): + """Compute 3D multimodal RoPE position IDs for Qwen3.5. + + Returns position_ids of shape (3, batch_size, seq_len) where: + - Axis 0: temporal position + - Axis 1: height position + - Axis 2: width position + + For text tokens, all 3 axes have the same sequential position. + For vision tokens, each axis encodes the spatial/temporal grid position. + + Also returns rope_deltas for use during TKG decoding. + + Adapted from HuggingFace Qwen3_5Model.get_rope_index(). + """ + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave( + video_grid_thw, video_grid_thw[:, 0], dim=0 + ) + video_grid_thw[:, 0] = 1 + + image_grid_thw_list = ( + image_grid_thw.tolist() if image_grid_thw is not None else None + ) + video_grid_thw_list = ( + video_grid_thw.tolist() if video_grid_thw is not None else None + ) + + mrope_position_deltas = [] + total_input_ids = input_ids + + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + + position_ids = torch.zeros( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + + for i, ids in enumerate(total_input_ids): + ids = ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + + vision_start_indices = torch.argwhere(ids == vision_start_token_id).squeeze(1) + if len(vision_start_indices) > 0: + vision_tokens = ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + + input_tokens = ids.tolist() + llm_pos_ids_list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + + if ed_image < ed_video: + t, h, w = image_grid_thw_list[image_index] + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = video_grid_thw_list[video_index] + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t = t + llm_grid_h = h // spatial_merge_size + llm_grid_w = w // spatial_merge_size + + text_len = ed - st + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( + position_ids.device + ) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device + ).unsqueeze(1) + return position_ids, mrope_position_deltas + + +class Qwen35VLInferenceConfig: + """Configuration for the full VL model (text + vision). + + Wraps the existing Qwen35InferenceConfig for text and adds + vision-specific settings. + """ + + def __init__( + self, + text_config, + vision_config, + image_token_id=248056, + video_token_id=248057, + vision_start_token_id=248053, + vision_end_token_id=248054, + spatial_merge_size=2, + vision_seq_len_buckets=None, + **kwargs, + ): + """ + Args: + text_config: Qwen35InferenceConfig instance for the text decoder + vision_config: dict with vision encoder hyperparams (depth, hidden_size, etc.) + image_token_id: Token ID for image placeholder tokens + video_token_id: Token ID for video placeholder tokens + vision_start_token_id: Token ID for <|vision_start|> + vision_end_token_id: Token ID for <|vision_end|> + spatial_merge_size: How many patches are merged (2 = 2x2 = 4 patches merged) + vision_seq_len_buckets: List of vision sequence length buckets for compilation + """ + self.text_config = text_config + self.vision_config = vision_config + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + self.spatial_merge_size = spatial_merge_size + self.vision_seq_len_buckets = vision_seq_len_buckets or [1024, 4096, 16384] + + +class NeuronQwen35VLForCausalLM: + """Top-level VL model for Qwen3.5/3.6-27B on Neuron. + + This class manages: + - Separate compilation/loading of vision encoder and text decoder + - CPU-side mRoPE computation + - Vision embedding injection into text decoder + - The CTE+TKG generation loop + + Note: This is NOT an NeuronBaseForImageToText subclass because the + text decoder (NeuronQwen35ForCausalLM) has extensive custom overrides + (DeltaNet state management, custom forward, custom ModelWrapper) that + don't fit the base class pattern. Instead, this class composes the two + models and handles the VL orchestration directly. + """ + + def __init__(self, model_path, text_config, vision_config=None, processor=None): + """ + Args: + model_path: Path to HF model directory + text_config: Qwen35InferenceConfig for text decoder + vision_config: Qwen35VLInferenceConfig (or None for text-only) + processor: HF AutoProcessor for image preprocessing + """ + self.model_path = model_path + self.text_config = text_config + self.vl_config = vision_config + self.processor = processor + + # Text decoder (existing implementation) + self.text_model = NeuronQwen35ForCausalLM( + model_path=model_path, config=text_config + ) + + # Vision encoder (lazy init -- only built if vl_config provided) + self.vision_model_wrapper = None + if vision_config is not None: + self._init_vision_model(vision_config) + + # mRoPE state + self.rope_deltas = None + + def _init_vision_model(self, vl_config): + """Initialize the vision encoder wrapper.""" + from types import SimpleNamespace + + vision_cfg = SimpleNamespace(**vl_config.vision_config) + self.vision_model_wrapper = NeuronQwen35VisionModelWrapper( + config=vision_cfg, + model_cls=None, # Standalone mode (no NxDI parallel layers) + vision_seq_len_buckets=vl_config.vision_seq_len_buckets, + ) + self._vl_config = vl_config + + def compile(self, compiled_model_path): + """Compile both text and vision models. + + For the vision encoder, use compile_vision_encoder.py separately + (standalone torch_neuronx.trace compilation). Then use load() to + load the pre-compiled vision encoder. + """ + # Compile text decoder + text_path = os.path.join(compiled_model_path, "text_model") + os.makedirs(text_path, exist_ok=True) + self.text_model.compile(text_path) + + # Vision encoder is compiled separately via compile_vision_encoder.py + if self.vision_model_wrapper is not None: + logger.info( + "Vision encoder must be compiled separately using " + "compile_vision_encoder.py. Use load() to load the " + "pre-compiled vision encoder." + ) + + def load(self, compiled_model_path, vision_compiled_path=None): + """Load both compiled models. + + Args: + compiled_model_path: Path to compiled text model (or parent dir) + vision_compiled_path: Path to compiled vision encoder .pt file. + If None, looks for 'vision_encoder.pt' in compiled_model_path. + """ + text_path = os.path.join(compiled_model_path, "text_model") + if os.path.exists(text_path): + self.text_model.load(text_path) + else: + # Backward compatibility: text model compiled at root + self.text_model.load(compiled_model_path) + + # Load vision encoder + if self.vision_model_wrapper is not None: + if vision_compiled_path is None: + vision_compiled_path = os.path.join( + compiled_model_path, "vision_encoder.pt" + ) + if os.path.exists(vision_compiled_path): + self.vision_model_wrapper.load_compiled(vision_compiled_path) + # Also load CPU-side weights (patch_embed, pos_embed) + self.vision_model_wrapper.load_vision_weights_from_hf(self.model_path) + logger.info("Vision encoder loaded from pre-compiled model") + else: + logger.warning( + f"No compiled vision encoder found at {vision_compiled_path}. " + "Vision encoding will not be available." + ) + + # Qwen3.5 stop token IDs (loaded from config/tokenizer) + _DEFAULT_EOS_TOKEN_IDS = { + 248044, # <|endoftext|> -- text config eos_token_id + 248046, # <|im_end|> -- tokenizer eos_token / end of assistant turn + } + + def generate( + self, + input_ids, + attention_mask=None, + pixel_values=None, + image_grid_thw=None, + video_grid_thw=None, + max_new_tokens=32, + temperature=0.0, + top_p=1.0, + top_k=0, + eos_token_ids=None, + **kwargs, + ): + """Generate text from text and/or vision inputs. + + Args: + input_ids: (batch_size, seq_len) token IDs + attention_mask: (batch_size, seq_len) attention mask + pixel_values: Vision pixel values from HF processor (or None for text-only) + image_grid_thw: (num_images, 3) grid dimensions + video_grid_thw: (num_videos, 3) grid dimensions + max_new_tokens: Maximum new tokens to generate + temperature: Sampling temperature (0.0 = greedy/argmax) + top_p: Nucleus sampling threshold (1.0 = disabled) + top_k: Top-k sampling (0 = disabled) + eos_token_ids: Set of token IDs to stop generation on + (default: {248044, 248046}) + + Returns: + generated_ids: (batch_size, seq_len + max_new_tokens) token IDs + """ + if eos_token_ids is None: + eos_token_ids = self._DEFAULT_EOS_TOKEN_IDS + + # Reset text model state for a fresh generation. + # This ensures CTE runs (not TKG) even if a prior generate() was called. + # DeltaNet recurrent states don't need explicit zeroing because the CTE + # NKI kernel always starts from zero state. + self.text_model.reset() + + has_vision = pixel_values is not None and pixel_values.numel() > 0 + + # Step 1: Compute 3D mRoPE position IDs + if has_vision and self._vl_config is not None: + position_ids, self.rope_deltas = get_rope_index( + input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + image_token_id=self._vl_config.image_token_id, + video_token_id=self._vl_config.video_token_id, + vision_start_token_id=self._vl_config.vision_start_token_id, + spatial_merge_size=self._vl_config.spatial_merge_size, + ) + else: + # Text-only: use standard sequential position IDs + seq_len = input_ids.shape[1] + position_ids = torch.arange(seq_len).unsqueeze(0) + self.rope_deltas = None + + # Step 2: Run vision encoder and prepare injection args + llava_args = [] + batch_size = input_ids.shape[0] + if has_vision and self.vision_model_wrapper is not None: + # The vision encoder processes both image and video frames identically + # (they share the same ViT architecture). The HF processor outputs a + # single pixel_values tensor for images, and video frames are treated + # as multiple images with temporal grid > 1. + vision_embeddings = self.vision_model_wrapper(pixel_values, image_grid_thw) + # vision_embeddings: (total_merged_tokens, out_hidden_size) + + # Build vision_mask: boolean mask of ALL vision token positions + # (both image_token_id and video_token_id placeholders) + image_token_id = self._vl_config.image_token_id + video_token_id = self._vl_config.video_token_id + vision_bool_mask = (input_ids == image_token_id) | ( + input_ids == video_token_id + ) # (BS, seq_len) + + # For batch_size=1 (primary path): extract positions from batch element 0. + # For batch_size>1: each element may have different image token positions; + # we'd need per-element scatter. Currently only batch_size=1 is supported + # for VL (the compiled model uses batch_size=1 for CTE). + if batch_size > 1: + logger.warning( + "VL generation with batch_size > 1 is not fully supported. " + "Using batch element 0 for vision scatter positions." + ) + + positions = ( + vision_bool_mask[0].nonzero(as_tuple=False).squeeze(-1) + ) # (n_vision_tokens,) + + # Reshape vision_embeddings to (1, n_vision_tokens, hidden_size) + n_vis = positions.shape[0] + hidden_size = vision_embeddings.shape[-1] + vis_emb = vision_embeddings[:n_vis].unsqueeze(0) # (1, n_vis, hidden) + + # Pad to match input sequence length for compiled graph compatibility + seq_len = input_ids.shape[1] + pad_limit = seq_len # Must match the bucket size + + # Pad vision_embeddings to (1, pad_limit, hidden_size) + if n_vis < pad_limit: + pad_emb = torch.zeros( + (1, pad_limit - n_vis, hidden_size), + dtype=vis_emb.dtype, + ) + vis_emb_padded = torch.cat([vis_emb, pad_emb], dim=1) + else: + vis_emb_padded = vis_emb[:, :pad_limit] + + # Pad positions to (1, pad_limit, 1) with a SAFE fill value. + # CRITICAL: fill_value must be a valid index (within [0, pad_limit-1]). + # Using pad_limit-1 targets the last position (always a padding slot) + # so index_put_ scatters zero embeddings there harmlessly. + # NOTE: Do NOT use large sentinel values (e.g., 2**30) as they cause + # DGE out-of-bounds crashes in the Neuron runtime. + positions_padded = torch.full( + (1, pad_limit, 1), + fill_value=pad_limit - 1, + dtype=torch.int32, + ) + positions_padded[0, :n_vis, 0] = positions[:pad_limit].to(torch.int32) + + llava_args = [vis_emb_padded, positions_padded] + + # Append 3D mRoPE position IDs for the text model. + # position_ids shape: (3, batch_size, seq_len) from get_rope_index. + # _get_model_outputs receives this at slot 21 and pre-computes + # mRoPE cos/sin in get_model_output() for all decoder layers. + if position_ids.ndim == 3: + mrope_pos = position_ids[:, :, :seq_len].to(torch.int32).contiguous() + llava_args.append(mrope_pos) + else: + vision_embeddings = None + + # Step 3: Context encoding (prefill) + generated_ids = input_ids.clone() + + # CRITICAL: Always pass an explicit attention_mask for CTE. + # The base class _infer_attention_mask() assumes sequential position_ids + # (position_ids[i] >= i). When position_ids come from mRoPE temporal + # axis (non-sequential, e.g., all vision tokens share position 4), + # the inferred mask incorrectly masks out most of the sequence. + # Fix: provide a real all-ones mask for the actual token positions. + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + # For slot 2 (position_ids): use SEQUENTIAL positions regardless of mRoPE. + # Slot 2 is only used for: (1) logit position selection via torch.max(), + # (2) attention mask inference (which we bypass with explicit mask above). + # The actual RoPE computation uses slot 21 (rotary_position_ids) from + # _get_model_outputs, NOT slot 2. Using sequential slot 2 ensures + # correct logit selection and avoids any position_ids-related issues. + seq_len = input_ids.shape[1] + cte_position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) + + with torch.no_grad(): + output = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=cte_position_ids, + output_attentions=False, + output_hidden_states=False, + return_dict=False, + llava_args=llava_args, + ) + + logits = output[0] if isinstance(output, tuple) else output.logits + next_token = self._sample_token(logits[:, -1, :], temperature, top_p, top_k) + generated_ids = torch.cat([generated_ids, next_token.unsqueeze(-1)], dim=-1) + + # Check EOS after first token + if next_token.item() in eos_token_ids: + return generated_ids + + # Step 4: Token generation (TKG) loop + for _ in range(max_new_tokens - 1): + pos_ids = torch.tensor([[generated_ids.shape[1] - 1]]) + if self.rope_deltas is not None: + pos_ids = pos_ids + self.rope_deltas + + last_token = generated_ids[:, -1:] + with torch.no_grad(): + output = self.text_model( + input_ids=last_token, + position_ids=pos_ids, + output_attentions=False, + output_hidden_states=False, + return_dict=False, + ) + logits = output[0] if isinstance(output, tuple) else output.logits + next_token = self._sample_token(logits[:, -1, :], temperature, top_p, top_k) + generated_ids = torch.cat([generated_ids, next_token.unsqueeze(-1)], dim=-1) + + # Stop on EOS + if next_token.item() in eos_token_ids: + break + + return generated_ids + + @staticmethod + def _sample_token(logits, temperature=0.0, top_p=1.0, top_k=0): + """Sample a token from logits with optional temperature/top-p/top-k. + + Args: + logits: (batch_size, vocab_size) unnormalized logits + temperature: Sampling temperature. 0.0 = greedy (argmax). + top_p: Nucleus sampling threshold. 1.0 = disabled. + top_k: Top-k filtering. 0 = disabled. + + Returns: + token_id: (batch_size,) sampled token IDs + """ + if temperature <= 0.0: + return torch.argmax(logits, dim=-1) + + # Apply temperature + logits = logits / temperature + + # Top-k filtering + if top_k > 0: + top_k = min(top_k, logits.shape[-1]) + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = float("-inf") + + # Top-p (nucleus) filtering + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum( + torch.softmax(sorted_logits, dim=-1), dim=-1 + ) + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift right so the first token above threshold is kept + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ + ..., :-1 + ].clone() + sorted_indices_to_remove[..., 0] = False + # Scatter back to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + -1, sorted_indices, sorted_indices_to_remove + ) + logits[indices_to_remove] = float("-inf") + + # Sample from the filtered distribution + probs = torch.softmax(logits, dim=-1) + return torch.multinomial(probs, num_samples=1).squeeze(-1) + + @staticmethod + def prepare_input_args(text_prompt, image_path, processor, role="user"): + """Prepare inputs for vision+text generation. + + Args: + text_prompt: Text prompt string + image_path: Path to image file (or None for text-only) + processor: HF AutoProcessor + role: Message role (default "user") + + Returns: + input_ids, attention_mask, vision_inputs dict + """ + content = [] + if image_path is not None: + import base64 + from pathlib import Path + + image_data = Path(image_path).read_bytes() + b64 = base64.b64encode(image_data).decode("utf-8") + content.append( + { + "type": "image", + "url": f"data:image/jpeg;base64,{b64}", + } + ) + content.append({"type": "text", "text": text_prompt}) + + messages = [{"role": role, "content": content}] + inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_tensors="pt", + return_dict=True, + ) + + input_ids = inputs["input_ids"] + attention_mask = inputs.get("attention_mask", torch.ones_like(input_ids)) + + vision_inputs = {} + if "pixel_values" in inputs: + vision_inputs["pixel_values"] = inputs["pixel_values"] + if "image_grid_thw" in inputs: + vision_inputs["image_grid_thw"] = inputs["image_grid_thw"] + if "video_grid_thw" in inputs: + vision_inputs["video_grid_thw"] = inputs["video_grid_thw"] + + return input_ids, attention_mask, vision_inputs diff --git a/contrib/models/Qwen3.6-27B/src/nki_kernels/__init__.py b/contrib/models/Qwen3.6-27B/src/nki_kernels/__init__.py new file mode 100644 index 00000000..7e78cdb9 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/nki_kernels/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Custom NKI kernels for Qwen3.5-27B / Qwen3.6-27B DeltaNet layers. + +Contains three kernel implementations: +- nki_deltanet: Per-token recurrent kernel (used for token generation) +- nki_deltanet_chunked: Per-chunk kernel (legacy, superseded by fused) +- nki_deltanet_fused: Fused single-kernel chunked forward (used for context encoding) +""" diff --git a/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet.py b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet.py new file mode 100644 index 00000000..a9994d54 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet.py @@ -0,0 +1,334 @@ +"""NKI kernels for DeltaNet gated delta rule recurrent forward. + +NKI v3 (SDK 2.29, NKI 0.3.0). Processes a SINGLE (batch, head) pair per kernel call. +The caller loops over (B, H) in PyTorch and calls this kernel for each pair. + +Input layout: All inputs are 2D contiguous tensors (S, 128). +Each call processes one (batch, head) element's full sequence. + +k_dim = v_dim = 128, which matches SBUF tile partition dimension exactly. +g and beta are scalars per token, expanded to (S, 128) by the caller. + +Two kernel variants: + deltanet_recurrent_fwd -- returns output only (original) + deltanet_recurrent_fwd_state -- returns (output, final_state) for CTE->TKG carry-over +""" + +import nki +import nki.isa as nisa +import nki.language as nl + +# Partition dimension max (NeuronCore SBUF tile width) +P_MAX = 128 + +# Shuffle mask: broadcast partition 0 to all partitions in a 32-wide group +_BROADCAST_MASK = [0] * 32 + + +@nki.jit +def deltanet_recurrent_fwd( + query: nl.ndarray, # (S, 128) float32 + key: nl.ndarray, # (S, 128) float32 + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 128) float32, log-decay broadcast to 128 + beta_in: nl.ndarray, # (S, 128) float32, write gate broadcast to 128 +) -> nl.ndarray: + """NKI kernel for DeltaNet recurrent forward -- single (batch, head). + + Iterates over sequence tokens with sequential_range. + State matrix (128 x 128) lives in SBUF. + + Args: + query: (S, 128) float32 + key: (S, 128) float32 + value: (S, 128) float32 + g_in: (S, 128) float32 + beta_in: (S, 128) float32 + + Returns: + output: (S, 128) float32 + """ + seq_len, dim = query.shape + + # Output tensor in HBM + output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm) + + # Stride: for 2D (S, D), dim0 stride = D=128, dim1 stride = 1 + seq_stride = dim + + # Initialize recurrent state in SBUF: (128, 128) + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state, value=0.0) + + # Sequential loop over tokens (state-dependent) + for t in nl.sequential_range(seq_len): + tok_offset = t * seq_stride + + # ---- Load inputs for token t ---- + q_t = nl.ndarray((P_MAX, 1), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_t, + src=query.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + k_t = nl.ndarray((P_MAX, 1), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_t, + src=key.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + v_t = nl.ndarray((P_MAX, 1), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_t, + src=value.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + g_t = nl.ndarray((P_MAX, 1), dtype=g_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_t, + src=g_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + beta_t = nl.ndarray((P_MAX, 1), dtype=beta_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_t, + src=beta_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + # ---- Step 1: Decay state -- state = state * exp(g_t) ---- + exp_g = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_g, op=nl.exp, data=g_t, bias=None, scale=1.0) + + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_g, + engine=nisa.vector_engine, + ) + nisa.tensor_copy(dst=state, src=state_decayed) + + # ---- Step 2: Read memory -- kv_mem = state^T @ k_t ---- + kv_mem_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_mem_psum, stationary=state, moving=k_t) + kv_mem = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_mem, src=kv_mem_psum) + + # ---- Step 3: delta = (v_t - kv_mem) * beta_t ---- + v_sub = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_sub, data1=v_t, data2=kv_mem, op=nl.subtract) + + delta = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=delta, + data=v_sub, + op0=nl.multiply, + operand0=beta_t, + engine=nisa.vector_engine, + ) + + # ---- Step 4: state += outer(k_t, delta) ---- + # Broadcast multiply: outer[i,j] = k_t[i] * delta[j] + # 1) Transpose delta (128,1) -> (1,128) in PSUM + # 2) Copy PSUM (1,128) -> SBUF (128,128) -- partition broadcast + # 3) Multiply by k_t (128,1) which broadcasts across free dim + # This avoids the nc_matmul P=1 outer product (wastes 127/128 TE lanes). + + # Transpose delta to get values along free dimension + delta_row_psum = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=delta_row_psum, data=delta) + + # Copy PSUM (1, 128) -> SBUF (1, 128) first (NKI 0.3.0 requires matching P dims) + delta_row_sb = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=delta_row_sb, src=delta_row_psum) + + # Broadcast (1, 128) SBUF -> (128, 128) SBUF via nc_stream_shuffle + # Each partition row gets the same delta values + delta_broadcast = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=delta_row_sb[0:1, 0:P_MAX], + dst=delta_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + # Element-wise multiply: outer[i,j] = delta_broadcast[i,j] * k_t[i,0] + # tensor_scalar broadcasts (P,1) k_t across all F columns + outer_prod = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=outer_prod, + data=delta_broadcast, + op0=nl.multiply, + operand0=k_t, + engine=nisa.vector_engine, + ) + + # Accumulate into state + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_new, data1=state, data2=outer_prod, op=nl.add) + nisa.tensor_copy(dst=state, src=state_new) + + # ---- Step 5: o_t = state^T @ q_t ---- + o_t_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=o_t_psum, stationary=state, moving=q_t) + o_t = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=o_t, src=o_t_psum) + + # ---- Store output for token t ---- + nisa.dma_copy( + dst=output.ap(pattern=[[1, dim]], offset=tok_offset), + src=o_t, + ) + + return output + + +@nki.jit +def deltanet_recurrent_fwd_state( + query: nl.ndarray, # (S, 128) float32 + key: nl.ndarray, # (S, 128) float32 + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 128) float32, log-decay broadcast to 128 + beta_in: nl.ndarray, # (S, 128) float32, write gate broadcast to 128 +): + """NKI kernel for DeltaNet recurrent forward with final state output. + + Same recurrence as deltanet_recurrent_fwd, but ALSO writes the final + recurrent state (128, 128) to an output HBM buffer. This enables + CTE -> TKG state carry-over. + + Returns: + output: (S, 128) float32 -- per-token output + final_state: (128, 128) float32 -- recurrent state after last token + """ + seq_len, dim = query.shape + + # Output tensors in HBM + output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm) + final_state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + # Stride: for 2D (S, D), dim0 stride = D=128, dim1 stride = 1 + seq_stride = dim + + # Initialize recurrent state in SBUF: (128, 128) + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state, value=0.0) + + # Sequential loop over tokens (state-dependent) + for t in nl.sequential_range(seq_len): + tok_offset = t * seq_stride + + # ---- Load inputs for token t ---- + q_t = nl.ndarray((P_MAX, 1), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_t, + src=query.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + k_t = nl.ndarray((P_MAX, 1), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_t, + src=key.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + v_t = nl.ndarray((P_MAX, 1), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_t, + src=value.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + g_t = nl.ndarray((P_MAX, 1), dtype=g_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_t, + src=g_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + beta_t = nl.ndarray((P_MAX, 1), dtype=beta_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_t, + src=beta_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + # ---- Step 1: Decay state -- state = state * exp(g_t) ---- + exp_g = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_g, op=nl.exp, data=g_t, bias=None, scale=1.0) + + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_g, + engine=nisa.vector_engine, + ) + nisa.tensor_copy(dst=state, src=state_decayed) + + # ---- Step 2: Read memory -- kv_mem = state^T @ k_t ---- + kv_mem_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_mem_psum, stationary=state, moving=k_t) + kv_mem = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_mem, src=kv_mem_psum) + + # ---- Step 3: delta = (v_t - kv_mem) * beta_t ---- + v_sub = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_sub, data1=v_t, data2=kv_mem, op=nl.subtract) + + delta = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=delta, + data=v_sub, + op0=nl.multiply, + operand0=beta_t, + engine=nisa.vector_engine, + ) + + # ---- Step 4: state += outer(k_t, delta) ---- + # Broadcast multiply: outer[i,j] = k_t[i] * delta[j] + delta_row_psum = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=delta_row_psum, data=delta) + + # Copy PSUM (1, 128) -> SBUF (1, 128) first (NKI 0.3.0 requires matching P dims) + delta_row_sb = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=delta_row_sb, src=delta_row_psum) + + # Broadcast (1, 128) SBUF -> (128, 128) SBUF via nc_stream_shuffle + delta_broadcast = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=delta_row_sb[0:1, 0:P_MAX], + dst=delta_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + outer_prod = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=outer_prod, + data=delta_broadcast, + op0=nl.multiply, + operand0=k_t, + engine=nisa.vector_engine, + ) + + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_new, data1=state, data2=outer_prod, op=nl.add) + nisa.tensor_copy(dst=state, src=state_new) + + # ---- Step 5: o_t = state^T @ q_t ---- + o_t_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=o_t_psum, stationary=state, moving=q_t) + o_t = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=o_t, src=o_t_psum) + + # ---- Store output for token t ---- + nisa.dma_copy( + dst=output.ap(pattern=[[1, dim]], offset=tok_offset), + src=o_t, + ) + + # ---- Write final state to HBM ---- + # state is (128, 128) in SBUF, copy to final_state in HBM + # Use dma_copy with full tile: P_MAX rows, dim cols + nisa.dma_copy(dst=final_state, src=state) + + return output, final_state diff --git a/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_chunked.py b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_chunked.py new file mode 100644 index 00000000..f834c969 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_chunked.py @@ -0,0 +1,546 @@ +"""NKI per-chunk DeltaNet kernel for CTE (context encoding / prefill). + +Single-chunk kernel: processes one chunk (128 tokens) with a stable +triangular solve for intra-chunk correction. The caller loops over chunks in +PyTorch, passing state between calls. + +Each kernel call: + - Takes one chunk of data: q, k, v, beta, g_cumsum, g_last (all 128x128) + - Takes recurrent state_in (128x128) + - Returns chunk output (128x128) and state_out (128x128) + +No sequence-indexed DMA inside the kernel -- all inputs/outputs are full tiles. +This avoids the DMA OOB issue seen with nl.sequential_range + slice indexing +in the NxDI model compilation context. + +NKI v3 (SDK 2.29, NKI 0.3.0). Uses nki.* namespace. +""" + +import nki +import nki.isa as nisa +import nki.language as nl + +P_MAX = 128 + +# Broadcast partition 0 to all partitions in a 32-wide group. +_BROADCAST_MASK = [0] * 32 + + +@nki.jit +def deltanet_chunk_step( + query, # (128, 128) float32 -- one chunk, l2-normed+scaled + key, # (128, 128) float32 -- one chunk, l2-normed + value, # (128, 128) float32 -- one chunk + beta_broadcast, # (128, 128) float32 -- write gate broadcast to 128 + g_cumsum, # (128, 128) float32 -- cumsum of g within chunk, broadcast + g_last, # (128, 128) float32 -- g_cumsum[-1], constant in chunk, broadcast + state_in, # (128, 128) float32 -- recurrent state from previous chunk + lower_mask, # (128, 128) float32 -- strict lower triangular + identity, # (128, 128) float32 -- identity matrix + lower_mask_diag, # (128, 128) float32 -- lower tri with diagonal +): + """Process one chunk of DeltaNet. + + Returns: + output: (128, 128) float32 -- chunk output + state_out: (128, 128) float32 -- updated recurrent state + """ + C, dim = query.shape # C = 128, dim = 128 + + # Output tensors in HBM + output = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.shared_hbm) + state_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + # Load all inputs into SBUF + q_c = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=q_c, src=query) + + k_c = nl.ndarray((P_MAX, dim), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=k_c, src=key) + + v_c = nl.ndarray((P_MAX, dim), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=v_c, src=value) + + beta_c = nl.ndarray((P_MAX, dim), dtype=beta_broadcast.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=beta_c, src=beta_broadcast) + + gc_c = nl.ndarray((P_MAX, dim), dtype=g_cumsum.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=gc_c, src=g_cumsum) + + gl_c = nl.ndarray((P_MAX, dim), dtype=g_last.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=gl_c, src=g_last) + + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=state, src=state_in) + + # Load masks + eye = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=eye, src=identity) + + Lmask = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask, src=lower_mask) + + Lmask_d = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask_d, src=lower_mask_diag) + + # ============================================================ + # k_beta = K * beta, v_beta = V * beta + # ============================================================ + k_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=k_beta, data1=k_c, data2=beta_c, op=nl.multiply) + + v_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_beta, data1=v_c, data2=beta_c, op=nl.multiply) + + # ============================================================ + # Stable decay factors from cumulative log-decay + # + # The caller passes g_cumsum and g_last broadcast to (128, 128). Extract + # one column and build pairwise decays as exp(gc[i] - gc[j]) so no + # individual exp(-gc[j]) term can overflow. + # ============================================================ + gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=gc_p[0:P_MAX, 0:1], src=gc_c[0:P_MAX, 0:1]) + + gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=gl_p[0:P_MAX, 0:1], src=gl_c[0:P_MAX, 0:1]) + + exp_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + exp_gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_p[0:P_MAX, 0:1], + op=nl.exp, + data=gl_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + gc_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=gc_padded, value=0.0) + nisa.tensor_copy(dst=gc_padded[0:P_MAX, 0:1], src=gc_p[0:P_MAX, 0:1]) + + gc_row_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=gc_row_psum, data=gc_padded) + + gc_row = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=gc_row[0:1, 0:P_MAX], src=gc_row_psum[0:1, 0:P_MAX]) + + gc_row_broadcast = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gc_row[0:1, 0:P_MAX], + dst=gc_row_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + gc_col_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=gc_col_strict, + data=Lmask, + op0=nl.multiply, + operand0=gc_p, + engine=nisa.vector_engine, + ) + gc_row_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gc_row_strict, data1=gc_row_broadcast, data2=Lmask, op=nl.multiply + ) + g_diff_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=g_diff_strict, + data1=gc_col_strict, + data2=gc_row_strict, + op=nl.subtract, + ) + decay_strict_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=decay_strict_raw, + op=nl.exp, + data=g_diff_strict, + bias=None, + scale=1.0, + ) + decay_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=decay_strict, data1=decay_strict_raw, data2=Lmask, op=nl.multiply + ) + + gc_col_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=gc_col_diag, + data=Lmask_d, + op0=nl.multiply, + operand0=gc_p, + engine=nisa.vector_engine, + ) + gc_row_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gc_row_diag, data1=gc_row_broadcast, data2=Lmask_d, op=nl.multiply + ) + g_diff_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=g_diff_diag, + data1=gc_col_diag, + data2=gc_row_diag, + op=nl.subtract, + ) + decay_diag_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=decay_diag_raw, + op=nl.exp, + data=g_diff_diag, + bias=None, + scale=1.0, + ) + decay_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=decay_diag, data1=decay_diag_raw, data2=Lmask_d, op=nl.multiply + ) + + # ============================================================ + # Phase 1: Build A matrix (intra-chunk correction) + # QK = k_beta @ k^T -- contract over features + # ============================================================ + kb_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kb_T_psum, stationary=k_beta, moving=eye) + kb_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kb_T, src=kb_T_psum) + + k_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=k_T_psum, stationary=k_c, moving=eye) + k_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_T, src=k_T_psum) + + QK_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_psum, stationary=kb_T, moving=k_T) + QK = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK, src=QK_psum) + + # QK_decay[i,j] = QK[i,j] * exp(gc[i] - gc[j]) for i > j. + QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=QK_decay, data1=QK, data2=decay_strict, op=nl.multiply) + + # A = -QK_decay * lower_mask + neg_QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_QK_decay, + data=QK_decay, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + A_mat = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=A_mat, data1=neg_QK_decay, data2=Lmask, op=nl.multiply) + + # ============================================================ + # Stable triangular solve: N = inv(I - A_mat) + # + # A_mat is strictly lower triangular. Solve two 64x64 diagonal blocks + # row-by-row, then merge the lower-left block. This is equivalent to the + # nilpotent Neumann series but avoids repeated squaring of A. + # ============================================================ + P_acc = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=P_acc, value=0.0) + + A_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=A_T_psum, data=A_mat) + A_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=A_T, src=A_T_psum) + + col_mask_left_row = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=col_mask_left_row, value=0.0) + nisa.memset(dst=col_mask_left_row[0:1, 0:64], value=1.0) + col_mask_left = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=col_mask_left_row[0:1, 0:P_MAX], + dst=col_mask_left[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + col_mask_right_row = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=col_mask_right_row, value=0.0) + nisa.memset(dst=col_mask_right_row[0:1, 64:P_MAX], value=1.0) + col_mask_right = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=col_mask_right_row[0:1, 0:P_MAX], + dst=col_mask_right[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + block_row_mask_bottom = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=block_row_mask_bottom[0:P_MAX, 0:1], + src=Lmask_d[0:P_MAX, 64:65], + ) + + for solve_i in nl.static_range(64): + row_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=row_psum, stationary=A_T, moving=P_acc) + row_prod = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=row_prod, src=row_psum) + + row_with_eye = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=row_with_eye, data1=row_prod, data2=eye, op=nl.add) + + row_col_masked = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=row_col_masked, + data1=row_with_eye, + data2=col_mask_left, + op=nl.multiply, + ) + + row_mask = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=row_mask[0:P_MAX, 0:1], + src=eye[0:P_MAX, solve_i : solve_i + 1], + ) + row_update = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=row_update, + data=row_col_masked, + op0=nl.multiply, + operand0=row_mask, + engine=nisa.vector_engine, + ) + + P_next = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=P_next, data1=P_acc, data2=row_update, op=nl.add) + nisa.tensor_copy(dst=P_acc, src=P_next) + + for solve_i in nl.static_range(64): + row_idx = 64 + solve_i + + row_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=row_psum, stationary=A_T, moving=P_acc) + row_prod = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=row_prod, src=row_psum) + + row_with_eye = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=row_with_eye, data1=row_prod, data2=eye, op=nl.add) + + row_col_masked = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=row_col_masked, + data1=row_with_eye, + data2=col_mask_right, + op=nl.multiply, + ) + + row_mask = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=row_mask[0:P_MAX, 0:1], + src=eye[0:P_MAX, row_idx : row_idx + 1], + ) + row_update = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=row_update, + data=row_col_masked, + op0=nl.multiply, + operand0=row_mask, + engine=nisa.vector_engine, + ) + + P_next = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=P_next, data1=P_acc, data2=row_update, op=nl.add) + nisa.tensor_copy(dst=P_acc, src=P_next) + + N_diag_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=N_diag_T_psum, data=P_acc) + N_diag_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=N_diag_T, src=N_diag_T_psum) + + tmp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=tmp_psum, stationary=N_diag_T, moving=A_mat) + tmp = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=tmp, src=tmp_psum) + + tmp_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=tmp_T_psum, data=tmp) + tmp_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=tmp_T, src=tmp_T_psum) + + N21_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=N21_psum, stationary=tmp_T, moving=P_acc) + N21 = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=N21, src=N21_psum) + + N21_col_masked = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=N21_col_masked, + data1=N21, + data2=col_mask_left, + op=nl.multiply, + ) + N21_block = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=N21_block, + data=N21_col_masked, + op0=nl.multiply, + operand0=block_row_mask_bottom, + engine=nisa.vector_engine, + ) + nisa.tensor_tensor(dst=P_acc, data1=P_acc, data2=N21_block, op=nl.add) + + # ============================================================ + # Apply N: value_corr = N @ v_beta, k_cumdecay = N @ (k_beta * exp_gc) + # ============================================================ + N_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=N_T_psum, data=P_acc) + N_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=N_T, src=N_T_psum) + + vc_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vc_psum, stationary=N_T, moving=v_beta) + value_corr = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=value_corr, src=vc_psum) + + kb_exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=kb_exp_gc, + data=k_beta, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + kcd_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kcd_psum, stationary=N_T, moving=kb_exp_gc) + k_cumdecay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_cumdecay, src=kcd_psum) + + # ============================================================ + # Phase 2: Inter-chunk state propagation + # attn_intra = (q @ k^T) * decay_mask * lower_mask_diag + # ============================================================ + q_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=q_T_psum, stationary=q_c, moving=eye) + q_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=q_T, src=q_T_psum) + + qk_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qk_psum, stationary=q_T, moving=k_T) + qk_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_raw, src=qk_psum) + + attn_intra = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=attn_intra, data1=qk_raw, data2=decay_diag, op=nl.multiply) + + # ============================================================ + # v_prime = k_cumdecay @ state + # ============================================================ + kcd_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kcd_T_psum, stationary=k_cumdecay, moving=eye) + kcd_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kcd_T, src=kcd_T_psum) + + vp_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vp_psum, stationary=kcd_T, moving=state) + v_prime = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=v_prime, src=vp_psum) + + v_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_new, data1=value_corr, data2=v_prime, op=nl.subtract) + + # ============================================================ + # attn_inter = (q * exp(g_cumsum)) @ state + # ============================================================ + q_exp = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_exp, + data=q_c, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + qe_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qe_T_psum, stationary=q_exp, moving=eye) + qe_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qe_T, src=qe_T_psum) + + ai_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=ai_psum, stationary=qe_T, moving=state) + attn_inter = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=attn_inter, src=ai_psum) + + # ============================================================ + # attn_intra @ v_new + # ============================================================ + ai_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=ai_T_psum, stationary=attn_intra, moving=eye) + ai_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=ai_T, src=ai_T_psum) + + intra_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=intra_psum, stationary=ai_T, moving=v_new) + intra_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=intra_out, src=intra_psum) + + # ============================================================ + # chunk_output = attn_inter + intra_out + # ============================================================ + chunk_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=chunk_out, data1=attn_inter, data2=intra_out, op=nl.add) + + nisa.dma_copy(dst=output, src=chunk_out) + + # ============================================================ + # State update: state_new = state * exp(g_last) + # + (k * exp(g_last - gc))^T @ v_new + # ============================================================ + gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gl_minus_gc_p, + data1=gl_p, + data2=gc_p, + op=nl.subtract, + ) + exp_gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_minus_gc_p, + op=nl.exp, + data=gl_minus_gc_p, + bias=None, + scale=1.0, + ) + + k_raw_decay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_raw_decay, + data=k_c, + op0=nl.multiply, + operand0=exp_gl_minus_gc_p, + engine=nisa.vector_engine, + ) + + kv_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_psum, stationary=k_raw_decay, moving=v_new) + kv_outer = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_outer, src=kv_psum) + + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_gl_p, + engine=nisa.vector_engine, + ) + + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_new, data1=state_decayed, data2=kv_outer, op=nl.add) + + nisa.dma_copy(dst=state_out, src=state_new) + + return output, state_out diff --git a/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py new file mode 100644 index 00000000..6008ae5a --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py @@ -0,0 +1,595 @@ +"""Fused single-kernel DeltaNet chunked forward for CTE (context encoding). + +SSD-style architecture: processes ALL chunks for one (batch, head) pair in +a single NKI kernel call. State (128x128) persists in SBUF across chunks — +no HBM round-trips for inter-chunk state propagation. + +Key optimizations over nki_deltanet_chunked.py: + 1. Single kernel call per (B,H) instead of B*H*num_chunks calls + 2. State in SBUF across all chunks (no HBM state read/write per chunk) + 3. In-kernel cumsum via tensor_tensor_scan (no PyTorch cumsum) + 4. Masks and constants loaded once, reused across chunks + 5. Uses tensor_scalar for partition-broadcast (no explicit broadcast loops) + 6. nc_transpose (Vector Engine) for all 128x128 transposes instead of + nc_matmul(moving=eye) (Tensor Engine) — frees TE for actual math + +NKI 0.3.0 (SDK 2.29). k_dim = v_dim = 128 = P_MAX exactly. +Chunk size = 128 = P_MAX (one tile per chunk). + +Mathematical framework (same as nki_deltanet_chunked.py): + Per-chunk Neumann-series power-doubling for intra-chunk correction: + A = -QK_decay * lower_mask + N = (I+A)(I+A^2)(I+A^4)...(I+A^64) [6 rounds] + value_corr = N @ v_beta + k_cumdecay = N @ (k_beta * exp(gc)) + + Inter-chunk state propagation: + v_prime = k_cumdecay @ state + v_new = value_corr - v_prime + attn_inter = (q * exp(gc)) @ state + attn_intra = (q @ k^T) * decay_mask * lower_mask_diag + output = attn_inter + attn_intra @ v_new + state = exp(g_last) * (state + k_raw_decay^T @ v_new) +""" + +import numpy as np + +import nki +import nki.isa as nisa +import nki.language as nl + +P_MAX = 128 # Partition dim = chunk_size = k_dim = v_dim +CHUNK_SIZE = 128 + +# Broadcast partition 0 to all partitions in a 32-wide group +_BROADCAST_MASK = [0] * 32 + + +def _make_lower_mask(): + """Strict lower triangular (128x128) as numpy constant.""" + return np.tril(np.ones((CHUNK_SIZE, CHUNK_SIZE), dtype=np.float32), k=-1) + + +def _make_lower_mask_diag(): + """Lower triangular with diagonal (128x128) as numpy constant.""" + return np.tril(np.ones((CHUNK_SIZE, CHUNK_SIZE), dtype=np.float32), k=0) + + +def _make_identity(): + """Identity matrix (128x128) as numpy constant.""" + return np.eye(CHUNK_SIZE, dtype=np.float32) + + +@nki.jit +def deltanet_fused_chunked_fwd( + query: nl.ndarray, # (S, 128) float32 — l2-normed and scaled + key: nl.ndarray, # (S, 128) float32 — l2-normed + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 1) float32 — per-token log-decay (NOT cumsum) + beta_in: nl.ndarray, # (S, 1) float32 — per-token write gate + lower_mask: nl.ndarray, # (128, 128) float32 — strict lower tri + identity: nl.ndarray, # (128, 128) float32 — identity + lower_mask_diag: nl.ndarray, # (128, 128) float32 — lower tri with diag +): + """Fused chunked DeltaNet forward — single kernel call per (batch, head). + + Processes all chunks sequentially within the kernel, keeping the recurrent + state (128x128) in SBUF across chunks. Returns per-token output and + final state. + + Input requirements: + - S must be divisible by 128 (pad before calling) + - query must be l2-normed and scaled by 1/sqrt(k_dim) + - key must be l2-normed + - g_in is RAW log-decay (cumsum computed in-kernel via tensor_tensor_scan) + - beta_in is sigmoid(b) (write gate) + + Returns: + output: (S, 128) float32 + final_state: (128, 128) float32 + """ + seq_len = query.shape[0] + dim = query.shape[1] # 128 + num_chunks = seq_len // CHUNK_SIZE + + # Output tensors in HBM + output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm) + final_state_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + # ================================================================ + # Load constant masks into SBUF once (reused across all chunks) + # ================================================================ + eye = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=eye, src=identity) + + Lmask = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask, src=lower_mask) + + Lmask_d = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask_d, src=lower_mask_diag) + + # Ones vector for cumsum scan: (1, CHUNK_SIZE) + ones_1xC = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=ones_1xC, value=1.0) + + # Zero initial for cumsum scan + zero_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=zero_11, value=0.0) + + # ================================================================ + # Initialize recurrent state in SBUF — persists across ALL chunks + # ================================================================ + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state, value=0.0) + + # ================================================================ + # Sequential chunk processing + # ================================================================ + for i_chunk in nl.sequential_range(num_chunks): + chunk_start = i_chunk * CHUNK_SIZE + + # ---- Load chunk data from HBM ---- + q_c = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_c, + src=query[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + k_c = nl.ndarray((P_MAX, dim), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_c, + src=key[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + v_c = nl.ndarray((P_MAX, dim), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_c, + src=value[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + # g: (CHUNK_SIZE, 1) — raw log-decay per token + g_chunk_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_chunk_p[0:CHUNK_SIZE, 0:1], + src=g_in[chunk_start : chunk_start + CHUNK_SIZE, 0:1], + ) + + # beta: (CHUNK_SIZE, 1) — write gate scalar per token + beta_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_p[0:CHUNK_SIZE, 0:1], + src=beta_in[chunk_start : chunk_start + CHUNK_SIZE, 0:1], + ) + + # ---- In-kernel cumsum of g via tensor_tensor_scan ---- + # Need g as (1, CHUNK_SIZE) for scan along free dim. + # Transpose: (CHUNK_SIZE, 1) -> (1, CHUNK_SIZE) via nc_transpose + g_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=g_padded, value=0.0) + nisa.tensor_copy( + dst=g_padded[0:CHUNK_SIZE, 0:1], + src=g_chunk_p[0:CHUNK_SIZE, 0:1], + ) + + g_tp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=g_tp_psum, data=g_padded) + + g_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=g_row[0:1, 0:CHUNK_SIZE], + src=g_tp_psum[0:1, 0:CHUNK_SIZE], + ) + + # cumsum: gc_row[t] = 1.0 * gc_row[t-1] + g_row[t] + gc_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor_scan( + dst=gc_row[0:1, 0:CHUNK_SIZE], + data0=ones_1xC[0:1, 0:CHUNK_SIZE], + data1=g_row[0:1, 0:CHUNK_SIZE], + initial=zero_11[0:1, 0:1], + op0=nl.multiply, + op1=nl.add, + ) + + # Transpose gc back to (CHUNK_SIZE, 1) partition layout + gc_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=gc_padded, value=0.0) + nisa.tensor_copy( + dst=gc_padded[0:1, 0:CHUNK_SIZE], + src=gc_row[0:1, 0:CHUNK_SIZE], + ) + + gc_tp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=gc_tp_psum, data=gc_padded) + + # gc_p: (P_MAX, 1) — cumulative sum of g per token in this chunk + gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=gc_p[0:CHUNK_SIZE, 0:1], + src=gc_tp_psum[0:CHUNK_SIZE, 0:1], + ) + + # g_last = gc[-1] (scalar) — needed for state decay + gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=gl_11[0:1, 0:1], + src=gc_row[0:1, CHUNK_SIZE - 1 : CHUNK_SIZE], + ) + + # ---- Compute exp(gc), exp(-gc), exp(g_last) as (P_MAX, 1) scalars ---- + # These (P_MAX, 1) tensors are used with tensor_scalar to broadcast + # across the free dimension without explicit (P_MAX, dim) copies. + + exp_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + neg_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_gc_p, + data=gc_p, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + exp_neg_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_neg_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=neg_gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + # exp(g_last): scalar, then broadcast to (P_MAX, 1) + exp_gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_11, + op=nl.exp, + data=gl_11, + bias=None, + scale=1.0, + ) + + exp_gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=exp_gl_11[0:1, 0:1], + dst=exp_gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) + + # ============================================================ + # k_beta = K * beta, v_beta = V * beta + # tensor_scalar broadcasts beta_p (P_MAX, 1) across free dim + # ============================================================ + k_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_beta, + data=k_c, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + v_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=v_beta, + data=v_c, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + # ============================================================ + # Phase 1: Build A matrix (intra-chunk correction) + # Transpose K and K_beta for matmul + # ============================================================ + kb_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=kb_T_psum, data=k_beta) + kb_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kb_T, src=kb_T_psum) + + k_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=k_T_psum, data=k_c) + k_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_T, src=k_T_psum) + + # QK = k_beta^T @ k (contract over features) + QK_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_psum, stationary=kb_T, moving=k_T) + QK = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK, src=QK_psum) + + # ============================================================ + # Decay mask: QK_decay[i,j] = QK[i,j] * exp(gc[i]) * exp(-gc[j]) + # + # Apply the strict causal mask before the split exp(gc) / exp(-gc) + # scaling. Upper-triangular entries are mathematically unused, but + # scaling them first can create very large finite values that poison + # later matmuls before the mask is applied. + # ============================================================ + QK_masked = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=QK_masked, data1=QK, data2=Lmask, op=nl.multiply) + + # Row scaling: QK_row[i,:] = QK[i,:] * exp(gc[i]) + # Then transpose, column scale, transpose back. + # Uses tensor_scalar with (P_MAX,1) operand for row scaling. + QK_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=QK_row, + data=QK_masked, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + # Transpose to scale columns (now rows in transposed view) + QK_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=QK_r_T_psum, data=QK_row) + QK_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK_r_T, src=QK_r_T_psum) + + QK_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=QK_r_T_col, + data=QK_r_T, + op0=nl.multiply, + operand0=exp_neg_gc_p, + engine=nisa.vector_engine, + ) + + # Transpose back + QK_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=QK_d_psum, data=QK_r_T_col) + QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK_decay, src=QK_d_psum) + + # A = -QK_decay * lower_mask + neg_QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_QK_decay, + data=QK_decay, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + A_mat = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=A_mat, data1=neg_QK_decay, data2=Lmask, op=nl.multiply) + + # ============================================================ + # Neumann power-doubling: N = (I+A)(I+A^2)...(I+A^{64}) + # 6 rounds → resolves rank up to 2^6 = 64 (sufficient for chunk=128) + # ============================================================ + P_acc = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=P_acc, data1=eye, data2=A_mat, op=nl.add) + + A_pow = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=A_pow, src=A_mat) + + for _round in nl.sequential_range(6): + # A_pow = A_pow^2: transpose A_pow, then matmul + Ap_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=Ap_T_psum, data=A_pow) + Ap_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=Ap_T, src=Ap_T_psum) + + Ap_sq_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=Ap_sq_psum, stationary=Ap_T, moving=A_pow) + nisa.tensor_copy(dst=A_pow, src=Ap_sq_psum) + + # P_acc = (I + A_pow) @ P_acc: transpose IpA, then matmul + IpA = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=IpA, data1=eye, data2=A_pow, op=nl.add) + + IpA_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=IpA_T_psum, data=IpA) + IpA_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=IpA_T, src=IpA_T_psum) + + Pacc_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=Pacc_psum, stationary=IpA_T, moving=P_acc) + nisa.tensor_copy(dst=P_acc, src=Pacc_psum) + + # ============================================================ + # Apply N: value_corr = N @ v_beta + # k_cumdecay = N @ (k_beta * exp(gc)) + # ============================================================ + N_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=N_T_psum, data=P_acc) + N_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=N_T, src=N_T_psum) + + vc_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vc_psum, stationary=N_T, moving=v_beta) + value_corr = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=value_corr, src=vc_psum) + + # k_beta * exp(gc): row-scaled + kb_exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=kb_exp_gc, + data=k_beta, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + kcd_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kcd_psum, stationary=N_T, moving=kb_exp_gc) + k_cumdecay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_cumdecay, src=kcd_psum) + + # ============================================================ + # Phase 2: Inter-chunk state propagation + # attn_intra = (q @ k^T) * decay_mask * lower_mask_diag + # ============================================================ + q_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=q_T_psum, data=q_c) + q_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=q_T, src=q_T_psum) + + qk_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qk_psum, stationary=q_T, moving=k_T) + qk_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_raw, src=qk_psum) + + # Mask before split scaling for the same reason as the A matrix above: + # upper-triangular decay factors are unused and can be numerically huge. + qk_masked = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=qk_masked, data1=qk_raw, data2=Lmask_d, op=nl.multiply) + + # Row-scale by exp(gc) + qk_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=qk_row, + data=qk_masked, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + # Transpose, column-scale by exp(-gc), transpose back + qk_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=qk_r_T_psum, data=qk_row) + qk_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_r_T, src=qk_r_T_psum) + + qk_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=qk_r_T_col, + data=qk_r_T, + op0=nl.multiply, + operand0=exp_neg_gc_p, + engine=nisa.vector_engine, + ) + + qk_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=qk_d_psum, data=qk_r_T_col) + qk_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_decay, src=qk_d_psum) + + attn_intra = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=attn_intra, data1=qk_decay, data2=Lmask_d, op=nl.multiply + ) + + # ============================================================ + # v_prime = k_cumdecay @ state (state is in SBUF!) + # ============================================================ + kcd_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=kcd_T_psum, data=k_cumdecay) + kcd_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kcd_T, src=kcd_T_psum) + + vp_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vp_psum, stationary=kcd_T, moving=state) + v_prime = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=v_prime, src=vp_psum) + + v_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_new, data1=value_corr, data2=v_prime, op=nl.subtract) + + # ============================================================ + # attn_inter = (q * exp(gc)) @ state (state is in SBUF!) + # ============================================================ + q_exp = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_exp, + data=q_c, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + qe_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=qe_T_psum, data=q_exp) + qe_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qe_T, src=qe_T_psum) + + ai_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=ai_psum, stationary=qe_T, moving=state) + attn_inter = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=attn_inter, src=ai_psum) + + # ============================================================ + # attn_intra @ v_new + # ============================================================ + ai_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=ai_T_psum, data=attn_intra) + ai_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=ai_T, src=ai_T_psum) + + intra_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=intra_psum, stationary=ai_T, moving=v_new) + intra_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=intra_out, src=intra_psum) + + # ============================================================ + # chunk_output = attn_inter + intra_out + # ============================================================ + chunk_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=chunk_out, data1=attn_inter, data2=intra_out, op=nl.add) + + # Store output chunk to HBM + nisa.dma_copy( + dst=output[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + src=chunk_out, + ) + + # ============================================================ + # State update: state = exp(g_last) * (state + k_raw_decay^T @ v_new) + # state is updated IN-PLACE in SBUF — no HBM round-trip! + # ============================================================ + + # k_raw_decay contributes as exp(g_last) * (k * exp(-gc))^T @ v_new. + # Compute the equivalent stable form k * exp(g_last - gc), so the + # factor is always <= 1 for valid causal positions. + exp_gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=exp_gl_minus_gc_p, + data1=exp_gl_p, + data2=exp_neg_gc_p, + op=nl.multiply, + ) + + # k_raw_decay = k * exp(g_last - gc) + k_raw_decay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_raw_decay, + data=k_c, + op0=nl.multiply, + operand0=exp_gl_minus_gc_p, + engine=nisa.vector_engine, + ) + + # k_raw_decay^T @ v_new → (dim, dim) outer product sum + # nc_matmul: result[M,N] = sum_K stationary[K,M] * moving[K,N] + # stationary=k_raw_decay (P_MAX, dim), moving=v_new (P_MAX, dim) + # Result: sum over tokens -> (dim, dim) + kv_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_psum, stationary=k_raw_decay, moving=v_new) + kv_outer = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_outer, src=kv_psum) + + # state = state * exp(g_last) + kv_outer + # tensor_scalar broadcasts exp_gl_p (P_MAX, 1) across free dim. + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_gl_p, + engine=nisa.vector_engine, + ) + nisa.tensor_tensor(dst=state, data1=state_decayed, data2=kv_outer, op=nl.add) + + # ---- Write final state to HBM ---- + nisa.dma_copy(dst=final_state_out, src=state) + + return output, final_state_out diff --git a/contrib/models/Qwen3.6-27B/test/__init__.py b/contrib/models/Qwen3.6-27B/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Qwen3.6-27B/test/integration/__init__.py b/contrib/models/Qwen3.6-27B/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py b/contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py new file mode 100644 index 00000000..1c06f89f --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python3 +"""Compile Qwen3.6-27B 64K with a scoped FP8 weight-quantization ablation. + +This script intentionally starts from the validated 64K hybrid/chunked-prefill +baseline and changes only weight quantization. The first supported mode is +``mlp_only``: MLP linear weights are converted to FP8 while attention, DeltaNet, +normalization, embeddings, lm_head, KV cache, and recurrent state remain BF16. +""" + +from __future__ import annotations + +import argparse +import gc +import json +import os +import sys +from pathlib import Path + +import torch + + +def _repo_root(path: str | None) -> Path: + if path: + return Path(path).expanduser().resolve() + return Path(__file__).resolve().parents[5] + + +def _load_text_config(model_path: Path) -> dict: + with (model_path / "config.json").open() as f: + full_config = json.load(f) + text_config = full_config.get("text_config", full_config) + config_dict = dict(text_config) + config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) + if "rope_parameters" in text_config: + config_dict["rope_theta"] = text_config["rope_parameters"].get( + "rope_theta", 10000000 + ) + config_dict.setdefault("tie_word_embeddings", False) + return config_dict + + +def _mlp_only_modules_to_not_convert(num_layers: int) -> list[str]: + """Exclude numerically sensitive or unsupported modules from FP8 conversion.""" + modules = [ + "embed_tokens", + "model.embed_tokens", + "lm_head", + "norm", + "model.norm", + "rotary_emb", + "model.rotary_emb", + ] + for layer_idx in range(num_layers): + for prefix in ("layers", "model.layers"): + modules.extend( + [ + f"{prefix}.{layer_idx}.self_attn", + f"{prefix}.{layer_idx}.linear_attn", + f"{prefix}.{layer_idx}.input_layernorm", + f"{prefix}.{layer_idx}.post_attention_layernorm", + ] + ) + return modules + + +def _quantized_checkpoint_ready(path: Path) -> bool: + if path.is_file(): + return True + if path.is_dir(): + return any(path.iterdir()) + return False + + +def _is_mlp_weight(name: str) -> bool: + parts = name.split(".") + return ( + len(parts) >= 4 + and parts[-3] == "mlp" + and parts[-2] in {"gate_proj", "up_proj", "down_proj"} + and parts[-1] == "weight" + ) + + +def _scale_name(weight_name: str) -> str: + return weight_name[: -len(".weight")] + ".weight_scale" + + +def _clear_quantized_checkpoint_dir(path: Path) -> None: + path.mkdir(parents=True, exist_ok=True) + for child in path.iterdir(): + if child.name.endswith(".safetensors") or child.name.endswith(".json"): + child.unlink() + + +def _save_mlp_only_fp8_state_dict(model_path: Path, output_path: Path) -> None: + """Create a sharded FP8 checkpoint directly from HF safetensors. + + Loading the HF architecture requires a newer Transformers than the Neuron + venv uses internally. For this MLP-only ablation, we do not need model + execution: the checkpoint transform is a direct tensor rewrite. + """ + from safetensors.torch import load_file, save_file # noqa: WPS433 + from neuronx_distributed.quantization.quantization_utils import ( # noqa: WPS433 + quantize_fp8_per_channel, + ) + + index_path = model_path / "model.safetensors.index.json" + if index_path.exists(): + with index_path.open() as f: + source_index = json.load(f) + source_weight_map = source_index["weight_map"] + filenames = sorted(set(source_weight_map.values())) + elif (model_path / "model.safetensors").exists(): + source_weight_map = None + filenames = ["model.safetensors"] + else: + raise FileNotFoundError(f"No safetensors checkpoint found in {model_path}") + + _clear_quantized_checkpoint_dir(output_path) + output_weight_map: dict[str, str] = {} + total_size = 0 + quantized_count = 0 + + for filename in filenames: + shard = load_file(str(model_path / filename)) + output_shard = {} + for name, tensor in shard.items(): + if _is_mlp_weight(name): + weight, scale = quantize_fp8_per_channel( + tensor, + torch.float8_e4m3fn, + channel_axis=0, + ) + output_shard[name] = weight + output_shard[_scale_name(name)] = scale + output_weight_map[_scale_name(name)] = filename + total_size += weight.numel() * weight.element_size() + total_size += scale.numel() * scale.element_size() + quantized_count += 1 + else: + output_shard[name] = tensor + total_size += tensor.numel() * tensor.element_size() + output_weight_map[name] = filename + + save_file(output_shard, str(output_path / filename), metadata={"format": "pt"}) + del shard + del output_shard + gc.collect() + + if source_weight_map is not None: + with (output_path / "model.safetensors.index.json").open("w") as f: + json.dump( + { + "metadata": {"total_size": total_size}, + "weight_map": output_weight_map, + }, + f, + indent=2, + sort_keys=True, + ) + + print("MANUAL_FP8_MLP_WEIGHT_COUNT", quantized_count, flush=True) + + +def _build_config(args: argparse.Namespace): + from neuronx_distributed_inference.models.config import ( # noqa: WPS433 + NeuronConfig, + OnDeviceSamplingConfig, + ) + from src.modeling_qwen35 import Qwen35InferenceConfig # noqa: WPS433 + + model_path = Path(args.model_path).expanduser().resolve() + config_dict = _load_text_config(model_path) + num_layers = int(config_dict["num_hidden_layers"]) + modules_to_not_convert = _mlp_only_modules_to_not_convert(num_layers) + + neuron_config = NeuronConfig( + tp_degree=args.tp_degree, + batch_size=1, + ctx_batch_size=1, + tkg_batch_size=1, + seq_len=args.seq_len, + max_context_length=args.cte_bucket, + max_length=args.seq_len, + context_encoding_buckets=[args.cte_bucket], + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig( + do_sample=False, + top_k=1, + top_p=1.0, + temperature=1.0, + ), + enable_bucketing=False, + logical_nc_config=args.logical_nc_config, + save_sharded_checkpoint=True, + quantized=True, + quantized_checkpoints_path=str( + Path(args.quantized_checkpoints_path).expanduser().resolve() + ), + quantization_type="per_channel_symmetric", + quantization_dtype="f8e4m3", + modules_to_not_convert=modules_to_not_convert, + kv_cache_quant=False, + quantized_mlp_kernel_enabled=False, + activation_quantization_type=None, + ) + + config_dict.setdefault("use_hybrid_cache_manager", True) + config_dict.setdefault("use_qwen_hybrid_chunked_prefill", True) + config_dict.setdefault("use_qwen_hybrid_chunked_prefill_nki", True) + + inf_config = Qwen35InferenceConfig(neuron_config=neuron_config, **config_dict) + return inf_config, modules_to_not_convert + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--repo-root", default=None) + parser.add_argument("--model-path", required=True) + parser.add_argument("--compiled-path", required=True) + parser.add_argument("--quantized-checkpoints-path", required=True) + parser.add_argument("--seq-len", type=int, default=65536) + parser.add_argument("--cte-bucket", type=int, default=512) + parser.add_argument("--tp-degree", type=int, default=4) + parser.add_argument("--logical-nc-config", type=int, default=2) + parser.add_argument("--force-quantize", action="store_true") + parser.add_argument("--quantize-only", action="store_true") + parser.add_argument("--load-after-compile", action="store_true") + args = parser.parse_args() + + repo = _repo_root(args.repo_root) + contrib_model_dir = repo / "contrib" / "models" / "Qwen3.6-27B" + sys.path.insert(0, str(repo)) + sys.path.insert(0, str(contrib_model_dir)) + + from src.modeling_qwen35 import NeuronQwen35ForCausalLM # noqa: WPS433 + + model_path = Path(args.model_path).expanduser().resolve() + compiled_path = Path(args.compiled_path).expanduser().resolve() + quantized_path = Path(args.quantized_checkpoints_path).expanduser().resolve() + + inf_config, modules_to_not_convert = _build_config(args) + + print("FP8_MODE mlp_only", flush=True) + print("MODEL_PATH", str(model_path), flush=True) + print("COMPILED_PATH", str(compiled_path), flush=True) + print("QUANTIZED_CHECKPOINTS_PATH", str(quantized_path), flush=True) + print("MODULES_TO_NOT_CONVERT_COUNT", len(modules_to_not_convert), flush=True) + print( + "CONTEXT_TRACE_SHAPE", + json.dumps( + { + "seq_len": args.seq_len, + "max_context_length": args.cte_bucket, + "context_encoding_buckets": [args.cte_bucket], + }, + sort_keys=True, + ), + flush=True, + ) + + if args.force_quantize or not _quantized_checkpoint_ready(quantized_path): + print("QUANTIZE_START manual_mlp_only", flush=True) + _save_mlp_only_fp8_state_dict(model_path, quantized_path) + print("QUANTIZE_DONE", flush=True) + else: + print("QUANTIZE_SKIP existing checkpoint found", flush=True) + + if args.quantize_only: + return 0 + + print("COMPILE_START", flush=True) + model = NeuronQwen35ForCausalLM(str(model_path), inf_config) + model.compile(str(compiled_path)) + del model + gc.collect() + print("COMPILE_DONE", flush=True) + + if args.load_after_compile: + model = NeuronQwen35ForCausalLM(str(compiled_path)) + model.load(str(compiled_path)) + print("LOAD_AFTER_COMPILE_OK", flush=True) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/test/integration/test_model.py b/contrib/models/Qwen3.6-27B/test/integration/test_model.py new file mode 100644 index 00000000..b1128c12 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/integration/test_model.py @@ -0,0 +1,605 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Integration tests for Qwen3.6-27B on Neuron. + +Tests compilation, loading, inference accuracy, and performance using +the full 27B model with pre-downloaded HuggingFace weights on a trn2 instance. + +Qwen3.6-27B shares identical architecture with Qwen3.5-27B (qwen3_5 model_type). +These tests use the same Qwen35* classes and QWEN35_* env vars because the +underlying code is shared. + +Note: A mini model option is not provided because DeltaNet layers require NKI +kernels that only execute on Neuron devices, and the hybrid DeltaNet + GQA +architecture needs at least TP=4 for the full model to fit in HBM. + +Environment variables: + QWEN35_MODEL_PATH Path to HF model weights (required) + QWEN35_COMPILED_PATH Path to compiled artifacts (default: /tmp/qwen35_27b_traced) + QWEN35_TP_DEGREE Tensor parallelism degree (default: 4) + QWEN35_SEQ_LEN Max sequence length (default: 128) + TTFT_THRESHOLD_MS Max TTFT in ms (default: 5000) + THROUGHPUT_THRESHOLD Min throughput in tok/s (default: 5.0) + +Prerequisites: + - trn2.3xlarge or larger with TP >= 4 NeuronCores available + - NXDI installed (neuronx_distributed_inference) + - HuggingFace weights downloaded to QWEN35_MODEL_PATH + - SDK 2.29+ (NKI 0.3.0 required for DeltaNet kernels) + +Usage: + # Full model (trn2.3xlarge, TP=4): + QWEN35_MODEL_PATH=/mnt/models/Qwen3.6-27B \\ + QWEN35_COMPILED_PATH=/mnt/models/qwen36_traced \\ + pytest test/integration/test_model.py --capture=tee-sys +""" + +import gc +import json +import os +import shutil +import subprocess +import sys +import time + +import pytest +import torch + +# Ensure the contrib root (Qwen3.6-27B/) is on sys.path +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +# ── Configuration from environment ────────────────────────────────────── + +MODEL_PATH = os.environ.get("QWEN35_MODEL_PATH", "") +COMPILED_PATH = os.environ.get("QWEN35_COMPILED_PATH", "/tmp/qwen35_27b_traced") +TP_DEGREE = int(os.environ.get("QWEN35_TP_DEGREE", "4")) +SEQ_LEN = int(os.environ.get("QWEN35_SEQ_LEN", "128")) +TTFT_THRESHOLD_MS = float(os.environ.get("TTFT_THRESHOLD_MS", "5000")) +THROUGHPUT_THRESHOLD = float(os.environ.get("THROUGHPUT_THRESHOLD", "5.0")) +USE_HYBRID_CACHE = os.environ.get("QWEN35_USE_HYBRID_CACHE", "0") == "1" +RECORD_HBM = os.environ.get("QWEN35_RECORD_HBM", "0") == "1" + +requires_model_path = pytest.mark.skipif( + not MODEL_PATH, + reason=( + "QWEN35_MODEL_PATH not set. Integration tests require the full 27B model " + "weights. Set QWEN35_MODEL_PATH=/path/to/Qwen3.6-27B to run these tests." + ), +) +requires_hbm_recording = pytest.mark.skipif( + not RECORD_HBM, + reason=( + "QWEN35_RECORD_HBM=1 not set. This optional test records Neuron HBM " + "usage for dummy-KV vs hybrid-cache comparisons." + ), +) + + +# ── Fixtures ──────────────────────────────────────────────────────────── + + +@pytest.fixture(scope="module") +def model_path(): + """Return path to model weights.""" + return MODEL_PATH + + +@pytest.fixture(scope="module") +def compiled_model(model_path): + """Compile and load the model on Neuron.""" + import json + + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronQwen35ForCausalLM + + neuron_config = NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + ctx_batch_size=1, + tkg_batch_size=1, + seq_len=SEQ_LEN, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + enable_bucketing=False, + flash_decoding_enabled=False, + logical_nc_config=2, + save_sharded_checkpoint=True, + ) + + # Read config.json directly (model_type 'qwen3_5' may not be in + # AutoConfig registry for all transformers versions) + with open(os.path.join(model_path, "config.json")) as f: + full_config = json.load(f) + text_config = full_config.get("text_config", full_config) + + config_dict = dict(text_config) + config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) + if "rope_parameters" in text_config: + config_dict["rope_theta"] = text_config["rope_parameters"].get( + "rope_theta", 10000000 + ) + config_dict.setdefault("tie_word_embeddings", False) + + inf_config = Qwen35InferenceConfig( + neuron_config=neuron_config, + use_hybrid_cache_manager=USE_HYBRID_CACHE, + **config_dict, + ) + + # Compile if no existing artifacts + compiled_path = COMPILED_PATH + neff_path = os.path.join(compiled_path, "model.pt") + if not os.path.exists(neff_path): + print(f"Compiling to {compiled_path}...") + model = NeuronQwen35ForCausalLM(model_path, inf_config) + model.compile(compiled_path) + del model + gc.collect() + + # Load + print(f"Loading from {compiled_path}...") + model = NeuronQwen35ForCausalLM(compiled_path) + model.load(compiled_path) + return model + + +@pytest.fixture(scope="module") +def tokenizer(model_path): + """Load tokenizer.""" + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(model_path, padding_side="right") + if tok.pad_token is None: + tok.pad_token = tok.eos_token + return tok + + +@pytest.fixture(scope="module") +def generation_config(tokenizer): + """Create generation config.""" + from transformers import GenerationConfig + + return GenerationConfig( + do_sample=True, + top_k=1, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + +def _generate(model, tokenizer, generation_config, prompt, max_new_tokens=20): + """Generate text using the NXDI model.""" + import transformers + + from neuronx_distributed_inference.utils.hf_adapter import ( + HuggingFaceGenerationAdapter, + ) + + inputs = tokenizer(prompt, padding=True, return_tensors="pt") + gen_model = HuggingFaceGenerationAdapter(model) + gen_model.generation_config.transformers_version = transformers.__version__ + generation_config.transformers_version = transformers.__version__ + outputs = gen_model.generate( + inputs.input_ids, + generation_config=generation_config, + attention_mask=inputs.attention_mask, + max_new_tokens=max_new_tokens, + ) + return outputs[0].tolist(), tokenizer.decode(outputs[0], skip_special_tokens=True) + + +def _is_repetitive(text, max_repeat=5): + """Check for excessive word repetition.""" + words = text.split() + if len(words) < max_repeat: + return False + for i in range(len(words) - max_repeat + 1): + if len(set(words[i : i + max_repeat])) == 1: + return True + return False + + +def _parse_peak_neuron_memory(stdout): + peak_device = 0 + peak_tensors = 0 + samples = 0 + for line in stdout.splitlines(): + line = line.strip() + if not line: + continue + try: + report = json.loads(line) + except json.JSONDecodeError: + continue + for runtime in report.get("neuron_runtime_data", []): + memory_used = runtime.get("report", {}).get("memory_used", {}) + used = memory_used.get("neuron_runtime_used_bytes", {}) + peak_device = max(peak_device, int(used.get("neuron_device", 0) or 0)) + nc_usage = ( + used.get("usage_breakdown", {}).get("neuroncore_memory_usage", {}) + ) + tensor_bytes = sum( + int(core.get("tensors", 0) or 0) for core in nc_usage.values() + ) + peak_tensors = max(peak_tensors, tensor_bytes) + samples += 1 + return peak_device, peak_tensors, samples + + +def _capture_neuron_hbm(tmp_path, fn): + if shutil.which("neuron-monitor") is None: + pytest.skip("neuron-monitor is not available") + + monitor_config = { + "period": "0.5s", + "neuron_runtimes": [ + { + "tag_filter": ".*", + "metrics": [{"type": "memory_used", "period": "0.5s"}], + } + ], + } + config_path = tmp_path / "neuron-monitor.json" + config_path.write_text(json.dumps(monitor_config)) + + proc = subprocess.Popen( + ["neuron-monitor", "--config-file", str(config_path)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + try: + time.sleep(1.0) + result = fn() + time.sleep(1.0) + finally: + proc.terminate() + try: + stdout, stderr = proc.communicate(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + stdout, stderr = proc.communicate(timeout=5) + + peak_device, peak_tensors, samples = _parse_peak_neuron_memory(stdout) + assert samples > 0, f"neuron-monitor produced no runtime samples: {stderr}" + assert peak_device > 0, "Expected non-zero Neuron device HBM usage" + return result, peak_device, peak_tensors, samples + + +# ── Smoke Tests ───────────────────────────────────────────────────────── + + +@requires_model_path +def test_model_loads(compiled_model): + """Model compiles and loads successfully.""" + assert compiled_model is not None + assert hasattr(compiled_model, "neuron_config") + print(" Model loaded successfully") + + +@requires_model_path +def test_model_generates(compiled_model, tokenizer, generation_config): + """Model generates at least 5 tokens.""" + tokens, text = _generate( + compiled_model, + tokenizer, + generation_config, + "Hello, I am a language model", + max_new_tokens=20, + ) + input_len = len(tokenizer.encode("Hello, I am a language model")) + new_tokens = len(tokens) - input_len + assert new_tokens >= 5, f"Expected >= 5 new tokens, got {new_tokens}" + print(f" Generated {new_tokens} tokens: {text[:100]}...") + + +# ── Accuracy Tests ────────────────────────────────────────────────────── + + +@requires_model_path +def test_output_coherence(compiled_model, tokenizer, generation_config): + """Output should contain multiple words and not be excessively repetitive.""" + _, text = _generate( + compiled_model, + tokenizer, + generation_config, + "The capital of France is", + max_new_tokens=30, + ) + generated = text[len("The capital of France is") :].strip() + words = generated.split() + assert len(words) >= 3, f"Expected >= 3 words, got {len(words)}: '{generated}'" + assert not _is_repetitive(generated), ( + f"Output is excessively repetitive: '{generated}'" + ) + print(f" Output coherent ({len(words)} words): {generated[:80]}...") + + +@requires_model_path +def test_top_token_valid(compiled_model, tokenizer, generation_config): + """First generated token should be a valid decodable token.""" + tokens, _ = _generate( + compiled_model, + tokenizer, + generation_config, + "Hello!", + max_new_tokens=1, + ) + input_len = len(tokenizer.encode("Hello!")) + first_new = tokens[input_len] + assert 0 <= first_new < len(tokenizer), ( + f"Token {first_new} out of vocab range" + ) + decoded = tokenizer.decode([first_new]) + assert len(decoded) > 0, f"Token {first_new} decoded to empty string" + print(f" First token: {first_new} -> '{decoded}'") + + +@requires_model_path +def test_olympics_prompt_no_invalid_tokens( + compiled_model, tokenizer, generation_config +): + """Regression test for NaN logits producing the int32-min token id.""" + prompt = "Give me a summary of the 2020 Olympics in 100 tokens." + tokens, _ = _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=32, + ) + input_len = len(tokenizer.encode(prompt)) + generated = tokens[input_len:] + invalid = [token for token in generated if token < 0 or token >= len(tokenizer)] + + assert len(generated) >= 5, f"Expected >= 5 generated tokens, got {generated}" + assert not invalid, f"Generated invalid token ids: {invalid}" + + +@requires_model_path +def test_capital_of_france(compiled_model, tokenizer, generation_config): + """'The capital of France is' should produce 'Paris' in the response.""" + tokens, text = _generate( + compiled_model, + tokenizer, + generation_config, + "The capital of France is", + max_new_tokens=30, + ) + generated = text[len("The capital of France is") :].strip() + assert "paris" in generated.lower(), ( + f"Expected 'Paris' in output, got: '{generated}'" + ) + print(f" Capital of France: {generated}") + + +# ── Performance Tests ─────────────────────────────────────────────────── + + +@requires_model_path +def test_performance_ttft(compiled_model, tokenizer, generation_config): + """Time to first token should be within threshold.""" + prompt = "Hello, I am a language model" + + # Warmup + _generate(compiled_model, tokenizer, generation_config, prompt, max_new_tokens=1) + + # Measure + times = [] + for _ in range(3): + t0 = time.perf_counter() + _generate( + compiled_model, tokenizer, generation_config, prompt, max_new_tokens=1 + ) + times.append((time.perf_counter() - t0) * 1000) + + avg_ms = sum(times) / len(times) + print(f" TTFT: {avg_ms:.1f} ms (threshold: {TTFT_THRESHOLD_MS} ms)") + assert avg_ms < TTFT_THRESHOLD_MS, ( + f"TTFT {avg_ms:.1f}ms > threshold {TTFT_THRESHOLD_MS}ms" + ) + + +@requires_model_path +def test_performance_throughput(compiled_model, tokenizer, generation_config): + """Throughput should meet minimum threshold.""" + prompt = "Once upon a time" + num_new_tokens = 20 + + # Warmup + _generate(compiled_model, tokenizer, generation_config, prompt, max_new_tokens=5) + + # Measure + t0 = time.perf_counter() + tokens, _ = _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=num_new_tokens, + ) + elapsed = time.perf_counter() - t0 + + input_len = len(tokenizer.encode(prompt)) + actual_new = len(tokens) - input_len + throughput = actual_new / elapsed if elapsed > 0 else 0 + + print( + f" Throughput: {throughput:.1f} tok/s ({actual_new} tokens in {elapsed:.2f}s)" + ) + print(f" Threshold: {THROUGHPUT_THRESHOLD} tok/s") + assert throughput > THROUGHPUT_THRESHOLD, ( + f"Throughput {throughput:.1f} tok/s < threshold {THROUGHPUT_THRESHOLD}" + ) + + +@requires_model_path +@requires_hbm_recording +def test_hybrid_cache_hbm_snapshot(compiled_model, tokenizer, generation_config, tmp_path): + """Record peak Neuron HBM for dummy-KV vs hybrid-cache comparison runs.""" + prompt = "Give me a summary of the 2020 Olympics in 100 tokens." + max_new_tokens = int(os.environ.get("QWEN35_HBM_NEW_TOKENS", "32")) + + (_, text), peak_device, peak_tensors, samples = _capture_neuron_hbm( + tmp_path, + lambda: _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=max_new_tokens, + ), + ) + + mode = "hybrid" if USE_HYBRID_CACHE else "dummy_kv" + print( + " HBM " + f"mode={mode} peak_device_bytes={peak_device} " + f"peak_tensor_bytes={peak_tensors} samples={samples}" + ) + assert len(text) > len(prompt) + + +# ── Multi-Prompt Quality Test ────────────────────────────────────────── + + +@requires_model_path +def test_multi_prompt_generation(compiled_model, tokenizer, generation_config): + """Multiple prompts should produce coherent outputs.""" + prompts = [ + "The capital of France is", + "def fibonacci(n):", + "The largest ocean on Earth is", + "To make a chocolate cake, you need", + ] + + for prompt in prompts: + _, text = _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=30, + ) + generated = text[len(prompt) :].strip() + words = generated.split() + assert len(words) >= 2, ( + f"Prompt '{prompt}' generated too few words: '{generated}'" + ) + assert not _is_repetitive(generated), ( + f"Prompt '{prompt}' produced repetitive output: '{generated}'" + ) + print(f" '{prompt[:30]}...' -> {generated[:60]}...") + + +# ── Standalone runner ─────────────────────────────────────────────────── + +if __name__ == "__main__": + print("=" * 60) + print("Qwen3.6-27B Integration Tests") + print("=" * 60) + + if not MODEL_PATH: + print("\nQWEN35_MODEL_PATH not set. Provide the model path to run tests:") + print(" QWEN35_MODEL_PATH=/path/to/Qwen3.6-27B \\") + print(" QWEN35_COMPILED_PATH=/mnt/models/qwen35_traced \\") + print(" python -m pytest test/integration/test_model.py --capture=tee-sys") + sys.exit(0) + + # Setup + from transformers import AutoTokenizer, GenerationConfig as GenConfig + + tok = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side="right") + if tok.pad_token is None: + tok.pad_token = tok.eos_token + gen_cfg = GenConfig( + do_sample=True, + top_k=1, + pad_token_id=tok.pad_token_id, + eos_token_id=tok.eos_token_id, + ) + + # Build model + import json + + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronQwen35ForCausalLM + + nc = NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + ctx_batch_size=1, + tkg_batch_size=1, + seq_len=SEQ_LEN, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + enable_bucketing=False, + flash_decoding_enabled=False, + logical_nc_config=2, + save_sharded_checkpoint=True, + ) + + with open(os.path.join(MODEL_PATH, "config.json")) as f: + full_config = json.load(f) + text_config = full_config.get("text_config", full_config) + config_dict = dict(text_config) + config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) + if "rope_parameters" in text_config: + config_dict["rope_theta"] = text_config["rope_parameters"].get( + "rope_theta", 10000000 + ) + config_dict.setdefault("tie_word_embeddings", False) + ic = Qwen35InferenceConfig(neuron_config=nc, **config_dict) + + cp = COMPILED_PATH + if not os.path.exists(os.path.join(cp, "model.pt")): + print(f"Compiling to {cp}...") + m = NeuronQwen35ForCausalLM(MODEL_PATH, ic) + m.compile(cp) + del m + gc.collect() + + print(f"Loading from {cp}...") + model = NeuronQwen35ForCausalLM(cp) + model.load(cp) + + tests = [ + ("model_loads", lambda: test_model_loads(model)), + ("model_generates", lambda: test_model_generates(model, tok, gen_cfg)), + ("output_coherence", lambda: test_output_coherence(model, tok, gen_cfg)), + ("top_token_valid", lambda: test_top_token_valid(model, tok, gen_cfg)), + ("capital_of_france", lambda: test_capital_of_france(model, tok, gen_cfg)), + ("performance_ttft", lambda: test_performance_ttft(model, tok, gen_cfg)), + ( + "performance_throughput", + lambda: test_performance_throughput(model, tok, gen_cfg), + ), + ( + "multi_prompt_generation", + lambda: test_multi_prompt_generation(model, tok, gen_cfg), + ), + ] + + passed = 0 + for name, fn in tests: + print(f"\n--- {name} ---") + try: + fn() + print(f" PASS") + passed += 1 + except Exception as e: + print(f" FAIL: {e}") + + print(f"\n{'=' * 60}") + print(f"Results: {passed}/{len(tests)} passed") + print(f"{'=' * 60}") diff --git a/contrib/models/Qwen3.6-27B/test/unit/__init__.py b/contrib/models/Qwen3.6-27B/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Qwen3.6-27B/test/unit/test_config.py b/contrib/models/Qwen3.6-27B/test/unit/test_config.py new file mode 100644 index 00000000..571ad522 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/unit/test_config.py @@ -0,0 +1,201 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for Qwen3.5/3.6-27B inference configuration. + +CPU-only tests that validate config parsing, layer type setup, +DeltaNet parameter defaults, RoPE configuration, and weight conversion logic. +These tests are architecture-level and apply to both Qwen3.5-27B and Qwen3.6-27B. +""" + +import os +import sys +import unittest +from unittest.mock import MagicMock + +import torch + +# Ensure the contrib root (Qwen3.6-27B/) is on sys.path +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from src.modeling_qwen35 import ( + Qwen35InferenceConfig, + convert_qwen35_hf_to_neuron_state_dict, +) +from neuronx_distributed_inference.models.config import NeuronConfig + + +def _make_config(**overrides): + """Create a Qwen35InferenceConfig with reasonable defaults.""" + neuron_config = NeuronConfig( + tp_degree=overrides.pop("tp_degree", 4), + batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + ) + defaults = dict( + hidden_size=5120, + num_hidden_layers=64, + num_attention_heads=24, + num_key_value_heads=4, + head_dim=256, + intermediate_size=17408, + vocab_size=248320, + rms_norm_eps=1e-6, + max_position_embeddings=131072, + rope_theta=10000, + hidden_act="silu", + # DeltaNet-specific + linear_num_value_heads=48, + linear_num_key_heads=16, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_conv_kernel_dim=4, + ) + defaults.update(overrides) + config = Qwen35InferenceConfig(neuron_config=neuron_config, **defaults) + return config + + +class TestConfigParsing(unittest.TestCase): + """Test basic config attribute initialization.""" + + def test_hidden_size(self): + config = _make_config() + self.assertEqual(config.hidden_size, 5120) + + def test_num_hidden_layers(self): + config = _make_config() + self.assertEqual(config.num_hidden_layers, 64) + + def test_num_attention_heads(self): + config = _make_config() + self.assertEqual(config.num_attention_heads, 24) + + def test_num_key_value_heads(self): + config = _make_config() + self.assertEqual(config.num_key_value_heads, 4) + + def test_head_dim(self): + config = _make_config() + self.assertEqual(config.head_dim, 256) + + def test_intermediate_size(self): + config = _make_config() + self.assertEqual(config.intermediate_size, 17408) + + def test_vocab_size(self): + config = _make_config() + self.assertEqual(config.vocab_size, 248320) + + def test_hidden_act(self): + config = _make_config() + self.assertEqual(config.hidden_act, "silu") + + +class TestLayerTypes(unittest.TestCase): + """Test hybrid layer type assignment (3 DeltaNet + 1 GQA) x 16.""" + + def test_layer_types_length(self): + config = _make_config() + self.assertEqual(len(config.layer_types), 64) + + def test_layer_types_pattern(self): + """Every 4th layer (3, 7, 11, ...) should be full_attention.""" + config = _make_config() + for i in range(64): + expected = "full_attention" if i % 4 == 3 else "linear_attention" + self.assertEqual(config.layer_types[i], expected, f"Layer {i} mismatch") + + def test_deltanet_layer_count(self): + config = _make_config() + dn_count = sum(1 for t in config.layer_types if t == "linear_attention") + self.assertEqual(dn_count, 48) + + def test_gqa_layer_count(self): + config = _make_config() + gqa_count = sum(1 for t in config.layer_types if t == "full_attention") + self.assertEqual(gqa_count, 16) + + +class TestDeltaNetConfig(unittest.TestCase): + """Test DeltaNet-specific configuration defaults.""" + + def test_linear_num_value_heads(self): + config = _make_config() + self.assertEqual(config.linear_num_value_heads, 48) + + def test_linear_num_key_heads(self): + config = _make_config() + self.assertEqual(config.linear_num_key_heads, 16) + + def test_linear_key_head_dim(self): + config = _make_config() + self.assertEqual(config.linear_key_head_dim, 128) + + def test_linear_value_head_dim(self): + config = _make_config() + self.assertEqual(config.linear_value_head_dim, 128) + + def test_linear_conv_kernel_dim(self): + config = _make_config() + self.assertEqual(config.linear_conv_kernel_dim, 4) + + +class TestRoPEConfig(unittest.TestCase): + """Test partial RoPE configuration.""" + + def test_partial_rotary_factor(self): + config = _make_config() + self.assertAlmostEqual(config.partial_rotary_factor, 0.25) + + def test_rope_dim(self): + """rope_dim = head_dim * partial_rotary_factor = 256 * 0.25 = 64.""" + config = _make_config() + self.assertEqual(config.rope_dim, 64) + + def test_attn_output_gate(self): + config = _make_config() + self.assertTrue(config.attn_output_gate) + + def test_mrope_section(self): + config = _make_config() + self.assertEqual(config.mrope_section, [11, 11, 10]) + + def test_mrope_interleaved(self): + config = _make_config() + self.assertTrue(config.mrope_interleaved) + + +class TestNeuronConfig(unittest.TestCase): + """Test Neuron-specific configuration settings.""" + + def test_neuron_config_cls(self): + """Qwen3.5/3.6-27B is dense -- uses NeuronConfig, NOT MoENeuronConfig.""" + self.assertEqual( + Qwen35InferenceConfig.get_neuron_config_cls(), + NeuronConfig, + ) + + def test_required_attributes(self): + config = _make_config() + required = config.get_required_attributes() + self.assertIn("hidden_size", required) + self.assertIn("num_hidden_layers", required) + self.assertIn("linear_num_value_heads", required) + self.assertIn("linear_key_head_dim", required) + self.assertIn("layer_types", required) + + def test_output_attentions_default(self): + config = _make_config() + self.assertFalse(config.output_attentions) + + def test_output_hidden_states_default(self): + config = _make_config() + self.assertFalse(config.output_hidden_states) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py b/contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py new file mode 100644 index 00000000..416a431a --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py @@ -0,0 +1,68 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for fused DeltaNet log-decay bounding.""" + +import os +import sys +import unittest + +import torch + +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from src.modeling_qwen35 import ( + FUSED_DELTANET_DECAY_MAX, + FUSED_DELTANET_DECAY_MIN, + _bound_fused_deltanet_log_decay, +) + + +def _chunked_cumsum(g, batch_size, num_heads, total_seq_len, chunk_size): + num_chunks = total_seq_len // chunk_size + return g.reshape(batch_size, num_heads, num_chunks, chunk_size).cumsum(dim=-1) + + +class TestFusedDeltaNetDecayBounding(unittest.TestCase): + def test_preserves_non_extreme_decay(self): + batch_size, num_heads, total_seq_len, chunk_size = 2, 3, 16, 8 + g = torch.full( + (batch_size, num_heads, total_seq_len), + -0.125, + dtype=torch.float32, + ) + + bounded = _bound_fused_deltanet_log_decay( + g, batch_size, num_heads, total_seq_len, chunk_size + ) + + torch.testing.assert_close(bounded, g) + + def test_bounds_per_chunk_cumulative_decay(self): + batch_size, num_heads, total_seq_len, chunk_size = 2, 3, 16, 8 + g = torch.full( + (batch_size, num_heads, total_seq_len), + -10.0, + dtype=torch.float32, + ) + + bounded = _bound_fused_deltanet_log_decay( + g, batch_size, num_heads, total_seq_len, chunk_size + ) + bounded_cumsum = _chunked_cumsum( + bounded, batch_size, num_heads, total_seq_len, chunk_size + ) + expected_cumsum = _chunked_cumsum( + g, batch_size, num_heads, total_seq_len, chunk_size + ).clamp(min=FUSED_DELTANET_DECAY_MIN, max=FUSED_DELTANET_DECAY_MAX) + + torch.testing.assert_close(bounded_cumsum, expected_cumsum) + self.assertGreaterEqual(float(bounded_cumsum.min()), FUSED_DELTANET_DECAY_MIN) + self.assertLessEqual(float(bounded_cumsum.max()), FUSED_DELTANET_DECAY_MAX) + self.assertTrue(torch.isfinite(bounded).all()) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.6-27B/test/unit/test_hybrid_cache_manager.py b/contrib/models/Qwen3.6-27B/test/unit/test_hybrid_cache_manager.py new file mode 100644 index 00000000..fa887ca2 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/unit/test_hybrid_cache_manager.py @@ -0,0 +1,314 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +import sys +import unittest +from math import prod +from unittest.mock import patch + +import torch + +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from neuronx_distributed_inference.models.config import NeuronConfig +from src.modeling_qwen35 import HybridDeltaNetCacheManager, Qwen35InferenceConfig + + +def _make_config(**overrides): + neuron_overrides = overrides.pop("neuron_overrides", {}) + neuron_kwargs = dict( + tp_degree=overrides.pop("tp_degree", 4), + batch_size=1, + max_batch_size=2, + kv_cache_batch_size=2, + seq_len=16, + torch_dtype=torch.bfloat16, + ) + neuron_kwargs.update(neuron_overrides) + neuron_config = NeuronConfig(**neuron_kwargs) + defaults = dict( + hidden_size=5120, + num_hidden_layers=64, + num_attention_heads=24, + num_key_value_heads=4, + head_dim=256, + intermediate_size=17408, + vocab_size=248320, + rms_norm_eps=1e-6, + max_position_embeddings=131072, + rope_theta=10000, + hidden_act="silu", + tie_word_embeddings=False, + linear_num_value_heads=48, + linear_num_key_heads=16, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_conv_kernel_dim=4, + use_hybrid_cache_manager=True, + ) + defaults.update(overrides) + return Qwen35InferenceConfig(neuron_config=neuron_config, **defaults) + + +def _numel(shape): + return prod(int(dim) for dim in shape) + + +def _managed_cache_numel(mgr): + return sum(param.numel() for param in mgr.past_key_values) + + +def _deltanet_state_numel(config, max_batch_size): + recurrent = ( + max_batch_size + * config.linear_num_value_heads + * config.linear_key_head_dim + * config.linear_value_head_dim + ) + conv_dim = ( + 2 * config.linear_num_key_heads * config.linear_key_head_dim + + config.linear_num_value_heads * config.linear_value_head_dim + ) + conv = max_batch_size * conv_dim * (config.linear_conv_kernel_dim - 1) + return recurrent + conv + + +class TestHybridDeltaNetCacheManager(unittest.TestCase): + def test_allocates_per_layer_cache_shapes(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + self.assertEqual(len(mgr.past_key_values), config.num_hidden_layers * 2) + self.assertEqual(list(mgr.past_key_values[0].shape), [2, 48, 128, 128]) + self.assertEqual(list(mgr.past_key_values[1].shape), [2, 10240, 3]) + self.assertEqual(mgr.layer_types[3], "full_attention") + self.assertEqual(mgr.past_key_values[6].dim(), 4) + self.assertEqual(mgr.past_key_values[7].shape[2], 16) + + def test_get_cache_slices_only_full_attention_layers(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + cache = mgr.get_cache(seq_len=4, seq_ids=torch.tensor([1])) + recurrent_state, conv_state = cache[0] + full_k, full_v = cache[3] + + self.assertEqual(list(recurrent_state.shape), [1, 48, 128, 128]) + self.assertEqual(list(conv_state.shape), [1, 10240, 3]) + self.assertEqual(full_k.shape[0], 2) + self.assertEqual(full_v.shape[0], 2) + self.assertEqual(full_k.shape[2], 4) + self.assertEqual(full_v.shape[2], 4) + + def test_get_seq_length_uses_first_full_attention_layer(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + nested_cache = mgr.get_cache(seq_len=5, seq_ids=torch.tensor([0])) + flat_cache = [tensor for layer_cache in nested_cache for tensor in layer_cache] + + self.assertEqual(nested_cache[0][1].shape[2], 3) + self.assertEqual(mgr.get_seq_length(nested_cache), 5) + self.assertEqual(mgr.get_seq_length(flat_cache), 5) + + def test_get_cache_selects_deltanet_state_rows_by_seq_ids(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + with torch.no_grad(): + mgr.past_key_values[0][0, ...].fill_(7) + mgr.past_key_values[0][1, ...].fill_(13) + mgr.past_key_values[1][0, ...].fill_(17) + mgr.past_key_values[1][1, ...].fill_(19) + + recurrent_state, conv_state = mgr.get_cache( + seq_len=4, + seq_ids=torch.tensor([1, 0]), + )[0] + + self.assertTrue(torch.all(recurrent_state[0] == 13)) + self.assertTrue(torch.all(recurrent_state[1] == 7)) + self.assertTrue(torch.all(conv_state[0] == 19)) + self.assertTrue(torch.all(conv_state[1] == 17)) + + def test_deltanet_update_scatters_by_seq_id(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + recurrent = torch.ones((1, 48, 128, 128), dtype=torch.bfloat16) + conv = torch.ones((1, 10240, 3), dtype=torch.bfloat16) + + updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( + idx=0, + seq_ids=torch.tensor([1]), + state_per_layer=(recurrent, conv), + ) + + self.assertTrue(torch.all(updated_recurrent[0] == 0)) + self.assertTrue(torch.all(updated_conv[0] == 0)) + self.assertTrue(torch.all(updated_recurrent[1] == 1)) + self.assertTrue(torch.all(updated_conv[1] == 1)) + + def test_deltanet_full_batch_update_replaces_state_cache(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + recurrent = torch.ones((2, 48, 128, 128), dtype=torch.bfloat16) + conv = torch.ones((2, 10240, 3), dtype=torch.bfloat16) + recurrent[0].fill_(3) + recurrent[1].fill_(5) + conv[0].fill_(11) + conv[1].fill_(13) + + updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( + idx=0, + seq_ids=torch.tensor([0, 1]), + state_per_layer=(recurrent, conv), + ) + + self.assertTrue(torch.all(updated_recurrent[0] == 3)) + self.assertTrue(torch.all(updated_recurrent[1] == 5)) + self.assertTrue(torch.all(updated_conv[0] == 11)) + self.assertTrue(torch.all(updated_conv[1] == 13)) + + def test_deltanet_update_maps_out_of_range_seq_id_to_padding_row(self): + config = _make_config(neuron_overrides={"kv_cache_padding_size": 1}) + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + recurrent = torch.ones((1, 48, 128, 128), dtype=torch.bfloat16) + conv = torch.ones((1, 10240, 3), dtype=torch.bfloat16) + + updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( + idx=0, + seq_ids=torch.tensor([99]), + state_per_layer=(recurrent, conv), + ) + + self.assertTrue(torch.all(updated_recurrent[0] == 0)) + self.assertTrue(torch.all(updated_recurrent[1] == 0)) + self.assertTrue(torch.all(updated_recurrent[2] == 1)) + self.assertTrue(torch.all(updated_conv[2] == 1)) + + def test_deltanet_state_shapes_do_not_scale_with_sequence_length(self): + short_config = _make_config(neuron_overrides={"seq_len": 128}) + long_config = _make_config(neuron_overrides={"seq_len": 2048}) + short_mgr = HybridDeltaNetCacheManager( + short_config, num_kv_head=short_config.num_key_value_heads + ) + long_mgr = HybridDeltaNetCacheManager( + long_config, num_kv_head=long_config.num_key_value_heads + ) + + self.assertEqual(short_mgr.past_key_values[0].shape, long_mgr.past_key_values[0].shape) + self.assertEqual(short_mgr.past_key_values[1].shape, long_mgr.past_key_values[1].shape) + self.assertLess(short_mgr.past_key_values[7].shape[2], long_mgr.past_key_values[7].shape[2]) + + def test_get_cache_trims_padding_row_without_seq_ids(self): + config = _make_config(neuron_overrides={"kv_cache_padding_size": 1}) + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + recurrent_state, conv_state = mgr.get_cache(seq_len=4)[0] + + self.assertEqual(list(recurrent_state.shape), [2, 48, 128, 128]) + self.assertEqual(list(conv_state.shape), [2, 10240, 3]) + + def test_update_cache_dispatches_deltanet_and_full_attention_layers(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + new_key_values = [] + for idx in range(4): + first = mgr.past_key_values[2 * idx] + second = mgr.past_key_values[2 * idx + 1] + new_key_values.append( + ( + torch.full_like(first, fill_value=idx + 1), + torch.full_like(second, fill_value=idx + 11), + ) + ) + + position_ids = torch.arange(16, dtype=torch.long).unsqueeze(0).expand(2, -1) + full_k_update = torch.full_like(mgr.past_key_values[6], fill_value=4) + full_v_update = torch.full_like(mgr.past_key_values[7], fill_value=14) + with patch.object( + mgr, "update_kv_by_layer_id", return_value=(full_k_update, full_v_update) + ) as update_kv: + updated = mgr.update_cache( + is_for_context_encoding=True, + seq_ids=torch.tensor([0, 1], dtype=torch.int32), + position_ids=position_ids, + new_key_values=new_key_values, + seq_len=16, + ) + + self.assertEqual(update_kv.call_count, 1) + self.assertEqual(update_kv.call_args.kwargs["idx"], 3) + self.assertTrue(torch.all(updated[0] == 1)) + self.assertTrue(torch.all(updated[1] == 11)) + self.assertTrue(torch.all(updated[6] == 4)) + self.assertTrue(torch.all(updated[7] == 14)) + + def test_managed_cache_removes_dummy_kv_for_deltanet_layers(self): + config = _make_config(neuron_overrides={"seq_len": 1024}) + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + max_batch_size = ( + config.neuron_config.kv_cache_batch_size + + config.neuron_config.kv_cache_padding_size + ) + full_kv_per_layer = _numel(mgr.k_shape) + _numel(mgr.v_shape) + deltanet_layers = config.layer_types.count("linear_attention") + legacy_total_numel = ( + full_kv_per_layer * config.num_hidden_layers + + _deltanet_state_numel(config, max_batch_size) * deltanet_layers + ) + expected_savings = full_kv_per_layer * deltanet_layers + + self.assertEqual( + legacy_total_numel - _managed_cache_numel(mgr), + expected_savings, + ) + self.assertLess(_managed_cache_numel(mgr), legacy_total_numel) + + def test_rejects_unsupported_hybrid_modes(self): + unsupported_cases = [ + ({"padding_side": "left"}, "left padding"), + ({"flash_decoding_enabled": True}, "flash decoding"), + ] + + for neuron_overrides, expected_error in unsupported_cases: + with self.subTest(expected_error=expected_error): + config = _make_config(neuron_overrides=neuron_overrides) + with self.assertRaisesRegex(ValueError, expected_error): + HybridDeltaNetCacheManager( + config, num_kv_head=config.num_key_value_heads + ) + + config = _make_config() + config.neuron_config.kv_cache_quant = True + with self.assertRaisesRegex(ValueError, "KV cache quantization"): + HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + config = _make_config( + neuron_overrides={ + "attention_dp_degree": 2, + "batch_size": 2, + "ctx_batch_size": 2, + "tkg_batch_size": 2, + "max_batch_size": 2, + "kv_cache_batch_size": 2, + "is_continuous_batching": True, + } + ) + with self.assertRaisesRegex(ValueError, "attention data parallelism"): + HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + config = _make_config() + config.neuron_config.kv_cache_tiling = True + with self.assertRaisesRegex(ValueError, "KV cache tiling"): + HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + def test_legacy_config_default_is_disabled(self): + config = _make_config(use_hybrid_cache_manager=False) + self.assertFalse(config.use_hybrid_cache_manager) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.6-27B/test/unit/test_weight_conversion.py b/contrib/models/Qwen3.6-27B/test/unit/test_weight_conversion.py new file mode 100644 index 00000000..252da3f4 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/unit/test_weight_conversion.py @@ -0,0 +1,436 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for Qwen3.5/3.6-27B HF-to-NxDI weight conversion. + +CPU-only tests that validate: +- RMSNorm (+1 convention) weight conversion +- GQA q_proj interleaved split (query + gate) +- QK norm key renaming (q_norm -> q_layernorm, k_norm -> k_layernorm) +- Fused QKV concatenation +- DeltaNet layer weights pass through unchanged +- VL wrapper prefix stripping +- rank_util injection + +These tests are architecture-level and apply to both Qwen3.5-27B and Qwen3.6-27B. +""" + +import os +import sys +import unittest + +import torch + +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from src.modeling_qwen35 import ( + Qwen35InferenceConfig, + NeuronQwen35ForCausalLM, + convert_qwen35_hf_to_neuron_state_dict, +) +from neuronx_distributed_inference.models.config import NeuronConfig + + +def _make_mini_config(num_layers=4, tp_degree=2, fused_qkv=True): + """Create a small Qwen35InferenceConfig for testing.""" + neuron_config = NeuronConfig( + tp_degree=tp_degree, + batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + fused_qkv=fused_qkv, + ) + config = Qwen35InferenceConfig( + neuron_config=neuron_config, + hidden_size=256, + num_hidden_layers=num_layers, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=64, + intermediate_size=512, + vocab_size=1000, + rms_norm_eps=1e-6, + max_position_embeddings=4096, + rope_theta=10000, + hidden_act="silu", + linear_num_value_heads=8, + linear_num_key_heads=4, + linear_key_head_dim=32, + linear_value_head_dim=32, + linear_conv_kernel_dim=4, + ) + return config + + +def _make_mini_state_dict(config): + """Create a minimal HF-style state dict for conversion testing.""" + sd = {} + H = config.hidden_size # 256 + I = config.intermediate_size # 512 + V = config.vocab_size # 1000 + num_heads = config.num_attention_heads # 4 + num_kv = config.num_key_value_heads # 2 + head_dim = config.head_dim # 64 + + sd["embed_tokens.weight"] = torch.randn(V, H, dtype=torch.bfloat16) * 0.02 + sd["lm_head.weight"] = torch.randn(V, H, dtype=torch.bfloat16) * 0.02 + sd["norm.weight"] = torch.zeros(H, dtype=torch.bfloat16) # +1 convention: zeros + + for l in range(config.num_hidden_layers): + sd[f"layers.{l}.input_layernorm.weight"] = torch.zeros(H, dtype=torch.bfloat16) + sd[f"layers.{l}.post_attention_layernorm.weight"] = torch.zeros( + H, dtype=torch.bfloat16 + ) + + # Dense MLP (all layers) + sd[f"layers.{l}.mlp.gate_proj.weight"] = ( + torch.randn(I, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.mlp.up_proj.weight"] = ( + torch.randn(I, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.mlp.down_proj.weight"] = ( + torch.randn(H, I, dtype=torch.bfloat16) * 0.02 + ) + + if config.layer_types[l] == "full_attention": + # GQA layer: q_proj is interleaved [head0_q | head0_gate | head1_q | ...] + q_proj = ( + torch.randn(num_heads * head_dim * 2, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.q_proj.weight"] = q_proj + sd[f"layers.{l}.self_attn.k_proj.weight"] = ( + torch.randn(num_kv * head_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.v_proj.weight"] = ( + torch.randn(num_kv * head_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.o_proj.weight"] = ( + torch.randn(H, num_heads * head_dim, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.q_norm.weight"] = torch.zeros( + head_dim, dtype=torch.bfloat16 + ) + sd[f"layers.{l}.self_attn.k_norm.weight"] = torch.zeros( + head_dim, dtype=torch.bfloat16 + ) + else: + # DeltaNet layer: minimal required weights + key_dim = config.linear_num_key_heads * config.linear_key_head_dim # 128 + value_dim = ( + config.linear_num_value_heads * config.linear_value_head_dim + ) # 256 + conv_dim = key_dim * 2 + value_dim # 512 + sd[f"layers.{l}.linear_attn.in_proj_qkv.weight"] = ( + torch.randn(conv_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.linear_attn.in_proj_z.weight"] = ( + torch.randn(value_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.linear_attn.in_proj_a.weight"] = ( + torch.randn(config.linear_num_value_heads, H, dtype=torch.bfloat16) + * 0.02 + ) + sd[f"layers.{l}.linear_attn.in_proj_b.weight"] = ( + torch.randn(config.linear_num_value_heads, H, dtype=torch.bfloat16) + * 0.02 + ) + sd[f"layers.{l}.linear_attn.conv1d.weight"] = ( + torch.randn( + conv_dim, 1, config.linear_conv_kernel_dim, dtype=torch.bfloat16 + ) + * 0.02 + ) + sd[f"layers.{l}.linear_attn.A_log"] = torch.randn( + config.linear_num_value_heads, dtype=torch.bfloat16 + ) + sd[f"layers.{l}.linear_attn.dt_bias"] = torch.randn( + config.linear_num_value_heads, dtype=torch.bfloat16 + ) + sd[f"layers.{l}.linear_attn.norm.weight"] = ( + torch.randn(value_dim, dtype=torch.bfloat16) * 0.5 + ) + sd[f"layers.{l}.linear_attn.out_proj.weight"] = ( + torch.randn(H, value_dim, dtype=torch.bfloat16) * 0.02 + ) + + return sd + + +class TestNormConversion(unittest.TestCase): + """Test (+1 convention) RMSNorm weight conversion.""" + + def test_norm_weight_adds_one(self): + """Weights initialized to zero should become 1.0 after conversion.""" + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + # norm.weight was zeros -> should now be ones + torch.testing.assert_close( + result["norm.weight"], + torch.ones_like(result["norm.weight"]), + ) + + def test_input_layernorm_adds_one(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + w = result[f"layers.{l}.input_layernorm.weight"] + self.assertTrue( + torch.allclose(w, torch.ones_like(w)), + f"Layer {l} input_layernorm not converted", + ) + + def test_post_attn_layernorm_adds_one(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + w = result[f"layers.{l}.post_attention_layernorm.weight"] + self.assertTrue( + torch.allclose(w, torch.ones_like(w)), + f"Layer {l} post_attention_layernorm not converted", + ) + + def test_qk_norm_adds_one(self): + """Q/K norms on GQA layers should also get +1 applied.""" + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + q_w = result[f"layers.{l}.self_attn.q_layernorm.weight"] + k_w = result[f"layers.{l}.self_attn.k_layernorm.weight"] + self.assertTrue( + torch.allclose(q_w, torch.ones_like(q_w)), + f"Layer {l} q_layernorm not converted", + ) + self.assertTrue( + torch.allclose(k_w, torch.ones_like(k_w)), + f"Layer {l} k_layernorm not converted", + ) + + +class TestQProjSplit(unittest.TestCase): + """Test q_proj interleaved split into query + gate.""" + + def test_q_proj_split_shapes(self): + """q_proj (num_heads * head_dim * 2, H) -> separate query and gate.""" + config = _make_mini_config(fused_qkv=False) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + # After split: q_proj should be (num_heads * head_dim, H) = (256, 256) + q_w = result[f"layers.{l}.self_attn.q_proj.weight"] + gate_w = result[f"layers.{l}.self_attn.output_gate_proj.weight"] + expected_shape = ( + config.num_attention_heads * config.head_dim, + config.hidden_size, + ) + self.assertEqual( + q_w.shape, expected_shape, f"Layer {l} q_proj shape wrong" + ) + self.assertEqual( + gate_w.shape, expected_shape, f"Layer {l} gate shape wrong" + ) + + def test_q_proj_deinterleave_correct(self): + """Verify the interleaved split correctly separates query and gate.""" + config = _make_mini_config(fused_qkv=False) + sd = _make_mini_state_dict(config) + + # Create a known pattern: head0 query is 1s, head0 gate is 2s, etc. + l = 3 # First full_attention layer (layer 3) + num_heads = config.num_attention_heads + head_dim = config.head_dim + H = config.hidden_size + + interleaved = torch.zeros(num_heads * head_dim * 2, H, dtype=torch.bfloat16) + for h in range(num_heads): + interleaved[h * head_dim * 2 : h * head_dim * 2 + head_dim, :] = float( + h + 1 + ) # query + interleaved[h * head_dim * 2 + head_dim : (h + 1) * head_dim * 2, :] = ( + float(h + 100) + ) # gate + + sd[f"layers.{l}.self_attn.q_proj.weight"] = interleaved + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + q_w = result[f"layers.{l}.self_attn.q_proj.weight"] + gate_w = result[f"layers.{l}.self_attn.output_gate_proj.weight"] + + for h in range(num_heads): + q_head = q_w[h * head_dim : (h + 1) * head_dim, :] + gate_head = gate_w[h * head_dim : (h + 1) * head_dim, :] + self.assertTrue( + torch.all(q_head == float(h + 1)), f"Head {h} query values wrong" + ) + self.assertTrue( + torch.all(gate_head == float(h + 100)), f"Head {h} gate values wrong" + ) + + +class TestQKNormRename(unittest.TestCase): + """Test q_norm -> q_layernorm and k_norm -> k_layernorm renaming.""" + + def test_old_keys_removed(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + self.assertNotIn(f"layers.{l}.self_attn.q_norm.weight", result) + self.assertNotIn(f"layers.{l}.self_attn.k_norm.weight", result) + + def test_new_keys_present(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + self.assertIn(f"layers.{l}.self_attn.q_layernorm.weight", result) + self.assertIn(f"layers.{l}.self_attn.k_layernorm.weight", result) + + +class TestFusedQKV(unittest.TestCase): + """Test fused QKV concatenation for attention layers.""" + + def test_fused_qkv_shape(self): + config = _make_mini_config(fused_qkv=True) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + fused_key = f"layers.{l}.self_attn.Wqkv.weight" + self.assertIn(fused_key, result, f"Layer {l} missing Wqkv") + + q_dim = config.num_attention_heads * config.head_dim + k_dim = config.num_key_value_heads * config.head_dim + v_dim = config.num_key_value_heads * config.head_dim + expected_rows = q_dim + k_dim + v_dim + self.assertEqual(result[fused_key].shape[0], expected_rows) + + def test_fused_qkv_removes_individual_keys(self): + config = _make_mini_config(fused_qkv=True) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + self.assertNotIn(f"layers.{l}.self_attn.q_proj.weight", result) + self.assertNotIn(f"layers.{l}.self_attn.k_proj.weight", result) + self.assertNotIn(f"layers.{l}.self_attn.v_proj.weight", result) + + +class TestDeltaNetPassthrough(unittest.TestCase): + """Test that DeltaNet layer weights pass through conversion unchanged.""" + + def test_deltanet_weights_unchanged(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + # Record original DeltaNet weights + originals = {} + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "linear_attention": + key = f"layers.{l}.linear_attn.in_proj_qkv.weight" + originals[key] = sd[key].clone() + + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for key, orig in originals.items(): + self.assertIn(key, result, f"Missing: {key}") + torch.testing.assert_close( + result[key], orig, msg=f"DeltaNet weight changed: {key}" + ) + + def test_deltanet_norm_not_converted(self): + """DeltaNet layers use standard RMSNorm (NOT +1 convention). + The norm weight should NOT be changed.""" + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + # Set DeltaNet norm to a known non-zero value + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "linear_attention": + sd[f"layers.{l}.linear_attn.norm.weight"] = torch.full( + (config.linear_num_value_heads * config.linear_value_head_dim,), + 0.87, + dtype=torch.bfloat16, + ) + + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "linear_attention": + w = result[f"layers.{l}.linear_attn.norm.weight"] + # Should still be ~0.87, NOT 1.87 + self.assertTrue( + torch.allclose(w, torch.full_like(w, 0.87), atol=0.01), + f"Layer {l} DeltaNet norm was incorrectly modified", + ) + + +class TestRankUtil(unittest.TestCase): + """Test rank_util tensor injection.""" + + def test_rank_util_present(self): + config = _make_mini_config(tp_degree=4) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + self.assertIn("rank_util.rank", result) + expected = torch.arange(0, 4, dtype=torch.int32) + torch.testing.assert_close(result["rank_util.rank"], expected) + + def test_gqa_layer_rank_util(self): + config = _make_mini_config(tp_degree=4) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + key = f"layers.{l}.self_attn.rank_util.rank" + self.assertIn(key, result) + expected = torch.arange(0, 4, dtype=torch.int32) + torch.testing.assert_close(result[key], expected) + + +class TestVLPrefixStripping(unittest.TestCase): + """Test VL wrapper prefix stripping in convert_hf_to_neuron_state_dict.""" + + def test_language_model_prefix_stripped(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + # Wrap with VL prefix + vl_sd = {} + for k, v in sd.items(): + vl_sd[f"language_model.{k}"] = v + vl_sd["visual.encoder.weight"] = torch.zeros(10) # should be skipped + vl_sd["mtp.something"] = torch.zeros(5) # should be skipped + + result = NeuronQwen35ForCausalLM.convert_hf_to_neuron_state_dict(vl_sd, config) + self.assertNotIn("visual.encoder.weight", result) + self.assertNotIn("mtp.something", result) + self.assertIn("norm.weight", result) + + def test_model_language_model_prefix_stripped(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + vl_sd = {} + for k, v in sd.items(): + vl_sd[f"model.language_model.{k}"] = v + + result = NeuronQwen35ForCausalLM.convert_hf_to_neuron_state_dict(vl_sd, config) + self.assertIn("norm.weight", result) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.6-27B/vllm/README.md b/contrib/models/Qwen3.6-27B/vllm/README.md new file mode 100644 index 00000000..54d6f904 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/README.md @@ -0,0 +1,262 @@ +# Qwen3.6-27B vLLM on Neuron + +This folder contains the first-pass vLLM integration helpers for the +Qwen3.6-27B contrib model. + +The current goal is **vLLM serving through the Neuron/NxDI plugin** for the +validated Qwen3.6 artifact, including long prompts through vLLM's native +chunked-prefill scheduler. + +## Which vLLM Neuron Package? + +Use the vLLM-on-Neuron environment that matches the installed Neuron SDK first. +For SDK 2.29, the AWS Neuron guide lists the NxDI/vLLM plugin stack as +`vLLM 0.16.0` with plugin version `0.5.0`. The +`vllm-project/vllm-neuron` repository is useful source/reference material, but +its README currently describes a beta plugin path tied to older `vLLM 0.11.0` +and SDK 2.26.1. Do not downgrade the working SDK 2.29 environment just to use +that repository. + +On a DLAMI, prefer the preinstalled vLLM/Neuron environment when available. If +the instance does not have one, install the Neuron-compatible vLLM plugin/fork +using the current AWS guide, then run the contrib registry patch below. + +## What Works First + +- Register the contrib `qwen3_5` text model with the NxDI model registry inside + the vLLM environment. +- Start vLLM with `VLLM_PLUGINS=neuron`. +- Load a small-context model or a precompiled artifact with + `NEURON_COMPILED_ARTIFACTS`. +- Run a short OpenAI-compatible smoke prompt. + +## Chunked Prefill Note + +The Neuron plugin disables vLLM chunked prefill by default and installs a custom +continuous-batching scheduler. For this Qwen3.6 artifact we need vLLM's native +chunked-prefill scheduler so prompts longer than the 512-token context graph are +fed to the precompiled model in 512-token chunks. The launcher sets +`DISABLE_NEURON_CUSTOM_SCHEDULER=1` when `--enable-vllm-chunked-prefill` is +passed. It also launches with `--generation-config vllm` so model +`generation_config.json` does not silently override deterministic sampling +defaults. + +## Install The Contrib Registry Patch + +Activate the vLLM/Neuron environment on the instance, then run: + +```bash +cd /home/ubuntu/inferentia-gdn +contrib/models/Qwen3.6-27B/vllm/install_qwen36_vllm.sh +``` + +If your vLLM environment is not in a standard location: + +```bash +contrib/models/Qwen3.6-27B/vllm/install_qwen36_vllm.sh \ + /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference +``` + +The installer only patches the active environment. It does not modify core repo +files. + +## Start vLLM + +Small-context compile/load path: + +```bash +contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --max-model-len 512 \ + --port 8000 +``` + +Precompiled artifact path: + +```bash +contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-bucket 512 \ + --port 8000 +``` + +Long-prompt precompiled artifact path: + +```bash +contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-bucket 512 \ + --block-size 256 \ + --enable-vllm-chunked-prefill \ + --port 8000 +``` + +Native vLLM prefix-cache experiment: + +```bash +contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-bucket 512 \ + --block-size 256 \ + --enable-vllm-chunked-prefill \ + --enable-prefix-caching \ + --mamba-cache-mode align \ + --port 8000 +``` + +Treat this as an experiment, not a production mode, until validation passes. +Standard vLLM APC reuses attention KV blocks; Qwen3.6 also needs DeltaNet +recurrent state and conv state at block boundaries. If native APC does not +produce exact greedy matches and a clear warm-hit speedup, the next step is a +hybrid APC path that caches those GDN states alongside attention KV. + +Production chat proxy: + +```bash +contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-bucket 512 \ + --block-size 256 \ + --enable-vllm-chunked-prefill \ + --port 8001 +``` + +Then expose the guarded OpenAI-compatible endpoint on port 8000: + +```bash +python contrib/models/Qwen3.6-27B/vllm/qwen36_chat_proxy.py \ + --backend-url http://127.0.0.1:8001 \ + --port 8000 +``` + +The proxy forces `chat_template_kwargs={"enable_thinking": false}` for +`/v1/chat/completions` by default. It rejects raw `/v1/completions` because raw +prompts bypass the Qwen chat template and can pollute the hybrid model state. +It also hoists `system` and `developer` messages to a single leading `system` +message because the Qwen chat template rejects system messages that appear later +in the conversation. Use `--allow-thinking` or `--allow-completions` only for +explicit debugging. + +Offline long-prompt smoke: + +```bash +python contrib/models/Qwen3.6-27B/vllm/run_offline_inference.py \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-bucket 512 \ + --block-size 256 \ + --enable-vllm-chunked-prefill \ + --chat \ + --prompt "$(python - <<'PY' +print('Summarize this document in one paragraph. ' + 'Neuron inference ' * 700) +PY +)" +``` + +Offline token-exact prefix-cache validation: + +```bash +python validation_scripts/qwen36_vllm_prefix_cache_offline.py \ + --repo-root /home/ubuntu/inferentia-gdn \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-bucket 512 \ + --block-size 256 \ + --enable-vllm-chunked-prefill \ + --mamba-cache-mode align +``` + +Offline partial-prefix validation: + +```bash +python validation_scripts/qwen36_vllm_prefix_cache_partial_offline.py \ + --repo-root /home/ubuntu/inferentia-gdn \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-bucket 512 \ + --block-size 256 \ + --enable-vllm-chunked-prefill \ + --mamba-cache-mode align +``` + +Server-side prefix-cache validation through the guarded proxy: + +```bash +python validation_scripts/qwen36_prefix_cache_validation.py \ + --base-url http://127.0.0.1:8000 \ + --model qwen3.6-27b-neuron-128k-fp8-mlp +``` + +The acceptance gate is strict: repeated greedy calls must produce identical +output, and warm-hit latency should be materially lower than cold-fill latency. +For hybrid Qwen3.6, prefix-cache validation is not complete until the GDN +recurrent/conv state behavior is proven, not just attention KV cache hits. + +Native APC validation run on Trn2 with the FP8 128K artifact: + +- server exact-repeat, `~10.8K` prompt tokens: `26.68s` cold to `1.67s` warm, + `16.0x` speedup, exact greedy text match; +- offline exact-repeat, token IDs exposed: `26.19s` cold to `2.38s` warm, + `11.0x` speedup, exact greedy token-ID match; +- offline partial-prefix reuse, token IDs exposed: `25.52s` no-cache target to + `1.70s` APC target after a different shared-prefix warmup request, `15.0x` + speedup, exact greedy token-ID match. +- server hardening, exact repeat: `25.38s` cold to `1.55s` warm, `16.35x` + speedup, exact text match; +- server hardening, cross-prefix reuse after unrelated prefix: `25.17s` cold to + `1.36s` warm, exact text match; +- shared-prefix concurrency at 1/2/4 requests returned all requested markers + exactly; the artifact still queues because it is compiled for `max_num_seqs=1`. + +Validation run on Trn2 with the FP8 128K artifact: + +- state-reset artifact: `/opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1`; +- OpenAI-compatible `/v1/chat/completions` behind the proxy passes focused + quality checks without callers passing `chat_template_kwargs`; +- repeated short-after-long validation passes after 32K and 64K requests, + confirming DeltaNet recurrent/conv state is reset for new requests; +- 32K and 64K needle retrieval prompts return all expected codes; +- measured prefill is `404-428 tok/s` from 512 through 64K prompt tokens; +- measured decode is `26.3-26.6 tok/s`; +- peak Neuron device memory is about `53.25 GB` decimal for the 64K eval. + +Raw `/v1/completions` prompts are not chat-templated and can pollute the hybrid +state if sent directly to the backend. Keep the backend private and expose the +proxy on the public port for production calls. + +## Offline Smoke + +```bash +python contrib/models/Qwen3.6-27B/vllm/run_offline_inference.py \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-bucket 512 \ + --chat \ + --prompt "What is 17 * 23? Answer with the number only." +``` + +## Next Milestone + +Validate native vLLM prefix caching with the token-exact offline harness. If it +does not pass, implement hybrid APC by saving/restoring DeltaNet recurrent and +conv state at block boundaries. diff --git a/contrib/models/Qwen3.6-27B/vllm/hf_qwen35_config.py b/contrib/models/Qwen3.6-27B/vllm/hf_qwen35_config.py new file mode 100644 index 00000000..f764048a --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/hf_qwen35_config.py @@ -0,0 +1,68 @@ +"""Minimal Hugging Face config registration for Qwen3.5/Qwen3.6 vLLM smoke. + +The Neuron vLLM environment can lag upstream Transformers. vLLM validates the +HF config before the NxDI model registry gets a chance to instantiate the +contrib model, so register a permissive config class for the new model_type. +""" + +from __future__ import annotations + +from transformers import AutoConfig, PretrainedConfig + + +class Qwen35TextConfig(PretrainedConfig): + model_type = "qwen3_5_text" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +class Qwen35Config(PretrainedConfig): + model_type = "qwen3_5" + sub_configs = {"text_config": Qwen35TextConfig} + + def __init__(self, text_config=None, **kwargs): + if isinstance(text_config, dict): + text_config = Qwen35TextConfig(**text_config) + self.text_config = text_config + if text_config is not None: + for name, value in text_config.to_dict().items(): + if name not in {"architectures", "model_type"}: + kwargs.setdefault(name, value) + rope_parameters = getattr(text_config, "rope_parameters", None) + if isinstance(rope_parameters, dict): + kwargs.setdefault("rope_theta", rope_parameters.get("rope_theta")) + super().__init__(**kwargs) + + +def _is_registered(model_type: str) -> bool: + try: + AutoConfig.for_model(model_type) + except ValueError: + return False + return True + + +def register_qwen35_hf_config() -> None: + if not _is_registered(Qwen35TextConfig.model_type): + AutoConfig.register(Qwen35TextConfig.model_type, Qwen35TextConfig) + if not _is_registered(Qwen35Config.model_type): + AutoConfig.register(Qwen35Config.model_type, Qwen35Config) + + +def register_qwen35_vllm_architecture() -> None: + try: + from vllm.model_executor.models import ModelRegistry + except Exception: + return + + supported_archs = ModelRegistry.get_supported_archs() + qwen3_impl = "vllm.model_executor.models.qwen3:Qwen3ForCausalLM" + for arch in ("Qwen3_5ForConditionalGeneration", "Qwen3_5ForCausalLM"): + if arch not in supported_archs: + ModelRegistry.register_model(arch, qwen3_impl) + + +def register_qwen35_config() -> None: + register_qwen35_hf_config() + register_qwen35_vllm_architecture() diff --git a/contrib/models/Qwen3.6-27B/vllm/install_qwen36_vllm.sh b/contrib/models/Qwen3.6-27B/vllm/install_qwen36_vllm.sh new file mode 100755 index 00000000..f21536eb --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/install_qwen36_vllm.sh @@ -0,0 +1,61 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CONTRIB_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +if [[ $# -gt 0 ]]; then + VENV="$1" +else + VENV="" + for candidate in \ + /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference \ + /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16 \ + /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13 \ + /opt/aws_neuronx_venv_pytorch_inference_vllm_0_12 \ + /opt/aws_neuronx_venv_pytorch_inference_vllm_0_11 + do + if [[ -x "${candidate}/bin/python" ]]; then + VENV="${candidate}" + break + fi + done +fi + +if [[ -z "${VENV}" || ! -x "${VENV}/bin/python" ]]; then + echo "ERROR: Could not find a vLLM/Neuron Python environment." >&2 + echo "Usage: $0 /path/to/venv" >&2 + exit 1 +fi + +PYTHON="${VENV}/bin/python" +export PATH="${VENV}/bin:${PATH}" +export PYTHONPATH="${CONTRIB_ROOT}:${PYTHONPATH:-}" + +echo "vLLM/Neuron env : ${VENV}" +echo "Contrib root : ${CONTRIB_ROOT}" + +"${PYTHON}" "${SCRIPT_DIR}/patch_nxdi_registry.py" --contrib-root "${CONTRIB_ROOT}" + +"${PYTHON}" - <<'PY' +import importlib.util +from neuronx_distributed_inference.utils.constants import MODEL_TYPES + +if importlib.util.find_spec("vllm") is None: + raise RuntimeError("vLLM is not installed in this environment") + +if importlib.util.find_spec("vllm_neuron") is None: + print( + "WARNING: vllm_neuron package was not found. If this environment uses " + "an AWS vLLM fork with built-in Neuron support this may be fine; " + "otherwise install the Neuron vLLM plugin that matches this SDK.", + ) + +for key in ("qwen3_5", "qwen3_5_text"): + assert key in MODEL_TYPES, f"{key} missing from MODEL_TYPES" + assert "causal-lm" in MODEL_TYPES[key], f"{key}/causal-lm missing" +print("Qwen3.6 vLLM registry verification OK") +PY + +echo "Installation complete." +echo "Remember to set PYTHONPATH=${CONTRIB_ROOT} when starting vLLM." diff --git a/contrib/models/Qwen3.6-27B/vllm/patch_nxdi_registry.py b/contrib/models/Qwen3.6-27B/vllm/patch_nxdi_registry.py new file mode 100644 index 00000000..91fe41c5 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/patch_nxdi_registry.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +"""Register Qwen3.6 contrib model in the installed NxDI registry. + +This patches the active Python environment, not the repository checkout. The +runtime still needs PYTHONPATH to include contrib/models/Qwen3.6-27B so that +`src.modeling_qwen35` can be imported by the vLLM process. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + + +MARKER_BEGIN = "# QWEN36_CONTRIB_VLLM_REGISTER_BEGIN" +MARKER_END = "# QWEN36_CONTRIB_VLLM_REGISTER_END" + +REGISTRATION_BLOCK = f""" + +{MARKER_BEGIN} +# Registered by contrib/models/Qwen3.6-27B/vllm/install_qwen36_vllm.sh. +# Requires PYTHONPATH to include the Qwen3.6-27B contrib directory at runtime. +try: + from src.modeling_qwen35 import ( + NeuronQwen35ForCausalLM as _Qwen36ContribForCausalLM, + ) +except Exception: + _Qwen36ContribForCausalLM = None + +if _Qwen36ContribForCausalLM is not None: + MODEL_TYPES.setdefault("qwen3_5", {{}})["causal-lm"] = _Qwen36ContribForCausalLM + MODEL_TYPES.setdefault("qwen3_5_text", {{}})["causal-lm"] = _Qwen36ContribForCausalLM +{MARKER_END} +""" + + +def _constants_path() -> Path: + import neuronx_distributed_inference.utils.constants as constants # noqa: WPS433 + + return Path(constants.__file__).resolve() + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--contrib-root", required=True) + parser.add_argument("--dry-run", action="store_true") + args = parser.parse_args() + + contrib_root = Path(args.contrib_root).expanduser().resolve() + if not (contrib_root / "src" / "modeling_qwen35.py").exists(): + raise FileNotFoundError(f"Qwen3.6 contrib root looks invalid: {contrib_root}") + + path = _constants_path() + text = path.read_text() + if MARKER_BEGIN in text: + print(f"Registry already patched: {path}") + return 0 + + patched = text.rstrip() + REGISTRATION_BLOCK + "\n" + print(f"Patch target: {path}") + if args.dry_run: + print("Dry run; no files written") + return 0 + + path.write_text(patched) + print("Patched NxDI MODEL_TYPES with qwen3_5 and qwen3_5_text") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/vllm/qwen36_chat_proxy.py b/contrib/models/Qwen3.6-27B/vllm/qwen36_chat_proxy.py new file mode 100644 index 00000000..d8bd0bda --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/qwen36_chat_proxy.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +"""Small OpenAI-compatible guard proxy for Qwen3.6 vLLM serving. + +The upstream Qwen3.6 chat template defaults to thinking mode. For this Neuron +artifact the production-safe chat path is non-thinking mode, so this proxy +injects ``chat_template_kwargs={"enable_thinking": false}`` for chat requests. +It also blocks raw completions by default because they are not chat-templated. +""" + +from __future__ import annotations + +import argparse +import json +import os +import urllib.error +import urllib.request +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from typing import Any + + +def _json_response(handler: BaseHTTPRequestHandler, status: int, payload: dict[str, Any]): + body = json.dumps(payload).encode("utf-8") + handler.send_response(status) + handler.send_header("Content-Type", "application/json") + handler.send_header("Content-Length", str(len(body))) + handler.end_headers() + handler.wfile.write(body) + + +def _message_text(content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, dict): + text = item.get("text") + if isinstance(text, str): + parts.append(text) + elif isinstance(item, str): + parts.append(item) + return "\n".join(parts) + return str(content) + + +def _normalize_messages_for_qwen(messages: Any) -> Any: + """Make common OpenAI message layouts acceptable to the Qwen chat template.""" + if not isinstance(messages, list): + return messages + + system_parts: list[str] = [] + normal_messages: list[Any] = [] + for message in messages: + if not isinstance(message, dict): + normal_messages.append(message) + continue + + role = message.get("role") + if role in {"system", "developer"}: + system_parts.append(_message_text(message.get("content", ""))) + else: + normal_messages.append(message) + + if not system_parts: + return messages + + system_message = { + "role": "system", + "content": "\n\n".join(part for part in system_parts if part), + } + return [system_message, *normal_messages] + + +class Qwen36ProxyHandler(BaseHTTPRequestHandler): + backend_url: str = "http://127.0.0.1:8001" + force_disable_thinking: bool = True + allow_completions: bool = False + + def log_message(self, fmt: str, *args): # noqa: D401 + print(f"{self.address_string()} - {fmt % args}", flush=True) + + def _forward(self, method: str, body: bytes | None = None): + headers = { + key: value + for key, value in self.headers.items() + if key.lower() not in {"host", "content-length", "connection"} + } + url = self.backend_url.rstrip("/") + self.path + req = urllib.request.Request(url, data=body, headers=headers, method=method) + try: + with urllib.request.urlopen(req, timeout=None) as resp: + response_body = resp.read() + self.send_response(resp.status) + for key, value in resp.headers.items(): + if key.lower() in {"transfer-encoding", "connection"}: + continue + self.send_header(key, value) + self.end_headers() + self.wfile.write(response_body) + except urllib.error.HTTPError as exc: + error_body = exc.read() + self.send_response(exc.code) + for key, value in exc.headers.items(): + if key.lower() in {"transfer-encoding", "connection"}: + continue + self.send_header(key, value) + self.end_headers() + self.wfile.write(error_body) + + def do_GET(self): # noqa: N802 + self._forward("GET") + + def do_POST(self): # noqa: N802 + length = int(self.headers.get("Content-Length", "0") or "0") + raw_body = self.rfile.read(length) if length else b"" + + if self.path == "/v1/completions" and not self.allow_completions: + _json_response( + self, + 400, + { + "error": { + "message": ( + "Raw /v1/completions is disabled for Qwen3.6. " + "Use /v1/chat/completions so the Qwen chat template " + "and non-thinking mode are applied." + ), + "type": "invalid_request_error", + "code": "qwen36_chat_required", + } + }, + ) + return + + if self.path == "/v1/chat/completions" and raw_body: + try: + payload = json.loads(raw_body) + except json.JSONDecodeError: + self._forward("POST", raw_body) + return + + template_kwargs = payload.get("chat_template_kwargs") + if not isinstance(template_kwargs, dict): + template_kwargs = {} + if self.force_disable_thinking: + template_kwargs["enable_thinking"] = False + else: + template_kwargs.setdefault("enable_thinking", False) + payload["chat_template_kwargs"] = template_kwargs + payload["messages"] = _normalize_messages_for_qwen(payload.get("messages")) + raw_body = json.dumps(payload).encode("utf-8") + + self._forward("POST", raw_body) + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--backend-url", default=os.getenv("QWEN36_BACKEND_URL", "http://127.0.0.1:8001")) + parser.add_argument("--allow-completions", action="store_true") + parser.add_argument("--allow-thinking", action="store_true") + args = parser.parse_args() + + Qwen36ProxyHandler.backend_url = args.backend_url + Qwen36ProxyHandler.allow_completions = args.allow_completions + Qwen36ProxyHandler.force_disable_thinking = not args.allow_thinking + + server = ThreadingHTTPServer((args.host, args.port), Qwen36ProxyHandler) + print( + "Qwen3.6 proxy listening on " + f"{args.host}:{args.port}, backend={args.backend_url}, " + f"allow_completions={args.allow_completions}, " + f"force_disable_thinking={not args.allow_thinking}", + flush=True, + ) + server.serve_forever() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/vllm/run_offline_inference.py b/contrib/models/Qwen3.6-27B/vllm/run_offline_inference.py new file mode 100644 index 00000000..8c0eb06f --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/run_offline_inference.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +"""Offline vLLM smoke runner for Qwen3.6-27B on Neuron.""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +import time +from pathlib import Path + + +def _contrib_root(repo_root: str | None) -> Path: + if repo_root: + return Path(repo_root).expanduser().resolve() / "contrib" / "models" / "Qwen3.6-27B" + return Path(__file__).resolve().parents[1] + + +def _override_config(args: argparse.Namespace) -> dict: + neuron_config = { + "tp_degree": args.tensor_parallel_size, + "batch_size": args.max_num_seqs, + "ctx_batch_size": 1, + "tkg_batch_size": args.max_num_seqs, + "seq_len": args.seq_len, + "max_length": args.seq_len, + "max_context_length": args.cte_bucket, + "context_encoding_buckets": [args.cte_bucket], + "token_generation_buckets": [args.seq_len], + "enable_bucketing": False, + "logical_nc_config": args.logical_nc_config, + "torch_dtype": "bfloat16", + "save_sharded_checkpoint": True, + } + if args.enable_vllm_chunked_prefill: + neuron_config.update( + { + "is_block_kv_layout": True, + "chunked_prefill_config": { + "max_num_seqs": args.max_num_seqs, + "tkg_model_enabled": True, + "kernel_q_tile_size": 128, + "kernel_kv_tile_size": 1024, + }, + } + ) + return { + "max_prompt_length": args.cte_bucket, + "override_neuron_config": neuron_config, + } + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--repo-root", default=None) + parser.add_argument("--model-path", required=True) + parser.add_argument("--compiled-artifacts", default=None) + parser.add_argument("--prompt", default="What is 17 * 23? Answer with the number only.") + parser.add_argument("--chat", action="store_true") + parser.add_argument("--enable-vllm-chunked-prefill", action="store_true") + parser.add_argument("--enable-prefix-caching", action="store_true") + parser.add_argument("--mamba-cache-mode", default=None) + parser.add_argument("--mamba-cache-dtype", default=None) + parser.add_argument("--mamba-ssm-cache-dtype", default=None) + parser.add_argument("--max-tokens", type=int, default=64) + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--top-k", type=int, default=1) + parser.add_argument("--tensor-parallel-size", type=int, default=4) + parser.add_argument("--logical-nc-config", type=int, default=2) + parser.add_argument("--max-num-seqs", type=int, default=1) + parser.add_argument("--max-model-len", type=int, default=512) + parser.add_argument("--seq-len", type=int, default=512) + parser.add_argument("--cte-bucket", type=int, default=512) + parser.add_argument("--block-size", type=int, default=256) + args = parser.parse_args() + + contrib_root = _contrib_root(args.repo_root) + script_dir = Path(__file__).resolve().parent + sys.path.insert(0, str(script_dir)) + sys.path.insert(0, str(contrib_root)) + os.environ["PYTHONPATH"] = ( + f"{script_dir}:{contrib_root}:{os.environ.get('PYTHONPATH', '')}" + ) + os.environ.setdefault("VLLM_NEURON_FRAMEWORK", "neuronx-distributed-inference") + os.environ.setdefault("VLLM_PLUGINS", "neuron") + if args.enable_vllm_chunked_prefill: + os.environ["DISABLE_NEURON_CUSTOM_SCHEDULER"] = "1" + if args.compiled_artifacts: + os.environ["NEURON_COMPILED_ARTIFACTS"] = str( + Path(args.compiled_artifacts).expanduser().resolve() + ) + + from hf_qwen35_config import register_qwen35_config # noqa: WPS433 + + register_qwen35_config() + + from vllm import LLM, SamplingParams # noqa: WPS433 + + prompt = args.prompt + if args.chat: + from transformers import AutoTokenizer # noqa: WPS433 + + tokenizer = AutoTokenizer.from_pretrained( + args.model_path, + trust_remote_code=True, + ) + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": args.prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + + additional_config = _override_config(args) + print("VLLM_QWEN36_CONFIG", json.dumps(additional_config, sort_keys=True), flush=True) + + llm_kwargs = { + "model": str(Path(args.model_path).expanduser().resolve()), + "trust_remote_code": True, + "dtype": "bfloat16", + "tensor_parallel_size": args.tensor_parallel_size, + "max_num_seqs": args.max_num_seqs, + "max_model_len": args.max_model_len, + "enable_prefix_caching": args.enable_prefix_caching, + "enable_chunked_prefill": args.enable_vllm_chunked_prefill, + "additional_config": additional_config, + } + if args.mamba_cache_mode is not None: + llm_kwargs["mamba_cache_mode"] = args.mamba_cache_mode + if args.mamba_cache_dtype is not None: + llm_kwargs["mamba_cache_dtype"] = args.mamba_cache_dtype + if args.mamba_ssm_cache_dtype is not None: + llm_kwargs["mamba_ssm_cache_dtype"] = args.mamba_ssm_cache_dtype + if args.enable_vllm_chunked_prefill: + llm_kwargs["max_num_batched_tokens"] = args.cte_bucket + llm_kwargs["block_size"] = args.block_size + llm = LLM(**llm_kwargs) + + sampling = SamplingParams( + temperature=args.temperature, + top_k=args.top_k, + max_tokens=args.max_tokens, + ) + start = time.perf_counter() + outputs = llm.generate([prompt], sampling) + elapsed = time.perf_counter() - start + text = outputs[0].outputs[0].text + token_ids = outputs[0].outputs[0].token_ids + + print("PROMPT", prompt) + print("OUTPUT", text) + print("TOKENS", list(token_ids)) + print("ELAPSED_SECONDS", f"{elapsed:.3f}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/vllm/serve_qwen36.py b/contrib/models/Qwen3.6-27B/vllm/serve_qwen36.py new file mode 100644 index 00000000..85e12ef6 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/serve_qwen36.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 +"""vLLM CLI wrapper that registers Qwen3.6 aliases before validation.""" + +from __future__ import annotations + +import sys + +from hf_qwen35_config import register_qwen35_config + + +def main() -> int: + register_qwen35_config() + + from vllm.entrypoints.cli.main import main as vllm_main + + sys.argv = ["vllm", "serve", *sys.argv[1:]] + return int(vllm_main() or 0) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/vllm/sitecustomize.py b/contrib/models/Qwen3.6-27B/vllm/sitecustomize.py new file mode 100644 index 00000000..dcec3056 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/sitecustomize.py @@ -0,0 +1,9 @@ +"""Auto-register Qwen3.5/Qwen3.6 HF config when this folder is on PYTHONPATH. + +Do not import vLLM here. Neuron helper commands such as libneuronpjrt-path run +inside Python subprocesses and expect clean stdout. +""" + +from hf_qwen35_config import register_qwen35_hf_config + +register_qwen35_hf_config() diff --git a/contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh b/contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh new file mode 100755 index 00000000..46342690 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh @@ -0,0 +1,147 @@ +#!/usr/bin/env bash +set -euo pipefail + +MODEL_PATH="" +COMPILED_ARTIFACTS="" +MAX_MODEL_LEN="512" +SEQ_LEN="512" +CTE_BUCKET="512" +TP_DEGREE="4" +LNC="2" +MAX_NUM_SEQS="1" +PORT="8000" +HOST="0.0.0.0" +ENABLE_CHUNKED_PREFILL="0" +ENABLE_PREFIX_CACHING="0" +MAMBA_CACHE_MODE="" +MAMBA_CACHE_DTYPE="" +MAMBA_SSM_CACHE_DTYPE="" +BLOCK_SIZE="" + +while [[ $# -gt 0 ]]; do + case "$1" in + --model-path) MODEL_PATH="$2"; shift 2 ;; + --compiled-artifacts) COMPILED_ARTIFACTS="$2"; shift 2 ;; + --max-model-len) MAX_MODEL_LEN="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --cte-bucket) CTE_BUCKET="$2"; shift 2 ;; + --tensor-parallel-size) TP_DEGREE="$2"; shift 2 ;; + --logical-nc-config) LNC="$2"; shift 2 ;; + --max-num-seqs) MAX_NUM_SEQS="$2"; shift 2 ;; + --enable-vllm-chunked-prefill) ENABLE_CHUNKED_PREFILL="1"; shift ;; + --enable-prefix-caching) ENABLE_PREFIX_CACHING="1"; shift ;; + --disable-prefix-caching|--no-enable-prefix-caching) ENABLE_PREFIX_CACHING="0"; shift ;; + --mamba-cache-mode) MAMBA_CACHE_MODE="$2"; shift 2 ;; + --mamba-cache-dtype) MAMBA_CACHE_DTYPE="$2"; shift 2 ;; + --mamba-ssm-cache-dtype) MAMBA_SSM_CACHE_DTYPE="$2"; shift 2 ;; + --block-size) BLOCK_SIZE="$2"; shift 2 ;; + --host) HOST="$2"; shift 2 ;; + --port) PORT="$2"; shift 2 ;; + *) echo "Unknown argument: $1" >&2; exit 2 ;; + esac +done + +if [[ -z "${MODEL_PATH}" ]]; then + echo "ERROR: --model-path is required" >&2 + exit 2 +fi + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CONTRIB_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" +export PYTHONPATH="${SCRIPT_DIR}:${CONTRIB_ROOT}:${PYTHONPATH:-}" +export VLLM_NEURON_FRAMEWORK="neuronx-distributed-inference" +export VLLM_PLUGINS="${VLLM_PLUGINS:-neuron}" + +if [[ -n "${COMPILED_ARTIFACTS}" ]]; then + export NEURON_COMPILED_ARTIFACTS="${COMPILED_ARTIFACTS}" +fi +if [[ -z "${BLOCK_SIZE}" ]]; then + BLOCK_SIZE="256" +fi +if [[ "${ENABLE_CHUNKED_PREFILL}" == "1" ]]; then + export DISABLE_NEURON_CUSTOM_SCHEDULER="1" +fi + +ADDITIONAL_CONFIG="$( + python3 - <