|
13 | 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | 14 | # See the License for the specific language governing permissions and |
15 | 15 | # limitations under the License. |
| 16 | +import enum |
16 | 17 |
|
17 | 18 | import math |
18 | 19 | import struct |
19 | 20 | from typing import Union, Optional |
20 | 21 |
|
21 | 22 | from scalecodec.base import ScaleType, ScaleBytes, ScalePrimitive, ScaleTypeDef |
22 | 23 | from scalecodec.constants import TYPE_DECOMP_MAX_RECURSIVE |
23 | | -from scalecodec.exceptions import ScaleEncodeException, ScaleDecodeException, ScaleDeserializeException |
| 24 | +from scalecodec.exceptions import ScaleEncodeException, ScaleDecodeException, ScaleDeserializeException, \ |
| 25 | + ScaleSerializeException |
24 | 26 |
|
25 | 27 |
|
26 | 28 | class UnsignedInteger(ScalePrimitive): |
@@ -230,6 +232,8 @@ def decode(self, data) -> dict: |
230 | 232 | return value |
231 | 233 |
|
232 | 234 | def serialize(self, value: dict) -> dict: |
| 235 | + if value is None: |
| 236 | + raise ScaleSerializeException('Value cannot be None') |
233 | 237 | return {k: obj.value for k, obj in value.items()} |
234 | 238 |
|
235 | 239 | def deserialize(self, value: dict) -> dict: |
@@ -400,6 +404,12 @@ def deserialize(self, value: Union[str, dict]) -> tuple: |
400 | 404 | if type(value) is str: |
401 | 405 | value = {value: None} |
402 | 406 |
|
| 407 | + if isinstance(value, enum.Enum): |
| 408 | + value = {value.name: None} |
| 409 | + |
| 410 | + if len(list(value.items())) != 1: |
| 411 | + raise ScaleDeserializeException("Only one variant can be specified for enums") |
| 412 | + |
403 | 413 | enum_key, enum_value = list(value.items())[0] |
404 | 414 |
|
405 | 415 | for idx, (variant_name, variant_obj) in enumerate(self.variants.items()): |
|
0 commit comments