Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tako/core/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def compile_proto(
if isinstance(comp_constants, list):
return comp_constants

# Q: Are conversions affected by skippable hash variants in any meaningful way?
comp_conversions = conversions.compile(
proto.name, proto.conversions, comp_types.types
)
Expand Down
5 changes: 5 additions & 0 deletions python/tako/core/compiler/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def compile(
if errors:
return errors

# Builds up a map from master to slave fields. This is the sort of reverse
# action that takes place during parsing. At parse time, a Seq's length is first
# read, then the sequence length is known. At build time, the Seq is built first
# and the previous field is later populated. This master_field_map seems focused
# on the "builder" problem
master_field_map = master_fields.run(lowered, type_order)
if isinstance(master_field_map, list):
return master_field_map
Expand Down
1 change: 1 addition & 0 deletions python/tako/core/compiler/types/fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def visit_fixed_variant(self, variant: mir.FixedVariant) -> lir.RootType:
trivial=self.tmap[variant.name],
name=variant.name,
digest=self.digest_map[variant.name],
len_type=checked_cast(lir.Int, variant.len_type.accept(self)) if variant.len_type else None,
tag_type=checked_cast(lir.Int, variant.tag_type.accept(self)),
tags={
checked_cast(lir.Struct, sr.accept(self)): value
Expand Down
18 changes: 9 additions & 9 deletions python/tako/core/compiler/types/hash_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import dataclasses
from tako.core.compiler.types import mir
from tako.core.error import Error
from tako.core.repr_str import ReprStr
from tako.core.repr_str import ShallowReprStr
from tako.util.qname import QName
import hashlib

Expand Down Expand Up @@ -58,20 +58,20 @@ class HashExpand(
mir.VariantVisitor[t.Union[mir.FixedVariant, Error]],
):
types: t.Dict[QName, mir.RootType]

def digest(self, rt: mir.RootType) -> Digest:
x = rt.accept(ReprStr(self.types))
def shallow_digest(self, rt: mir.RootType) -> Digest:
x = rt.accept(ShallowReprStr(self.types))
return Digest(repr_str=x, repr_hash=sha256hex(x))

def visit_struct(self, root: mir.Struct) -> t.Union[HashExpandResult, Error]:
return HashExpandResult(self.digest(root), None)
return HashExpandResult(self.shallow_digest(root), None)

def visit_variant(self, root: mir.Variant) -> t.Union[HashExpandResult, Error]:
fixed = root.accept_v(self)
if isinstance(fixed, Error):
return fixed
else:
return HashExpandResult(self.digest(fixed), fixed)
return HashExpandResult(self.shallow_digest(fixed), fixed)

def visit_fixed_variant(
self, variant: mir.FixedVariant
Expand All @@ -87,7 +87,7 @@ def visit_hash_variant(
tag_map: t.Dict[mir.StructRef, int] = {}
inv_tag_map: t.Dict[int, mir.StructRef] = {}
for type_ in variant.types():
hash_hex = self.digest(type_.resolve(self.types)).repr_hash
hash_hex = self.shallow_digest(type_.resolve(self.types)).repr_hash
short = int(hash_hex[:tag_width_hex_digits], 16)
if short in inv_tag_map:
return Error(
Expand All @@ -96,7 +96,7 @@ def visit_hash_variant(
tag_map[type_] = short
inv_tag_map[short] = type_

return mir.FixedVariant(variant.name, variant.tag_type, tag_map)
return mir.FixedVariant(variant.name, variant.tag_type, tag_map, variant.len_type)

def visit_enum(self, root: mir.Enum) -> t.Union[HashExpandResult, Error]:
return HashExpandResult(self.digest(root), None)
return HashExpandResult(self.shallow_digest(root), None)
1 change: 1 addition & 0 deletions python/tako/core/compiler/types/lir.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class MasterField:
@dataclasses.dataclass(frozen=True)
class Variant(RootType):
tag_type: Int
len_type: Optional[Int]
tags: t.Dict[Struct, int] = dataclasses.field(compare=False)

def accept_rtv(self, visitor: RootTypeVisitor[T]) -> T:
Expand Down
10 changes: 8 additions & 2 deletions python/tako/core/compiler/types/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,23 @@ def visit_variant_def(self, type_: pt.VariantDef) -> mir.RootType:
checked_cast(mir.StructRef, struct.accept(Lower())): value
for struct, value in type_.variants.items()
},
False
)

def visit_hash_variant_def(self, type_: pt.HashVariantDef) -> mir.RootType:
# Cast to the mir int types
mir_tag_type = checked_cast(mir.Int, type_.tag_type.accept(Lower()))
mir_len_type = checked_cast(mir.Int, type_.len_type.accept(Lower())) if type_.len_type else None
return mir.HashVariant(
type_.qualified_name(),
checked_cast(mir.Int, type_.tag_type.accept(Lower())),
mir_tag_type,
set(
[
checked_cast(mir.StructRef, struct.accept(Lower()))
for struct in type_.hash_types
]
),
mir_len_type
)


Expand Down Expand Up @@ -88,6 +93,7 @@ def visit_detached_variant(self, type_: pt.DetachedVariant) -> mir.Type:
return mir.DetachedVariant(
mir.VariantRef(type_.variant.qualified_name()),
mir.FieldReference(type_.tag.name),
False
)

def visit_virtual(self, type_: pt.Virtual) -> mir.Type:
Expand All @@ -103,4 +109,4 @@ def visit_variant_def(self, type_: pt.VariantDef) -> mir.Type:
return mir.VariantRef(type_.qualified_name())

def visit_hash_variant_def(self, type_: pt.HashVariantDef) -> mir.Type:
return mir.VariantRef(type_.qualified_name())
return mir.HashVariantRef(type_.qualified_name())
3 changes: 1 addition & 2 deletions python/tako/core/compiler/types/master_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class MasterField:
master_field: str
key_property: KeyProperty


@dataclasses.dataclass(frozen=True)
class DeterminedField:
determined_field: str
Expand Down Expand Up @@ -79,7 +78,7 @@ def visit_enum(

# Map from a field A to a field B where the value of A (slave field) is determined by the the value
# of field B (master field).
# For example, if a struct has these fields: {"len": li32, "data", Seq(li32, this.len)},
# For example, if a struct has these fields: {"len": li32, "data": Seq(li32, this.len)},
# then this function would return {"len": "data"} because the value of len is determined
# by the value in data.
# This is needed for generating builders: should a given field be included in the generated
Expand Down
13 changes: 12 additions & 1 deletion python/tako/core/compiler/types/mir.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def visit_hash_variant(self, variant: HashVariant) -> T:
@dataclasses.dataclass(frozen=True)
class FixedVariant(Variant):
tags: t.Dict[StructRef, int]
len_type: t.Optional[Int]

def types(self) -> t.Iterable[StructRef]:
return self.tags.keys()
Expand All @@ -94,6 +95,7 @@ def accept_v(self, visitor: VariantVisitor[T]) -> T:
@dataclasses.dataclass(frozen=True)
class HashVariant(Variant):
hash_types: t.Set[StructRef]
len_type: t.Optional[Int]

def types(self) -> t.Iterable[StructRef]:
return self.hash_types
Expand Down Expand Up @@ -280,11 +282,11 @@ def accept(self, visitor: LengthVisitor[T]) -> T:
class DetachedVariant(Type):
variant: VariantRef
tag: FieldReference
is_hash_tag: bool

def accept(self, visitor: TypeVisitor[T]) -> T:
return visitor.visit_detached_variant(self)


@dataclasses.dataclass(frozen=True)
class Virtual(Type):
inner: Type
Expand Down Expand Up @@ -338,6 +340,15 @@ def resolve(self, context: t.Dict[QName, RootType]) -> Variant:
def accept_r(self, visitor: RefVisitor[T]) -> T:
return visitor.visit_variant_ref(self)

@dataclasses.dataclass(frozen=True)
class HashVariantRef(VariantRef):
def resolve(self, context: t.Dict[QName, RootType]) -> Variant:
return checked_cast(Variant, context[self.name])

def accept_r(self, visitor: RefVisitor[T]) -> T:
# TODO clean up this type warning
return visitor.visit_variant_ref(self)


@dataclasses.dataclass(frozen=True)
class EnumRef(Ref):
Expand Down
5 changes: 5 additions & 0 deletions python/tako/core/compiler/types/size.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,13 @@ def visit_struct(self, root: mir.Struct) -> RootSizeResult:

def visit_variant(self, root: mir.Variant) -> RootSizeResult:
target_size: t.Optional[st.Constant] = None
# Mark
if isinstance(root, mir.FixedVariant) and root.len_type:
print(f"Detected variant with embedded length: {root.name}")
return RootSizeResult(st.Dynamic(), None)
for sr in root.types():
size = self.size_map[sr.name].size
print(f"Contained size of {root.name} member {sr.name} is {size}")
if not isinstance(size, st.Constant):
return RootSizeResult(st.Dynamic(), None)
if target_size is None:
Expand Down
10 changes: 8 additions & 2 deletions python/tako/core/compiler/types/variant_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,17 @@ class VariantExpand(mir.RootTypeVisitor[t.Optional[mir.Struct]]):
def visit_struct(self, root: mir.Struct) -> t.Optional[mir.Struct]:
new_fields: t.Dict[str, mir.Type] = {}
for fname, ftype in root.fields.items():
if isinstance(ftype, mir.VariantRef):
if isinstance(ftype, mir.HashVariantRef):
new_fname = f"{fname}_injected_key_"
new_fields[new_fname] = ftype.resolve(self.types).tag_type
new_fields[fname] = mir.DetachedVariant(
variant=ftype, tag=mir.FieldReference(new_fname)
variant=ftype, tag=mir.FieldReference(new_fname), is_hash_tag = True
)
elif isinstance(ftype, mir.VariantRef):
new_fname = f"{fname}_injected_key_"
new_fields[new_fname] = ftype.resolve(self.types).tag_type
new_fields[fname] = mir.DetachedVariant(
variant=ftype, tag=mir.FieldReference(new_fname), is_hash_tag = False
)
else:
new_fields[fname] = ftype
Expand Down
88 changes: 88 additions & 0 deletions python/tako/core/repr_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,91 @@ def visit_enum(self, root: mir.Enum) -> str:
pairs = sorted([(value, name) for name, value in root.variants.items()])
variants = ",".join([f"{value}={name}" for value, name in pairs])
return f"Enum(name={root.name},underlying={root.underlying_type.accept(self)},variants={{{variants}}})"

## This is used to construct the shallow hash.
@dataclasses.dataclass
class ShallowReprStr(
mir.RootTypeVisitor[str],
mir.VariantVisitor[str],
mir.SeqTypeVisitor[str],
mir.LengthVisitor[str],
):
types: t.Dict[QName, mir.RootType]

def visit_int(self, type_: mir.Int) -> str:
return f"Int(width={type_.width},sign={type_.sign.name},endianness={type_.endianness.name})"

def visit_float(self, type_: mir.Float) -> str:
return f"Float(width={type_.width},endianness={type_.endianness.name})"

def repr_field_reference(self, fr: mir.FieldReference) -> str:
return f"FieldReference(name={fr.name})"

def visit_seq(self, type_: mir.Seq) -> str:
# Note that the repr is done on the type before seq reduce
# -- there is only Seq, but not List, Vector, or Array.
# This is deliberate - the representation of the type is the intended to
# be a representation that is as simple as possible, but conveys everything
# needed to parse or serialize a type from a wire representation.
# List, Vector, and Array do not impact the wire representation, and hence
# are not used.
length = type_.length.accept(self)
return f"Seq(inner={type_.inner.accept(self)},length={length})"

def visit_unbound_seq(self, type_: mir.UnboundSeq) -> str:
raise InternalError()

def visit_fixed_length(self, length: mir.FixedLength) -> str:
return f"{length.length}"

def visit_variable_length(self, length: mir.VariableLength) -> str:
return self.repr_field_reference(length.length)

def visit_detached_variant(self, type_: mir.DetachedVariant) -> str:
variant = type_.variant.resolve(self.types)
return f"DetachedVariant(variant={variant.accept(self)},tag={self.repr_field_reference(type_.tag)})"

def visit_virtual(self, type_: mir.Virtual) -> str:
# Virtual types contribute to the hash just like normal types, even though they have no
# effect on the wire representation. They allow a type to represent that there is
# some other data on the wire related to it -- virtual fields are parsed
# using the context of the rest of the type.
return f"Virtual(inner={type_.inner.accept(self)})"

def visit_ref(self, type_: mir.Ref) -> str:
return type_.resolve(self.types).accept(self)

def visit_struct(self, root: mir.Struct) -> str:
# Including the name of the fields is critical -- otherwise the struct Foo { x: int, y: int }
# is the same as the struct Foo { y: int, x: int }.
# Names provide meaning to fields.
# Note that no derrived information, like the size or offset of fields is included. That could all
# be computed from this and is extraneous.
l = []
for name, type_ in root.fields.items():
if isinstance(type_, mir.DetachedVariant) and type_.is_hash_tag:
l.append(f"{name}=HashVariant")
else:
l.append(f"{name}={type_.accept(self)}")
fields = ",".join(l)
return f"Struct(name={root.name},fields={{{fields}}})"

def visit_variant(self, root: mir.Variant) -> str:
return root.accept_v(self)

def visit_fixed_variant(self, root: mir.FixedVariant) -> str:
# Sort the variant by tag to ensure order doesn't matter
pairs = sorted([(tag, sr) for sr, tag in root.tags.items()])
variants = ",".join(
[f"{tag}={value.resolve(self.types).accept(self)}" for tag, value in pairs]
)
return f"Variant(name={root.name},tag_type={root.tag_type.accept(self)},variants={{{variants}}})"

def visit_hash_variant(self, root: mir.HashVariant) -> str:
raise InternalError()

def visit_enum(self, root: mir.Enum) -> str:
# Like variant
pairs = sorted([(value, name) for name, value in root.variants.items()])
variants = ",".join([f"{value}={name}" for value, name in pairs])
return f"Enum(name={root.name},underlying={root.underlying_type.accept(self)},variants={{{variants}}})"
1 change: 0 additions & 1 deletion python/tako/core/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
kir = constants.lir
cir = conversions.lir


@dataclasses.dataclass(frozen=True)
class Protocol:
name: QName
Expand Down
15 changes: 12 additions & 3 deletions python/tako/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,26 +268,35 @@ def make_variant(variants: t.Dict[StructDef, int]) -> VariantDef:
class HashVariantDef(RootType):
tag_type: Int
hash_types: t.List[StructDef]
len_type: t.Optional[Int]

def accept(self, visitor: TypeVisitor[T]) -> T:
return visitor.visit_hash_variant_def(self)

def accept_rtv(self, visitor: RootTypeVisitor[T]) -> T:
return visitor.visit_hash_variant_def(self)


class HashVariantHelper:
def __getitem__(
self, tag_type: Int
) -> t.Callable[[t.List[StructDef]], HashVariantDef]:
def make_hash_variant(hash_types: t.List[StructDef]) -> HashVariantDef:
return HashVariantDef(tag_type, hash_types)
return HashVariantDef(tag_type, hash_types, None)

return make_hash_variant

class SkippableHashVariantHelper:
def __getitem__(
self, tag_len_type: t.Tuple[Int, ...]
) -> t.Callable[[t.List[StructDef]], HashVariantDef]:
def make_hash_variant(hash_types: t.List[StructDef]) -> HashVariantDef:
return HashVariantDef(tag_len_type[0], hash_types, tag_len_type[1])

return make_hash_variant


HashVariant = HashVariantHelper()

SkippableHashVariant = SkippableHashVariantHelper()

@dataclasses.dataclass(eq=False)
class Seq(Type):
Expand Down
2 changes: 2 additions & 0 deletions python/tako/generators/lsir/lsir.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def generate_into(self, proto: Protocol, out_dir: Path, args: t.Any) -> None:
with (out_dir / Path(str(proto.name) + ".json")).open("w") as out:
json.dump(dform, out, indent=4)

# TODO this should produce a real value
def list_outputs(
self, proto_qname: QName, args: t.Any
) -> t.Generator[Path, None, None]:
Expand Down Expand Up @@ -104,6 +105,7 @@ def visit_variant(self, root: tir.Variant) -> t.Dict[str, t.Any]:
root,
{
"tag_type": root.tag_type.accept(TypeLsir()),
"len_type": root.len_type.accept(TypeLsir()) if root.len_type else None,
"variants": {
f"{struct.name}": value for struct, value in root.tags.items()
},
Expand Down
9 changes: 9 additions & 0 deletions python/test_types/actor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from test_types.bakery import Bakery
from test_types.bakery.v1 import V1
from test_types.bakery.v2 import V2

class BakeryActor:
## TODO fix field access
## TODO fix variant access
## Today
Produces = [Bakery.Packet.payload[V2.Message]]
Loading