Skip to content

Commit e06d193

Browse files
committed
refactor(compile, evaluate, types, lambda): introduce FnNode and enhance control flow handling
1 parent 775a906 commit e06d193

5 files changed

Lines changed: 333 additions & 97 deletions

File tree

src/compile.ts

Lines changed: 185 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import type { ASTNode } from "./ast-types";
22
import { generate, transformIdentifiers, transformPlaceholders } from "./generate";
33
import { serializeArgumentToAST } from "./proxy-variable";
4-
import type { BranchNode, CompiledData, CompiledExpression, ExprValue, JumpNode, PhiNode } from "./types";
4+
import type { BranchNode, CompiledData, CompiledExpression, ExprValue, FnNode, JumpNode, PhiNode } from "./types";
55
import { getVariableId } from "./variable";
66

77
const ALLOWED_GLOBALS = new Set([
@@ -46,6 +46,20 @@ const ALLOWED_GLOBALS = new Set([
4646
*/
4747
export interface CompileOptions {}
4848

49+
/**
50+
* 编译上下文(用于 lambda 参数分配)
51+
*/
52+
interface CompileCtx {
53+
/** 下一个可用的参数名索引(全局递增) */
54+
nextParamIndex: number;
55+
/** 当前正在编译的表达式列表(栈顶为当前 lambda 或顶层) */
56+
expressionStack: CompiledExpression[][];
57+
/** 下一个全局表达式索引 */
58+
nextIndex: number;
59+
/** 变量数量(用于索引偏移) */
60+
variableCount: number;
61+
}
62+
4963
/**
5064
* 将 Proxy Expression 编译为可序列化的 JSON 结构
5165
*
@@ -106,6 +120,9 @@ export function compile<TResult>(
106120
// 已经转换为 $[N] 的跳过
107121
if (name.startsWith("$[") && /^\$\[\d+\]$/.test(name)) return name;
108122

123+
// lambda 参数名 _N 跳过
124+
if (/^_\d+$/.test(name)) return name;
125+
109126
const index = variableToIndex.get(name);
110127
if (index !== undefined) return `$[${index}]`;
111128

@@ -119,70 +136,191 @@ export function compile<TResult>(
119136
}
120137

121138
// 生成编译后的表达式(短路求值总是启用)
122-
const expressions: CompiledExpression[] = [];
123-
let nextIndex = variableOrder.length;
139+
const topLevelExprs: CompiledExpression[] = [];
140+
const ctx: CompileCtx = {
141+
nextParamIndex: 0,
142+
expressionStack: [topLevelExprs],
143+
nextIndex: variableOrder.length,
144+
variableCount: variableOrder.length,
145+
};
124146

125-
function compileAst(node: ASTNode): number {
126-
if (node.type === "BinaryExpr" && (node.operator === "||" || node.operator === "&&" || node.operator === "??")) {
127-
return compileShortCircuit(node);
128-
}
129-
if (node.type === "ConditionalExpr") {
130-
return compileConditional(node);
147+
compileAst(transformed, ctx);
148+
149+
return [variableOrder, ...topLevelExprs];
150+
}
151+
152+
function currentExprs(ctx: CompileCtx): CompiledExpression[] {
153+
return ctx.expressionStack[ctx.expressionStack.length - 1]!;
154+
}
155+
156+
/**
157+
* 提取 AST 中所有 ArrowFunctionExpr 节点,编译为 FnNode,
158+
* 并将原始位置替换为 $[N] 标识符引用。
159+
*/
160+
function extractAndCompileArrowFunctions(node: ASTNode, ctx: CompileCtx): ASTNode {
161+
switch (node.type) {
162+
case "ArrowFunctionExpr": {
163+
// 编译这个箭头函数为 FnNode,返回 $[N] 引用
164+
const idx = compileArrowFunction(node, ctx);
165+
return { type: "Identifier", name: `$[${idx}]` };
131166
}
132-
const exprStr = generate(node);
133-
expressions.push(exprStr);
134-
return nextIndex++;
135-
}
136167

137-
function compileShortCircuit(node: ASTNode & { type: "BinaryExpr" }): number {
138-
const leftIdx = compileAst(node.left);
168+
case "BinaryExpr":
169+
return {
170+
...node,
171+
left: extractAndCompileArrowFunctions(node.left, ctx),
172+
right: extractAndCompileArrowFunctions(node.right, ctx),
173+
};
174+
175+
case "UnaryExpr":
176+
return {
177+
...node,
178+
argument: extractAndCompileArrowFunctions(node.argument, ctx),
179+
};
180+
181+
case "ConditionalExpr":
182+
return {
183+
...node,
184+
test: extractAndCompileArrowFunctions(node.test, ctx),
185+
consequent: extractAndCompileArrowFunctions(node.consequent, ctx),
186+
alternate: extractAndCompileArrowFunctions(node.alternate, ctx),
187+
};
139188

140-
const branchConditions: Record<string, string> = {
141-
"||": `$[${leftIdx}]`,
142-
"&&": `!$[${leftIdx}]`,
143-
"??": `$[${leftIdx}]!=null`,
144-
};
189+
case "MemberExpr":
190+
return {
191+
...node,
192+
object: extractAndCompileArrowFunctions(node.object, ctx),
193+
property: node.computed ? extractAndCompileArrowFunctions(node.property, ctx) : node.property,
194+
};
145195

146-
const branchIdx = expressions.length;
147-
expressions.push(["br", branchConditions[node.operator], 0] as BranchNode);
148-
nextIndex++;
196+
case "CallExpr":
197+
return {
198+
...node,
199+
callee: extractAndCompileArrowFunctions(node.callee, ctx),
200+
arguments: node.arguments.map((arg) => extractAndCompileArrowFunctions(arg, ctx)),
201+
};
149202

150-
compileAst(node.right);
151-
const skipCount = expressions.length - branchIdx - 1;
152-
(expressions[branchIdx] as BranchNode)[2] = skipCount;
203+
case "ArrayExpr":
204+
return {
205+
...node,
206+
elements: node.elements.map((el) => extractAndCompileArrowFunctions(el, ctx)),
207+
};
153208

154-
const phiIdx = nextIndex++;
155-
expressions.push(["phi"] as PhiNode);
209+
case "ObjectExpr":
210+
return {
211+
...node,
212+
properties: node.properties.map((prop) => ({
213+
...prop,
214+
key: prop.computed ? extractAndCompileArrowFunctions(prop.key, ctx) : prop.key,
215+
value: extractAndCompileArrowFunctions(prop.value, ctx),
216+
})),
217+
};
156218

157-
return phiIdx;
219+
default:
220+
return node;
158221
}
222+
}
159223

160-
function compileConditional(node: ASTNode & { type: "ConditionalExpr" }): number {
161-
const testIdx = compileAst(node.test);
224+
function compileAst(node: ASTNode, ctx: CompileCtx): number {
225+
if (node.type === "BinaryExpr" && (node.operator === "||" || node.operator === "&&" || node.operator === "??")) {
226+
return compileShortCircuit(node, ctx);
227+
}
228+
if (node.type === "ConditionalExpr") {
229+
return compileConditional(node, ctx);
230+
}
231+
if (node.type === "ArrowFunctionExpr") {
232+
return compileArrowFunction(node, ctx);
233+
}
162234

163-
const branchIdx = expressions.length;
164-
expressions.push(["br", `$[${testIdx}]`, 0] as BranchNode);
165-
nextIndex++;
235+
// 提取并编译嵌套的箭头函数,替换为 $[N] 引用
236+
const processed = extractAndCompileArrowFunctions(node, ctx);
166237

167-
compileAst(node.alternate);
238+
const exprStr = generate(processed);
239+
currentExprs(ctx).push(exprStr);
240+
return ctx.nextIndex++;
241+
}
168242

169-
const jmpIdx = expressions.length;
170-
expressions.push(["jmp", 0] as JumpNode);
171-
nextIndex++;
243+
function compileShortCircuit(node: ASTNode & { type: "BinaryExpr" }, ctx: CompileCtx): number {
244+
const exprs = currentExprs(ctx);
245+
const leftIdx = compileAst(node.left, ctx);
172246

173-
compileAst(node.consequent);
174-
const thenEndIdx = expressions.length;
247+
const branchConditions: Record<string, string> = {
248+
"||": `$[${leftIdx}]`,
249+
"&&": `!$[${leftIdx}]`,
250+
"??": `$[${leftIdx}]!=null`,
251+
};
175252

176-
(expressions[branchIdx] as BranchNode)[2] = jmpIdx - branchIdx;
177-
(expressions[jmpIdx] as JumpNode)[1] = thenEndIdx - jmpIdx - 1;
253+
const branchIdx = exprs.length;
254+
exprs.push(["br", branchConditions[node.operator], 0] as BranchNode);
255+
ctx.nextIndex++;
178256

179-
const phiIdx = nextIndex++;
180-
expressions.push(["phi"] as PhiNode);
257+
compileAst(node.right, ctx);
258+
const skipCount = exprs.length - branchIdx - 1;
259+
(exprs[branchIdx] as BranchNode)[2] = skipCount;
181260

182-
return phiIdx;
261+
const phiIdx = ctx.nextIndex++;
262+
exprs.push(["phi"] as PhiNode);
263+
264+
return phiIdx;
265+
}
266+
267+
function compileConditional(node: ASTNode & { type: "ConditionalExpr" }, ctx: CompileCtx): number {
268+
const exprs = currentExprs(ctx);
269+
const testIdx = compileAst(node.test, ctx);
270+
271+
const branchIdx = exprs.length;
272+
exprs.push(["br", `$[${testIdx}]`, 0] as BranchNode);
273+
ctx.nextIndex++;
274+
275+
compileAst(node.alternate, ctx);
276+
277+
const jmpIdx = exprs.length;
278+
exprs.push(["jmp", 0] as JumpNode);
279+
ctx.nextIndex++;
280+
281+
compileAst(node.consequent, ctx);
282+
const thenEndIdx = exprs.length;
283+
284+
(exprs[branchIdx] as BranchNode)[2] = jmpIdx - branchIdx;
285+
(exprs[jmpIdx] as JumpNode)[1] = thenEndIdx - jmpIdx - 1;
286+
287+
const phiIdx = ctx.nextIndex++;
288+
exprs.push(["phi"] as PhiNode);
289+
290+
return phiIdx;
291+
}
292+
293+
function compileArrowFunction(node: ASTNode & { type: "ArrowFunctionExpr" }, ctx: CompileCtx): number {
294+
const paramCount = node.params.length;
295+
296+
// 0. Claim this FnNode's global index first
297+
const fnIndex = ctx.nextIndex++;
298+
299+
// 1. 为参数分配 _N 名称
300+
const paramMapping = new Map<symbol, string>();
301+
for (const param of node.params) {
302+
if (param.type === "Placeholder") {
303+
const paramName = `_${ctx.nextParamIndex++}`;
304+
paramMapping.set(param.id, paramName);
305+
}
183306
}
184307

185-
compileAst(transformed);
308+
// 2. 将函数体中的 lambda 参数 Placeholder 转换为 _N 标识符
309+
const transformedBody = transformPlaceholders(node.body, (id) => {
310+
return paramMapping.get(id) ?? null;
311+
});
312+
313+
// 3. 编译函数体到新的表达式列表
314+
const lambdaStmts: CompiledExpression[] = [];
315+
ctx.expressionStack.push(lambdaStmts);
316+
317+
compileAst(transformedBody, ctx);
318+
319+
ctx.expressionStack.pop();
320+
321+
// 4. 构造 FnNode
322+
const fnNode: FnNode = ["fn", paramCount, ...lambdaStmts];
323+
currentExprs(ctx).push(fnNode);
186324

187-
return [variableOrder, ...expressions];
325+
return fnIndex;
188326
}

0 commit comments

Comments
 (0)