From ff7fc4c4e7ae2bca92faae3d22a0995a1eda59d7 Mon Sep 17 00:00:00 2001 From: Alex Dubov Date: Fri, 8 May 2026 18:56:23 +1000 Subject: [PATCH] RecordBuilder: add Resize method Presently added Resize method addresses two, relatively common needs: 1. All fields within the Builder may need to be resized to the same length. 2. As part of error recovery process, fields may need to be truncated to the length of the shortest one, effectively discarding incomplete "rows". RecordBuilder.NewRecordBatch now performs row length validation prior to any action, to ensure reusability of RecordBuilder across errors. This also prevents hard to reason about memory leaks. Fixes #796 Unit test for RecordBuilder.Resize --- arrow/array/json_reader_test.go | 30 ++++++++++++ arrow/array/record.go | 45 ++++++++++++++--- arrow/array/record_test.go | 86 +++++++++++++++++++++++++++++++++ 3 files changed, 154 insertions(+), 7 deletions(-) diff --git a/arrow/array/json_reader_test.go b/arrow/array/json_reader_test.go index d309b0c39..c1807a9f2 100644 --- a/arrow/array/json_reader_test.go +++ b/arrow/array/json_reader_test.go @@ -17,6 +17,7 @@ package array_test import ( + "bufio" "bytes" "fmt" "strings" @@ -231,6 +232,35 @@ func generateJSONData(n int) []byte { return data } +func ndjsonToRecordBuilder(t *testing.T, recordBuilder *array.RecordBuilder, data string) { + scanner := bufio.NewScanner(strings.NewReader(data)) + + for scanner.Scan() { + if len(scanner.Bytes()) > 0 { + err := recordBuilder.UnmarshalJSON(scanner.Bytes()) + assert.NoError(t, err) + } + } + + assert.NoError(t, scanner.Err()) +} + +func recordBatchToNDJSON(t *testing.T, rec arrow.RecordBatch) string { + var sb strings.Builder + + arr := array.RecordToStructArray(rec) + defer arr.Release() + + for pos := range arr.Len() { + s, err := json.Marshal(arr.GetOneForMarshal(pos)) + assert.NoError(t, err) + sb.Write(s) + sb.WriteByte('\n') + } + + return sb.String() +} + func jsonArrayToNDJSON(data []byte) ([]byte, error) { var records []json.RawMessage if err := json.Unmarshal(data, &records); err != nil { diff --git a/arrow/array/record.go b/arrow/array/record.go index 0aaf771b5..6e3d9dfca 100644 --- a/arrow/array/record.go +++ b/arrow/array/record.go @@ -348,6 +348,38 @@ func (b *RecordBuilder) Reserve(size int) { } } +func (b *RecordBuilder) columnLenRange() (lower, upper int) { + if len(b.fields) > 0 { + lower = b.fields[0].Len() + upper = lower + + for _, f := range b.fields[1:] { + lower = min(lower, f.Len()) + upper = max(upper, f.Len()) + } + } + return +} + +// Resize adjusts the space allocated by all the field builders to n elements. +// If n is greater than an individual builder Cap(), additional memory will be +// allocated. If n is smaller, the allocated memory may reduced. +// +// As a special case, if n equals to -1, all field builders will be resized +// to the size of the shortest one. +func (b *RecordBuilder) Resize(n int) { + if n >= 0 { + for _, f := range b.fields { + f.Resize(n) + } + } else if n == -1 { + lower, upper := b.columnLenRange() + if lower != upper { + b.Resize(lower) + } + } +} + // NewRecordBatch creates a new record batch from the memory buffers and resets the // RecordBuilder so it can be used to build a new record batch. // @@ -355,8 +387,12 @@ func (b *RecordBuilder) Reserve(size int) { // // NewRecordBatch panics if the fields' builder do not have the same length. func (b *RecordBuilder) NewRecordBatch() arrow.RecordBatch { + lower, upper := b.columnLenRange() + if lower != upper { + panic(fmt.Errorf("arrow/array: some fields have excessive number of rows (want at most %d, have %d)", lower, upper)) + } + cols := make([]arrow.Array, len(b.fields)) - rows := int64(0) defer func(cols []arrow.Array) { for _, col := range cols { @@ -369,14 +405,9 @@ func (b *RecordBuilder) NewRecordBatch() arrow.RecordBatch { for i, f := range b.fields { cols[i] = f.NewArray() - irow := int64(cols[i].Len()) - if i > 0 && irow != rows { - panic(fmt.Errorf("arrow/array: field %d has %d rows. want=%d", i, irow, rows)) - } - rows = irow } - return NewRecordBatch(b.schema, cols, rows) + return NewRecordBatch(b.schema, cols, int64(lower)) } // Deprecated: Use [NewRecordBatch] instead. diff --git a/arrow/array/record_test.go b/arrow/array/record_test.go index 5900efe7f..a3924382a 100644 --- a/arrow/array/record_test.go +++ b/arrow/array/record_test.go @@ -19,6 +19,7 @@ package array_test import ( "fmt" "reflect" + "strings" "testing" "github.com/apache/arrow-go/v18/arrow" @@ -531,6 +532,91 @@ func TestRecordBuilder(t *testing.T) { } } +func TestRecordBuilderResize(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + schema := arrow.NewSchema([]arrow.Field{ + {Name: "region", Type: arrow.BinaryTypes.String, Nullable: true}, + {Name: "model", Type: arrow.BinaryTypes.String}, + {Name: "sales", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, + }, nil) + + t.Run("truncate", func(t *testing.T) { + recordBuilder := array.NewRecordBuilder(mem, schema) + defer recordBuilder.Release() + + ndjsonToRecordBuilder(t, recordBuilder, jsondata) + + rb0 := recordBuilder.NewRecordBatch() + assert.Equal(t, int64(3), rb0.NumCols()) + assert.Equal(t, int64(16), rb0.NumRows()) + + data0 := recordBatchToNDJSON(t, rb0) + rb0.Release() + + ndjsonToRecordBuilder(t, recordBuilder, jsondata) + recordBuilder.Resize(8) + + rb1 := recordBuilder.NewRecordBatch() + assert.Equal(t, int64(3), rb1.NumCols()) + assert.Equal(t, int64(8), rb1.NumRows()) + + data1 := recordBatchToNDJSON(t, rb1) + rb1.Release() + + assert.True(t, strings.HasPrefix(data0, data1)) + }) + + t.Run("truncate_incomplete", func(t *testing.T) { + recordBuilder := array.NewRecordBuilder(mem, schema) + defer recordBuilder.Release() + + ndjsonToRecordBuilder(t, recordBuilder, jsondata) + + rb0 := recordBuilder.NewRecordBatch() + assert.Equal(t, int64(3), rb0.NumCols()) + assert.Equal(t, int64(16), rb0.NumRows()) + + data0 := recordBatchToNDJSON(t, rb0) + rb0.Release() + + ndjsonToRecordBuilder(t, recordBuilder, jsondata) + recordBuilder.Field(0).(*array.StringBuilder).Append("TN") + recordBuilder.Field(0).(*array.StringBuilder).Append("IL") + recordBuilder.Field(0).(*array.StringBuilder).Append("WI") + + panicked := false + func() { + defer func() { + if r := recover(); r != nil { + if strings.HasPrefix(r.(error).Error(), "arrow/array:") { + panicked = true + } else { + panic(r) + } + } + }() + + rb1 := recordBuilder.NewRecordBatch() + rb1.Release() + }() + + assert.True(t, panicked) + + recordBuilder.Resize(-1) + + rb2 := recordBuilder.NewRecordBatch() + assert.Equal(t, int64(3), rb2.NumCols()) + assert.Equal(t, int64(16), rb2.NumRows()) + + data2 := recordBatchToNDJSON(t, rb2) + rb2.Release() + + assert.Equal(t, data0, data2) + }) +} + type testMessage struct { Foo *testMessageFoo Bars []*testMessageBar