@@ -1207,20 +1207,29 @@ def find_function_node(node, target_name: str):
12071207 return node
12081208
12091209 # Check function declarations
1210- if node .type in ("function_declaration" , "function" ):
1210+ if node .type in (
1211+ "function_declaration" ,
1212+ "function" ,
1213+ "generator_function_declaration" ,
1214+ "generator_function" ,
1215+ ):
12111216 name_node = node .child_by_field_name ("name" )
12121217 if name_node :
12131218 name = source_bytes [name_node .start_byte : name_node .end_byte ].decode ("utf8" )
12141219 if name == target_name :
12151220 return node
12161221
1217- # Check arrow functions assigned to variables
1218- if node .type == "lexical_declaration" :
1222+ # Check arrow functions and function expressions assigned to variables
1223+ if node .type in ( "lexical_declaration" , "variable_declaration" ) :
12191224 for child in node .children :
12201225 if child .type == "variable_declarator" :
12211226 name_node = child .child_by_field_name ("name" )
12221227 value_node = child .child_by_field_name ("value" )
1223- if name_node and value_node and value_node .type == "arrow_function" :
1228+ if (
1229+ name_node
1230+ and value_node
1231+ and value_node .type in ("arrow_function" , "function_expression" , "generator_function" )
1232+ ):
12241233 name = source_bytes [name_node .start_byte : name_node .end_byte ].decode ("utf8" )
12251234 if name == target_name :
12261235 return value_node
@@ -1235,6 +1244,7 @@ def find_function_node(node, target_name: str):
12351244
12361245 func_node = find_function_node (tree .root_node , function_name )
12371246 if not func_node :
1247+ logger .debug ("Could not find function '%s' in optimized code for body extraction" , function_name )
12381248 return None
12391249
12401250 # Find the body node
@@ -1295,14 +1305,21 @@ def find_function_at_line(node, target_name: str, target_line: int):
12951305 if name == target_name and (node .start_point [0 ] + 1 ) == target_line :
12961306 return node
12971307
1298- if node .type == "lexical_declaration" :
1308+ if node .type in ( "lexical_declaration" , "variable_declaration" ) :
12991309 for child in node .children :
13001310 if child .type == "variable_declarator" :
13011311 name_node = child .child_by_field_name ("name" )
13021312 value_node = child .child_by_field_name ("value" )
1303- if name_node and value_node and value_node .type == "arrow_function" :
1313+ if (
1314+ name_node
1315+ and value_node
1316+ and value_node .type in ("arrow_function" , "function_expression" , "generator_function" )
1317+ ):
13041318 name = source_bytes [name_node .start_byte : name_node .end_byte ].decode ("utf8" )
1305- if name == target_name and (node .start_point [0 ] + 1 ) == target_line :
1319+ if name == target_name and (
1320+ (node .start_point [0 ] + 1 ) == target_line
1321+ or (value_node .start_point [0 ] + 1 ) == target_line
1322+ ):
13061323 return value_node
13071324
13081325 for child in node .children :
@@ -1686,26 +1703,14 @@ def validate_syntax(self, source: str) -> bool:
16861703 return False
16871704
16881705 def normalize_code (self , source : str ) -> str :
1689- """Normalize JavaScript code for deduplication.
1690-
1691- Removes comments and normalizes whitespace.
1692-
1693- Args:
1694- source: Source code to normalize.
1695-
1696- Returns:
1697- Normalized source code.
1706+ """Normalize JavaScript code for deduplication using tree-sitter."""
1707+ from codeflash .languages .javascript .normalizer import normalize_js_code
16981708
1699- """
1700- # Simple normalization: remove extra whitespace
1701- # A full implementation would use tree-sitter to strip comments
1702- lines = source .splitlines ()
1703- normalized_lines = []
1704- for line in lines :
1705- stripped = line .strip ()
1706- if stripped and not stripped .startswith ("//" ):
1707- normalized_lines .append (stripped )
1708- return "\n " .join (normalized_lines )
1709+ try :
1710+ is_ts = self .treesitter_language == TreeSitterLanguage .TYPESCRIPT
1711+ return normalize_js_code (source , typescript = is_ts )
1712+ except Exception :
1713+ return source
17091714
17101715 def generate_concolic_tests (
17111716 self , test_cfg : Any , project_root : Any , function_to_optimize : Any , function_to_optimize_ast : Any
0 commit comments