diff --git a/__tests__/index.test.ts b/__tests__/index.test.ts index 3875857..68af786 100644 --- a/__tests__/index.test.ts +++ b/__tests__/index.test.ts @@ -1054,6 +1054,92 @@ describe("map() over RPC", () => { }); }); +describe("record-replay closure over RPC", () => { + it("passes a function as a closure inside a map() callback", async () => { + await using harness = new TestHarness(new TestTarget()); + let counter = new RpcStub(new Counter(0)); + + expect(await harness.stub.map(stub => { + return stub.callFunction(y => counter.increment(y), 3); + })).toStrictEqual({result: 3}); + }); + + it("encodes a closure's literal return value as its terminal instruction", async () => { + await using harness = new TestHarness(new TestTarget()); + let counter = new RpcStub(new Counter(0)); + + expect(await harness.stub.map(stub => { + return stub.callFunction(y => { + counter.increment(y); + return 42; + }, 3); + })).toStrictEqual({result: 42}); + expect(await counter.value).toBe(3); + }); + + it("supports invoking a received closure multiple times", async () => { + class DoubleCaller extends RpcTarget { + async callTwice(fn: RpcStub<(x: number) => Promise>) { + return [await fn(1), await fn(2)]; + } + } + + await using harness = new TestHarness(new DoubleCaller()); + let counter = new RpcStub(new Counter(10)); + + expect(await harness.stub.map(stub => { + return stub.callTwice(y => counter.increment(y)); + })).toStrictEqual([11, 13]); + }); + + it("supports nested closures", async () => { + class Passthrough extends RpcTarget { + call(fn: RpcStub<(x: number) => number | Promise>, x: number) { + return fn(x); + } + } + + await using harness = new TestHarness(new Passthrough()); + let counter = new RpcStub(new Counter(0)); + + expect(await harness.stub.map(stub => { + return stub.call(y => { + return stub.call(z => { + counter.increment(y); + return counter.increment(z); + }, 7); + }, 5); + })).toBe(12); + }); + + it("supports dup() to stash a closure past its param payload's lifetime", async () => { + class Stasher extends RpcTarget { + private stashed: any; + + stashFn(fn: any) { this.stashed = fn.dup(); } + async invokeStashed(x: number) { return await this.stashed(x); } + release() { + this.stashed?.[Symbol.dispose](); + this.stashed = undefined; + } + } + + await using harness = new TestHarness(new Stasher()); + let stub = harness.stub; + let counter = new RpcStub(new Counter(100)); + + await stub.map(stub => { + return stub.stashFn((y: number) => counter.increment(y)); + }); + + expect(await stub.invokeStashed(5)).toBe(105); + expect(await stub.invokeStashed(3)).toBe(108); + + await stub.release(); + }); + +}); + describe("stub disposal over RPC", () => { it("disposes remote RpcTarget when stub is disposed", async () => { let targetDisposedCount = 0; diff --git a/__tests__/test-util.ts b/__tests__/test-util.ts index 1559230..9053901 100644 --- a/__tests__/test-util.ts +++ b/__tests__/test-util.ts @@ -35,7 +35,7 @@ export class TestTarget extends RpcTarget { return { result: self.square(i) }; } - async callFunction(func: RpcStub<(i: number) => Promise>, i: number) { + async callFunction(func: RpcStub<(i: number) => number | Promise>, i: number) { return { result: await func(i) }; } diff --git a/src/core.ts b/src/core.ts index 41c3221..459d998 100644 --- a/src/core.ts +++ b/src/core.ts @@ -4,6 +4,7 @@ import type { RpcTargetBranded, __RPC_TARGET_BRAND } from "./types.js"; import { WORKERS_MODULE_SYMBOL } from "./symbols.js" +import type { Importer } from "./serialize.js"; // Polyfill Symbol.dispose for browsers that don't support it yet if (!Symbol.dispose) { @@ -155,7 +156,13 @@ function mapNotLoaded(): never { // map() is implemented in `map.ts`. We can't import it here because it would create an import // cycle, so instead we define two hook functions that map.ts will overwrite when it is imported. -export let mapImpl: MapImpl = { applyMap: mapNotLoaded, sendMap: mapNotLoaded }; +export let mapImpl: MapImpl = { + applyMap: mapNotLoaded, + sendMap: mapNotLoaded, + evaluateCaptures: mapNotLoaded, + serializeClosure: mapNotLoaded, + evaluateClosure: mapNotLoaded +}; type MapImpl = { // Applies a map function to an input value (usually an array). @@ -166,6 +173,12 @@ type MapImpl = { // Implements the .map() method of RpcStub. sendMap(hook: StubHook, path: PropertyPath, func: (value: RpcPromise) => unknown) : RpcPromise; + + evaluateCaptures(captures: unknown[], importer: Importer): StubHook[]; + + evaluateClosure(captures: StubHook[], instructions: unknown[]): (arg: unknown) => unknown; + + serializeClosure(func: (value: RpcPromise) => unknown): unknown[]; } function streamNotLoaded(): never { diff --git a/src/map.ts b/src/map.ts index b161456..5406df9 100644 --- a/src/map.ts +++ b/src/map.ts @@ -5,6 +5,8 @@ import { StubHook, PropertyPath, RpcPayload, RpcStub, RpcPromise, withCallInterceptor, ErrorStubHook, mapImpl, PayloadStubHook, unwrapStubAndPath, unwrapStubNoProperties } from "./core.js"; import { Devaluator, Exporter, Importer, ExportId, ImportId, Evaluator } from "./serialize.js"; +const AsyncFunction = (async function () {}).constructor; + let currentMapBuilder: MapBuilder | undefined; // We use this type signature when building the instructions for type checking purposes. It @@ -16,27 +18,18 @@ export type MapInstruction = class MapBuilder implements Exporter { private context: - | {parent: undefined, captures: StubHook[], subject: StubHook, path: PropertyPath} - | {parent: MapBuilder, captures: number[], subject: number, path: PropertyPath}; + | {parent: undefined, captures: StubHook[]} + | {parent: MapBuilder, captures: number[]}; private captureMap: Map = new Map(); private instructions: MapInstruction[] = []; + exportFunctionAsClosure = true; - constructor(subject: StubHook, path: PropertyPath) { + constructor() { if (currentMapBuilder) { - this.context = { - parent: currentMapBuilder, - captures: [], - subject: currentMapBuilder.capture(subject), - path - }; + this.context = { parent: currentMapBuilder, captures: [] }; } else { - this.context = { - parent: undefined, - captures: [], - subject, - path - }; + this.context = { parent: undefined, captures: [] }; } currentMapBuilder = this; @@ -50,7 +43,8 @@ class MapBuilder implements Exporter { return new MapVariableHook(this, 0); } - makeOutput(result: RpcPayload): StubHook { + // Devalue the callback's return and push it as the terminal instruction. + private finalize(result: RpcPayload): void { let devalued: unknown; try { devalued = Devaluator.devaluate(result.value, undefined, this, result); @@ -61,19 +55,31 @@ class MapBuilder implements Exporter { // The result is the final instruction. This doesn't actually fit our MapInstruction type // signature, so we cheat a bit. this.instructions.push(devalued); + } + + finalizeAsRemap(result: RpcPayload, subject: StubHook, path: PropertyPath): StubHook { + this.finalize(result); if (this.context.parent) { + const subjectIdx = this.context.parent.capture(subject); this.context.parent.instructions.push( - ["remap", this.context.subject, this.context.path, + ["remap", subjectIdx, path, this.context.captures.map(cap => ["import", cap]), this.instructions] ); return new MapVariableHook(this.context.parent, this.context.parent.instructions.length); } else { - return this.context.subject.map(this.context.path, this.context.captures, this.instructions); + return subject.map(path, this.context.captures, this.instructions); } } + finalizeAsClosure(result: RpcPayload): unknown[] { + this.finalize(result); + return ["closure", + this.context.captures.map(cap => ["import", cap]), + this.instructions]; + } + pushCall(hook: StubHook, path: PropertyPath, params: RpcPayload): StubHook { let devalued = Devaluator.devaluate(params.value, undefined, this, params); // HACK: Since the args is an array, devaluator will wrap in a second array. Need to unwrap. @@ -154,8 +160,15 @@ class MapBuilder implements Exporter { } }; -mapImpl.sendMap = (hook: StubHook, path: PropertyPath, func: (promise: RpcPromise) => unknown) => { - let builder = new MapBuilder(hook, path); +mapImpl.serializeClosure = (func: (promise: RpcPromise) => unknown): unknown[] => { + if (func.length !== 1) { + throw new Error("Only single-argument functions can be serialized as closures."); + } + if (Object.getPrototypeOf(func) === AsyncFunction.prototype) { + throw new Error("RPC closures cannot be async functions."); + } + + let builder = new MapBuilder(); let result: RpcPayload; try { result = RpcPayload.fromAppReturn(withCallInterceptor(builder.pushCall.bind(builder), () => { @@ -165,17 +178,28 @@ mapImpl.sendMap = (hook: StubHook, path: PropertyPath, func: (promise: RpcPromis builder.unregister(); } - // Detect misuse: Map callbacks cannot be async. - if (result instanceof Promise) { - // Squelch unhandled rejections from the map function itself -- it'll probably just throw - // something about pulling a MapVariableHook. - result.catch(err => {}); + return builder.finalizeAsClosure(result); +} - // Throw an understandable error. +mapImpl.sendMap = (hook: StubHook, path: PropertyPath, func: (promise: RpcPromise) => unknown) => { + if (Object.getPrototypeOf(func) === AsyncFunction.prototype) { throw new Error("RPC map() callbacks cannot be async."); } - return new RpcPromise(builder.makeOutput(result), []); + if (currentMapBuilder) { + currentMapBuilder.capture(hook); + } + let builder = new MapBuilder(); + let result: RpcPayload; + try { + result = RpcPayload.fromAppReturn(withCallInterceptor(builder.pushCall.bind(builder), () => { + return func(new RpcPromise(builder.makeInput(), [])); + })); + } finally { + builder.unregister(); + } + + return new RpcPromise(builder.finalizeAsRemap(result, hook, path), []); } function throwMapperBuilderUseError(): never { @@ -350,4 +374,54 @@ mapImpl.applyMap = (input: unknown, parent: object | undefined, owner: RpcPayloa } } +mapImpl.evaluateCaptures = (rawCaptures: unknown[], importer: Importer) => { + return rawCaptures.map(cap => { + if (!(cap instanceof Array) || + cap.length !== 2 || + (cap[0] !== "import" && cap[0] !== "export") || + typeof cap[1] !== "number") { + throw new TypeError(`unknown map capture: ${JSON.stringify(cap)}`); + } + + if (cap[0] === "export") { + return importer.importStub(cap[1]); + } else { + let exp = importer.getExport(cap[1]); + if (!exp) { + throw new Error(`no such entry on exports table: ${cap[1]}`); + } + return exp.dup(); + } + }); +} + +mapImpl.evaluateClosure = (captures: StubHook[], instructions: unknown[]): (arg: unknown) => Promise => { + let disposed = false; + const dispose = () => { + disposed = true; + for (let cap of captures) { + cap.dispose(); + } + } + + const fn = (arg: unknown): Promise => { + if (disposed) { + throw new Error("Attempted to call a closure after it was disposed."); + } + const payload = applyMapToElement(arg, undefined, null, captures, instructions); + return payload.deliverResolve(); + } + + fn.dup = () => { + if (disposed) { + throw new Error("Attempted to dup a disposed closure."); + } + return mapImpl.evaluateClosure(captures.map(cap => cap.dup()), instructions); + } + + fn[Symbol.dispose] = dispose; + + return fn; +} + export function forceInitMap() {} diff --git a/src/serialize.ts b/src/serialize.ts index 11bf830..622e5e8 100644 --- a/src/serialize.ts +++ b/src/serialize.ts @@ -2,7 +2,7 @@ // Licensed under the MIT license found in the LICENSE.txt file or at: // https://opensource.org/license/mit -import { StubHook, RpcPayload, typeForRpc, RpcStub, RpcPromise, LocatedPromise, RpcTarget, unwrapStubAndPath, streamImpl, PromiseStubHook, PayloadStubHook } from "./core.js"; +import { StubHook, RpcPayload, typeForRpc, RpcStub, RpcPromise, LocatedPromise, RpcTarget, unwrapStubAndPath, streamImpl, PromiseStubHook, PayloadStubHook, mapImpl } from "./core.js"; export type ImportId = number; export type ExportId = number; @@ -10,6 +10,7 @@ export type ExportId = number; // ======================================================================================= export interface Exporter { + exportFunctionAsClosure?: boolean; exportStub(hook: StubHook): ExportId; exportPromise(hook: StubHook): ExportId; getImport(hook: StubHook): ImportId | undefined; @@ -350,6 +351,10 @@ export class Devaluator { if (!this.source) { throw new Error("Can't serialize RPC stubs in this context."); } + if (kind === 'function' && this.exporter.exportFunctionAsClosure) { + // Serialize as a closure (record-replay) + return mapImpl.serializeClosure(value as ((arg: any) => any)); + } let hook = this.source.getHookForRpcTarget(value, parent); return this.devaluateHook("export", hook); @@ -725,24 +730,7 @@ export class Evaluator { break; // report error below } - let captures: StubHook[] = value[3].map(cap => { - if (!(cap instanceof Array) || - cap.length !== 2 || - (cap[0] !== "import" && cap[0] !== "export") || - typeof cap[1] !== "number") { - throw new TypeError(`unknown map capture: ${JSON.stringify(cap)}`); - } - - if (cap[0] === "export") { - return this.importer.importStub(cap[1]); - } else { - let exp = this.importer.getExport(cap[1]); - if (!exp) { - throw new Error(`no such entry on exports table: ${cap[1]}`); - } - return exp.dup(); - } - }); + let captures: StubHook[] = mapImpl.evaluateCaptures(value[3], this.importer); let instructions = value[4]; @@ -753,6 +741,21 @@ export class Evaluator { return promise; } + case "closure": { + if (value.length !== 3 || + !(value[1] instanceof Array) || + !(value[2] instanceof Array)) { + break; // report error below + } + const captures = mapImpl.evaluateCaptures(value[1], this.importer); + const instructions = value[2]; + // Tie each capture's lifetime to the containing payload: + for (const cap of captures) { + this.hooks.push(cap); + } + return mapImpl.evaluateClosure(captures, instructions); + } + case "export": case "promise": // It's an "export" from the perspective of the sender, i.e. they sent us a new object