diff --git a/echo-core-zig/src/inference/engine.zig b/echo-core-zig/src/inference/engine.zig index eae5fd0..cd08f1e 100644 --- a/echo-core-zig/src/inference/engine.zig +++ b/echo-core-zig/src/inference/engine.zig @@ -1015,4 +1015,27 @@ test "Engine.greedyNextToken correctly identifies token with max logit" { // Test early in the slice eng.logits[0] = 10.0; try std.testing.expectEqual(@as(u32, 0), eng.greedyNextToken()); + + // Edge case: all negative logits + eng.logits[0] = -5.0; + eng.logits[1] = -10.0; + eng.logits[2] = -2.0; // max is -2.0 at index 2 + eng.logits[3] = -3.0; + try std.testing.expectEqual(@as(u32, 2), eng.greedyNextToken()); + + // Edge case: tie (should return the first one encountered) + eng.logits[0] = 1.0; + eng.logits[1] = 3.0; // max + eng.logits[2] = 3.0; // tie + eng.logits[3] = 2.0; + try std.testing.expectEqual(@as(u32, 1), eng.greedyNextToken()); +} + +test "Engine.decodeStep returns logits of expected length" { + const cfg = makeTinyConfig(0, 0); + var eng = try Engine.init(cfg, null, std.testing.allocator); + defer eng.deinit(std.testing.allocator); + + const logits = try eng.decodeStep(0); + try std.testing.expectEqual(@as(usize, cfg.vocab_size), logits.len); }