diff --git a/trees/cosigned/message.go b/trees/cosigned/message.go new file mode 100644 index 00000000000..8b2277630de --- /dev/null +++ b/trees/cosigned/message.go @@ -0,0 +1,122 @@ +// Package cosigned implements CosignedMessage from +// https://ietf-plants-wg.github.io/merkle-tree-certs/draft-ietf-plants-merkle-tree-certs.html#section-5.3.1. +package cosigned + +import ( + "bytes" + "crypto/sha256" + "errors" + "fmt" + + "golang.org/x/crypto/cryptobyte" +) + +// Message represents a CosignedMessage from +// https://ietf-plants-wg.github.io/merkle-tree-certs/draft-ietf-plants-merkle-tree-certs.html#section-5.3.1. +type Message struct { + CosignerName string + Timestamp uint64 + LogOrigin string + Start uint64 + End uint64 + SubtreeHash [sha256.Size]byte +} + +const subtreeLabel = "subtree/v1\n\x00" + +// Marshal encodes the Message as bytes. +// +// It errors if cosigner_name or log_origin are too long or too short. It does not validate semantic constraints, +// like start < end. +// +// https://ietf-plants-wg.github.io/merkle-tree-certs/draft-ietf-plants-merkle-tree-certs.html#section-5.3.1 +// opaque HashValue[HASH_SIZE]; +// +// struct { +// uint8 label[12] = "subtree/v1\n\0"; +// opaque cosigner_name<1..2^8-1>; +// uint64 timestamp; +// opaque log_origin<1..2^8-1>; +// uint64 start; +// uint64 end; +// HashValue subtree_hash; +// } CosignedMessage; +func (message *Message) Marshal() ([]byte, error) { + if len(message.CosignerName) < 1 || len(message.CosignerName) > 255 { + return nil, fmt.Errorf("invalid cosigner_name length %d", len(message.CosignerName)) + } + if len(message.LogOrigin) < 1 || len(message.LogOrigin) > 255 { + return nil, fmt.Errorf("invalid log_origin length %d", len(message.LogOrigin)) + } + + var b cryptobyte.Builder + b.AddBytes([]byte(subtreeLabel)) + b.AddUint8LengthPrefixed(func(child *cryptobyte.Builder) { + child.AddBytes([]byte(message.CosignerName)) + }) + b.AddUint64(message.Timestamp) + b.AddUint8LengthPrefixed(func(child *cryptobyte.Builder) { + child.AddBytes([]byte(message.LogOrigin)) + }) + b.AddUint64(message.Start) + b.AddUint64(message.End) + b.AddBytes(message.SubtreeHash[:]) + + return b.Bytes() +} + +// Unmarshal unmarshals the input bytes and returns a *Message. +func Unmarshal(input []byte) (*Message, error) { + var out Message + + s := cryptobyte.String(input) + var label []byte + if !s.ReadBytes(&label, len(subtreeLabel)) { + return nil, errors.New("invalid label") + } + if !bytes.Equal(label, []byte(subtreeLabel)) { + return nil, errors.New("label was not subtree/v1") + } + + var cosignerName cryptobyte.String + if !s.ReadUint8LengthPrefixed(&cosignerName) { + return nil, errors.New("invalid cosigner_name") + } + if len(cosignerName) < 1 { + return nil, errors.New("empty cosigner_name") + } + out.CosignerName = string(cosignerName) + + if !s.ReadUint64(&out.Timestamp) { + return nil, errors.New("invalid timestamp") + } + + var logOrigin cryptobyte.String + if !s.ReadUint8LengthPrefixed(&logOrigin) { + return nil, errors.New("invalid log_origin") + } + if len(logOrigin) < 1 { + return nil, errors.New("empty log_origin") + } + out.LogOrigin = string(logOrigin) + + if !s.ReadUint64(&out.Start) { + return nil, errors.New("invalid start") + } + + if !s.ReadUint64(&out.End) { + return nil, errors.New("invalid end") + } + + var subtreeHash []byte + if !s.ReadBytes(&subtreeHash, len(out.SubtreeHash)) { + return nil, errors.New("invalid subtree hash") + } + copy(out.SubtreeHash[:], subtreeHash) + + if !s.Empty() { + return nil, errors.New("trailing bytes") + } + + return &out, nil +} diff --git a/trees/cosigned/message_test.go b/trees/cosigned/message_test.go new file mode 100644 index 00000000000..b6503c21b5d --- /dev/null +++ b/trees/cosigned/message_test.go @@ -0,0 +1,126 @@ +package cosigned + +import ( + "encoding/hex" + "reflect" + "strings" + "testing" +) + +func TestMessageRoundtrip(t *testing.T) { + m := Message{ + CosignerName: "alpha", + Timestamp: 1234, + LogOrigin: "beta", + Start: 999, + End: 1000, + SubtreeHash: [32]byte{}, + } + + copy(m.SubtreeHash[:], []byte("0123456789abcdef0123456789abcdef")) + + out, err := m.Marshal() + if err != nil { + t.Fatalf("marshaling: %s", err) + } + + m2, err := Unmarshal(out) + if err != nil { + t.Fatalf("unmarshaling encoded message: %s", err) + } + + if !reflect.DeepEqual(m, *m2) { + t.Errorf("round-tripping message: got %#v, want %#v", m, *m2) + } +} + +func TestMarshalErrors(t *testing.T) { + m := Message{ + CosignerName: "Michigan", + Timestamp: 1337000, + LogOrigin: "Illinois", + Start: 9, + End: 87654321, + SubtreeHash: [32]byte{}, + } + + type testCase struct { + name, expected string + distorter func(target *Message) + } + + testCases := []testCase{ + {"short CosignerName", "invalid cosigner_name length 0", func(target *Message) { + target.CosignerName = "" + }}, + {"long CosignerName", "invalid cosigner_name length 256", func(target *Message) { + target.CosignerName = strings.Repeat("a", 256) + }}, + {"short LogOrigin", "invalid log_origin length 0", func(target *Message) { + target.LogOrigin = "" + }}, + {"long LogOrigin", "invalid log_origin length 256", func(target *Message) { + target.LogOrigin = strings.Repeat("a", 256) + }}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + m2 := m + tc.distorter(&m2) + _, err := m2.Marshal() + if err == nil { + t.Fatalf("got no error, want %q", tc.expected) + } + if err.Error() != tc.expected { + t.Errorf("marshal with short name: got %q, want %q", err, tc.expected) + } + }) + } +} + +func TestUnmarshalErrors(t *testing.T) { + m := Message{ + CosignerName: "Debut", + Timestamp: 55555, + LogOrigin: "Post", + Start: 11, + End: 22, + SubtreeHash: [32]byte{}, + } + + out, err := m.Marshal() + if err != nil { + t.Fatalf("marshal: %s", err) + } + t.Logf("%x", out) + + _, err = Unmarshal(out[:len(out)-1]) + if err == nil { + t.Errorf("unmarshal with short input: got no error") + } + + long := append(out, byte('a')) + _, err = Unmarshal(long) + if err == nil { + t.Errorf("unmarshal with trailing bytes: got no error") + } + + emptyCosigner, err := hex.DecodeString("737562747265652f76310a0000000000000000d90304506f7374000000000000000b00000000000000160000000000000000000000000000000000000000000000000000000000000000") + if err != nil { + t.Errorf("decoding hex: %s", err) + } + _, err = Unmarshal(emptyCosigner) + if err == nil { + t.Errorf("unmarshal with empty cosigner_name: got no error") + } + + emptyLogOrigin, err := hex.DecodeString("737562747265652f76310a00054465627574000000000000d90300000000000000000b00000000000000160000000000000000000000000000000000000000000000000000000000000000") + if err != nil { + t.Errorf("decoding hex: %s", err) + } + _, err = Unmarshal(emptyLogOrigin) + if err == nil { + t.Errorf("unmarshal with empty log_origin: got no error") + } +}