66
77from app .bootstrap .huggingface import model_slug
88from app .core .config import settings
9- from app .core .schema import Candidate , Session
9+ from app .core .schema import Candidate , Session , SteeringMode
1010
1111
1212def _color_from_candidate (candidate : Candidate ) -> tuple [str , str ]:
@@ -30,6 +30,15 @@ def parse_image_size(value: str) -> tuple[int, int]:
3030 raise ValueError (f"Invalid image size: { value !r} . Expected format WIDTHxHEIGHT." ) from exc
3131
3232
33+ def resolve_steering_mode (session : Session ) -> SteeringMode :
34+ """Resolve and validate the session steering mode used at generation time."""
35+
36+ mode = session .config .steering_mode
37+ if mode == SteeringMode .low_dimensional :
38+ return mode
39+ raise ValueError (f"Unsupported steering mode: { mode } " )
40+
41+
3342class GenerationEngine (Protocol ):
3443 """Protocol shared by generation backends used by the orchestrator."""
3544
@@ -55,6 +64,7 @@ def __init__(self, artifacts_dir: Path | None = None) -> None:
5564 def render_candidate (self , session : Session , candidate : Candidate ) -> Candidate :
5665 """Render one candidate to an SVG artifact and attach its public path."""
5766
67+ steering_mode = resolve_steering_mode (session )
5868 primary , secondary = _color_from_candidate (candidate )
5969 width , height = parse_image_size (session .config .image_size )
6070 path = self .artifacts_dir / f"{ candidate .id } .svg"
@@ -77,6 +87,7 @@ def render_candidate(self, session: Session, candidate: Candidate) -> Candidate:
7787<text x="40" y="330" fill="white" font-size="18" font-family="Arial">CFG: { session .config .guidance_scale :.2f} </text>
7888<text x="40" y="365" fill="white" font-size="18" font-family="Arial">Steps: { session .config .num_inference_steps } </text>
7989<text x="40" y="400" fill="white" font-size="18" font-family="Arial">Anchor strength: { session .config .anchor_strength :.2f} </text>
90+ <text x="40" y="435" fill="white" font-size="18" font-family="Arial">Steering mode: { escape (steering_mode .value )} </text>
8091</svg>"""
8192 path .write_text (svg , encoding = "utf-8" )
8293 candidate .image_path = f"/artifacts/{ path .name } "
@@ -88,6 +99,7 @@ def render_candidate(self, session: Session, candidate: Candidate) -> Candidate:
8899 "num_inference_steps" : session .config .num_inference_steps ,
89100 "model_source" : session .config .model_name ,
90101 "anchor_strength" : session .config .anchor_strength ,
102+ "steering_mode" : steering_mode .value ,
91103 }
92104 )
93105 return candidate
@@ -265,6 +277,7 @@ def _steering_offset(self, prompt_embeds, z, anchor_strength: float):
265277 def _encode_steered_embeddings (self , session : Session , candidate : Candidate ):
266278 """Encode prompt text, then apply a deterministic steering offset."""
267279
280+ steering_mode = resolve_steering_mode (session )
268281 pipe = self ._load_pipeline (self ._resolve_model_source (session ))
269282 prompt_embeds , negative_prompt_embeds = pipe .encode_prompt (
270283 prompt = session .prompt ,
@@ -273,7 +286,14 @@ def _encode_steered_embeddings(self, session: Session, candidate: Candidate):
273286 do_classifier_free_guidance = True ,
274287 negative_prompt = session .negative_prompt or "" ,
275288 )
276- steered_prompt_embeds = prompt_embeds + self ._steering_offset (prompt_embeds , candidate .z , session .config .anchor_strength )
289+ if steering_mode == SteeringMode .low_dimensional :
290+ steered_prompt_embeds = prompt_embeds + self ._steering_offset (
291+ prompt_embeds ,
292+ candidate .z ,
293+ session .config .anchor_strength ,
294+ )
295+ else :
296+ raise ValueError (f"Unsupported steering mode: { steering_mode } " )
277297 return steered_prompt_embeds , negative_prompt_embeds
278298
279299 def render_candidate (self , session : Session , candidate : Candidate ) -> Candidate :
@@ -309,7 +329,7 @@ def render_candidate(self, session: Session, candidate: Candidate) -> Candidate:
309329 "num_inference_steps" : num_inference_steps ,
310330 "model_source" : model_source ,
311331 "anchor_strength" : session .config .anchor_strength ,
312- "steering_mode" : session . config . steering_mode ,
332+ "steering_mode" : resolve_steering_mode ( session ). value ,
313333 }
314334 )
315335 return candidate
0 commit comments