Skip to content

Commit 2f78e73

Browse files
committed
add equals generation
1 parent 58b68e7 commit 2f78e73

2 files changed

Lines changed: 56 additions & 0 deletions

File tree

src/method_generator.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,34 @@ def generate_setter(field: Field) -> str:
106106
return setter
107107

108108

109+
def generate_equals(class_name: str, fields: List[Field]) -> str:
110+
equals = [
111+
"",
112+
f"{indent_lvl1}@Override",
113+
f"{indent_lvl1}public boolean equals(Object obj) {{",
114+
115+
f"{indent_lvl2}if (this == obj)",
116+
f"{indent_lvl3}return true;",
117+
118+
f"{indent_lvl2}if (!(obj instanceof {class_name}))",
119+
f"{indent_lvl3}return false;",
120+
121+
f"{indent_lvl2}{class_name} that = ({class_name}) obj;"
122+
]
123+
124+
for i, field in enumerate(fields):
125+
_validate_java_identifier(field.name)
126+
getter_name = "get" + field.name[0].upper() + field.name[1:]
127+
semicolon = ";" if i == (len(fields) - 1) else ""
128+
if i == 0:
129+
equals.append(f"{indent_lvl2}return Objects.equals({getter_name}(), that.{getter_name}()){semicolon}")
130+
else:
131+
equals.append(f"{indent_lvl2} && Objects.equals({getter_name}(), that.{getter_name}()){semicolon}")
132+
equals.append(f"{indent_lvl1}}}")
133+
134+
return "\n".join(equals)
135+
136+
109137
def generate_hash_code(fields: List[Field]) -> str:
110138
hash_code = [
111139
"",

tests/test_method_generator.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,31 @@ def test_generate_hash_code_invalid_name():
160160
attr = Field(name=name, type="String")
161161
with pytest.raises(ValueError):
162162
generate_hash_code([attr])
163+
164+
165+
def test_generate_equals():
166+
class_name = "MyClass"
167+
attr1 = field_exampleAttribute_int
168+
attr2 = field_someName_String
169+
attr3 = field_customData_CustomObject
170+
attributes = [attr1, attr2, attr3]
171+
expected = """
172+
@Override
173+
public boolean equals(Object obj) {
174+
if (this == obj)
175+
return true;
176+
if (!(obj instanceof MyClass))
177+
return false;
178+
MyClass that = (MyClass) obj;
179+
return Objects.equals(getExampleAttribute(), that.getExampleAttribute())
180+
&& Objects.equals(getSomeName(), that.getSomeName())
181+
&& Objects.equals(getCustomData(), that.getCustomData());
182+
}"""
183+
assert generate_equals(class_name, attributes) == expected
184+
185+
186+
def test_generate_hash_code_invalid_name():
187+
for name in illegal_names:
188+
attr = Field(name=name, type="String")
189+
with pytest.raises(ValueError):
190+
generate_equals("MyClass", [attr])

0 commit comments

Comments
 (0)