diff --git a/arrow/array/json_reader_test.go b/arrow/array/json_reader_test.go index d309b0c3..c1807a9f 100644 --- a/arrow/array/json_reader_test.go +++ b/arrow/array/json_reader_test.go @@ -17,6 +17,7 @@ package array_test import ( + "bufio" "bytes" "fmt" "strings" @@ -231,6 +232,35 @@ func generateJSONData(n int) []byte { return data } +func ndjsonToRecordBuilder(t *testing.T, recordBuilder *array.RecordBuilder, data string) { + scanner := bufio.NewScanner(strings.NewReader(data)) + + for scanner.Scan() { + if len(scanner.Bytes()) > 0 { + err := recordBuilder.UnmarshalJSON(scanner.Bytes()) + assert.NoError(t, err) + } + } + + assert.NoError(t, scanner.Err()) +} + +func recordBatchToNDJSON(t *testing.T, rec arrow.RecordBatch) string { + var sb strings.Builder + + arr := array.RecordToStructArray(rec) + defer arr.Release() + + for pos := range arr.Len() { + s, err := json.Marshal(arr.GetOneForMarshal(pos)) + assert.NoError(t, err) + sb.Write(s) + sb.WriteByte('\n') + } + + return sb.String() +} + func jsonArrayToNDJSON(data []byte) ([]byte, error) { var records []json.RawMessage if err := json.Unmarshal(data, &records); err != nil { diff --git a/arrow/array/record.go b/arrow/array/record.go index 0aaf771b..6e3d9dfc 100644 --- a/arrow/array/record.go +++ b/arrow/array/record.go @@ -348,6 +348,38 @@ func (b *RecordBuilder) Reserve(size int) { } } +func (b *RecordBuilder) columnLenRange() (lower, upper int) { + if len(b.fields) > 0 { + lower = b.fields[0].Len() + upper = lower + + for _, f := range b.fields[1:] { + lower = min(lower, f.Len()) + upper = max(upper, f.Len()) + } + } + return +} + +// Resize adjusts the space allocated by all the field builders to n elements. +// If n is greater than an individual builder Cap(), additional memory will be +// allocated. If n is smaller, the allocated memory may reduced. +// +// As a special case, if n equals to -1, all field builders will be resized +// to the size of the shortest one. +func (b *RecordBuilder) Resize(n int) { + if n >= 0 { + for _, f := range b.fields { + f.Resize(n) + } + } else if n == -1 { + lower, upper := b.columnLenRange() + if lower != upper { + b.Resize(lower) + } + } +} + // NewRecordBatch creates a new record batch from the memory buffers and resets the // RecordBuilder so it can be used to build a new record batch. // @@ -355,8 +387,12 @@ func (b *RecordBuilder) Reserve(size int) { // // NewRecordBatch panics if the fields' builder do not have the same length. func (b *RecordBuilder) NewRecordBatch() arrow.RecordBatch { + lower, upper := b.columnLenRange() + if lower != upper { + panic(fmt.Errorf("arrow/array: some fields have excessive number of rows (want at most %d, have %d)", lower, upper)) + } + cols := make([]arrow.Array, len(b.fields)) - rows := int64(0) defer func(cols []arrow.Array) { for _, col := range cols { @@ -369,14 +405,9 @@ func (b *RecordBuilder) NewRecordBatch() arrow.RecordBatch { for i, f := range b.fields { cols[i] = f.NewArray() - irow := int64(cols[i].Len()) - if i > 0 && irow != rows { - panic(fmt.Errorf("arrow/array: field %d has %d rows. want=%d", i, irow, rows)) - } - rows = irow } - return NewRecordBatch(b.schema, cols, rows) + return NewRecordBatch(b.schema, cols, int64(lower)) } // Deprecated: Use [NewRecordBatch] instead. diff --git a/arrow/array/record_test.go b/arrow/array/record_test.go index 5900efe7..a3924382 100644 --- a/arrow/array/record_test.go +++ b/arrow/array/record_test.go @@ -19,6 +19,7 @@ package array_test import ( "fmt" "reflect" + "strings" "testing" "github.com/apache/arrow-go/v18/arrow" @@ -531,6 +532,91 @@ func TestRecordBuilder(t *testing.T) { } } +func TestRecordBuilderResize(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + schema := arrow.NewSchema([]arrow.Field{ + {Name: "region", Type: arrow.BinaryTypes.String, Nullable: true}, + {Name: "model", Type: arrow.BinaryTypes.String}, + {Name: "sales", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, + }, nil) + + t.Run("truncate", func(t *testing.T) { + recordBuilder := array.NewRecordBuilder(mem, schema) + defer recordBuilder.Release() + + ndjsonToRecordBuilder(t, recordBuilder, jsondata) + + rb0 := recordBuilder.NewRecordBatch() + assert.Equal(t, int64(3), rb0.NumCols()) + assert.Equal(t, int64(16), rb0.NumRows()) + + data0 := recordBatchToNDJSON(t, rb0) + rb0.Release() + + ndjsonToRecordBuilder(t, recordBuilder, jsondata) + recordBuilder.Resize(8) + + rb1 := recordBuilder.NewRecordBatch() + assert.Equal(t, int64(3), rb1.NumCols()) + assert.Equal(t, int64(8), rb1.NumRows()) + + data1 := recordBatchToNDJSON(t, rb1) + rb1.Release() + + assert.True(t, strings.HasPrefix(data0, data1)) + }) + + t.Run("truncate_incomplete", func(t *testing.T) { + recordBuilder := array.NewRecordBuilder(mem, schema) + defer recordBuilder.Release() + + ndjsonToRecordBuilder(t, recordBuilder, jsondata) + + rb0 := recordBuilder.NewRecordBatch() + assert.Equal(t, int64(3), rb0.NumCols()) + assert.Equal(t, int64(16), rb0.NumRows()) + + data0 := recordBatchToNDJSON(t, rb0) + rb0.Release() + + ndjsonToRecordBuilder(t, recordBuilder, jsondata) + recordBuilder.Field(0).(*array.StringBuilder).Append("TN") + recordBuilder.Field(0).(*array.StringBuilder).Append("IL") + recordBuilder.Field(0).(*array.StringBuilder).Append("WI") + + panicked := false + func() { + defer func() { + if r := recover(); r != nil { + if strings.HasPrefix(r.(error).Error(), "arrow/array:") { + panicked = true + } else { + panic(r) + } + } + }() + + rb1 := recordBuilder.NewRecordBatch() + rb1.Release() + }() + + assert.True(t, panicked) + + recordBuilder.Resize(-1) + + rb2 := recordBuilder.NewRecordBatch() + assert.Equal(t, int64(3), rb2.NumCols()) + assert.Equal(t, int64(16), rb2.NumRows()) + + data2 := recordBatchToNDJSON(t, rb2) + rb2.Release() + + assert.Equal(t, data0, data2) + }) +} + type testMessage struct { Foo *testMessageFoo Bars []*testMessageBar