diff --git a/python/tako/core/compiler/__init__.py b/python/tako/core/compiler/__init__.py index 08591b5..44d10c6 100644 --- a/python/tako/core/compiler/__init__.py +++ b/python/tako/core/compiler/__init__.py @@ -38,6 +38,7 @@ def compile_proto( if isinstance(comp_constants, list): return comp_constants + # Q: Are conversions affected by skippable hash variants in any meaningful way? comp_conversions = conversions.compile( proto.name, proto.conversions, comp_types.types ) diff --git a/python/tako/core/compiler/types/__init__.py b/python/tako/core/compiler/types/__init__.py index e9d81d0..6274156 100644 --- a/python/tako/core/compiler/types/__init__.py +++ b/python/tako/core/compiler/types/__init__.py @@ -45,6 +45,11 @@ def compile( if errors: return errors + # Builds up a map from master to slave fields. This is the sort of reverse + # action that takes place during parsing. At parse time, a Seq's length is first + # read, then the sequence length is known. At build time, the Seq is built first + # and the previous field is later populated. This master_field_map seems focused + # on the "builder" problem master_field_map = master_fields.run(lowered, type_order) if isinstance(master_field_map, list): return master_field_map diff --git a/python/tako/core/compiler/types/fuse.py b/python/tako/core/compiler/types/fuse.py index c7abbdf..7c4f62f 100644 --- a/python/tako/core/compiler/types/fuse.py +++ b/python/tako/core/compiler/types/fuse.py @@ -100,6 +100,7 @@ def visit_fixed_variant(self, variant: mir.FixedVariant) -> lir.RootType: trivial=self.tmap[variant.name], name=variant.name, digest=self.digest_map[variant.name], + len_type=checked_cast(lir.Int, variant.len_type.accept(self)) if variant.len_type else None, tag_type=checked_cast(lir.Int, variant.tag_type.accept(self)), tags={ checked_cast(lir.Struct, sr.accept(self)): value diff --git a/python/tako/core/compiler/types/hash_expand.py b/python/tako/core/compiler/types/hash_expand.py index f2298de..8a13ef1 100644 --- a/python/tako/core/compiler/types/hash_expand.py +++ b/python/tako/core/compiler/types/hash_expand.py @@ -16,7 +16,7 @@ import dataclasses from tako.core.compiler.types import mir from tako.core.error import Error -from tako.core.repr_str import ReprStr +from tako.core.repr_str import ShallowReprStr from tako.util.qname import QName import hashlib @@ -58,20 +58,20 @@ class HashExpand( mir.VariantVisitor[t.Union[mir.FixedVariant, Error]], ): types: t.Dict[QName, mir.RootType] - - def digest(self, rt: mir.RootType) -> Digest: - x = rt.accept(ReprStr(self.types)) + + def shallow_digest(self, rt: mir.RootType) -> Digest: + x = rt.accept(ShallowReprStr(self.types)) return Digest(repr_str=x, repr_hash=sha256hex(x)) def visit_struct(self, root: mir.Struct) -> t.Union[HashExpandResult, Error]: - return HashExpandResult(self.digest(root), None) + return HashExpandResult(self.shallow_digest(root), None) def visit_variant(self, root: mir.Variant) -> t.Union[HashExpandResult, Error]: fixed = root.accept_v(self) if isinstance(fixed, Error): return fixed else: - return HashExpandResult(self.digest(fixed), fixed) + return HashExpandResult(self.shallow_digest(fixed), fixed) def visit_fixed_variant( self, variant: mir.FixedVariant @@ -87,7 +87,7 @@ def visit_hash_variant( tag_map: t.Dict[mir.StructRef, int] = {} inv_tag_map: t.Dict[int, mir.StructRef] = {} for type_ in variant.types(): - hash_hex = self.digest(type_.resolve(self.types)).repr_hash + hash_hex = self.shallow_digest(type_.resolve(self.types)).repr_hash short = int(hash_hex[:tag_width_hex_digits], 16) if short in inv_tag_map: return Error( @@ -96,7 +96,7 @@ def visit_hash_variant( tag_map[type_] = short inv_tag_map[short] = type_ - return mir.FixedVariant(variant.name, variant.tag_type, tag_map) + return mir.FixedVariant(variant.name, variant.tag_type, tag_map, variant.len_type) def visit_enum(self, root: mir.Enum) -> t.Union[HashExpandResult, Error]: - return HashExpandResult(self.digest(root), None) + return HashExpandResult(self.shallow_digest(root), None) diff --git a/python/tako/core/compiler/types/lir.py b/python/tako/core/compiler/types/lir.py index 8ed8729..969d2b2 100644 --- a/python/tako/core/compiler/types/lir.py +++ b/python/tako/core/compiler/types/lir.py @@ -147,6 +147,7 @@ class MasterField: @dataclasses.dataclass(frozen=True) class Variant(RootType): tag_type: Int + len_type: Optional[Int] tags: t.Dict[Struct, int] = dataclasses.field(compare=False) def accept_rtv(self, visitor: RootTypeVisitor[T]) -> T: diff --git a/python/tako/core/compiler/types/lower.py b/python/tako/core/compiler/types/lower.py index 123cb85..2caddeb 100644 --- a/python/tako/core/compiler/types/lower.py +++ b/python/tako/core/compiler/types/lower.py @@ -49,18 +49,23 @@ def visit_variant_def(self, type_: pt.VariantDef) -> mir.RootType: checked_cast(mir.StructRef, struct.accept(Lower())): value for struct, value in type_.variants.items() }, + False ) def visit_hash_variant_def(self, type_: pt.HashVariantDef) -> mir.RootType: + # Cast to the mir int types + mir_tag_type = checked_cast(mir.Int, type_.tag_type.accept(Lower())) + mir_len_type = checked_cast(mir.Int, type_.len_type.accept(Lower())) if type_.len_type else None return mir.HashVariant( type_.qualified_name(), - checked_cast(mir.Int, type_.tag_type.accept(Lower())), + mir_tag_type, set( [ checked_cast(mir.StructRef, struct.accept(Lower())) for struct in type_.hash_types ] ), + mir_len_type ) @@ -88,6 +93,7 @@ def visit_detached_variant(self, type_: pt.DetachedVariant) -> mir.Type: return mir.DetachedVariant( mir.VariantRef(type_.variant.qualified_name()), mir.FieldReference(type_.tag.name), + False ) def visit_virtual(self, type_: pt.Virtual) -> mir.Type: @@ -103,4 +109,4 @@ def visit_variant_def(self, type_: pt.VariantDef) -> mir.Type: return mir.VariantRef(type_.qualified_name()) def visit_hash_variant_def(self, type_: pt.HashVariantDef) -> mir.Type: - return mir.VariantRef(type_.qualified_name()) + return mir.HashVariantRef(type_.qualified_name()) diff --git a/python/tako/core/compiler/types/master_fields.py b/python/tako/core/compiler/types/master_fields.py index 401dcbb..b02fe81 100644 --- a/python/tako/core/compiler/types/master_fields.py +++ b/python/tako/core/compiler/types/master_fields.py @@ -28,7 +28,6 @@ class MasterField: master_field: str key_property: KeyProperty - @dataclasses.dataclass(frozen=True) class DeterminedField: determined_field: str @@ -79,7 +78,7 @@ def visit_enum( # Map from a field A to a field B where the value of A (slave field) is determined by the the value # of field B (master field). -# For example, if a struct has these fields: {"len": li32, "data", Seq(li32, this.len)}, +# For example, if a struct has these fields: {"len": li32, "data": Seq(li32, this.len)}, # then this function would return {"len": "data"} because the value of len is determined # by the value in data. # This is needed for generating builders: should a given field be included in the generated diff --git a/python/tako/core/compiler/types/mir.py b/python/tako/core/compiler/types/mir.py index 4339cbf..a259055 100644 --- a/python/tako/core/compiler/types/mir.py +++ b/python/tako/core/compiler/types/mir.py @@ -83,6 +83,7 @@ def visit_hash_variant(self, variant: HashVariant) -> T: @dataclasses.dataclass(frozen=True) class FixedVariant(Variant): tags: t.Dict[StructRef, int] + len_type: t.Optional[Int] def types(self) -> t.Iterable[StructRef]: return self.tags.keys() @@ -94,6 +95,7 @@ def accept_v(self, visitor: VariantVisitor[T]) -> T: @dataclasses.dataclass(frozen=True) class HashVariant(Variant): hash_types: t.Set[StructRef] + len_type: t.Optional[Int] def types(self) -> t.Iterable[StructRef]: return self.hash_types @@ -280,11 +282,11 @@ def accept(self, visitor: LengthVisitor[T]) -> T: class DetachedVariant(Type): variant: VariantRef tag: FieldReference + is_hash_tag: bool def accept(self, visitor: TypeVisitor[T]) -> T: return visitor.visit_detached_variant(self) - @dataclasses.dataclass(frozen=True) class Virtual(Type): inner: Type @@ -338,6 +340,15 @@ def resolve(self, context: t.Dict[QName, RootType]) -> Variant: def accept_r(self, visitor: RefVisitor[T]) -> T: return visitor.visit_variant_ref(self) +@dataclasses.dataclass(frozen=True) +class HashVariantRef(VariantRef): + def resolve(self, context: t.Dict[QName, RootType]) -> Variant: + return checked_cast(Variant, context[self.name]) + + def accept_r(self, visitor: RefVisitor[T]) -> T: + # TODO clean up this type warning + return visitor.visit_variant_ref(self) + @dataclasses.dataclass(frozen=True) class EnumRef(Ref): diff --git a/python/tako/core/compiler/types/size.py b/python/tako/core/compiler/types/size.py index 5652ac5..6275a48 100644 --- a/python/tako/core/compiler/types/size.py +++ b/python/tako/core/compiler/types/size.py @@ -63,8 +63,13 @@ def visit_struct(self, root: mir.Struct) -> RootSizeResult: def visit_variant(self, root: mir.Variant) -> RootSizeResult: target_size: t.Optional[st.Constant] = None + # Mark + if isinstance(root, mir.FixedVariant) and root.len_type: + print(f"Detected variant with embedded length: {root.name}") + return RootSizeResult(st.Dynamic(), None) for sr in root.types(): size = self.size_map[sr.name].size + print(f"Contained size of {root.name} member {sr.name} is {size}") if not isinstance(size, st.Constant): return RootSizeResult(st.Dynamic(), None) if target_size is None: diff --git a/python/tako/core/compiler/types/variant_expand.py b/python/tako/core/compiler/types/variant_expand.py index 99d0222..4b17a6d 100644 --- a/python/tako/core/compiler/types/variant_expand.py +++ b/python/tako/core/compiler/types/variant_expand.py @@ -34,11 +34,17 @@ class VariantExpand(mir.RootTypeVisitor[t.Optional[mir.Struct]]): def visit_struct(self, root: mir.Struct) -> t.Optional[mir.Struct]: new_fields: t.Dict[str, mir.Type] = {} for fname, ftype in root.fields.items(): - if isinstance(ftype, mir.VariantRef): + if isinstance(ftype, mir.HashVariantRef): new_fname = f"{fname}_injected_key_" new_fields[new_fname] = ftype.resolve(self.types).tag_type new_fields[fname] = mir.DetachedVariant( - variant=ftype, tag=mir.FieldReference(new_fname) + variant=ftype, tag=mir.FieldReference(new_fname), is_hash_tag = True + ) + elif isinstance(ftype, mir.VariantRef): + new_fname = f"{fname}_injected_key_" + new_fields[new_fname] = ftype.resolve(self.types).tag_type + new_fields[fname] = mir.DetachedVariant( + variant=ftype, tag=mir.FieldReference(new_fname), is_hash_tag = False ) else: new_fields[fname] = ftype diff --git a/python/tako/core/repr_str.py b/python/tako/core/repr_str.py index ecccdce..3e7d528 100644 --- a/python/tako/core/repr_str.py +++ b/python/tako/core/repr_str.py @@ -120,3 +120,91 @@ def visit_enum(self, root: mir.Enum) -> str: pairs = sorted([(value, name) for name, value in root.variants.items()]) variants = ",".join([f"{value}={name}" for value, name in pairs]) return f"Enum(name={root.name},underlying={root.underlying_type.accept(self)},variants={{{variants}}})" + +## This is used to construct the shallow hash. +@dataclasses.dataclass +class ShallowReprStr( + mir.RootTypeVisitor[str], + mir.VariantVisitor[str], + mir.SeqTypeVisitor[str], + mir.LengthVisitor[str], +): + types: t.Dict[QName, mir.RootType] + + def visit_int(self, type_: mir.Int) -> str: + return f"Int(width={type_.width},sign={type_.sign.name},endianness={type_.endianness.name})" + + def visit_float(self, type_: mir.Float) -> str: + return f"Float(width={type_.width},endianness={type_.endianness.name})" + + def repr_field_reference(self, fr: mir.FieldReference) -> str: + return f"FieldReference(name={fr.name})" + + def visit_seq(self, type_: mir.Seq) -> str: + # Note that the repr is done on the type before seq reduce + # -- there is only Seq, but not List, Vector, or Array. + # This is deliberate - the representation of the type is the intended to + # be a representation that is as simple as possible, but conveys everything + # needed to parse or serialize a type from a wire representation. + # List, Vector, and Array do not impact the wire representation, and hence + # are not used. + length = type_.length.accept(self) + return f"Seq(inner={type_.inner.accept(self)},length={length})" + + def visit_unbound_seq(self, type_: mir.UnboundSeq) -> str: + raise InternalError() + + def visit_fixed_length(self, length: mir.FixedLength) -> str: + return f"{length.length}" + + def visit_variable_length(self, length: mir.VariableLength) -> str: + return self.repr_field_reference(length.length) + + def visit_detached_variant(self, type_: mir.DetachedVariant) -> str: + variant = type_.variant.resolve(self.types) + return f"DetachedVariant(variant={variant.accept(self)},tag={self.repr_field_reference(type_.tag)})" + + def visit_virtual(self, type_: mir.Virtual) -> str: + # Virtual types contribute to the hash just like normal types, even though they have no + # effect on the wire representation. They allow a type to represent that there is + # some other data on the wire related to it -- virtual fields are parsed + # using the context of the rest of the type. + return f"Virtual(inner={type_.inner.accept(self)})" + + def visit_ref(self, type_: mir.Ref) -> str: + return type_.resolve(self.types).accept(self) + + def visit_struct(self, root: mir.Struct) -> str: + # Including the name of the fields is critical -- otherwise the struct Foo { x: int, y: int } + # is the same as the struct Foo { y: int, x: int }. + # Names provide meaning to fields. + # Note that no derrived information, like the size or offset of fields is included. That could all + # be computed from this and is extraneous. + l = [] + for name, type_ in root.fields.items(): + if isinstance(type_, mir.DetachedVariant) and type_.is_hash_tag: + l.append(f"{name}=HashVariant") + else: + l.append(f"{name}={type_.accept(self)}") + fields = ",".join(l) + return f"Struct(name={root.name},fields={{{fields}}})" + + def visit_variant(self, root: mir.Variant) -> str: + return root.accept_v(self) + + def visit_fixed_variant(self, root: mir.FixedVariant) -> str: + # Sort the variant by tag to ensure order doesn't matter + pairs = sorted([(tag, sr) for sr, tag in root.tags.items()]) + variants = ",".join( + [f"{tag}={value.resolve(self.types).accept(self)}" for tag, value in pairs] + ) + return f"Variant(name={root.name},tag_type={root.tag_type.accept(self)},variants={{{variants}}})" + + def visit_hash_variant(self, root: mir.HashVariant) -> str: + raise InternalError() + + def visit_enum(self, root: mir.Enum) -> str: + # Like variant + pairs = sorted([(value, name) for name, value in root.variants.items()]) + variants = ",".join([f"{value}={name}" for value, name in pairs]) + return f"Enum(name={root.name},underlying={root.underlying_type.accept(self)},variants={{{variants}}})" diff --git a/python/tako/core/sir.py b/python/tako/core/sir.py index 6836375..7a70d93 100644 --- a/python/tako/core/sir.py +++ b/python/tako/core/sir.py @@ -23,7 +23,6 @@ kir = constants.lir cir = conversions.lir - @dataclasses.dataclass(frozen=True) class Protocol: name: QName diff --git a/python/tako/core/types.py b/python/tako/core/types.py index 20743d5..e7d32e2 100644 --- a/python/tako/core/types.py +++ b/python/tako/core/types.py @@ -268,6 +268,7 @@ def make_variant(variants: t.Dict[StructDef, int]) -> VariantDef: class HashVariantDef(RootType): tag_type: Int hash_types: t.List[StructDef] + len_type: t.Optional[Int] def accept(self, visitor: TypeVisitor[T]) -> T: return visitor.visit_hash_variant_def(self) @@ -275,19 +276,27 @@ def accept(self, visitor: TypeVisitor[T]) -> T: def accept_rtv(self, visitor: RootTypeVisitor[T]) -> T: return visitor.visit_hash_variant_def(self) - class HashVariantHelper: def __getitem__( self, tag_type: Int ) -> t.Callable[[t.List[StructDef]], HashVariantDef]: def make_hash_variant(hash_types: t.List[StructDef]) -> HashVariantDef: - return HashVariantDef(tag_type, hash_types) + return HashVariantDef(tag_type, hash_types, None) + + return make_hash_variant + +class SkippableHashVariantHelper: + def __getitem__( + self, tag_len_type: t.Tuple[Int, ...] + ) -> t.Callable[[t.List[StructDef]], HashVariantDef]: + def make_hash_variant(hash_types: t.List[StructDef]) -> HashVariantDef: + return HashVariantDef(tag_len_type[0], hash_types, tag_len_type[1]) return make_hash_variant HashVariant = HashVariantHelper() - +SkippableHashVariant = SkippableHashVariantHelper() @dataclasses.dataclass(eq=False) class Seq(Type): diff --git a/python/tako/generators/lsir/lsir.py b/python/tako/generators/lsir/lsir.py index 6dc1913..f86379d 100644 --- a/python/tako/generators/lsir/lsir.py +++ b/python/tako/generators/lsir/lsir.py @@ -46,6 +46,7 @@ def generate_into(self, proto: Protocol, out_dir: Path, args: t.Any) -> None: with (out_dir / Path(str(proto.name) + ".json")).open("w") as out: json.dump(dform, out, indent=4) + # TODO this should produce a real value def list_outputs( self, proto_qname: QName, args: t.Any ) -> t.Generator[Path, None, None]: @@ -104,6 +105,7 @@ def visit_variant(self, root: tir.Variant) -> t.Dict[str, t.Any]: root, { "tag_type": root.tag_type.accept(TypeLsir()), + "len_type": root.len_type.accept(TypeLsir()) if root.len_type else None, "variants": { f"{struct.name}": value for struct, value in root.tags.items() }, diff --git a/python/test_types/actor/__init__.py b/python/test_types/actor/__init__.py new file mode 100644 index 0000000..aa37bc0 --- /dev/null +++ b/python/test_types/actor/__init__.py @@ -0,0 +1,9 @@ +from test_types.bakery import Bakery +from test_types.bakery.v1 import V1 +from test_types.bakery.v2 import V2 + +class BakeryActor: + ## TODO fix field access + ## TODO fix variant access + ## Today + Produces = [Bakery.Packet.payload[V2.Message]] diff --git a/python/test_types/actor/actor_compiler.py b/python/test_types/actor/actor_compiler.py new file mode 100644 index 0000000..1637f34 --- /dev/null +++ b/python/test_types/actor/actor_compiler.py @@ -0,0 +1,29 @@ +import importlib +from test_types.actor import BakeryActor + +module = "test_types.actor" + +live_module = importlib.import_module(module) + +a = BakeryActor.Produces[0] + +def print_keys(a): + [print(k) for k, v in a.__dict__.items()] +print("*******************************************************************") +print("*******************************************************************") +print("*******************************************************************") +print("*******************************************************************") +print("*******************************************************************") +print("Produces") +print(a) +## This synatx has to be better +print("Version1") +print_keys(BakeryActor.Produces[0].fields['payload']) +print("Version2") +print(BakeryActor.Produces[0].payload) +# b = BakeryActor.Consumes[0] +# +# print("Consumes") +# print(b) + + diff --git a/python/test_types/bakery/v4.py b/python/test_types/bakery/v4.py index 935dd40..d77ace1 100644 --- a/python/test_types/bakery/v4.py +++ b/python/test_types/bakery/v4.py @@ -27,9 +27,11 @@ class V4(Protocol): # Frosting flavor is actually the most important # I can't believe I forgot it frosting_flavor=Flavor, + # Forgot about the number of sprinkles + # sprinkle_quantity=li32, ) CakeOrder = Struct(layers=li32, shape=Shape, flavor=Flavor) - Order = Variant[u8]({CupcakeOrder: 0, CakeOrder: 1}) + Order = HashVariant[u8]([ CupcakeOrder, CakeOrder]) ErrorResponse = Struct() NewOrderRequest = Struct(name_len=u8, name=Seq(i8, this.name_len), order=Order) @@ -47,14 +49,3 @@ class V4(Protocol): } ) Message = Struct(msg=MessageVariant) - - conversions = [ - ConversionsFromPrior( - Prior, - VariantConversion( - src=MessageVariant, - target=Prior.MessageVariant, - mapping={CancelOrderRequest: None, CancelOrderResponse: None}, - ), - ) - ] diff --git a/python/test_types/organism.py b/python/test_types/organism.py new file mode 100644 index 0000000..b9cf1e6 --- /dev/null +++ b/python/test_types/organism.py @@ -0,0 +1,15 @@ +from tako.core.types import * + +class Organism(Protocol): + Felidae = Struct(num_lives=i8) + Canidae = Struct(balls_caught=i8) + Mammalia = HashVariant[li32]([Felidae, Canidae]) + Chordata = Struct(family=Mammalia, symmetry=i8) + + Phylum = HashVariant[li32]([Chordata]) + Animalia = Struct(motility_type=i8, phylum=Phylum) + # Attempt to make a dynamic length member of a hash variant, does it embed the length somehow? + Plantae = Struct(seed_count=li32, seeds = Seq(i8, this.seed_count)) + + Kingdom = SkippableHashVariant[li32, u8]([Animalia, Plantae]) + Organism = Struct(kingdom=Kingdom) diff --git a/shallow_hash.txt b/shallow_hash.txt new file mode 100644 index 0000000..bbb43ef --- /dev/null +++ b/shallow_hash.txt @@ -0,0 +1,31 @@ +[](const ::test_types::bakery::v4::Message&) { return static_cast<::std::uint32_t>(UINT32_C(2458638600)); } + +This is the hash of a v4 message, which inside has a hash variant. Hopefully it changes when i add to the inner HV + +[](const ::test_types::bakery::v4::Message&) { return static_cast<::std::uint32_t>(UINT32_C(3267984003)); } + +Indeed it does. + +I made ShallowReprHash. This might not be the most efficient way to do it but initially I just want to see it preserve the hashes + +[](const ::test_types::bakery::v4::Message&) { return static_cast<::std::uint32_t>(UINT32_C(2458638600)); } + +Seems like it's back to the old hash when I remove the new sprinkle_quantity from the cupcakeorder, great. + +Ok that was more work than expected, but before: + + 5 [](const ::test_types::bakery::v4::Message&) { return static_cast<::std::uint32_t>(UINT32_C(390615788)); } + +after changing a field on a struct only present in an HV + 5 [](const ::test_types::bakery::v4::Message&) { return static_cast<::std::uint32_t>(UINT32_C(390615788)); } + +after changing a field on a struct elsewhere in v4 +::std::uint32_t tag() const { + 1 return match( + 2 [](const ::test_types::bakery::v2::Message&) { return static_cast<::std::uint32_t>(UINT32_C(3831964682)); }, + 3 [](const ::test_types::bakery::v1::Message&) { return static_cast<::std::uint32_t>(UINT32_C(2782154402)); }, + 4 [](const ::test_types::bakery::v3::Message&) { return static_cast<::std::uint32_t>(UINT32_C(716972100)); }, + 5 [](const ::test_types::bakery::v4::Message&) { return static_cast<::std::uint32_t>(UINT32_C(1691428639)); } + 6 ); + +after changing a field on a struct elsewhere in v2