@@ -605,25 +605,40 @@ def get_model_names() -> Iterable:
605605#
606606
607607
608- def _to_image (input_ ):
609- # we require the input to be uint8
610- if input_ .dtype != np .dtype ("uint8" ):
611- # first normalize the input to [0, 1]
612- input_ = input_ .astype ("float32" ) - input_ .min ()
613- input_ = input_ / input_ .max ()
614- # then bring to [0, 255] and cast to uint8
615- input_ = (input_ * 255 ).astype ("uint8" )
616-
617- if input_ .ndim == 2 :
618- image = np .concatenate ([input_ [..., None ]] * 3 , axis = - 1 )
619- elif input_ .ndim == 3 and input_ .shape [- 1 ] == 3 :
620- image = input_
608+ def _to_image (image ):
609+ input_ = image
610+ ndim = input_ .ndim
611+ n_channels = 1 if ndim == 2 else input_ .shape [- 1 ]
612+
613+ # Map the input to three channels.
614+ if ndim == 2 : # Grayscale image -> replicate channels.
615+ input_ = np .concatenate ([input_ [..., None ]] * 3 , axis = - 1 )
616+ elif ndim == 3 and n_channels == 1 : # Grayscale image -> replicate channels.
617+ input_ = np .concatenate ([input_ ] * 3 , axis = - 1 )
618+ elif ndim == 3 and n_channels == 2 : # Two channels -> add a zero channel.
619+ zero_channel = np .zeros (input_ .shape [:2 ] + (1 ,), dtype = input_ .dtype )
620+ input_ = np .concatenate ([input_ , zero_channel ], axis = - 1 )
621+ elif input_ .ndim == 3 and n_channels == 3 : # RGB input -> do nothing.
622+ pass
623+ elif input_ .ndim == 3 and n_channels > 3 : # More than three channels -> select first three.
624+ warnings .warn (f"You provided an input with { n_channels } channels. Only the first three will be used." )
625+ input_ = input_ [..., :3 ]
621626 else :
622- raise ValueError (f"Invalid input image of shape { input_ .shape } . Expect either 2D grayscale or 3D RGB image." )
627+ raise ValueError (
628+ f"Invalid input dimensionality { ndim } . Expect either a 2D input (=grayscale image) "
629+ "or a 3D input (= image with channels)."
630+ )
631+ assert input_ .ndim == 3 and input_ .shape [- 1 ] == 3
632+
633+ # Normalize the input per channel and bring it to uint8.
634+ input_ = input_ .astype ("float32" )
635+ input_ -= input_ .min (axis = (0 , 1 ))[None , None ]
636+ input_ /= (input_ .max (axis = (0 , 1 ))[None , None ] + 1e-7 )
637+ input_ = (input_ * 255 ).astype ("uint8" )
623638
624- # explicitly return a numpy array for compatibility with torchvision
625- # because the input_ array could be something like dask array
626- return np .array (image )
639+ # Explicitly return a numpy array for compatibility with torchvision
640+ # because the input_ array could be something like dask array.
641+ return np .array (input_ )
627642
628643
629644@torch .no_grad
0 commit comments