Skip to content
Merged
86 changes: 78 additions & 8 deletions rationai/resources/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import cast

import lz4.frame
import numpy as np
from httpx import USE_CLIENT_DEFAULT
Expand Down Expand Up @@ -27,8 +29,7 @@ def classify_image(
(float | dict[str, float]): The classification result as a single float
(for binary classification) or probabilities for each class.
"""
data = image.tobytes()
compressed_data = lz4.frame.compress(data)
compressed_data = lz4.frame.compress(image.tobytes())
response = self._post(model, data=compressed_data, timeout=timeout)
response.raise_for_status()
return response.json()
Expand All @@ -55,15 +56,50 @@ def segment_image(
else:
h, w = image.shape[:2]

data = image.tobytes()
compressed_data = lz4.frame.compress(data)
compressed_data = lz4.frame.compress(image.tobytes())
response = self._post(model, data=compressed_data, timeout=timeout)
response.raise_for_status()

return np.frombuffer(
lz4.frame.decompress(response.content), dtype=np.float16
).reshape(-1, h, w)

def embed_image[DType: np.generic](
self,
model: str,
image: Image | NDArray[np.uint8],
output_dtype: type[DType] = np.float32, # type: ignore[assignment]
timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
) -> NDArray[DType]:
"""Compute an embedding vector for an image using the specified model.

Args:
model: The name of the model to use for embedding.
image: The image to embed. It must be uint8 RGB image.
output_dtype: Output numpy dtype for embeddings (e.g. np.float16, np.float32).
timeout: Optional timeout for the request.

Returns:
NDArray[DType]: The embedding array reshaped according to
the `x-output-shape` response header.
"""
compressed_data = lz4.frame.compress(image.tobytes())
response = self._post(
model,
data=compressed_data,
headers={"x-output-dtype": np.dtype(output_dtype).name},
Comment thread
matejpekar marked this conversation as resolved.
timeout=timeout,
)
response.raise_for_status()

payload = lz4.frame.decompress(response.content)
embedding = np.frombuffer(payload, dtype=output_dtype)

response_shape = response.headers["x-output-shape"]
embedding = embedding.reshape(eval(response_shape))

return cast("NDArray[DType]", embedding)


class AsyncModels(AsyncAPIResource):
async def classify_image(
Expand All @@ -83,8 +119,7 @@ async def classify_image(
(float | dict[str, float]): The classification result as a single float
(for binary classification) or probabilities for each class.
"""
data = image.tobytes()
compressed_data = lz4.frame.compress(data)
compressed_data = lz4.frame.compress(image.tobytes())
response = await self._post(model, data=compressed_data, timeout=timeout)
response.raise_for_status()
return response.json()
Expand All @@ -111,11 +146,46 @@ async def segment_image(
else:
h, w = image.shape[:2]

data = image.tobytes()
compressed_data = lz4.frame.compress(data)
compressed_data = lz4.frame.compress(image.tobytes())
response = await self._post(model, data=compressed_data, timeout=timeout)
response.raise_for_status()

return np.frombuffer(
lz4.frame.decompress(response.content), dtype=np.float16
).reshape(-1, h, w)

async def embed_image[DType: np.generic](
self,
model: str,
image: Image | NDArray[np.uint8],
output_dtype: type[DType] = np.float32, # type: ignore[assignment]
timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
) -> NDArray[DType]:
"""Compute an embedding vector for an image using the specified model.

Args:
model: The name of the model to use for embedding.
image: The image to embed. It must be uint8 RGB image.
output_dtype: Output numpy dtype for embeddings (e.g. np.float16, np.float32).
timeout: Optional timeout for the request.

Returns:
NDArray[DType]: The embedding array reshaped according to
the `x-output-shape` response header.
"""
compressed_data = lz4.frame.compress(image.tobytes())
response = await self._post(
model,
data=compressed_data,
headers={"x-output-dtype": np.dtype(output_dtype).name},
timeout=timeout,
)
response.raise_for_status()

payload = lz4.frame.decompress(response.content)
embedding = np.frombuffer(payload, dtype=output_dtype)

response_shape = response.headers["x-output-shape"]
embedding = embedding.reshape(eval(response_shape))

return cast("NDArray[DType]", embedding)
Loading