From 5ce5861fd3204cbb09e1dc284907270283b7e1e0 Mon Sep 17 00:00:00 2001 From: Parv Ahuja Date: Mon, 8 Jun 2026 14:15:17 -0700 Subject: [PATCH] feat: add mcp client approval callbacks --- .changeset/mcp-client-approval-callbacks.md | 5 + src/mcp-sdk/client/McpClient.test-d.ts | 7 + src/mcp-sdk/client/McpClient.ts | 166 ++++++++++++-------- src/mcp-sdk/client/McpClient.unit.test.ts | 131 +++++++++++++++ 4 files changed, 242 insertions(+), 67 deletions(-) create mode 100644 .changeset/mcp-client-approval-callbacks.md create mode 100644 src/mcp-sdk/client/McpClient.unit.test.ts diff --git a/.changeset/mcp-client-approval-callbacks.md b/.changeset/mcp-client-approval-callbacks.md new file mode 100644 index 00000000..68f32d22 --- /dev/null +++ b/.changeset/mcp-client-approval-callbacks.md @@ -0,0 +1,5 @@ +--- +'mppx': patch +--- + +Added MCP client payment approval callbacks before credential creation. diff --git a/src/mcp-sdk/client/McpClient.test-d.ts b/src/mcp-sdk/client/McpClient.test-d.ts index 48371b59..82be9d6a 100644 --- a/src/mcp-sdk/client/McpClient.test-d.ts +++ b/src/mcp-sdk/client/McpClient.test-d.ts @@ -73,8 +73,15 @@ describe('McpClient.wrap', () => { }) expectTypeOf(wrapped.callTool).toBeCallableWith({ name: 'tool' }) + expectTypeOf(wrapped.callTool).toBeCallableWith(null, { name: 'tool' }) + expectTypeOf(wrapped.callTool).toBeCallableWith(() => true, { name: 'tool' }) expectTypeOf(wrapped.callTool).toBeCallableWith({ name: 'tool' }, {}) expectTypeOf(wrapped.callTool).toBeCallableWith({ name: 'tool' }, { timeout: 5000 }) + expectTypeOf(wrapped.callTool).toBeCallableWith( + async (challenge) => challenge.intent === 'charge', + { name: 'tool' }, + { timeout: 5000 }, + ) }) test('callTool result includes receipt', () => { diff --git a/src/mcp-sdk/client/McpClient.ts b/src/mcp-sdk/client/McpClient.ts index c369c7c5..608e65fa 100644 --- a/src/mcp-sdk/client/McpClient.ts +++ b/src/mcp-sdk/client/McpClient.ts @@ -11,6 +11,14 @@ import type * as z from '../../zod.js' type AnyClient = Method.Client +export type CallToolParameters = { + name: string + arguments?: Record + _meta?: Record +} + +export type OnPaymentRequired = (challenge: Challenge.Challenge) => boolean | Promise + /** * Result of a tool call with payment handling. * Extends the SDK's callTool return type with an optional payment receipt. @@ -55,69 +63,91 @@ export function wrap< const { methods } = config const paymentPreferences = AcceptPayment.resolve(methods) - return { - ...client, - async callTool(params, options) { - const context = options?.context - const timeout = options?.timeout - - try { - const result = await client.callTool( - params, - undefined, - timeout !== undefined ? { timeout } : undefined, + const callTool = (async ( + first: CallToolParameters | OnPaymentRequired | null | undefined, + second?: CallToolParameters | wrap.CallToolOptions, + third?: wrap.CallToolOptions, + ) => { + const hasApprovalArgument = typeof first === 'function' || first === null || first === undefined + const params = (hasApprovalArgument ? second : first) as CallToolParameters + const options = (hasApprovalArgument ? third : second) as + | wrap.CallToolOptions + | undefined + const onPaymentRequired = + first === null + ? undefined + : hasApprovalArgument + ? ((first as OnPaymentRequired | undefined) ?? config.onPaymentRequired) + : config.onPaymentRequired + const context = options?.context + const timeout = options?.timeout + + try { + const result = await client.callTool( + params, + undefined, + timeout !== undefined ? { timeout } : undefined, + ) + + return { + ...result, + receipt: result._meta?.[core_Mcp.receiptMetaKey] as core_Mcp.Receipt | undefined, + } + } catch (error) { + // Check if this is a payment required error + if (!isPaymentRequiredError(error)) throw error + + const challenges = (error.data as { challenges?: Challenge.Challenge[] })?.challenges + if (!challenges?.length) throw error + + const selected = AcceptPayment.selectChallenge( + challenges, + methods, + paymentPreferences.entries, + ) + if (!selected) { + const available = challenges.map((c) => `${c.method}.${c.intent}`).join(', ') + const installed = methods.map((m) => `${m.name}.${m.intent}`).join(', ') + throw new Error( + `No compatible payment method. Server offers: ${available}. Client has: ${installed}`, + { cause: error }, ) + } - return { - ...result, - receipt: result._meta?.[core_Mcp.receiptMetaKey] as core_Mcp.Receipt | undefined, - } - } catch (error) { - // Check if this is a payment required error - if (!isPaymentRequiredError(error)) throw error - - const challenges = (error.data as { challenges?: Challenge.Challenge[] })?.challenges - if (!challenges?.length) throw error - - const selected = AcceptPayment.selectChallenge( - challenges, - methods, - paymentPreferences.entries, - ) - if (!selected) { - const available = challenges.map((c) => `${c.method}.${c.intent}`).join(', ') - const installed = methods.map((m) => `${m.name}.${m.intent}`).join(', ') - throw new Error( - `No compatible payment method. Server offers: ${available}. Client has: ${installed}`, - { cause: error }, - ) - } - - const credential = await createCredential(selected.challenge, { - context, - methods, - }) - const parsed = Credential.deserialize(credential) - - const retryResult = await client.callTool( - { - ...params, - _meta: { - ...params._meta, - [core_Mcp.credentialMetaKey]: parsed, - }, - }, - undefined, - timeout !== undefined ? { timeout } : undefined, - ) + if (selected.challenge.expires) + Expires.assert(selected.challenge.expires, selected.challenge.id) - return { - ...retryResult, - receipt: retryResult._meta?.[core_Mcp.receiptMetaKey] as core_Mcp.Receipt | undefined, - } + if (onPaymentRequired) { + const approved = await onPaymentRequired(selected.challenge) + if (!approved) throw new Error('Payment declined.', { cause: error }) } - }, - } + + const credential = await createCredential(selected.challenge, { + context, + methods, + }) + const parsed = Credential.deserialize(credential) + + const retryResult = await client.callTool( + { + ...params, + _meta: { + ...params._meta, + [core_Mcp.credentialMetaKey]: parsed, + }, + }, + undefined, + timeout !== undefined ? { timeout } : undefined, + ) + + return { + ...retryResult, + receipt: retryResult._meta?.[core_Mcp.receiptMetaKey] as core_Mcp.Receipt | undefined, + } + } + }) as wrap.McpClient['callTool'] + + return { ...client, callTool } as wrap.McpClient } /** Union of all context types from all methods that have context schemas. */ @@ -133,6 +163,8 @@ export declare namespace wrap { type Config = { /** Array of methods to use. */ methods: methods + /** Optional approval hook called before creating a payment credential. */ + onPaymentRequired?: OnPaymentRequired } type McpClient< @@ -140,14 +172,14 @@ export declare namespace wrap { methods extends readonly AnyClient[] = readonly AnyClient[], > = Omit & { /** Call a tool with automatic payment handling. */ - callTool: ( - params: { - name: string - arguments?: Record - _meta?: Record - }, - options?: CallToolOptions, - ) => Promise + callTool: { + (params: CallToolParameters, options?: CallToolOptions): Promise + ( + onPaymentRequired: OnPaymentRequired | null | undefined, + params: CallToolParameters, + options?: CallToolOptions, + ): Promise + } } type CallToolOptions = { diff --git a/src/mcp-sdk/client/McpClient.unit.test.ts b/src/mcp-sdk/client/McpClient.unit.test.ts new file mode 100644 index 00000000..6a9feb4a --- /dev/null +++ b/src/mcp-sdk/client/McpClient.unit.test.ts @@ -0,0 +1,131 @@ +import type { Client } from '@modelcontextprotocol/sdk/client/index.js' +import { McpError } from '@modelcontextprotocol/sdk/types.js' +import { Challenge, Credential, Mcp as core_Mcp, Method } from 'mppx' +import { Methods } from 'mppx/tempo' +import { describe, expect, test, vi } from 'vp/test' + +import * as McpClient from './McpClient.js' + +describe('MCP client payment approval', () => { + test('calls an approval hook before creating a credential', async () => { + const challenge = Challenge.from({ + id: 'approval-test', + intent: 'charge', + method: 'tempo', + realm: 'api.example.com', + request: {}, + }) + const calls: unknown[] = [] + const client = { + async callTool(params: unknown) { + calls.push(params) + if (calls.length === 1) + throw new McpError(core_Mcp.paymentRequiredCode, 'Payment Required', { + challenges: [challenge], + httpStatus: 402, + }) + return { + _meta: { + [core_Mcp.receiptMetaKey]: { + method: 'tempo', + reference: 'test', + status: 'success', + timestamp: new Date().toISOString(), + }, + }, + content: [{ type: 'text', text: 'ok' }], + } + }, + } + const createCredential = vi.fn(async ({ challenge }: { challenge: Challenge.Challenge }) => + Credential.serialize({ + challenge, + payload: { signature: '0xsignature', type: 'transaction' }, + }), + ) + const onPaymentRequired = vi.fn(() => true) + const mcp = McpClient.wrap(client as unknown as Pick, { + methods: [Method.toClient(Methods.charge, { createCredential })], + }) + + const result = await mcp.callTool(onPaymentRequired, { name: 'paid_tool', arguments: {} }) + + expect(result.content).toEqual([{ type: 'text', text: 'ok' }]) + expect(onPaymentRequired).toHaveBeenCalledWith(challenge) + expect(createCredential).toHaveBeenCalledOnce() + expect(calls).toHaveLength(2) + }) + + test('does not create a credential when approval is denied', async () => { + const challenge = Challenge.from({ + id: 'denied-test', + intent: 'charge', + method: 'tempo', + realm: 'api.example.com', + request: {}, + }) + const client = { + async callTool() { + throw new McpError(core_Mcp.paymentRequiredCode, 'Payment Required', { + challenges: [challenge], + httpStatus: 402, + }) + }, + } + const createCredential = vi.fn(async ({ challenge }: { challenge: Challenge.Challenge }) => + Credential.serialize({ + challenge, + payload: { signature: '0xsignature', type: 'transaction' }, + }), + ) + const mcp = McpClient.wrap(client as unknown as Pick, { + methods: [Method.toClient(Methods.charge, { createCredential })], + }) + + await expect(mcp.callTool(() => false, { name: 'paid_tool' })).rejects.toThrow( + 'Payment declined.', + ) + expect(createCredential).not.toHaveBeenCalled() + }) + + test('allows null to bypass a config approval hook', async () => { + const challenge = Challenge.from({ + id: 'null-bypass-test', + intent: 'charge', + method: 'tempo', + realm: 'api.example.com', + request: {}, + }) + let calls = 0 + const client = { + async callTool() { + calls += 1 + if (calls === 1) + throw new McpError(core_Mcp.paymentRequiredCode, 'Payment Required', { + challenges: [challenge], + httpStatus: 402, + }) + return { + content: [{ type: 'text', text: 'ok' }], + } + }, + } + const createCredential = vi.fn(async ({ challenge }: { challenge: Challenge.Challenge }) => + Credential.serialize({ + challenge, + payload: { signature: '0xsignature', type: 'transaction' }, + }), + ) + const onPaymentRequired = vi.fn(() => false) + const mcp = McpClient.wrap(client as unknown as Pick, { + methods: [Method.toClient(Methods.charge, { createCredential })], + onPaymentRequired, + }) + + await expect(mcp.callTool(null, { name: 'paid_tool' })).resolves.toMatchObject({ + content: [{ type: 'text', text: 'ok' }], + }) + expect(onPaymentRequired).not.toHaveBeenCalled() + expect(createCredential).toHaveBeenCalledOnce() + }) +})