forked from zhongkaifu/TensorSharp
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDiffusionGemmaModel.cs
More file actions
2494 lines (2277 loc) · 133 KB
/
Copy pathDiffusionGemmaModel.cs
File metadata and controls
2494 lines (2277 loc) · 133 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// Copyright (c) Zhongkai Fu. All rights reserved.
// https://github.com/zhongkaifu/TensorSharp
//
// This file is part of TensorSharp.
//
// TensorSharp is licensed under the BSD-3-Clause license found in the LICENSE file in the root directory of this source tree.
//
// TensorSharp is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD-3-Clause License for more details.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Numerics.Tensors;
using System.Runtime.InteropServices;
using System.Threading.Tasks;
using TensorSharp;
using TensorSharp.GGML;
using TensorSharp.MLX;
namespace TensorSharp.Models
{
/// <summary>
/// DiffusionGemma (architecture key <c>diffusion-gemma</c>) — a block text-diffusion
/// Mixture-of-Experts language model built on a Gemma-4 backbone.
///
/// This is fundamentally different from the autoregressive Gemma 4 model:
/// - A single <b>no-cache, bidirectional</b> forward is run over a concatenated
/// <c>[prompt | canvas]</c> sequence. Split point <c>P = n_tokens - canvas_length</c>.
/// - Prompt queries are causal (and sliding-window clipped on local layers) and never
/// attend to the canvas; canvas queries are bidirectional over prompt + canvas.
/// - Input embeddings are region-aware: prompt = <c>embed*sqrt(n_embd)</c>,
/// canvas = <c>rms_norm_noscale(embed*sqrt(n_embd) [+ self-conditioning])</c>.
/// - Each layer applies a region-aware per-layer scalar: prompt uses the encoder scalar,
/// canvas uses the decoder scalar.
///
/// Generation is performed by an iterative denoising loop (the EntropyBound sampler) which
/// lives in <see cref="DiffusionGemmaSampler"/>; this class exposes the per-step forward.
///
/// The Gemma-4 backbone shared with the AR model:
/// - Per-head Q/K RMSNorm, unweighted V RMSNorm, NeoX RoPE.
/// - Attention scale = 1.0 (the learnable Q/K norms absorb the 1/sqrt(d) factor).
/// - 5 sliding (local) layers : 1 global layer pattern (local head dim 256, global 512).
/// - Global layers omit the V projection (V == raw K projection).
/// - Dense gated-GELU MLP (shared expert) + 128-expert top-8 softmax MoE, summed per layer.
/// - Embedding scaling by sqrt(hidden), tied lm-head, final logit softcapping.
/// </summary>
public sealed class DiffusionGemmaModel : ModelBase
{
// ---- architecture configuration ----
private bool[] _isLocal; // per-layer: true = sliding-window (local), false = global
private int[] _kvHeads; // per-layer KV head count
private int[] _headDim; // per-layer head dim (local 256, global 512)
private bool[] _hasVProj; // per-layer: whether attn_v.weight exists (false on global layers)
private int _localHeadDim;
private int _globalHeadDim;
private int _slidingWindow;
private float _ropeLocalBase;
private float _ropeGlobalBase;
private int _denseFfn; // dense MLP intermediate (feed_forward_length)
private int _expertFfn; // per-expert intermediate (expert_feed_forward_length)
private int _numExperts;
private int _numExpertsUsed;
private float _finalLogitSoftcap;
private int _canvasLength;
private int _maskTokenId;
// precomputed RoPE base frequencies per layer type (NeoX). Global folds in rope_freqs (NTK).
private float[] _ropeFreqsLocal; // length localHeadDim/2
private float[] _ropeFreqsGlobal; // length globalHeadDim/2
// raw rope_freqs.weight (NOT folded), used by the on-device GGML freq-factors RoPE kernel.
private Tensor _ropeFreqsRawTensor;
// ---- on-device glue-op caches (rebuilt only when N/P change, i.e. once per block) ----
// RoPE integer positions [N*numHeads], keyed by (numHeads, startPos).
private readonly Dictionary<long, Tensor> _ropePosCache = new();
private int _ropePosN = -1;
// global cos/sin tables (folded freqs) for the non-GGML RoPE paths (mlx fused / cuda raw).
private int _cosSinGlobalN = -1;
private float[] _cosGlobalHost, _sinGlobalHost;
private Tensor _cosGlobalTensor, _sinGlobalTensor;
// additive attention masks [maskHeads,N,N] (0 = attend, large-negative = masked), per layer type.
private int _maskN = -1, _maskP = -1;
private Tensor _maskLocal, _maskGlobal;
private const float MaskNeg = -1e30f;
private const int NeoXRopeMode = 2;
// host cos/sin tables for the CPU raw-pointer RoPE path, cached by N.
private int _cosSinHostN = -1;
private float[] _cosLocalHost, _sinLocalHost;
// GPU backends keep the whole forward device-resident (Ops-based glue); CPU backends use the
// hand-tuned SIMD raw-pointer glue, which is faster on the CPU and avoids host<->device churn
// that doesn't exist there anyway.
private bool _useDeviceGlue;
// ---- prompt-KV caching (PKV) — the headline llama.cpp/vLLM diffusion-gemma optimization ----
// The prompt's per-layer K/V do not depend on the canvas (prompt is causal and never attends to
// the canvas), so they are identical across every denoising step. PrefillPrompt computes them
// once per block and stores them here (head-first [kvHeads, P, hd]); DecodeCanvas then processes
// only the canvas each step and prepends the cached prompt K/V. GPU path only.
private Tensor[] _promptK, _promptV;
private int _promptLen = -1;
private bool _pkvEnabled;
// decode-phase additive masks [maskHeads, C, P+C] (null when no masking is needed, i.e. when the
// prompt fits within the sliding window — the common case).
private int _decodeMaskP = -1, _decodeMaskC = -1;
private Tensor _decodeMaskLocal, _decodeMaskGlobal;
// per-layer output scalars
private float[] _decScale; // layer_output_scale (canvas / decoder)
private float[] _encScale; // enc_layer_output_scale (prompt / encoder)
// fused-MoE stacked expert weights (one ggml_mul_mat_id dispatch per layer instead of 128 matmuls)
private StackedExpertWeights[] _stackedGateUp; // pre-fused [n_embd, 2*expertFfn, numExperts]
private StackedExpertWeights[] _stackedDown; // [expertFfn, n_embd, numExperts]
private float[][] _perExpertScale; // per-layer ffn_down_exps.scale
private bool _fusedMoeAvailable;
// MLX fused MoE via mlx_gather_qmm: one batched (expert-sorted) gather_qmm (gate_up) + GEGLU + one
// gather_qmm (down) over stacked MLX-affine experts — the MLX analogue of GGML's mul_mat_id.
// Correct (self-checked at runtime: cosine 1.0 vs the per-expert path), but measured NOT faster than
// the per-expert affine path for this model's MoE shape: 256 canvas tokens × top-8 over 128 experts
// leaves only ~16 tokens/expert, so the per-expert path already issues efficient grouped GEMMs and
// the MoE is small-GEMM compute-bound — gather_qmm's sort/unsort overhead offsets its dispatch
// savings (M=1 ~3900 ms, sorted ~3635 ms, per-expert affine ~3450 ms / decode step). Hence OFF by
// default. Opt in with TS_MLX_MOE_GATHER_QMM=1 (it can win for larger canvases / fewer experts where
// per-expert batches are bigger). The one-time self-check disables it permanently if it ever diverges.
private readonly bool _moeGatherQmmEnabled = Environment.GetEnvironmentVariable("TS_MLX_MOE_GATHER_QMM") == "1";
private bool _moeGatherQmmOk = true;
private bool _moeGatherQmmChecked;
// MLX FULLY-ON-DEVICE fused MoE: routing (top-K) AND the expert FFN run on-device with NO host read,
// so the entire decode forward stays device-resident — the MLX port of GGML's fused single-graph
// decode. CORRECT (self-check cosine 1.0) but measured NOT faster (~4030 vs ~3450 ms/step for the
// per-expert affine path), because unlike GGML — where fusion into ONE C++ graph / Metal command
// buffer removes ALL per-op host/native overhead (its 4x: per-op 2650 -> fused 640 ms) — MLX still
// builds its graph op-by-op through the worker thread, so that overhead remains even when the forward
// collapses to one lazy eval; and a device-resident router forces the M=1 (GEMV) gather_qmm whose
// penalty offsets the (small) per-layer-sync saving. So OFF by default; opt in with
// TS_MLX_FUSED_DEVICE_MOE=1. Matching GGML on MLX needs mlx_compile of the whole forward or a custom
// Metal megakernel (i.e. reimplementing GGML's kernel) — not a cheap port. One-time self-check.
private readonly bool _moeFusedDeviceEnabled = Environment.GetEnvironmentVariable("TS_MLX_FUSED_DEVICE_MOE") == "1";
private bool _moeFusedDeviceOk = true;
private bool _moeFusedDeviceChecked;
private Tensor _moeLhsGateConst, _moeLhsArangeConst; // host-built once: [N*K] gather_qmm lhs indices
private int _moeConstNK = -1;
// cached ones tensors for unweighted RMSNorm, keyed by dim
private readonly Dictionary<int, Tensor> _onesByDim = new();
// self-conditioning (optional)
private bool _scEnabled;
// reusable output buffer for canvas logits
private float[] _canvasLogits;
// timing
private readonly Stopwatch _swForward = new();
private long _tEmbed, _tAttn, _tMoe, _tDense, _tLmHead, _tSc, _tRope, _tMoeRoute, _tMoeFfn;
private long _tScTopK, _tScDevice; // self-conditioning split: host top-K vs on-device gather+MLP
// When set (DIFFUSION_PROFILE=1) each timed section is drained before its timer stops so the
// per-section attribution is accurate (otherwise async GGML/Metal work inflates the section that
// happens to force the next host read).
private readonly bool _profile = Environment.GetEnvironmentVariable("DIFFUSION_PROFILE") == "1";
private readonly bool _stepTime = Environment.GetEnvironmentVariable("DIFFUSION_STEPTIME") == "1";
private long _stepT0;
// Fused decode-layer kernel: whole layer (attention + dense + MoE) in one GGML graph dispatch.
// Correct (output matches the per-op path) but per-LAYER fusion alone doesn't speed up the decode
// because `hidden` still round-trips host memory between layers, serialising them. The model-wide
// single graph (hidden stays on-device across all layers) is the real throughput win. Opt-in for now
// via DIFFUSION_FUSED_DECODE=1 so the default stays on the proven per-op path.
// Default ON for GGML: fuse all transformer layers into one graph per decode step (correct + ~1.7x).
// Disable via DIFFUSION_NO_FUSED_DECODE=1. The further lm_head fusion (~3-4x) is separately opt-in.
private readonly bool _fusedDecodeEnabled = Environment.GetEnvironmentVariable("DIFFUSION_NO_FUSED_DECODE") != "1";
private bool _fusedDecodeOk = true; // flipped false if the kernel ever rejects the layout
// Fused lm_head tail (output_norm + lm_head + softcap in one dispatch). Default ON for GGML; the
// separate small graph keeps it correct (unlike folding lm_head into the layer graph).
private readonly bool _fusedLmHeadTailDisabled = Environment.GetEnvironmentVariable("DIFFUSION_NO_FUSED_LMHEAD_TAIL") == "1";
private bool _fusedLmHeadTailOk = true;
// Segmented (per-layer) fused decode: run each layer as its own fused graph so layers whose
// weights are NOT device-resident stream through one bounded, reused staging buffer instead of
// all coexisting in VRAM. Selected automatically by PrepareCudaWeightResidency when the model
// does not fit; override with DIFFUSION_SEGMENTED_DECODE=1/0.
private bool _segmentedDecode;
// Reusable host buffer for the fused lm_head logits (268 MB at vocab 256K, C 256). A fresh
// array per step costs a large-object-heap allocation + zeroing every step; one pooled buffer
// is safe because each step's logits are fully consumed (sampling + self-conditioning read)
// before the next step's lm_head overwrites it — the same contract the batched scheduler
// already documents for the shared readback buffer. Allocated on the pinned object heap and
// cudaHostRegister'ed so the per-step 268 MB device->host logits download takes the fast DMA
// path instead of the pageable one.
private float[] _fusedLogitsBuffer;
// Host regions page-locked for fast per-step DMA (streamed weights + the logits buffer).
// cudaHostRegister'ed memory MUST be unregistered before it is unmapped/freed (the streamed
// weights live in the GGUF mmap), so Dispose undoes these.
private readonly List<IntPtr> _pinnedHostRegions = new();
// Streamed (non-resident) weights re-homed into page-locked PRIVATE host copies: Windows
// cannot cudaHostRegister file-mapped (GGUF mmap) pages, so the bytes are copied once into
// an owned allocation that CAN be page-locked — per-step uploads then run at DMA speed
// (~2x pageable mmap throughput). Maps original weight pointer -> pinned copy; the fused
// decode arg builder substitutes these. Freed (after unregistering) in Dispose.
private readonly Dictionary<IntPtr, IntPtr> _pinnedStreamCopies = new();
// Batched-decode lm_head memory cap: batch the [B*C, vocab] lm_head into one weight-read while the
// logits tensor fits under this many bytes, else fall back to a per-sequence lm_head (one [C, vocab]
// transient — the per-seq weight re-read costs ~1ms, negligible vs the per-step compute). Default 300
// MB so a single canvas's logits (~262 MB at vocab 256K, C 256) batches but a B>=2 batch falls back —
// this is what keeps batched decode inside the tight Metal headroom on a 24 GB machine running a 16.8
// GB model. Tunable via DIFFUSION_LMHEAD_BATCH_CAP_MB.
private static readonly long LmHeadBatchByteCap =
(long)(int.TryParse(Environment.GetEnvironmentVariable("DIFFUSION_LMHEAD_BATCH_CAP_MB"), out int capMb) && capMb > 0 ? capMb : 300) * 1024 * 1024;
private void ProfSync(Tensor t) { if (_profile && t != null) _ = t.GetElementsAsFloat(1); }
public int CanvasLength => _canvasLength;
public int MaskTokenId => _maskTokenId;
public int VocabSize => Config.VocabSize;
public bool SelfConditioningEnabled { get => _scEnabled; set => _scEnabled = value; }
/// <summary>Whether prompt-KV caching is active: the sampler calls <see cref="PrefillPrompt"/>
/// once per block then <see cref="DecodeCanvas"/> per step. Only available on the device-glue
/// (GPU) backends; the setter is a no-op on CPU backends.</summary>
public bool SupportsPromptKvCache
{
get => _pkvEnabled;
set => _pkvEnabled = value && _useDeviceGlue;
}
public DiffusionGemmaModel(string ggufPath, BackendType backend) : base(ggufPath, backend)
{
Config = new ModelConfig { Architecture = _gguf.GetString("general.architecture") };
ParseBaseConfig();
ParseDiffusionConfig();
ParseTokenizer();
// MLX: DiffusionGemma always runs the MULTI-ROW regime (it denoises a C=256 canvas every
// step), where MLX's raw GGUF K-quant Metal kernels are poorly tuned (~150 GFLOP/s — they are
// written for rows==1 autoregressive decode). Preload K-quant weights into MLX-native AFFINE
// form instead, so every matmul (attention projections, dense FFN, per-expert MoE, lm_head)
// runs on Apple's fast built-in mlx_quantized_matmul. The affine repack is LOSSLESS for
// K-quant (per-32 group scale=d*scaleByte, bias=-dmin*minByte is exactly ggml's dequant), so
// accuracy is unchanged. Opt out with TS_MLX_KQUANT_AFFINE=0.
// NOTE: this is intentionally not restored after load. MLX weights are created lazily on first
// matmul (not eagerly in LoadWeights), and the per-matmul path keys off the *created* weight's
// Mode — so the flag must stay set through the first forward for the K-quants to materialize in
// affine form. The only side effect is that a subsequently-loaded autoregressive MLX model in the
// same process would also get affine K-quants (lossless; its rows==1 decode may marginally prefer
// the raw custom kernels). For the common single-model process this is a no-op. Force off with
// TS_MLX_KQUANT_AFFINE=0.
if (backend == BackendType.Mlx &&
Environment.GetEnvironmentVariable("TS_MLX_KQUANT_AFFINE") != "0")
{
MlxQuantizedOps.PreferAffineKQuant = true;
}
LoadWeights();
LoadLayerScalars();
CacheStackedExpertWeights();
PrecomputeRoPE();
_useDeviceGlue = _backend is BackendType.GgmlMetal or BackendType.GgmlCuda
or BackendType.Mlx or BackendType.Cuda;
// Prompt-KV caching needs the device-resident attention path (it stores/concats K/V tensors).
_pkvEnabled = _useDeviceGlue && Environment.GetEnvironmentVariable("DIFFUSION_NO_PKV") != "1";
PrepareCudaWeightResidency();
// Self-conditioning (matches the reference sampler) materially improves quality on longer
// outputs and converges in far fewer steps. The top-K soft-embedding makes its per-step cost
// negligible (~tens of ms), so it is enabled by default. Disable with DIFFUSION_NO_SC=1.
_scEnabled = Environment.GetEnvironmentVariable("DIFFUSION_NO_SC") != "1";
Console.WriteLine($"DiffusionGemma ready: canvas_length={_canvasLength}, experts={_numExperts}/{_numExpertsUsed}, " +
$"softcap={_finalLogitSoftcap}, mask_token={_maskTokenId}, self_conditioning={_scEnabled}, " +
$"device_glue={_useDeviceGlue}, prompt_kv_cache={_pkvEnabled}");
}
private void ParseDiffusionConfig()
{
string arch = Config.Architecture;
_slidingWindow = (int)_gguf.GetUint32($"{arch}.attention.sliding_window", 1024);
Config.SlidingWindow = _slidingWindow;
_globalHeadDim = (int)_gguf.GetUint32($"{arch}.attention.key_length", 512);
_localHeadDim = (int)_gguf.GetUint32($"{arch}.attention.key_length_swa", 256);
_ropeGlobalBase = Config.RopeBase;
_ropeLocalBase = _gguf.GetFloat32($"{arch}.rope.freq_base_swa", 10000f);
_denseFfn = Config.IntermediateSize > 0 ? Config.IntermediateSize
: (int)_gguf.GetUint32($"{arch}.feed_forward_length", 0);
_expertFfn = (int)_gguf.GetUint32($"{arch}.expert_feed_forward_length", 0);
_numExperts = (int)_gguf.GetUint32($"{arch}.expert_count", 0);
_numExpertsUsed = (int)_gguf.GetUint32($"{arch}.expert_used_count", 0);
_finalLogitSoftcap = _gguf.GetFloat32($"{arch}.final_logit_softcapping", 0f);
_canvasLength = (int)_gguf.GetUint32("diffusion.canvas_length", 256);
_maskTokenId = (int)_gguf.GetUint32("tokenizer.ggml.mask_token_id", 0);
// sliding-window pattern: true = local (SWA), false = global.
bool[] swaPattern = _gguf.GetBoolArray($"{arch}.attention.sliding_window_pattern");
int[] kvArr = _gguf.GetInt32Array($"{arch}.attention.head_count_kv");
int L = Config.NumLayers;
_isLocal = new bool[L];
_kvHeads = new int[L];
_headDim = new int[L];
_hasVProj = new bool[L];
for (int i = 0; i < L; i++)
{
bool local = swaPattern != null && i < swaPattern.Length ? swaPattern[i] : true;
_isLocal[i] = local;
_headDim[i] = local ? _localHeadDim : _globalHeadDim;
if (kvArr != null && i < kvArr.Length)
_kvHeads[i] = kvArr[i];
else
_kvHeads[i] = Config.NumKVHeads;
}
Console.WriteLine($"Model: {arch}, Layers={L}, Hidden={Config.HiddenSize}, Heads={Config.NumHeads}");
Console.WriteLine($"Head dims: local={_localHeadDim} global={_globalHeadDim}, RoPE local={_ropeLocalBase} global={_ropeGlobalBase}");
Console.WriteLine($"Sliding window={_slidingWindow}, dense FFN={_denseFfn}, expert FFN={_expertFfn}");
}
private void LoadLayerScalars()
{
int L = Config.NumLayers;
_decScale = new float[L];
_encScale = new float[L];
for (int l = 0; l < L; l++)
{
_decScale[l] = _weights.TryGetValue($"blk.{l}.layer_output_scale.weight", out var d) ? d.GetElementAsFloat(0) : 1f;
_encScale[l] = _weights.TryGetValue($"blk.{l}.enc_layer_output_scale.weight", out var e) ? e.GetElementAsFloat(0) : 1f;
// detect missing V projection (global layers): V == raw K
_hasVProj[l] = _weights.ContainsKey($"blk.{l}.attn_v.weight") || _quantWeights.ContainsKey($"blk.{l}.attn_v.weight");
}
}
private void CacheStackedExpertWeights()
{
int L = Config.NumLayers;
_stackedGateUp = new StackedExpertWeights[L];
_stackedDown = new StackedExpertWeights[L];
_perExpertScale = new float[L][];
int ok = 0;
for (int l = 0; l < L; l++)
{
string prefix = $"blk.{l}";
_stackedExpertWeights.TryGetValue($"{prefix}.ffn_gate_up_exps.weight", out _stackedGateUp[l]);
_stackedExpertWeights.TryGetValue($"{prefix}.ffn_down_exps.weight", out _stackedDown[l]);
if (_weights.TryGetValue($"{prefix}.ffn_down_exps.scale", out var scaleT))
{
var scales = new float[_numExperts];
for (int e = 0; e < _numExperts; e++) scales[e] = scaleT.GetElementAsFloat(e);
_perExpertScale[l] = scales;
}
if (_stackedGateUp[l] != null && _stackedDown[l] != null) ok++;
}
_fusedMoeAvailable = IsGgmlBackend && ok == L;
Console.WriteLine($" Fused MoE FFN kernel available on {ok}/{L} layers (enabled={_fusedMoeAvailable}).");
}
/// <summary>
/// CUDA VRAM residency plan. The fused decode reads EVERY weight every denoising step, and
/// ggml's allocator model means a monolithic whole-model graph needs all of them device-resident
/// at once — when the model is larger than VRAM, Windows WDDM transparently pages the
/// oversubscribed working set in and out of system RAM on every command submission (measured
/// ~4.9 s/step for a ~150 ms compute on a 16 GB GPU with a 16 GB model). The fix mirrors
/// llama.cpp's partial-offload discipline: never oversubscribe. We preload weights device-side
/// in priority order (lm_head/embedding first — it is also read by every step's lm_head tail —
/// then per-layer attention/dense weights, then MoE expert stacks by layer) until a budget of
/// free-VRAM-minus-headroom is reached, cap incidental device copies (prompt K/V, masks) with
/// the device-copy budget, and switch the decode to the SEGMENTED per-layer fused path so the
/// non-resident remainder streams through one bounded reused staging buffer (PCIe-speed
/// streaming, ~no paging) instead of joining the whole-model graph's working set.
/// No-op when everything fits (the whole-model fused graph stays the default) or off-CUDA.
/// </summary>
private void PrepareCudaWeightResidency()
{
if (_backend != BackendType.GgmlCuda)
return;
EnsureQuantBackendAvailable(); // memory query / preloads must hit the CUDA backend
if (!GgmlBasicOps.TryGetDeviceMemoryInfo(out long freeBytes, out _))
return;
long headroomMb = long.TryParse(Environment.GetEnvironmentVariable("DIFFUSION_VRAM_HEADROOM_MB"), out long hm) && hm > 0 ? hm : 2048;
long copyBudgetMb = long.TryParse(Environment.GetEnvironmentVariable("DIFFUSION_DEVICE_COPY_BUDGET_MB"), out long cm) && cm > 0 ? cm : 768;
// Opt-in: re-home streamed weights into page-locked private copies (costs RAM equal to
// the streamed bytes). Measured neutral on Windows/WDDM (the pageable mmap upload already
// ran near DMA speed once the file cache was warm), so default off; may pay off on Linux.
bool pinStreamed = Environment.GetEnvironmentVariable("DIFFUSION_PIN_STREAMED") == "1";
long preloadBudget = freeBytes - headroomMb * 1024 * 1024;
// Priority order: tied lm_head/embedding (read by the per-step lm_head tail), per-layer
// non-expert weights (attention + dense MLP), then the per-layer expert stacks — the bulk
// of the model; whatever does not fit streams per step.
var priority = new List<(IntPtr key, IntPtr host, int type, long ne0, long ne1, long bytes)>();
void AddQuant(string name)
{
if (_quantWeights.TryGetValue(name, out var qw) && qw.HasHostData)
priority.Add((qw.CacheKey, qw.Data, qw.GgmlType, qw.Ne0, qw.Ne1, qw.RawBytes));
}
AddQuant("token_embd.weight");
int L = Config.NumLayers;
for (int l = 0; l < L; l++)
{
AddQuant($"blk.{l}.attn_q.weight");
AddQuant($"blk.{l}.attn_k.weight");
AddQuant($"blk.{l}.attn_v.weight");
AddQuant($"blk.{l}.attn_output.weight");
AddQuant($"blk.{l}.ffn_gate.weight");
AddQuant($"blk.{l}.ffn_up.weight");
AddQuant($"blk.{l}.ffn_down.weight");
}
for (int l = 0; l < L; l++)
{
var gu = _stackedGateUp[l];
var dn = _stackedDown[l];
if (gu != null) priority.Add((gu.Data, gu.Data, gu.GgmlType, gu.PerExpertNe0, gu.PerExpertNe1 * _numExperts, gu.TotalRawBytes));
if (dn != null) priority.Add((dn.Data, dn.Data, dn.GgmlType, dn.PerExpertNe0, dn.PerExpertNe1 * _numExperts, dn.TotalRawBytes));
}
long preloadedBytes = 0;
int preloadedCount = 0;
long streamedBytes = 0;
int streamedCount = 0;
long pinnedBytes = 0;
foreach (var w in priority)
{
if (preloadedBytes + w.bytes <= preloadBudget)
{
try
{
GgmlBasicOps.PreloadQuantizedWeight(w.key, w.host, w.type, w.ne0, w.ne1, w.bytes);
preloadedBytes += w.bytes;
preloadedCount++;
continue;
}
catch (Exception)
{
// Device allocation failed despite the budget (fragmentation); treat the rest
// as streamed.
}
}
streamedBytes += w.bytes;
streamedCount++;
// Re-home the streamed weight into a page-locked private copy so its per-step PCIe
// upload uses the fast DMA path (~2x pageable mmap throughput; the mmap itself cannot
// be page-locked on Windows). Costs RAM equal to the streamed bytes, so opt out with
// DIFFUSION_PIN_STREAMED=0. Unregistered + freed in Dispose.
if (pinStreamed)
{
IntPtr copy = IntPtr.Zero;
try
{
copy = GgmlBasicOps.AlignedAlloc(w.bytes);
if (copy != IntPtr.Zero)
{
unsafe { Buffer.MemoryCopy((void*)w.host, (void*)copy, w.bytes, w.bytes); }
if (GgmlBasicOps.TryRegisterPinnedHostBuffer(copy, w.bytes))
{
_pinnedHostRegions.Add(copy);
_pinnedStreamCopies[w.key] = copy;
pinnedBytes += w.bytes;
}
else
{
GgmlBasicOps.AlignedFree(copy);
}
}
}
catch (Exception)
{
if (copy != IntPtr.Zero && !_pinnedHostRegions.Contains(copy))
GgmlBasicOps.AlignedFree(copy);
}
}
}
bool everythingFits = streamedCount == 0;
string segEnv = Environment.GetEnvironmentVariable("DIFFUSION_SEGMENTED_DECODE");
_segmentedDecode = segEnv == "1" || (segEnv != "0" && !everythingFits);
if (!everythingFits)
{
// Cap incidental device copies (prompt K/V, decode masks, activations bound by per-op
// kernels) so they cannot push VRAM past physical either. When everything fits there is
// no oversubscription risk, so the legacy unlimited behaviour is kept.
GgmlBasicOps.SetDeviceCopyBudget(copyBudgetMb * 1024 * 1024);
}
Console.WriteLine(
$" CUDA weight residency: preloaded {preloadedBytes / 1024 / 1024} MB / {preloadedCount} tensors " +
$"(free VRAM {freeBytes / 1024 / 1024} MB, headroom {headroomMb} MB); " +
(everythingFits
? "model fully resident."
: $"streaming {streamedBytes / 1024 / 1024} MB / {streamedCount} tensors per step " +
$"({pinnedBytes / 1024 / 1024} MB page-locked); segmented decode={(_segmentedDecode ? "on" : "off")}, device-copy budget {copyBudgetMb} MB."));
}
private void PrecomputeRoPE()
{
int localHalf = _localHeadDim / 2;
_ropeFreqsLocal = new float[localHalf];
for (int i = 0; i < localHalf; i++)
_ropeFreqsLocal[i] = (float)(1.0 / Math.Pow(_ropeLocalBase, 2.0 * i / _localHeadDim));
int globalHalf = _globalHeadDim / 2;
_ropeFreqsGlobal = new float[globalHalf];
_weights.TryGetValue("rope_freqs.weight", out var ft);
_ropeFreqsRawTensor = ft; // raw factors for the on-device GGML freq-factors kernel
float[] freqFactors = ft != null ? TensorToFloatArray(ft) : null;
for (int i = 0; i < globalHalf; i++)
{
double freq = 1.0 / Math.Pow(_ropeGlobalBase, 2.0 * i / _globalHeadDim);
if (freqFactors != null && i < freqFactors.Length)
freq /= freqFactors[i];
_ropeFreqsGlobal[i] = (float)freq;
}
}
private Tensor GetOnes(int dim)
{
if (!_onesByDim.TryGetValue(dim, out var t))
{
t = new Tensor(_allocator, DType.Float32, dim);
Ops.Fill(t, 1f);
_onesByDim[dim] = t;
}
return t;
}
// ===================================================================================
// Core per-step forward: runs the bidirectional [prompt|canvas] graph and returns the
// canvas logits [C, vocab] (after final softcap) as a flat float[C*vocab].
// scPrevLogits/scUse/prevTempInv drive self-conditioning (scPrevLogits is [C*vocab] raw
// logits from the previous step; scUse is the {0,1} gate; pass null/0 to disable SC).
// ===================================================================================
public unsafe float[] ForwardCanvas(int[] tokens, int promptLen,
float[] scPrevLogits = null, float scUse = 0f, float prevTempInv = 1f)
{
_swForward.Start();
int N = tokens.Length;
int P = promptLen;
int C = N - P;
int D = Config.HiddenSize;
float eps = Config.Eps;
// 1) embeddings, region-aware
long ts = Stopwatch.GetTimestamp();
Tensor hidden = Embedding(tokens); // [N, D]
Ops.Mul(hidden, hidden, MathF.Sqrt(D)); // embed_scale = sqrt(n_embd)
EmbedCanvasRegion(hidden, P, C, scPrevLogits, scUse, prevTempInv);
_tEmbed += Stopwatch.GetTimestamp() - ts;
// 2) transformer stack (all glue ops are on-device; per-N caches are rebuilt lazily)
for (int l = 0; l < Config.NumLayers; l++)
{
hidden = TransformerBlock(hidden, l, N, P, C);
}
// 4) final norm + tied lm-head over canvas positions only
Tensor normed = RMSNormOp(hidden, "output_norm.weight");
hidden.Dispose();
Tensor canvasHidden;
using (var view = normed.Narrow(0, P, C))
canvasHidden = Ops.NewContiguous(view);
normed.Dispose();
ts = Stopwatch.GetTimestamp();
Tensor logits = LinearForward(canvasHidden, "token_embd.weight"); // [C, vocab]
canvasHidden.Dispose();
_tLmHead += Stopwatch.GetTimestamp() - ts;
if (_finalLogitSoftcap > 0f)
{
Ops.Mul(logits, logits, 1f / _finalLogitSoftcap);
Ops.Tanh(logits, logits);
Ops.Mul(logits, logits, _finalLogitSoftcap);
}
int total = C * Config.VocabSize;
float[] result = ReadbackLogits(logits, total);
logits.Dispose();
_swForward.Stop();
return result;
}
/// <summary>One device->host read of the canvas logits. On GGML the storage is host-mapped so a
/// pointer copy into the reusable buffer is cheapest; on cuda/mlx GetElementsAsFloat is the single
/// read-only sync (GetFloatPtr would also force a host->device re-upload).</summary>
private unsafe float[] ReadbackLogits(Tensor logits, int total)
{
if (IsGgmlBackend)
{
if (_canvasLogits == null || _canvasLogits.Length != total)
_canvasLogits = new float[total];
float* lp = GetFloatPtr(logits);
fixed (float* dst = _canvasLogits)
Buffer.MemoryCopy(lp, dst, (long)total * 4, (long)total * 4);
return _canvasLogits;
}
return logits.GetElementsAsFloat(total);
}
/// <summary>Replace the canvas rows of <paramref name="hidden"/> with
/// rms_norm_noscale(embed*sqrt [+ self-conditioning signal]).</summary>
private unsafe void EmbedCanvasRegion(Tensor hidden, int P, int C,
float[] scPrevLogits, float scUse, float prevTempInv)
{
int D = Config.HiddenSize;
float eps = Config.Eps;
using var canvasView = hidden.Narrow(0, P, C); // [C, D] contiguous block
if (_scEnabled && scPrevLogits != null && scUse != 0f)
{
long ts = Stopwatch.GetTimestamp();
using var scSignal = ComputeSelfConditioning(scPrevLogits, C, prevTempInv); // [C, D]
Ops.Mul(scSignal, scSignal, scUse);
Ops.Add(canvasView, canvasView, scSignal);
_tSc += Stopwatch.GetTimestamp() - ts;
}
Ops.RMSNorm(canvasView, canvasView, GetOnes(D), null, eps);
}
private Tensor TransformerBlock(Tensor hidden, int layer, int N, int P, int C)
{
string prefix = $"blk.{layer}";
float eps = Config.Eps;
long ts = Stopwatch.GetTimestamp();
using var attnNormed = RMSNormOp(hidden, $"{prefix}.attn_norm.weight");
Tensor attnOut = Attention(attnNormed, layer, prefix, N, P);
// post-attention norm + residual
Ops.RMSNorm(attnOut, attnOut, _weights[$"{prefix}.post_attention_norm.weight"], null, eps);
Ops.Add(attnOut, attnOut, hidden);
hidden.Dispose();
ProfSync(attnOut);
_tAttn += Stopwatch.GetTimestamp() - ts;
attnOut = FeedForward(attnOut, layer, prefix, N);
// region-aware per-layer scalar
ApplyRegionScalar(attnOut, layer, P, C);
return attnOut;
}
/// <summary>Dense gated-GELU MLP (shared expert) + 128-expert MoE, summed, post-norm + residual.
/// Shared by the unified / prefill / decode forwards. Returns attnOut + post_ffw_norm(dense+moe).</summary>
private Tensor FeedForward(Tensor attnOut, int layer, string prefix, int N)
{
float eps = Config.Eps;
long ts = Stopwatch.GetTimestamp();
Tensor mlpOut = DenseMlp(attnOut, prefix, N);
Ops.RMSNorm(mlpOut, mlpOut, _weights[$"{prefix}.post_ffw_norm_1.weight"], null, eps);
ProfSync(mlpOut);
_tDense += Stopwatch.GetTimestamp() - ts;
ts = Stopwatch.GetTimestamp();
using (Tensor moeOut = MoEForward(attnOut, layer, prefix, N))
{
Ops.RMSNorm(moeOut, moeOut, _weights[$"{prefix}.post_ffw_norm_2.weight"], null, eps);
Ops.Add(mlpOut, mlpOut, moeOut);
}
ProfSync(mlpOut);
_tMoe += Stopwatch.GetTimestamp() - ts;
Ops.RMSNorm(mlpOut, mlpOut, _weights[$"{prefix}.post_ffw_norm.weight"], null, eps);
Ops.Add(attnOut, attnOut, mlpOut);
mlpOut.Dispose();
return attnOut;
}
private void ApplyRegionScalar(Tensor x, int layer, int P, int C)
{
float enc = _encScale[layer];
float dec = _decScale[layer];
if (P > 0 && enc != 1f)
{
using var prompt = x.Narrow(0, 0, P);
Ops.Mul(prompt, prompt, enc);
}
if (C > 0 && dec != 1f)
{
using var canvas = x.Narrow(0, P, C);
Ops.Mul(canvas, canvas, dec);
}
}
private Tensor DenseMlp(Tensor input, string prefix, int N)
{
using var normed = RMSNormOp(input, $"{prefix}.ffn_norm.weight");
Tensor gate = LinearForward(normed, $"{prefix}.ffn_gate.weight");
using (Tensor up = LinearForward(normed, $"{prefix}.ffn_up.weight"))
Ops.GELUMul(gate, gate, up); // gate = gelu(gate) * up
Tensor down = LinearForward(gate, $"{prefix}.ffn_down.weight");
gate.Dispose();
return down;
}
// ===================================================================================
// Attention: region-aware bidirectional, no KV cache. Q/K per-head RMSNorm, unweighted
// V RMSNorm, NeoX RoPE, attention scale = 1.0. Fully on-device (backend Ops): the scores are
// computed with a batched matmul, the region-aware mask is added as a precomputed additive
// bias tensor, then softmax and the value matmul. This keeps the whole attention block on the
// GPU for ggml_metal / ggml_cuda / mlx / cuda (no host round-trips).
// ===================================================================================
private Tensor Attention(Tensor input, int layer, string prefix, int N, int P)
{
bool local = _isLocal[layer];
int hd = _headDim[layer];
int qHeads = Config.NumHeads;
int kvHeads = _kvHeads[layer];
int groupSize = qHeads / kvHeads;
float eps = Config.Eps;
Tensor q = LinearForward(input, $"{prefix}.attn_q.weight"); // [N, qHeads*hd]
Tensor k = LinearForward(input, $"{prefix}.attn_k.weight"); // [N, kvHeads*hd]
Tensor v;
if (_hasVProj[layer])
{
v = LinearForward(input, $"{prefix}.attn_v.weight");
}
else
{
v = new Tensor(_allocator, DType.Float32, k.Sizes); // global layers: V == raw K
Ops.Copy(v, k);
}
// per-head Q/K norm (with weight), unweighted V norm
using (var qr = q.View(N * qHeads, hd))
Ops.RMSNorm(qr, qr, _weights[$"{prefix}.attn_q_norm.weight"], null, eps);
using (var kr = k.View(N * kvHeads, hd))
Ops.RMSNorm(kr, kr, _weights[$"{prefix}.attn_k_norm.weight"], null, eps);
using (var vr = v.View(N * kvHeads, hd))
Ops.RMSNorm(vr, vr, GetOnes(hd), null, eps);
if (!_useDeviceGlue)
{
// CPU fast path: SIMD raw-pointer RoPE + region-aware attention (host-resident).
var (cos, sin) = GetCosSinHost(N, local);
ApplyNeoXRoPERaw(q, qHeads, hd, N, cos, sin);
ApplyNeoXRoPERaw(k, kvHeads, hd, N, cos, sin);
Tensor cpuResult = new Tensor(_allocator, DType.Float32, N, qHeads * hd);
AttentionRegionAware(q, k, v, cpuResult, N, P, qHeads, kvHeads, hd, local);
q.Dispose(); k.Dispose(); v.Dispose();
using (cpuResult)
return LinearForward(cpuResult, $"{prefix}.attn_output.weight");
}
// GPU path: everything stays device-resident.
// NeoX RoPE on Q and K (in place, on-device); V is not rotated
ApplyRoPE(q, qHeads, hd, N, 0, local);
ApplyRoPE(k, kvHeads, hd, N, 0, local);
// head-first [heads, N, hd]
using var qH = ReshapeToHeads(q, qHeads, N, hd);
using var kH = ReshapeToHeads(k, kvHeads, N, hd);
using var vH = ReshapeToHeads(v, kvHeads, N, hd);
q.Dispose(); k.Dispose(); v.Dispose();
Tensor mask = GetAttentionMask(N, P, local);
using var result = AttnCoreHeadFirst(qH, kH, vH, mask, N, N, qHeads, kvHeads, hd);
return LinearForward(result, $"{prefix}.attn_output.weight");
}
/// <summary>Shared on-device attention core: expand KV heads, batched Q·Kᵀ (scale 1.0), add the
/// additive mask (null = none), softmax, batched ·V, and reshape back to flat [qLen, qHeads*hd].</summary>
private Tensor AttnCoreHeadFirst(Tensor qH, Tensor kHfull, Tensor vHfull, Tensor mask,
int qLen, int kvLen, int qHeads, int kvHeads, int hd)
{
int groupSize = qHeads / kvHeads;
using var kExp = ExpandKVHeads(kHfull, groupSize, kvLen); // [qHeads, kvLen, hd]
using var vExp = ExpandKVHeads(vHfull, groupSize, kvLen);
using var kT = kExp.Transpose(1, 2); // [qHeads, hd, kvLen]
using var scores = new Tensor(_allocator, DType.Float32, qHeads, qLen, kvLen);
Ops.AddmmBatch(scores, 0f, scores, 1f, qH, kT);
if (mask != null) Ops.Add(scores, scores, mask);
Ops.Softmax(scores, scores);
using var attnOut = new Tensor(_allocator, DType.Float32, qHeads, qLen, hd);
Ops.AddmmBatch(attnOut, 0f, attnOut, 1f, scores, vExp);
return ReshapeFromHeads(attnOut, qHeads, qLen, hd); // [qLen, qHeads*hd]
}
/// <summary>SIMD raw-pointer region-aware attention for CPU backends (the contiguous allowed-key
/// window per query lets us score only the unmasked interval). Scale 1.0.</summary>
private unsafe void AttentionRegionAware(Tensor qT, Tensor kT, Tensor vT, Tensor outT,
int N, int P, int qHeads, int kvHeads, int hd, bool local)
{
float* q = GetFloatPtr(qT);
float* k = GetFloatPtr(kT);
float* v = GetFloatPtr(vT);
float* o = GetFloatPtr(outT);
int groupSize = qHeads / kvHeads;
int swa = _slidingWindow;
Parallel.For(0, qHeads, h =>
{
int kvHead = h / groupSize;
float[] scores = new float[N];
for (int qi = 0; qi < N; qi++)
{
bool qCanvas = qi >= P;
AllowedRange(qi, qCanvas, P, N, local, swa, out int klo, out int khi);
if (khi <= klo) khi = klo + 1;
float* qVec = q + ((long)qi * qHeads + h) * hd;
float maxScore = float.NegativeInfinity;
for (int kj = klo; kj < khi; kj++)
{
float dot = VecDot(qVec, k + ((long)kj * kvHeads + kvHead) * hd, hd);
scores[kj] = dot;
if (dot > maxScore) maxScore = dot;
}
float sum = 0f;
for (int kj = klo; kj < khi; kj++)
{
float e = MathF.Exp(scores[kj] - maxScore);
scores[kj] = e;
sum += e;
}
float inv = sum > 0f ? 1f / sum : 0f;
float* oVec = o + ((long)qi * qHeads + h) * hd;
VecZero(oVec, hd);
for (int kj = klo; kj < khi; kj++)
VecScaleAdd(oVec, v + ((long)kj * kvHeads + kvHead) * hd, scores[kj] * inv, hd);
}
});
InvalidateTensorDeviceCache(outT);
}
private (float[] cos, float[] sin) GetCosSinHost(int N, bool isLocal)
{
if (_cosSinHostN != N)
{
BuildCosSin(N, _ropeFreqsLocal, out _cosLocalHost, out _sinLocalHost);
BuildCosSin(N, _ropeFreqsGlobal, out _cosGlobalHost, out _sinGlobalHost);
_cosSinHostN = N;
}
return isLocal ? (_cosLocalHost, _sinLocalHost) : (_cosGlobalHost, _sinGlobalHost);
}
// ---- on-device NeoX RoPE -----------------------------------------------------------
/// <summary>In-place NeoX RoPE on a flat [N, numHeads*hd] tensor, on-device. Local layers use a
/// single base; global layers use proportional/NTK RoPE (per-frequency <c>rope_freqs</c>). GGML
/// backends use the native freq-factors kernel; mlx uses a fused cos/sin kernel; other backends
/// fall back to a cos/sin host kernel.</summary>
private void ApplyRoPE(Tensor data, int numHeads, int hd, int N, int startPos, bool isLocal)
{
Tensor pos = GetRoPEPositions(N, numHeads, startPos);
using (var view = data.View(1, N, numHeads, hd))
{
if (isLocal)
{
Ops.RoPEEx(view, view, pos, hd, NeoXRopeMode, 0, _ropeLocalBase, 1f, 0f, 1f, 0f, 0f);
return;
}
if (IsGgmlBackend && _ropeFreqsRawTensor != null)
{
GgmlBasicOps.RoPEExWithFreqFactors(view, view, pos, _ropeFreqsRawTensor,
hd, NeoXRopeMode, 0, _ropeGlobalBase, 1f, 0f, 1f, 0f, 0f);
return;
}
}
// non-GGML global: proportional RoPE via cos/sin tables (freqs already fold in rope_freqs)
int half = hd / 2;
var (cosT, sinT) = GetCosSinGlobalTensors(N, half);
if (_backend == BackendType.Mlx &&
MlxFusedOps.TryNeoXRoPEFlatInPlace(data, cosT, sinT, numHeads, N, hd, half))
return;
ApplyNeoXRoPERaw(data, numHeads, hd, N, _cosGlobalHost, _sinGlobalHost);
}
private Tensor GetRoPEPositions(int N, int numHeads, int startPos)
{
if (_ropePosN != N)
{
foreach (var t in _ropePosCache.Values) t?.Dispose();
_ropePosCache.Clear();
_ropePosN = N;
}
long key = ((long)numHeads << 32) | (uint)startPos;
if (!_ropePosCache.TryGetValue(key, out var pos))
{
int[] positions = new int[N * numHeads];
for (int s = 0; s < N; s++)
for (int h = 0; h < numHeads; h++)
positions[s * numHeads + h] = startPos + s;
pos = CreateIntTensor(positions, N * numHeads);
_ropePosCache[key] = pos;
}
return pos;
}
private (Tensor cos, Tensor sin) GetCosSinGlobalTensors(int N, int half)
{
if (_cosSinGlobalN != N)
{
BuildCosSin(N, _ropeFreqsGlobal, out _cosGlobalHost, out _sinGlobalHost);
_cosGlobalTensor?.Dispose();
_sinGlobalTensor?.Dispose();
_cosGlobalTensor = CreateFloatTensor(_cosGlobalHost, N * half);
_sinGlobalTensor = CreateFloatTensor(_sinGlobalHost, N * half);
_cosSinGlobalN = N;
}
return (_cosGlobalTensor, _sinGlobalTensor);
}
// ---- region-aware additive attention mask -----------------------------------------
/// <summary>Returns a cached additive attention mask [maskHeads, N, N] (0 where the key is
/// attended, large-negative where masked) for the given layer type. maskHeads is 1 (broadcast)
/// on GGML/MLX and qHeads on cuda (which doesn't broadcast element-wise adds).</summary>
private Tensor GetAttentionMask(int N, int P, bool local)
{
if (_maskN != N || _maskP != P)
{
_maskLocal?.Dispose(); _maskGlobal?.Dispose();
_maskLocal = null; _maskGlobal = null;
_maskN = N; _maskP = P;
}
if (local && _maskLocal != null) return _maskLocal;
if (!local && _maskGlobal != null) return _maskGlobal;
int qHeads = Config.NumHeads;
int maskHeads = _backend == BackendType.Cuda ? qHeads : 1;
var data = new float[(long)maskHeads * N * N];
for (int qi = 0; qi < N; qi++)
{
bool qCanvas = qi >= P;
AllowedRange(qi, qCanvas, P, N, local, _slidingWindow, out int klo, out int khi);
if (khi <= klo) khi = klo + 1;
long rowBase = (long)qi * N;
for (int kj = 0; kj < N; kj++)
{
float val = (kj >= klo && kj < khi) ? 0f : MaskNeg;
for (int h = 0; h < maskHeads; h++)
data[(long)h * N * N + rowBase + kj] = val;
}
}
var mask = CreateFloatTensor(data, maskHeads, N, N);
if (local) _maskLocal = mask; else _maskGlobal = mask;
return mask;
}
// ===================================================================================
// Prompt-KV caching (PKV): PrefillPrompt computes the prompt's per-layer K/V once; DecodeCanvas
// then processes only the canvas each denoising step, reading the cached prompt K/V. This is the
// canonical llama.cpp/vLLM diffusion-gemma optimization — it removes the prompt's projection /
// attention / dense-MLP / MoE-matmul work from every step (computed once instead of S times),
// which is a large saving for long prompts (system prompt + chat history). GPU path only.
// ===================================================================================
/// <summary>Project + per-head-norm + RoPE (positions startPos..startPos+seqLen-1) and reshape Q/K/V
/// to head-first [heads, seqLen, hd]. Shared by the unified, prefill and decode forwards.</summary>
private void ComputeQKVHeadFirst(Tensor input, int layer, string prefix, int seqLen, int startPos,
out Tensor qH, out Tensor kH, out Tensor vH)
{
bool local = _isLocal[layer];
int hd = _headDim[layer];
int qHeads = Config.NumHeads;
int kvHeads = _kvHeads[layer];
float eps = Config.Eps;
Tensor q = LinearForward(input, $"{prefix}.attn_q.weight");
Tensor k = LinearForward(input, $"{prefix}.attn_k.weight");
Tensor v;
if (_hasVProj[layer]) v = LinearForward(input, $"{prefix}.attn_v.weight");
else { v = new Tensor(_allocator, DType.Float32, k.Sizes); Ops.Copy(v, k); }
using (var qr = q.View(seqLen * qHeads, hd))
Ops.RMSNorm(qr, qr, _weights[$"{prefix}.attn_q_norm.weight"], null, eps);
using (var kr = k.View(seqLen * kvHeads, hd))
Ops.RMSNorm(kr, kr, _weights[$"{prefix}.attn_k_norm.weight"], null, eps);
using (var vr = v.View(seqLen * kvHeads, hd))
Ops.RMSNorm(vr, vr, GetOnes(hd), null, eps);
ApplyRoPE(q, qHeads, hd, seqLen, startPos, local);
ApplyRoPE(k, kvHeads, hd, seqLen, startPos, local);
qH = ReshapeToHeads(q, qHeads, seqLen, hd); q.Dispose();
kH = ReshapeToHeads(k, kvHeads, seqLen, hd); k.Dispose();
vH = ReshapeToHeads(v, kvHeads, seqLen, hd); v.Dispose();
}
/// <summary>Run the prompt through every layer once and cache each layer's prompt K/V (head-first).
/// The prompt uses scaled-embedding input, causal attention, and the encoder per-layer scalar.</summary>
public void PrefillPrompt(int[] promptTokens)
{
if (!_pkvEnabled)
throw new InvalidOperationException("Prompt-KV caching is not enabled for this backend.");
AllocPromptStore();
long pt0 = Stopwatch.GetTimestamp();
_promptLen = PrefillPromptInto(promptTokens, _promptK, _promptV);
if (_stepTime) { _ = _promptK[Config.NumLayers - 1].GetElementsAsFloat(1); long now = Stopwatch.GetTimestamp(); double ms = (now - pt0) * 1000.0 / Stopwatch.Frequency; Console.Error.WriteLine($"[prefill] {_promptLen} tokens in {ms:F1} ms = {_promptLen * 1000.0 / ms:F1} tok/s"); }
}