diff --git a/src/communication/rpc/messages-inl.hpp b/src/communication/rpc/messages-inl.hpp index 56413c086..00615d03a 100644 --- a/src/communication/rpc/messages-inl.hpp +++ b/src/communication/rpc/messages-inl.hpp @@ -51,9 +51,7 @@ BOOST_CLASS_EXPORT(distributed::DispatchPlanReq); BOOST_CLASS_EXPORT(distributed::ConsumePlanRes); // Remote pull. -BOOST_CLASS_EXPORT(distributed::RemotePullReqData); BOOST_CLASS_EXPORT(distributed::RemotePullReq); -BOOST_CLASS_EXPORT(distributed::RemotePullResData); BOOST_CLASS_EXPORT(distributed::RemotePullRes); BOOST_CLASS_EXPORT(distributed::EndRemotePullReq); BOOST_CLASS_EXPORT(distributed::EndRemotePullRes); diff --git a/src/database/graph_db_accessor.cpp b/src/database/graph_db_accessor.cpp index a5aa48ba8..9aefdab9a 100644 --- a/src/database/graph_db_accessor.cpp +++ b/src/database/graph_db_accessor.cpp @@ -524,8 +524,6 @@ std::vector GraphDbAccessor::IndexInfo() const { } return info; } -auto &GraphDbAccessor::remote_vertices() { return *remote_vertices_; } -auto &GraphDbAccessor::remote_edges() { return *remote_edges_; } template <> distributed::RemoteCache &GraphDbAccessor::remote_elements() { diff --git a/src/database/graph_db_accessor.hpp b/src/database/graph_db_accessor.hpp index b51fe8f6c..fc75aab07 100644 --- a/src/database/graph_db_accessor.hpp +++ b/src/database/graph_db_accessor.hpp @@ -552,8 +552,16 @@ class GraphDbAccessor { /* Returns a list of index names present in the database. */ std::vector IndexInfo() const; - auto &remote_vertices(); - auto &remote_edges(); + distributed::RemoteCache &remote_vertices() { + CHECK(remote_vertices_) + << "Attempting to get a remote cache in single-node Memgraph"; + return remote_vertices_.value(); + } + distributed::RemoteCache &remote_edges() { + CHECK(remote_edges_) + << "Attempting to get a remote cache in single-node Memgraph"; + return remote_edges_.value(); + } /** Gets remote_vertices or remote_edges, depending on type param. */ template diff --git a/src/distributed/remote_cache.hpp b/src/distributed/remote_cache.hpp index 2a287e1cc..d8c1613c7 100644 --- a/src/distributed/remote_cache.hpp +++ b/src/distributed/remote_cache.hpp @@ -57,18 +57,17 @@ class RemoteCache { TRecord *&old_record, TRecord *&new_record) { std::lock_guard guard{lock_}; auto found = cache_.find(gid); - if (found == cache_.end()) { - rec_uptr old_record = + if (found != cache_.end()) { + old_record = found->second.first.get(); + new_record = found->second.second.get(); + } else { + auto remote = remote_data_clients_.RemoteElement(worker_id, tx_id, gid); - found = cache_ - .emplace(gid, - std::make_pair(nullptr, nullptr)) - .first; - found->second.first.swap(old_record); + old_record = remote.get(); + new_record = nullptr; + cache_[gid] = + std::make_pair(std::move(remote), nullptr); } - - old_record = found->second.first.get(); - new_record = found->second.second.get(); } void AdvanceCommand() { @@ -82,6 +81,19 @@ class RemoteCache { // though we'll have pointers to nothing. } + /** Sets the given records as (new, old) data for the given gid. */ + void emplace(gid::Gid gid, rec_uptr old_record, rec_uptr new_record) { + std::lock_guard guard{lock_}; + // We can't replace existing data because some accessors might be using it. + // TODO - consider if it's necessary and OK to copy just the data content. + auto found = cache_.find(gid); + if (found != cache_.end()) + return; + else + cache_[gid] = + std::make_pair(std::move(old_record), std::move(new_record)); + } + private: std::mutex lock_; distributed::RemoteDataRpcClients &remote_data_clients_; diff --git a/src/distributed/remote_produce_rpc_server.hpp b/src/distributed/remote_produce_rpc_server.hpp index 2ba2a440c..2f97654f2 100644 --- a/src/distributed/remote_produce_rpc_server.hpp +++ b/src/distributed/remote_produce_rpc_server.hpp @@ -52,8 +52,9 @@ class RemoteProduceRpcServer { auto success = cursor_->Pull(frame_, context_); if (success) { results.reserve(pull_symbols_.size()); - for (const auto &symbol : pull_symbols_) + for (const auto &symbol : pull_symbols_) { results.emplace_back(std::move(frame_[symbol])); + } } return std::make_pair(std::move(results), success); } @@ -79,7 +80,7 @@ class RemoteProduceRpcServer { plan_consumer_(plan_consumer) { remote_produce_rpc_server_.Register( [this](const RemotePullReq &req) { - return std::make_unique(RemotePull(req.member)); + return std::make_unique(RemotePull(req)); }); remote_produce_rpc_server_.Register([this]( @@ -101,7 +102,7 @@ class RemoteProduceRpcServer { ongoing_produces_; std::mutex ongoing_produces_lock_; - auto &GetOngoingProduce(const RemotePullReqData &req) { + auto &GetOngoingProduce(const RemotePullReq &req) { std::lock_guard guard{ongoing_produces_lock_}; auto found = ongoing_produces_.find({req.tx_id, req.plan_id}); if (found != ongoing_produces_.end()) { @@ -118,21 +119,22 @@ class RemoteProduceRpcServer { .first->second; } - RemotePullResData RemotePull(const RemotePullReqData &req) { + RemotePullResData RemotePull(const RemotePullReq &req) { auto &ongoing_produce = GetOngoingProduce(req); - RemotePullResData result; - result.pull_state = RemotePullState::CURSOR_IN_PROGRESS; + RemotePullResData result{db_.WorkerId(), req.send_old, req.send_new}; + + result.state_and_frames.pull_state = RemotePullState::CURSOR_IN_PROGRESS; for (int i = 0; i < req.batch_size; ++i) { // TODO exception handling (Serialization errors) // when full CRUD. Maybe put it in OngoingProduce::Pull auto pull_result = ongoing_produce.Pull(); if (!pull_result.second) { - result.pull_state = RemotePullState::CURSOR_EXHAUSTED; + result.state_and_frames.pull_state = RemotePullState::CURSOR_EXHAUSTED; break; } - result.frames.emplace_back(std::move(pull_result.first)); + result.state_and_frames.frames.emplace_back(std::move(pull_result.first)); } return result; diff --git a/src/distributed/remote_pull_produce_rpc_messages.hpp b/src/distributed/remote_pull_produce_rpc_messages.hpp index 67f3a1330..8c8233073 100644 --- a/src/distributed/remote_pull_produce_rpc_messages.hpp +++ b/src/distributed/remote_pull_produce_rpc_messages.hpp @@ -1,14 +1,17 @@ #pragma once #include +#include #include #include "boost/serialization/utility.hpp" #include "boost/serialization/vector.hpp" #include "communication/rpc/messages.hpp" +#include "distributed/serialization.hpp" #include "query/frontend/semantic/symbol.hpp" #include "query/parameters.hpp" +#include "storage/edges.hpp" #include "transactions/type.hpp" #include "utils/serialization.hpp" @@ -18,8 +21,8 @@ namespace distributed { /// master that requested it. constexpr int kDefaultBatchSize = 20; -/** Returnd along with a batch of results in the remote-pull RPC. Indicates the - * state of execution on the worker. */ +/// Returnd along with a batch of results in the remote-pull RPC. Indicates the +/// state of execution on the worker. enum class RemotePullState { CURSOR_EXHAUSTED, CURSOR_IN_PROGRESS, @@ -29,31 +32,51 @@ enum class RemotePullState { const std::string kRemotePullProduceRpcName = "RemotePullProduceRpc"; -struct RemotePullReqData { +struct RemotePullReq : public communication::rpc::Message { + RemotePullReq() {} + RemotePullReq(tx::transaction_id_t tx_id, int64_t plan_id, + const Parameters ¶ms, std::vector symbols, + int batch_size, bool send_old, bool send_new) + : tx_id(tx_id), + plan_id(plan_id), + params(params), + symbols(symbols), + batch_size(batch_size), + send_old(send_old), + send_new(send_new) {} + tx::transaction_id_t tx_id; int64_t plan_id; Parameters params; std::vector symbols; int batch_size; + // Indicates which of (old, new) records of a graph element should be sent. + bool send_old; + bool send_new; private: friend class boost::serialization::access; template void save(TArchive &ar, unsigned int) const { + ar << boost::serialization::base_object(*this); ar << tx_id; ar << plan_id; ar << params.size(); for (auto &kv : params) { ar << kv.first; + // Params never contain a vertex/edge, so save plan TypedValue. utils::SaveTypedValue(ar, kv.second); } ar << symbols; ar << batch_size; + ar << send_old; + ar << send_new; } template void load(TArchive &ar, unsigned int) { + ar >> boost::serialization::base_object(*this); ar >> tx_id; ar >> plan_id; size_t params_size; @@ -62,63 +85,277 @@ struct RemotePullReqData { int token_pos; ar >> token_pos; query::TypedValue param; + // Params never contain a vertex/edge, so load plan TypedValue. utils::LoadTypedValue(ar, param); params.Add(token_pos, param); } ar >> symbols; ar >> batch_size; + ar >> send_old; + ar >> send_new; } BOOST_SERIALIZATION_SPLIT_MEMBER() }; -struct RemotePullResData { - public: +/// The data returned to the end consumer (the RemotePull operator). Contains +/// only the relevant parts of the response, ready for use. +struct RemotePullData { RemotePullState pull_state; std::vector> frames; +}; + +/// The data of the remote pull response. Post-processing is required after +/// deserialization to initialize Vertex/Edge typed values in the frames +/// (possibly encapsulated in lists/maps) to their proper values. This requires +/// a GraphDbAccessor and therefore can't be done as part of deserialization. +/// +/// TODO - make it possible to inject a &GraphDbAcessor from the RemotePull +/// layer +/// all the way into RPC data deserialization to remove the requirement for +/// post-processing. The current approach of holding references to parts of the +/// frame (potentially embedded in lists/maps) is too error-prone. +struct RemotePullResData { + private: + // Temp cache for deserialized vertices and edges. These objects are created + // during deserialization. They are used immediatelly after during + // post-processing. The vertex/edge data ownership gets transfered to the + // RemoteCache, and the `element_in_frame` reference is used to set the + // appropriate accessor to the appropriate value. Not used on side that + // generates the response. + template + struct GraphElementData { + using AddressT = storage::Address>; + using PtrT = std::unique_ptr; + + GraphElementData(AddressT address, PtrT old_record, PtrT new_record, + query::TypedValue *element_in_frame) + : global_address(address), + old_record(std::move(old_record)), + new_record(std::move(new_record)), + element_in_frame(element_in_frame) {} + + storage::Address> global_address; + std::unique_ptr old_record; + std::unique_ptr new_record; + // The position in frame is optional. This same structure is used for + // deserializing path elements, in which case the vertex/edge in question is + // not directly part of the frame. + query::TypedValue *element_in_frame; + }; + + // Same like `GraphElementData`, but for paths. + struct PathData { + PathData(query::TypedValue &path_in_frame) : path_in_frame(path_in_frame) {} + std::vector> vertices; + std::vector> edges; + query::TypedValue &path_in_frame; + }; + + public: + RemotePullResData() {} // Default constructor required for serialization. + RemotePullResData(int worker_id, bool send_old, bool send_new) + : worker_id(worker_id), send_old(send_old), send_new(send_new) {} + + RemotePullResData(const RemotePullResData &) = delete; + RemotePullResData &operator=(const RemotePullResData &) = delete; + RemotePullResData(RemotePullResData &&) = default; + RemotePullResData &operator=(RemotePullResData &&) = default; + + RemotePullData state_and_frames; + // Id of the worker on which the response is created, used for serializing + // vertices (converting local to global addresses). + int worker_id; + // Indicates which of (old, new) records of a graph element should be sent. + bool send_old; + bool send_new; + + // Temporary caches used between deserialization and post-processing + // (transfering the ownership of this data to a RemoteCache). + std::vector> vertices; + std::vector> edges; + std::vector paths; + + /// Saves a typed value that is a vertex/edge/path. + template + void SaveGraphElement(TArchive &ar, const query::TypedValue &value) const { + // Helper template function for storing a vertex or an edge. + auto save_element = [&ar, this](auto element_accessor) { + ar << element_accessor.GlobalAddress().raw(); + + // If both old and new are null, we need to reconstruct. + if (!(element_accessor.GetOld() || element_accessor.GetNew())) { + bool result = element_accessor.Reconstruct(); + CHECK(result) << "Attempting to serialize an element not visible to " + "current transaction."; + } + auto *old_rec = element_accessor.GetOld(); + if (send_old && old_rec) { + ar << true; + distributed::SaveElement(ar, *old_rec, worker_id); + } else { + ar << false; + } + if (send_new) { + // Must call SwitchNew as that will trigger a potentially necesary + // Reconstruct. + element_accessor.SwitchNew(); + auto *new_rec = element_accessor.GetNew(); + if (new_rec) { + ar << true; + distributed::SaveElement(ar, *new_rec, worker_id); + } else { + ar << false; + } + } else { + ar << false; + } + }; + switch (value.type()) { + case query::TypedValue::Type::Vertex: + save_element(value.ValueVertex()); + break; + case query::TypedValue::Type::Edge: + save_element(value.ValueEdge()); + break; + case query::TypedValue::Type::Path: { + auto &path = value.ValuePath(); + ar << path.size(); + save_element(path.vertices()[0]); + for (size_t i = 0; i < path.size(); ++i) { + save_element(path.edges()[i]); + save_element(path.vertices()[i + 1]); + } + break; + } + default: + LOG(FATAL) << "Unsupported graph element type: " << value.type(); + } + } + + /// Loads a typed value that is a vertex/edge/path. Part of the + /// deserialization process, populates the temporary data caches which are + /// processed later. + template + void LoadGraphElement(TArchive &ar, query::TypedValue::Type type, + query::TypedValue &value) { + auto load_edge = [](auto &ar) { + bool exists; + ar >> exists; + return exists ? LoadEdge(ar) : nullptr; + }; + auto load_vertex = [](auto &ar) { + bool exists; + ar >> exists; + return exists ? LoadVertex(ar) : nullptr; + }; + + switch (type) { + case query::TypedValue::Type::Vertex: { + Edges::VertexAddress::StorageT address; + ar >> address; + vertices.emplace_back(Edges::VertexAddress(address), load_vertex(ar), + load_vertex(ar), &value); + break; + } + case query::TypedValue::Type::Edge: { + Edges::VertexAddress::StorageT address; + ar >> address; + edges.emplace_back(Edges::EdgeAddress(address), load_edge(ar), + load_edge(ar), &value); + break; + } + case query::TypedValue::Type::Path: { + size_t path_size; + ar >> path_size; + + paths.emplace_back(value); + auto &path_data = paths.back(); + + Edges::VertexAddress::StorageT vertex_address; + Edges::EdgeAddress::StorageT edge_address; + ar >> vertex_address; + path_data.vertices.emplace_back(Edges::VertexAddress(vertex_address), + load_vertex(ar), load_vertex(ar), + nullptr); + for (size_t i = 0; i < path_size; ++i) { + ar >> edge_address; + path_data.edges.emplace_back(Edges::EdgeAddress(edge_address), + load_edge(ar), load_edge(ar), nullptr); + ar >> vertex_address; + path_data.vertices.emplace_back(Edges::VertexAddress(vertex_address), + load_vertex(ar), load_vertex(ar), + nullptr); + } + break; + } + default: + LOG(FATAL) << "Unsupported graph element type: " << type; + } + } +}; + +class RemotePullRes : public communication::rpc::Message { + public: + RemotePullRes() {} + RemotePullRes(RemotePullResData data) : data(std::move(data)) {} + + RemotePullResData data; private: friend class boost::serialization::access; template void save(TArchive &ar, unsigned int) const { - ar << pull_state; - ar << frames.size(); + ar << boost::serialization::base_object(*this); + ar << data.state_and_frames.pull_state; + ar << data.state_and_frames.frames.size(); // We need to indicate how many values are in each frame. // Assume all the frames have an equal number of elements. - ar << (frames.size() == 0 ? 0 : frames[0].size()); - for (const auto &frame : frames) + ar << (data.state_and_frames.frames.size() == 0 + ? 0 + : data.state_and_frames.frames[0].size()); + for (const auto &frame : data.state_and_frames.frames) for (const auto &value : frame) { - utils::SaveTypedValue(ar, value); + utils::SaveTypedValue( + ar, value, [this](TArchive &ar, const query::TypedValue &value) { + data.SaveGraphElement(ar, value); + }); } } template void load(TArchive &ar, unsigned int) { - ar >> pull_state; + ar >> boost::serialization::base_object(*this); + ar >> data.state_and_frames.pull_state; size_t frame_count; ar >> frame_count; + data.state_and_frames.frames.reserve(frame_count); size_t frame_size; ar >> frame_size; for (size_t i = 0; i < frame_count; ++i) { - frames.emplace_back(); - auto ¤t_frame = frames.back(); + data.state_and_frames.frames.emplace_back(); + auto ¤t_frame = data.state_and_frames.frames.back(); + current_frame.reserve(frame_size); for (size_t j = 0; j < frame_size; ++j) { current_frame.emplace_back(); - utils::LoadTypedValue(ar, current_frame.back()); + utils::LoadTypedValue( + ar, current_frame.back(), + [this](TArchive &ar, query::TypedValue::TypedValue::Type type, + query::TypedValue &value) { + data.LoadGraphElement(ar, type, value); + }); } } } BOOST_SERIALIZATION_SPLIT_MEMBER() }; -RPC_SINGLE_MEMBER_MESSAGE(RemotePullReq, RemotePullReqData); -RPC_SINGLE_MEMBER_MESSAGE(RemotePullRes, RemotePullResData); - using RemotePullRpc = communication::rpc::RequestResponse; // TODO make a separate RPC for the continuation of an existing pull, as an -// optimization not to have to send the full RemotePullReqData pack every time. +// optimization not to have to send the full RemotePullReqData pack every +// time. using EndRemotePullReqData = std::pair; RPC_SINGLE_MEMBER_MESSAGE(EndRemotePullReq, EndRemotePullReqData); diff --git a/src/distributed/remote_pull_rpc_clients.hpp b/src/distributed/remote_pull_rpc_clients.hpp index 7168eb704..16799ddda 100644 --- a/src/distributed/remote_pull_rpc_clients.hpp +++ b/src/distributed/remote_pull_rpc_clients.hpp @@ -3,6 +3,7 @@ #include #include +#include "database/graph_db_accessor.hpp" #include "distributed/remote_pull_produce_rpc_messages.hpp" #include "distributed/rpc_worker_clients.hpp" #include "query/frontend/semantic/symbol.hpp" @@ -25,17 +26,53 @@ class RemotePullRpcClients { /// Calls a remote pull asynchroniously. IMPORTANT: take care not to call this /// function for the same (tx_id, worker_id, plan_id) before the previous call /// has ended. - std::future RemotePull( - tx::transaction_id_t tx_id, int worker_id, int64_t plan_id, + std::future RemotePull( + database::GraphDbAccessor &dba, int worker_id, int64_t plan_id, const Parameters ¶ms, const std::vector &symbols, int batch_size = kDefaultBatchSize) { - return clients_.ExecuteOnWorker( - worker_id, [tx_id, plan_id, ¶ms, &symbols, - batch_size](ClientPool &client_pool) { - return client_pool - .Call(RemotePullReqData{tx_id, plan_id, params, - symbols, batch_size}) - ->member; + return clients_.ExecuteOnWorker( + worker_id, + [&dba, plan_id, params, symbols, batch_size](ClientPool &client) { + auto result = + client.Call(dba.transaction_id(), plan_id, params, + symbols, batch_size, true, true); + + auto handle_vertex = [&dba](auto &v) { + dba.remote_vertices().emplace(v.global_address.gid(), + std::move(v.old_record), + std::move(v.new_record)); + if (v.element_in_frame) { + VertexAccessor va(v.global_address, dba); + *v.element_in_frame = va; + } + }; + auto handle_edge = [&dba](auto &e) { + dba.remote_edges().emplace(e.global_address.gid(), + std::move(e.old_record), + std::move(e.new_record)); + if (e.element_in_frame) { + EdgeAccessor ea(e.global_address, dba); + *e.element_in_frame = ea; + } + }; + for (auto &v : result->data.vertices) handle_vertex(v); + for (auto &e : result->data.edges) handle_edge(e); + for (auto &p : result->data.paths) { + handle_vertex(p.vertices[0]); + p.path_in_frame = + query::Path(VertexAccessor(p.vertices[0].global_address, dba)); + query::Path &path_in_frame = p.path_in_frame.ValuePath(); + for (size_t i = 0; i < p.edges.size(); ++i) { + handle_edge(p.edges[i]); + path_in_frame.Expand( + EdgeAccessor(p.edges[i].global_address, dba)); + handle_vertex(p.vertices[i + 1]); + path_in_frame.Expand( + VertexAccessor(p.vertices[i + 1].global_address, dba)); + } + } + + return std::move(result->data.state_and_frames); }); } @@ -68,5 +105,4 @@ class RemotePullRpcClients { private: RpcWorkerClients clients_; }; - } // namespace distributed diff --git a/src/distributed/serialization.hpp b/src/distributed/serialization.hpp index 91079c253..c44e172f9 100644 --- a/src/distributed/serialization.hpp +++ b/src/distributed/serialization.hpp @@ -80,6 +80,18 @@ void SaveEdge(TArchive &ar, const Edge &edge, int worker_id) { impl::SaveProperties(ar, edge.properties_); } +/// Alias for `SaveEdge` allowing for param type resolution. +template +void SaveElement(TArchive &ar, const Edge &record, int worker_id) { + return SaveEdge(ar, record, worker_id); +} + +/// Alias for `SaveVertex` allowing for param type resolution. +template +void SaveElement(TArchive &ar, const Vertex &record, int worker_id) { + return SaveVertex(ar, record, worker_id); +} + namespace impl { template diff --git a/src/query/interpret/eval.hpp b/src/query/interpret/eval.hpp index a4d6f9b83..58fb63771 100644 --- a/src/query/interpret/eval.hpp +++ b/src/query/interpret/eval.hpp @@ -429,7 +429,22 @@ class ExpressionEvaluator : public TreeVisitor { for (auto &kv : map) SwitchAccessors(kv.second); break; } - default: + case TypedValue::Type::Path: + switch (graph_view_) { + case GraphView::NEW: + value.ValuePath().SwitchNew(); + break; + case GraphView::OLD: + value.ValuePath().SwitchOld(); + break; + default: + LOG(FATAL) << "Unhandled GraphView enum"; + } + case TypedValue::Type::Null: + case TypedValue::Type::Bool: + case TypedValue::Type::String: + case TypedValue::Type::Int: + case TypedValue::Type::Double: break; } } diff --git a/src/query/path.hpp b/src/query/path.hpp index e8015daf6..f6a258d88 100644 --- a/src/query/path.hpp +++ b/src/query/path.hpp @@ -76,6 +76,18 @@ class Path { return os; } + /// Calls SwitchNew on all the elements of the path. + void SwitchNew() { + for (auto &v : vertices_) v.SwitchNew(); + for (auto &e : edges_) e.SwitchNew(); + } + + /// Calls SwitchNew on all the elements of the path. + void SwitchOld() { + for (auto &v : vertices_) v.SwitchOld(); + for (auto &e : edges_) e.SwitchOld(); + } + private: // Contains all the vertices in the path. std::vector vertices_; diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 89a928731..c3d41560f 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -2584,7 +2584,7 @@ PullRemote::PullRemoteCursor::PullRemoteCursor(const PullRemote &self, void PullRemote::PullRemoteCursor::EndRemotePull() { if (remote_pull_ended_) return; db_.db().remote_pull_clients().EndAllRemotePulls(db_.transaction().id_, - self_.plan_id()); + self_.plan_id()); remote_pull_ended_ = true; } @@ -2595,13 +2595,16 @@ bool PullRemote::PullRemoteCursor::Pull(Frame &frame, Context &context) { while (worker_ids_.size() > 0 && results_.empty()) { last_pulled_worker_ = (last_pulled_worker_ + 1) % worker_ids_.size(); - auto remote_results = db_.db().remote_pull_clients().RemotePull( - db_.transaction().id_, worker_ids_[last_pulled_worker_], - self_.plan_id(), context.parameters_, self_.symbols()).get(); + auto remote_results = + db_.db() + .remote_pull_clients() + .RemotePull(db_, worker_ids_[last_pulled_worker_], self_.plan_id(), + context.parameters_, self_.symbols()) + .get(); auto get_results = [&]() { - for (auto &result : remote_results.frames) { - results_.emplace(std::move(result)); + for (auto &frame : remote_results.frames) { + results_.emplace(std::move(frame)); } }; diff --git a/src/query/typed_value.cpp b/src/query/typed_value.cpp index f1c193eed..886dedc81 100644 --- a/src/query/typed_value.cpp +++ b/src/query/typed_value.cpp @@ -114,14 +114,18 @@ TypedValue::operator PropertyValue() const { template <> \ type_param &TypedValue::Value() { \ if (type_ != Type::type_enum) \ - throw TypedValueException("Incompatible template param and type"); \ + throw TypedValueException( \ + "Incompatible template param '{}' and type '{}'", Type::type_enum, \ + type_); \ return field; \ } \ \ template <> \ const type_param &TypedValue::Value() const { \ if (type_ != Type::type_enum) \ - throw TypedValueException("Incompatible template param and type"); \ + throw TypedValueException( \ + "Incompatible template param '{}' and type '{}'", Type::type_enum, \ + type_); \ return field; \ } \ \ diff --git a/src/storage/address.hpp b/src/storage/address.hpp index f91bab9c2..d7455eec8 100644 --- a/src/storage/address.hpp +++ b/src/storage/address.hpp @@ -27,7 +27,6 @@ namespace storage { */ template class Address { - using Storage = uint64_t; static constexpr uint64_t kTypeMaskSize{1}; static constexpr uint64_t kTypeMask{(1ULL << kTypeMaskSize) - 1}; static constexpr uint64_t kWorkerIdSize{gid::kWorkerIdSize}; @@ -35,10 +34,12 @@ class Address { static constexpr uint64_t kRemote{1}; public: + using StorageT = uint64_t; + Address() {} // Constructor for raw address value - Address(Storage storage) : storage_(storage) {} + Address(StorageT storage) : storage_(storage) {} // Constructor for local Address. Address(TLocalObj *ptr) { @@ -51,7 +52,7 @@ class Address { // that is storing that vertex/edge Address(gid::Gid global_id, int worker_id) { CHECK(global_id < - (1ULL << (sizeof(Storage) * 8 - kWorkerIdSize - kTypeMaskSize))) + (1ULL << (sizeof(StorageT) * 8 - kWorkerIdSize - kTypeMaskSize))) << "Too large global id"; CHECK(worker_id < (1ULL << kWorkerIdSize)) << "Too larger worker id"; @@ -80,13 +81,13 @@ class Address { } /// Returns raw address value - Storage raw() const { return storage_; } + StorageT raw() const { return storage_; } bool operator==(const Address &other) const { return storage_ == other.storage_; } private: - Storage storage_{0}; + StorageT storage_{0}; }; } // namespace storage diff --git a/src/utils/serialization.hpp b/src/utils/serialization.hpp index 0dc7f8deb..de9f1722b 100644 --- a/src/utils/serialization.hpp +++ b/src/utils/serialization.hpp @@ -42,9 +42,17 @@ void load(TArchive &ar, std::experimental::optional &opt, unsigned int) { namespace utils { -/** Saves the given value into the given Boost archive. */ +/** + * Saves the given value into the given Boost archive. The optional + * `save_graph_element` function is called if the given `value` is a + * [Vertex|Edge|Path]. If that function is not provided, and `value` is one of + * those, an exception is thrown. + */ template -void SaveTypedValue(TArchive &ar, const query::TypedValue &value) { +void SaveTypedValue( + TArchive &ar, const query::TypedValue &value, + std::function + save_graph_element = nullptr) { ar << value.type(); switch (value.type()) { case query::TypedValue::Type::Null: @@ -65,7 +73,7 @@ void SaveTypedValue(TArchive &ar, const query::TypedValue &value) { const auto &values = value.Value>(); ar << values.size(); for (const auto &v : values) { - SaveTypedValue(ar, v); + SaveTypedValue(ar, v, save_graph_element); } return; } @@ -74,21 +82,32 @@ void SaveTypedValue(TArchive &ar, const query::TypedValue &value) { ar << map.size(); for (const auto &key_value : map) { ar << key_value.first; - SaveTypedValue(ar, key_value.second); + SaveTypedValue(ar, key_value.second, save_graph_element); } return; } case query::TypedValue::Type::Vertex: case query::TypedValue::Type::Edge: case query::TypedValue::Type::Path: - throw utils::BasicException("Unable to archive TypedValue of type: {}", - value.type()); + if (save_graph_element) { + save_graph_element(ar, value); + } else { + throw utils::BasicException("Unable to archive TypedValue of type: {}", + value.type()); + } } } -/** Loads a typed value into the given reference from the given archive. */ +/** Loads a typed value into the given reference from the given archive. The + * optional `load_graph_element` function is called if a [Vertex|Edge|Path] + * TypedValue should be unarchived. If that function is not provided, and + * `value` is one of those, an exception is thrown. + */ template -void LoadTypedValue(TArchive &ar, query::TypedValue &value) { +void LoadTypedValue(TArchive &ar, query::TypedValue &value, + std::function + load_graph_element = nullptr) { query::TypedValue::Type type = query::TypedValue::Type::Null; ar >> type; switch (type) { @@ -119,37 +138,38 @@ void LoadTypedValue(TArchive &ar, query::TypedValue &value) { return; } case query::TypedValue::Type::List: { - std::vector values; + value = std::vector{}; + auto &list = value.ValueList(); size_t size; ar >> size; - values.reserve(size); + list.reserve(size); for (size_t i = 0; i < size; ++i) { - query::TypedValue tv; - LoadTypedValue(ar, tv); - values.emplace_back(tv); + list.emplace_back(); + LoadTypedValue(ar, list.back(), load_graph_element); } - value = values; return; } case query::TypedValue::Type::Map: { - std::map map; + value = std::map{}; + auto &map = value.ValueMap(); size_t size; ar >> size; for (size_t i = 0; i < size; ++i) { std::string key; ar >> key; - query::TypedValue v; - LoadTypedValue(ar, v); - map.emplace(key, v); + LoadTypedValue(ar, map[key], load_graph_element); } - value = map; return; } case query::TypedValue::Type::Vertex: case query::TypedValue::Type::Edge: case query::TypedValue::Type::Path: - throw utils::BasicException( - "Unexpected TypedValue type '{}' when loading from archive", type); + if (load_graph_element) { + load_graph_element(ar, type, value); + } else { + throw utils::BasicException( + "Unexpected TypedValue type '{}' when loading from archive", type); + } } } } // namespace utils diff --git a/tests/unit/distributed_graph_db.cpp b/tests/unit/distributed_graph_db.cpp index 4ec8a7942..cfb2e30a8 100644 --- a/tests/unit/distributed_graph_db.cpp +++ b/tests/unit/distributed_graph_db.cpp @@ -1,5 +1,7 @@ #include +#include #include +#include #include "gtest/gtest.h" @@ -235,8 +237,7 @@ TEST_F(DistributedGraphDbTest, DispatchPlan) { } TEST_F(DistributedGraphDbTest, RemotePullProduceRpc) { - database::GraphDb &db = master(); - database::GraphDbAccessor dba{db}; + database::GraphDbAccessor dba{master()}; Context ctx{dba}; SymbolGenerator symbol_generator{ctx.symbol_table_}; AstTreeStorage storage; @@ -263,8 +264,8 @@ TEST_F(DistributedGraphDbTest, RemotePullProduceRpc) { Parameters params; std::vector symbols{ctx.symbol_table_[*x_ne]}; auto remote_pull = [this, plan_id, ¶ms, &symbols]( - tx::transaction_id_t tx_id, int worker_id) { - return master().remote_pull_clients().RemotePull(tx_id, worker_id, plan_id, + database::GraphDbAccessor &dba, int worker_id) { + return master().remote_pull_clients().RemotePull(dba, worker_id, plan_id, params, symbols, 3); }; auto expect_first_batch = [](auto &batch) { @@ -287,13 +288,14 @@ TEST_F(DistributedGraphDbTest, RemotePullProduceRpc) { database::GraphDbAccessor dba_1{master()}; database::GraphDbAccessor dba_2{master()}; for (int worker_id : {1, 2}) { - auto tx1_batch1 = remote_pull(dba_1.transaction_id(), worker_id).get(); + // TODO flor, proper test async here. + auto tx1_batch1 = remote_pull(dba_1, worker_id).get(); expect_first_batch(tx1_batch1); - auto tx2_batch1 = remote_pull(dba_2.transaction_id(), worker_id).get(); + auto tx2_batch1 = remote_pull(dba_2, worker_id).get(); expect_first_batch(tx2_batch1); - auto tx2_batch2 = remote_pull(dba_2.transaction_id(), worker_id).get(); + auto tx2_batch2 = remote_pull(dba_2, worker_id).get(); expect_second_batch(tx2_batch2); - auto tx1_batch2 = remote_pull(dba_1.transaction_id(), worker_id).get(); + auto tx1_batch2 = remote_pull(dba_1, worker_id).get(); expect_second_batch(tx1_batch2); } master().remote_pull_clients().EndAllRemotePulls(dba_1.transaction_id(), @@ -302,6 +304,103 @@ TEST_F(DistributedGraphDbTest, RemotePullProduceRpc) { plan_id); } +TEST_F(DistributedGraphDbTest, RemotePullProduceRpcWithGraphElements) { + // Create some data on the master and both workers. Eeach edge (3 of them) and + // vertex (6 of them) will be uniquely identified with their worker id and + // sequence ID, so we can check we retrieved all. + storage::Property prop; + { + database::GraphDbAccessor dba{master()}; + prop = dba.Property("prop"); + auto create_data = [prop](database::GraphDbAccessor &dba, int worker_id) { + auto v1 = dba.InsertVertex(); + v1.PropsSet(prop, worker_id * 10); + auto v2 = dba.InsertVertex(); + v2.PropsSet(prop, worker_id * 10 + 1); + auto e12 = dba.InsertEdge(v1, v2, dba.EdgeType("et")); + e12.PropsSet(prop, worker_id * 10 + 2); + }; + create_data(dba, 0); + database::GraphDbAccessor dba_w1{worker1(), dba.transaction_id()}; + create_data(dba_w1, 1); + database::GraphDbAccessor dba_w2{worker2(), dba.transaction_id()}; + create_data(dba_w2, 2); + dba.Commit(); + } + + database::GraphDbAccessor dba{master()}; + Context ctx{dba}; + SymbolGenerator symbol_generator{ctx.symbol_table_}; + AstTreeStorage storage; + + // Query plan for: MATCH p = (n)-[r]->(m) return [n, r], m, p + // Use this query to test graph elements are transferred correctly in + // collections too. + auto n = MakeScanAll(storage, ctx.symbol_table_, "n"); + auto r_m = + MakeExpand(storage, ctx.symbol_table_, n.op_, n.sym_, "r", + EdgeAtom::Direction::OUT, {}, "m", false, GraphView::OLD); + auto p_sym = ctx.symbol_table_.CreateSymbol("p", true); + auto p = std::make_shared( + r_m.op_, p_sym, + std::vector{n.sym_, r_m.edge_sym_, r_m.node_sym_}); + auto return_n = IDENT("n"); + ctx.symbol_table_[*return_n] = n.sym_; + auto return_r = IDENT("r"); + ctx.symbol_table_[*return_r] = r_m.edge_sym_; + auto return_n_r = NEXPR("[n, r]", LIST(return_n, return_r)); + ctx.symbol_table_[*return_n_r] = ctx.symbol_table_.CreateSymbol("", true); + auto return_m = NEXPR("m", IDENT("m")); + ctx.symbol_table_[*return_m->expression_] = r_m.node_sym_; + ctx.symbol_table_[*return_m] = ctx.symbol_table_.CreateSymbol("", true); + auto return_p = NEXPR("p", IDENT("p")); + ctx.symbol_table_[*return_p->expression_] = p_sym; + ctx.symbol_table_[*return_p] = ctx.symbol_table_.CreateSymbol("", true); + auto produce = MakeProduce(p, return_n_r, return_m, return_p); + + auto check_result = [prop]( + int worker_id, + const std::vector> &frames) { + int offset = worker_id * 10; + ASSERT_EQ(frames.size(), 1); + auto &row = frames[0]; + ASSERT_EQ(row.size(), 3); + auto &list = row[0].ValueList(); + ASSERT_EQ(list.size(), 2); + ASSERT_EQ(list[0].ValueVertex().PropsAt(prop).Value(), offset); + ASSERT_EQ(list[1].ValueEdge().PropsAt(prop).Value(), offset + 2); + ASSERT_EQ(row[1].ValueVertex().PropsAt(prop).Value(), offset + 1); + auto &path = row[2].ValuePath(); + ASSERT_EQ(path.size(), 1); + ASSERT_EQ(path.vertices()[0].PropsAt(prop).Value(), offset); + ASSERT_EQ(path.edges()[0].PropsAt(prop).Value(), offset + 2); + ASSERT_EQ(path.vertices()[1].PropsAt(prop).Value(), offset + 1); + }; + + // Test that the plan works locally. + auto results = CollectProduce(produce.get(), ctx.symbol_table_, dba); + check_result(0, results); + + const int plan_id = 42; + master().plan_dispatcher().DispatchPlan(plan_id, produce, ctx.symbol_table_); + + Parameters params; + std::vector symbols{ctx.symbol_table_[*return_n_r], + ctx.symbol_table_[*return_m], p_sym}; + auto remote_pull = [this, plan_id, ¶ms, &symbols]( + database::GraphDbAccessor &dba, int worker_id) { + return master().remote_pull_clients().RemotePull(dba, worker_id, plan_id, + params, symbols, 3); + }; + auto future_w1_results = remote_pull(dba, 1); + auto future_w2_results = remote_pull(dba, 2); + check_result(1, future_w1_results.get().frames); + check_result(2, future_w2_results.get().frames); + + master().remote_pull_clients().EndAllRemotePulls(dba.transaction_id(), + plan_id); +} + TEST_F(DistributedGraphDbTest, BuildIndexDistributed) { using GraphDbAccessor = database::GraphDbAccessor; storage::Label label;