diff --git a/rationai/resources/models.py b/rationai/resources/models.py index 47be222..3a297a7 100644 --- a/rationai/resources/models.py +++ b/rationai/resources/models.py @@ -1,3 +1,5 @@ +from typing import cast + import lz4.frame import numpy as np from httpx import USE_CLIENT_DEFAULT @@ -27,8 +29,7 @@ def classify_image( (float | dict[str, float]): The classification result as a single float (for binary classification) or probabilities for each class. """ - data = image.tobytes() - compressed_data = lz4.frame.compress(data) + compressed_data = lz4.frame.compress(image.tobytes()) response = self._post(model, data=compressed_data, timeout=timeout) response.raise_for_status() return response.json() @@ -55,8 +56,7 @@ def segment_image( else: h, w = image.shape[:2] - data = image.tobytes() - compressed_data = lz4.frame.compress(data) + compressed_data = lz4.frame.compress(image.tobytes()) response = self._post(model, data=compressed_data, timeout=timeout) response.raise_for_status() @@ -64,6 +64,42 @@ def segment_image( lz4.frame.decompress(response.content), dtype=np.float16 ).reshape(-1, h, w) + def embed_image[DType: np.generic]( + self, + model: str, + image: Image | NDArray[np.uint8], + output_dtype: type[DType] = np.float32, # type: ignore[assignment] + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + ) -> NDArray[DType]: + """Compute an embedding vector for an image using the specified model. + + Args: + model: The name of the model to use for embedding. + image: The image to embed. It must be uint8 RGB image. + output_dtype: Output numpy dtype for embeddings (e.g. np.float16, np.float32). + timeout: Optional timeout for the request. + + Returns: + NDArray[DType]: The embedding array reshaped according to + the `x-output-shape` response header. + """ + compressed_data = lz4.frame.compress(image.tobytes()) + response = self._post( + model, + data=compressed_data, + headers={"x-output-dtype": np.dtype(output_dtype).name}, + timeout=timeout, + ) + response.raise_for_status() + + payload = lz4.frame.decompress(response.content) + embedding = np.frombuffer(payload, dtype=output_dtype) + + response_shape = response.headers["x-output-shape"] + embedding = embedding.reshape(eval(response_shape)) + + return cast("NDArray[DType]", embedding) + class AsyncModels(AsyncAPIResource): async def classify_image( @@ -83,8 +119,7 @@ async def classify_image( (float | dict[str, float]): The classification result as a single float (for binary classification) or probabilities for each class. """ - data = image.tobytes() - compressed_data = lz4.frame.compress(data) + compressed_data = lz4.frame.compress(image.tobytes()) response = await self._post(model, data=compressed_data, timeout=timeout) response.raise_for_status() return response.json() @@ -111,11 +146,46 @@ async def segment_image( else: h, w = image.shape[:2] - data = image.tobytes() - compressed_data = lz4.frame.compress(data) + compressed_data = lz4.frame.compress(image.tobytes()) response = await self._post(model, data=compressed_data, timeout=timeout) response.raise_for_status() return np.frombuffer( lz4.frame.decompress(response.content), dtype=np.float16 ).reshape(-1, h, w) + + async def embed_image[DType: np.generic]( + self, + model: str, + image: Image | NDArray[np.uint8], + output_dtype: type[DType] = np.float32, # type: ignore[assignment] + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + ) -> NDArray[DType]: + """Compute an embedding vector for an image using the specified model. + + Args: + model: The name of the model to use for embedding. + image: The image to embed. It must be uint8 RGB image. + output_dtype: Output numpy dtype for embeddings (e.g. np.float16, np.float32). + timeout: Optional timeout for the request. + + Returns: + NDArray[DType]: The embedding array reshaped according to + the `x-output-shape` response header. + """ + compressed_data = lz4.frame.compress(image.tobytes()) + response = await self._post( + model, + data=compressed_data, + headers={"x-output-dtype": np.dtype(output_dtype).name}, + timeout=timeout, + ) + response.raise_for_status() + + payload = lz4.frame.decompress(response.content) + embedding = np.frombuffer(payload, dtype=output_dtype) + + response_shape = response.headers["x-output-shape"] + embedding = embedding.reshape(eval(response_shape)) + + return cast("NDArray[DType]", embedding)