diff --git a/echo-core-zig/src/inference/engine.zig b/echo-core-zig/src/inference/engine.zig index eae5fd0..9bda35a 100644 --- a/echo-core-zig/src/inference/engine.zig +++ b/echo-core-zig/src/inference/engine.zig @@ -1016,3 +1016,57 @@ test "Engine.greedyNextToken correctly identifies token with max logit" { eng.logits[0] = 10.0; try std.testing.expectEqual(@as(u32, 0), eng.greedyNextToken()); } + +test "Engine.greedyNextToken handles all negative logits" { + const cfg = makeTinyConfig(0, 0); + var eng = try Engine.init(cfg, null, std.testing.allocator); + defer eng.deinit(std.testing.allocator); + + @memset(eng.logits, -1000.0); + eng.logits[0] = -10.0; + eng.logits[1] = -5.0; + eng.logits[2] = -20.0; + eng.logits[3] = -100.0; + + // -5.0 is the largest (closest to 0) + try std.testing.expectEqual(@as(u32, 1), eng.greedyNextToken()); +} + +test "Engine.decodeStep advances seq_pos and returns logits" { + const cfg = makeTinyConfig(1, 8); + var eng = try Engine.init(cfg, null, std.testing.allocator); + defer eng.deinit(std.testing.allocator); + + // Initialize weight_pool to zeros + @memset(eng.weight_pool, 0); + + try std.testing.expectEqual(@as(u32, 0), eng.seq_pos); + + const logits1 = try eng.decodeStep(0); + try std.testing.expectEqual(@as(usize, cfg.vocab_size), logits1.len); + try std.testing.expectEqual(@as(u32, 1), eng.seq_pos); + + const logits2 = try eng.decodeStep(1); + try std.testing.expectEqual(@as(usize, cfg.vocab_size), logits2.len); + try std.testing.expectEqual(@as(u32, 2), eng.seq_pos); +} + +test "Engine.prefill processes multiple tokens and advances seq_pos correctly" { + const cfg = makeTinyConfig(1, 8); + var eng = try Engine.init(cfg, null, std.testing.allocator); + defer eng.deinit(std.testing.allocator); + + @memset(eng.weight_pool, 0); + + try std.testing.expectEqual(@as(u32, 0), eng.seq_pos); + + const prompt = [_]u32{ 0, 1, 2 }; + const logits = try eng.prefill(&prompt); + + try std.testing.expectEqual(@as(usize, cfg.vocab_size), logits.len); + try std.testing.expectEqual(@as(u32, 3), eng.seq_pos); + + // Check that KV cache actually registered the tokens + try std.testing.expect(eng.kv_cache != null); + try std.testing.expectEqual(@as(u32, 3), eng.kv_cache.?.seqLen()); +}