From da828c8f494e785d7132eea5f19aa39064847a07 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Wed, 8 Apr 2026 05:05:42 +0000 Subject: [PATCH] test: add missing tests for Engine inference functions - Add test for `Engine.decodeStep` to verify seq_pos and logits - Add test for `Engine.prefill` to verify multiple token processing - Add edge case test for `Engine.greedyNextToken` with negative logits Co-authored-by: ulac000000 <132948319+ulac000000@users.noreply.github.com> --- echo-core-zig/src/inference/engine.zig | 54 ++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/echo-core-zig/src/inference/engine.zig b/echo-core-zig/src/inference/engine.zig index eae5fd0..9bda35a 100644 --- a/echo-core-zig/src/inference/engine.zig +++ b/echo-core-zig/src/inference/engine.zig @@ -1016,3 +1016,57 @@ test "Engine.greedyNextToken correctly identifies token with max logit" { eng.logits[0] = 10.0; try std.testing.expectEqual(@as(u32, 0), eng.greedyNextToken()); } + +test "Engine.greedyNextToken handles all negative logits" { + const cfg = makeTinyConfig(0, 0); + var eng = try Engine.init(cfg, null, std.testing.allocator); + defer eng.deinit(std.testing.allocator); + + @memset(eng.logits, -1000.0); + eng.logits[0] = -10.0; + eng.logits[1] = -5.0; + eng.logits[2] = -20.0; + eng.logits[3] = -100.0; + + // -5.0 is the largest (closest to 0) + try std.testing.expectEqual(@as(u32, 1), eng.greedyNextToken()); +} + +test "Engine.decodeStep advances seq_pos and returns logits" { + const cfg = makeTinyConfig(1, 8); + var eng = try Engine.init(cfg, null, std.testing.allocator); + defer eng.deinit(std.testing.allocator); + + // Initialize weight_pool to zeros + @memset(eng.weight_pool, 0); + + try std.testing.expectEqual(@as(u32, 0), eng.seq_pos); + + const logits1 = try eng.decodeStep(0); + try std.testing.expectEqual(@as(usize, cfg.vocab_size), logits1.len); + try std.testing.expectEqual(@as(u32, 1), eng.seq_pos); + + const logits2 = try eng.decodeStep(1); + try std.testing.expectEqual(@as(usize, cfg.vocab_size), logits2.len); + try std.testing.expectEqual(@as(u32, 2), eng.seq_pos); +} + +test "Engine.prefill processes multiple tokens and advances seq_pos correctly" { + const cfg = makeTinyConfig(1, 8); + var eng = try Engine.init(cfg, null, std.testing.allocator); + defer eng.deinit(std.testing.allocator); + + @memset(eng.weight_pool, 0); + + try std.testing.expectEqual(@as(u32, 0), eng.seq_pos); + + const prompt = [_]u32{ 0, 1, 2 }; + const logits = try eng.prefill(&prompt); + + try std.testing.expectEqual(@as(usize, cfg.vocab_size), logits.len); + try std.testing.expectEqual(@as(u32, 3), eng.seq_pos); + + // Check that KV cache actually registered the tokens + try std.testing.expect(eng.kv_cache != null); + try std.testing.expectEqual(@as(u32, 3), eng.kv_cache.?.seqLen()); +}