From 5a74954844594d83d71056ee9d406df1a93a380a Mon Sep 17 00:00:00 2001 From: John Date: Mon, 27 Jan 2025 11:46:12 -0500 Subject: [PATCH 1/2] Adding support for shallow hashes --- .../tako/core/compiler/types/hash_expand.py | 8 +- python/tako/core/compiler/types/lower.py | 3 +- python/tako/core/compiler/types/mir.py | 11 ++- .../core/compiler/types/variant_expand.py | 10 ++- python/tako/core/repr_str.py | 88 +++++++++++++++++++ python/test_types/bakery/v4.py | 15 +--- shallow_hash.txt | 31 +++++++ 7 files changed, 148 insertions(+), 18 deletions(-) create mode 100644 shallow_hash.txt diff --git a/python/tako/core/compiler/types/hash_expand.py b/python/tako/core/compiler/types/hash_expand.py index f2298de..84eb063 100644 --- a/python/tako/core/compiler/types/hash_expand.py +++ b/python/tako/core/compiler/types/hash_expand.py @@ -16,7 +16,7 @@ import dataclasses from tako.core.compiler.types import mir from tako.core.error import Error -from tako.core.repr_str import ReprStr +from tako.core.repr_str import ReprStr, ShallowReprStr from tako.util.qname import QName import hashlib @@ -59,6 +59,10 @@ class HashExpand( ): types: t.Dict[QName, mir.RootType] + def shallow_digest(self, rt: mir.RootType) -> Digest: + x = rt.accept(ShallowReprStr(self.types)) + return Digest(repr_str=x, repr_hash=sha256hex(x)) + def digest(self, rt: mir.RootType) -> Digest: x = rt.accept(ReprStr(self.types)) return Digest(repr_str=x, repr_hash=sha256hex(x)) @@ -87,7 +91,7 @@ def visit_hash_variant( tag_map: t.Dict[mir.StructRef, int] = {} inv_tag_map: t.Dict[int, mir.StructRef] = {} for type_ in variant.types(): - hash_hex = self.digest(type_.resolve(self.types)).repr_hash + hash_hex = self.shallow_digest(type_.resolve(self.types)).repr_hash short = int(hash_hex[:tag_width_hex_digits], 16) if short in inv_tag_map: return Error( diff --git a/python/tako/core/compiler/types/lower.py b/python/tako/core/compiler/types/lower.py index 123cb85..2ec5172 100644 --- a/python/tako/core/compiler/types/lower.py +++ b/python/tako/core/compiler/types/lower.py @@ -88,6 +88,7 @@ def visit_detached_variant(self, type_: pt.DetachedVariant) -> mir.Type: return mir.DetachedVariant( mir.VariantRef(type_.variant.qualified_name()), mir.FieldReference(type_.tag.name), + False ) def visit_virtual(self, type_: pt.Virtual) -> mir.Type: @@ -103,4 +104,4 @@ def visit_variant_def(self, type_: pt.VariantDef) -> mir.Type: return mir.VariantRef(type_.qualified_name()) def visit_hash_variant_def(self, type_: pt.HashVariantDef) -> mir.Type: - return mir.VariantRef(type_.qualified_name()) + return mir.HashVariantRef(type_.qualified_name()) diff --git a/python/tako/core/compiler/types/mir.py b/python/tako/core/compiler/types/mir.py index 4339cbf..7b1f842 100644 --- a/python/tako/core/compiler/types/mir.py +++ b/python/tako/core/compiler/types/mir.py @@ -280,11 +280,11 @@ def accept(self, visitor: LengthVisitor[T]) -> T: class DetachedVariant(Type): variant: VariantRef tag: FieldReference + is_hash_tag: bool def accept(self, visitor: TypeVisitor[T]) -> T: return visitor.visit_detached_variant(self) - @dataclasses.dataclass(frozen=True) class Virtual(Type): inner: Type @@ -338,6 +338,15 @@ def resolve(self, context: t.Dict[QName, RootType]) -> Variant: def accept_r(self, visitor: RefVisitor[T]) -> T: return visitor.visit_variant_ref(self) +@dataclasses.dataclass(frozen=True) +class HashVariantRef(VariantRef): + def resolve(self, context: t.Dict[QName, RootType]) -> Variant: + return checked_cast(Variant, context[self.name]) + + def accept_r(self, visitor: RefVisitor[T]) -> T: + # TODO clean up this type warning + return visitor.visit_variant_ref(self) + @dataclasses.dataclass(frozen=True) class EnumRef(Ref): diff --git a/python/tako/core/compiler/types/variant_expand.py b/python/tako/core/compiler/types/variant_expand.py index 99d0222..4b17a6d 100644 --- a/python/tako/core/compiler/types/variant_expand.py +++ b/python/tako/core/compiler/types/variant_expand.py @@ -34,11 +34,17 @@ class VariantExpand(mir.RootTypeVisitor[t.Optional[mir.Struct]]): def visit_struct(self, root: mir.Struct) -> t.Optional[mir.Struct]: new_fields: t.Dict[str, mir.Type] = {} for fname, ftype in root.fields.items(): - if isinstance(ftype, mir.VariantRef): + if isinstance(ftype, mir.HashVariantRef): new_fname = f"{fname}_injected_key_" new_fields[new_fname] = ftype.resolve(self.types).tag_type new_fields[fname] = mir.DetachedVariant( - variant=ftype, tag=mir.FieldReference(new_fname) + variant=ftype, tag=mir.FieldReference(new_fname), is_hash_tag = True + ) + elif isinstance(ftype, mir.VariantRef): + new_fname = f"{fname}_injected_key_" + new_fields[new_fname] = ftype.resolve(self.types).tag_type + new_fields[fname] = mir.DetachedVariant( + variant=ftype, tag=mir.FieldReference(new_fname), is_hash_tag = False ) else: new_fields[fname] = ftype diff --git a/python/tako/core/repr_str.py b/python/tako/core/repr_str.py index ecccdce..3e7d528 100644 --- a/python/tako/core/repr_str.py +++ b/python/tako/core/repr_str.py @@ -120,3 +120,91 @@ def visit_enum(self, root: mir.Enum) -> str: pairs = sorted([(value, name) for name, value in root.variants.items()]) variants = ",".join([f"{value}={name}" for value, name in pairs]) return f"Enum(name={root.name},underlying={root.underlying_type.accept(self)},variants={{{variants}}})" + +## This is used to construct the shallow hash. +@dataclasses.dataclass +class ShallowReprStr( + mir.RootTypeVisitor[str], + mir.VariantVisitor[str], + mir.SeqTypeVisitor[str], + mir.LengthVisitor[str], +): + types: t.Dict[QName, mir.RootType] + + def visit_int(self, type_: mir.Int) -> str: + return f"Int(width={type_.width},sign={type_.sign.name},endianness={type_.endianness.name})" + + def visit_float(self, type_: mir.Float) -> str: + return f"Float(width={type_.width},endianness={type_.endianness.name})" + + def repr_field_reference(self, fr: mir.FieldReference) -> str: + return f"FieldReference(name={fr.name})" + + def visit_seq(self, type_: mir.Seq) -> str: + # Note that the repr is done on the type before seq reduce + # -- there is only Seq, but not List, Vector, or Array. + # This is deliberate - the representation of the type is the intended to + # be a representation that is as simple as possible, but conveys everything + # needed to parse or serialize a type from a wire representation. + # List, Vector, and Array do not impact the wire representation, and hence + # are not used. + length = type_.length.accept(self) + return f"Seq(inner={type_.inner.accept(self)},length={length})" + + def visit_unbound_seq(self, type_: mir.UnboundSeq) -> str: + raise InternalError() + + def visit_fixed_length(self, length: mir.FixedLength) -> str: + return f"{length.length}" + + def visit_variable_length(self, length: mir.VariableLength) -> str: + return self.repr_field_reference(length.length) + + def visit_detached_variant(self, type_: mir.DetachedVariant) -> str: + variant = type_.variant.resolve(self.types) + return f"DetachedVariant(variant={variant.accept(self)},tag={self.repr_field_reference(type_.tag)})" + + def visit_virtual(self, type_: mir.Virtual) -> str: + # Virtual types contribute to the hash just like normal types, even though they have no + # effect on the wire representation. They allow a type to represent that there is + # some other data on the wire related to it -- virtual fields are parsed + # using the context of the rest of the type. + return f"Virtual(inner={type_.inner.accept(self)})" + + def visit_ref(self, type_: mir.Ref) -> str: + return type_.resolve(self.types).accept(self) + + def visit_struct(self, root: mir.Struct) -> str: + # Including the name of the fields is critical -- otherwise the struct Foo { x: int, y: int } + # is the same as the struct Foo { y: int, x: int }. + # Names provide meaning to fields. + # Note that no derrived information, like the size or offset of fields is included. That could all + # be computed from this and is extraneous. + l = [] + for name, type_ in root.fields.items(): + if isinstance(type_, mir.DetachedVariant) and type_.is_hash_tag: + l.append(f"{name}=HashVariant") + else: + l.append(f"{name}={type_.accept(self)}") + fields = ",".join(l) + return f"Struct(name={root.name},fields={{{fields}}})" + + def visit_variant(self, root: mir.Variant) -> str: + return root.accept_v(self) + + def visit_fixed_variant(self, root: mir.FixedVariant) -> str: + # Sort the variant by tag to ensure order doesn't matter + pairs = sorted([(tag, sr) for sr, tag in root.tags.items()]) + variants = ",".join( + [f"{tag}={value.resolve(self.types).accept(self)}" for tag, value in pairs] + ) + return f"Variant(name={root.name},tag_type={root.tag_type.accept(self)},variants={{{variants}}})" + + def visit_hash_variant(self, root: mir.HashVariant) -> str: + raise InternalError() + + def visit_enum(self, root: mir.Enum) -> str: + # Like variant + pairs = sorted([(value, name) for name, value in root.variants.items()]) + variants = ",".join([f"{value}={name}" for value, name in pairs]) + return f"Enum(name={root.name},underlying={root.underlying_type.accept(self)},variants={{{variants}}})" diff --git a/python/test_types/bakery/v4.py b/python/test_types/bakery/v4.py index 935dd40..d77ace1 100644 --- a/python/test_types/bakery/v4.py +++ b/python/test_types/bakery/v4.py @@ -27,9 +27,11 @@ class V4(Protocol): # Frosting flavor is actually the most important # I can't believe I forgot it frosting_flavor=Flavor, + # Forgot about the number of sprinkles + # sprinkle_quantity=li32, ) CakeOrder = Struct(layers=li32, shape=Shape, flavor=Flavor) - Order = Variant[u8]({CupcakeOrder: 0, CakeOrder: 1}) + Order = HashVariant[u8]([ CupcakeOrder, CakeOrder]) ErrorResponse = Struct() NewOrderRequest = Struct(name_len=u8, name=Seq(i8, this.name_len), order=Order) @@ -47,14 +49,3 @@ class V4(Protocol): } ) Message = Struct(msg=MessageVariant) - - conversions = [ - ConversionsFromPrior( - Prior, - VariantConversion( - src=MessageVariant, - target=Prior.MessageVariant, - mapping={CancelOrderRequest: None, CancelOrderResponse: None}, - ), - ) - ] diff --git a/shallow_hash.txt b/shallow_hash.txt new file mode 100644 index 0000000..bbb43ef --- /dev/null +++ b/shallow_hash.txt @@ -0,0 +1,31 @@ +[](const ::test_types::bakery::v4::Message&) { return static_cast<::std::uint32_t>(UINT32_C(2458638600)); } + +This is the hash of a v4 message, which inside has a hash variant. Hopefully it changes when i add to the inner HV + +[](const ::test_types::bakery::v4::Message&) { return static_cast<::std::uint32_t>(UINT32_C(3267984003)); } + +Indeed it does. + +I made ShallowReprHash. This might not be the most efficient way to do it but initially I just want to see it preserve the hashes + +[](const ::test_types::bakery::v4::Message&) { return static_cast<::std::uint32_t>(UINT32_C(2458638600)); } + +Seems like it's back to the old hash when I remove the new sprinkle_quantity from the cupcakeorder, great. + +Ok that was more work than expected, but before: + + 5 [](const ::test_types::bakery::v4::Message&) { return static_cast<::std::uint32_t>(UINT32_C(390615788)); } + +after changing a field on a struct only present in an HV + 5 [](const ::test_types::bakery::v4::Message&) { return static_cast<::std::uint32_t>(UINT32_C(390615788)); } + +after changing a field on a struct elsewhere in v4 +::std::uint32_t tag() const { + 1 return match( + 2 [](const ::test_types::bakery::v2::Message&) { return static_cast<::std::uint32_t>(UINT32_C(3831964682)); }, + 3 [](const ::test_types::bakery::v1::Message&) { return static_cast<::std::uint32_t>(UINT32_C(2782154402)); }, + 4 [](const ::test_types::bakery::v3::Message&) { return static_cast<::std::uint32_t>(UINT32_C(716972100)); }, + 5 [](const ::test_types::bakery::v4::Message&) { return static_cast<::std::uint32_t>(UINT32_C(1691428639)); } + 6 ); + +after changing a field on a struct elsewhere in v2 From 45b3a058e420057700b024844055e90abd92f455 Mon Sep 17 00:00:00 2001 From: John Date: Tue, 28 Jan 2025 20:31:30 -0500 Subject: [PATCH 2/2] i think this fixes hash expansion to always be shallow --- python/tako/core/compiler/types/hash_expand.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/python/tako/core/compiler/types/hash_expand.py b/python/tako/core/compiler/types/hash_expand.py index 84eb063..d034b96 100644 --- a/python/tako/core/compiler/types/hash_expand.py +++ b/python/tako/core/compiler/types/hash_expand.py @@ -16,7 +16,7 @@ import dataclasses from tako.core.compiler.types import mir from tako.core.error import Error -from tako.core.repr_str import ReprStr, ShallowReprStr +from tako.core.repr_str import ShallowReprStr from tako.util.qname import QName import hashlib @@ -63,19 +63,15 @@ def shallow_digest(self, rt: mir.RootType) -> Digest: x = rt.accept(ShallowReprStr(self.types)) return Digest(repr_str=x, repr_hash=sha256hex(x)) - def digest(self, rt: mir.RootType) -> Digest: - x = rt.accept(ReprStr(self.types)) - return Digest(repr_str=x, repr_hash=sha256hex(x)) - def visit_struct(self, root: mir.Struct) -> t.Union[HashExpandResult, Error]: - return HashExpandResult(self.digest(root), None) + return HashExpandResult(self.shallow_digest(root), None) def visit_variant(self, root: mir.Variant) -> t.Union[HashExpandResult, Error]: fixed = root.accept_v(self) if isinstance(fixed, Error): return fixed else: - return HashExpandResult(self.digest(fixed), fixed) + return HashExpandResult(self.shallow_digest(fixed), fixed) def visit_fixed_variant( self, variant: mir.FixedVariant @@ -103,4 +99,4 @@ def visit_hash_variant( return mir.FixedVariant(variant.name, variant.tag_type, tag_map) def visit_enum(self, root: mir.Enum) -> t.Union[HashExpandResult, Error]: - return HashExpandResult(self.digest(root), None) + return HashExpandResult(self.shallow_digest(root), None)