Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions lib/py/src/ext/protocol.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <cStringIO.h>
#else
#include <algorithm>
#include <cstring>
#endif

namespace apache {
Expand Down Expand Up @@ -120,8 +121,10 @@ inline bool input_check(PyObject* input) {

inline EncodeBuffer* new_encode_buffer(size_t size) {
EncodeBuffer* buffer = new EncodeBuffer;
buffer->buf.reserve(size);
buffer->pos = 0;
if (!buffer->init(size)) {
delete buffer;
return nullptr;
}
return buffer;
}

Expand Down Expand Up @@ -165,21 +168,18 @@ inline bool ProtocolBase<Impl>::isUtf8(PyObject* typeargs) {

template <typename Impl>
PyObject* ProtocolBase<Impl>::getEncodedValue() {
return PyBytes_FromStringAndSize(output_->buf.data(), output_->buf.size());
return PyBytes_FromStringAndSize(output_->data, output_->size);
}

template <typename Impl>
inline bool ProtocolBase<Impl>::writeBuffer(char* data, size_t size) {
size_t need = size + output_->pos;
if (output_->buf.capacity() < need) {
try {
output_->buf.reserve(need);
} catch (std::bad_alloc&) {
PyErr_SetString(PyExc_MemoryError, "Failed to allocate write buffer");
return false;
}
if (!output_->ensure(size)) {
PyErr_SetString(PyExc_MemoryError, "Failed to allocate write buffer");
return false;
}
std::copy(data, data + size, std::back_inserter(output_->buf));

memcpy(output_->data + output_->size, data, size);
output_->size += size;
return true;
}

Expand Down
63 changes: 61 additions & 2 deletions lib/py/src/ext/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#if PY_MAJOR_VERSION >= 3

#include <vector>
#include <limits>
#include <stdlib.h>

// TODO: better macros
#define PyInt_AsLong(v) PyLong_AsLong(v)
Expand Down Expand Up @@ -131,8 +133,65 @@ typedef PyObject EncodeBuffer;
#else
extern const char* refill_signature;
struct EncodeBuffer {
std::vector<char> buf;
size_t pos;
char* data;
size_t size;
size_t capacity;

EncodeBuffer() : data(nullptr), size(0), capacity(0) {}
EncodeBuffer(const EncodeBuffer&) = delete;
EncodeBuffer& operator=(const EncodeBuffer&) = delete;

~EncodeBuffer() {
if (data) {
free(data);
}
}
Comment on lines 135 to +148

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks bot


bool init(size_t initial_capacity) {
if (initial_capacity == 0) {
data = nullptr;
size = 0;
capacity = 0;
return true;
}

data = static_cast<char*>(malloc(initial_capacity));
if (!data) {
return false;
}
size = 0;
capacity = initial_capacity;
return true;
}
Comment on lines +150 to +165

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks bot


bool ensure(size_t additional) {
if (additional > (std::numeric_limits<size_t>::max)() - size) {
return false;
}

size_t needed = size + additional;
if (needed <= capacity) {
return true;
}

size_t new_capacity = capacity == 0 ? needed : capacity;
while (new_capacity < needed) {
if (new_capacity > (std::numeric_limits<size_t>::max)() / 2) {
new_capacity = needed;
break;
}
new_capacity *= 2;
}

char* new_data = static_cast<char*>(realloc(data, new_capacity));
if (!new_data) {
return false;
}

data = new_data;
capacity = new_capacity;
return true;
}
Comment on lines +167 to +194

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks bot

};
#endif

Expand Down
43 changes: 43 additions & 0 deletions lib/py/test/thrift_TBinaryProtocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
import uuid

import _import_local_thrift # noqa
from thrift.Thrift import TApplicationException
from thrift.protocol.TBinaryProtocol import TBinaryProtocol
from thrift.protocol.TBinaryProtocol import TBinaryProtocolAcceleratedFactory
from thrift.protocol.TProtocol import TProtocolException
from thrift.transport import TTransport

Expand Down Expand Up @@ -167,6 +169,13 @@ def testField(type, data):
protocol.readStructEnd()


APPLICATION_EXCEPTION_THRIFT_SPEC = (
None,
(1, 11, "message", "UTF8", None),
(2, 8, "type", None, None),
)


def testMessage(data, strict=True):
message = {}
message['name'] = data[0]
Expand Down Expand Up @@ -196,6 +205,13 @@ def testMessage(data, strict=True):

class TestTBinaryProtocol(unittest.TestCase):

def setUp(self):
try:
from thrift.protocol import fastbinary # noqa: F401
self._has_fastbinary = True
except ImportError:
self._has_fastbinary = False

def test_TBinaryProtocol_write_read(self):
try:
testNaked('Byte', 123)
Expand Down Expand Up @@ -280,6 +296,33 @@ def test_TBinaryProtocol_write_read(self):
print("Assertion fail")
raise e

def test_accelerated_large_message_roundtrip(self):
if not self._has_fastbinary:
self.skipTest("C extension not available")

original = TApplicationException(
type=TApplicationException.INTERNAL_ERROR,
message="x" * 8192,
)

otrans = TTransport.TMemoryBuffer()
oproto = TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(otrans)
oproto.trans.write(oproto._fast_encode(
original,
[TApplicationException, APPLICATION_EXCEPTION_THRIFT_SPEC],
))

itrans = TTransport.TMemoryBuffer(otrans.getvalue())
iproto = TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(itrans)
decoded = iproto._fast_decode(
None,
iproto,
[TApplicationException, APPLICATION_EXCEPTION_THRIFT_SPEC],
)

self.assertEqual(decoded.message, original.message)
self.assertEqual(decoded.type, original.type)

def test_TBinaryProtocol_no_strict_write_read(self):
TMessageType = {"T_CALL": 1, "T_REPLY": 2, "T_EXCEPTION": 3, "T_ONEWAY": 4}
test_data = [("short message name", TMessageType['T_CALL'], 0),
Expand Down
Loading