Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions echo-core-zig/src/inference/engine.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1016,3 +1016,57 @@ test "Engine.greedyNextToken correctly identifies token with max logit" {
eng.logits[0] = 10.0;
try std.testing.expectEqual(@as(u32, 0), eng.greedyNextToken());
}

test "Engine.greedyNextToken handles all negative logits" {
const cfg = makeTinyConfig(0, 0);
var eng = try Engine.init(cfg, null, std.testing.allocator);
defer eng.deinit(std.testing.allocator);

@memset(eng.logits, -1000.0);
eng.logits[0] = -10.0;
eng.logits[1] = -5.0;
eng.logits[2] = -20.0;
eng.logits[3] = -100.0;

// -5.0 is the largest (closest to 0)
try std.testing.expectEqual(@as(u32, 1), eng.greedyNextToken());
}

test "Engine.decodeStep advances seq_pos and returns logits" {
const cfg = makeTinyConfig(1, 8);
var eng = try Engine.init(cfg, null, std.testing.allocator);
defer eng.deinit(std.testing.allocator);

// Initialize weight_pool to zeros
@memset(eng.weight_pool, 0);

try std.testing.expectEqual(@as(u32, 0), eng.seq_pos);

const logits1 = try eng.decodeStep(0);
try std.testing.expectEqual(@as(usize, cfg.vocab_size), logits1.len);
try std.testing.expectEqual(@as(u32, 1), eng.seq_pos);

const logits2 = try eng.decodeStep(1);
try std.testing.expectEqual(@as(usize, cfg.vocab_size), logits2.len);
try std.testing.expectEqual(@as(u32, 2), eng.seq_pos);
}

test "Engine.prefill processes multiple tokens and advances seq_pos correctly" {
const cfg = makeTinyConfig(1, 8);
var eng = try Engine.init(cfg, null, std.testing.allocator);
defer eng.deinit(std.testing.allocator);

@memset(eng.weight_pool, 0);

try std.testing.expectEqual(@as(u32, 0), eng.seq_pos);

const prompt = [_]u32{ 0, 1, 2 };
const logits = try eng.prefill(&prompt);

try std.testing.expectEqual(@as(usize, cfg.vocab_size), logits.len);
try std.testing.expectEqual(@as(u32, 3), eng.seq_pos);

// Check that KV cache actually registered the tokens
try std.testing.expect(eng.kv_cache != null);
try std.testing.expectEqual(@as(u32, 3), eng.kv_cache.?.seqLen());
}