diff --git a/src/communication/rpc/messages-inl.hpp b/src/communication/rpc/messages-inl.hpp index b18ee14d5..c2c51e40b 100644 --- a/src/communication/rpc/messages-inl.hpp +++ b/src/communication/rpc/messages-inl.hpp @@ -5,6 +5,7 @@ #include "distributed/coordination_rpc_messages.hpp" #include "distributed/plan_rpc_messages.hpp" #include "distributed/remote_data_rpc_messages.hpp" +#include "distributed/remote_pull_produce_rpc_messages.hpp" #include "storage/concurrent_id_mapper_rpc_messages.hpp" #include "transactions/engine_rpc_messages.hpp" @@ -47,3 +48,11 @@ BOOST_CLASS_EXPORT(distributed::TxGidPair); // Distributed plan exchange. BOOST_CLASS_EXPORT(distributed::DispatchPlanReq); BOOST_CLASS_EXPORT(distributed::ConsumePlanRes); + +// Remote pull. +BOOST_CLASS_EXPORT(distributed::RemotePullReqData); +BOOST_CLASS_EXPORT(distributed::RemotePullReq); +BOOST_CLASS_EXPORT(distributed::RemotePullResData); +BOOST_CLASS_EXPORT(distributed::RemotePullRes); +BOOST_CLASS_EXPORT(distributed::EndRemotePullReq); +BOOST_CLASS_EXPORT(distributed::EndRemotePullRes); diff --git a/src/database/graph_db.cpp b/src/database/graph_db.cpp index 2437dd351..fee7a2148 100644 --- a/src/database/graph_db.cpp +++ b/src/database/graph_db.cpp @@ -8,6 +8,8 @@ #include "distributed/plan_dispatcher.hpp" #include "distributed/remote_data_rpc_clients.hpp" #include "distributed/remote_data_rpc_server.hpp" +#include "distributed/remote_produce_rpc_server.hpp" +#include "distributed/remote_pull_rpc_clients.hpp" #include "durability/paths.hpp" #include "durability/recovery.hpp" #include "durability/snapshooter.hpp" @@ -35,6 +37,19 @@ class PrivateBase : public GraphDb { durability::WriteAheadLog &wal() override { return wal_; } int WorkerId() const override { return config_.worker_id; } + distributed::RemotePullRpcClients &remote_pull_clients() override { + LOG(FATAL) << "Remote pull clients only available in master."; + } + distributed::RemoteProduceRpcServer &remote_produce_server() override { + LOG(FATAL) << "Remote produce server only available in worker."; + } + distributed::PlanConsumer &plan_consumer() override { + LOG(FATAL) << "Plan consumer only available in distributed worker."; + } + distributed::PlanDispatcher &plan_dispatcher() override { + LOG(FATAL) << "Plan dispatcher only available in distributed master."; + } + protected: Storage storage_{config_.worker_id}; durability::WriteAheadLog wal_{config_.worker_id, @@ -83,12 +98,6 @@ class SingleNode : public PrivateBase { distributed::RemoteDataRpcClients &remote_data_clients() override { LOG(FATAL) << "Remote data clients not available in single-node."; } - distributed::PlanDispatcher &plan_dispatcher() override { - LOG(FATAL) << "Plan Dispatcher not available in single-node."; - } - distributed::PlanConsumer &plan_consumer() override { - LOG(FATAL) << "Plan Consumer not available in single-node."; - } }; #define IMPL_DISTRIBUTED_GETTERS \ @@ -110,8 +119,8 @@ class Master : public PrivateBase { distributed::PlanDispatcher &plan_dispatcher() override { return plan_dispatcher_; } - distributed::PlanConsumer &plan_consumer() override { - LOG(FATAL) << "Plan Consumer not available in single-node."; + distributed::RemotePullRpcClients &remote_pull_clients() override { + return remote_pull_clients_; } communication::rpc::System system_{config_.master_endpoint}; @@ -123,6 +132,7 @@ class Master : public PrivateBase { distributed::RemoteDataRpcServer remote_data_server_{*this, system_}; distributed::RemoteDataRpcClients remote_data_clients_{coordination_}; distributed::PlanDispatcher plan_dispatcher_{coordination_}; + distributed::RemotePullRpcClients remote_pull_clients_{coordination_}; }; class Worker : public PrivateBase { @@ -137,8 +147,8 @@ class Worker : public PrivateBase { IMPL_GETTERS IMPL_DISTRIBUTED_GETTERS distributed::PlanConsumer &plan_consumer() override { return plan_consumer_; } - distributed::PlanDispatcher &plan_dispatcher() override { - LOG(FATAL) << "Plan Dispatcher not available in single-node."; + distributed::RemoteProduceRpcServer &remote_produce_server() override { + return remote_produce_server_; } communication::rpc::System system_{config_.worker_endpoint}; @@ -151,6 +161,8 @@ class Worker : public PrivateBase { distributed::RemoteDataRpcServer remote_data_server_{*this, system_}; distributed::RemoteDataRpcClients remote_data_clients_{coordination_}; distributed::PlanConsumer plan_consumer_{system_}; + distributed::RemoteProduceRpcServer remote_produce_server_{*this, system_, + plan_consumer_}; }; #undef IMPL_GETTERS @@ -204,6 +216,12 @@ distributed::PlanDispatcher &PublicBase::plan_dispatcher() { distributed::PlanConsumer &PublicBase::plan_consumer() { return impl_->plan_consumer(); } +distributed::RemotePullRpcClients &PublicBase::remote_pull_clients() { + return impl_->remote_pull_clients(); +} +distributed::RemoteProduceRpcServer &PublicBase::remote_produce_server() { + return impl_->remote_produce_server(); +} void PublicBase::MakeSnapshot() { const bool status = durability::MakeSnapshot( diff --git a/src/database/graph_db.hpp b/src/database/graph_db.hpp index a20da5fd8..5e56ec128 100644 --- a/src/database/graph_db.hpp +++ b/src/database/graph_db.hpp @@ -18,6 +18,8 @@ class RemoteDataRpcServer; class RemoteDataRpcClients; class PlanDispatcher; class PlanConsumer; +class RemotePullRpcClients; +class RemoteProduceRpcServer; } namespace database { @@ -88,7 +90,14 @@ class GraphDb { // Supported only in distributed master and worker, not in single-node. virtual distributed::RemoteDataRpcServer &remote_data_server() = 0; virtual distributed::RemoteDataRpcClients &remote_data_clients() = 0; + + // Supported only in distributed master. + virtual distributed::RemotePullRpcClients &remote_pull_clients() = 0; virtual distributed::PlanDispatcher &plan_dispatcher() = 0; + + // Supported only in distributed worker. + // TODO remove once end2end testing is possible. + virtual distributed::RemoteProduceRpcServer &remote_produce_server() = 0; virtual distributed::PlanConsumer &plan_consumer() = 0; GraphDb(const GraphDb &) = delete; @@ -121,6 +130,8 @@ class PublicBase : public GraphDb { distributed::RemoteDataRpcClients &remote_data_clients() override; distributed::PlanDispatcher &plan_dispatcher() override; distributed::PlanConsumer &plan_consumer() override; + distributed::RemotePullRpcClients &remote_pull_clients() override; + distributed::RemoteProduceRpcServer &remote_produce_server() override; protected: explicit PublicBase(std::unique_ptr<PrivateBase> impl); diff --git a/src/distributed/plan_consumer.cpp b/src/distributed/plan_consumer.cpp index 32e3d3b58..50babb596 100644 --- a/src/distributed/plan_consumer.cpp +++ b/src/distributed/plan_consumer.cpp @@ -5,19 +5,21 @@ namespace distributed { PlanConsumer::PlanConsumer(communication::rpc::System &system) : server_(system, kDistributedPlanServerName) { server_.Register<DistributedPlanRpc>([this](const DispatchPlanReq &req) { - plan_cache_.access().insert(req.plan_id_, - std::make_pair(req.plan_, req.symbol_table_)); + plan_cache_.access().insert( + req.plan_id_, + std::make_unique<PlanPack>( + req.plan_, req.symbol_table_, + std::move(const_cast<DispatchPlanReq &>(req).storage_))); return std::make_unique<ConsumePlanRes>(true); }); } -std::pair<std::shared_ptr<query::plan::LogicalOperator>, SymbolTable> -PlanConsumer::PlanForId(int64_t plan_id) { +PlanConsumer::PlanPack &PlanConsumer::PlanForId(int64_t plan_id) const { auto accessor = plan_cache_.access(); auto found = accessor.find(plan_id); CHECK(found != accessor.end()) << "Missing plan and symbol table for plan id!"; - return found->second; + return *found->second; } } // namespace distributed diff --git a/src/distributed/plan_consumer.hpp b/src/distributed/plan_consumer.hpp index b6a77c03d..daf5a02ee 100644 --- a/src/distributed/plan_consumer.hpp +++ b/src/distributed/plan_consumer.hpp @@ -9,24 +9,31 @@ namespace distributed { /** Handles plan consumption from master. Creates and holds a local cache of - * plans. Worker side. - */ + * plans. Worker side. */ class PlanConsumer { public: + struct PlanPack { + PlanPack(std::shared_ptr<query::plan::LogicalOperator> plan, + SymbolTable symbol_table, AstTreeStorage storage) + : plan(plan), + symbol_table(std::move(symbol_table)), + storage(std::move(storage)) {} + + std::shared_ptr<query::plan::LogicalOperator> plan; + SymbolTable symbol_table; + const AstTreeStorage storage; + }; + explicit PlanConsumer(communication::rpc::System &system); - /** - * Return cached plan and symbol table for a given plan id. - */ - std::pair<std::shared_ptr<query::plan::LogicalOperator>, SymbolTable> - PlanForId(int64_t plan_id); + /** Return cached plan and symbol table for a given plan id. */ + PlanPack &PlanForId(int64_t plan_id) const; private: communication::rpc::Server server_; - mutable ConcurrentMap< - int64_t, - std::pair<std::shared_ptr<query::plan::LogicalOperator>, SymbolTable>> - plan_cache_; + // TODO remove unique_ptr. This is to get it to work, emplacing into a + // ConcurrentMap is tricky. + mutable ConcurrentMap<int64_t, std::unique_ptr<PlanPack>> plan_cache_; }; } // namespace distributed diff --git a/src/distributed/remote_produce_rpc_server.hpp b/src/distributed/remote_produce_rpc_server.hpp new file mode 100644 index 000000000..2ba2a440c --- /dev/null +++ b/src/distributed/remote_produce_rpc_server.hpp @@ -0,0 +1,141 @@ +#pragma once + +#include <cstdint> +#include <map> +#include <mutex> +#include <utility> +#include <vector> + +#include "communication/rpc/server.hpp" +#include "database/graph_db.hpp" +#include "database/graph_db_accessor.hpp" +#include "distributed/plan_consumer.hpp" +#include "distributed/remote_pull_produce_rpc_messages.hpp" +#include "query/context.hpp" +#include "query/frontend/semantic/symbol_table.hpp" +#include "query/interpret/frame.hpp" +#include "query/parameters.hpp" +#include "query/plan/operator.hpp" +#include "query/typed_value.hpp" +#include "transactions/type.hpp" + +namespace distributed { + +/** + * Handles the execution of a plan on the worker, requested by the remote + * master. Assumes that (tx_id, plan_id) uniquely identifies an execution, and + * that there will never be parallel requests for the same execution thus + * identified. + */ +class RemoteProduceRpcServer { + /** Encapsulates an execution in progress. */ + class OngoingProduce { + public: + OngoingProduce(database::GraphDb &db, tx::transaction_id_t tx_id, + std::shared_ptr<query::plan::LogicalOperator> op, + query::SymbolTable symbol_table, Parameters parameters, + std::vector<query::Symbol> pull_symbols) + : dba_{db, tx_id}, + cursor_(op->MakeCursor(dba_)), + context_(dba_), + pull_symbols_(std::move(pull_symbols)), + frame_(symbol_table.max_position()) { + context_.symbol_table_ = std::move(symbol_table); + context_.parameters_ = std::move(parameters); + } + + /** Returns a vector of typed values (one for each `pull_symbol`), and a + * `bool` indicating if the pull was successful (or the cursor is + * exhausted). */ + std::pair<std::vector<query::TypedValue>, bool> Pull() { + std::vector<query::TypedValue> results; + auto success = cursor_->Pull(frame_, context_); + if (success) { + results.reserve(pull_symbols_.size()); + for (const auto &symbol : pull_symbols_) + results.emplace_back(std::move(frame_[symbol])); + } + return std::make_pair(std::move(results), success); + } + + private: + // TODO currently each OngoingProduce has it's own GDBA. There is no sharing + // of them in the same transaction. This should be correct, but it's + // inefficient in multi-command queries, and when a single query will get + // broken down into multiple parts. + database::GraphDbAccessor dba_; + std::unique_ptr<query::plan::Cursor> cursor_; + query::Context context_; + std::vector<query::Symbol> pull_symbols_; + query::Frame frame_; + }; + + public: + RemoteProduceRpcServer(database::GraphDb &db, + communication::rpc::System &system, + const distributed::PlanConsumer &plan_consumer) + : db_(db), + remote_produce_rpc_server_(system, kRemotePullProduceRpcName), + plan_consumer_(plan_consumer) { + remote_produce_rpc_server_.Register<RemotePullRpc>( + [this](const RemotePullReq &req) { + return std::make_unique<RemotePullRes>(RemotePull(req.member)); + }); + + remote_produce_rpc_server_.Register<EndRemotePullRpc>([this]( + const EndRemotePullReq &req) { + std::lock_guard<std::mutex> guard{ongoing_produces_lock_}; + auto it = ongoing_produces_.find(req.member); + CHECK(it != ongoing_produces_.end()) << "Failed to find ongoing produce"; + ongoing_produces_.erase(it); + return std::make_unique<EndRemotePullRes>(); + }); + } + + private: + database::GraphDb &db_; + communication::rpc::Server remote_produce_rpc_server_; + const distributed::PlanConsumer &plan_consumer_; + + std::map<std::pair<tx::transaction_id_t, int64_t>, OngoingProduce> + ongoing_produces_; + std::mutex ongoing_produces_lock_; + + auto &GetOngoingProduce(const RemotePullReqData &req) { + std::lock_guard<std::mutex> guard{ongoing_produces_lock_}; + auto found = ongoing_produces_.find({req.tx_id, req.plan_id}); + if (found != ongoing_produces_.end()) { + return found->second; + } + + auto &plan_pack = plan_consumer_.PlanForId(req.plan_id); + return ongoing_produces_ + .emplace(std::piecewise_construct, + std::forward_as_tuple(req.tx_id, req.plan_id), + std::forward_as_tuple(db_, req.tx_id, plan_pack.plan, + plan_pack.symbol_table, req.params, + req.symbols)) + .first->second; + } + + RemotePullResData RemotePull(const RemotePullReqData &req) { + auto &ongoing_produce = GetOngoingProduce(req); + + RemotePullResData result; + result.pull_state = RemotePullState::CURSOR_IN_PROGRESS; + + for (int i = 0; i < req.batch_size; ++i) { + // TODO exception handling (Serialization errors) + // when full CRUD. Maybe put it in OngoingProduce::Pull + auto pull_result = ongoing_produce.Pull(); + if (!pull_result.second) { + result.pull_state = RemotePullState::CURSOR_EXHAUSTED; + break; + } + result.frames.emplace_back(std::move(pull_result.first)); + } + + return result; + } +}; +} // namespace distributed diff --git a/src/distributed/remote_pull_produce_rpc_messages.hpp b/src/distributed/remote_pull_produce_rpc_messages.hpp new file mode 100644 index 000000000..67f3a1330 --- /dev/null +++ b/src/distributed/remote_pull_produce_rpc_messages.hpp @@ -0,0 +1,130 @@ +#pragma once + +#include <cstdint> +#include <string> + +#include "boost/serialization/utility.hpp" +#include "boost/serialization/vector.hpp" + +#include "communication/rpc/messages.hpp" +#include "query/frontend/semantic/symbol.hpp" +#include "query/parameters.hpp" +#include "transactions/type.hpp" +#include "utils/serialization.hpp" + +namespace distributed { + +/// The default number of results returned via RPC from remote execution to the +/// master that requested it. +constexpr int kDefaultBatchSize = 20; + +/** Returnd along with a batch of results in the remote-pull RPC. Indicates the + * state of execution on the worker. */ +enum class RemotePullState { + CURSOR_EXHAUSTED, + CURSOR_IN_PROGRESS, + SERIALIZATION_ERROR // future-proofing for full CRUD + // TODO in full CRUD other errors +}; + +const std::string kRemotePullProduceRpcName = "RemotePullProduceRpc"; + +struct RemotePullReqData { + tx::transaction_id_t tx_id; + int64_t plan_id; + Parameters params; + std::vector<query::Symbol> symbols; + int batch_size; + + private: + friend class boost::serialization::access; + + template <class TArchive> + void save(TArchive &ar, unsigned int) const { + ar << tx_id; + ar << plan_id; + ar << params.size(); + for (auto &kv : params) { + ar << kv.first; + utils::SaveTypedValue(ar, kv.second); + } + ar << symbols; + ar << batch_size; + } + + template <class TArchive> + void load(TArchive &ar, unsigned int) { + ar >> tx_id; + ar >> plan_id; + size_t params_size; + ar >> params_size; + for (size_t i = 0; i < params_size; ++i) { + int token_pos; + ar >> token_pos; + query::TypedValue param; + utils::LoadTypedValue(ar, param); + params.Add(token_pos, param); + } + ar >> symbols; + ar >> batch_size; + } + BOOST_SERIALIZATION_SPLIT_MEMBER() +}; + +struct RemotePullResData { + public: + RemotePullState pull_state; + std::vector<std::vector<query::TypedValue>> frames; + + private: + friend class boost::serialization::access; + + template <class TArchive> + void save(TArchive &ar, unsigned int) const { + ar << pull_state; + ar << frames.size(); + // We need to indicate how many values are in each frame. + // Assume all the frames have an equal number of elements. + ar << (frames.size() == 0 ? 0 : frames[0].size()); + for (const auto &frame : frames) + for (const auto &value : frame) { + utils::SaveTypedValue(ar, value); + } + } + + template <class TArchive> + void load(TArchive &ar, unsigned int) { + ar >> pull_state; + size_t frame_count; + ar >> frame_count; + size_t frame_size; + ar >> frame_size; + for (size_t i = 0; i < frame_count; ++i) { + frames.emplace_back(); + auto ¤t_frame = frames.back(); + for (size_t j = 0; j < frame_size; ++j) { + current_frame.emplace_back(); + utils::LoadTypedValue(ar, current_frame.back()); + } + } + } + BOOST_SERIALIZATION_SPLIT_MEMBER() +}; + +RPC_SINGLE_MEMBER_MESSAGE(RemotePullReq, RemotePullReqData); +RPC_SINGLE_MEMBER_MESSAGE(RemotePullRes, RemotePullResData); + +using RemotePullRpc = + communication::rpc::RequestResponse<RemotePullReq, RemotePullRes>; + +// TODO make a separate RPC for the continuation of an existing pull, as an +// optimization not to have to send the full RemotePullReqData pack every time. + +using EndRemotePullReqData = std::pair<tx::transaction_id_t, int64_t>; +RPC_SINGLE_MEMBER_MESSAGE(EndRemotePullReq, EndRemotePullReqData); +RPC_NO_MEMBER_MESSAGE(EndRemotePullRes); + +using EndRemotePullRpc = + communication::rpc::RequestResponse<EndRemotePullReq, EndRemotePullRes>; + +} // namespace distributed diff --git a/src/distributed/remote_pull_rpc_clients.hpp b/src/distributed/remote_pull_rpc_clients.hpp new file mode 100644 index 000000000..984987c21 --- /dev/null +++ b/src/distributed/remote_pull_rpc_clients.hpp @@ -0,0 +1,53 @@ +#pragma once + +#include <functional> +#include <vector> + +#include "distributed/remote_pull_produce_rpc_messages.hpp" +#include "distributed/rpc_worker_clients.hpp" +#include "query/frontend/semantic/symbol.hpp" +#include "query/parameters.hpp" +#include "transactions/type.hpp" + +namespace distributed { + +/** Provides means of calling for the execution of a plan on some remote worker, + * and getting the results of that execution. The results are returned in + * batches and are therefore accompanied with an enum indicator of the state of + * remote execution. */ +class RemotePullRpcClients { + public: + RemotePullRpcClients(Coordination &coordination) + : clients_(coordination, kRemotePullProduceRpcName) {} + + RemotePullResData RemotePull(tx::transaction_id_t tx_id, int worker_id, + int64_t plan_id, const Parameters ¶ms, + const std::vector<query::Symbol> &symbols, + int batch_size = kDefaultBatchSize) { + return std::move(clients_.GetClient(worker_id) + .Call<RemotePullRpc>(RemotePullReqData{ + tx_id, plan_id, params, symbols, batch_size}) + ->member); + } + + // Notifies all workers that the given transaction/plan is done. Otherwise the + // server is left with potentially unconsumed Cursors that never get deleted. + // + // TODO - maybe this needs to be done with hooks into the transactional + // engine, so that the Worker discards it's stuff when the relevant + // transaction are done. + // + // TODO - this will maybe need a per-worker granularity. + void EndRemotePull(tx::transaction_id_t tx_id, int64_t plan_id) { + auto futures = clients_.ExecuteOnWorkers<void>( + 0, [tx_id, plan_id](communication::rpc::Client &client) { + client.Call<EndRemotePullRpc>(EndRemotePullReqData{tx_id, plan_id}); + }); + for (auto &future : futures) future.wait(); + } + + private: + RpcWorkerClients clients_; +}; + +} // namespace distributed diff --git a/src/distributed/rpc_worker_clients.hpp b/src/distributed/rpc_worker_clients.hpp index 3dd7837c0..b459148af 100644 --- a/src/distributed/rpc_worker_clients.hpp +++ b/src/distributed/rpc_worker_clients.hpp @@ -3,6 +3,7 @@ #include <functional> #include <future> #include <type_traits> +#include <unordered_map> #include "communication/rpc/client.hpp" #include "distributed/coordination.hpp" diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 738dd30f9..a7862461b 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -196,5 +196,4 @@ Interpreter::MakeLogicalPlan(AstTreeStorage &ast_storage, return plan::MakeLogicalPlan(planning_context, context.parameters_, FLAGS_query_cost_planner); }; - } // namespace query diff --git a/src/query/parameters.hpp b/src/query/parameters.hpp index e27a53918..91cb5c84e 100644 --- a/src/query/parameters.hpp +++ b/src/query/parameters.hpp @@ -1,10 +1,4 @@ -// -// Copyright 2017 Memgraph -// Created by Florijan Stamenkovic on 08.03.17. -// - -#ifndef MEMGRAPH_PARAMETERS_HPP -#define MEMGRAPH_PARAMETERS_HPP +#pragma once #include <algorithm> #include <utility> @@ -16,6 +10,7 @@ * Encapsulates user provided parameters (and stripped literals) * and provides ways of obtaining them by position. */ +// TODO move to namespace query:: struct Parameters { public: /** @@ -57,10 +52,11 @@ struct Parameters { } /** Returns the number of arguments in this container */ - int size() const { return storage_.size(); } + auto size() const { return storage_.size(); } + + auto begin() const { return storage_.begin(); } + auto end() const { return storage_.end(); } private: std::vector<std::pair<int, query::TypedValue>> storage_; }; - -#endif // MEMGRAPH_PARAMETERS_HPP diff --git a/tests/unit/distributed_graph_db.cpp b/tests/unit/distributed_graph_db.cpp index 670f517c9..f075114bc 100644 --- a/tests/unit/distributed_graph_db.cpp +++ b/tests/unit/distributed_graph_db.cpp @@ -11,10 +11,19 @@ #include "distributed/plan_dispatcher.hpp" #include "distributed/remote_data_rpc_clients.hpp" #include "distributed/remote_data_rpc_server.hpp" +#include "distributed/remote_pull_rpc_clients.hpp" #include "io/network/endpoint.hpp" +#include "query/frontend/ast/ast.hpp" +#include "query/frontend/ast/cypher_main_visitor.hpp" +#include "query/frontend/semantic/symbol_generator.hpp" +#include "query/frontend/semantic/symbol_table.hpp" +#include "query/plan/planner.hpp" #include "query_plan_common.hpp" #include "transactions/engine_master.hpp" +#include "query_common.hpp" +#include "query_plan_common.hpp" + template <typename T> using optional = std::experimental::optional<T>; @@ -215,18 +224,79 @@ TEST_F(DistributedGraphDbTest, DispatchPlan) { master().plan_dispatcher().DispatchPlan(plan_id, scan_all.op_, symbol_table); std::this_thread::sleep_for(kRPCWaitTime); - { - auto cached = worker1().plan_consumer().PlanForId(plan_id); - EXPECT_NE(dynamic_cast<query::plan::ScanAll *>(cached.first.get()), - nullptr); - EXPECT_EQ(cached.second.max_position(), symbol_table.max_position()); - EXPECT_EQ(cached.second.table(), symbol_table.table()); - } - { - auto cached = worker2().plan_consumer().PlanForId(plan_id); - EXPECT_NE(dynamic_cast<query::plan::ScanAll *>(cached.first.get()), - nullptr); - EXPECT_EQ(cached.second.max_position(), symbol_table.max_position()); - EXPECT_EQ(cached.second.table(), symbol_table.table()); - } + auto check_for_worker = [plan_id, &symbol_table](auto &worker) { + auto &cached = worker.plan_consumer().PlanForId(plan_id); + EXPECT_NE(dynamic_cast<query::plan::ScanAll *>(cached.plan.get()), nullptr); + EXPECT_EQ(cached.symbol_table.max_position(), symbol_table.max_position()); + EXPECT_EQ(cached.symbol_table.table(), symbol_table.table()); + }; + check_for_worker(worker1()); + check_for_worker(worker2()); } + +TEST_F(DistributedGraphDbTest, RemotePullProduceRpc) { + database::GraphDb &db = master(); + database::GraphDbAccessor dba{db}; + Context ctx{dba}; + SymbolGenerator symbol_generator{ctx.symbol_table_}; + AstTreeStorage storage; + + // Query plan for: UNWIND [42, true, "bla", 1, 2] as x RETURN x + using namespace query; + auto list = + LIST(LITERAL(42), LITERAL(true), LITERAL("bla"), LITERAL(1), LITERAL(2)); + auto x = ctx.symbol_table_.CreateSymbol("x", true); + auto unwind = std::make_shared<plan::Unwind>(nullptr, list, x); + auto x_expr = IDENT("x"); + ctx.symbol_table_[*x_expr] = x; + auto x_ne = NEXPR("x", x_expr); + ctx.symbol_table_[*x_ne] = ctx.symbol_table_.CreateSymbol("x_ne", true); + auto produce = MakeProduce(unwind, x_ne); + + // Test that the plan works locally. + auto results = CollectProduce(produce.get(), ctx.symbol_table_, dba); + ASSERT_EQ(results.size(), 5); + + const int plan_id = 42; + master().plan_dispatcher().DispatchPlan(plan_id, produce, ctx.symbol_table_); + + auto remote_pull = [this, plan_id, &ctx, &x_ne](tx::transaction_id_t tx_id, + int worker_id) { + return master().remote_pull_clients().RemotePull( + tx_id, worker_id, plan_id, Parameters(), {ctx.symbol_table_[*x_ne]}, 3); + }; + auto expect_first_batch = [](auto &batch) { + EXPECT_EQ(batch.pull_state, + distributed::RemotePullState::CURSOR_IN_PROGRESS); + ASSERT_EQ(batch.frames.size(), 3); + ASSERT_EQ(batch.frames[0].size(), 1); + EXPECT_EQ(batch.frames[0][0].ValueInt(), 42); + EXPECT_EQ(batch.frames[1][0].ValueBool(), true); + EXPECT_EQ(batch.frames[2][0].ValueString(), "bla"); + }; + auto expect_second_batch = [](auto &batch) { + EXPECT_EQ(batch.pull_state, + distributed::RemotePullState::CURSOR_EXHAUSTED); + ASSERT_EQ(batch.frames.size(), 2); + ASSERT_EQ(batch.frames[0].size(), 1); + EXPECT_EQ(batch.frames[0][0].ValueInt(), 1); + EXPECT_EQ(batch.frames[1][0].ValueInt(), 2); + }; + + database::GraphDbAccessor dba_1{master()}; + database::GraphDbAccessor dba_2{master()}; + for (int worker_id : {1, 2}) { + auto tx1_batch1 = remote_pull(dba_1.transaction_id(), worker_id); + expect_first_batch(tx1_batch1); + auto tx2_batch1 = remote_pull(dba_2.transaction_id(), worker_id); + expect_first_batch(tx2_batch1); + auto tx2_batch2 = remote_pull(dba_2.transaction_id(), worker_id); + expect_second_batch(tx2_batch2); + auto tx1_batch2 = remote_pull(dba_1.transaction_id(), worker_id); + expect_second_batch(tx1_batch2); + } + master().remote_pull_clients().EndRemotePull(dba_1.transaction_id(), plan_id); + master().remote_pull_clients().EndRemotePull(dba_2.transaction_id(), plan_id); +} + +// TODO EndRemotePull test diff --git a/tests/unit/query_plan_common.hpp b/tests/unit/query_plan_common.hpp index 239aef091..7a52145ed 100644 --- a/tests/unit/query_plan_common.hpp +++ b/tests/unit/query_plan_common.hpp @@ -1,8 +1,3 @@ -// -// Copyright 2017 Memgraph -// Created by Florijan Stamenkovic on 14.03.17. -// - #pragma once #include <iterator>