From 4ebe6500829181a78fe21c5f8d817433701dca07 Mon Sep 17 00:00:00 2001 From: David Burton Date: Tue, 5 May 2026 18:48:19 -0400 Subject: [PATCH] feat(inference): add bf16 mixed-precision via KronosPredictor.amp_dtype Wraps the body of auto_regressive_inference in torch.autocast, gated by a new amp_dtype argument on KronosPredictor and the inference function. "bfloat16" enables bf16 autocast on the active device; None (default) keeps the existing FP32 path bit-exact. The dtype is validated eagerly in KronosPredictor.__init__ so a typo fails at construction rather than on the first predict call. z.float() before .cpu().numpy() handles the case where bf16 autocast leaves the decoded tensor in bf16 (numpy has no bf16 dtype). --- model/kronos.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/model/kronos.py b/model/kronos.py index ce4494ee..24b14dc9 100644 --- a/model/kronos.py +++ b/model/kronos.py @@ -386,8 +386,20 @@ def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_l return x -def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False): - with torch.no_grad(): +def _resolve_amp_dtype(amp_dtype): + if amp_dtype is None: + return torch.float32, False + if amp_dtype == "bfloat16": + return torch.bfloat16, True + raise ValueError( + f"Unsupported amp_dtype {amp_dtype!r}; expected 'bfloat16' or None." + ) + + +def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False, amp_dtype=None): + autocast_dtype, amp_enabled = _resolve_amp_dtype(amp_dtype) + + with torch.no_grad(), torch.autocast(device_type=x.device.type, dtype=autocast_dtype, enabled=amp_enabled): x = torch.clip(x, -clip, clip) device = x.device @@ -396,7 +408,7 @@ def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context y_stamp = y_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, y_stamp.size(1), y_stamp.size(2)).to(device) x_token = tokenizer.encode(x, half=True) - + initial_seq_len = x.size(1) batch_size = x_token[0].size(0) total_seq_len = initial_seq_len + pred_len @@ -463,7 +475,7 @@ def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context ] z = tokenizer.decode(input_tokens, half=True) z = z.reshape(-1, sample_count, z.size(1), z.size(2)) - preds = z.cpu().numpy() + preds = z.float().cpu().numpy() preds = np.mean(preds, axis=1) return preds @@ -481,7 +493,7 @@ def calc_time_stamps(x_timestamp): class KronosPredictor: - def __init__(self, model, tokenizer, device=None, max_context=512, clip=5): + def __init__(self, model, tokenizer, device=None, max_context=512, clip=5, amp_dtype=None): self.tokenizer = tokenizer self.model = model self.max_context = max_context @@ -490,7 +502,11 @@ def __init__(self, model, tokenizer, device=None, max_context=512, clip=5): self.vol_col = 'volume' self.amt_vol = 'amount' self.time_cols = ['minute', 'hour', 'weekday', 'day', 'month'] - + + # Validate eagerly so the wrong value fails at construction, not on first predict. + _resolve_amp_dtype(amp_dtype) + self.amp_dtype = amp_dtype + # Auto-detect device if not specified if device is None: if torch.cuda.is_available(): @@ -499,7 +515,7 @@ def __init__(self, model, tokenizer, device=None, max_context=512, clip=5): device = "mps" else: device = "cpu" - + self.device = device self.tokenizer = self.tokenizer.to(self.device) @@ -512,7 +528,7 @@ def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, y_stamp_tensor = torch.from_numpy(np.array(y_stamp).astype(np.float32)).to(self.device) preds = auto_regressive_inference(self.tokenizer, self.model, x_tensor, x_stamp_tensor, y_stamp_tensor, self.max_context, pred_len, - self.clip, T, top_k, top_p, sample_count, verbose) + self.clip, T, top_k, top_p, sample_count, verbose, amp_dtype=self.amp_dtype) preds = preds[:, -pred_len:, :] return preds