@@ -33,6 +33,73 @@ def method_is_not_in_class(self, method_name: str, class_body: str) -> bool:
3333
3434 return method_name not in {method .node .text .decode () for method in methods_in_class }
3535
36+ def get_all_imports (self , source_code : str ) -> Set [str ]:
37+ """Get a list of all the imports in a class.
38+
39+ Args:
40+ source_code (str): The source code to process.
41+
42+ Returns:
43+ Set[str]: A set of all the imports in the class.
44+ """
45+ import_declerations : Captures = self .frame_query_and_capture_output (query = "(import_declaration (scoped_identifier) @name)" , code_to_process = source_code )
46+ return {capture .node .text .decode () for capture in import_declerations }
47+
48+ def get_pacakge_name (self , source_code : str ) -> str :
49+ """Get the package name from the source code.
50+
51+ Args:
52+ source_code (str): The source code to process.
53+
54+ Returns:
55+ str: The package name.
56+ """
57+ package_name : Captures = self .frame_query_and_capture_output (query = "((package_declaration) @name)" , code_to_process = source_code )
58+ if package_name :
59+ return package_name [0 ].node .text .decode ().replace ("package " , "" ).replace (";" , "" )
60+ return None
61+
62+ def get_class_name (self , source_code : str ) -> str :
63+ """Get the class name from the source code.
64+
65+ Args:
66+ source_code (str): The source code to process.
67+
68+ Returns:
69+ str: The class name.
70+ """
71+ class_name = self .frame_query_and_capture_output ("(class_declaration name: (identifier) @name)" , source_code )
72+ return class_name [0 ].node .text .decode ()
73+
74+ def get_superclass (self , source_code : str ) -> str :
75+ """Get a list of all the superclasses in a class.
76+
77+ Args:
78+ source_code (str): The source code to process.
79+
80+ Returns:
81+ Set[str]: A set of all the superclasses in the class.
82+ """
83+ superclass : Captures = self .frame_query_and_capture_output (query = "(class_declaration (superclass (type_identifier) @superclass))" , code_to_process = source_code )
84+
85+ if len (superclass ) == 0 :
86+ return ""
87+
88+ return superclass [0 ].node .text .decode ()
89+
90+ def get_all_interfaces (self , source_code : str ) -> Set [str ]:
91+ """Get a set of interfaces implemented by a class.
92+
93+ Args:
94+ source_code (str): The source code to process.
95+
96+ Returns:
97+ Set[str]: A set of all the interfaces implemented by the class.
98+ """
99+
100+ interfaces = self .frame_query_and_capture_output ("(class_declaration (super_interfaces (type_list (type_identifier) @interface)))" , code_to_process = source_code )
101+ return {interface .node .text .decode () for interface in interfaces }
102+
36103 def frame_query_and_capture_output (self , query : str , code_to_process : str ) -> Captures :
37104 """Frame a query for the tree-sitter parser.
38105
@@ -232,79 +299,6 @@ def get_methods_with_annotations(self, source_class_code: str, annotations: List
232299 annotation_method_dict [annotation ] = [method ]
233300 return annotation_method_dict
234301
235- def get_fields_with_annotations (self , source_class_code : str ) -> Dict [str , Dict ]:
236- """
237- Returns a dictionary of field names and field bodies.
238-
239- Parameters:
240- -----------
241- source_class_code : str
242- String containing code for a java class.
243- Returns:
244- --------
245- Dict[str,Dict]
246- Dictionary with field names as keys and a dictionary of annotation and body as values.
247- """
248- query = """
249- (field_declaration
250- (variable_declarator
251- name: (identifier) @field_name
252- )
253- )
254- """
255- captures : Captures = self .frame_query_and_capture_output (query , source_class_code )
256- field_dict = {}
257- for capture in captures :
258- if capture .name == "field_name" :
259- field_name = capture .node .text .decode ()
260- inner_dict = {}
261- annotation = None
262- field_node = self .safe_ascend (capture .node , 2 )
263- body = field_node .text .decode ()
264- for fc in field_node .children :
265- if fc .type == "modifiers" :
266- for mc in fc .children :
267- if mc .type == "marker_annotation" :
268- annotation = mc .text .decode ()
269- inner_dict ["annotation" ] = annotation
270- inner_dict ["body" ] = body
271- field_dict [field_name ] = inner_dict
272- return field_dict
273-
274- def get_field_accesses (self , source_class_code : str ) -> Dict [str , list [list [int ]]]:
275- """
276- Returns a dictionary of field names with start and end positions of field accesses.
277-
278- Parameters:
279- -----------
280- source_class_code : str
281- String containing code for a java class.
282- Returns:
283- --------
284- Dict[str, [[int, int], [int, int]]]
285- Dictionary with field names as keys and a list of starting and ending line, and starting and ending column.
286- """
287- query = """
288- (field_access
289- field:(identifier) @field_name
290- )
291- """
292- captures : Captures = self .frame_query_and_capture_output (query , source_class_code )
293- field_dict = {}
294- for capture in captures :
295- if capture .name == "field_name" :
296- field_name = capture .node .text .decode ()
297- field_node = self .safe_ascend (capture .node , 2 )
298- start_line = field_node .start_point [0 ]
299- start_column = field_node .start_point [1 ]
300- end_line = field_node .end_point [0 ]
301- end_column = field_node .end_point [1 ]
302- start_list = [start_line , start_column ]
303- end_list = [end_line , end_column ]
304- position = [start_list , end_list ]
305- field_dict [field_name ] = position
306- return field_dict
307-
308302 def get_all_type_invocations (self , source_code : str ) -> Set [str ]:
309303 """
310304 Given the source code, get all the type invocations in the source code.
@@ -322,6 +316,24 @@ def get_all_type_invocations(self, source_code: str) -> Set[str]:
322316 type_references : Captures = self .frame_query_and_capture_output ("(type_identifier) @type_id" , source_code )
323317 return {type_id .node .text .decode () for type_id in type_references }
324318
319+ def get_method_return_type (self , source_code : str ) -> str :
320+ """Get the return type of a method.
321+
322+ Parameters
323+ ----------
324+ source_code : str
325+ The source code to process.
326+
327+ Returns
328+ -------
329+ str
330+ The return type of the method.
331+ """
332+
333+ type_references : Captures = self .frame_query_and_capture_output ("(method_declaration type: ((type_identifier) @type_id))" , source_code )
334+
335+ return type_references [0 ].node .text .decode ()
336+
325337 def get_lexical_tokens (self , code : str , filter_by_node_type : List [str ] | None = None ) -> List [str ]:
326338 """
327339 Get the lexical tokens given the code
0 commit comments