Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions arrow/array/json_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package array_test

import (
"bufio"
"bytes"
"fmt"
"strings"
Expand Down Expand Up @@ -231,6 +232,35 @@ func generateJSONData(n int) []byte {
return data
}

func ndjsonToRecordBuilder(t *testing.T, recordBuilder *array.RecordBuilder, data string) {
scanner := bufio.NewScanner(strings.NewReader(data))

for scanner.Scan() {
if len(scanner.Bytes()) > 0 {
err := recordBuilder.UnmarshalJSON(scanner.Bytes())
assert.NoError(t, err)
}
}

assert.NoError(t, scanner.Err())
}

func recordBatchToNDJSON(t *testing.T, rec arrow.RecordBatch) string {
var sb strings.Builder

arr := array.RecordToStructArray(rec)
defer arr.Release()

for pos := range arr.Len() {
s, err := json.Marshal(arr.GetOneForMarshal(pos))
assert.NoError(t, err)
sb.Write(s)
sb.WriteByte('\n')
}

return sb.String()
}

func jsonArrayToNDJSON(data []byte) ([]byte, error) {
var records []json.RawMessage
if err := json.Unmarshal(data, &records); err != nil {
Expand Down
45 changes: 38 additions & 7 deletions arrow/array/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,15 +348,51 @@ func (b *RecordBuilder) Reserve(size int) {
}
}

func (b *RecordBuilder) columnLenRange() (lower, upper int) {
if len(b.fields) > 0 {
lower = b.fields[0].Len()
upper = lower

for _, f := range b.fields[1:] {
lower = min(lower, f.Len())
upper = max(upper, f.Len())
}
}
return
}

// Resize adjusts the space allocated by all the field builders to n elements.
// If n is greater than an individual builder Cap(), additional memory will be
// allocated. If n is smaller, the allocated memory may reduced.
//
// As a special case, if n equals to -1, all field builders will be resized
// to the size of the shortest one.
func (b *RecordBuilder) Resize(n int) {
if n >= 0 {
for _, f := range b.fields {
f.Resize(n)
}
} else if n == -1 {
lower, upper := b.columnLenRange()
if lower != upper {
b.Resize(lower)
}
}
}

// NewRecordBatch creates a new record batch from the memory buffers and resets the
// RecordBuilder so it can be used to build a new record batch.
//
// The returned RecordBatch must be Release()'d after use.
//
// NewRecordBatch panics if the fields' builder do not have the same length.
func (b *RecordBuilder) NewRecordBatch() arrow.RecordBatch {
lower, upper := b.columnLenRange()
if lower != upper {
panic(fmt.Errorf("arrow/array: some fields have excessive number of rows (want at most %d, have %d)", lower, upper))
}

cols := make([]arrow.Array, len(b.fields))
rows := int64(0)

defer func(cols []arrow.Array) {
for _, col := range cols {
Expand All @@ -369,14 +405,9 @@ func (b *RecordBuilder) NewRecordBatch() arrow.RecordBatch {

for i, f := range b.fields {
cols[i] = f.NewArray()
irow := int64(cols[i].Len())
if i > 0 && irow != rows {
panic(fmt.Errorf("arrow/array: field %d has %d rows. want=%d", i, irow, rows))
}
rows = irow
}

return NewRecordBatch(b.schema, cols, rows)
return NewRecordBatch(b.schema, cols, int64(lower))
}

// Deprecated: Use [NewRecordBatch] instead.
Expand Down
86 changes: 86 additions & 0 deletions arrow/array/record_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package array_test
import (
"fmt"
"reflect"
"strings"
"testing"

"github.com/apache/arrow-go/v18/arrow"
Expand Down Expand Up @@ -531,6 +532,91 @@ func TestRecordBuilder(t *testing.T) {
}
}

func TestRecordBuilderResize(t *testing.T) {
mem := memory.NewCheckedAllocator(memory.NewGoAllocator())
defer mem.AssertSize(t, 0)

schema := arrow.NewSchema([]arrow.Field{
{Name: "region", Type: arrow.BinaryTypes.String, Nullable: true},
{Name: "model", Type: arrow.BinaryTypes.String},
{Name: "sales", Type: arrow.PrimitiveTypes.Float64, Nullable: true},
}, nil)

t.Run("truncate", func(t *testing.T) {
recordBuilder := array.NewRecordBuilder(mem, schema)
defer recordBuilder.Release()

ndjsonToRecordBuilder(t, recordBuilder, jsondata)

rb0 := recordBuilder.NewRecordBatch()
assert.Equal(t, int64(3), rb0.NumCols())
assert.Equal(t, int64(16), rb0.NumRows())

data0 := recordBatchToNDJSON(t, rb0)
rb0.Release()

ndjsonToRecordBuilder(t, recordBuilder, jsondata)
recordBuilder.Resize(8)

rb1 := recordBuilder.NewRecordBatch()
assert.Equal(t, int64(3), rb1.NumCols())
assert.Equal(t, int64(8), rb1.NumRows())

data1 := recordBatchToNDJSON(t, rb1)
rb1.Release()

assert.True(t, strings.HasPrefix(data0, data1))
})

t.Run("truncate_incomplete", func(t *testing.T) {
recordBuilder := array.NewRecordBuilder(mem, schema)
defer recordBuilder.Release()

ndjsonToRecordBuilder(t, recordBuilder, jsondata)

rb0 := recordBuilder.NewRecordBatch()
assert.Equal(t, int64(3), rb0.NumCols())
assert.Equal(t, int64(16), rb0.NumRows())

data0 := recordBatchToNDJSON(t, rb0)
rb0.Release()

ndjsonToRecordBuilder(t, recordBuilder, jsondata)
recordBuilder.Field(0).(*array.StringBuilder).Append("TN")
recordBuilder.Field(0).(*array.StringBuilder).Append("IL")
recordBuilder.Field(0).(*array.StringBuilder).Append("WI")

panicked := false
func() {
defer func() {
if r := recover(); r != nil {
if strings.HasPrefix(r.(error).Error(), "arrow/array:") {
panicked = true
} else {
panic(r)
}
}
}()

rb1 := recordBuilder.NewRecordBatch()
rb1.Release()
}()

assert.True(t, panicked)

recordBuilder.Resize(-1)

rb2 := recordBuilder.NewRecordBatch()
assert.Equal(t, int64(3), rb2.NumCols())
assert.Equal(t, int64(16), rb2.NumRows())

data2 := recordBatchToNDJSON(t, rb2)
rb2.Release()

assert.Equal(t, data0, data2)
})
}

type testMessage struct {
Foo *testMessageFoo
Bars []*testMessageBar
Expand Down