Skip to content

Commit 8dccaea

Browse files
Fix enum from dict (#130)
* Fix enum from dict * Fix type checking
1 parent ebc67b4 commit 8dccaea

4 files changed

Lines changed: 32 additions & 20 deletions

File tree

betterproto2/src/betterproto2/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,8 @@ def _value_from_dict(value: Any, meta: FieldMetadata, field_type: type) -> Any:
611611

612612
if meta.proto_type == TYPE_ENUM:
613613
if isinstance(value, str):
614+
if (int_value := field_type.betterproto_renamed_proto_names_to_value().get(value)) is not None:
615+
return field_type(int_value)
614616
return field_type.from_string(value)
615617
if isinstance(value, int):
616618
return field_type(value)

betterproto2/src/betterproto2/enum_.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,12 @@
1-
import sys
21
from enum import EnumMeta, IntEnum
32

43
from typing_extensions import Self
54

65

76
class _EnumMeta(EnumMeta):
87
def __new__(metacls, cls, bases, classdict):
9-
# Find the proto names if defined
10-
if sys.version_info >= (3, 11):
11-
proto_names = classdict.pop("betterproto_proto_names", {})
12-
classdict._member_names.pop("betterproto_proto_names", None)
13-
else:
14-
proto_names = {}
15-
if "betterproto_proto_names" in classdict:
16-
proto_names = classdict.pop("betterproto_proto_names")
17-
classdict._member_names.remove("betterproto_proto_names")
18-
198
enum_class = super().__new__(metacls, cls, bases, classdict)
9+
proto_names = enum_class.betterproto_value_to_renamed_proto_names() # type: ignore[reportAttributeAccessIssue]
2010

2111
# Attach extra info to each enum member
2212
for member in enum_class:
@@ -32,6 +22,14 @@ class Enum(IntEnum, metaclass=_EnumMeta):
3222
def proto_name(self) -> str | None:
3323
return self._proto_name # type: ignore[reportAttributeAccessIssue]
3424

25+
@classmethod
26+
def betterproto_value_to_renamed_proto_names(cls) -> dict[int, str]:
27+
return {}
28+
29+
@classmethod
30+
def betterproto_renamed_proto_names_to_value(cls) -> dict[str, int]:
31+
return {}
32+
3533
@classmethod
3634
def _missing_(cls, value):
3735
# If the given value is not an integer, let the standard enum implementation raise an error

betterproto2/tests/test_enum.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,13 @@ def test_enum_to_dict() -> None:
9292
no_striping=NoStriping.NO_STRIPING_A,
9393
)
9494

95-
print(ArithmeticOperator.PLUS.proto_name)
96-
9795
assert msg.to_dict() == {
9896
"arithmeticOperator": "ARITHMETIC_OPERATOR_PLUS", # The original proto name must be preserved
9997
"noStriping": "NO_STRIPING_A",
10098
}
10199

100+
assert EnumMessage.from_dict(msg.to_dict()) == msg
101+
102102

103103
def test_unknown_variant_to_dict() -> None:
104104
from tests.outputs.enum.enum import NewVersion, NewVersionMessage, OldVersionMessage

betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,25 @@ class {{ enum.py_name | add_to_all }}(betterproto2.Enum):
3232
{% endif %}
3333

3434
{% if enum.has_renamed_entries %}
35-
betterproto_proto_names = {
36-
{% for entry in enum.entries %}
37-
{% if entry.proto_name != entry.name %}
38-
{{ entry.value }}: "{{ entry.proto_name }}",
39-
{% endif %}
40-
{% endfor %}
41-
}
35+
@classmethod
36+
def betterproto_value_to_renamed_proto_names(cls) -> dict[int, str]:
37+
return {
38+
{% for entry in enum.entries %}
39+
{% if entry.proto_name != entry.name %}
40+
{{ entry.value }}: "{{ entry.proto_name }}",
41+
{% endif %}
42+
{% endfor %}
43+
}
44+
45+
@classmethod
46+
def betterproto_renamed_proto_names_to_value(cls) -> dict[str, int]:
47+
return {
48+
{% for entry in enum.entries %}
49+
{% if entry.proto_name != entry.name %}
50+
"{{ entry.proto_name }}": {{ entry.value }},
51+
{% endif %}
52+
{% endfor %}
53+
}
4254
{% endif %}
4355

4456
{% endfor %}

0 commit comments

Comments
 (0)