@@ -26,7 +26,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
2626from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
2727
2828model_id = " inclusionAI/LLaDA2.1-mini"
29- model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code = True , dtype = torch.bfloat16, device_map = " auto" )
29+ model = AutoModelForCausalLM.from_pretrained(
30+ model_id, trust_remote_code = True , dtype = torch.bfloat16, device_map = " auto"
31+ )
3032tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code = True )
3133scheduler = BlockRefinementScheduler()
3234
@@ -46,18 +48,21 @@ print(output.texts[0])
4648
4749## Callbacks
4850
49- Callbacks run after each refinement step and can inspect or modify the current tokens.
51+ Callbacks run after each refinement step. Pass ` callback_on_step_end_tensor_inputs ` to select which tensors are
52+ included in ` callback_kwargs ` . In the current implementation, ` block_x ` (the sequence window being refined) and
53+ ` transfer_index ` (mask-filling commit mask) are provided; return ` {"block_x": ...} ` from the callback to replace the
54+ window.
5055
5156``` py
5257def on_step_end (pipe , step , timestep , callback_kwargs ):
53- cur_x = callback_kwargs[" cur_x " ]
54- # Inspect or modify `cur_x ` here.
55- return {" cur_x " : cur_x }
58+ block_x = callback_kwargs[" block_x " ]
59+ # Inspect or modify `block_x ` here.
60+ return {" block_x " : block_x }
5661
5762out = pipe(
5863 prompt = " Write a short poem." ,
5964 callback_on_step_end = on_step_end,
60- callback_on_step_end_tensor_inputs = [" cur_x " ],
65+ callback_on_step_end_tensor_inputs = [" block_x " ],
6166)
6267```
6368
@@ -68,11 +73,13 @@ LLaDA2.1 models support two modes:
6873| Mode | ` threshold ` | ` editing_threshold ` | ` max_post_steps ` |
6974| ------| -------------| ---------------------| ------------------|
7075| Quality | 0.7 | 0.5 | 16 |
71- | Speed | 0.5 | 0.0 | 16 |
76+ | Speed | 0.5 | ` None ` | 16 |
77+
78+ Pass ` editing_threshold=None ` , ` 0.0 ` , or a negative value to turn off post-mask editing.
7279
73- For LLaDA2.0 models, disable editing by passing ` editing_threshold=None ` .
80+ For LLaDA2.0 models, disable editing by passing ` editing_threshold=None ` or ` 0.0 ` .
7481
75- For all models: ` block_length=32 ` , ` temperature=0.0 ` , ` steps =32` .
82+ For all models: ` block_length=32 ` , ` temperature=0.0 ` , ` num_inference_steps =32` .
7683
7784## LLaDA2Pipeline
7885[[ autodoc]] LLaDA2Pipeline
0 commit comments