Skip to content

Commit 0f12508

Browse files
committed
Deepen paper package and sync docs
1 parent 36d8cf1 commit 0f12508

441 files changed

Lines changed: 145309 additions & 2448 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ StableSteering is built around that gap. It turns generation into a feedback loo
4444

4545
That makes the project useful both as:
4646

47-
- a research platform for studying human-in-the-loop steering
47+
- a research platform for iterative preference-guided steering
4848
- a concrete prototype for interactive generative workflows
4949

5050
## Current MVP

app/bootstrap/experiment_models.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from __future__ import annotations
2+
3+
from pathlib import Path
4+
5+
from huggingface_hub.utils import LocalEntryNotFoundError
6+
7+
from app.core.config import settings
8+
9+
_CLIP_CACHE: dict[tuple[str, str], tuple[object, object]] = {}
10+
_DINO_CACHE: dict[tuple[str, str], tuple[object, object]] = {}
11+
12+
13+
def huggingface_cache_dir() -> Path:
14+
path = settings.huggingface_cache_dir
15+
path.mkdir(parents=True, exist_ok=True)
16+
return path
17+
18+
19+
def get_clip_components(model_id: str, device: str, *, local_only: bool = True):
20+
key = (model_id, device)
21+
cached = _CLIP_CACHE.get(key)
22+
if cached is not None:
23+
return cached
24+
25+
from transformers import CLIPModel, CLIPProcessor
26+
27+
cache_dir = huggingface_cache_dir()
28+
try:
29+
model = CLIPModel.from_pretrained(
30+
model_id,
31+
cache_dir=str(cache_dir),
32+
local_files_only=local_only,
33+
).to(device)
34+
processor = CLIPProcessor.from_pretrained(
35+
model_id,
36+
cache_dir=str(cache_dir),
37+
local_files_only=local_only,
38+
)
39+
except (OSError, LocalEntryNotFoundError) as exc:
40+
if local_only:
41+
raise RuntimeError(
42+
f"CLIP model '{model_id}' is not available in the local cache. "
43+
"Run scripts/preload_experiment_models.py first."
44+
) from exc
45+
raise
46+
model.eval()
47+
_CLIP_CACHE[key] = (model, processor)
48+
return model, processor
49+
50+
51+
def get_dino_components(model_id: str, device: str, *, local_only: bool = True):
52+
key = (model_id, device)
53+
cached = _DINO_CACHE.get(key)
54+
if cached is not None:
55+
return cached
56+
57+
from transformers import AutoImageProcessor, AutoModel
58+
59+
cache_dir = huggingface_cache_dir()
60+
try:
61+
processor = AutoImageProcessor.from_pretrained(
62+
model_id,
63+
cache_dir=str(cache_dir),
64+
local_files_only=local_only,
65+
)
66+
model = AutoModel.from_pretrained(
67+
model_id,
68+
cache_dir=str(cache_dir),
69+
local_files_only=local_only,
70+
).to(device)
71+
except (OSError, LocalEntryNotFoundError) as exc:
72+
if local_only:
73+
raise RuntimeError(
74+
f"DINO model '{model_id}' is not available in the local cache. "
75+
"Run scripts/preload_experiment_models.py first."
76+
) from exc
77+
raise
78+
model.eval()
79+
_DINO_CACHE[key] = (processor, model)
80+
return processor, model

app/bootstrap/huggingface.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,26 @@ def prepare_huggingface_model(
7979
model_dir.mkdir(parents=True, exist_ok=True)
8080

8181
allow_patterns = build_allow_patterns(extra_patterns)
82+
manifest_path = model_dir / "prepare_manifest.json"
83+
model_index_path = model_dir / "model_index.json"
84+
if manifest_path.exists() and model_index_path.exists():
85+
try:
86+
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
87+
except json.JSONDecodeError:
88+
manifest = None
89+
if (
90+
isinstance(manifest, dict)
91+
and str(manifest.get("model_id")) == model_id
92+
and manifest.get("revision") == revision
93+
and list(manifest.get("allow_patterns", [])) == allow_patterns
94+
):
95+
return {
96+
"model_id": model_id,
97+
"model_dir": str(model_dir),
98+
"snapshot_path": str(model_dir),
99+
"manifest_path": str(manifest_path),
100+
}
101+
82102
snapshot_path = snapshot_download(
83103
repo_id=model_id,
84104
revision=revision,

app/core/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class Settings(BaseSettings):
1414
data_dir: Path = Path("data")
1515
artifacts_dir_name: str = "artifacts"
1616
models_dir: Path = Path("models")
17+
huggingface_cache_dir: Path = Path("models") / "hf_cache"
1718
traces_dir_name: str = "traces"
1819
default_candidate_count: int = 4
1920
default_image_size: str = "512x512"

app/core/config_yaml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
# Edit any of these values before creating a new session.
1515
# This YAML is reloaded fresh for each setup page visit or reset action.
1616
#
17-
# sampler: random_local | exploit_orthogonal | uncertainty_guided | axis_sweep | incumbent_mix | diversity_shell | line_search | plateau_escape | annealed_shell | spherical_cover
18-
# updater: winner_average | winner_copy | linear_preference | score_weighted_preference | contrastive_preference | softmax_preference | borda_preference | bradley_terry_preference
17+
# sampler: random_local | exploit_orthogonal | uncertainty_guided | axis_sweep | incumbent_mix | diversity_shell | line_search | plateau_escape | annealed_shell | spherical_cover | two_scale_cover | quality_diversity_mix
18+
# updater: winner_average | winner_copy | linear_preference | score_weighted_preference | contrastive_preference | softmax_preference | borda_preference | bradley_terry_preference | challenger_mixture_preference | plackett_luce_preference
1919
# feedback_mode: scalar_rating | pairwise | top_k | winner_only | approve_reject
2020
# seed_policy: fixed-per-round | fixed-per-candidate | fixed-per-candidate-role
2121
# steering_mode: currently low_dimensional

app/core/schema.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ class SamplerType(str, Enum):
7070
plateau_escape = "plateau_escape"
7171
annealed_shell = "annealed_shell"
7272
spherical_cover = "spherical_cover"
73+
two_scale_cover = "two_scale_cover"
74+
quality_diversity_mix = "quality_diversity_mix"
7375

7476

7577
class UpdaterType(str, Enum):
@@ -81,6 +83,8 @@ class UpdaterType(str, Enum):
8183
softmax_preference = "softmax_preference"
8284
borda_preference = "borda_preference"
8385
bradley_terry_preference = "bradley_terry_preference"
86+
challenger_mixture_preference = "challenger_mixture_preference"
87+
plackett_luce_preference = "plackett_luce_preference"
8488

8589

8690
class SteeringMode(str, Enum):

app/engine/orchestrator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,18 @@
3333
from app.samplers.incumbent_mix import IncumbentMixSampler
3434
from app.samplers.line_search import LineSearchSampler
3535
from app.samplers.plateau_escape import PlateauEscapeSampler
36+
from app.samplers.quality_diversity_mix import QualityDiversityMixSampler
3637
from app.samplers.random_local import RandomLocalSampler
3738
from app.samplers.spherical_cover import SphericalCoverSampler
39+
from app.samplers.two_scale_cover import TwoScaleCoverSampler
3840
from app.samplers.uncertainty import UncertaintyGuidedSampler
3941
from app.storage.repository import JsonRepository
4042
from app.updaters.contrastive_pref import ContrastivePreferenceUpdater
4143
from app.updaters.borda_pref import BordaPreferenceUpdater
4244
from app.updaters.bradley_terry_pref import BradleyTerryPreferenceUpdater
45+
from app.updaters.challenger_mixture import ChallengerMixturePreferenceUpdater
4346
from app.updaters.linear_pref import LinearPreferenceUpdater
47+
from app.updaters.plackett_luce_pref import PlackettLucePreferenceUpdater
4448
from app.updaters.softmax_pref import SoftmaxPreferenceUpdater
4549
from app.updaters.score_weighted import ScoreWeightedPreferenceUpdater
4650
from app.updaters.winner_average import WinnerAverageUpdater
@@ -71,6 +75,8 @@ def __init__(
7175
"plateau_escape": PlateauEscapeSampler(),
7276
"annealed_shell": AnnealedShellSampler(),
7377
"spherical_cover": SphericalCoverSampler(),
78+
"two_scale_cover": TwoScaleCoverSampler(),
79+
"quality_diversity_mix": QualityDiversityMixSampler(),
7480
}
7581
self.updaters = {
7682
"winner_copy": WinnerCopyUpdater(),
@@ -81,6 +87,8 @@ def __init__(
8187
"softmax_preference": SoftmaxPreferenceUpdater(),
8288
"borda_preference": BordaPreferenceUpdater(),
8389
"bradley_terry_preference": BradleyTerryPreferenceUpdater(),
90+
"challenger_mixture_preference": ChallengerMixturePreferenceUpdater(),
91+
"plackett_luce_preference": PlackettLucePreferenceUpdater(),
8492
}
8593

8694
@staticmethod
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from __future__ import annotations
2+
3+
import math
4+
5+
from app.core.schema import Candidate, Session
6+
from app.samplers.base import clamp_vector, make_rng
7+
8+
9+
class QualityDiversityMixSampler:
10+
"""Sampler inspired by quality-diversity search with several complementary emitters."""
11+
12+
name = "quality_diversity_mix"
13+
14+
def propose(self, session: Session, seed: int) -> list[Candidate]:
15+
rng = make_rng(seed + 991)
16+
dimensions = max(1, len(session.current_z))
17+
base_direction = self._base_direction(session.current_z, dimensions)
18+
lateral_direction = self._orthogonal_direction(base_direction)
19+
cover_pool = [self._unit_vector([rng.uniform(-1.0, 1.0) for _ in range(dimensions)]) for _ in range(28)]
20+
far_directions = self._greedy_cover(cover_pool, max(2, session.config.candidate_count // 2))
21+
22+
medium = min(max(session.config.trust_radius * 0.42, 0.16), session.config.trust_radius)
23+
far = min(max(session.config.trust_radius * 0.82, 0.28), session.config.trust_radius)
24+
counter = min(max(session.config.trust_radius * 0.3, 0.12), session.config.trust_radius)
25+
26+
patterns: list[tuple[str, list[float], float]] = [
27+
("qd_refine", base_direction, medium * 0.62),
28+
("qd_forward", base_direction, medium),
29+
("qd_lateral_plus", lateral_direction, medium),
30+
("qd_far_cover_1", far_directions[0], far),
31+
("qd_lateral_minus", [-value for value in lateral_direction], medium),
32+
("qd_counter", [-value for value in base_direction], counter),
33+
]
34+
for index, direction in enumerate(far_directions[1:], start=2):
35+
patterns.append((f"qd_far_cover_{index + 1}", direction, far))
36+
37+
candidates: list[Candidate] = []
38+
for index in range(session.config.candidate_count):
39+
role, direction, radius = patterns[index % len(patterns)]
40+
jitter_scale = 0.014 if "refine" in role else 0.024 if "far_cover" not in role else 0.03
41+
jitter = [rng.uniform(-jitter_scale, jitter_scale) for _ in range(dimensions)]
42+
z = clamp_vector(
43+
[
44+
current + (axis * radius) + noise
45+
for current, axis, noise in zip(session.current_z, direction, jitter, strict=False)
46+
],
47+
session.config.trust_radius,
48+
)
49+
candidates.append(
50+
Candidate(
51+
round_id="",
52+
candidate_index=index,
53+
z=z,
54+
sampler_role=role,
55+
predicted_score=sum(z) + (0.01 if "far_cover" in role else 0.0),
56+
predicted_uncertainty=0.16 + (0.02 * index),
57+
seed=seed,
58+
generation_params={
59+
"image_size": session.config.image_size,
60+
"qd_radius": round(radius, 4),
61+
"qd_direction": [round(value, 4) for value in direction],
62+
"qd_emitter_role": role,
63+
},
64+
)
65+
)
66+
return candidates
67+
68+
@staticmethod
69+
def _base_direction(current_z: list[float], dimensions: int) -> list[float]:
70+
length = math.sqrt(sum(value * value for value in current_z))
71+
if length > 1e-8:
72+
return [value / length for value in current_z]
73+
direction = [0.0 for _ in range(dimensions)]
74+
direction[0] = 1.0
75+
if dimensions > 1:
76+
direction[1] = 0.35
77+
norm = math.sqrt(sum(value * value for value in direction))
78+
return [value / norm for value in direction]
79+
80+
@staticmethod
81+
def _orthogonal_direction(base_direction: list[float]) -> list[float]:
82+
dimensions = len(base_direction)
83+
if dimensions == 1:
84+
return [1.0]
85+
lateral = [0.0 for _ in range(dimensions)]
86+
lateral[0] = -base_direction[1]
87+
lateral[1] = base_direction[0]
88+
for index in range(2, dimensions):
89+
lateral[index] = base_direction[index] * (-0.45 if index % 2 == 0 else 0.45)
90+
length = math.sqrt(sum(value * value for value in lateral))
91+
if length == 0.0:
92+
lateral[1] = 1.0
93+
return lateral
94+
return [value / length for value in lateral]
95+
96+
@classmethod
97+
def _greedy_cover(cls, pool: list[list[float]], count: int) -> list[list[float]]:
98+
if not pool:
99+
return []
100+
selected = [pool[0]]
101+
remaining = pool[1:]
102+
while remaining and len(selected) < count:
103+
best_direction = max(
104+
remaining,
105+
key=lambda candidate: min(cls._angular_distance(candidate, prior) for prior in selected),
106+
)
107+
selected.append(best_direction)
108+
remaining = [candidate for candidate in remaining if candidate is not best_direction]
109+
return selected[:count]
110+
111+
@staticmethod
112+
def _angular_distance(left: list[float], right: list[float]) -> float:
113+
cosine = sum(a * b for a, b in zip(left, right, strict=False))
114+
cosine = max(-1.0, min(1.0, cosine))
115+
return math.acos(cosine)
116+
117+
@staticmethod
118+
def _unit_vector(values: list[float]) -> list[float]:
119+
norm = math.sqrt(sum(value * value for value in values))
120+
if norm == 0.0:
121+
fallback = [0.0 for _ in values]
122+
fallback[0] = 1.0
123+
return fallback
124+
return [value / norm for value in values]

0 commit comments

Comments
 (0)