Skip to content

Commit 681bdda

Browse files
authored
Update javasitter.py
1 parent 96c8270 commit 681bdda

1 file changed

Lines changed: 85 additions & 73 deletions

File tree

cldk/analysis/java/treesitter/javasitter.py

Lines changed: 85 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,73 @@ def method_is_not_in_class(self, method_name: str, class_body: str) -> bool:
3333

3434
return method_name not in {method.node.text.decode() for method in methods_in_class}
3535

36+
def get_all_imports(self, source_code: str) -> Set[str]:
37+
"""Get a list of all the imports in a class.
38+
39+
Args:
40+
source_code (str): The source code to process.
41+
42+
Returns:
43+
Set[str]: A set of all the imports in the class.
44+
"""
45+
import_declerations: Captures = self.frame_query_and_capture_output(query="(import_declaration (scoped_identifier) @name)", code_to_process=source_code)
46+
return {capture.node.text.decode() for capture in import_declerations}
47+
48+
def get_pacakge_name(self, source_code: str) -> str:
49+
"""Get the package name from the source code.
50+
51+
Args:
52+
source_code (str): The source code to process.
53+
54+
Returns:
55+
str: The package name.
56+
"""
57+
package_name: Captures = self.frame_query_and_capture_output(query="((package_declaration) @name)", code_to_process=source_code)
58+
if package_name:
59+
return package_name[0].node.text.decode().replace("package ", "").replace(";", "")
60+
return None
61+
62+
def get_class_name(self, source_code: str) -> str:
63+
"""Get the class name from the source code.
64+
65+
Args:
66+
source_code (str): The source code to process.
67+
68+
Returns:
69+
str: The class name.
70+
"""
71+
class_name = self.frame_query_and_capture_output("(class_declaration name: (identifier) @name)", source_code)
72+
return class_name[0].node.text.decode()
73+
74+
def get_superclass(self, source_code: str) -> str:
75+
"""Get a list of all the superclasses in a class.
76+
77+
Args:
78+
source_code (str): The source code to process.
79+
80+
Returns:
81+
Set[str]: A set of all the superclasses in the class.
82+
"""
83+
superclass: Captures = self.frame_query_and_capture_output(query="(class_declaration (superclass (type_identifier) @superclass))", code_to_process=source_code)
84+
85+
if len(superclass) == 0:
86+
return ""
87+
88+
return superclass[0].node.text.decode()
89+
90+
def get_all_interfaces(self, source_code: str) -> Set[str]:
91+
"""Get a set of interfaces implemented by a class.
92+
93+
Args:
94+
source_code (str): The source code to process.
95+
96+
Returns:
97+
Set[str]: A set of all the interfaces implemented by the class.
98+
"""
99+
100+
interfaces = self.frame_query_and_capture_output("(class_declaration (super_interfaces (type_list (type_identifier) @interface)))", code_to_process=source_code)
101+
return {interface.node.text.decode() for interface in interfaces}
102+
36103
def frame_query_and_capture_output(self, query: str, code_to_process: str) -> Captures:
37104
"""Frame a query for the tree-sitter parser.
38105
@@ -232,79 +299,6 @@ def get_methods_with_annotations(self, source_class_code: str, annotations: List
232299
annotation_method_dict[annotation] = [method]
233300
return annotation_method_dict
234301

235-
def get_fields_with_annotations(self, source_class_code: str) -> Dict[str, Dict]:
236-
"""
237-
Returns a dictionary of field names and field bodies.
238-
239-
Parameters:
240-
-----------
241-
source_class_code : str
242-
String containing code for a java class.
243-
Returns:
244-
--------
245-
Dict[str,Dict]
246-
Dictionary with field names as keys and a dictionary of annotation and body as values.
247-
"""
248-
query = """
249-
(field_declaration
250-
(variable_declarator
251-
name: (identifier) @field_name
252-
)
253-
)
254-
"""
255-
captures: Captures = self.frame_query_and_capture_output(query, source_class_code)
256-
field_dict = {}
257-
for capture in captures:
258-
if capture.name == "field_name":
259-
field_name = capture.node.text.decode()
260-
inner_dict = {}
261-
annotation = None
262-
field_node = self.safe_ascend(capture.node, 2)
263-
body = field_node.text.decode()
264-
for fc in field_node.children:
265-
if fc.type == "modifiers":
266-
for mc in fc.children:
267-
if mc.type == "marker_annotation":
268-
annotation = mc.text.decode()
269-
inner_dict["annotation"] = annotation
270-
inner_dict["body"] = body
271-
field_dict[field_name] = inner_dict
272-
return field_dict
273-
274-
def get_field_accesses(self, source_class_code: str) -> Dict[str, list[list[int]]]:
275-
"""
276-
Returns a dictionary of field names with start and end positions of field accesses.
277-
278-
Parameters:
279-
-----------
280-
source_class_code : str
281-
String containing code for a java class.
282-
Returns:
283-
--------
284-
Dict[str, [[int, int], [int, int]]]
285-
Dictionary with field names as keys and a list of starting and ending line, and starting and ending column.
286-
"""
287-
query = """
288-
(field_access
289-
field:(identifier) @field_name
290-
)
291-
"""
292-
captures: Captures = self.frame_query_and_capture_output(query, source_class_code)
293-
field_dict = {}
294-
for capture in captures:
295-
if capture.name == "field_name":
296-
field_name = capture.node.text.decode()
297-
field_node = self.safe_ascend(capture.node, 2)
298-
start_line = field_node.start_point[0]
299-
start_column = field_node.start_point[1]
300-
end_line = field_node.end_point[0]
301-
end_column = field_node.end_point[1]
302-
start_list = [start_line, start_column]
303-
end_list = [end_line, end_column]
304-
position = [start_list, end_list]
305-
field_dict[field_name] = position
306-
return field_dict
307-
308302
def get_all_type_invocations(self, source_code: str) -> Set[str]:
309303
"""
310304
Given the source code, get all the type invocations in the source code.
@@ -322,6 +316,24 @@ def get_all_type_invocations(self, source_code: str) -> Set[str]:
322316
type_references: Captures = self.frame_query_and_capture_output("(type_identifier) @type_id", source_code)
323317
return {type_id.node.text.decode() for type_id in type_references}
324318

319+
def get_method_return_type(self, source_code: str) -> str:
320+
"""Get the return type of a method.
321+
322+
Parameters
323+
----------
324+
source_code : str
325+
The source code to process.
326+
327+
Returns
328+
-------
329+
str
330+
The return type of the method.
331+
"""
332+
333+
type_references: Captures = self.frame_query_and_capture_output("(method_declaration type: ((type_identifier) @type_id))", source_code)
334+
335+
return type_references[0].node.text.decode()
336+
325337
def get_lexical_tokens(self, code: str, filter_by_node_type: List[str] | None = None) -> List[str]:
326338
"""
327339
Get the lexical tokens given the code

0 commit comments

Comments
 (0)