diff --git a/cpp/src/arrow/ipc/message.cc b/cpp/src/arrow/ipc/message.cc index 84ee62fe9e8d..fc7c2baec519 100644 --- a/cpp/src/arrow/ipc/message.cc +++ b/cpp/src/arrow/ipc/message.cc @@ -563,7 +563,25 @@ Status DecodeMessage(MessageDecoder* decoder, io::InputStream* file) { } auto metadata_length = decoder->next_required_size(); + + // "ARRO" (first 4 bytes of kArrowMagicBytes) as little-endian int32. + constexpr int32_t kArrowMagicPrefix = 0x4F525241; + + // Did we misinterpret the metadata as a length? + if (metadata_length == kArrowMagicPrefix) { + constexpr std::string_view kRemainingMagic = + internal::kArrowMagicBytes.substr(sizeof(int32_t)); + ARROW_ASSIGN_OR_RAISE(auto peek, file->Read(kRemainingMagic.size())); + if (peek->size() >= static_cast(kRemainingMagic.size()) && + std::string_view(reinterpret_cast(peek->data()), + kRemainingMagic.size()) == kRemainingMagic) { + return Status::Invalid( + "This appears to be an Arrow IPC file. " + "Try the IPC file reader instead of the IPC stream reader."); + } + } ARROW_ASSIGN_OR_RAISE(auto metadata, file->Read(metadata_length)); + if (metadata->size() != metadata_length) { return Status::Invalid("Expected to read ", metadata_length, " metadata bytes, but ", "only read ", metadata->size()); diff --git a/cpp/src/arrow/ipc/read_write_test.cc b/cpp/src/arrow/ipc/read_write_test.cc index 15cf0258b2ee..7a7c13093f80 100644 --- a/cpp/src/arrow/ipc/read_write_test.cc +++ b/cpp/src/arrow/ipc/read_write_test.cc @@ -2265,6 +2265,53 @@ TEST(TestRecordBatchStreamReader, MalformedInput) { ASSERT_RAISES(Invalid, RecordBatchStreamReader::Open(&garbage_reader)); } +TEST(TestRecordBatchStreamReader, OpenFileFormatSuggestsFileReader) { + std::shared_ptr batch; + ASSERT_OK(MakeIntRecordBatch(&batch)); + + FileWriterHelper helper; + ASSERT_OK(helper.Init(batch->schema(), IpcWriteOptions::Defaults())); + ASSERT_OK(helper.WriteBatch(batch)); + ASSERT_OK(helper.Finish()); + + io::BufferReader reader(helper.buffer_); + // Check we mention using the file_reader when we detect file format + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("Try the IPC file reader"), + RecordBatchStreamReader::Open(&reader)); +} + +TEST(TestRecordBatchStreamReader, CorruptDataDoesNotSuggestFileReader) { + // Continuation marker + metadata_length = 100, then 8 bytes of non-magic data. + const std::string corrupt( + "\xff\xff\xff\xff" + "\x64\x00\x00\x00" + "ABABABAB", + 16); + auto buffer = std::make_shared(corrupt); + io::BufferReader reader(buffer); + // Validate that we don't suggest file reader when file is just corrupt + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::Not(::testing::HasSubstr("Try the IPC file reader")), + RecordBatchStreamReader::Open(&reader)); +} + +TEST(TestRecordBatchFileReader, OpenStreamFormatSuggestsStreamReader) { + std::shared_ptr batch; + ASSERT_OK(MakeIntRecordBatch(&batch)); + + StreamWriterHelper helper; + ASSERT_OK(helper.Init(batch->schema(), IpcWriteOptions::Defaults())); + ASSERT_OK(helper.WriteBatch(batch)); + ASSERT_OK(helper.Finish()); + + auto buf_reader = std::make_shared(helper.buffer_); + // Check we mention using the stream_reader when we detect stream format + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("use the IPC stream reader"), + RecordBatchFileReader::Open(buf_reader.get(), helper.buffer_->size())); +} + class EndlessCollectListener : public CollectListener { public: EndlessCollectListener() : CollectListener(), decoder_(nullptr) {} diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index 580081384308..68b0d1068c2c 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -1890,7 +1890,9 @@ class RecordBatchFileReaderImpl : public RecordBatchFileReader { const auto magic_start = buffer->data() + sizeof(int32_t); if (std::string_view(reinterpret_cast(magic_start), kMagicSize) != kArrowMagicBytes) { - return Status::Invalid("Not an Arrow file"); + return Status::Invalid( + "Not an Arrow file. If this is an Arrow IPC stream, use " + "the IPC stream reader instead."); } int32_t footer_length = bit_util::FromLittleEndian(