Skip to content

Commit 7eae64e

Browse files
authored
Add show_points option to show_anns for overlaying point prompts on annotations (#511)
- Add show_points parameter to show_anns() that renders point prompts as star markers on top of the annotation image - Store point_coords and point_labels in predict_inst() so show_anns() can automatically use them - Points are drawn as filled 5-pointed stars with white edge outlines, matching the style of the show_points() method
1 parent 351fdff commit 7eae64e

1 file changed

Lines changed: 75 additions & 0 deletions

File tree

samgeo/samgeo3.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2678,6 +2678,12 @@ def show_anns(
26782678
blend: bool = True,
26792679
alpha: float = 0.5,
26802680
font_scale: float = 0.8,
2681+
show_points: bool = False,
2682+
point_coords: Optional[List[List[float]]] = None,
2683+
point_labels: Optional[List[int]] = None,
2684+
foreground_color: Tuple[int, int, int] = (0, 128, 0),
2685+
background_color: Tuple[int, int, int] = (255, 0, 0),
2686+
point_size: int = 15,
26812687
**kwargs: Any,
26822688
) -> None:
26832689
"""Show the annotations (objects with random color) on the input image.
@@ -2696,6 +2702,20 @@ def show_anns(
26962702
only annotations will be shown on a white background.
26972703
alpha (float): The alpha value for the annotations.
26982704
font_scale (float): The font scale for labels. Defaults to 0.8.
2705+
show_points (bool): Whether to show point prompts on the image.
2706+
If True, uses stored point_coords/point_labels from predict_inst()
2707+
or the explicitly provided ones. Defaults to False.
2708+
point_coords (List[List[float]], optional): Point coordinates to
2709+
display. If None and show_points is True, uses the stored
2710+
point_coords from the last predict_inst() or show_canvas() call.
2711+
point_labels (List[int], optional): Labels for the points (1 for
2712+
foreground, 0 for background). If None and show_points is True,
2713+
uses the stored point_labels.
2714+
foreground_color (Tuple[int, int, int]): RGB color for foreground
2715+
points (label=1). Defaults to dark green (0, 128, 0).
2716+
background_color (Tuple[int, int, int]): RGB color for background
2717+
points (label=0). Defaults to red (255, 0, 0).
2718+
point_size (int): Size of star markers in pixels. Defaults to 15.
26992719
**kwargs: Additional keyword arguments (kept for backward compatibility).
27002720
"""
27012721

@@ -2715,6 +2735,45 @@ def show_anns(
27152735
font_scale=font_scale,
27162736
)
27172737

2738+
# Overlay point prompts if requested
2739+
if show_points:
2740+
coords = point_coords
2741+
labels = point_labels
2742+
if coords is None and hasattr(self, "point_coords"):
2743+
coords = self.point_coords
2744+
if labels is None and hasattr(self, "point_labels"):
2745+
labels = self.point_labels
2746+
2747+
if coords is not None:
2748+
coords_arr = np.array(coords)
2749+
if labels is None:
2750+
labels = [1] * len(coords_arr)
2751+
labels_arr = np.array(labels)
2752+
2753+
for pt, lbl in zip(coords_arr, labels_arr):
2754+
x, y = int(pt[0]), int(pt[1])
2755+
color = foreground_color if lbl == 1 else background_color
2756+
2757+
# Build a 5-pointed star polygon
2758+
def _star_pts(cx, cy, r_outer, r_inner, n=5):
2759+
pts = []
2760+
for i in range(2 * n):
2761+
r = r_outer if i % 2 == 0 else r_inner
2762+
angle = np.pi / 2 + i * np.pi / n
2763+
pts.append(
2764+
[
2765+
int(cx + r * np.cos(angle)),
2766+
int(cy - r * np.sin(angle)),
2767+
]
2768+
)
2769+
return np.array(pts, dtype=np.int32)
2770+
2771+
star_pts = _star_pts(x, y, point_size, point_size * 0.4)
2772+
# Filled color star
2773+
cv2.fillPoly(blended, [star_pts], color)
2774+
# White edge outline
2775+
cv2.polylines(blended, [star_pts], True, (255, 255, 255), 2)
2776+
27182777
if output is not None:
27192778
# Save directly using OpenCV
27202779
cv2.imwrite(output, cv2.cvtColor(blended, cv2.COLOR_RGB2BGR))
@@ -3227,6 +3286,22 @@ def predict_inst(
32273286
self.scores = list(scores) if isinstance(scores, np.ndarray) else [scores]
32283287
self.logits = logits
32293288

3289+
# Store point prompts for visualization
3290+
if point_coords is not None:
3291+
self.point_coords = (
3292+
point_coords.tolist()
3293+
if isinstance(point_coords, np.ndarray)
3294+
else point_coords
3295+
)
3296+
self.point_labels = (
3297+
point_labels.tolist()
3298+
if isinstance(point_labels, np.ndarray)
3299+
else point_labels
3300+
)
3301+
else:
3302+
self.point_coords = None
3303+
self.point_labels = None
3304+
32303305
return masks, scores, logits
32313306

32323307
def predict_inst_batch(

0 commit comments

Comments
 (0)