Skip to content

Commit 28945a9

Browse files
committed
issue/224 - feat: add warmup before InfiniLM generation
1 parent 0879747 commit 28945a9

1 file changed

Lines changed: 37 additions & 0 deletions

File tree

examples/jiuge.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,43 @@ def test(
236236

237237
model.reset_cache(cache_config)
238238

239+
# ---------------------------------------------------------------------------- #
240+
# Warmup
241+
# ---------------------------------------------------------------------------- #
242+
warmup_steps = 1
243+
244+
# Choose a length that approximates the real workload.
245+
# It should be long enough to trigger the correct kernel paths,
246+
# but not so long that warmup becomes unnecessarily expensive.
247+
avg_prompt_len = min(64, max(len(ids) for ids in input_ids_list))
248+
249+
# Use truncated versions of real prompts for warmup
250+
warmup_ids = [
251+
ids[:avg_prompt_len] if len(ids) >= avg_prompt_len else ids
252+
for ids in input_ids_list
253+
]
254+
255+
input_ids_infini = infinicore.from_list(warmup_ids)
256+
257+
print("=================== warmup start ===================")
258+
259+
for _ in range(warmup_steps):
260+
_ = model.generate(
261+
input_ids_infini,
262+
GenerationConfig(
263+
max_new_tokens=2, # warmup decode kernel
264+
temperature=1,
265+
top_k=1,
266+
top_p=0.8,
267+
),
268+
_measure_and_log_time=False,
269+
)
270+
271+
print("=================== warmup done ====================")
272+
273+
# Reset KV cache
274+
model.reset_cache(cache_config)
275+
239276
# ---------------------------------------------------------------------------- #
240277
# Generate
241278
# ---------------------------------------------------------------------------- #

0 commit comments

Comments
 (0)