Skip to content

Commit 762ae05

Browse files
authored
[LLADA2] documentation fixes (#13333)
documentation fixes
1 parent 5d207e7 commit 762ae05

4 files changed

Lines changed: 29 additions & 19 deletions

File tree

docs/source/en/_toctree.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,8 +580,6 @@
580580
title: Latent Diffusion
581581
- local: api/pipelines/ledits_pp
582582
title: LEDITS++
583-
- local: api/pipelines/llada2
584-
title: LLaDA2
585583
- local: api/pipelines/longcat_image
586584
title: LongCat-Image
587585
- local: api/pipelines/lumina2
@@ -672,6 +670,10 @@
672670
- local: api/pipelines/z_image
673671
title: Z-Image
674672
title: Image
673+
- sections:
674+
- local: api/pipelines/llada2
675+
title: LLaDA2
676+
title: Text
675677
- sections:
676678
- local: api/pipelines/allegro
677679
title: Allegro

docs/source/en/api/pipelines/llada2.md

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
2626
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
2727

2828
model_id = "inclusionAI/LLaDA2.1-mini"
29-
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto")
29+
model = AutoModelForCausalLM.from_pretrained(
30+
model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto"
31+
)
3032
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
3133
scheduler = BlockRefinementScheduler()
3234

@@ -46,18 +48,21 @@ print(output.texts[0])
4648

4749
## Callbacks
4850

49-
Callbacks run after each refinement step and can inspect or modify the current tokens.
51+
Callbacks run after each refinement step. Pass `callback_on_step_end_tensor_inputs` to select which tensors are
52+
included in `callback_kwargs`. In the current implementation, `block_x` (the sequence window being refined) and
53+
`transfer_index` (mask-filling commit mask) are provided; return `{"block_x": ...}` from the callback to replace the
54+
window.
5055

5156
```py
5257
def on_step_end(pipe, step, timestep, callback_kwargs):
53-
cur_x = callback_kwargs["cur_x"]
54-
# Inspect or modify `cur_x` here.
55-
return {"cur_x": cur_x}
58+
block_x = callback_kwargs["block_x"]
59+
# Inspect or modify `block_x` here.
60+
return {"block_x": block_x}
5661

5762
out = pipe(
5863
prompt="Write a short poem.",
5964
callback_on_step_end=on_step_end,
60-
callback_on_step_end_tensor_inputs=["cur_x"],
65+
callback_on_step_end_tensor_inputs=["block_x"],
6166
)
6267
```
6368

@@ -68,11 +73,13 @@ LLaDA2.1 models support two modes:
6873
| Mode | `threshold` | `editing_threshold` | `max_post_steps` |
6974
|------|-------------|---------------------|------------------|
7075
| Quality | 0.7 | 0.5 | 16 |
71-
| Speed | 0.5 | 0.0 | 16 |
76+
| Speed | 0.5 | `None` | 16 |
77+
78+
Pass `editing_threshold=None`, `0.0`, or a negative value to turn off post-mask editing.
7279

73-
For LLaDA2.0 models, disable editing by passing `editing_threshold=None`.
80+
For LLaDA2.0 models, disable editing by passing `editing_threshold=None` or `0.0`.
7481

75-
For all models: `block_length=32`, `temperature=0.0`, `steps=32`.
82+
For all models: `block_length=32`, `temperature=0.0`, `num_inference_steps=32`.
7683

7784
## LLaDA2Pipeline
7885
[[autodoc]] LLaDA2Pipeline

src/diffusers/pipelines/llada2/pipeline_llada2.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,10 @@ def __call__(
273273
threshold (`float`):
274274
Confidence threshold for committing tokens.
275275
editing_threshold (`float`, *optional*):
276-
Confidence threshold for editing already-committed (non-mask) tokens. When set, after all mask tokens
277-
in a block are resolved, the pipeline continues refining: if the model predicts a different token with
278-
confidence above this threshold, the existing token is replaced. Set to `None` or a negative value to
279-
disable editing. Defaults to `0.5`.
276+
Confidence threshold for editing already-committed (non-mask) tokens. When positive, after all mask
277+
tokens in a block are resolved, the pipeline continues refining: if the model predicts a different
278+
token with confidence above this threshold, the existing token is replaced. Set to `None`, `0.0`, or a
279+
negative value to disable editing. Defaults to `0.5`.
280280
max_post_steps (`int`):
281281
Maximum number of additional refinement iterations after all mask tokens in a block are resolved. Only
282282
used when `editing_threshold` is enabled. Defaults to `16`.
@@ -373,7 +373,7 @@ def __call__(
373373
self._num_timesteps = num_inference_steps * max(num_blocks - prefill_blocks, 0)
374374

375375
finished = torch.zeros((batch_size,), device=device, dtype=torch.bool)
376-
editing_enabled = editing_threshold is not None and editing_threshold >= 0.0
376+
editing_enabled = editing_threshold is not None and editing_threshold > 0.0
377377
global_step = 0
378378

379379
# 5. Block-wise refinement loop

src/diffusers/schedulers/scheduling_block_refinement.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class BlockRefinementScheduler(SchedulerMixin, ConfigMixin):
5757
the number of refinement steps.
5858
5959
Optionally supports editing: after all mask tokens are resolved, tokens can be replaced if the model predicts a
60-
different token with confidence above `editing_threshold`.
60+
different token with confidence above a positive `editing_threshold` (`None`, `0.0`, or negative disables editing).
6161
"""
6262

6363
order = 1
@@ -208,7 +208,8 @@ def step(
208208
threshold (`float`, *optional*):
209209
Confidence threshold for committing tokens. Defaults to config value.
210210
editing_threshold (`float`, *optional*):
211-
Confidence threshold for editing non-mask tokens. Defaults to config value.
211+
Confidence threshold for editing non-mask tokens; must be positive to enable editing. Defaults to
212+
config value.
212213
minimal_topk (`int`, *optional*):
213214
Minimum tokens to commit per step. Defaults to config value.
214215
prompt_mask (`torch.BoolTensor`, *optional*):
@@ -268,7 +269,7 @@ def step(
268269
transfer_index[b, idx] = True
269270

270271
# --- Editing transfer (non-mask, non-prompt positions) ---
271-
editing_enabled = editing_threshold is not None and editing_threshold >= 0.0
272+
editing_enabled = editing_threshold is not None and editing_threshold > 0.0
272273
editing_transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool)
273274
if editing_enabled:
274275
if prompt_mask is None:

0 commit comments

Comments
 (0)