Make py::Object conversion to PyObject * explicit

Summary:
This fixes an issue in Py(Vertex|Edge)GetProperty and prevents any
further issues of that type at the cost of additional typing effort.

Reviewers: ipaljak, llugovic

Reviewed By: ipaljak, llugovic

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D2735
This commit is contained in:
Teon Banek 2020-03-23 13:23:57 +01:00
parent a8b81fdcde
commit b7738c64b3
4 changed files with 65 additions and 60 deletions

View File

@ -75,11 +75,18 @@ class [[nodiscard]] Object final {
return *this;
}
operator PyObject *() const { return ptr_; }
operator bool() const { return ptr_; }
/// Borrow the original `PyObject *`, the ownership is not transferred.
/// @sa Steal
explicit operator PyObject *() const { return ptr_; }
/// Borrow the original `PyObject *`, the ownership is not transferred.
/// @sa Steal
PyObject *Ptr() const { return ptr_; }
/// Release the ownership on this PyObject, i.e. we steal the reference.
/// @sa Ptr
PyObject *Steal() {
auto *p = ptr_;
ptr_ = nullptr;
@ -165,7 +172,7 @@ class [[nodiscard]] Object final {
Object CallMethod(std::string_view meth_name) const {
Object name(
PyUnicode_FromStringAndSize(meth_name.data(), meth_name.size()));
return Object(PyObject_CallMethodObjArgs(ptr_, name, nullptr));
return Object(PyObject_CallMethodObjArgs(ptr_, name.Ptr(), nullptr));
}
/// Equivalent to `obj.meth_name(*args)` in Python.
@ -177,14 +184,14 @@ class [[nodiscard]] Object final {
Object name(
PyUnicode_FromStringAndSize(meth_name.data(), meth_name.size()));
return Object(PyObject_CallMethodObjArgs(
ptr_, name, static_cast<PyObject *>(args)..., nullptr));
ptr_, name.Ptr(), static_cast<PyObject *>(args)..., nullptr));
}
};
/// Write Object to stream as if `str(o)` was called in Python.
inline std::ostream &operator<<(std::ostream &os, const Object &py_object) {
auto py_str = py_object.Str();
os << PyUnicode_AsUTF8(py_str);
os << PyUnicode_AsUTF8(py_str.Ptr());
return os;
}
@ -208,12 +215,12 @@ inline std::ostream &operator<<(std::ostream &os,
Object format_exception_fn(traceback_mod.GetAttr("format_exception"));
CHECK(format_exception_fn);
auto list = format_exception_fn.Call(
exc_info.type, exc_info.value ? exc_info.value : Py_None,
exc_info.traceback ? exc_info.traceback : Py_None);
exc_info.type, exc_info.value ? exc_info.value.Ptr() : Py_None,
exc_info.traceback ? exc_info.traceback.Ptr() : Py_None);
CHECK(list);
auto len = PyList_GET_SIZE(static_cast<PyObject *>(list));
auto len = PyList_GET_SIZE(list.Ptr());
for (Py_ssize_t i = 0; i < len; ++i) {
auto *py_str = PyList_GET_ITEM(static_cast<PyObject *>(list), i);
auto *py_str = PyList_GET_ITEM(list.Ptr(), i);
os << PyUnicode_AsUTF8(py_str);
}
return os;
@ -246,10 +253,10 @@ inline void RestoreError(ExceptionInfo exc_info) {
CHECK(py_path);
py::Object import_dir(PyUnicode_FromString(dir));
if (!import_dir) return py::FetchError();
int import_dir_in_path = PySequence_Contains(py_path, import_dir);
int import_dir_in_path = PySequence_Contains(py_path, import_dir.Ptr());
if (import_dir_in_path == -1) return py::FetchError();
if (import_dir_in_path == 1) return std::nullopt;
if (PyList_Append(py_path, import_dir) == -1) return py::FetchError();
if (PyList_Append(py_path, import_dir.Ptr()) == -1) return py::FetchError();
return std::nullopt;
}

View File

@ -385,7 +385,7 @@ bool PythonModule::Reload() {
procedures_.clear();
py_module_ =
WithModuleRegistration(&procedures_, [&](auto *module_def, auto *memory) {
return ReloadPyModule(py_module_, module_def);
return ReloadPyModule(py_module_.Ptr(), module_def);
});
if (py_module_) return true;
auto exc_info = py::FetchError().value();

View File

@ -409,7 +409,7 @@ py::Object MgpListToPyTuple(const mgp_list *list, PyGraph *py_graph) {
if (!elem) return nullptr;
// Explicitly convert `py_tuple`, which is `py::Object`, via static_cast.
// Then the macro will cast it to `PyTuple *`.
PyTuple_SET_ITEM(static_cast<PyObject *>(py_tuple), i, elem.Steal());
PyTuple_SET_ITEM(py_tuple.Ptr(), i, elem.Steal());
}
return py_tuple;
}
@ -437,7 +437,7 @@ std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result,
if (!py_mgp) return py::FetchError();
auto record_cls = py_mgp.GetAttr("Record");
if (!record_cls) return py::FetchError();
if (!PyObject_IsInstance(py_record, record_cls)) {
if (!PyObject_IsInstance(py_record.Ptr(), record_cls.Ptr())) {
std::stringstream ss;
ss << "Value '" << py_record << "' is not an instance of 'mgp.Record'";
const auto &msg = ss.str();
@ -451,16 +451,16 @@ std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result,
"Expected 'mgp.Record.fields' to be a 'dict'");
return py::FetchError();
}
py::Object items(PyDict_Items(fields));
py::Object items(PyDict_Items(fields.Ptr()));
if (!items) return py::FetchError();
auto *record = mgp_result_new_record(result);
if (!record) {
PyErr_NoMemory();
return py::FetchError();
}
Py_ssize_t len = PyList_GET_SIZE(static_cast<PyObject *>(items));
Py_ssize_t len = PyList_GET_SIZE(items.Ptr());
for (Py_ssize_t i = 0; i < len; ++i) {
auto *item = PyList_GET_ITEM(static_cast<PyObject *>(items), i);
auto *item = PyList_GET_ITEM(items.Ptr(), i);
if (!item) return py::FetchError();
CHECK(PyTuple_Check(item));
auto *key = PyTuple_GetItem(item, 0);
@ -504,10 +504,10 @@ std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result,
std::optional<py::ExceptionInfo> AddMultipleRecordsFromPython(
mgp_result *result, py::Object py_seq) {
Py_ssize_t len = PySequence_Size(py_seq);
Py_ssize_t len = PySequence_Size(py_seq.Ptr());
if (len == -1) return py::FetchError();
for (Py_ssize_t i = 0; i < len; ++i) {
py::Object py_record(PySequence_GetItem(py_seq, i));
py::Object py_record(PySequence_GetItem(py_seq.Ptr(), i));
if (!py_record) return py::FetchError();
auto maybe_exc = AddRecordFromPython(result, py_record);
if (maybe_exc) return maybe_exc;
@ -554,7 +554,7 @@ PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) {
}
auto py_cb = py::Object::FromBorrow(cb);
py::Object py_name(py_cb.GetAttr("__name__"));
const auto *name = PyUnicode_AsUTF8(py_name);
const auto *name = PyUnicode_AsUTF8(py_name.Ptr());
if (!name) return nullptr;
if (!IsValidIdentifierName(name)) {
PyErr_SetString(PyExc_ValueError,
@ -567,19 +567,19 @@ PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) {
[py_cb](const mgp_list *args, const mgp_graph *graph, mgp_result *result,
mgp_memory *memory) {
auto gil = py::EnsureGIL();
auto maybe_exc =
WithPyGraph(graph, memory,
[&](auto py_graph) -> std::optional<py::ExceptionInfo> {
py::Object py_args(MgpListToPyTuple(args, py_graph));
if (!py_args) return py::FetchError();
auto py_res = py_cb.Call(py_graph, py_args);
if (!py_res) return py::FetchError();
if (PySequence_Check(py_res)) {
return AddMultipleRecordsFromPython(result, py_res);
} else {
return AddRecordFromPython(result, py_res);
}
});
auto maybe_exc = WithPyGraph(
graph, memory,
[&](auto py_graph) -> std::optional<py::ExceptionInfo> {
py::Object py_args(MgpListToPyTuple(args, py_graph.Ptr()));
if (!py_args) return py::FetchError();
auto py_res = py_cb.Call(py_graph, py_args);
if (!py_res) return py::FetchError();
if (PySequence_Check(py_res.Ptr())) {
return AddMultipleRecordsFromPython(result, py_res);
} else {
return AddRecordFromPython(result, py_res);
}
});
if (maybe_exc) return SetErrorFromPython(result, *maybe_exc);
},
memory);
@ -748,8 +748,7 @@ PyObject *PyPropertiesIteratorGet(PyPropertiesIterator *self,
if (!py_name) return nullptr;
auto py_value = MgpValueToPyObject(*property->value, self->py_graph);
if (!py_value) return nullptr;
return PyTuple_Pack(2, static_cast<PyObject *>(py_name),
static_cast<PyObject *>(py_value));
return PyTuple_Pack(2, py_name.Ptr(), py_value.Ptr());
}
PyObject *PyPropertiesIteratorNext(PyPropertiesIterator *self,
@ -763,8 +762,7 @@ PyObject *PyPropertiesIteratorNext(PyPropertiesIterator *self,
if (!py_name) return nullptr;
auto py_value = MgpValueToPyObject(*property->value, self->py_graph);
if (!py_value) return nullptr;
return PyTuple_Pack(2, static_cast<PyObject *>(py_name),
static_cast<PyObject *>(py_value));
return PyTuple_Pack(2, py_name.Ptr(), py_value.Ptr());
}
static PyMethodDef PyPropertiesIteratorMethods[] = {
@ -884,7 +882,7 @@ PyObject *PyEdgeGetProperty(PyEdge *self, PyObject *args) {
}
auto py_prop_value = MgpValueToPyObject(*prop_value, self->py_graph);
mgp_value_destroy(prop_value);
return py_prop_value;
return py_prop_value.Steal();
}
static PyMethodDef PyEdgeMethods[] = {
@ -1105,7 +1103,7 @@ PyObject *PyVertexGetProperty(PyVertex *self, PyObject *args) {
}
auto py_prop_value = MgpValueToPyObject(*prop_value, self->py_graph);
mgp_value_destroy(prop_value);
return py_prop_value;
return py_prop_value.Steal();
}
static PyMethodDef PyVertexMethods[] = {
@ -1397,7 +1395,7 @@ auto WithMgpModule(mgp_module *module_def, const TFun &fun) {
CHECK(py_mgp_module) << "Expected '_mgp' to have attribute '_MODULE'";
// NOTE: This check is not thread safe, but this should only go through
// ModuleRegistry::LoadModuleLibrary which ought to serialize loading.
CHECK(py_mgp_module == Py_None)
CHECK(py_mgp_module.Ptr() == Py_None)
<< "Expected '_mgp._MODULE' to be None as we are just starting to "
"import a new module. Is some other thread also importing Python "
"modules?";
@ -1457,7 +1455,7 @@ py::Object MgpValueToPyObject(const mgp_value &value, PyGraph *py_graph) {
auto py_val = MgpValueToPyObject(val, py_graph);
if (!py_val) return nullptr;
// Unlike PyList_SET_ITEM, PyDict_SetItem does not steal the value.
if (PyDict_SetItemString(py_dict, key.c_str(), py_val) != 0)
if (PyDict_SetItemString(py_dict.Ptr(), key.c_str(), py_val.Ptr()) != 0)
return nullptr;
}
return py_dict;
@ -1534,7 +1532,7 @@ mgp_value *PyObjectToMgpValue(PyObject *o, mgp_memory *memory) {
ss << "'mgp' module is missing '" << mgp_type_name << "' type";
throw std::invalid_argument(ss.str());
}
int res = PyObject_IsInstance(obj, mgp_type);
int res = PyObject_IsInstance(obj, mgp_type.Ptr());
if (res == -1) {
PyErr_Clear();
std::stringstream ss;
@ -1655,7 +1653,7 @@ mgp_value *PyObjectToMgpValue(PyObject *o, mgp_memory *memory) {
PyErr_Clear();
throw std::invalid_argument("'mgp.Edge' is missing '_edge' attribute");
}
return PyObjectToMgpValue(edge, memory);
return PyObjectToMgpValue(edge.Ptr(), memory);
} else if (is_mgp_instance(o, "Vertex")) {
py::Object vertex(PyObject_GetAttrString(o, "_vertex"));
if (!vertex) {
@ -1663,14 +1661,14 @@ mgp_value *PyObjectToMgpValue(PyObject *o, mgp_memory *memory) {
throw std::invalid_argument(
"'mgp.Vertex' is missing '_vertex' attribute");
}
return PyObjectToMgpValue(vertex, memory);
return PyObjectToMgpValue(vertex.Ptr(), memory);
} else if (is_mgp_instance(o, "Path")) {
py::Object path(PyObject_GetAttrString(o, "_path"));
if (!path) {
PyErr_Clear();
throw std::invalid_argument("'mgp.Path' is missing '_path' attribute");
}
return PyObjectToMgpValue(path, memory);
return PyObjectToMgpValue(path.Ptr(), memory);
} else {
throw std::invalid_argument("Unsupported PyObject conversion");
}

View File

@ -30,17 +30,16 @@ TEST(PyModule, MgpValueToPyObject) {
auto gil = py::EnsureGIL();
py::Object py_graph(query::procedure::MakePyGraph(nullptr, &memory));
auto py_dict = query::procedure::MgpValueToPyObject(
*map_val, reinterpret_cast<query::procedure::PyGraph *>(
static_cast<PyObject *>(py_graph)));
*map_val, reinterpret_cast<query::procedure::PyGraph *>(py_graph.Ptr()));
mgp_value_destroy(map_val);
// We should now have in Python:
// {"list": (None, False, True, 42, 0.1, "some text")}
ASSERT_TRUE(PyDict_Check(py_dict));
EXPECT_EQ(PyDict_Size(py_dict), 1);
EXPECT_EQ(PyDict_Size(py_dict.Ptr()), 1);
PyObject *key = nullptr;
PyObject *value = nullptr;
Py_ssize_t pos = 0;
while (PyDict_Next(py_dict, &pos, &key, &value)) {
while (PyDict_Next(py_dict.Ptr(), &pos, &key, &value)) {
ASSERT_TRUE(PyUnicode_Check(key));
EXPECT_EQ(std::string(PyUnicode_AsUTF8(key)), "list");
ASSERT_TRUE(PyTuple_Check(value));
@ -117,12 +116,12 @@ TEST(PyModule, PyVertex) {
ASSERT_TRUE(py_graph);
// Convert from mgp_value to mgp.Vertex.
py::Object py_vertex_value(
query::procedure::MgpValueToPyObject(*vertex_value, py_graph));
query::procedure::MgpValueToPyObject(*vertex_value, py_graph.Ptr()));
ASSERT_TRUE(py_vertex_value);
AssertPickleAndCopyAreNotSupported(py_vertex_value.GetAttr("_vertex"));
AssertPickleAndCopyAreNotSupported(py_vertex_value.GetAttr("_vertex").Ptr());
// Convert from mgp.Vertex to mgp_value.
auto *new_vertex_value =
query::procedure::PyObjectToMgpValue(py_vertex_value, &memory);
query::procedure::PyObjectToMgpValue(py_vertex_value.Ptr(), &memory);
// Test for equality.
ASSERT_TRUE(new_vertex_value);
ASSERT_NE(new_vertex_value, vertex_value); // Pointer compare.
@ -175,12 +174,12 @@ TEST(PyModule, PyEdge) {
ASSERT_TRUE(py_graph);
// Convert from mgp_value to mgp.Edge.
py::Object py_edge_value(
query::procedure::MgpValueToPyObject(*edge_value, py_graph));
query::procedure::MgpValueToPyObject(*edge_value, py_graph.Ptr()));
ASSERT_TRUE(py_edge_value);
AssertPickleAndCopyAreNotSupported(py_edge_value.GetAttr("_edge"));
AssertPickleAndCopyAreNotSupported(py_edge_value.GetAttr("_edge").Ptr());
// Convert from mgp.Edge to mgp_value.
auto *new_edge_value =
query::procedure::PyObjectToMgpValue(py_edge_value, &memory);
query::procedure::PyObjectToMgpValue(py_edge_value.Ptr(), &memory);
// Test for equality.
ASSERT_TRUE(new_edge_value);
ASSERT_NE(new_edge_value, edge_value); // Pointer compare.
@ -227,12 +226,12 @@ TEST(PyModule, PyPath) {
ASSERT_TRUE(py_graph);
// We have setup the C structs, so create convert to PyObject.
py::Object py_path_value(
query::procedure::MgpValueToPyObject(*path_value, py_graph));
query::procedure::MgpValueToPyObject(*path_value, py_graph.Ptr()));
ASSERT_TRUE(py_path_value);
AssertPickleAndCopyAreNotSupported(py_path_value.GetAttr("_path"));
AssertPickleAndCopyAreNotSupported(py_path_value.GetAttr("_path").Ptr());
// Convert back to C struct and check equality.
auto *new_path_value =
query::procedure::PyObjectToMgpValue(py_path_value, &memory);
query::procedure::PyObjectToMgpValue(py_path_value.Ptr(), &memory);
ASSERT_TRUE(new_path_value);
ASSERT_NE(new_path_value, path_value); // Pointer compare.
ASSERT_TRUE(mgp_value_is_path(new_path_value));
@ -248,7 +247,8 @@ TEST(PyModule, PyObjectToMgpValue) {
auto gil = py::EnsureGIL();
py::Object py_value{Py_BuildValue("[i f s (i f s) {s i s f}]", 1, 1.0, "one",
2, 2.0, "two", "three", 3, "four", 4.0)};
mgp_value *value = query::procedure::PyObjectToMgpValue(py_value, &memory);
mgp_value *value =
query::procedure::PyObjectToMgpValue(py_value.Ptr(), &memory);
ASSERT_TRUE(mgp_value_is_list(value));
const mgp_list *list1 = mgp_value_get_list(value);
@ -301,7 +301,7 @@ int main(int argc, char **argv) {
CHECK(py_path);
py::Object import_dir(
PyUnicode_FromString(mgp_py_path.parent_path().c_str()));
if (PyList_Append(py_path, import_dir) != 0) {
if (PyList_Append(py_path, import_dir.Ptr()) != 0) {
auto exc_info = py::FetchError().value();
LOG(FATAL) << exc_info;
}