Skip to content

Commit 48857a2

Browse files
Yuqi-Duamorton
andauthored
Usage metering for embedding and reranking (#2008)
Co-authored-by: Aaron Morton <aaron.morton@datastax.com>
1 parent 81d5741 commit 48857a2

72 files changed

Lines changed: 4110 additions & 2707 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/CreateCollectionCommand.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import io.stargate.sgv2.jsonapi.api.model.command.CollectionOnlyCommand;
66
import io.stargate.sgv2.jsonapi.api.model.command.CommandName;
77
import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants;
8-
import io.stargate.sgv2.jsonapi.config.constants.RerankingConstants;
8+
import io.stargate.sgv2.jsonapi.config.constants.ServiceDescConstants;
99
import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1;
1010
import io.stargate.sgv2.jsonapi.service.schema.collections.DocumentPath;
1111
import io.stargate.sgv2.jsonapi.service.schema.naming.NamingRules;
@@ -276,28 +276,28 @@ public record RerankServiceDesc(
276276
description = "Registered reranking service provider",
277277
type = SchemaType.STRING,
278278
implementation = String.class)
279-
@JsonProperty(RerankingConstants.RerankingService.PROVIDER)
279+
@JsonProperty(ServiceDescConstants.PROVIDER)
280280
String provider,
281281
@Schema(
282282
description = "Registered reranking service model",
283283
type = SchemaType.STRING,
284284
implementation = String.class)
285-
@JsonProperty(RerankingConstants.RerankingService.MODEL_NAME)
285+
@JsonProperty(ServiceDescConstants.MODEL_NAME)
286286
String modelName,
287287
@Valid
288288
@Nullable
289289
@Schema(
290290
description = "Authentication config for chosen reranking service",
291291
type = SchemaType.OBJECT)
292-
@JsonProperty(RerankingConstants.RerankingService.AUTHENTICATION)
292+
@JsonProperty(ServiceDescConstants.AUTHENTICATION)
293293
@JsonInclude(JsonInclude.Include.NON_NULL)
294294
Map<String, String> authentication,
295295
@Nullable
296296
@Schema(
297297
description =
298298
"Optional parameters that match the messageTemplate provided for the reranking provider",
299299
type = SchemaType.OBJECT)
300-
@JsonProperty(RerankingConstants.RerankingService.PARAMETERS)
300+
@JsonProperty(ServiceDescConstants.PARAMETERS)
301301
@JsonInclude(JsonInclude.Include.NON_NULL)
302302
Map<String, Object> parameters) {
303303

src/main/java/io/stargate/sgv2/jsonapi/api/model/command/impl/VectorizeConfig.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import com.fasterxml.jackson.annotation.JsonProperty;
55
import io.stargate.sgv2.jsonapi.config.constants.VectorConstants;
66
import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1;
7-
import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants;
7+
import io.stargate.sgv2.jsonapi.service.embedding.operation.HuggingFaceDedicatedEmbeddingProvider;
8+
import io.stargate.sgv2.jsonapi.service.provider.ModelProvider;
89
import jakarta.validation.Valid;
910
import jakarta.validation.constraints.*;
1011
import java.util.*;
@@ -48,24 +49,30 @@ public VectorizeConfig(
4849
String modelName,
4950
Map<String, String> authentication,
5051
Map<String, Object> parameters) {
52+
5153
if (provider == null) {
5254
throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException(
5355
"'provider' in required property for 'vector.service' Object value");
5456
}
57+
5558
this.provider = provider;
59+
5660
// HuggingfaceDedicated does not need user to specify model explicitly
5761
// If user specifies modelName other than endpoint-defined-model, will error out
5862
// By default, huggingfaceDedicated provider use endpoint-defined-model as placeholder
59-
if (ProviderConstants.HUGGINGFACE_DEDICATED.equals(provider)) {
63+
if (ModelProvider.HUGGINGFACE_DEDICATED.apiName().equals(provider)) {
6064
if (modelName == null) {
61-
modelName = ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL;
62-
} else if (!modelName.equals(ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL)) {
65+
modelName =
66+
HuggingFaceDedicatedEmbeddingProvider.HUGGINGFACE_DEDICATED_ENDPOINT_DEFINED_MODEL;
67+
} else if (!modelName.equals(
68+
HuggingFaceDedicatedEmbeddingProvider.HUGGINGFACE_DEDICATED_ENDPOINT_DEFINED_MODEL)) {
6369
throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException(
6470
"'modelName' is not needed for embedding provider %s explicitly, only '%s' is accepted",
65-
ProviderConstants.HUGGINGFACE_DEDICATED,
66-
ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL);
71+
ModelProvider.HUGGINGFACE_DEDICATED,
72+
HuggingFaceDedicatedEmbeddingProvider.HUGGINGFACE_DEDICATED_ENDPOINT_DEFINED_MODEL);
6773
}
6874
}
75+
6976
this.modelName = modelName;
7077
if (authentication != null && !authentication.isEmpty()) {
7178
Map<String, String> updatedAuth = new HashMap<>();

src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentials.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,16 @@
66
* EmbeddingCredentials is a record that holds the embedding provider credentials for the embedding
77
* service passed as header.
88
*
9+
* <p>Includes the tenantID, so we can fully identify the usage when creating the {@link
10+
* io.stargate.sgv2.jsonapi.service.provider.ModelUsage}
11+
*
12+
* @param tenantId - Tenant Id that called the API.
913
* @param apiKey - API token for the embedding service
1014
* @param accessId - Access Id used for AWS Bedrock embedding service
1115
* @param secretId - Secret Id used for AWS Bedrock embedding service
1216
*/
1317
public record EmbeddingCredentials(
14-
Optional<String> apiKey, Optional<String> accessId, Optional<String> secretId) {}
18+
String tenantId,
19+
Optional<String> apiKey,
20+
Optional<String> accessId,
21+
Optional<String> secretId) {}

src/main/java/io/stargate/sgv2/jsonapi/api/request/EmbeddingCredentialsSupplier.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,14 @@ public EmbeddingCredentials create(
7979
&& collectionSupportsNoneAuth) {
8080
var authToken = requestContext.getHttpHeaders().getHeader(this.authTokenHeaderName);
8181
return new EmbeddingCredentials(
82-
Optional.ofNullable(authToken), Optional.empty(), Optional.empty());
82+
requestContext.getTenantId().orElse(""),
83+
Optional.ofNullable(authToken),
84+
Optional.empty(),
85+
Optional.empty());
8386
}
8487

8588
return new EmbeddingCredentials(
89+
requestContext.getTenantId().orElse(""),
8690
Optional.ofNullable(embeddingApi),
8791
Optional.ofNullable(accessId),
8892
Optional.ofNullable(secretId));

src/main/java/io/stargate/sgv2/jsonapi/api/request/RequestContext.java

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import com.fasterxml.uuid.Generators;
44
import com.fasterxml.uuid.NoArgGenerator;
5+
import com.google.common.annotations.VisibleForTesting;
56
import io.stargate.sgv2.jsonapi.api.request.tenant.DataApiTenantResolver;
67
import io.stargate.sgv2.jsonapi.api.request.token.DataApiTokenResolver;
78
import io.stargate.sgv2.jsonapi.config.constants.HttpConstants;
@@ -35,20 +36,25 @@ public class RequestContext {
3536

3637
private final String userAgent;
3738

38-
/**
39-
* Constructor that will be useful in the offline library mode, where only the tenant will be set
40-
* and accessed.
41-
*
42-
* @param tenantId Tenant Id
43-
*/
44-
public RequestContext(Optional<String> tenantId) {
39+
/** FOR TESTING ONLY - so we can bypass pulling things the headers, still messy, getting better */
40+
@VisibleForTesting
41+
public RequestContext(
42+
Optional<String> tenantId,
43+
Optional<String> cassandraToken,
44+
RerankingCredentials rerankingCredentials,
45+
String userAgent) {
4546
this.tenantId = tenantId;
46-
cassandraToken = Optional.empty();
47-
embeddingCredentialsSupplier = null;
48-
rerankingCredentials = null;
49-
httpHeaders = null;
47+
this.cassandraToken = cassandraToken;
48+
embeddingCredentialsSupplier =
49+
new EmbeddingCredentialsSupplier(
50+
HttpConstants.AUTHENTICATION_TOKEN_HEADER_NAME,
51+
HttpConstants.EMBEDDING_AUTHENTICATION_TOKEN_HEADER_NAME,
52+
HttpConstants.EMBEDDING_AUTHENTICATION_ACCESS_ID_HEADER_NAME,
53+
HttpConstants.EMBEDDING_AUTHENTICATION_SECRET_ID_HEADER_NAME);
54+
this.rerankingCredentials = rerankingCredentials;
55+
this.userAgent = userAgent;
56+
this.httpHeaders = new HttpHeaderAccess(io.vertx.core.MultiMap.caseInsensitiveMultiMap());
5057
requestId = generateRequestId();
51-
userAgent = null;
5258
}
5359

5460
@Inject
@@ -77,11 +83,14 @@ public RequestContext(
7783
HeaderBasedRerankingKeyResolver.resolveRerankingKey(routingContext);
7884
rerankingCredentials =
7985
rerankingApiKeyFromHeader
80-
.map(apiKey -> new RerankingCredentials(Optional.of(apiKey)))
86+
.map(apiKey -> new RerankingCredentials(this.tenantId.orElse(""), Optional.of(apiKey)))
8187
.orElse(
8288
this.cassandraToken
83-
.map(cassandraToken -> new RerankingCredentials(Optional.of(cassandraToken)))
84-
.orElse(new RerankingCredentials(Optional.empty())));
89+
.map(
90+
cassandraToken ->
91+
new RerankingCredentials(
92+
this.tenantId.orElse(""), Optional.of(cassandraToken)))
93+
.orElse(new RerankingCredentials(this.tenantId.orElse(""), Optional.empty())));
8594
}
8695

8796
private static String generateRequestId() {

src/main/java/io/stargate/sgv2/jsonapi/api/request/RerankingCredentials.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,8 @@
77
* resolved from the request header 'reranking-api-key', if it is not present, then we will use the
88
* cassandra token as the reranking api key. Note, both cassandra token and reranking-api-key could
99
* be absent in Data API request, although it is invalid for authentication.
10+
*
11+
* <p>Includes the tenantId, so we can fully identify the usage when creating the {@link
12+
* io.stargate.sgv2.jsonapi.service.provider.ModelUsage}
1013
*/
11-
public record RerankingCredentials(Optional<String> apiKey) {}
14+
public record RerankingCredentials(String tenantId, Optional<String> apiKey) {}

src/main/java/io/stargate/sgv2/jsonapi/api/v1/CollectionResource.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ public Uni<RestResponse<CommandResult>> postCommand(
255255

256256
if (vectorColDef != null && vectorColDef.vectorizeDefinition() != null) {
257257
embeddingProvider =
258-
embeddingProviderFactory.getConfiguration(
258+
embeddingProviderFactory.create(
259259
requestContext.getTenantId(),
260260
requestContext.getCassandraToken(),
261261
vectorColDef.vectorizeDefinition().provider(),

src/main/java/io/stargate/sgv2/jsonapi/config/constants/RerankingConstants.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,4 @@ interface CollectionRerankingOptions {
66
String ENABLED = "enabled";
77
String SERVICE = ServiceDescConstants.SERVICE;
88
}
9-
10-
interface RerankingService extends ServiceDescConstants {}
119
}

src/main/java/io/stargate/sgv2/jsonapi/config/constants/ServiceDescConstants.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package io.stargate.sgv2.jsonapi.config.constants;
22

33
/** Common service description constants shared between vector and reranking */
4-
interface ServiceDescConstants {
4+
public interface ServiceDescConstants {
55
String SERVICE = "service";
66
String PROVIDER = "provider";
77
String MODEL_NAME = "modelName";

src/main/java/io/stargate/sgv2/jsonapi/service/embedding/DataVectorizer.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ public DataVectorizer(
5252
SchemaObject schemaObject) {
5353
this.embeddingProvider = embeddingProvider;
5454
this.nodeFactory = nodeFactory;
55-
this.embeddingCredentials = embeddingCredentials;
55+
this.embeddingCredentials =
56+
Objects.requireNonNull(embeddingCredentials, "embeddingCredentials must not be null");
5657
this.schemaObject = schemaObject;
5758
}
5859

@@ -175,7 +176,7 @@ public Uni<float[]> vectorize(String vectorizeContent) {
175176
List.of(vectorizeContent),
176177
embeddingCredentials,
177178
EmbeddingProvider.EmbeddingRequestType.INDEX)
178-
.map(EmbeddingProvider.Response::embeddings);
179+
.map(EmbeddingProvider.BatchedEmbeddingResponse::embeddings);
179180
return vectors
180181
.onItem()
181182
.transform(
@@ -303,7 +304,7 @@ private Uni<List<float[]>> vectorizeTexts(
303304

304305
return embeddingProvider
305306
.vectorize(1, textsToVectorize, embeddingCredentials, requestType)
306-
.map(EmbeddingProvider.Response::embeddings)
307+
.map(EmbeddingProvider.BatchedEmbeddingResponse::embeddings)
307308
.onItem()
308309
.transform(
309310
vectorData -> {

0 commit comments

Comments
 (0)