From d1167ef2e6bd66dfe16765272c586194078c49ab Mon Sep 17 00:00:00 2001 From: Mark Molinaro Date: Sat, 13 Jun 2026 06:52:38 +0000 Subject: [PATCH 1/2] python: add fastbinary decode_binary_from_bytes Add a bytes-based fastbinary decode entry point so callers that already hold serialized thrift data can bypass TMemoryBuffer and protocol wrapper setup while reusing the existing struct decoder. Performance (50k iterations, warmed) | Workload | decode_binary (transport) | decode_binary_from_bytes | Speedup | |----------|---------------------------|--------------------------|---------| | simple (30B) | 3.59 us | 0.81 us | 4.43x | | 10-string (182B) | 9.52 us | 6.77 us | 1.41x | | complex (395B) | 12.03 us | 8.26 us | 1.46x | The transport overhead (TMemoryBuffer + TBinaryProtocolAccelerated + BytesIO) is a fixed ~2.6us per call. For small structs this dominates; for larger structs the per-field decode work catches up. --- lib/py/src/ext/module.cpp | 25 ++++++ lib/py/src/ext/protocol.h | 1 + lib/py/src/ext/protocol.tcc | 37 ++++++++- lib/py/src/ext/types.h | 7 ++ lib/py/test/thrift_TBinaryProtocol.py | 105 ++++++++++++++++++++++++++ 5 files changed, 171 insertions(+), 4 deletions(-) diff --git a/lib/py/src/ext/module.cpp b/lib/py/src/ext/module.cpp index e2b540e6f6a..98a91e20ef1 100644 --- a/lib/py/src/ext/module.cpp +++ b/lib/py/src/ext/module.cpp @@ -139,11 +139,36 @@ static PyObject* decode_compact(PyObject*, PyObject* args) { return decode_impl(args); } +static PyObject* decode_binary_from_bytes(PyObject*, PyObject* args) { + PyObject* bytes_obj = nullptr; + PyObject* typeargs = nullptr; + if (!PyArg_ParseTuple(args, "OO", &bytes_obj, &typeargs)) { + return nullptr; + } + if (!PyBytes_Check(bytes_obj)) { + PyErr_SetString(PyExc_TypeError, "first argument must be bytes"); + return nullptr; + } + + StructTypeArgs parsedargs; + if (!parse_struct_args(&parsedargs, typeargs)) { + return nullptr; + } + + BinaryProtocol protocol; + if (!protocol.prepareDecodeBufferFromBytes(bytes_obj)) { + return nullptr; + } + + return protocol.readStruct(Py_None, parsedargs.klass, parsedargs.spec); +} + static PyMethodDef ThriftFastBinaryMethods[] = { {"encode_binary", encode_binary, METH_VARARGS, ""}, {"decode_binary", decode_binary, METH_VARARGS, ""}, {"encode_compact", encode_compact, METH_VARARGS, ""}, {"decode_compact", decode_compact, METH_VARARGS, ""}, + {"decode_binary_from_bytes", decode_binary_from_bytes, METH_VARARGS, ""}, {nullptr, nullptr, 0, nullptr} /* Sentinel */ }; diff --git a/lib/py/src/ext/protocol.h b/lib/py/src/ext/protocol.h index 20911c89724..bad201f02c1 100644 --- a/lib/py/src/ext/protocol.h +++ b/lib/py/src/ext/protocol.h @@ -40,6 +40,7 @@ class ProtocolBase { inline virtual ~ProtocolBase(); bool prepareDecodeBufferFromTransport(PyObject* trans); + bool prepareDecodeBufferFromBytes(PyObject* bytes_obj); PyObject* readStruct(PyObject* output, PyObject* klass, PyObject* spec_seq); diff --git a/lib/py/src/ext/protocol.tcc b/lib/py/src/ext/protocol.tcc index 448fc6f105e..308fd28a1d3 100644 --- a/lib/py/src/ext/protocol.tcc +++ b/lib/py/src/ext/protocol.tcc @@ -301,9 +301,17 @@ bool ProtocolBase::readBytes(char** output, int len) { PyErr_Format(PyExc_ValueError, "attempted to read negative length: %d", len); return false; } - // TODO(dreiss): Don't fear the malloc. Think about taking a copy of - // the partial read instead of forcing the transport - // to prepend it to its buffer. + + if (input_.direct_buf) { + if (input_.direct_pos + static_cast(len) > input_.direct_size) { + PyErr_SetString(PyExc_EOFError, "read past end of buffer"); + return false; + } + + *output = const_cast(input_.direct_buf + input_.direct_pos); + input_.direct_pos += static_cast(len); + return true; + } int rlen = detail::read_buffer(input_.stringiobuf.get(), output, len); @@ -338,7 +346,7 @@ bool ProtocolBase::readBytes(char** output, int len) { template bool ProtocolBase::prepareDecodeBufferFromTransport(PyObject* trans) { - if (input_.stringiobuf) { + if (input_.stringiobuf || input_.direct_buf) { PyErr_SetString(PyExc_ValueError, "decode buffer is already initialized"); return false; } @@ -366,6 +374,27 @@ bool ProtocolBase::prepareDecodeBufferFromTransport(PyObject* trans) { return true; } +template +bool ProtocolBase::prepareDecodeBufferFromBytes(PyObject* bytes_obj) { + if (input_.stringiobuf || input_.direct_buf) { + PyErr_SetString(PyExc_ValueError, "decode buffer is already initialized"); + return false; + } + + char* buf = nullptr; + Py_ssize_t len = 0; + if (PyBytes_AsStringAndSize(bytes_obj, &buf, &len) < 0) { + return false; + } + + Py_INCREF(bytes_obj); + input_.direct_source.reset(bytes_obj); + input_.direct_buf = buf; + input_.direct_size = static_cast(len); + input_.direct_pos = 0; + return true; +} + template bool ProtocolBase::prepareEncodeBuffer() { output_ = detail::new_encode_buffer(INIT_OUTBUF_SIZE); diff --git a/lib/py/src/ext/types.h b/lib/py/src/ext/types.h index 2848b28f0ba..c6905c52c76 100644 --- a/lib/py/src/ext/types.h +++ b/lib/py/src/ext/types.h @@ -119,10 +119,17 @@ class ScopedPyObject { /** * A cache of the two key attributes of a CReadableTransport, * so we don't have to keep calling PyObject_GetAttr. + * Also supports reading directly from a bytes object. */ struct DecodeBuffer { ScopedPyObject stringiobuf; ScopedPyObject refill_callable; + ScopedPyObject direct_source; + const char* direct_buf; + size_t direct_size; + size_t direct_pos; + + DecodeBuffer() : direct_buf(nullptr), direct_size(0), direct_pos(0) {} }; #if PY_MAJOR_VERSION < 3 diff --git a/lib/py/test/thrift_TBinaryProtocol.py b/lib/py/test/thrift_TBinaryProtocol.py index d4269eb6175..a2841844eed 100644 --- a/lib/py/test/thrift_TBinaryProtocol.py +++ b/lib/py/test/thrift_TBinaryProtocol.py @@ -23,6 +23,7 @@ import _import_local_thrift # noqa from thrift.protocol.TBinaryProtocol import TBinaryProtocol +from thrift.protocol.TBinaryProtocol import TBinaryProtocolAcceleratedFactory from thrift.protocol.TProtocol import TProtocolException from thrift.transport import TTransport @@ -194,8 +195,60 @@ def testMessage(data, strict=True): return result +class SimpleStruct(object): + thrift_spec = ( + None, + (1, 11, "name", "UTF8", None), + (2, 8, "value", None, None), + (3, 2, "flag", None, None), + ) + + def __init__(self, name=None, value=None, flag=None): + self.name = name + self.value = value + self.flag = flag + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + + oprot.writeStructBegin("SimpleStruct") + if self.name is not None: + oprot.writeFieldBegin("name", 11, 1) + oprot.writeString(self.name) + oprot.writeFieldEnd() + if self.value is not None: + oprot.writeFieldBegin("value", 8, 2) + oprot.writeI32(self.value) + oprot.writeFieldEnd() + if self.flag is not None: + oprot.writeFieldBegin("flag", 2, 3) + oprot.writeBool(self.flag) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() + + @classmethod + def read(cls, iprot): + if ( + iprot._fast_decode is not None + and isinstance(iprot.trans, TTransport.CReadableTransport) + and cls.thrift_spec is not None + ): + return iprot._fast_decode(None, iprot, [cls, cls.thrift_spec]) + return iprot.readStruct(cls, cls.thrift_spec, False) + + class TestTBinaryProtocol(unittest.TestCase): + def setUp(self): + try: + from thrift.protocol import fastbinary + self._fastbinary = fastbinary + except ImportError: + self._fastbinary = None + def test_TBinaryProtocol_write_read(self): try: testNaked('Byte', 123) @@ -280,6 +333,58 @@ def test_TBinaryProtocol_write_read(self): print("Assertion fail") raise e + def _encode_accelerated_struct(self, value): + otrans = TTransport.TMemoryBuffer() + oproto = TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(otrans) + value.write(oproto) + return otrans.getvalue() + + def _decode_accelerated_struct(self, encoded): + itrans = TTransport.TMemoryBuffer(encoded) + iproto = TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(itrans) + return SimpleStruct.read(iproto) + + def test_decode_binary_from_bytes_matches_transport(self): + if self._fastbinary is None: + self.skipTest("C extension not available") + + original = SimpleStruct(name="transport-free", value=42, flag=True) + encoded = self._encode_accelerated_struct(original) + + decoded_transport = self._decode_accelerated_struct(encoded) + decoded_direct = self._fastbinary.decode_binary_from_bytes( + encoded, + [SimpleStruct, SimpleStruct.thrift_spec], + ) + + self.assertEqual(decoded_direct.name, decoded_transport.name) + self.assertEqual(decoded_direct.value, decoded_transport.value) + self.assertEqual(decoded_direct.flag, decoded_transport.flag) + + def test_decode_binary_from_bytes_rejects_non_bytes(self): + if self._fastbinary is None: + self.skipTest("C extension not available") + + with self.assertRaises(TypeError): + self._fastbinary.decode_binary_from_bytes( + "not-bytes", + [SimpleStruct, SimpleStruct.thrift_spec], + ) + + def test_decode_binary_from_bytes_rejects_truncated_input(self): + if self._fastbinary is None: + self.skipTest("C extension not available") + + encoded = self._encode_accelerated_struct( + SimpleStruct(name="trim me", value=7, flag=False) + ) + + with self.assertRaises(EOFError): + self._fastbinary.decode_binary_from_bytes( + encoded[:-1], + [SimpleStruct, SimpleStruct.thrift_spec], + ) + def test_TBinaryProtocol_no_strict_write_read(self): TMessageType = {"T_CALL": 1, "T_REPLY": 2, "T_EXCEPTION": 3, "T_ONEWAY": 4} test_data = [("short message name", TMessageType['T_CALL'], 0), From 6af2e54e8655ab3c59f26cdfd418c3621a8a439d Mon Sep 17 00:00:00 2001 From: Mark Molinaro Date: Sat, 13 Jun 2026 07:11:13 +0000 Subject: [PATCH 2/2] python: harden direct decode buffer bounds check Use an overflow-safe remaining-bytes comparison in the direct bytes decode path so extreme size_t values cannot wrap the position-plus-length check and bypass EOF enforcement. --- lib/py/src/ext/protocol.tcc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/py/src/ext/protocol.tcc b/lib/py/src/ext/protocol.tcc index 308fd28a1d3..d53ebe47bb9 100644 --- a/lib/py/src/ext/protocol.tcc +++ b/lib/py/src/ext/protocol.tcc @@ -303,13 +303,14 @@ bool ProtocolBase::readBytes(char** output, int len) { } if (input_.direct_buf) { - if (input_.direct_pos + static_cast(len) > input_.direct_size) { + size_t requested = static_cast(len); + if (input_.direct_pos > input_.direct_size || requested > (input_.direct_size - input_.direct_pos)) { PyErr_SetString(PyExc_EOFError, "read past end of buffer"); return false; } *output = const_cast(input_.direct_buf + input_.direct_pos); - input_.direct_pos += static_cast(len); + input_.direct_pos += requested; return true; }