Skip to content

Commit 5324787

Browse files
committed
Enhance visualization script for success or not
1 parent 5829349 commit 5324787

2 files changed

Lines changed: 40 additions & 7 deletions

File tree

-5.89 KB
Loading

experiments/wikipedia/visualization.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,20 @@ def load_logs(log_dir, scenario_filter=None):
5656

5757
trajectory = []
5858
kinematics_series = []
59+
is_success = False
5960

6061
with open(f, 'r') as file:
6162
for line in file:
6263
try:
6364
entry = json.loads(line)
6465

66+
# Check for Success/Failure signals
67+
msg = entry.get('msg', '')
68+
if "Stable state found" in msg or "Navigation Complete" in msg:
69+
is_success = True
70+
elif "Inner Loop exhausted" in msg or "Immediate stop for experiment comparison" in msg:
71+
is_success = False
72+
6573
# 1. State Trajectory
6674
state = entry.get('state')
6775
if state and isinstance(state, dict):
@@ -88,7 +96,8 @@ def load_logs(log_dir, scenario_filter=None):
8896
'timestamp': timestamp,
8997
'path': trajectory,
9098
'kinematics': kinematics_series,
91-
'label': f"{policy}" # Simplified label
99+
'label': f"{policy}", # Simplified label
100+
'success': is_success
92101
})
93102

94103
return data
@@ -241,6 +250,7 @@ def run_visualization(scenario_id, title_suffix, output_filename, custom_landmar
241250
'cot': {'color': 'tab:green', 'ls': '-.', 'marker': '*', 'label': 'Baseline B (LLM CoT)'}
242251
}
243252

253+
seen_policies = set()
244254
for run in runs:
245255
policy = run['policy']
246256
path = run['path']
@@ -255,18 +265,40 @@ def run_visualization(scenario_id, title_suffix, output_filename, custom_landmar
255265
jitter = np.random.normal(0, 0.04, path_coords.shape)
256266
path_coords_jittered = path_coords + jitter
257267

268+
# Handle Legend Deduplication
269+
lbl = style['label']
270+
if lbl in seen_policies:
271+
lbl = None
272+
else:
273+
seen_policies.add(lbl)
274+
258275
# Plot Line
259276
plt.plot(path_coords_jittered[:, 0], path_coords_jittered[:, 1],
260-
color=style['color'], label=style['label'], linestyle=style['ls'], linewidth=2, alpha=0.6)
277+
color=style['color'], label=lbl, linestyle=style['ls'], linewidth=2, alpha=0.6)
261278

262279
# Plot Markers
263280
plt.scatter(path_coords_jittered[:, 0], path_coords_jittered[:, 1],
264281
color=style['color'], s=50, marker=style['marker'], alpha=0.7)
265282

266-
# Annotate Steps (First and Last few, or strided to avoid clutter)
267-
# Showing start and end is usually most important
268-
plt.text(path_coords_jittered[0, 0], path_coords_jittered[0, 1], "Start", fontsize=9, fontweight='bold', color='black')
269-
plt.text(path_coords_jittered[-1, 0], path_coords_jittered[-1, 1], f"End ({len(path)})", fontsize=9, fontweight='bold', color=style['color'])
283+
# Visual Indication of Success/Failure
284+
is_success = run.get('success', False)
285+
end_marker = '*' if is_success else 'X'
286+
287+
# Plot distinct end marker
288+
plt.scatter(path_coords_jittered[-1, 0], path_coords_jittered[-1, 1],
289+
color=style['color'], s=200, marker=end_marker, edgecolors='black', zorder=10)
290+
291+
# Annotate Start (One-time, unjittered)
292+
if runs and runs[0]['path']:
293+
start_node = runs[0]['path'][0]
294+
if start_node in node_map:
295+
start_coord = node_map[start_node]
296+
plt.text(start_coord[0], start_coord[1], "Start", fontsize=10, fontweight='bold', color='black',
297+
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2), zorder=20)
298+
299+
# Add Custom Legend for Success/Fail
300+
plt.scatter([], [], color='white', marker='*', s=200, edgecolors='black', label='Success')
301+
plt.scatter([], [], color='white', marker='X', s=200, edgecolors='black', label='Fail')
270302

271303
# Highlight Landmarks
272304
key_landmarks = custom_landmarks if custom_landmarks else []
@@ -281,7 +313,8 @@ def run_visualization(scenario_id, title_suffix, output_filename, custom_landmar
281313
plt.title(f"Trajectory Analysis: {title_suffix}", fontsize=16)
282314
plt.xlabel("Semantic Dimension 1 (PCA)")
283315
plt.ylabel("Semantic Dimension 2 (PCA)")
284-
plt.legend(loc='best')
316+
# Legend top-left inside graph
317+
plt.legend(loc='upper left', framealpha=0.9)
285318
plt.grid(True, alpha=0.2)
286319

287320
output_path = os.path.join(OUTPUT_DIR, output_filename)

0 commit comments

Comments
 (0)