Skip to content

Commit 4c35da3

Browse files
committed
Refresh sample bundle and tighten config runtime behavior
1 parent 553e247 commit 4c35da3

53 files changed

Lines changed: 1218 additions & 1604 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

app/core/config_yaml.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,15 @@
1818
# updater: winner_average | winner_copy | linear_preference
1919
# feedback_mode: scalar_rating | pairwise | top_k | winner_only | approve_reject
2020
# seed_policy: fixed-per-round | fixed-per-candidate | fixed-per-candidate-role
21+
# steering_mode: currently low_dimensional
2122
# steering_dimension: low-dimensional steering vector size, for example 3 or 5
23+
# candidate_count: visible candidates per round
2224
# image_size: WIDTHxHEIGHT, for example 512x512
25+
# trust_radius: steering search radius around the current state
26+
# anchor_strength: strength of the steering offset applied to prompt embeddings
2327
# guidance_scale: classifier-free guidance strength, for example 7.5
2428
# num_inference_steps: diffusion denoising steps, for example 15 or 30
29+
# model_name: prepared local model id or Hugging Face model id
2530
"""
2631
)
2732

app/engine/generation.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from app.bootstrap.huggingface import model_slug
88
from app.core.config import settings
9-
from app.core.schema import Candidate, Session
9+
from app.core.schema import Candidate, Session, SteeringMode
1010

1111

1212
def _color_from_candidate(candidate: Candidate) -> tuple[str, str]:
@@ -30,6 +30,15 @@ def parse_image_size(value: str) -> tuple[int, int]:
3030
raise ValueError(f"Invalid image size: {value!r}. Expected format WIDTHxHEIGHT.") from exc
3131

3232

33+
def resolve_steering_mode(session: Session) -> SteeringMode:
34+
"""Resolve and validate the session steering mode used at generation time."""
35+
36+
mode = session.config.steering_mode
37+
if mode == SteeringMode.low_dimensional:
38+
return mode
39+
raise ValueError(f"Unsupported steering mode: {mode}")
40+
41+
3342
class GenerationEngine(Protocol):
3443
"""Protocol shared by generation backends used by the orchestrator."""
3544

@@ -55,6 +64,7 @@ def __init__(self, artifacts_dir: Path | None = None) -> None:
5564
def render_candidate(self, session: Session, candidate: Candidate) -> Candidate:
5665
"""Render one candidate to an SVG artifact and attach its public path."""
5766

67+
steering_mode = resolve_steering_mode(session)
5868
primary, secondary = _color_from_candidate(candidate)
5969
width, height = parse_image_size(session.config.image_size)
6070
path = self.artifacts_dir / f"{candidate.id}.svg"
@@ -77,6 +87,7 @@ def render_candidate(self, session: Session, candidate: Candidate) -> Candidate:
7787
<text x="40" y="330" fill="white" font-size="18" font-family="Arial">CFG: {session.config.guidance_scale:.2f}</text>
7888
<text x="40" y="365" fill="white" font-size="18" font-family="Arial">Steps: {session.config.num_inference_steps}</text>
7989
<text x="40" y="400" fill="white" font-size="18" font-family="Arial">Anchor strength: {session.config.anchor_strength:.2f}</text>
90+
<text x="40" y="435" fill="white" font-size="18" font-family="Arial">Steering mode: {escape(steering_mode.value)}</text>
8091
</svg>"""
8192
path.write_text(svg, encoding="utf-8")
8293
candidate.image_path = f"/artifacts/{path.name}"
@@ -88,6 +99,7 @@ def render_candidate(self, session: Session, candidate: Candidate) -> Candidate:
8899
"num_inference_steps": session.config.num_inference_steps,
89100
"model_source": session.config.model_name,
90101
"anchor_strength": session.config.anchor_strength,
102+
"steering_mode": steering_mode.value,
91103
}
92104
)
93105
return candidate
@@ -265,6 +277,7 @@ def _steering_offset(self, prompt_embeds, z, anchor_strength: float):
265277
def _encode_steered_embeddings(self, session: Session, candidate: Candidate):
266278
"""Encode prompt text, then apply a deterministic steering offset."""
267279

280+
steering_mode = resolve_steering_mode(session)
268281
pipe = self._load_pipeline(self._resolve_model_source(session))
269282
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
270283
prompt=session.prompt,
@@ -273,7 +286,14 @@ def _encode_steered_embeddings(self, session: Session, candidate: Candidate):
273286
do_classifier_free_guidance=True,
274287
negative_prompt=session.negative_prompt or "",
275288
)
276-
steered_prompt_embeds = prompt_embeds + self._steering_offset(prompt_embeds, candidate.z, session.config.anchor_strength)
289+
if steering_mode == SteeringMode.low_dimensional:
290+
steered_prompt_embeds = prompt_embeds + self._steering_offset(
291+
prompt_embeds,
292+
candidate.z,
293+
session.config.anchor_strength,
294+
)
295+
else:
296+
raise ValueError(f"Unsupported steering mode: {steering_mode}")
277297
return steered_prompt_embeds, negative_prompt_embeds
278298

279299
def render_candidate(self, session: Session, candidate: Candidate) -> Candidate:
@@ -309,7 +329,7 @@ def render_candidate(self, session: Session, candidate: Candidate) -> Candidate:
309329
"num_inference_steps": num_inference_steps,
310330
"model_source": model_source,
311331
"anchor_strength": session.config.anchor_strength,
312-
"steering_mode": session.config.steering_mode,
332+
"steering_mode": resolve_steering_mode(session).value,
313333
}
314334
)
315335
return candidate

app/engine/orchestrator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,16 @@ def _widen_first_round_candidates(session: Session, proposed_candidates: list[Ca
465465
boost_radius = min(max(session.config.trust_radius * 1.55, 0.34), 0.72)
466466
min_radius = min(max(session.config.trust_radius * 0.95, 0.24), boost_radius)
467467
for index, candidate in enumerate(proposed_candidates):
468+
if candidate.sampler_role == "exploit":
469+
exploit_radius = min(max(session.config.trust_radius * 0.35, 0.12), 0.24)
470+
boosted_z = clamp_vector(list(candidate.z), exploit_radius)
471+
candidate.z = boosted_z
472+
candidate.generation_params["first_round_diversity_boost"] = True
473+
candidate.generation_params["first_round_diversity_scale"] = 0.6
474+
candidate.generation_params["first_round_role_behavior"] = "keep_exploit_close"
475+
boosted_candidates.append(candidate)
476+
continue
477+
468478
spread_direction = Orchestrator._first_round_spread_direction(index, dimensions)
469479
scale = 1.15 + (0.1 * index)
470480
blended = [

app/frontend/templates/setup.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ <h1>Start from your text prompt</h1>
4747
<input name="description" value="Initial real generation workflow">
4848
</label>
4949
<label>
50-
<span class="field-label">Session configuration (YAML) <span class="help-tip" tabindex="0" role="note" aria-label="YAML configuration help" data-tooltip="This YAML controls sampler, updater, feedback mode, seeds, candidate count, and generation settings for this one session.">?</span></span>
50+
<span class="field-label">Session configuration (YAML) <span class="help-tip" tabindex="0" role="note" aria-label="YAML configuration help" data-tooltip="This YAML controls sampler, updater, feedback mode, seeds, candidate count, and generation settings for this one session.">?</span> <a href="https://apartsinprojects.github.io/StableSteering/docs/configuration_manual.html" target="_blank" rel="noopener noreferrer">Open configuration manual</a></span>
5151
<textarea name="config_yaml" id="config-yaml-editor" class="yaml-editor" spellcheck="false">{{ config_yaml }}</textarea>
5252
</label>
5353
<p class="hint">All per-session strategy values live in this YAML document. Edit it before creating a session, or reload the default template.</p>

app/samplers/random_local.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import math
4+
35
from app.core.schema import Candidate, Session
46
from app.samplers.base import clamp_vector, make_rng
57

@@ -14,22 +16,74 @@ def propose(self, session: Session, seed: int) -> list[Candidate]:
1416

1517
rng = make_rng(seed)
1618
candidates = []
19+
dimensions = max(1, len(session.current_z))
20+
exploit_radius = min(session.config.trust_radius * 0.28, 0.18)
21+
explore_radius = min(max(session.config.trust_radius * 0.9, 0.28), session.config.trust_radius)
1722
for index in range(session.config.candidate_count):
18-
offset = [rng.uniform(-0.35, 0.35) for _ in session.current_z]
19-
z = clamp_vector(
20-
[current + delta for current, delta in zip(session.current_z, offset, strict=False)],
21-
session.config.trust_radius,
22-
)
23+
if index == 0:
24+
role = "exploit"
25+
offset = [rng.uniform(-0.12, 0.12) for _ in session.current_z]
26+
z = clamp_vector(
27+
[current + delta for current, delta in zip(session.current_z, offset, strict=False)],
28+
exploit_radius,
29+
)
30+
else:
31+
role = "explore"
32+
direction = self._explore_direction(index - 1, dimensions)
33+
jitter = [rng.uniform(-0.08, 0.08) for _ in session.current_z]
34+
target_radius = min(explore_radius, max(explore_radius * (0.82 + (0.06 * ((index - 1) % 3))), 0.24))
35+
z = clamp_vector(
36+
[
37+
current + (axis * target_radius) + noise
38+
for current, axis, noise in zip(session.current_z, direction, jitter, strict=False)
39+
],
40+
session.config.trust_radius,
41+
)
42+
length = math.sqrt(sum(value * value for value in z))
43+
minimum_radius = min(max(session.config.trust_radius * 0.58, 0.22), session.config.trust_radius)
44+
if 0.0 < length < minimum_radius:
45+
z = clamp_vector([value * (minimum_radius / length) for value in z], session.config.trust_radius)
46+
2347
candidates.append(
2448
Candidate(
2549
round_id="",
2650
candidate_index=index,
2751
z=z,
28-
sampler_role="explore" if index else "exploit",
52+
sampler_role=role,
2953
predicted_score=sum(z),
3054
predicted_uncertainty=max(0.05, 0.3 - (0.03 * index)),
3155
seed=seed,
32-
generation_params={"image_size": session.config.image_size},
56+
generation_params={
57+
"image_size": session.config.image_size,
58+
"proposal_role_radius": exploit_radius if role == "exploit" else explore_radius,
59+
},
3360
)
3461
)
3562
return candidates
63+
64+
@staticmethod
65+
def _explore_direction(index: int, dimensions: int) -> list[float]:
66+
"""Return a separated exploratory direction for one candidate slot."""
67+
68+
vector = [0.0 for _ in range(dimensions)]
69+
primary_axis = index % dimensions
70+
secondary_axis = (index + 1) % dimensions
71+
tertiary_axis = (index + 2) % dimensions
72+
primary_sign = 1.0 if index % 2 == 0 else -1.0
73+
secondary_sign = -1.0 if index % 3 == 1 else 1.0
74+
tertiary_sign = -1.0 if index % 4 >= 2 else 1.0
75+
76+
vector[primary_axis] = 1.0 * primary_sign
77+
if dimensions > 1:
78+
vector[secondary_axis] += 0.45 * secondary_sign
79+
if dimensions > 2:
80+
vector[tertiary_axis] += 0.22 * tertiary_sign
81+
if dimensions > 3:
82+
extra_axis = (index + 3) % dimensions
83+
vector[extra_axis] += 0.16 if index % 2 == 0 else -0.16
84+
85+
length = math.sqrt(sum(value * value for value in vector))
86+
if length == 0.0:
87+
vector[0] = 1.0
88+
return vector
89+
return [value / length for value in vector]

docs/configuration_manual.md

Lines changed: 6 additions & 6 deletions
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)