diff --git a/packages/proxy/schema/index.test.ts b/packages/proxy/schema/index.test.ts index e6db0751..888765d1 100644 --- a/packages/proxy/schema/index.test.ts +++ b/packages/proxy/schema/index.test.ts @@ -344,6 +344,62 @@ describe("APISecretSchema compatibility", () => { }); }); + it("accepts resolved Vertex OAuth bearer metadata", () => { + const parsed = APISecretSchema.parse({ + secret: "google-access-token", + type: "vertex", + metadata: { + authType: "oauth_bearer", + auth_source: "google_workload_identity_federation", + connection_id: 123, + project: "vertex-project", + future_field: "future-value", + }, + }); + + expect(parsed.type).toBe("vertex"); + expect(parsed.metadata).toMatchObject({ + authType: "oauth_bearer", + connection_id: 123, + auth_source: "google_workload_identity_federation", + future_field: "future-value", + project: "vertex-project", + }); + }); + + it("accepts raw Vertex workload identity metadata", () => { + const parsed = APISecretSchema.parse({ + secret: "__VERTEX_WIF__", + type: "vertex", + metadata: { + authType: "workload_identity_federation", + project: "vertex-project", + workload_identity_provider: "//iam.googleapis.com/projects/123", + }, + }); + + expect(parsed.type).toBe("vertex"); + expect(parsed.metadata).toMatchObject({ + authType: "workload_identity_federation", + project: "vertex-project", + workload_identity_provider: "//iam.googleapis.com/projects/123", + }); + }); + + it("validates OIDC metadata only for raw Vertex workload identity metadata", () => { + const result = APISecretSchema.safeParse({ + secret: "__VERTEX_WIF__", + type: "vertex", + metadata: { + authType: "workload_identity_federation", + project: "vertex-project", + connection_id: 123, + }, + }); + + expect(result.success).toBe(false); + }); + it("defaults Anthropic auth metadata to api_key", () => { const parsed = APISecretSchema.parse({ secret: "anthropic-api-key", diff --git a/packages/proxy/schema/secrets.ts b/packages/proxy/schema/secrets.ts index 4c9eef03..5956af08 100644 --- a/packages/proxy/schema/secrets.ts +++ b/packages/proxy/schema/secrets.ts @@ -88,6 +88,15 @@ export const BedrockMetadataSchemaWithAuth = }); export type BedrockMetadata = z.infer; +export const VertexOIDCSecretMetadataSchema = z.object({ + connection_id: z.string().nullish(), + scopes: z.array(z.string()).nullish(), + workload_identity_provider: z.string().nullish(), +}); +export type VertexOIDCSecretMetadata = z.infer< + typeof VertexOIDCSecretMetadataSchema +>; + export const VertexMetadataSchema = BaseMetadataSchema.merge( z.object({ project: z.string().min(1, "Project cannot be empty"), @@ -97,10 +106,25 @@ export const VertexMetadataSchema = BaseMetadataSchema.merge( } return value; }, z.string().min(1, "Location cannot be empty").optional()), - authType: z.enum(["access_token", "service_account_key"]), + authType: z.enum([ + "access_token", + "oauth_bearer", + "service_account_key", + "workload_identity_federation", + ]), api_base: z.union([z.string().url(), z.string().length(0)]).nullish(), }), ).passthrough(); +export const VertexAPISecretMetadataSchema = z.union([ + VertexMetadataSchema.extend({ + authType: z.literal("workload_identity_federation"), + }) + .merge(VertexOIDCSecretMetadataSchema) + .passthrough(), + VertexMetadataSchema.extend({ + authType: z.enum(["access_token", "oauth_bearer", "service_account_key"]), + }).passthrough(), +]); export const DatabricksMetadataSchema = BaseMetadataSchema.merge( z.object({ @@ -213,7 +237,7 @@ export const APISecretSchema = z.union([ APISecretBaseSchema.merge( z.object({ type: z.literal("vertex"), - metadata: VertexMetadataSchema.nullish(), + metadata: VertexAPISecretMetadataSchema.nullish(), }), ).passthrough(), APISecretBaseSchema.merge( diff --git a/packages/proxy/src/proxy.ts b/packages/proxy/src/proxy.ts index 7d07102e..899c8ecf 100644 --- a/packages/proxy/src/proxy.ts +++ b/packages/proxy/src/proxy.ts @@ -19,6 +19,7 @@ import { MessageTypeToMessageType, modelProviderHasReasoning, ModelSpec, + VertexOIDCSecretMetadataSchema, VertexMetadataSchema, } from "@schema"; import { translateParams } from "../schema/translate"; @@ -2136,6 +2137,9 @@ async function fetchOpenAI( const { project, location, authType, api_base } = VertexMetadataSchema.parse(secret.metadata); + if (authType === "workload_identity_federation") { + VertexOIDCSecretMetadataSchema.parse(secret.metadata); + } const resolvedLocation = resolveVertexLocation({ metadataLocation: location, modelSpec, @@ -2625,6 +2629,9 @@ async function vertexEndpointInfo({ }): Promise { const { project, location, authType, api_base } = VertexMetadataSchema.parse(metadata); + if (authType === "workload_identity_federation") { + VertexOIDCSecretMetadataSchema.parse(metadata); + } const resolvedLocation = resolveVertexLocation({ metadataLocation: location, modelSpec,