Skip to content

Commit 3e1fb5a

Browse files
authored
fix: move produce tagged fields inside partition loop (#127)
1 parent 5162304 commit 3e1fb5a

3 files changed

Lines changed: 142 additions & 46 deletions

File tree

pkg/protocol/request.go

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -505,16 +505,16 @@ func ParseRequest(b []byte) (*RequestHeader, Request, error) {
505505
if err != nil {
506506
return nil, nil, fmt.Errorf("read produce records: %w", err)
507507
}
508+
if flexible {
509+
if err := reader.SkipTaggedFields(); err != nil {
510+
return nil, nil, fmt.Errorf("skip partition tags: %w", err)
511+
}
512+
}
508513
partitions = append(partitions, ProducePartition{
509514
Partition: index,
510515
Records: records,
511516
})
512517
}
513-
if flexible {
514-
if err := reader.SkipTaggedFields(); err != nil {
515-
return nil, nil, fmt.Errorf("skip partition tags: %w", err)
516-
}
517-
}
518518
if flexible {
519519
if err := reader.SkipTaggedFields(); err != nil {
520520
return nil, nil, fmt.Errorf("skip topic tags: %w", err)
@@ -1931,14 +1931,12 @@ func EncodeProduceRequest(header *RequestHeader, req *ProduceRequest, version in
19311931
} else {
19321932
w.BytesWithLength(part.Records)
19331933
}
1934+
if flexible {
1935+
w.WriteTaggedFields(0)
1936+
}
19341937
}
1935-
// Match the parser: two tagged-field blocks after the partition array.
1936-
// The Kafka protocol spec places per-partition tags inside the partition
1937-
// loop, but our parser (ParseRequest) reads them outside. Since the
1938-
// broker is always KafScale (not vanilla Kafka), this is intentional.
19391938
if flexible {
19401939
w.WriteTaggedFields(0)
1941-
w.WriteTaggedFields(0)
19421940
}
19431941
}
19441942
if flexible {

pkg/protocol/request_test.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,3 +1033,113 @@ func TestEncodeFetchRequest_KmsgValidation(t *testing.T) {
10331033
t.Fatalf("fetch offset: got %d, want 42", kmsgReq.Topics[0].Partitions[0].FetchOffset)
10341034
}
10351035
}
1036+
1037+
// TestProduceMultiPartitionFranzCompat tests byte-level compatibility with
1038+
// franz-go for multi-partition produce requests in both directions:
1039+
// - franz-go encodes → KafScale parses
1040+
// - KafScale encodes → franz-go decodes
1041+
func TestProduceMultiPartitionFranzCompat(t *testing.T) {
1042+
t.Run("franz-encode-kafscale-parse", func(t *testing.T) {
1043+
req := kmsg.NewPtrProduceRequest()
1044+
req.Version = 9
1045+
req.Acks = -1
1046+
req.TimeoutMillis = 3000
1047+
topic := kmsg.NewProduceRequestTopic()
1048+
topic.Topic = "orders"
1049+
for _, pi := range []int32{0, 1, 2} {
1050+
part := kmsg.NewProduceRequestTopicPartition()
1051+
part.Partition = pi
1052+
part.Records = []byte{byte(pi + 1), byte(pi + 2)}
1053+
topic.Partitions = append(topic.Partitions, part)
1054+
}
1055+
req.Topics = append(req.Topics, topic)
1056+
body := req.AppendTo(nil)
1057+
1058+
w := newByteWriter(len(body) + 32)
1059+
w.Int16(APIKeyProduce)
1060+
w.Int16(9)
1061+
w.Int32(55)
1062+
clientID := "kgo"
1063+
w.NullableString(&clientID)
1064+
w.WriteTaggedFields(0)
1065+
w.write(body)
1066+
1067+
_, parsed, err := ParseRequest(w.Bytes())
1068+
if err != nil {
1069+
t.Fatalf("ParseRequest: %v", err)
1070+
}
1071+
got, ok := parsed.(*ProduceRequest)
1072+
if !ok {
1073+
t.Fatalf("expected *ProduceRequest, got %T", parsed)
1074+
}
1075+
if len(got.Topics) != 1 {
1076+
t.Fatalf("topic count: got %d want 1", len(got.Topics))
1077+
}
1078+
if len(got.Topics[0].Partitions) != 3 {
1079+
t.Fatalf("partition count: got %d want 3", len(got.Topics[0].Partitions))
1080+
}
1081+
for pi, part := range got.Topics[0].Partitions {
1082+
if part.Partition != int32(pi) {
1083+
t.Fatalf("part[%d] index: got %d want %d", pi, part.Partition, pi)
1084+
}
1085+
want := []byte{byte(pi + 1), byte(pi + 2)}
1086+
if string(part.Records) != string(want) {
1087+
t.Fatalf("part[%d] records: got %x want %x", pi, part.Records, want)
1088+
}
1089+
}
1090+
})
1091+
1092+
t.Run("kafscale-encode-franz-parse", func(t *testing.T) {
1093+
header := &RequestHeader{
1094+
APIKey: APIKeyProduce,
1095+
APIVersion: 9,
1096+
CorrelationID: 66,
1097+
ClientID: strPtr("test"),
1098+
}
1099+
req := &ProduceRequest{
1100+
Acks: -1,
1101+
TimeoutMs: 3000,
1102+
Topics: []ProduceTopic{
1103+
{
1104+
Name: "orders",
1105+
Partitions: []ProducePartition{
1106+
{Partition: 0, Records: []byte{1, 2}},
1107+
{Partition: 1, Records: []byte{3, 4}},
1108+
{Partition: 2, Records: []byte{5, 6}},
1109+
},
1110+
},
1111+
},
1112+
}
1113+
encoded, err := EncodeProduceRequest(header, req, 9)
1114+
if err != nil {
1115+
t.Fatalf("encode: %v", err)
1116+
}
1117+
1118+
_, reader, err := ParseRequestHeader(encoded)
1119+
if err != nil {
1120+
t.Fatalf("ParseRequestHeader: %v", err)
1121+
}
1122+
bodyStart := len(encoded) - reader.remaining()
1123+
1124+
kmsgReq := kmsg.NewPtrProduceRequest()
1125+
kmsgReq.Version = 9
1126+
if err := kmsgReq.ReadFrom(encoded[bodyStart:]); err != nil {
1127+
t.Fatalf("kmsg.ReadFrom: %v", err)
1128+
}
1129+
if len(kmsgReq.Topics) != 1 {
1130+
t.Fatalf("topic count: got %d want 1", len(kmsgReq.Topics))
1131+
}
1132+
if len(kmsgReq.Topics[0].Partitions) != 3 {
1133+
t.Fatalf("partition count: got %d want 3", len(kmsgReq.Topics[0].Partitions))
1134+
}
1135+
for pi, part := range kmsgReq.Topics[0].Partitions {
1136+
if part.Partition != int32(pi) {
1137+
t.Fatalf("part[%d] index: got %d want %d", pi, part.Partition, pi)
1138+
}
1139+
want := []byte{byte(pi*2 + 1), byte(pi*2 + 2)}
1140+
if string(part.Records) != string(want) {
1141+
t.Fatalf("part[%d] records: got %x want %x", pi, part.Records, want)
1142+
}
1143+
}
1144+
})
1145+
}

pkg/protocol/response_test.go

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,77 +1518,65 @@ func TestParseProduceResponseRoundTrip(t *testing.T) {
15181518
}
15191519

15201520
func TestEncodeProduceRequestRoundTrip(t *testing.T) {
1521-
header := &RequestHeader{
1522-
APIKey: APIKeyProduce,
1523-
APIVersion: 9,
1524-
CorrelationID: 77,
1525-
ClientID: strPtr("test-client"),
1526-
}
15271521
req := &ProduceRequest{
15281522
Acks: -1,
15291523
TimeoutMs: 5000,
15301524
Topics: []ProduceTopic{
15311525
{
15321526
Name: "orders",
15331527
Partitions: []ProducePartition{
1534-
{Partition: 0, Records: []byte{1, 2, 3, 4}},
1535-
{Partition: 1, Records: []byte{5, 6}},
1528+
{Partition: 0, Records: []byte{1, 2, 3}},
1529+
{Partition: 1, Records: []byte{4, 5}},
1530+
{Partition: 2, Records: []byte{6, 7, 8, 9}},
15361531
},
15371532
},
15381533
{
15391534
Name: "events",
15401535
Partitions: []ProducePartition{
1541-
{Partition: 0, Records: []byte{7, 8, 9}},
1536+
{Partition: 0, Records: []byte{10}},
1537+
{Partition: 3, Records: []byte{11, 12}},
15421538
},
15431539
},
15441540
},
15451541
}
15461542
for _, version := range []int16{3, 5, 7, 8, 9, 10} {
15471543
t.Run(fmt.Sprintf("v%d", version), func(t *testing.T) {
1548-
h := &RequestHeader{
1544+
header := &RequestHeader{
15491545
APIKey: APIKeyProduce,
15501546
APIVersion: version,
1551-
CorrelationID: header.CorrelationID,
1552-
ClientID: header.ClientID,
1547+
CorrelationID: 77,
1548+
ClientID: strPtr("test-client"),
15531549
}
1554-
encoded, err := EncodeProduceRequest(h, req, version)
1550+
encoded, err := EncodeProduceRequest(header, req, version)
15551551
if err != nil {
15561552
t.Fatalf("encode: %v", err)
15571553
}
1558-
parsedHeader, parsedReq, err := ParseRequest(encoded)
1554+
_, parsedReq, err := ParseRequest(encoded)
15591555
if err != nil {
15601556
t.Fatalf("parse: %v", err)
15611557
}
1562-
if parsedHeader.CorrelationID != h.CorrelationID {
1563-
t.Fatalf("correlation id: got %d want %d", parsedHeader.CorrelationID, h.CorrelationID)
1564-
}
1565-
produceReq, ok := parsedReq.(*ProduceRequest)
1558+
got, ok := parsedReq.(*ProduceRequest)
15661559
if !ok {
15671560
t.Fatalf("expected *ProduceRequest, got %T", parsedReq)
15681561
}
1569-
if produceReq.Acks != req.Acks {
1570-
t.Fatalf("acks: got %d want %d", produceReq.Acks, req.Acks)
1562+
if len(got.Topics) != len(req.Topics) {
1563+
t.Fatalf("topic count: got %d want %d", len(got.Topics), len(req.Topics))
15711564
}
1572-
if produceReq.TimeoutMs != req.TimeoutMs {
1573-
t.Fatalf("timeout: got %d want %d", produceReq.TimeoutMs, req.TimeoutMs)
1574-
}
1575-
if len(produceReq.Topics) != len(req.Topics) {
1576-
t.Fatalf("topic count: got %d want %d", len(produceReq.Topics), len(req.Topics))
1577-
}
1578-
for ti, topic := range produceReq.Topics {
1579-
if topic.Name != req.Topics[ti].Name {
1580-
t.Fatalf("topic[%d] name: got %q want %q", ti, topic.Name, req.Topics[ti].Name)
1565+
for ti, topic := range got.Topics {
1566+
want := req.Topics[ti]
1567+
if topic.Name != want.Name {
1568+
t.Fatalf("topic[%d] name: got %q want %q", ti, topic.Name, want.Name)
15811569
}
1582-
if len(topic.Partitions) != len(req.Topics[ti].Partitions) {
1583-
t.Fatalf("topic[%d] partition count: got %d want %d", ti, len(topic.Partitions), len(req.Topics[ti].Partitions))
1570+
if len(topic.Partitions) != len(want.Partitions) {
1571+
t.Fatalf("topic[%d] partition count: got %d want %d", ti, len(topic.Partitions), len(want.Partitions))
15841572
}
15851573
for pi, part := range topic.Partitions {
1586-
want := req.Topics[ti].Partitions[pi]
1587-
if part.Partition != want.Partition {
1588-
t.Fatalf("topic[%d].part[%d] index: got %d want %d", ti, pi, part.Partition, want.Partition)
1574+
wantPart := want.Partitions[pi]
1575+
if part.Partition != wantPart.Partition {
1576+
t.Fatalf("topic[%d].part[%d] index: got %d want %d", ti, pi, part.Partition, wantPart.Partition)
15891577
}
1590-
if string(part.Records) != string(want.Records) {
1591-
t.Fatalf("topic[%d].part[%d] records: got %v want %v", ti, pi, part.Records, want.Records)
1578+
if string(part.Records) != string(wantPart.Records) {
1579+
t.Fatalf("topic[%d].part[%d] records: got %x want %x", ti, pi, part.Records, wantPart.Records)
15921580
}
15931581
}
15941582
}

0 commit comments

Comments
 (0)