Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions lib/py/src/ext/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,36 @@ static PyObject* decode_compact(PyObject*, PyObject* args) {
return decode_impl<CompactProtocol>(args);
}

static PyObject* decode_binary_from_bytes(PyObject*, PyObject* args) {
PyObject* bytes_obj = nullptr;
PyObject* typeargs = nullptr;
if (!PyArg_ParseTuple(args, "OO", &bytes_obj, &typeargs)) {
return nullptr;
}
if (!PyBytes_Check(bytes_obj)) {
PyErr_SetString(PyExc_TypeError, "first argument must be bytes");
return nullptr;
}

StructTypeArgs parsedargs;
if (!parse_struct_args(&parsedargs, typeargs)) {
return nullptr;
}

BinaryProtocol protocol;
if (!protocol.prepareDecodeBufferFromBytes(bytes_obj)) {
return nullptr;
}

return protocol.readStruct(Py_None, parsedargs.klass, parsedargs.spec);
}
Comment on lines +142 to +164

static PyMethodDef ThriftFastBinaryMethods[] = {
{"encode_binary", encode_binary, METH_VARARGS, ""},
{"decode_binary", decode_binary, METH_VARARGS, ""},
{"encode_compact", encode_compact, METH_VARARGS, ""},
{"decode_compact", decode_compact, METH_VARARGS, ""},
{"decode_binary_from_bytes", decode_binary_from_bytes, METH_VARARGS, ""},
{nullptr, nullptr, 0, nullptr} /* Sentinel */
};
Comment thread
markjm marked this conversation as resolved.

Expand Down
1 change: 1 addition & 0 deletions lib/py/src/ext/protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class ProtocolBase {
inline virtual ~ProtocolBase();

bool prepareDecodeBufferFromTransport(PyObject* trans);
bool prepareDecodeBufferFromBytes(PyObject* bytes_obj);

PyObject* readStruct(PyObject* output, PyObject* klass, PyObject* spec_seq);

Expand Down
38 changes: 34 additions & 4 deletions lib/py/src/ext/protocol.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,18 @@ bool ProtocolBase<Impl>::readBytes(char** output, int len) {
PyErr_Format(PyExc_ValueError, "attempted to read negative length: %d", len);
return false;
}
// TODO(dreiss): Don't fear the malloc. Think about taking a copy of
// the partial read instead of forcing the transport
// to prepend it to its buffer.

if (input_.direct_buf) {
size_t requested = static_cast<size_t>(len);
if (input_.direct_pos > input_.direct_size || requested > (input_.direct_size - input_.direct_pos)) {
PyErr_SetString(PyExc_EOFError, "read past end of buffer");
return false;
}

*output = const_cast<char*>(input_.direct_buf + input_.direct_pos);
input_.direct_pos += requested;
return true;
}
Comment thread
markjm marked this conversation as resolved.

int rlen = detail::read_buffer(input_.stringiobuf.get(), output, len);

Expand Down Expand Up @@ -338,7 +347,7 @@ bool ProtocolBase<Impl>::readBytes(char** output, int len) {

template <typename Impl>
bool ProtocolBase<Impl>::prepareDecodeBufferFromTransport(PyObject* trans) {
if (input_.stringiobuf) {
if (input_.stringiobuf || input_.direct_buf) {
PyErr_SetString(PyExc_ValueError, "decode buffer is already initialized");
return false;
}
Expand Down Expand Up @@ -366,6 +375,27 @@ bool ProtocolBase<Impl>::prepareDecodeBufferFromTransport(PyObject* trans) {
return true;
}

template <typename Impl>
bool ProtocolBase<Impl>::prepareDecodeBufferFromBytes(PyObject* bytes_obj) {
if (input_.stringiobuf || input_.direct_buf) {
PyErr_SetString(PyExc_ValueError, "decode buffer is already initialized");
return false;
}

char* buf = nullptr;
Py_ssize_t len = 0;
if (PyBytes_AsStringAndSize(bytes_obj, &buf, &len) < 0) {
return false;
}

Py_INCREF(bytes_obj);
input_.direct_source.reset(bytes_obj);
input_.direct_buf = buf;
input_.direct_size = static_cast<size_t>(len);
input_.direct_pos = 0;
return true;
}

template <typename Impl>
bool ProtocolBase<Impl>::prepareEncodeBuffer() {
output_ = detail::new_encode_buffer(INIT_OUTBUF_SIZE);
Expand Down
7 changes: 7 additions & 0 deletions lib/py/src/ext/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,17 @@ class ScopedPyObject {
/**
* A cache of the two key attributes of a CReadableTransport,
* so we don't have to keep calling PyObject_GetAttr.
* Also supports reading directly from a bytes object.
*/
struct DecodeBuffer {
ScopedPyObject stringiobuf;
ScopedPyObject refill_callable;
ScopedPyObject direct_source;
const char* direct_buf;
size_t direct_size;
size_t direct_pos;

DecodeBuffer() : direct_buf(nullptr), direct_size(0), direct_pos(0) {}
};

#if PY_MAJOR_VERSION < 3
Expand Down
100 changes: 100 additions & 0 deletions lib/py/test/thrift_TBinaryProtocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import _import_local_thrift # noqa
from thrift.protocol.TBinaryProtocol import TBinaryProtocol
from thrift.protocol.TBinaryProtocol import TBinaryProtocolAcceleratedFactory
from thrift.protocol.TProtocol import TProtocolException
from thrift.transport import TTransport

Expand Down Expand Up @@ -194,8 +195,55 @@ def testMessage(data, strict=True):
return result


class SimpleStruct(object):
thrift_spec = (
None,
(1, 11, "name", "UTF8", None),
(2, 8, "value", None, None),
(3, 2, "flag", None, None),
)

def __init__(self, name=None, value=None, flag=None):
self.name = name
self.value = value
self.flag = flag

def write(self, oprot):
if oprot._fast_encode is not None and self.thrift_spec is not None:
oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec]))
return

oprot.writeStructBegin("SimpleStruct")
if self.name is not None:
oprot.writeFieldBegin("name", 11, 1)
oprot.writeString(self.name)
oprot.writeFieldEnd()
if self.value is not None:
oprot.writeFieldBegin("value", 8, 2)
oprot.writeI32(self.value)
oprot.writeFieldEnd()
if self.flag is not None:
oprot.writeFieldBegin("flag", 2, 3)
oprot.writeBool(self.flag)
oprot.writeFieldEnd()
oprot.writeFieldStop()
oprot.writeStructEnd()

@classmethod
def read(cls, iprot):
# Accelerated path only: tests construct iprot with fallback=False.
return iprot._fast_decode(None, iprot, [cls, cls.thrift_spec])


class TestTBinaryProtocol(unittest.TestCase):

def setUp(self):
try:
from thrift.protocol import fastbinary
self._fastbinary = fastbinary
except ImportError:
self._fastbinary = None

def test_TBinaryProtocol_write_read(self):
try:
testNaked('Byte', 123)
Expand Down Expand Up @@ -280,6 +328,58 @@ def test_TBinaryProtocol_write_read(self):
print("Assertion fail")
raise e

def _encode_accelerated_struct(self, value):
otrans = TTransport.TMemoryBuffer()
oproto = TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(otrans)
value.write(oproto)
return otrans.getvalue()

def _decode_accelerated_struct(self, encoded):
itrans = TTransport.TMemoryBuffer(encoded)
iproto = TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(itrans)
return SimpleStruct.read(iproto)

def test_decode_binary_from_bytes_matches_transport(self):
if self._fastbinary is None:
self.skipTest("C extension not available")

original = SimpleStruct(name="transport-free", value=42, flag=True)
encoded = self._encode_accelerated_struct(original)

decoded_transport = self._decode_accelerated_struct(encoded)
decoded_direct = self._fastbinary.decode_binary_from_bytes(
encoded,
[SimpleStruct, SimpleStruct.thrift_spec],
)

self.assertEqual(decoded_direct.name, decoded_transport.name)
self.assertEqual(decoded_direct.value, decoded_transport.value)
self.assertEqual(decoded_direct.flag, decoded_transport.flag)

def test_decode_binary_from_bytes_rejects_non_bytes(self):
if self._fastbinary is None:
self.skipTest("C extension not available")

with self.assertRaises(TypeError):
self._fastbinary.decode_binary_from_bytes(
"not-bytes",
[SimpleStruct, SimpleStruct.thrift_spec],
)

def test_decode_binary_from_bytes_rejects_truncated_input(self):
if self._fastbinary is None:
self.skipTest("C extension not available")

encoded = self._encode_accelerated_struct(
SimpleStruct(name="trim me", value=7, flag=False)
)

with self.assertRaises(EOFError):
self._fastbinary.decode_binary_from_bytes(
encoded[:-1],
[SimpleStruct, SimpleStruct.thrift_spec],
)

def test_TBinaryProtocol_no_strict_write_read(self):
TMessageType = {"T_CALL": 1, "T_REPLY": 2, "T_EXCEPTION": 3, "T_ONEWAY": 4}
test_data = [("short message name", TMessageType['T_CALL'], 0),
Expand Down