Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions python/tako/core/compiler/types/hash_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import dataclasses
from tako.core.compiler.types import mir
from tako.core.error import Error
from tako.core.repr_str import ReprStr
from tako.core.repr_str import ShallowReprStr
from tako.util.qname import QName
import hashlib

Expand Down Expand Up @@ -59,19 +59,19 @@ class HashExpand(
):
types: t.Dict[QName, mir.RootType]

def digest(self, rt: mir.RootType) -> Digest:
x = rt.accept(ReprStr(self.types))
def shallow_digest(self, rt: mir.RootType) -> Digest:
x = rt.accept(ShallowReprStr(self.types))
return Digest(repr_str=x, repr_hash=sha256hex(x))

def visit_struct(self, root: mir.Struct) -> t.Union[HashExpandResult, Error]:
return HashExpandResult(self.digest(root), None)
return HashExpandResult(self.shallow_digest(root), None)

def visit_variant(self, root: mir.Variant) -> t.Union[HashExpandResult, Error]:
fixed = root.accept_v(self)
if isinstance(fixed, Error):
return fixed
else:
return HashExpandResult(self.digest(fixed), fixed)
return HashExpandResult(self.shallow_digest(fixed), fixed)

def visit_fixed_variant(
self, variant: mir.FixedVariant
Expand All @@ -87,7 +87,7 @@ def visit_hash_variant(
tag_map: t.Dict[mir.StructRef, int] = {}
inv_tag_map: t.Dict[int, mir.StructRef] = {}
for type_ in variant.types():
hash_hex = self.digest(type_.resolve(self.types)).repr_hash
hash_hex = self.shallow_digest(type_.resolve(self.types)).repr_hash
short = int(hash_hex[:tag_width_hex_digits], 16)
if short in inv_tag_map:
return Error(
Expand All @@ -99,4 +99,4 @@ def visit_hash_variant(
return mir.FixedVariant(variant.name, variant.tag_type, tag_map)

def visit_enum(self, root: mir.Enum) -> t.Union[HashExpandResult, Error]:
return HashExpandResult(self.digest(root), None)
return HashExpandResult(self.shallow_digest(root), None)
3 changes: 2 additions & 1 deletion python/tako/core/compiler/types/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def visit_detached_variant(self, type_: pt.DetachedVariant) -> mir.Type:
return mir.DetachedVariant(
mir.VariantRef(type_.variant.qualified_name()),
mir.FieldReference(type_.tag.name),
False
)

def visit_virtual(self, type_: pt.Virtual) -> mir.Type:
Expand All @@ -103,4 +104,4 @@ def visit_variant_def(self, type_: pt.VariantDef) -> mir.Type:
return mir.VariantRef(type_.qualified_name())

def visit_hash_variant_def(self, type_: pt.HashVariantDef) -> mir.Type:
return mir.VariantRef(type_.qualified_name())
return mir.HashVariantRef(type_.qualified_name())
11 changes: 10 additions & 1 deletion python/tako/core/compiler/types/mir.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,11 @@ def accept(self, visitor: LengthVisitor[T]) -> T:
class DetachedVariant(Type):
variant: VariantRef
tag: FieldReference
is_hash_tag: bool

def accept(self, visitor: TypeVisitor[T]) -> T:
return visitor.visit_detached_variant(self)


@dataclasses.dataclass(frozen=True)
class Virtual(Type):
inner: Type
Expand Down Expand Up @@ -338,6 +338,15 @@ def resolve(self, context: t.Dict[QName, RootType]) -> Variant:
def accept_r(self, visitor: RefVisitor[T]) -> T:
return visitor.visit_variant_ref(self)

@dataclasses.dataclass(frozen=True)
class HashVariantRef(VariantRef):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the purpose of this so that we can differentiate between the two?

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it's either this or a field on variantref

def resolve(self, context: t.Dict[QName, RootType]) -> Variant:
return checked_cast(Variant, context[self.name])

def accept_r(self, visitor: RefVisitor[T]) -> T:
# TODO clean up this type warning
return visitor.visit_variant_ref(self)


@dataclasses.dataclass(frozen=True)
class EnumRef(Ref):
Expand Down
10 changes: 8 additions & 2 deletions python/tako/core/compiler/types/variant_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,17 @@ class VariantExpand(mir.RootTypeVisitor[t.Optional[mir.Struct]]):
def visit_struct(self, root: mir.Struct) -> t.Optional[mir.Struct]:
new_fields: t.Dict[str, mir.Type] = {}
for fname, ftype in root.fields.items():
if isinstance(ftype, mir.VariantRef):
if isinstance(ftype, mir.HashVariantRef):
new_fname = f"{fname}_injected_key_"
new_fields[new_fname] = ftype.resolve(self.types).tag_type
new_fields[fname] = mir.DetachedVariant(
variant=ftype, tag=mir.FieldReference(new_fname)
variant=ftype, tag=mir.FieldReference(new_fname), is_hash_tag = True
)
elif isinstance(ftype, mir.VariantRef):
new_fname = f"{fname}_injected_key_"
new_fields[new_fname] = ftype.resolve(self.types).tag_type
new_fields[fname] = mir.DetachedVariant(
variant=ftype, tag=mir.FieldReference(new_fname), is_hash_tag = False
)
else:
new_fields[fname] = ftype
Expand Down
88 changes: 88 additions & 0 deletions python/tako/core/repr_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,91 @@ def visit_enum(self, root: mir.Enum) -> str:
pairs = sorted([(value, name) for name, value in root.variants.items()])
variants = ",".join([f"{value}={name}" for value, name in pairs])
return f"Enum(name={root.name},underlying={root.underlying_type.accept(self)},variants={{{variants}}})"

## This is used to construct the shallow hash.
@dataclasses.dataclass
class ShallowReprStr(
mir.RootTypeVisitor[str],
mir.VariantVisitor[str],
mir.SeqTypeVisitor[str],
mir.LengthVisitor[str],
):
types: t.Dict[QName, mir.RootType]

def visit_int(self, type_: mir.Int) -> str:
return f"Int(width={type_.width},sign={type_.sign.name},endianness={type_.endianness.name})"

def visit_float(self, type_: mir.Float) -> str:
return f"Float(width={type_.width},endianness={type_.endianness.name})"

def repr_field_reference(self, fr: mir.FieldReference) -> str:
return f"FieldReference(name={fr.name})"

def visit_seq(self, type_: mir.Seq) -> str:
# Note that the repr is done on the type before seq reduce
# -- there is only Seq, but not List, Vector, or Array.
# This is deliberate - the representation of the type is the intended to
# be a representation that is as simple as possible, but conveys everything
# needed to parse or serialize a type from a wire representation.
# List, Vector, and Array do not impact the wire representation, and hence
# are not used.
length = type_.length.accept(self)
return f"Seq(inner={type_.inner.accept(self)},length={length})"

def visit_unbound_seq(self, type_: mir.UnboundSeq) -> str:
raise InternalError()

def visit_fixed_length(self, length: mir.FixedLength) -> str:
return f"{length.length}"

def visit_variable_length(self, length: mir.VariableLength) -> str:
return self.repr_field_reference(length.length)

def visit_detached_variant(self, type_: mir.DetachedVariant) -> str:
variant = type_.variant.resolve(self.types)
return f"DetachedVariant(variant={variant.accept(self)},tag={self.repr_field_reference(type_.tag)})"

def visit_virtual(self, type_: mir.Virtual) -> str:
# Virtual types contribute to the hash just like normal types, even though they have no
# effect on the wire representation. They allow a type to represent that there is
# some other data on the wire related to it -- virtual fields are parsed
# using the context of the rest of the type.
return f"Virtual(inner={type_.inner.accept(self)})"

def visit_ref(self, type_: mir.Ref) -> str:
return type_.resolve(self.types).accept(self)

def visit_struct(self, root: mir.Struct) -> str:
# Including the name of the fields is critical -- otherwise the struct Foo { x: int, y: int }
# is the same as the struct Foo { y: int, x: int }.
# Names provide meaning to fields.
# Note that no derrived information, like the size or offset of fields is included. That could all
# be computed from this and is extraneous.
l = []
for name, type_ in root.fields.items():
if isinstance(type_, mir.DetachedVariant) and type_.is_hash_tag:
l.append(f"{name}=HashVariant")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This crucially prevents recursion and keeps the hash variants shallow?

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep!

@johnlevidy johnlevidy Feb 5, 2025

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's still important to include the type, but that's exactly right: "i dont need to include the inner types in the hash b/c access is protected by the hash"

else:
l.append(f"{name}={type_.accept(self)}")
fields = ",".join(l)
return f"Struct(name={root.name},fields={{{fields}}})"

def visit_variant(self, root: mir.Variant) -> str:
return root.accept_v(self)

def visit_fixed_variant(self, root: mir.FixedVariant) -> str:
# Sort the variant by tag to ensure order doesn't matter
pairs = sorted([(tag, sr) for sr, tag in root.tags.items()])
variants = ",".join(
[f"{tag}={value.resolve(self.types).accept(self)}" for tag, value in pairs]
)
return f"Variant(name={root.name},tag_type={root.tag_type.accept(self)},variants={{{variants}}})"

def visit_hash_variant(self, root: mir.HashVariant) -> str:
raise InternalError()

def visit_enum(self, root: mir.Enum) -> str:
# Like variant
pairs = sorted([(value, name) for name, value in root.variants.items()])
variants = ",".join([f"{value}={name}" for value, name in pairs])
return f"Enum(name={root.name},underlying={root.underlying_type.accept(self)},variants={{{variants}}})"
15 changes: 3 additions & 12 deletions python/test_types/bakery/v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ class V4(Protocol):
# Frosting flavor is actually the most important
# I can't believe I forgot it
frosting_flavor=Flavor,
# Forgot about the number of sprinkles
# sprinkle_quantity=li32,
)
CakeOrder = Struct(layers=li32, shape=Shape, flavor=Flavor)
Order = Variant[u8]({CupcakeOrder: 0, CakeOrder: 1})
Order = HashVariant[u8]([ CupcakeOrder, CakeOrder])

ErrorResponse = Struct()
NewOrderRequest = Struct(name_len=u8, name=Seq(i8, this.name_len), order=Order)
Expand All @@ -47,14 +49,3 @@ class V4(Protocol):
}
)
Message = Struct(msg=MessageVariant)

conversions = [
ConversionsFromPrior(
Prior,
VariantConversion(
src=MessageVariant,
target=Prior.MessageVariant,
mapping={CancelOrderRequest: None, CancelOrderResponse: None},
),
)
]
31 changes: 31 additions & 0 deletions shallow_hash.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
[](const ::test_types::bakery::v4::Message&) { return static_cast<::std::uint32_t>(UINT32_C(2458638600)); }

This is the hash of a v4 message, which inside has a hash variant. Hopefully it changes when i add to the inner HV

[](const ::test_types::bakery::v4::Message&) { return static_cast<::std::uint32_t>(UINT32_C(3267984003)); }

Indeed it does.

I made ShallowReprHash. This might not be the most efficient way to do it but initially I just want to see it preserve the hashes

[](const ::test_types::bakery::v4::Message&) { return static_cast<::std::uint32_t>(UINT32_C(2458638600)); }

Seems like it's back to the old hash when I remove the new sprinkle_quantity from the cupcakeorder, great.

Ok that was more work than expected, but before:

5 [](const ::test_types::bakery::v4::Message&) { return static_cast<::std::uint32_t>(UINT32_C(390615788)); }

after changing a field on a struct only present in an HV
5 [](const ::test_types::bakery::v4::Message&) { return static_cast<::std::uint32_t>(UINT32_C(390615788)); }

after changing a field on a struct elsewhere in v4
::std::uint32_t tag() const {
1 return match(
2 [](const ::test_types::bakery::v2::Message&) { return static_cast<::std::uint32_t>(UINT32_C(3831964682)); },
3 [](const ::test_types::bakery::v1::Message&) { return static_cast<::std::uint32_t>(UINT32_C(2782154402)); },
4 [](const ::test_types::bakery::v3::Message&) { return static_cast<::std::uint32_t>(UINT32_C(716972100)); },
5 [](const ::test_types::bakery::v4::Message&) { return static_cast<::std::uint32_t>(UINT32_C(1691428639)); }
6 );

after changing a field on a struct elsewhere in v2