diff --git a/src/database/graph_db.cpp b/src/database/graph_db.cpp index 7248e0fe9..ac9fd03e7 100644 --- a/src/database/graph_db.cpp +++ b/src/database/graph_db.cpp @@ -99,6 +99,7 @@ class SingleNode : public PrivateBase { StorageGc storage_gc_{storage_, tx_engine_, config_.gc_cycle_sec}; TypemapPack typemap_pack_; database::SingleNodeCounters counters_; + std::vector GetWorkerIds() const override { return {0}; } distributed::RemoteDataRpcServer &remote_data_server() override { LOG(FATAL) << "Remote data server not available in single-node."; } @@ -123,6 +124,9 @@ class SingleNode : public PrivateBase { }; #define IMPL_DISTRIBUTED_GETTERS \ + std::vector GetWorkerIds() const override { \ + return coordination_.GetWorkerIds(); \ + } \ distributed::RemoteDataRpcServer &remote_data_server() override { \ return remote_data_server_; \ } \ @@ -251,6 +255,9 @@ ConcurrentIdMapper &PublicBase::property_mapper() { database::Counters &PublicBase::counters() { return impl_->counters(); } void PublicBase::CollectGarbage() { impl_->CollectGarbage(); } int PublicBase::WorkerId() const { return impl_->WorkerId(); } +std::vector PublicBase::GetWorkerIds() const { + return impl_->GetWorkerIds(); +} distributed::RemoteDataRpcServer &PublicBase::remote_data_server() { return impl_->remote_data_server(); } diff --git a/src/database/graph_db.hpp b/src/database/graph_db.hpp index b5167aa64..42f7eba5a 100644 --- a/src/database/graph_db.hpp +++ b/src/database/graph_db.hpp @@ -2,6 +2,7 @@ #include #include +#include #include "database/counters.hpp" #include "database/storage.hpp" @@ -90,6 +91,7 @@ class GraphDb { virtual database::Counters &counters() = 0; virtual void CollectGarbage() = 0; virtual int WorkerId() const = 0; + virtual std::vector GetWorkerIds() const = 0; // Supported only in distributed master and worker, not in single-node. virtual distributed::RemoteDataRpcServer &remote_data_server() = 0; @@ -134,6 +136,7 @@ class PublicBase : public GraphDb { database::Counters &counters() override; void CollectGarbage() override; int WorkerId() const override; + std::vector GetWorkerIds() const override; distributed::RemoteDataRpcServer &remote_data_server() override; distributed::RemoteDataRpcClients &remote_data_clients() override; distributed::PlanDispatcher &plan_dispatcher() override; diff --git a/src/distributed/coordination.hpp b/src/distributed/coordination.hpp index 55146d0fb..4d669dc93 100644 --- a/src/distributed/coordination.hpp +++ b/src/distributed/coordination.hpp @@ -14,7 +14,7 @@ class Coordination { /** Gets the connected worker ids - should only be called on a master * instance*/ - virtual std::vector GetWorkerIds() = 0; + virtual std::vector GetWorkerIds() const = 0; }; } // namespace distributed diff --git a/src/distributed/coordination_master.cpp b/src/distributed/coordination_master.cpp index d74dac068..46fb3e541 100644 --- a/src/distributed/coordination_master.cpp +++ b/src/distributed/coordination_master.cpp @@ -57,7 +57,7 @@ Endpoint MasterCoordination::GetEndpoint(int worker_id) { return found->second; } -std::vector MasterCoordination::GetWorkerIds() { +std::vector MasterCoordination::GetWorkerIds() const { std::vector worker_ids; for (auto worker : workers_) worker_ids.push_back(worker.first); return worker_ids; diff --git a/src/distributed/coordination_master.hpp b/src/distributed/coordination_master.hpp index 59343af76..221e3bc92 100644 --- a/src/distributed/coordination_master.hpp +++ b/src/distributed/coordination_master.hpp @@ -36,7 +36,7 @@ class MasterCoordination : public Coordination { Endpoint GetEndpoint(int worker_id) override; /** Returns all workers id, this includes master id(0) */ - std::vector GetWorkerIds() override; + std::vector GetWorkerIds() const override; private: communication::rpc::Server server_; diff --git a/src/distributed/coordination_worker.cpp b/src/distributed/coordination_worker.cpp index d9a66c077..b9d485427 100644 --- a/src/distributed/coordination_worker.cpp +++ b/src/distributed/coordination_worker.cpp @@ -54,7 +54,7 @@ void WorkerCoordination::WaitForShutdown() { std::this_thread::sleep_for(100ms); }; -std::vector WorkerCoordination::GetWorkerIds() { +std::vector WorkerCoordination::GetWorkerIds() const { LOG(FATAL) << "Unimplemented worker ids discovery on worker"; }; } // namespace distributed diff --git a/src/distributed/coordination_worker.hpp b/src/distributed/coordination_worker.hpp index c6fc832e6..0c24c947d 100644 --- a/src/distributed/coordination_worker.hpp +++ b/src/distributed/coordination_worker.hpp @@ -30,7 +30,7 @@ class WorkerCoordination : public Coordination { /** Shouldn't be called on worker for now! * TODO fix this */ - std::vector GetWorkerIds() override; + std::vector GetWorkerIds() const override; /** Starts listening for a remote shutdown command (issued by the master). * Blocks the calling thread until that has finished. */ diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 704f8e22d..155bbbf33 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -94,9 +94,11 @@ std::unique_ptr Once::MakeCursor(database::GraphDbAccessor &) const { void Once::OnceCursor::Reset() { did_pull_ = false; } -CreateNode::CreateNode(NodeAtom *node_atom, - const std::shared_ptr &input) - : node_atom_(node_atom), input_(input ? input : std::make_shared()) {} +CreateNode::CreateNode(const std::shared_ptr &input, + NodeAtom *node_atom, bool on_random_worker) + : input_(input ? input : std::make_shared()), + node_atom_(node_atom), + on_random_worker_(on_random_worker) {} ACCEPT_WITH_INPUT(CreateNode) @@ -118,7 +120,17 @@ CreateNode::CreateNodeCursor::CreateNodeCursor(const CreateNode &self, bool CreateNode::CreateNodeCursor::Pull(Frame &frame, Context &context) { if (input_cursor_->Pull(frame, context)) { - Create(frame, context); + if (self_.on_random_worker_) { + auto worker_ids = context.db_accessor_.db().GetWorkerIds(); + auto worker_id = worker_ids[rand_(gen_) % worker_ids.size()]; + if (worker_id == context.db_accessor_.db().WorkerId()) { + CreateLocally(frame, context); + } else { + CreateOnWorker(worker_id, frame, context); + } + } else { + CreateLocally(frame, context); + } return true; } return false; @@ -126,7 +138,8 @@ bool CreateNode::CreateNodeCursor::Pull(Frame &frame, Context &context) { void CreateNode::CreateNodeCursor::Reset() { input_cursor_->Reset(); } -void CreateNode::CreateNodeCursor::Create(Frame &frame, Context &context) { +void CreateNode::CreateNodeCursor::CreateLocally(Frame &frame, + Context &context) { auto new_node = db_.InsertVertex(); for (auto label : self_.node_atom_->labels_) new_node.add_label(label); @@ -139,6 +152,29 @@ void CreateNode::CreateNodeCursor::Create(Frame &frame, Context &context) { frame[context.symbol_table_.at(*self_.node_atom_->identifier_)] = new_node; } +void CreateNode::CreateNodeCursor::CreateOnWorker(int worker_id, Frame &frame, + Context &context) { + std::unordered_map properties; + + // Evaluator should use the latest accessors, as modified in this query, when + // setting properties on new nodes. + ExpressionEvaluator evaluator(frame, context.parameters_, + context.symbol_table_, db_, GraphView::NEW); + for (auto &kv : self_.node_atom_->properties_) { + auto value = kv.second->Accept(evaluator); + if (!value.IsPropertyValue()) { + throw QueryRuntimeException("'{}' cannot be used as a property value.", + value.type()); + } + properties.emplace(kv.first.second, std::move(value)); + } + + auto new_node = context.db_accessor_.InsertVertexIntoRemote( + worker_id, self_.node_atom_->labels_, properties); + + frame[context.symbol_table_.at(*self_.node_atom_->identifier_)] = new_node; +} + CreateExpand::CreateExpand(NodeAtom *node_atom, EdgeAtom *edge_atom, const std::shared_ptr &input, Symbol input_symbol, bool existing_node) @@ -349,10 +385,10 @@ std::unique_ptr ScanAllByLabelPropertyRange::MakeCursor( context.symbol_table_, db, graph_view_); auto convert = [&evaluator](const auto &bound) -> std::experimental::optional> { - if (!bound) return std::experimental::nullopt; - return std::experimental::make_optional(utils::Bound( - bound.value().value()->Accept(evaluator), bound.value().type())); - }; + if (!bound) return std::experimental::nullopt; + return std::experimental::make_optional(utils::Bound( + bound.value().value()->Accept(evaluator), bound.value().type())); + }; return db.Vertices(label_, property_, convert(lower_bound()), convert(upper_bound()), graph_view_ == GraphView::NEW); }; @@ -1894,7 +1930,7 @@ bool ExpandUniquenessFilter::ExpandUniquenessFilterCursor::Pull( for (const auto &previous_symbol : self_.previous_symbols_) { TypedValue &previous_value = frame[previous_symbol]; // This shouldn't raise a TypedValueException, because the planner - // makes sure these are all of the expected type. In case they are not, + // makes sure these are all of the expected type. In case they are not // an error should be raised long before this code is executed. if (ContainsSame(previous_value, expand_value)) return false; } diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp index c90be8035..f89e4b6d2 100644 --- a/src/query/plan/operator.hpp +++ b/src/query/plan/operator.hpp @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -237,15 +238,16 @@ class Once : public LogicalOperator { class CreateNode : public LogicalOperator { public: /** - * - * @param node_atom @c NodeAtom with information on how to create a node. * @param input Optional. If @c nullptr, then a single node will be * created (a single successful @c Cursor::Pull from this op's @c Cursor). * If a valid input, then a node will be created for each * successful pull from the given input. + * @param node_atom @c NodeAtom with information on how to create a node. + * @param on_random_worker If the node should be created locally or on random + * worker. */ - CreateNode(NodeAtom *node_atom, - const std::shared_ptr &input); + CreateNode(const std::shared_ptr &input, NodeAtom *node_atom, + bool on_random_worker); bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; std::unique_ptr MakeCursor( database::GraphDbAccessor &db) const override; @@ -257,8 +259,9 @@ class CreateNode : public LogicalOperator { private: CreateNode() {} - NodeAtom *node_atom_ = nullptr; std::shared_ptr input_; + NodeAtom *node_atom_ = nullptr; + bool on_random_worker_; class CreateNodeCursor : public Cursor { public: @@ -271,10 +274,15 @@ class CreateNode : public LogicalOperator { database::GraphDbAccessor &db_; const std::unique_ptr input_cursor_; - /** - * Creates a single node and places it in the frame. - */ - void Create(Frame &, Context &); + // For random worker choosing in distributed. + std::mt19937 gen_{std::random_device{}()}; + std::uniform_int_distribution rand_; + + /** Creates a single node locally and places it in the frame. */ + void CreateLocally(Frame &, Context &); + + /** Creates a single node on the given worker and places it in the frame. */ + void CreateOnWorker(int worker_id, Frame &, Context &); }; friend class boost::serialization::access; @@ -286,6 +294,7 @@ class CreateNode : public LogicalOperator { ar &boost::serialization::base_object(*this); ar &input_; SavePointer(ar, node_atom_); + ar &on_random_worker_; } template @@ -293,6 +302,7 @@ class CreateNode : public LogicalOperator { ar &boost::serialization::base_object(*this); ar &input_; LoadPointer(ar, node_atom_); + ar &on_random_worker_; } }; diff --git a/src/query/plan/rule_based_planner.cpp b/src/query/plan/rule_based_planner.cpp index c2798b867..398cc4ff5 100644 --- a/src/query/plan/rule_based_planner.cpp +++ b/src/query/plan/rule_based_planner.cpp @@ -595,7 +595,7 @@ std::unique_ptr GenCreateForPattern( std::unordered_set &bound_symbols) { auto base = [&](NodeAtom *node) -> std::unique_ptr { if (bound_symbols.insert(symbol_table.at(*node->identifier_)).second) - return std::make_unique(node, std::move(input_op)); + return std::make_unique(std::move(input_op), node, false); else return std::move(input_op); }; diff --git a/tests/unit/distributed_graph_db.cpp b/tests/unit/distributed_graph_db.cpp index 790416938..68d4d48a6 100644 --- a/tests/unit/distributed_graph_db.cpp +++ b/tests/unit/distributed_graph_db.cpp @@ -388,3 +388,31 @@ TEST_F(DistributedGraphDbTest, Synchronize) { // TODO test without advance command? } + +TEST_F(DistributedGraphDbTest, Create) { + // Query: UNWIND range(0, 1000) as x CREATE () + auto &db = master(); + GraphDbAccessor dba{db}; + Context ctx{dba}; + SymbolGenerator symbol_generator{ctx.symbol_table_}; + AstTreeStorage storage; + auto range = FN("range", LITERAL(0), LITERAL(1000)); + auto x = ctx.symbol_table_.CreateSymbol("x", true); + auto unwind = std::make_shared(nullptr, range, x); + auto node = NODE("n"); + ctx.symbol_table_[*node->identifier_] = + ctx.symbol_table_.CreateSymbol("n", true); + auto create = std::make_shared(unwind, node, true); + PullAll(create, dba, ctx.symbol_table_); + dba.Commit(); + + auto vertex_count = [](database::GraphDb &db) { + database::GraphDbAccessor dba{db}; + auto vertices = dba.Vertices(false); + return std::distance(vertices.begin(), vertices.end()); + }; + + EXPECT_GT(vertex_count(master()), 200); + EXPECT_GT(vertex_count(worker(1)), 200); + EXPECT_GT(vertex_count(worker(2)), 200); +} diff --git a/tests/unit/query_plan_accumulate_aggregate.cpp b/tests/unit/query_plan_accumulate_aggregate.cpp index 2976ccaf9..48cec6390 100644 --- a/tests/unit/query_plan_accumulate_aggregate.cpp +++ b/tests/unit/query_plan_accumulate_aggregate.cpp @@ -94,7 +94,7 @@ TEST(QueryPlan, AccumulateAdvance) { auto node = NODE("n"); auto sym_n = symbol_table.CreateSymbol("n", true); symbol_table[*node->identifier_] = sym_n; - auto create = std::make_shared(node, nullptr); + auto create = std::make_shared(nullptr, node, false); auto accumulate = std::make_shared( create, std::vector{sym_n}, advance); auto match = MakeScanAll(storage, symbol_table, "m", accumulate); diff --git a/tests/unit/query_plan_bag_semantics.cpp b/tests/unit/query_plan_bag_semantics.cpp index f3353edff..1c3cb8e75 100644 --- a/tests/unit/query_plan_bag_semantics.cpp +++ b/tests/unit/query_plan_bag_semantics.cpp @@ -95,7 +95,7 @@ TEST(QueryPlan, CreateLimit) { auto n = MakeScanAll(storage, symbol_table, "n1"); auto m = NODE("m"); symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m", true); - auto c = std::make_shared(m, n.op_); + auto c = std::make_shared(n.op_, m, false); auto skip = std::make_shared(c, LITERAL(1)); EXPECT_EQ(1, PullAll(skip, dba, symbol_table)); diff --git a/tests/unit/query_plan_create_set_remove_delete.cpp b/tests/unit/query_plan_create_set_remove_delete.cpp index 56afc8f92..a4d2ca4d3 100644 --- a/tests/unit/query_plan_create_set_remove_delete.cpp +++ b/tests/unit/query_plan_create_set_remove_delete.cpp @@ -1,8 +1,3 @@ -// -// Copyright 2017 Memgraph -// Created by Florijan Stamenkovic on 14.03.17. -// - #include #include #include @@ -36,7 +31,7 @@ TEST(QueryPlan, CreateNodeWithAttributes) { node->labels_.emplace_back(label); node->properties_[property] = LITERAL(42); - auto create = std::make_shared(node, nullptr); + auto create = std::make_shared(nullptr, node, false); PullAll(create, dba, symbol_table); dba.AdvanceCommand(); @@ -71,7 +66,7 @@ TEST(QueryPlan, CreateReturn) { node->labels_.emplace_back(label); node->properties_[property] = LITERAL(42); - auto create = std::make_shared(node, nullptr); + auto create = std::make_shared(nullptr, node, false); auto named_expr_n = NEXPR("n", IDENT("n")); symbol_table[*named_expr_n] = symbol_table.CreateSymbol("named_expr_n", true); symbol_table[*named_expr_n->expression_] = sym_n; @@ -134,7 +129,7 @@ TEST(QueryPlan, CreateExpand) { r->edge_types_.emplace_back(edge_type); r->properties_[property] = LITERAL(3); - auto create_op = std::make_shared(n, nullptr); + auto create_op = std::make_shared(nullptr, n, false); auto create_expand = std::make_shared(m, r, create_op, n_sym, cycle); PullAll(create_expand, dba, symbol_table); @@ -189,7 +184,7 @@ TEST(QueryPlan, MatchCreateNode) { auto m = NODE("m"); symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m", true); // creation op - auto create_node = std::make_shared(m, n_scan_all.op_); + auto create_node = std::make_shared(n_scan_all.op_, m, false); EXPECT_EQ(CountIterable(dba.Vertices(false)), 3); PullAll(create_node, dba, symbol_table); @@ -846,7 +841,7 @@ TEST(QueryPlan, MergeNoInput) { auto node = NODE("n"); auto sym_n = symbol_table.CreateSymbol("n", true); symbol_table[*node->identifier_] = sym_n; - auto create = std::make_shared(node, nullptr); + auto create = std::make_shared(nullptr, node, false); auto merge = std::make_shared(nullptr, create, create); EXPECT_EQ(0, CountIterable(dba.Vertices(false)));