diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 37ac18fdf..cab86ff61 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -2613,7 +2613,8 @@ bool PullRemote::PullRemoteCursor::Pull(Frame &frame, Context &context) { // Get locally stored results from workers in a round-robin fasion. int num_workers = worker_ids_.size(); for (int i = 0; i < num_workers; ++i) { - int worker_id_index = (last_pulled_worker_id_index_ + i) % num_workers; + int worker_id_index = + (last_pulled_worker_id_index_ + i + 1) % num_workers; int worker_id = worker_ids_[worker_id_index]; if (!remote_results_[worker_id].empty()) { @@ -2623,10 +2624,20 @@ bool PullRemote::PullRemoteCursor::Pull(Frame &frame, Context &context) { } } - // If there are no remote results available, pull and return local results. - if (!have_remote_results && self_.pull_local() && - input_cursor_->Pull(frame, context)) { - return true; + if (!have_remote_results) { + // If we didn't find any remote results and there aren't any remote + // pulls, we've exhausted all remote results. Make sure we signal that to + // workers and exit the loop. + if (remote_pulls_.empty()) { + EndRemotePull(); + break; + } + + // If there are no remote results available, pull and return local + // results. + if (self_.pull_local() && input_cursor_->Pull(frame, context)) { + return true; + } } } @@ -2634,9 +2645,8 @@ bool PullRemote::PullRemoteCursor::Pull(Frame &frame, Context &context) { if (!have_remote_results) { if (self_.pull_local() && input_cursor_->Pull(frame, context)) { return true; - } else { - return false; } + return false; } int pull_from_worker_id = worker_ids_[last_pulled_worker_id_index_]; diff --git a/tests/unit/distributed_graph_db.cpp b/tests/unit/distributed_graph_db.cpp index 4f079a35f..f93321595 100644 --- a/tests/unit/distributed_graph_db.cpp +++ b/tests/unit/distributed_graph_db.cpp @@ -13,20 +13,18 @@ #include "distributed/remote_data_rpc_clients.hpp" #include "distributed/remote_data_rpc_server.hpp" #include "distributed/remote_pull_rpc_clients.hpp" +#include "distributed_common.hpp" #include "io/network/endpoint.hpp" #include "query/frontend/ast/ast.hpp" #include "query/frontend/ast/cypher_main_visitor.hpp" #include "query/frontend/semantic/symbol_generator.hpp" #include "query/frontend/semantic/symbol_table.hpp" -#include "query/plan/planner.hpp" -#include "query_plan_common.hpp" -#include "transactions/engine_master.hpp" - -#include "distributed_common.hpp" #include "query/interpreter.hpp" +#include "query/plan/planner.hpp" #include "query/typed_value.hpp" #include "query_common.hpp" #include "query_plan_common.hpp" +#include "transactions/engine_master.hpp" using namespace distributed; using namespace database; @@ -341,27 +339,3 @@ TEST_F(DistributedGraphDbTest, WorkerOwnedDbAccessors) { VertexAccessor v_in_w2{v_ga, dba_w2}; EXPECT_EQ(v_in_w2.PropsAt(prop).Value(), 42); } - -TEST_F(DistributedGraphDbTest, RemotePullTest) { - using Interpreter = query::Interpreter; - std::map params = {}; - - GraphDbAccessor dba(master()); - - ResultStreamFaker result; - Interpreter interpreter_; - interpreter_("OPTIONAL MATCH(n) UNWIND(RANGE(0, 20)) AS X RETURN 1", dba, - params, false) - .PullAll(result); - - // Three instances (master + 2 workers) with 21 result each. - uint expected_results = 3U * 21; - ASSERT_EQ(result.GetHeader().size(), 1U); - EXPECT_EQ(result.GetHeader()[0], "1"); - ASSERT_EQ(result.GetResults().size(), expected_results); - - for (uint i = 0; i < expected_results; ++i) { - ASSERT_EQ(result.GetResults()[i].size(), 1U); - ASSERT_EQ(result.GetResults()[i][0].Value(), 1); - } -} diff --git a/tests/unit/distributed_interpretation.cpp b/tests/unit/distributed_interpretation.cpp new file mode 100644 index 000000000..42fb6665e --- /dev/null +++ b/tests/unit/distributed_interpretation.cpp @@ -0,0 +1,49 @@ +#include "gtest/gtest.h" + +#include "database/graph_db.hpp" +#include "distributed_common.hpp" +#include "query/interpreter.hpp" +#include "query_common.hpp" +#include "query_plan_common.hpp" + +using namespace distributed; +using namespace database; + +TEST_F(DistributedGraphDbTest, RemotePullTest) { + using Interpreter = query::Interpreter; + std::map params = {}; + + GraphDbAccessor dba(master()); + + ResultStreamFaker result; + Interpreter interpreter_; + interpreter_("OPTIONAL MATCH(n) UNWIND(RANGE(0, 20)) AS X RETURN 1", dba, + params, false) + .PullAll(result); + + // Three instances (master + 2 workers) with 21 result each. + uint expected_results = 3U * 21; + ASSERT_EQ(result.GetHeader().size(), 1U); + EXPECT_EQ(result.GetHeader()[0], "1"); + ASSERT_EQ(result.GetResults().size(), expected_results); + + for (uint i = 0; i < expected_results; ++i) { + ASSERT_EQ(result.GetResults()[i].size(), 1U); + ASSERT_EQ(result.GetResults()[i][0].Value(), 1); + } +} + +TEST_F(DistributedGraphDbTest, RemotePullNoResultsTest) { + using Interpreter = query::Interpreter; + std::map params = {}; + + GraphDbAccessor dba(master()); + + ResultStreamFaker result; + Interpreter interpreter_; + interpreter_("MATCH (n) RETURN n", dba, params, false).PullAll(result); + + ASSERT_EQ(result.GetHeader().size(), 1U); + EXPECT_EQ(result.GetHeader()[0], "n"); + ASSERT_EQ(result.GetResults().size(), 0U); +}