From 313ef9efd9e77f9ef7dfc748861657ebd44ad232 Mon Sep 17 00:00:00 2001 From: moscowmule2240 Date: Thu, 4 Jun 2026 14:36:41 +0900 Subject: [PATCH] feat(inference): add average_samples flag to auto_regressive_inference By default (average_samples=True) behavior is unchanged: predictions are averaged over the sample dimension. With average_samples=False the per-sample paths are returned with shape (batch, sample_count, total_seq, n_feat), letting callers obtain the full Monte Carlo distribution (probabilistic forecasting / uncertainty) without re-implementing the autoregressive loop. Backward compatible. Adds a docstring for the flag and tests/test_average_samples.py (fake model, no weight download) covering the shape behavior and mean-over-samples consistency. --- model/kronos.py | 21 +++++++- tests/test_average_samples.py | 92 +++++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 2 deletions(-) create mode 100644 tests/test_average_samples.py diff --git a/model/kronos.py b/model/kronos.py index ce4494ee..06c29e46 100644 --- a/model/kronos.py +++ b/model/kronos.py @@ -386,7 +386,23 @@ def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_l return x -def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False): +def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False, average_samples=True): + """Autoregressively sample future tokens and decode them to the feature space. + + Generates ``sample_count`` Monte Carlo trajectories in parallel. + + Args: + average_samples (bool): If True (default), average over the sample + dimension and return shape ``(batch, total_seq, n_feat)`` -- the + original, backward-compatible behavior. If False, keep the + per-sample trajectories and return shape + ``(batch, sample_count, total_seq, n_feat)`` so callers can access + the full predictive distribution (e.g. probabilistic forecasting / + uncertainty estimation). + + Returns: + np.ndarray: Decoded predictions; shape depends on ``average_samples``. + """ with torch.no_grad(): x = torch.clip(x, -clip, clip) @@ -464,7 +480,8 @@ def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context z = tokenizer.decode(input_tokens, half=True) z = z.reshape(-1, sample_count, z.size(1), z.size(2)) preds = z.cpu().numpy() - preds = np.mean(preds, axis=1) + if average_samples: + preds = np.mean(preds, axis=1) return preds diff --git a/tests/test_average_samples.py b/tests/test_average_samples.py new file mode 100644 index 00000000..2503b497 --- /dev/null +++ b/tests/test_average_samples.py @@ -0,0 +1,92 @@ +"""Tests for the ``average_samples`` flag of ``auto_regressive_inference``. + +Uses minimal fake tokenizer / model and a stubbed ``sample_from_logits`` so the +generation loop runs without downloading any pretrained weights. Verifies that: + - ``average_samples=False`` keeps the per-sample dimension; + - ``average_samples=True`` (default) collapses it (backward-compatible shape); + - the averaged output equals the mean over samples of the per-sample output. +""" +from dataclasses import dataclass + +import numpy as np +import pytest + +torch = pytest.importorskip("torch") + +import model.kronos as kronos_mod # noqa: E402 +from model.kronos import auto_regressive_inference # noqa: E402 + +_BATCH = 1 +_SEQ = 4 +_PRED = 2 +_SAMPLES = 3 +_FEAT = 6 +_STAMP = 4 +_MAX_CTX = 8 # >= seq + pred so the buffer never rolls (simple path) +_VOCAB = 8 + + +@dataclass +class _FakeTokenizer: + def encode(self, x, half=False): + bsz, slen = x.size(0), x.size(1) + return [ + torch.zeros(bsz, slen, dtype=torch.long), + torch.zeros(bsz, slen, dtype=torch.long), + ] + + def decode(self, input_tokens, half=False): + pre = input_tokens[0] + return torch.zeros(pre.size(0), pre.size(1), _FEAT) + + +@dataclass +class _FakeModel: + def decode_s1(self, pre, post, stamp): + bsz, slen = pre.size(0), pre.size(1) + return torch.zeros(bsz, slen, _VOCAB), torch.zeros(bsz, slen, 4) + + def decode_s2(self, context, sample_pre): + return torch.zeros(context.size(0), context.size(1), _VOCAB) + + +def _fake_sample_from_logits(logits, temperature=1.0, top_k=0, top_p=1.0, sample_logits=True): + return torch.zeros(logits.size(0), 1, dtype=torch.long) + + +def _inputs(): + return ( + torch.zeros(_BATCH, _SEQ, _FEAT), + torch.zeros(_BATCH, _SEQ, _STAMP), + torch.zeros(_BATCH, _PRED, _STAMP), + ) + + +@pytest.fixture +def patched_sampler(monkeypatch): + monkeypatch.setattr(kronos_mod, "sample_from_logits", _fake_sample_from_logits) + + +def _run(average_samples): + x, x_stamp, y_stamp = _inputs() + return auto_regressive_inference( + _FakeTokenizer(), _FakeModel(), x, x_stamp, y_stamp, + max_context=_MAX_CTX, pred_len=_PRED, sample_count=_SAMPLES, + average_samples=average_samples, + ) + + +def test_average_samples_false_keeps_sample_dimension(patched_sampler): + preds = _run(average_samples=False) + assert preds.shape == (_BATCH, _SAMPLES, _SEQ + _PRED, _FEAT) + + +def test_average_samples_true_collapses_sample_dimension(patched_sampler): + preds = _run(average_samples=True) + assert preds.shape == (_BATCH, _SEQ + _PRED, _FEAT) + + +def test_averaged_output_is_mean_over_samples(patched_sampler): + paths = _run(average_samples=False) + averaged = _run(average_samples=True) + np.testing.assert_allclose(paths.mean(axis=1), averaged, rtol=1e-6, atol=1e-6)