Skip to content

Commit 3c333cc

Browse files
Merge branch 'automatic-prompt-generator' of https://github.com/computational-cell-analytics/micro-sam into automatic-prompt-generator
2 parents 82d6501 + b2465da commit 3c333cc

1 file changed

Lines changed: 32 additions & 17 deletions

File tree

micro_sam/util.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -605,25 +605,40 @@ def get_model_names() -> Iterable:
605605
#
606606

607607

608-
def _to_image(input_):
609-
# we require the input to be uint8
610-
if input_.dtype != np.dtype("uint8"):
611-
# first normalize the input to [0, 1]
612-
input_ = input_.astype("float32") - input_.min()
613-
input_ = input_ / input_.max()
614-
# then bring to [0, 255] and cast to uint8
615-
input_ = (input_ * 255).astype("uint8")
616-
617-
if input_.ndim == 2:
618-
image = np.concatenate([input_[..., None]] * 3, axis=-1)
619-
elif input_.ndim == 3 and input_.shape[-1] == 3:
620-
image = input_
608+
def _to_image(image):
609+
input_ = image
610+
ndim = input_.ndim
611+
n_channels = 1 if ndim == 2 else input_.shape[-1]
612+
613+
# Map the input to three channels.
614+
if ndim == 2: # Grayscale image -> replicate channels.
615+
input_ = np.concatenate([input_[..., None]] * 3, axis=-1)
616+
elif ndim == 3 and n_channels == 1: # Grayscale image -> replicate channels.
617+
input_ = np.concatenate([input_] * 3, axis=-1)
618+
elif ndim == 3 and n_channels == 2: # Two channels -> add a zero channel.
619+
zero_channel = np.zeros(input_.shape[:2] + (1,), dtype=input_.dtype)
620+
input_ = np.concatenate([input_, zero_channel], axis=-1)
621+
elif input_.ndim == 3 and n_channels == 3: # RGB input -> do nothing.
622+
pass
623+
elif input_.ndim == 3 and n_channels > 3: # More than three channels -> select first three.
624+
warnings.warn(f"You provided an input with {n_channels} channels. Only the first three will be used.")
625+
input_ = input_[..., :3]
621626
else:
622-
raise ValueError(f"Invalid input image of shape {input_.shape}. Expect either 2D grayscale or 3D RGB image.")
627+
raise ValueError(
628+
f"Invalid input dimensionality {ndim}. Expect either a 2D input (=grayscale image) "
629+
"or a 3D input (= image with channels)."
630+
)
631+
assert input_.ndim == 3 and input_.shape[-1] == 3
632+
633+
# Normalize the input per channel and bring it to uint8.
634+
input_ = input_.astype("float32")
635+
input_ -= input_.min(axis=(0, 1))[None, None]
636+
input_ /= (input_.max(axis=(0, 1))[None, None] + 1e-7)
637+
input_ = (input_ * 255).astype("uint8")
623638

624-
# explicitly return a numpy array for compatibility with torchvision
625-
# because the input_ array could be something like dask array
626-
return np.array(image)
639+
# Explicitly return a numpy array for compatibility with torchvision
640+
# because the input_ array could be something like dask array.
641+
return np.array(input_)
627642

628643

629644
@torch.no_grad

0 commit comments

Comments
 (0)