diff --git a/lib/py/src/ext/binary.h b/lib/py/src/ext/binary.h index dd7750b49a..9ccc87bc53 100644 --- a/lib/py/src/ext/binary.h +++ b/lib/py/src/ext/binary.h @@ -88,7 +88,7 @@ class BinaryProtocol : public ProtocolBase { return encodeValue(value, parsedspec.type, parsedspec.typeargs); } - void writeUuid(char* value) { + void writeUuid(const char* value) { writeBuffer(value, 16); } diff --git a/lib/py/src/ext/compact.h b/lib/py/src/ext/compact.h index 0d8946b344..7f9c017ebd 100644 --- a/lib/py/src/ext/compact.h +++ b/lib/py/src/ext/compact.h @@ -104,7 +104,7 @@ class CompactProtocol : public ProtocolBase { void writeFieldStop() { writeByte(0); } - void writeUuid(char* value) { + void writeUuid(const char* value) { writeBuffer(value, 16); } diff --git a/lib/py/src/ext/protocol.h b/lib/py/src/ext/protocol.h index 20911c8972..2e73723952 100644 --- a/lib/py/src/ext/protocol.h +++ b/lib/py/src/ext/protocol.h @@ -71,7 +71,7 @@ class ProtocolBase { return true; } - bool writeBuffer(char* data, size_t len); + bool writeBuffer(const char* data, size_t len); void writeByte(uint8_t val) { writeBuffer(reinterpret_cast(&val), 1); } diff --git a/lib/py/src/ext/protocol.tcc b/lib/py/src/ext/protocol.tcc index 448fc6f105..b09fe57424 100644 --- a/lib/py/src/ext/protocol.tcc +++ b/lib/py/src/ext/protocol.tcc @@ -89,7 +89,7 @@ PyObject* ProtocolBase::getEncodedValue() { } template -inline bool ProtocolBase::writeBuffer(char* data, size_t size) { +inline bool ProtocolBase::writeBuffer(const char* data, size_t size) { if (!PycStringIO) { PycString_IMPORT; } @@ -169,7 +169,7 @@ PyObject* ProtocolBase::getEncodedValue() { } template -inline bool ProtocolBase::writeBuffer(char* data, size_t size) { +inline bool ProtocolBase::writeBuffer(const char* data, size_t size) { size_t need = size + output_->pos; if (output_->buf.capacity() < need) { try { @@ -456,18 +456,32 @@ bool ProtocolBase::encodeValue(PyObject* value, TType type, PyObject* type case T_STRING: { ScopedPyObject nval; + Py_ssize_t len; if (PyUnicode_Check(value)) { +#if PY_VERSION_HEX >= 0x03030000 + const char* str = PyUnicode_AsUTF8AndSize(value, &len); + if (!str) { + return false; + } + if (!detail::check_ssize_t_32(len)) { + return false; + } + + impl()->writeI32(static_cast(len)); + return writeBuffer(str, static_cast(len)); +#else nval.reset(PyUnicode_AsUTF8String(value)); if (!nval) { return false; } +#endif } else { Py_INCREF(value); nval.reset(value); } - Py_ssize_t len = PyBytes_Size(nval.get()); + len = PyBytes_Size(nval.get()); if (!detail::check_ssize_t_32(len)) { return false; } diff --git a/lib/py/test/thrift_TBinaryProtocol.py b/lib/py/test/thrift_TBinaryProtocol.py index d4269eb617..b7e9b62399 100644 --- a/lib/py/test/thrift_TBinaryProtocol.py +++ b/lib/py/test/thrift_TBinaryProtocol.py @@ -22,7 +22,9 @@ import uuid import _import_local_thrift # noqa +from thrift.Thrift import TApplicationException from thrift.protocol.TBinaryProtocol import TBinaryProtocol +from thrift.protocol.TBinaryProtocol import TBinaryProtocolAcceleratedFactory from thrift.protocol.TProtocol import TProtocolException from thrift.transport import TTransport @@ -167,6 +169,16 @@ def testField(type, data): protocol.readStructEnd() +APPLICATION_EXCEPTION_TYPEARGS = [ + TApplicationException, + ( + None, + (1, 11, "message", "UTF8", None), + (2, 8, "type", None, None), + ), +] + + def testMessage(data, strict=True): message = {} message['name'] = data[0] @@ -196,6 +208,13 @@ def testMessage(data, strict=True): class TestTBinaryProtocol(unittest.TestCase): + def setUp(self): + try: + from thrift.protocol import fastbinary # noqa: F401 + self._has_fastbinary = True + except ImportError: + self._has_fastbinary = False + def test_TBinaryProtocol_write_read(self): try: testNaked('Byte', 123) @@ -280,6 +299,26 @@ def test_TBinaryProtocol_write_read(self): print("Assertion fail") raise e + def test_accelerated_utf8_roundtrip_on_application_exception(self): + if not self._has_fastbinary: + self.skipTest("C extension not available") + + original = TApplicationException( + type=TApplicationException.PROTOCOL_ERROR, + message=("snowman-\u2603-rocket-\U0001F680-" * 32), + ) + + otrans = TTransport.TMemoryBuffer() + oproto = TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(otrans) + oproto.trans.write(oproto._fast_encode(original, APPLICATION_EXCEPTION_TYPEARGS)) + + itrans = TTransport.TMemoryBuffer(otrans.getvalue()) + iproto = TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(itrans) + decoded = iproto._fast_decode(None, iproto, APPLICATION_EXCEPTION_TYPEARGS) + + self.assertEqual(decoded.message, original.message) + self.assertEqual(decoded.type, original.type) + def test_TBinaryProtocol_no_strict_write_read(self): TMessageType = {"T_CALL": 1, "T_REPLY": 2, "T_EXCEPTION": 3, "T_ONEWAY": 4} test_data = [("short message name", TMessageType['T_CALL'], 0),