Skip to content

Commit 92a2d70

Browse files
Updates for EGW to use ModelUsage and bug fixes (#2140)
Co-authored-by: Hazel <hazel.he@datastax.com>
1 parent 402b9e8 commit 92a2d70

14 files changed

Lines changed: 177 additions & 85 deletions

src/main/java/io/stargate/sgv2/jsonapi/service/embedding/configuration/EmbeddingProviderResponseValidation.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import jakarta.ws.rs.client.ClientResponseContext;
88
import jakarta.ws.rs.client.ClientResponseFilter;
99
import jakarta.ws.rs.core.MediaType;
10+
import jakarta.ws.rs.core.Response;
1011
import java.io.IOException;
1112
import java.nio.charset.StandardCharsets;
1213
import org.slf4j.Logger;
@@ -37,11 +38,19 @@ public class EmbeddingProviderResponseValidation implements ClientResponseFilter
3738
@Override
3839
public void filter(ClientRequestContext requestContext, ClientResponseContext responseContext)
3940
throws JsonApiException {
41+
4042
// If the status is 0, it means something went wrong (maybe a timeout). Directly return and pass
4143
// the error to the client
4244
if (responseContext.getStatus() == 0) {
4345
return;
4446
}
47+
48+
// only validate for successful responses, errors may return non-JSON content,
49+
// e.g. a HTTP 401 may just have "Unauthorized" in the response body
50+
if (responseContext.getStatusInfo().getFamily() != Response.Status.Family.SUCCESSFUL) {
51+
return;
52+
}
53+
4554
// Throw error if there is no response body
4655
if (!responseContext.hasEntity()) {
4756
throw EMBEDDING_PROVIDER_UNEXPECTED_RESPONSE.toApiException(

src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/AwsBedrockEmbeddingProvider.java

Lines changed: 61 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -93,73 +93,72 @@ public Uni<BatchedEmbeddingResponse> vectorize(
9393
AwsBasicCredentials.create(
9494
embeddingCredentials.accessId().get(), embeddingCredentials.secretId().get());
9595

96-
try (var bedrockClient =
96+
// NOTE: cannot put this client in a resource block for auto close because it will close
97+
// te connection pool before we pull the async result.
98+
var bedrockClient =
9799
BedrockRuntimeAsyncClient.builder()
98100
.credentialsProvider(StaticCredentialsProvider.create(awsCreds))
99101
.region(Region.of(vectorizeServiceParameters.get("region").toString()))
100-
.build()) {
101-
102-
long callStartNano = System.nanoTime();
103-
104-
// NOTE: need to use the AWS client for the request, not a Rest Easy, so we cannot use
105-
// all the features from the superclasses such as error mapping and building the model usage
106-
var bytesUsageTracker = new ByteUsageTracker();
107-
var bedrockFuture =
108-
bedrockClient
109-
.invokeModel(
110-
requestBuilder -> {
111-
try {
112-
var inputData =
113-
OBJECT_WRITER.writeValueAsBytes(
114-
new AwsBedrockEmbeddingRequest(texts.getFirst(), dimension));
115-
bytesUsageTracker.requestBytes = inputData.length;
116-
requestBuilder.body(SdkBytes.fromByteArray(inputData)).modelId(modelName());
117-
} catch (JsonProcessingException e) {
118-
throw ErrorCodeV1.EMBEDDING_REQUEST_ENCODING_ERROR.toApiException();
102+
.build();
103+
104+
long callStartNano = System.nanoTime();
105+
106+
// NOTE: need to use the AWS client for the request, not a Rest Easy, so we cannot use
107+
// all the features from the superclasses such as error mapping and building the model usage
108+
var bytesUsageTracker = new ByteUsageTracker();
109+
var bedrockFuture =
110+
bedrockClient
111+
.invokeModel(
112+
requestBuilder -> {
113+
try {
114+
var inputData =
115+
OBJECT_WRITER.writeValueAsBytes(
116+
new AwsBedrockEmbeddingRequest(texts.getFirst(), dimension));
117+
bytesUsageTracker.requestBytes = inputData.length;
118+
requestBuilder.body(SdkBytes.fromByteArray(inputData)).modelId(modelName());
119+
} catch (JsonProcessingException e) {
120+
throw ErrorCodeV1.EMBEDDING_REQUEST_ENCODING_ERROR.toApiException();
121+
}
122+
})
123+
.thenApply(
124+
rawResponse -> {
125+
try {
126+
// aws docs say do not need to close the stream
127+
var inputStream = rawResponse.body().asInputStream();
128+
var bedrockResponse =
129+
OBJECT_READER.readValue(inputStream, AwsBedrockEmbeddingResponse.class);
130+
long callDurationNano = System.nanoTime() - callStartNano;
131+
132+
try (var countingOut =
133+
new CountingOutputStream(OutputStream.nullOutputStream())) {
134+
inputStream.transferTo(countingOut);
135+
long responseSize = countingOut.getCount();
136+
bytesUsageTracker.responseBytes =
137+
responseSize > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) responseSize;
119138
}
120-
})
121-
.thenApply(
122-
rawResponse -> {
123-
try {
124-
// aws docs say do not need to close the stream
125-
var inputStream = rawResponse.body().asInputStream();
126-
var bedrockResponse =
127-
OBJECT_READER.readValue(inputStream, AwsBedrockEmbeddingResponse.class);
128-
long callDurationNano = System.nanoTime() - callStartNano;
129-
130-
try (var countingOut =
131-
new CountingOutputStream(OutputStream.nullOutputStream())) {
132-
inputStream.transferTo(countingOut);
133-
long responseSize = countingOut.getCount();
134-
bytesUsageTracker.responseBytes =
135-
responseSize > Integer.MAX_VALUE
136-
? Integer.MAX_VALUE
137-
: (int) responseSize;
138-
}
139-
140-
var modelUsage =
141-
createModelUsage(
142-
embeddingCredentials.tenantId(),
143-
ModelInputType.fromEmbeddingRequestType(embeddingRequestType),
144-
bedrockResponse.inputTextTokenCount(),
145-
bedrockResponse.inputTextTokenCount(),
146-
bytesUsageTracker.requestBytes,
147-
bytesUsageTracker.responseBytes,
148-
callDurationNano);
149-
150-
return new BatchedEmbeddingResponse(
151-
batchId, List.of(bedrockResponse.embedding), modelUsage);
152-
153-
} catch (IOException e) {
154-
throw ErrorCodeV1.EMBEDDING_RESPONSE_DECODING_ERROR.toApiException();
155-
}
156-
});
157139

158-
return Uni.createFrom()
159-
.completionStage(bedrockFuture)
160-
.onFailure(BedrockRuntimeException.class)
161-
.transform(throwable -> mapBedrockException((BedrockRuntimeException) throwable));
162-
}
140+
var modelUsage =
141+
createModelUsage(
142+
embeddingCredentials.tenantId(),
143+
ModelInputType.fromEmbeddingRequestType(embeddingRequestType),
144+
bedrockResponse.inputTextTokenCount(),
145+
bedrockResponse.inputTextTokenCount(),
146+
bytesUsageTracker.requestBytes,
147+
bytesUsageTracker.responseBytes,
148+
callDurationNano);
149+
150+
return new BatchedEmbeddingResponse(
151+
batchId, List.of(bedrockResponse.embedding), modelUsage);
152+
153+
} catch (IOException e) {
154+
throw ErrorCodeV1.EMBEDDING_RESPONSE_DECODING_ERROR.toApiException();
155+
}
156+
});
157+
158+
return Uni.createFrom()
159+
.completionStage(bedrockFuture)
160+
.onFailure(BedrockRuntimeException.class)
161+
.transform(throwable -> mapBedrockException((BedrockRuntimeException) throwable));
163162
}
164163

165164
private Throwable mapBedrockException(BedrockRuntimeException bedrockException) {

src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/HuggingFaceEmbeddingProvider.java

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
import io.stargate.sgv2.jsonapi.service.provider.ProviderHttpInterceptor;
1313
import jakarta.ws.rs.HeaderParam;
1414
import jakarta.ws.rs.POST;
15-
import jakarta.ws.rs.Path;
16-
import jakarta.ws.rs.PathParam;
1715
import jakarta.ws.rs.core.*;
1816
import java.net.URI;
1917
import java.util.List;
@@ -41,9 +39,13 @@ public HuggingFaceEmbeddingProvider(
4139
dimension,
4240
vectorizeServiceParameters);
4341

42+
var baseUrl = serviceConfig.getBaseUrl(modelName());
43+
// replace was added in https://github.com/stargate/data-api/pull/2108/files
44+
var actualUrl = replaceParameters(baseUrl, Map.of("modelId", modelName()));
45+
4446
huggingFaceClient =
4547
QuarkusRestClientBuilder.newBuilder()
46-
.baseUri(URI.create(serviceConfig.getBaseUrl(modelName())))
48+
.baseUri(URI.create(actualUrl))
4749
.readTimeout(requestProperties().readTimeoutMillis(), TimeUnit.MILLISECONDS)
4850
.build(HuggingFaceEmbeddingProviderClient.class);
4951
}
@@ -79,7 +81,7 @@ public Uni<BatchedEmbeddingResponse> vectorize(
7981
var accessToken = HttpConstants.BEARER_PREFIX_FOR_API_KEY + embeddingCredentials.apiKey().get();
8082

8183
long callStartNano = System.nanoTime();
82-
return retryHTTPCall(huggingFaceClient.embed(accessToken, modelName(), huggingFaceRequest))
84+
return retryHTTPCall(huggingFaceClient.embed(accessToken, huggingFaceRequest))
8385
.onItem()
8486
.transform(
8587
jakartaResponse -> {
@@ -136,12 +138,9 @@ public Uni<BatchedEmbeddingResponse> vectorize(
136138
@RegisterProvider(ProviderHttpInterceptor.class)
137139
public interface HuggingFaceEmbeddingProviderClient {
138140
@POST
139-
@Path("/{modelId}")
140141
@ClientHeaderParam(name = HttpHeaders.CONTENT_TYPE, value = MediaType.APPLICATION_JSON)
141142
Uni<Response> embed(
142-
@HeaderParam("Authorization") String accessToken,
143-
@PathParam("modelId") String modelId,
144-
HuggingFaceEmbeddingRequest request);
143+
@HeaderParam("Authorization") String accessToken, HuggingFaceEmbeddingRequest request);
145144
}
146145

147146
/**

src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/MistralEmbeddingProvider.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,18 @@ protected String errorMessageJsonPtr() {
7676
return "/message";
7777
}
7878

79+
/**
80+
* Mistral for 401 Unauthorized returns a response with no content type and just the text
81+
* "Unauthorized".
82+
*/
83+
@Override
84+
protected String responseErrorMessage(Response jakartaResponse) {
85+
if (jakartaResponse.getStatus() == Response.Status.UNAUTHORIZED.getStatusCode()) {
86+
return Response.Status.UNAUTHORIZED.getReasonPhrase();
87+
}
88+
return super.responseErrorMessage(jakartaResponse);
89+
}
90+
7991
@Override
8092
public Uni<BatchedEmbeddingResponse> vectorize(
8193
int batchId,

src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingProvider.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public VoyageAIEmbeddingProvider(
4040
int dimension,
4141
Map<String, Object> vectorizeServiceParameters) {
4242
super(
43-
ModelProvider.VERTEXAI,
43+
ModelProvider.VOYAGE_AI,
4444
providerConfig,
4545
modelConfig,
4646
serviceConfig,

src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelInputType.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,12 @@ public static Optional<ModelInputType> fromEmbeddingGateway(
3636
default -> Optional.empty();
3737
};
3838
}
39+
40+
public EmbeddingGateway.ModelUsage.InputType toEmbeddingGateway() {
41+
return switch (this) {
42+
case INPUT_TYPE_UNSPECIFIED -> EmbeddingGateway.ModelUsage.InputType.INPUT_TYPE_UNSPECIFIED;
43+
case INDEX -> EmbeddingGateway.ModelUsage.InputType.INDEX;
44+
case SEARCH -> EmbeddingGateway.ModelUsage.InputType.SEARCH;
45+
};
46+
}
3947
}

src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelType.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,12 @@ public static Optional<ModelType> fromEmbeddingGateway(
2323
default -> Optional.empty();
2424
};
2525
}
26+
27+
public EmbeddingGateway.ModelUsage.ModelType toEmbeddingGateway() {
28+
return switch (this) {
29+
case MODEL_TYPE_UNSPECIFIED -> EmbeddingGateway.ModelUsage.ModelType.MODEL_TYPE_UNSPECIFIED;
30+
case EMBEDDING -> EmbeddingGateway.ModelUsage.ModelType.EMBEDDING;
31+
case RERANKING -> EmbeddingGateway.ModelUsage.ModelType.RERANKING;
32+
};
33+
}
2634
}

src/main/java/io/stargate/sgv2/jsonapi/service/provider/ModelUsage.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,21 @@ public static ModelUsage fromEmbeddingGateway(EmbeddingGateway.ModelUsage grpcMo
124124
grpcModelUsage.getCallDurationNanos());
125125
}
126126

127+
public EmbeddingGateway.ModelUsage toEmbeddingGateway() {
128+
return EmbeddingGateway.ModelUsage.newBuilder()
129+
.setModelProvider(modelProvider.apiName())
130+
.setModelType(modelType.toEmbeddingGateway())
131+
.setModelName(modelName)
132+
.setTenantId(tenantId)
133+
.setInputType(inputType.toEmbeddingGateway())
134+
.setPromptTokens(promptTokens)
135+
.setTotalTokens(totalTokens)
136+
.setRequestBytes(requestBytes)
137+
.setResponseBytes(responseBytes)
138+
.setCallDurationNanos(durationNanos)
139+
.build();
140+
}
141+
127142
/**
128143
* Creates a new model usage that merges this and the other usage, to combine after batching.
129144
*

src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderBase.java

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import java.time.Duration;
1111
import java.util.Map;
1212
import java.util.concurrent.TimeoutException;
13+
import java.util.function.Predicate;
1314
import org.slf4j.Logger;
1415
import org.slf4j.LoggerFactory;
1516

@@ -32,6 +33,14 @@
3233
public abstract class ProviderBase {
3334
protected static final Logger LOGGER = LoggerFactory.getLogger(ProviderBase.class);
3435

36+
// There is not a MediaType for text/json in Jakarta
37+
private static final MediaType MEDIATYPE_TEXT_JSON = new MediaType("text", "json");
38+
39+
protected static final Predicate<MediaType> IS_JSON_MEDIA_TYPE =
40+
mediaType ->
41+
MediaType.APPLICATION_JSON_TYPE.isCompatible(mediaType)
42+
|| MEDIATYPE_TEXT_JSON.isCompatible(mediaType);
43+
3544
private final ModelProvider modelProvider;
3645
private final ModelType modelType;
3746

@@ -203,12 +212,15 @@ protected String responseErrorMessage(Response jakartaResponse) {
203212
MediaType contentType = jakartaResponse.getMediaType();
204213
String raw = jakartaResponse.readEntity(String.class);
205214

206-
if (contentType == null || !MediaType.APPLICATION_JSON_TYPE.isCompatible(contentType)) {
207-
LOGGER.error(
208-
"Non-JSON error response from model provider, modelProvider:{}, modelName: {}, raw:{}",
209-
modelProvider(),
210-
modelName(),
211-
raw);
215+
if (contentType == null || !IS_JSON_MEDIA_TYPE.test(contentType)) {
216+
// we have an error, only need a debug
217+
if (LOGGER.isDebugEnabled()) {
218+
LOGGER.debug(
219+
"Non-JSON error response from model provider, modelProvider:{}, modelName: {}, raw:{}",
220+
modelProvider(),
221+
modelName(),
222+
raw);
223+
}
212224
return raw;
213225
}
214226

src/main/java/io/stargate/sgv2/jsonapi/service/provider/ProviderHttpInterceptor.java

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,32 @@ public void filter(ClientRequestContext requestContext, ClientResponseContext re
4444
long receivedBytes = 0;
4545
long sentBytes = 0;
4646

47+
// we may still get called even if the request failed, and we do not get a valid HTTP response,
48+
// for sanity check that we have the things we need to for processing.
49+
boolean isValid =
50+
responseContext != null
51+
&& responseContext.getStatus() > 0
52+
&& responseContext.getHeaders() != null;
53+
54+
if (!isValid) {
55+
if (LOGGER.isWarnEnabled()) {
56+
LOGGER.warn(
57+
"filter() - Invalid responseContext, skipping sent/received bytes tracking. responseContext is null: {}, getStatus: {}, getHeaders: {}",
58+
responseContext == null,
59+
responseContext != null ? responseContext.getStatus() : "response null",
60+
responseContext != null ? responseContext.getHeaders() : "response null");
61+
}
62+
return;
63+
}
64+
4765
if (LOGGER.isTraceEnabled()) {
4866
LOGGER.trace(
49-
"ProviderHttpInterceptor.filter() - requestContext.getUri(): {}, requestContext.getHeaders(): {}",
67+
"filter() - requestContext.getUri(): {}, requestContext.getHeaders(): {}",
5068
requestContext.getUri(),
5169
requestContext.getStringHeaders());
5270

5371
LOGGER.trace(
54-
"ProviderHttpInterceptor.filter() - responseContext.getStatus(): {}, responseContext.getHeaders(): {}",
72+
"filter() - responseContext.getStatus(): {}, responseContext.getHeaders(): {}",
5573
responseContext.getStatus(),
5674
responseContext.getHeaders());
5775
}
@@ -98,6 +116,16 @@ public static int getReceivedBytes(Response jakartaResponse) {
98116

99117
private static int getHeaderInt(Response jakartaResponse, String headerName) {
100118

119+
if (jakartaResponse == null || jakartaResponse.getHeaders() == null) {
120+
// log at trace, because this should be detected in filter() method
121+
if (LOGGER.isTraceEnabled()) {
122+
LOGGER.trace(
123+
"getHeaderInt() - jakartaResponse or headers is null, returning 0 for headerName: {}",
124+
headerName);
125+
}
126+
return 0;
127+
}
128+
101129
var headerString = jakartaResponse.getHeaderString(headerName);
102130
if (headerString != null && !headerString.isBlank()) {
103131
try {

0 commit comments

Comments
 (0)