Skip to content

Commit f7b7bb5

Browse files
committed
refactor get_program_slice to take nodeID
1 parent 2169730 commit f7b7bb5

2 files changed

Lines changed: 171 additions & 141 deletions

File tree

src/tools/mcp_tools.py

Lines changed: 97 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -2508,16 +2508,15 @@ async def check_method_reachability(
25082508
@mcp.tool()
25092509
async def get_program_slice(
25102510
session_id: str,
2511-
filename: str,
2512-
line_number: int,
2513-
call_name: Optional[str] = None,
2511+
node_id: Optional[str] = None,
2512+
location: Optional[str] = None,
25142513
include_dataflow: bool = True,
25152514
include_control_flow: bool = True,
25162515
max_depth: int = 5,
25172516
timeout: int = 60,
25182517
) -> Dict[str, Any]:
25192518
"""
2520-
Build a program slice from a specific line or call.
2519+
Build a program slice from a specific call node.
25212520
25222521
Creates a backward program slice showing all code that could affect the
25232522
execution at a specific point. This includes:
@@ -2526,12 +2525,16 @@ async def get_program_slice(
25262525
- Control flow: conditions that determine whether the call executes
25272526
- Call graph: functions called and their data dependencies
25282527
2528+
**Important**: Use node IDs (from list_calls) or specify exact locations to
2529+
avoid ambiguity, especially when multiple calls appear on the same line.
2530+
25292531
Args:
25302532
session_id: The session ID from create_cpg_session
2531-
filename: Source file containing the line of interest
2532-
line_number: Line number of the call/statement to slice from
2533-
call_name: Optional: specific call name to slice (e.g., "memcpy", "strcpy")
2534-
If not provided, slices all calls on that line
2533+
node_id: Preferred: Direct CPG node ID of the target call
2534+
(Get from list_calls or other query results)
2535+
Example: "12345"
2536+
location: Alternative: "filename:line_number" or "filename:line_number:call_name"
2537+
Example: "main.c:42" or "main.c:42:memcpy"
25352538
include_dataflow: Include dataflow (variable assignments) in slice (default: true)
25362539
include_control_flow: Include control dependencies (if/while conditions) (default: true)
25372540
max_depth: Maximum depth for dataflow tracking (default: 5)
@@ -2542,31 +2545,30 @@ async def get_program_slice(
25422545
"success": true,
25432546
"slice": {
25442547
"target_call": {
2548+
"node_id": "12345",
25452549
"name": "memcpy",
25462550
"code": "memcpy(buf, src, size)",
25472551
"filename": "main.c",
25482552
"lineNumber": 42,
2553+
"method": "process_data",
25492554
"arguments": ["buf", "src", "size"]
25502555
},
25512556
"dataflow": [
25522557
{
25532558
"variable": "buf",
2554-
"definition": "char buf[256]",
2559+
"code": "char buf[256]",
25552560
"filename": "main.c",
25562561
"lineNumber": 10,
2557-
"dependencies": [...]
2562+
"method": "process_data"
25582563
}
25592564
],
25602565
"control_dependencies": [
25612566
{
2562-
"condition": "if (user_input != NULL)",
2567+
"code": "if (user_input != NULL)",
25632568
"filename": "main.c",
2564-
"lineNumber": 35
2569+
"lineNumber": 35,
2570+
"method": "process_data"
25652571
}
2566-
],
2567-
"call_graph": [
2568-
{"from": "main", "to": "memcpy", "depth": 1},
2569-
{"from": "memcpy", "to": "__memcpy", "depth": 2}
25702572
]
25712573
},
25722574
"total_nodes": 15
@@ -2575,6 +2577,10 @@ async def get_program_slice(
25752577
try:
25762578
validate_session_id(session_id)
25772579

2580+
# Validate that we have proper node identification
2581+
if not node_id and not location:
2582+
raise ValidationError("Either node_id or location must be provided")
2583+
25782584
session_manager = services["session_manager"]
25792585
query_executor = services["query_executor"]
25802586

@@ -2587,159 +2593,138 @@ async def get_program_slice(
25872593

25882594
await session_manager.touch_session(session_id)
25892595

2590-
# Step 1: Find the target call(s) at the specified location
2591-
call_filter = f'.name("{call_name}")' if call_name else ""
2592-
target_query = (
2593-
f'cpg.call{call_filter}'
2594-
f'.where(_.file.name(".*{re.escape(filename)}.*"))'
2595-
f'.lineNumber({line_number})'
2596-
f'.map(c => (c.name, c.code, c.file.name.headOption.getOrElse("unknown"), c.lineNumber.getOrElse(-1), '
2597-
f'c.argument.l.map(_.code), c.method.name))'
2598-
)
2599-
2600-
target_result = await query_executor.execute_query(
2596+
# Step 1: Resolve target call node
2597+
target_call = None
2598+
2599+
if node_id:
2600+
# Direct node ID lookup - most efficient and unambiguous
2601+
query = (
2602+
f'cpg.id({node_id}).map(c => (c.id, c.name, c.code, c.file.name.headOption.getOrElse("unknown"), '
2603+
f'c.lineNumber.getOrElse(-1), c.method.name, c.argument.code.l)).headOption'
2604+
)
2605+
else:
2606+
# Parse location string to find call
2607+
parts = location.split(":")
2608+
if len(parts) < 2:
2609+
raise ValidationError("location must be in format 'filename:line' or 'filename:line:callname'")
2610+
2611+
filename = parts[0]
2612+
line_num = parts[1]
2613+
call_name = parts[2] if len(parts) > 2 else None
2614+
2615+
# Build query to find call at location
2616+
if call_name:
2617+
query = (
2618+
f'cpg.file.name(".*{re.escape(filename)}.*").call.name("{call_name}").lineNumber({line_num})'
2619+
f'.map(c => (c.id, c.name, c.code, c.file.name.headOption.getOrElse("unknown"), '
2620+
f'c.lineNumber.getOrElse(-1), c.method.name, c.argument.code.l)).headOption'
2621+
)
2622+
else:
2623+
query = (
2624+
f'cpg.file.name(".*{re.escape(filename)}.*").call.lineNumber({line_num})'
2625+
f'.map(c => (c.id, c.name, c.code, c.file.name.headOption.getOrElse("unknown"), '
2626+
f'c.lineNumber.getOrElse(-1), c.method.name, c.argument.code.l)).headOption'
2627+
)
2628+
2629+
result = await query_executor.execute_query(
26012630
session_id=session_id,
26022631
cpg_path="/workspace/cpg.bin",
2603-
query=target_query,
2604-
timeout=30,
2605-
limit=10,
2632+
query=query,
2633+
timeout=10,
2634+
limit=1,
26062635
)
26072636

2608-
if not target_result.success or not target_result.data:
2637+
if not result.success or not result.data or not result.data[0].get("_1"):
26092638
return {
26102639
"success": False,
26112640
"error": {
26122641
"code": "NOT_FOUND",
2613-
"message": f"No call found at {filename}:{line_number}"
2642+
"message": f"Call not found: node_id={node_id}, location={location}"
26142643
},
26152644
}
26162645

2617-
# Parse target call information
2618-
target_item = target_result.data[0]
2619-
if not isinstance(target_item, dict):
2620-
return {
2621-
"success": False,
2622-
"error": {"code": "PARSE_ERROR", "message": "Invalid target call data"},
2623-
}
2624-
2646+
item = result.data[0]
26252647
target_call = {
2626-
"name": target_item.get("_1", ""),
2627-
"code": target_item.get("_2", ""),
2628-
"filename": target_item.get("_3", ""),
2629-
"lineNumber": target_item.get("_4", -1),
2630-
"arguments": target_item.get("_5", []),
2631-
"method": target_item.get("_6", ""),
2648+
"node_id": item.get("_1"),
2649+
"name": item.get("_2", ""),
2650+
"code": item.get("_3", ""),
2651+
"filename": item.get("_4", ""),
2652+
"lineNumber": item.get("_5", -1),
2653+
"method": item.get("_6", ""),
2654+
"arguments": item.get("_7", []),
26322655
}
26332656

26342657
slice_result = {
26352658
"target_call": target_call,
26362659
"dataflow": [],
26372660
"control_dependencies": [],
2638-
"call_graph": [],
26392661
}
26402662

26412663
# Step 2: Get dataflow for arguments
26422664
if include_dataflow and target_call["arguments"]:
2643-
# Build query to track dataflow for each argument using reachableByFlows
26442665
for arg in target_call["arguments"]:
2645-
# Clean up argument (remove operators, casts, etc)
2646-
clean_arg = arg.strip()
2647-
if not clean_arg or clean_arg.isdigit() or clean_arg.startswith('"'):
2648-
continue # Skip literals
2666+
# Clean up argument
2667+
clean_arg = arg.strip().replace("\"", "")
2668+
if not clean_arg or clean_arg.isdigit() or clean_arg.startswith("(") or clean_arg.startswith("0x"):
2669+
continue
26492670

2650-
# Query for dataflow using reachableByFlows from identifiers to the call argument
2671+
# Find identifiers with this name and their definitions
26512672
dataflow_query = (
2652-
f'val sources = cpg.identifier.name("{re.escape(clean_arg)}").l\n'
2653-
f'val sink = cpg.call.where(_.file.name(".*{re.escape(filename)}.*"))'
2654-
f'.lineNumber({line_number})'
2655-
f'.argument.code(".*{re.escape(clean_arg)}.*").l\n'
2656-
f'if (sources.nonEmpty && sink.nonEmpty) {{\n'
2657-
f' sink.reachableByFlows(sources).map(flow => {{\n'
2658-
f' val elems = flow.elements\n'
2659-
f' elems.take(15).map(e => (e.code, e.file.name.headOption.getOrElse("unknown"), '
2660-
f'e.lineNumber.getOrElse(-1), e.method.name))\n'
2661-
f' }}).take({max_depth}).l.flatten\n'
2662-
f'}} else List()'
2673+
f'cpg.identifier.name("{re.escape(clean_arg)}").l.take(10).map(id => '
2674+
f'(id.code, id.file.name.headOption.getOrElse("unknown"), '
2675+
f'id.lineNumber.getOrElse(-1), id.method.name))'
26632676
)
26642677

26652678
dataflow_result = await query_executor.execute_query(
26662679
session_id=session_id,
26672680
cpg_path="/workspace/cpg.bin",
26682681
query=dataflow_query,
2669-
timeout=20,
2670-
limit=50,
2682+
timeout=15,
2683+
limit=20,
26712684
)
26722685

26732686
if dataflow_result.success and dataflow_result.data:
2674-
for item in dataflow_result.data:
2675-
if isinstance(item, dict):
2687+
for dflow_item in dataflow_result.data[:5]: # Limit to 5 per argument
2688+
if isinstance(dflow_item, dict):
26762689
slice_result["dataflow"].append({
26772690
"variable": clean_arg,
2678-
"definition": item.get("_1", ""),
2679-
"filename": item.get("_2", ""),
2680-
"lineNumber": item.get("_3", -1),
2681-
"method": item.get("_4", ""),
2691+
"code": dflow_item.get("_1", ""),
2692+
"filename": dflow_item.get("_2", ""),
2693+
"lineNumber": dflow_item.get("_3", -1),
2694+
"method": dflow_item.get("_4", ""),
26822695
})
26832696

26842697
# Step 3: Get control dependencies
26852698
if include_control_flow:
2686-
# Query for control dependencies using controlledBy
2699+
# Query using node ID for precise control dependency lookup
26872700
control_query = (
2688-
f'cpg.call.where(_.file.name(".*{re.escape(filename)}.*"))'
2689-
f'.lineNumber({line_number})'
2690-
f'.controlledBy'
2691-
f'.map(n => (n.code, n.file.name.headOption.getOrElse("unknown"), n.lineNumber.getOrElse(-1)))'
2692-
f'.dedup.take(20)'
2701+
f'cpg.id({target_call["node_id"]}).controlledBy.map(ctrl => '
2702+
f'(ctrl.code, ctrl.file.name.headOption.getOrElse("unknown"), '
2703+
f'ctrl.lineNumber.getOrElse(-1), ctrl.method.name)).dedup.take(20)'
26932704
)
26942705

26952706
control_result = await query_executor.execute_query(
26962707
session_id=session_id,
26972708
cpg_path="/workspace/cpg.bin",
26982709
query=control_query,
2699-
timeout=20,
2710+
timeout=15,
27002711
limit=20,
27012712
)
27022713

27032714
if control_result.success and control_result.data:
2704-
for item in control_result.data:
2705-
if isinstance(item, dict):
2715+
for ctrl_item in control_result.data:
2716+
if isinstance(ctrl_item, dict):
27062717
slice_result["control_dependencies"].append({
2707-
"condition": item.get("_1", ""),
2708-
"filename": item.get("_2", ""),
2709-
"lineNumber": item.get("_3", -1),
2710-
})
2711-
2712-
# Step 4: Get call graph (outgoing calls from the target call's method)
2713-
if target_call["method"]:
2714-
callgraph_query = (
2715-
f'cpg.method.name("{target_call["method"]}")'
2716-
f'.call.nameNot("<operator>.*")'
2717-
f'.map(c => ("{target_call["method"]}", c.name, 1))'
2718-
f'.dedup.take(30)'
2719-
)
2720-
2721-
callgraph_result = await query_executor.execute_query(
2722-
session_id=session_id,
2723-
cpg_path="/workspace/cpg.bin",
2724-
query=callgraph_query,
2725-
timeout=20,
2726-
limit=30,
2727-
)
2728-
2729-
if callgraph_result.success and callgraph_result.data:
2730-
for item in callgraph_result.data:
2731-
if isinstance(item, dict):
2732-
slice_result["call_graph"].append({
2733-
"from": item.get("_1", ""),
2734-
"to": item.get("_2", ""),
2735-
"depth": item.get("_3", 1),
2718+
"code": ctrl_item.get("_1", ""),
2719+
"filename": ctrl_item.get("_2", ""),
2720+
"lineNumber": ctrl_item.get("_3", -1),
2721+
"method": ctrl_item.get("_4", ""),
27362722
})
27372723

27382724
total_nodes = (
27392725
1 + # target call
27402726
len(slice_result["dataflow"]) +
2741-
len(slice_result["control_dependencies"]) +
2742-
len(slice_result["call_graph"])
2727+
len(slice_result["control_dependencies"])
27432728
)
27442729

27452730
return {

0 commit comments

Comments
 (0)