File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -236,6 +236,43 @@ def test(
236236
237237 model .reset_cache (cache_config )
238238
239+ # ---------------------------------------------------------------------------- #
240+ # Warmup
241+ # ---------------------------------------------------------------------------- #
242+ warmup_steps = 1
243+
244+ # Choose a length that approximates the real workload.
245+ # It should be long enough to trigger the correct kernel paths,
246+ # but not so long that warmup becomes unnecessarily expensive.
247+ avg_prompt_len = min (64 , max (len (ids ) for ids in input_ids_list ))
248+
249+ # Use truncated versions of real prompts for warmup
250+ warmup_ids = [
251+ ids [:avg_prompt_len ] if len (ids ) >= avg_prompt_len else ids
252+ for ids in input_ids_list
253+ ]
254+
255+ input_ids_infini = infinicore .from_list (warmup_ids )
256+
257+ print ("=================== warmup start ===================" )
258+
259+ for _ in range (warmup_steps ):
260+ _ = model .generate (
261+ input_ids_infini ,
262+ GenerationConfig (
263+ max_new_tokens = 2 , # warmup decode kernel
264+ temperature = 1 ,
265+ top_k = 1 ,
266+ top_p = 0.8 ,
267+ ),
268+ _measure_and_log_time = False ,
269+ )
270+
271+ print ("=================== warmup done ====================" )
272+
273+ # Reset KV cache
274+ model .reset_cache (cache_config )
275+
239276 # ---------------------------------------------------------------------------- #
240277 # Generate
241278 # ---------------------------------------------------------------------------- #
You can’t perform that action at this time.
0 commit comments