Skip to content

Commit 353feab

Browse files
committed
[Fix] Normalizer and expand its scope
1 parent 5d872e8 commit 353feab

5 files changed

Lines changed: 289 additions & 43 deletions

File tree

codeflash/languages/javascript/normalizer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""JavaScript/TypeScript code normalizer using tree-sitter.
22
3-
Not currently wired into JavaScriptSupport.normalize_code — kept as a
4-
ready-to-use upgrade path when AST-based JS deduplication is needed.
3+
Wired into JavaScriptSupport.normalize_code for AST-based JS deduplication.
54
65
The old CodeNormalizer ABC (deleted from base.py) is preserved below for reference.
76
"""
@@ -236,8 +235,7 @@ def normalize_js_code(code: str, typescript: bool = False) -> str:
236235
Uses tree-sitter to parse and normalize variable names. Falls back to
237236
basic comment/whitespace stripping if tree-sitter is unavailable or parsing fails.
238237
239-
Not currently wired into JavaScriptSupport.normalize_code — kept as a
240-
ready-to-use upgrade path when AST-based JS deduplication is needed.
238+
Wired into JavaScriptSupport.normalize_code for AST-based JS deduplication.
241239
"""
242240
try:
243241
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage

codeflash/languages/javascript/support.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,20 +1207,29 @@ def find_function_node(node, target_name: str):
12071207
return node
12081208

12091209
# Check function declarations
1210-
if node.type in ("function_declaration", "function"):
1210+
if node.type in (
1211+
"function_declaration",
1212+
"function",
1213+
"generator_function_declaration",
1214+
"generator_function",
1215+
):
12111216
name_node = node.child_by_field_name("name")
12121217
if name_node:
12131218
name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
12141219
if name == target_name:
12151220
return node
12161221

1217-
# Check arrow functions assigned to variables
1218-
if node.type == "lexical_declaration":
1222+
# Check arrow functions and function expressions assigned to variables
1223+
if node.type in ("lexical_declaration", "variable_declaration"):
12191224
for child in node.children:
12201225
if child.type == "variable_declarator":
12211226
name_node = child.child_by_field_name("name")
12221227
value_node = child.child_by_field_name("value")
1223-
if name_node and value_node and value_node.type == "arrow_function":
1228+
if (
1229+
name_node
1230+
and value_node
1231+
and value_node.type in ("arrow_function", "function_expression", "generator_function")
1232+
):
12241233
name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
12251234
if name == target_name:
12261235
return value_node
@@ -1235,6 +1244,7 @@ def find_function_node(node, target_name: str):
12351244

12361245
func_node = find_function_node(tree.root_node, function_name)
12371246
if not func_node:
1247+
logger.debug("Could not find function '%s' in optimized code for body extraction", function_name)
12381248
return None
12391249

12401250
# Find the body node
@@ -1295,14 +1305,21 @@ def find_function_at_line(node, target_name: str, target_line: int):
12951305
if name == target_name and (node.start_point[0] + 1) == target_line:
12961306
return node
12971307

1298-
if node.type == "lexical_declaration":
1308+
if node.type in ("lexical_declaration", "variable_declaration"):
12991309
for child in node.children:
13001310
if child.type == "variable_declarator":
13011311
name_node = child.child_by_field_name("name")
13021312
value_node = child.child_by_field_name("value")
1303-
if name_node and value_node and value_node.type == "arrow_function":
1313+
if (
1314+
name_node
1315+
and value_node
1316+
and value_node.type in ("arrow_function", "function_expression", "generator_function")
1317+
):
13041318
name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
1305-
if name == target_name and (node.start_point[0] + 1) == target_line:
1319+
if name == target_name and (
1320+
(node.start_point[0] + 1) == target_line
1321+
or (value_node.start_point[0] + 1) == target_line
1322+
):
13061323
return value_node
13071324

13081325
for child in node.children:
@@ -1686,26 +1703,14 @@ def validate_syntax(self, source: str) -> bool:
16861703
return False
16871704

16881705
def normalize_code(self, source: str) -> str:
1689-
"""Normalize JavaScript code for deduplication.
1690-
1691-
Removes comments and normalizes whitespace.
1692-
1693-
Args:
1694-
source: Source code to normalize.
1695-
1696-
Returns:
1697-
Normalized source code.
1706+
"""Normalize JavaScript code for deduplication using tree-sitter."""
1707+
from codeflash.languages.javascript.normalizer import normalize_js_code
16981708

1699-
"""
1700-
# Simple normalization: remove extra whitespace
1701-
# A full implementation would use tree-sitter to strip comments
1702-
lines = source.splitlines()
1703-
normalized_lines = []
1704-
for line in lines:
1705-
stripped = line.strip()
1706-
if stripped and not stripped.startswith("//"):
1707-
normalized_lines.append(stripped)
1708-
return "\n".join(normalized_lines)
1709+
try:
1710+
is_ts = self.treesitter_language == TreeSitterLanguage.TYPESCRIPT
1711+
return normalize_js_code(source, typescript=is_ts)
1712+
except Exception:
1713+
return source
17091714

17101715
def generate_concolic_tests(
17111716
self, test_cfg: Any, project_root: Any, function_to_optimize: Any, function_to_optimize_ast: Any

tests/test_code_deduplication.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from codeflash.languages.javascript.normalizer import normalize_js_code
12
from codeflash.languages.python.normalizer import normalize_python_code as normalize_code
23

34

@@ -133,3 +134,74 @@ def safe_divide(a, b):
133134
assert normalize_code(code9) == normalize_code(code10)
134135

135136
assert normalize_code(code9) != normalize_code(code8)
137+
138+
139+
# === JavaScript deduplication tests ===
140+
141+
142+
def test_js_deduplicate_same_logic_different_vars():
143+
code1 = """
144+
function process(items) {
145+
const result = [];
146+
for (const item of items) {
147+
result.push(item * 2);
148+
}
149+
return result;
150+
}
151+
"""
152+
code2 = """
153+
function process(items) {
154+
const output = [];
155+
for (const val of items) {
156+
output.push(val * 2);
157+
}
158+
return output;
159+
}
160+
"""
161+
assert normalize_js_code(code1) == normalize_js_code(code2)
162+
163+
164+
def test_js_different_logic_not_deduplicated():
165+
code1 = """
166+
function compute(x) {
167+
return x + 1;
168+
}
169+
"""
170+
code2 = """
171+
function compute(x) {
172+
return x * 2;
173+
}
174+
"""
175+
assert normalize_js_code(code1) != normalize_js_code(code2)
176+
177+
178+
def test_js_deduplicate_whitespace_and_comments():
179+
code1 = """
180+
function add(a, b) {
181+
// fast path
182+
return a + b;
183+
}
184+
"""
185+
code2 = """
186+
function add(a, b) {
187+
/* optimized */
188+
return a + b;
189+
}
190+
"""
191+
assert normalize_js_code(code1) == normalize_js_code(code2)
192+
193+
194+
def test_ts_normalize():
195+
code1 = """
196+
function greet(name: string): string {
197+
const msg = "hello " + name;
198+
return msg;
199+
}
200+
"""
201+
code2 = """
202+
function greet(name: string): string {
203+
const result = "hello " + name;
204+
return result;
205+
}
206+
"""
207+
assert normalize_js_code(code1, typescript=True) == normalize_js_code(code2, typescript=True)

tests/test_languages/test_javascript_support.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -443,10 +443,10 @@ def test_syntax_error_types(self, js_support):
443443

444444

445445
class TestNormalizeCode:
446-
"""Tests for normalize_code method."""
446+
"""Tests for normalize_code method using tree-sitter normalizer."""
447447

448448
def test_removes_comments(self, js_support):
449-
"""Test that single-line comments are removed."""
449+
"""Test that comments are absent from normalized output."""
450450
code = """
451451
function add(a, b) {
452452
// Add two numbers
@@ -455,19 +455,43 @@ def test_removes_comments(self, js_support):
455455
"""
456456
normalized = js_support.normalize_code(code)
457457
assert "// Add two numbers" not in normalized
458-
assert "return a + b" in normalized
458+
assert "Add two numbers" not in normalized
459459

460-
def test_preserves_functionality(self, js_support):
461-
"""Test that code functionality is preserved."""
462-
code = """
463-
function add(a, b) {
464-
// Comment
465-
return a + b;
460+
def test_same_logic_different_vars_are_equal(self, js_support):
461+
"""Test that two functions with same logic but different variable names normalize identically."""
462+
code1 = """
463+
function process(items) {
464+
const result = [];
465+
for (const item of items) {
466+
result.push(item * 2);
467+
}
468+
return result;
466469
}
467470
"""
468-
normalized = js_support.normalize_code(code)
469-
assert "function add" in normalized
470-
assert "return" in normalized
471+
code2 = """
472+
function process(items) {
473+
const output = [];
474+
for (const val of items) {
475+
output.push(val * 2);
476+
}
477+
return output;
478+
}
479+
"""
480+
assert js_support.normalize_code(code1) == js_support.normalize_code(code2)
481+
482+
def test_different_logic_not_equal(self, js_support):
483+
"""Test that two functions with different logic produce different normalized forms."""
484+
code1 = """
485+
function compute(x) {
486+
return x + 1;
487+
}
488+
"""
489+
code2 = """
490+
function compute(x) {
491+
return x * 2;
492+
}
493+
"""
494+
assert js_support.normalize_code(code1) != js_support.normalize_code(code2)
471495

472496

473497
class TestExtractCodeContext:

0 commit comments

Comments
 (0)