Remove root tracking from AST storage

Summary: Up till now, `AstStorage` also took care of tracking the root of the `Query` and loading of cloning of `Query` nodes would change that root. This felt out of place because sometimes `AstStorage` is used only for storing expressions, and we don't even have an entire query in the storage. This diff removes that feature from `AstStorage`. Now its only functionality is owning AST nodes and assigning unique IDs to them.

Reviewers: teon.banek, llugovic

Reviewed By: teon.banek

Subscribers: mferencevic, pullbot

Differential Revision: https://phabricator.memgraph.io/D1646
This commit is contained in:
Marin Tomic 2018-10-10 15:19:34 +02:00
parent baae40fcc6
commit 285e02d5ec
24 changed files with 882 additions and 830 deletions

View File

@ -57,10 +57,10 @@ DistributedInterpreter::DistributedInterpreter(database::Master *db)
: plan_dispatcher_(&db->plan_dispatcher()) {}
std::unique_ptr<LogicalPlan> DistributedInterpreter::MakeLogicalPlan(
AstStorage ast_storage, Context *context) {
Query *query, AstStorage ast_storage, Context *context) {
auto vertex_counts = plan::MakeVertexCountCache(context->db_accessor_);
auto planning_context = plan::MakePlanningContext(
ast_storage, context->symbol_table_, vertex_counts);
ast_storage, context->symbol_table_, query, vertex_counts);
std::unique_ptr<plan::LogicalOperator> tmp_logical_plan;
double cost;
std::tie(tmp_logical_plan, cost) = plan::MakeLogicalPlan(

View File

@ -17,7 +17,8 @@ class DistributedInterpreter final : public Interpreter {
DistributedInterpreter(database::Master *db);
private:
std::unique_ptr<LogicalPlan> MakeLogicalPlan(AstStorage, Context *) override;
std::unique_ptr<LogicalPlan> MakeLogicalPlan(Query *, AstStorage,
Context *) override;
std::atomic<int64_t> next_plan_id_{0};
distributed::PlanDispatcher *plan_dispatcher_{nullptr};

View File

@ -4,16 +4,6 @@
namespace query {
AstStorage::AstStorage() {
std::unique_ptr<Query> root(new Query(++max_existing_uid_));
root_idx_ = 0;
storage_.emplace_back(std::move(root));
}
Query *AstStorage::query() const {
return dynamic_cast<Query *>(storage_[root_idx_].get());
}
ReturnBody CloneReturnBody(AstStorage &storage, const ReturnBody &body) {
ReturnBody new_body;
new_body.distinct = body.distinct;

View File

@ -130,7 +130,7 @@ class Tree;
// called NodeAtom...
class AstStorage {
public:
AstStorage();
AstStorage() = default;
AstStorage(const AstStorage &) = delete;
AstStorage &operator=(const AstStorage &) = delete;
AstStorage(AstStorage &&) = default;
@ -140,19 +140,13 @@ class AstStorage {
T *Create(Args &&... args) {
T *ptr = new T(++max_existing_uid_, std::forward<Args>(args)...);
std::unique_ptr<T> tmp(ptr);
if (std::is_same<T, Query>::value) {
root_idx_ = storage_.size();
}
storage_.emplace_back(std::move(tmp));
return ptr;
}
Query *query() const;
// Public only for serialization access
std::vector<std::unique_ptr<Tree>> storage_;
int max_existing_uid_ = -1;
size_t root_idx_;
};
cpp<#

View File

@ -29,13 +29,8 @@ cpp<#
CHECK(found != ast->storage_.end());
return found->get();
}
std::unique_ptr<Tree> root;
::query::Load(&root, tree, ast, loaded_uids);
if (dynamic_cast<Query *>(root.get())) {
ast->root_idx_ = ast->storage_.size();
}
ast->storage_.emplace_back(std::move(root));
loaded_uids->emplace_back(uid);
ast->max_existing_uid_ = std::max(ast->max_existing_uid_, uid);

View File

@ -35,7 +35,7 @@ antlrcpp::Any CypherMainVisitor::visitExplainQuery(
antlrcpp::Any CypherMainVisitor::visitCypherQuery(
MemgraphCypher::CypherQueryContext *ctx) {
query_ = storage_.query();
query_ = storage_->Create<Query>();
DCHECK(ctx->singleQuery()) << "Expected single query.";
query_->single_query_ = ctx->singleQuery()->accept(this).as<SingleQuery *>();
@ -59,8 +59,8 @@ antlrcpp::Any CypherMainVisitor::visitCypherQuery(
antlrcpp::Any CypherMainVisitor::visitIndexQuery(
MemgraphCypher::IndexQueryContext *ctx) {
query_ = storage_.query();
query_->single_query_ = storage_.Create<SingleQuery>();
query_ = storage_->Create<Query>();
query_->single_query_ = storage_->Create<SingleQuery>();
if (ctx->createIndex()) {
query_->single_query_->clauses_.emplace_back(
ctx->createIndex()->accept(this).as<CreateIndex *>());
@ -74,8 +74,8 @@ antlrcpp::Any CypherMainVisitor::visitIndexQuery(
antlrcpp::Any CypherMainVisitor::visitAuthQuery(
MemgraphCypher::AuthQueryContext *ctx) {
query_ = storage_.query();
query_->single_query_ = storage_.Create<SingleQuery>();
query_ = storage_->Create<Query>();
query_->single_query_ = storage_->Create<SingleQuery>();
CHECK(ctx->children.size() == 1)
<< "AuthQuery should have exactly one child!";
query_->single_query_->clauses_.push_back(
@ -85,8 +85,8 @@ antlrcpp::Any CypherMainVisitor::visitAuthQuery(
antlrcpp::Any CypherMainVisitor::visitStreamQuery(
MemgraphCypher::StreamQueryContext *ctx) {
query_ = storage_.query();
query_->single_query_ = storage_.Create<SingleQuery>();
query_ = storage_->Create<Query>();
query_->single_query_ = storage_->Create<SingleQuery>();
Clause *clause = nullptr;
if (ctx->createStream()) {
clause = ctx->createStream()->accept(this).as<CreateStream *>();
@ -109,7 +109,7 @@ antlrcpp::Any CypherMainVisitor::visitStreamQuery(
antlrcpp::Any CypherMainVisitor::visitCypherUnion(
MemgraphCypher::CypherUnionContext *ctx) {
bool distinct = !ctx->ALL();
auto *cypher_union = storage_.Create<CypherUnion>(distinct);
auto *cypher_union = storage_->Create<CypherUnion>(distinct);
DCHECK(ctx->singleQuery()) << "Expected single query.";
cypher_union->single_query_ =
ctx->singleQuery()->accept(this).as<SingleQuery *>();
@ -118,7 +118,7 @@ antlrcpp::Any CypherMainVisitor::visitCypherUnion(
antlrcpp::Any CypherMainVisitor::visitSingleQuery(
MemgraphCypher::SingleQueryContext *ctx) {
auto *single_query = storage_.Create<SingleQuery>();
auto *single_query = storage_->Create<SingleQuery>();
for (auto *child : ctx->clause()) {
antlrcpp::Any got = child->accept(this);
if (got.is<Clause *>()) {
@ -201,7 +201,7 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery(
while (true) {
std::string id_name = kAnonPrefix + std::to_string(id++);
if (users_identifiers.find(id_name) == users_identifiers.end()) {
*identifier = storage_.Create<Identifier>(id_name, false);
*identifier = storage_->Create<Identifier>(id_name, false);
break;
}
}
@ -250,7 +250,7 @@ antlrcpp::Any CypherMainVisitor::visitClause(
antlrcpp::Any CypherMainVisitor::visitCypherMatch(
MemgraphCypher::CypherMatchContext *ctx) {
auto *match = storage_.Create<Match>();
auto *match = storage_->Create<Match>();
match->optional_ = !!ctx->OPTIONAL();
if (ctx->where()) {
match->where_ = ctx->where()->accept(this);
@ -261,7 +261,7 @@ antlrcpp::Any CypherMainVisitor::visitCypherMatch(
antlrcpp::Any CypherMainVisitor::visitCreate(
MemgraphCypher::CreateContext *ctx) {
auto *create = storage_.Create<Create>();
auto *create = storage_->Create<Create>();
create->patterns_ = ctx->pattern()->accept(this).as<std::vector<Pattern *>>();
return create;
}
@ -273,7 +273,7 @@ antlrcpp::Any CypherMainVisitor::visitCreateIndex(
MemgraphCypher::CreateIndexContext *ctx) {
std::pair<std::string, storage::Property> key =
ctx->propertyKeyName()->accept(this);
return storage_.Create<CreateIndex>(
return storage_->Create<CreateIndex>(
dba_->Label(ctx->labelName()->accept(this)), key.second);
}
@ -289,7 +289,7 @@ antlrcpp::Any CypherMainVisitor::visitCreateUniqueIndex(
prop_name->accept(this);
properties.push_back(name_key.second);
}
return storage_.Create<CreateUniqueIndex>(
return storage_->Create<CreateUniqueIndex>(
dba_->Label(ctx->labelName()->accept(this)), properties);
}
@ -311,7 +311,7 @@ antlrcpp::Any CypherMainVisitor::visitUserOrRoleName(
*/
antlrcpp::Any CypherMainVisitor::visitCreateRole(
MemgraphCypher::CreateRoleContext *ctx) {
AuthQuery *auth = storage_.Create<AuthQuery>();
AuthQuery *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::CREATE_ROLE;
auth->role_ = ctx->role->accept(this).as<std::string>();
return auth;
@ -322,7 +322,7 @@ antlrcpp::Any CypherMainVisitor::visitCreateRole(
*/
antlrcpp::Any CypherMainVisitor::visitDropRole(
MemgraphCypher::DropRoleContext *ctx) {
AuthQuery *auth = storage_.Create<AuthQuery>();
AuthQuery *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::DROP_ROLE;
auth->role_ = ctx->role->accept(this).as<std::string>();
return auth;
@ -333,7 +333,7 @@ antlrcpp::Any CypherMainVisitor::visitDropRole(
*/
antlrcpp::Any CypherMainVisitor::visitShowRoles(
MemgraphCypher::ShowRolesContext *ctx) {
AuthQuery *auth = storage_.Create<AuthQuery>();
AuthQuery *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::SHOW_ROLES;
return auth;
}
@ -343,7 +343,7 @@ antlrcpp::Any CypherMainVisitor::visitShowRoles(
*/
antlrcpp::Any CypherMainVisitor::visitCreateUser(
MemgraphCypher::CreateUserContext *ctx) {
AuthQuery *auth = storage_.Create<AuthQuery>();
AuthQuery *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::CREATE_USER;
auth->user_ = ctx->user->accept(this).as<std::string>();
if (ctx->password) {
@ -360,7 +360,7 @@ antlrcpp::Any CypherMainVisitor::visitCreateUser(
*/
antlrcpp::Any CypherMainVisitor::visitSetPassword(
MemgraphCypher::SetPasswordContext *ctx) {
AuthQuery *auth = storage_.Create<AuthQuery>();
AuthQuery *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::SET_PASSWORD;
auth->user_ = ctx->user->accept(this).as<std::string>();
if (!ctx->password->StringLiteral() && !ctx->literal()->CYPHERNULL()) {
@ -375,7 +375,7 @@ antlrcpp::Any CypherMainVisitor::visitSetPassword(
*/
antlrcpp::Any CypherMainVisitor::visitDropUser(
MemgraphCypher::DropUserContext *ctx) {
AuthQuery *auth = storage_.Create<AuthQuery>();
AuthQuery *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::DROP_USER;
auth->user_ = ctx->user->accept(this).as<std::string>();
return auth;
@ -386,7 +386,7 @@ antlrcpp::Any CypherMainVisitor::visitDropUser(
*/
antlrcpp::Any CypherMainVisitor::visitShowUsers(
MemgraphCypher::ShowUsersContext *ctx) {
AuthQuery *auth = storage_.Create<AuthQuery>();
AuthQuery *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::SHOW_USERS;
return auth;
}
@ -396,7 +396,7 @@ antlrcpp::Any CypherMainVisitor::visitShowUsers(
*/
antlrcpp::Any CypherMainVisitor::visitSetRole(
MemgraphCypher::SetRoleContext *ctx) {
AuthQuery *auth = storage_.Create<AuthQuery>();
AuthQuery *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::SET_ROLE;
auth->user_ = ctx->user->accept(this).as<std::string>();
auth->role_ = ctx->role->accept(this).as<std::string>();
@ -408,7 +408,7 @@ antlrcpp::Any CypherMainVisitor::visitSetRole(
*/
antlrcpp::Any CypherMainVisitor::visitClearRole(
MemgraphCypher::ClearRoleContext *ctx) {
AuthQuery *auth = storage_.Create<AuthQuery>();
AuthQuery *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::CLEAR_ROLE;
auth->user_ = ctx->user->accept(this).as<std::string>();
return auth;
@ -419,7 +419,7 @@ antlrcpp::Any CypherMainVisitor::visitClearRole(
*/
antlrcpp::Any CypherMainVisitor::visitGrantPrivilege(
MemgraphCypher::GrantPrivilegeContext *ctx) {
AuthQuery *auth = storage_.Create<AuthQuery>();
AuthQuery *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::GRANT_PRIVILEGE;
auth->user_or_role_ = ctx->userOrRole->accept(this).as<std::string>();
if (ctx->privilegeList()) {
@ -438,7 +438,7 @@ antlrcpp::Any CypherMainVisitor::visitGrantPrivilege(
*/
antlrcpp::Any CypherMainVisitor::visitDenyPrivilege(
MemgraphCypher::DenyPrivilegeContext *ctx) {
AuthQuery *auth = storage_.Create<AuthQuery>();
AuthQuery *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::DENY_PRIVILEGE;
auth->user_or_role_ = ctx->userOrRole->accept(this).as<std::string>();
if (ctx->privilegeList()) {
@ -457,7 +457,7 @@ antlrcpp::Any CypherMainVisitor::visitDenyPrivilege(
*/
antlrcpp::Any CypherMainVisitor::visitRevokePrivilege(
MemgraphCypher::RevokePrivilegeContext *ctx) {
AuthQuery *auth = storage_.Create<AuthQuery>();
AuthQuery *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::REVOKE_PRIVILEGE;
auth->user_or_role_ = ctx->userOrRole->accept(this).as<std::string>();
if (ctx->privilegeList()) {
@ -493,7 +493,7 @@ antlrcpp::Any CypherMainVisitor::visitPrivilege(
*/
antlrcpp::Any CypherMainVisitor::visitShowPrivileges(
MemgraphCypher::ShowPrivilegesContext *ctx) {
AuthQuery *auth = storage_.Create<AuthQuery>();
AuthQuery *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::SHOW_PRIVILEGES;
auth->user_or_role_ = ctx->userOrRole->accept(this).as<std::string>();
return auth;
@ -504,7 +504,7 @@ antlrcpp::Any CypherMainVisitor::visitShowPrivileges(
*/
antlrcpp::Any CypherMainVisitor::visitShowRoleForUser(
MemgraphCypher::ShowRoleForUserContext *ctx) {
AuthQuery *auth = storage_.Create<AuthQuery>();
AuthQuery *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::SHOW_ROLE_FOR_USER;
auth->user_ = ctx->user->accept(this).as<std::string>();
return auth;
@ -515,7 +515,7 @@ antlrcpp::Any CypherMainVisitor::visitShowRoleForUser(
*/
antlrcpp::Any CypherMainVisitor::visitShowUsersForRole(
MemgraphCypher::ShowUsersForRoleContext *ctx) {
AuthQuery *auth = storage_.Create<AuthQuery>();
AuthQuery *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::SHOW_USERS_FOR_ROLE;
auth->role_ = ctx->role->accept(this).as<std::string>();
return auth;
@ -552,9 +552,9 @@ antlrcpp::Any CypherMainVisitor::visitCreateStream(
batch_size = ctx->batchSizeOption()->accept(this);
}
return storage_.Create<CreateStream>(stream_name, stream_uri, stream_topic,
transform_uri, batch_interval_in_ms,
batch_size);
return storage_->Create<CreateStream>(stream_name, stream_uri, stream_topic,
transform_uri, batch_interval_in_ms,
batch_size);
}
/**
@ -586,7 +586,8 @@ antlrcpp::Any CypherMainVisitor::visitBatchSizeOption(
*/
antlrcpp::Any CypherMainVisitor::visitDropStream(
MemgraphCypher::DropStreamContext *ctx) {
return storage_.Create<DropStream>(std::string(ctx->streamName()->getText()));
return storage_->Create<DropStream>(
std::string(ctx->streamName()->getText()));
}
/**
@ -594,7 +595,7 @@ antlrcpp::Any CypherMainVisitor::visitDropStream(
*/
antlrcpp::Any CypherMainVisitor::visitShowStreams(
MemgraphCypher::ShowStreamsContext *ctx) {
return storage_.Create<ShowStreams>();
return storage_->Create<ShowStreams>();
}
/**
@ -613,7 +614,8 @@ antlrcpp::Any CypherMainVisitor::visitStartStopStream(
limit_batches = ctx->limitBatchesOption()->accept(this);
}
return storage_.Create<StartStopStream>(stream_name, is_start, limit_batches);
return storage_->Create<StartStopStream>(stream_name, is_start,
limit_batches);
}
/**
@ -634,12 +636,12 @@ antlrcpp::Any CypherMainVisitor::visitLimitBatchesOption(
antlrcpp::Any CypherMainVisitor::visitStartStopAllStreams(
MemgraphCypher::StartStopAllStreamsContext *ctx) {
bool is_start = static_cast<bool>(ctx->START());
return storage_.Create<StartStopAllStreams>(is_start);
return storage_->Create<StartStopAllStreams>(is_start);
}
antlrcpp::Any CypherMainVisitor::visitCypherReturn(
MemgraphCypher::CypherReturnContext *ctx) {
auto *return_clause = storage_.Create<Return>();
auto *return_clause = storage_->Create<Return>();
return_clause->body_ = ctx->returnBody()->accept(this);
if (ctx->DISTINCT()) {
return_clause->body_.distinct = true;
@ -659,7 +661,7 @@ antlrcpp::Any CypherMainVisitor::visitTestStream(
limit_batches = ctx->limitBatchesOption()->accept(this);
}
return storage_.Create<TestStream>(stream_name, limit_batches);
return storage_->Create<TestStream>(stream_name, limit_batches);
}
antlrcpp::Any CypherMainVisitor::visitReturnBody(
@ -693,7 +695,7 @@ antlrcpp::Any CypherMainVisitor::visitReturnItems(
antlrcpp::Any CypherMainVisitor::visitReturnItem(
MemgraphCypher::ReturnItemContext *ctx) {
auto *named_expr = storage_.Create<NamedExpression>();
auto *named_expr = storage_->Create<NamedExpression>();
named_expr->expression_ = ctx->expression()->accept(this);
if (ctx->variable()) {
named_expr->name_ =
@ -726,10 +728,10 @@ antlrcpp::Any CypherMainVisitor::visitSortItem(
antlrcpp::Any CypherMainVisitor::visitNodePattern(
MemgraphCypher::NodePatternContext *ctx) {
auto *node = storage_.Create<NodeAtom>();
auto *node = storage_->Create<NodeAtom>();
if (ctx->variable()) {
std::string variable = ctx->variable()->accept(this);
node->identifier_ = storage_.Create<Identifier>(variable);
node->identifier_ = storage_->Create<Identifier>(variable);
users_identifiers.insert(variable);
} else {
anonymous_identifiers.push_back(&node->identifier_);
@ -848,7 +850,7 @@ antlrcpp::Any CypherMainVisitor::visitPatternPart(
Pattern *pattern = ctx->anonymousPatternPart()->accept(this);
if (ctx->variable()) {
std::string variable = ctx->variable()->accept(this);
pattern->identifier_ = storage_.Create<Identifier>(variable);
pattern->identifier_ = storage_->Create<Identifier>(variable);
users_identifiers.insert(variable);
} else {
anonymous_identifiers.push_back(&pattern->identifier_);
@ -861,7 +863,7 @@ antlrcpp::Any CypherMainVisitor::visitPatternElement(
if (ctx->patternElement()) {
return ctx->patternElement()->accept(this);
}
auto pattern = storage_.Create<Pattern>();
auto pattern = storage_->Create<Pattern>();
pattern->atoms_.push_back(ctx->nodePattern()->accept(this).as<NodeAtom *>());
for (auto *pattern_element_chain : ctx->patternElementChain()) {
std::pair<PatternAtom *, PatternAtom *> element =
@ -881,7 +883,7 @@ antlrcpp::Any CypherMainVisitor::visitPatternElementChain(
antlrcpp::Any CypherMainVisitor::visitRelationshipPattern(
MemgraphCypher::RelationshipPatternContext *ctx) {
auto *edge = storage_.Create<EdgeAtom>();
auto *edge = storage_->Create<EdgeAtom>();
auto relationshipDetail = ctx->relationshipDetail();
auto *variableExpansion =
@ -909,7 +911,7 @@ antlrcpp::Any CypherMainVisitor::visitRelationshipPattern(
if (relationshipDetail->name) {
std::string variable = relationshipDetail->name->accept(this);
edge->identifier_ = storage_.Create<Identifier>(variable);
edge->identifier_ = storage_->Create<Identifier>(variable);
users_identifiers.insert(variable);
} else {
anonymous_identifiers.push_back(&edge->identifier_);
@ -934,11 +936,11 @@ antlrcpp::Any CypherMainVisitor::visitRelationshipPattern(
std::string traversed_edge_variable =
lambda->traversed_edge->accept(this);
edge_lambda.inner_edge =
storage_.Create<Identifier>(traversed_edge_variable);
storage_->Create<Identifier>(traversed_edge_variable);
std::string traversed_node_variable =
lambda->traversed_node->accept(this);
edge_lambda.inner_node =
storage_.Create<Identifier>(traversed_node_variable);
storage_->Create<Identifier>(traversed_node_variable);
edge_lambda.expression = lambda->expression()->accept(this);
return edge_lambda;
};
@ -946,7 +948,7 @@ antlrcpp::Any CypherMainVisitor::visitRelationshipPattern(
if (relationshipDetail->total_weight) {
std::string total_weight_name =
relationshipDetail->total_weight->accept(this);
edge->total_weight_ = storage_.Create<Identifier>(total_weight_name);
edge->total_weight_ = storage_->Create<Identifier>(total_weight_name);
} else {
anonymous_identifiers.push_back(&edge->total_weight_);
}
@ -1160,7 +1162,8 @@ antlrcpp::Any CypherMainVisitor::visitExpression8(
first_operand = comparisons[0];
// Calculate logical and of results of comparisons.
for (int i = 1; i < (int)comparisons.size(); ++i) {
first_operand = storage_.Create<AndOperator>(first_operand, comparisons[i]);
first_operand =
storage_->Create<AndOperator>(first_operand, comparisons[i]);
}
return first_operand;
}
@ -1213,13 +1216,13 @@ antlrcpp::Any CypherMainVisitor::visitExpression3a(
for (auto *op : ctx->stringAndNullOperators()) {
if (op->IS() && op->NOT() && op->CYPHERNULL()) {
expression = static_cast<Expression *>(storage_.Create<NotOperator>(
storage_.Create<IsNullOperator>(expression)));
expression = static_cast<Expression *>(storage_->Create<NotOperator>(
storage_->Create<IsNullOperator>(expression)));
} else if (op->IS() && op->CYPHERNULL()) {
expression = static_cast<Expression *>(
storage_.Create<IsNullOperator>(expression));
storage_->Create<IsNullOperator>(expression));
} else if (op->IN()) {
expression = static_cast<Expression *>(storage_.Create<InListOperator>(
expression = static_cast<Expression *>(storage_->Create<InListOperator>(
expression, op->expression3b()->accept(this)));
} else {
std::string function_name;
@ -1235,7 +1238,7 @@ antlrcpp::Any CypherMainVisitor::visitExpression3a(
auto expression2 = op->expression3b()->accept(this);
std::vector<Expression *> args = {expression, expression2};
expression = static_cast<Expression *>(
storage_.Create<Function>(function_name, args));
storage_->Create<Function>(function_name, args));
}
}
return expression;
@ -1252,7 +1255,7 @@ antlrcpp::Any CypherMainVisitor::visitExpression3b(
for (auto *list_op : ctx->listIndexingOrSlicing()) {
if (list_op->getTokens(MemgraphCypher::DOTS).size() == 0U) {
// If there is no '..' then we need to create list indexing operator.
expression = storage_.Create<SubscriptOperator>(
expression = storage_->Create<SubscriptOperator>(
expression, list_op->expression()[0]->accept(this));
} else if (!list_op->lower_bound && !list_op->upper_bound) {
throw SemanticException(
@ -1266,7 +1269,7 @@ antlrcpp::Any CypherMainVisitor::visitExpression3b(
list_op->upper_bound
? static_cast<Expression *>(list_op->upper_bound->accept(this))
: nullptr;
expression = storage_.Create<ListSlicingOperator>(
expression = storage_->Create<ListSlicingOperator>(
expression, lower_bound_ast, upper_bound_ast);
}
}
@ -1285,7 +1288,7 @@ antlrcpp::Any CypherMainVisitor::visitExpression2a(
if (ctx->nodeLabels()) {
auto labels =
ctx->nodeLabels()->accept(this).as<std::vector<storage::Label>>();
expression = storage_.Create<LabelsTest>(expression, labels);
expression = storage_->Create<LabelsTest>(expression, labels);
}
return expression;
}
@ -1296,7 +1299,7 @@ antlrcpp::Any CypherMainVisitor::visitExpression2b(
for (auto *lookup : ctx->propertyLookup()) {
std::pair<std::string, storage::Property> key = lookup->accept(this);
auto property_lookup =
storage_.Create<PropertyLookup>(expression, key.first, key.second);
storage_->Create<PropertyLookup>(expression, key.first, key.second);
expression = property_lookup;
}
return expression;
@ -1314,21 +1317,21 @@ antlrcpp::Any CypherMainVisitor::visitAtom(MemgraphCypher::AtomContext *ctx) {
} else if (ctx->variable()) {
std::string variable = ctx->variable()->accept(this);
users_identifiers.insert(variable);
return static_cast<Expression *>(storage_.Create<Identifier>(variable));
return static_cast<Expression *>(storage_->Create<Identifier>(variable));
} else if (ctx->functionInvocation()) {
return static_cast<Expression *>(ctx->functionInvocation()->accept(this));
} else if (ctx->COUNT()) {
// Here we handle COUNT(*). COUNT(expression) is handled in
// visitFunctionInvocation with other aggregations. This is visible in
// functionInvocation and atom producions in opencypher grammar.
return static_cast<Expression *>(
storage_.Create<Aggregation>(nullptr, nullptr, Aggregation::Op::COUNT));
return static_cast<Expression *>(storage_->Create<Aggregation>(
nullptr, nullptr, Aggregation::Op::COUNT));
} else if (ctx->ALL()) {
auto *ident = storage_.Create<Identifier>(ctx->filterExpression()
->idInColl()
->variable()
->accept(this)
.as<std::string>());
auto *ident = storage_->Create<Identifier>(ctx->filterExpression()
->idInColl()
->variable()
->accept(this)
.as<std::string>());
Expression *list_expr =
ctx->filterExpression()->idInColl()->expression()->accept(this);
if (!ctx->filterExpression()->where()) {
@ -1336,13 +1339,13 @@ antlrcpp::Any CypherMainVisitor::visitAtom(MemgraphCypher::AtomContext *ctx) {
}
Where *where = ctx->filterExpression()->where()->accept(this);
return static_cast<Expression *>(
storage_.Create<All>(ident, list_expr, where));
storage_->Create<All>(ident, list_expr, where));
} else if (ctx->SINGLE()) {
auto *ident = storage_.Create<Identifier>(ctx->filterExpression()
->idInColl()
->variable()
->accept(this)
.as<std::string>());
auto *ident = storage_->Create<Identifier>(ctx->filterExpression()
->idInColl()
->variable()
->accept(this)
.as<std::string>());
Expression *list_expr =
ctx->filterExpression()->idInColl()->expression()->accept(this);
if (!ctx->filterExpression()->where()) {
@ -1350,35 +1353,35 @@ antlrcpp::Any CypherMainVisitor::visitAtom(MemgraphCypher::AtomContext *ctx) {
}
Where *where = ctx->filterExpression()->where()->accept(this);
return static_cast<Expression *>(
storage_.Create<Single>(ident, list_expr, where));
storage_->Create<Single>(ident, list_expr, where));
} else if (ctx->REDUCE()) {
auto *accumulator = storage_.Create<Identifier>(
auto *accumulator = storage_->Create<Identifier>(
ctx->reduceExpression()->accumulator->accept(this).as<std::string>());
Expression *initializer = ctx->reduceExpression()->initial->accept(this);
auto *ident = storage_.Create<Identifier>(ctx->reduceExpression()
->idInColl()
->variable()
->accept(this)
.as<std::string>());
auto *ident = storage_->Create<Identifier>(ctx->reduceExpression()
->idInColl()
->variable()
->accept(this)
.as<std::string>());
Expression *list =
ctx->reduceExpression()->idInColl()->expression()->accept(this);
Expression *expr =
ctx->reduceExpression()->expression().back()->accept(this);
return static_cast<Expression *>(
storage_.Create<Reduce>(accumulator, initializer, ident, list, expr));
storage_->Create<Reduce>(accumulator, initializer, ident, list, expr));
} else if (ctx->caseExpression()) {
return static_cast<Expression *>(ctx->caseExpression()->accept(this));
} else if (ctx->extractExpression()) {
auto *ident = storage_.Create<Identifier>(ctx->extractExpression()
->idInColl()
->variable()
->accept(this)
.as<std::string>());
auto *ident = storage_->Create<Identifier>(ctx->extractExpression()
->idInColl()
->variable()
->accept(this)
.as<std::string>());
Expression *list =
ctx->extractExpression()->idInColl()->expression()->accept(this);
Expression *expr = ctx->extractExpression()->expression()->accept(this);
return static_cast<Expression *>(
storage_.Create<Extract>(ident, list, expr));
storage_->Create<Extract>(ident, list, expr));
}
// TODO: Implement this. We don't support comprehensions, filtering... at
// the moment.
@ -1387,7 +1390,7 @@ antlrcpp::Any CypherMainVisitor::visitAtom(MemgraphCypher::AtomContext *ctx) {
antlrcpp::Any CypherMainVisitor::visitParameter(
MemgraphCypher::ParameterContext *ctx) {
return storage_.Create<ParameterLookup>(ctx->getStart()->getTokenIndex());
return storage_->Create<ParameterLookup>(ctx->getStart()->getTokenIndex());
}
antlrcpp::Any CypherMainVisitor::visitLiteral(
@ -1397,7 +1400,7 @@ antlrcpp::Any CypherMainVisitor::visitLiteral(
int token_position = ctx->getStart()->getTokenIndex();
if (ctx->CYPHERNULL()) {
return static_cast<Expression *>(
storage_.Create<PrimitiveLiteral>(TypedValue::Null, token_position));
storage_->Create<PrimitiveLiteral>(TypedValue::Null, token_position));
} else if (context_.is_query_cached) {
// Instead of generating PrimitiveLiteral, we generate a
// ParameterLookup, so that the AST can be cached. This allows for
@ -1405,24 +1408,24 @@ antlrcpp::Any CypherMainVisitor::visitLiteral(
// (even though they are not user provided). Note, that NULL always
// generates a PrimitiveLiteral.
return static_cast<Expression *>(
storage_.Create<ParameterLookup>(token_position));
storage_->Create<ParameterLookup>(token_position));
} else if (ctx->StringLiteral()) {
return static_cast<Expression *>(storage_.Create<PrimitiveLiteral>(
return static_cast<Expression *>(storage_->Create<PrimitiveLiteral>(
visitStringLiteral(ctx->StringLiteral()->getText()).as<std::string>(),
token_position));
} else if (ctx->booleanLiteral()) {
return static_cast<Expression *>(storage_.Create<PrimitiveLiteral>(
return static_cast<Expression *>(storage_->Create<PrimitiveLiteral>(
ctx->booleanLiteral()->accept(this).as<bool>(), token_position));
} else if (ctx->numberLiteral()) {
return static_cast<Expression *>(storage_.Create<PrimitiveLiteral>(
return static_cast<Expression *>(storage_->Create<PrimitiveLiteral>(
ctx->numberLiteral()->accept(this).as<TypedValue>(), token_position));
}
LOG(FATAL) << "Expected to handle all cases above";
} else if (ctx->listLiteral()) {
return static_cast<Expression *>(storage_.Create<ListLiteral>(
return static_cast<Expression *>(storage_->Create<ListLiteral>(
ctx->listLiteral()->accept(this).as<std::vector<Expression *>>()));
} else {
return static_cast<Expression *>(storage_.Create<MapLiteral>(
return static_cast<Expression *>(storage_->Create<MapLiteral>(
ctx->mapLiteral()
->accept(this)
.as<std::unordered_map<std::pair<std::string, storage::Property>,
@ -1462,33 +1465,33 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(
}
if (expressions.size() == 1U) {
if (function_name == Aggregation::kCount) {
return static_cast<Expression *>(storage_.Create<Aggregation>(
return static_cast<Expression *>(storage_->Create<Aggregation>(
expressions[0], nullptr, Aggregation::Op::COUNT));
}
if (function_name == Aggregation::kMin) {
return static_cast<Expression *>(storage_.Create<Aggregation>(
return static_cast<Expression *>(storage_->Create<Aggregation>(
expressions[0], nullptr, Aggregation::Op::MIN));
}
if (function_name == Aggregation::kMax) {
return static_cast<Expression *>(storage_.Create<Aggregation>(
return static_cast<Expression *>(storage_->Create<Aggregation>(
expressions[0], nullptr, Aggregation::Op::MAX));
}
if (function_name == Aggregation::kSum) {
return static_cast<Expression *>(storage_.Create<Aggregation>(
return static_cast<Expression *>(storage_->Create<Aggregation>(
expressions[0], nullptr, Aggregation::Op::SUM));
}
if (function_name == Aggregation::kAvg) {
return static_cast<Expression *>(storage_.Create<Aggregation>(
return static_cast<Expression *>(storage_->Create<Aggregation>(
expressions[0], nullptr, Aggregation::Op::AVG));
}
if (function_name == Aggregation::kCollect) {
return static_cast<Expression *>(storage_.Create<Aggregation>(
return static_cast<Expression *>(storage_->Create<Aggregation>(
expressions[0], nullptr, Aggregation::Op::COLLECT_LIST));
}
}
if (expressions.size() == 2U && function_name == Aggregation::kCollect) {
return static_cast<Expression *>(storage_.Create<Aggregation>(
return static_cast<Expression *>(storage_->Create<Aggregation>(
expressions[1], expressions[0], Aggregation::Op::COLLECT_MAP));
}
@ -1496,7 +1499,7 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(
if (!function)
throw SemanticException("Function '{}' doesn't exist.", function_name);
return static_cast<Expression *>(
storage_.Create<Function>(function_name, expressions));
storage_->Create<Function>(function_name, expressions));
}
antlrcpp::Any CypherMainVisitor::visitFunctionName(
@ -1533,7 +1536,7 @@ antlrcpp::Any CypherMainVisitor::visitBooleanLiteral(
antlrcpp::Any CypherMainVisitor::visitCypherDelete(
MemgraphCypher::CypherDeleteContext *ctx) {
auto *del = storage_.Create<Delete>();
auto *del = storage_->Create<Delete>();
if (ctx->DETACH()) {
del->detach_ = true;
}
@ -1544,7 +1547,7 @@ antlrcpp::Any CypherMainVisitor::visitCypherDelete(
}
antlrcpp::Any CypherMainVisitor::visitWhere(MemgraphCypher::WhereContext *ctx) {
auto *where = storage_.Create<Where>();
auto *where = storage_->Create<Where>();
where->expression_ = ctx->expression()->accept(this);
return where;
}
@ -1561,7 +1564,7 @@ antlrcpp::Any CypherMainVisitor::visitSetItem(
MemgraphCypher::SetItemContext *ctx) {
// SetProperty
if (ctx->propertyExpression()) {
auto *set_property = storage_.Create<SetProperty>();
auto *set_property = storage_->Create<SetProperty>();
set_property->property_lookup_ = ctx->propertyExpression()->accept(this);
set_property->expression_ = ctx->expression()->accept(this);
return static_cast<Clause *>(set_property);
@ -1570,8 +1573,8 @@ antlrcpp::Any CypherMainVisitor::visitSetItem(
// SetProperties either assignment or update
if (ctx->getTokens(MemgraphCypher::EQ).size() ||
ctx->getTokens(MemgraphCypher::PLUS_EQ).size()) {
auto *set_properties = storage_.Create<SetProperties>();
set_properties->identifier_ = storage_.Create<Identifier>(
auto *set_properties = storage_->Create<SetProperties>();
set_properties->identifier_ = storage_->Create<Identifier>(
ctx->variable()->accept(this).as<std::string>());
set_properties->expression_ = ctx->expression()->accept(this);
if (ctx->getTokens(MemgraphCypher::PLUS_EQ).size()) {
@ -1581,8 +1584,8 @@ antlrcpp::Any CypherMainVisitor::visitSetItem(
}
// SetLabels
auto *set_labels = storage_.Create<SetLabels>();
set_labels->identifier_ = storage_.Create<Identifier>(
auto *set_labels = storage_->Create<SetLabels>();
set_labels->identifier_ = storage_->Create<Identifier>(
ctx->variable()->accept(this).as<std::string>());
set_labels->labels_ =
ctx->nodeLabels()->accept(this).as<std::vector<storage::Label>>();
@ -1602,14 +1605,14 @@ antlrcpp::Any CypherMainVisitor::visitRemoveItem(
MemgraphCypher::RemoveItemContext *ctx) {
// RemoveProperty
if (ctx->propertyExpression()) {
auto *remove_property = storage_.Create<RemoveProperty>();
auto *remove_property = storage_->Create<RemoveProperty>();
remove_property->property_lookup_ = ctx->propertyExpression()->accept(this);
return static_cast<Clause *>(remove_property);
}
// RemoveLabels
auto *remove_labels = storage_.Create<RemoveLabels>();
remove_labels->identifier_ = storage_.Create<Identifier>(
auto *remove_labels = storage_->Create<RemoveLabels>();
remove_labels->identifier_ = storage_->Create<Identifier>(
ctx->variable()->accept(this).as<std::string>());
remove_labels->labels_ =
ctx->nodeLabels()->accept(this).as<std::vector<storage::Label>>();
@ -1622,7 +1625,7 @@ antlrcpp::Any CypherMainVisitor::visitPropertyExpression(
for (auto *lookup : ctx->propertyLookup()) {
std::pair<std::string, storage::Property> key = lookup->accept(this);
auto property_lookup =
storage_.Create<PropertyLookup>(expression, key.first, key.second);
storage_->Create<PropertyLookup>(expression, key.first, key.second);
expression = property_lookup;
}
// It is guaranteed by grammar that there is at least one propertyLookup.
@ -1639,16 +1642,16 @@ antlrcpp::Any CypherMainVisitor::visitCaseExpression(
Expression *else_expression =
ctx->else_expression
? ctx->else_expression->accept(this).as<Expression *>()
: storage_.Create<PrimitiveLiteral>(TypedValue::Null);
: storage_->Create<PrimitiveLiteral>(TypedValue::Null);
for (auto *alternative : alternatives) {
Expression *condition =
test_expression
? storage_.Create<EqualOperator>(
? storage_->Create<EqualOperator>(
test_expression, alternative->when_expression->accept(this))
: alternative->when_expression->accept(this).as<Expression *>();
Expression *then_expression = alternative->then_expression->accept(this);
else_expression = storage_.Create<IfOperator>(condition, then_expression,
else_expression);
else_expression = storage_->Create<IfOperator>(condition, then_expression,
else_expression);
}
return else_expression;
}
@ -1660,7 +1663,7 @@ antlrcpp::Any CypherMainVisitor::visitCaseAlternatives(
}
antlrcpp::Any CypherMainVisitor::visitWith(MemgraphCypher::WithContext *ctx) {
auto *with = storage_.Create<With>();
auto *with = storage_->Create<With>();
in_with_ = true;
with->body_ = ctx->returnBody()->accept(this);
in_with_ = false;
@ -1674,7 +1677,7 @@ antlrcpp::Any CypherMainVisitor::visitWith(MemgraphCypher::WithContext *ctx) {
}
antlrcpp::Any CypherMainVisitor::visitMerge(MemgraphCypher::MergeContext *ctx) {
auto *merge = storage_.Create<Merge>();
auto *merge = storage_->Create<Merge>();
merge->pattern_ = ctx->patternPart()->accept(this);
for (auto &merge_action : ctx->mergeAction()) {
auto set = merge_action->set()->accept(this).as<std::vector<Clause *>>();
@ -1690,11 +1693,11 @@ antlrcpp::Any CypherMainVisitor::visitMerge(MemgraphCypher::MergeContext *ctx) {
antlrcpp::Any CypherMainVisitor::visitUnwind(
MemgraphCypher::UnwindContext *ctx) {
auto *named_expr = storage_.Create<NamedExpression>();
auto *named_expr = storage_->Create<NamedExpression>();
named_expr->expression_ = ctx->expression()->accept(this);
named_expr->name_ =
std::string(ctx->variable()->accept(this).as<std::string>());
return storage_.Create<Unwind>(named_expr);
return storage_->Create<Unwind>(named_expr);
}
antlrcpp::Any CypherMainVisitor::visitFilterExpression(

View File

@ -20,43 +20,43 @@ using query::Context;
class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
public:
explicit CypherMainVisitor(ParsingContext context,
explicit CypherMainVisitor(ParsingContext context, AstStorage *storage,
database::GraphDbAccessor *dba)
: context_(context), dba_(dba) {}
: context_(context), storage_(storage), dba_(dba) {}
private:
Expression *CreateBinaryOperatorByToken(size_t token, Expression *e1,
Expression *e2) {
switch (token) {
case MemgraphCypher::OR:
return storage_.Create<OrOperator>(e1, e2);
return storage_->Create<OrOperator>(e1, e2);
case MemgraphCypher::XOR:
return storage_.Create<XorOperator>(e1, e2);
return storage_->Create<XorOperator>(e1, e2);
case MemgraphCypher::AND:
return storage_.Create<AndOperator>(e1, e2);
return storage_->Create<AndOperator>(e1, e2);
case MemgraphCypher::PLUS:
return storage_.Create<AdditionOperator>(e1, e2);
return storage_->Create<AdditionOperator>(e1, e2);
case MemgraphCypher::MINUS:
return storage_.Create<SubtractionOperator>(e1, e2);
return storage_->Create<SubtractionOperator>(e1, e2);
case MemgraphCypher::ASTERISK:
return storage_.Create<MultiplicationOperator>(e1, e2);
return storage_->Create<MultiplicationOperator>(e1, e2);
case MemgraphCypher::SLASH:
return storage_.Create<DivisionOperator>(e1, e2);
return storage_->Create<DivisionOperator>(e1, e2);
case MemgraphCypher::PERCENT:
return storage_.Create<ModOperator>(e1, e2);
return storage_->Create<ModOperator>(e1, e2);
case MemgraphCypher::EQ:
return storage_.Create<EqualOperator>(e1, e2);
return storage_->Create<EqualOperator>(e1, e2);
case MemgraphCypher::NEQ1:
case MemgraphCypher::NEQ2:
return storage_.Create<NotEqualOperator>(e1, e2);
return storage_->Create<NotEqualOperator>(e1, e2);
case MemgraphCypher::LT:
return storage_.Create<LessOperator>(e1, e2);
return storage_->Create<LessOperator>(e1, e2);
case MemgraphCypher::GT:
return storage_.Create<GreaterOperator>(e1, e2);
return storage_->Create<GreaterOperator>(e1, e2);
case MemgraphCypher::LTE:
return storage_.Create<LessEqualOperator>(e1, e2);
return storage_->Create<LessEqualOperator>(e1, e2);
case MemgraphCypher::GTE:
return storage_.Create<GreaterEqualOperator>(e1, e2);
return storage_->Create<GreaterEqualOperator>(e1, e2);
default:
throw utils::NotYetImplemented("binary operator");
}
@ -65,11 +65,11 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
Expression *CreateUnaryOperatorByToken(size_t token, Expression *e) {
switch (token) {
case MemgraphCypher::NOT:
return storage_.Create<NotOperator>(e);
return storage_->Create<NotOperator>(e);
case MemgraphCypher::PLUS:
return storage_.Create<UnaryPlusOperator>(e);
return storage_->Create<UnaryPlusOperator>(e);
case MemgraphCypher::MINUS:
return storage_.Create<UnaryMinusOperator>(e);
return storage_->Create<UnaryMinusOperator>(e);
default:
throw utils::NotYetImplemented("unary operator");
}
@ -750,18 +750,17 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
public:
Query *query() { return query_; }
AstStorage &storage() { return storage_; }
const static std::string kAnonPrefix;
private:
ParsingContext context_;
AstStorage *storage_;
database::GraphDbAccessor *dba_;
// Set of identifiers from queries.
std::unordered_set<std::string> users_identifiers;
// Identifiers that user didn't name.
std::vector<Identifier **> anonymous_identifiers;
AstStorage storage_;
Query *query_ = nullptr;
// All return items which are not variables must be aliased in with.
// We use this variable in visitReturnItem to check if we are in with or

View File

@ -99,10 +99,9 @@ class PrivilegeExtractor : public HierarchicalTreeVisitor {
std::vector<AuthQuery::Privilege> privileges_;
};
std::vector<AuthQuery::Privilege> GetRequiredPrivileges(
const AstStorage &ast_storage) {
std::vector<AuthQuery::Privilege> GetRequiredPrivileges(Query *query) {
PrivilegeExtractor extractor;
ast_storage.query()->Accept(extractor);
query->Accept(extractor);
return extractor.privileges();
}

View File

@ -3,6 +3,5 @@
#include "query/frontend/ast/ast.hpp"
namespace query {
std::vector<AuthQuery::Privilege> GetRequiredPrivileges(
const AstStorage &ast_storage);
std::vector<AuthQuery::Privilege> GetRequiredPrivileges(Query *query);
}

View File

@ -57,10 +57,13 @@ Interpreter::Results Interpreter::operator()(
ParsingContext parsing_context;
parsing_context.is_query_cached = true;
AstStorage ast_storage = QueryToAst(stripped, parsing_context, &db_accessor);
AstStorage ast_storage;
Query *ast =
QueryToAst(stripped, parsing_context, &ast_storage, &db_accessor);
// TODO: Maybe cache required privileges to improve performance on very simple
// queries.
auto required_privileges = query::GetRequiredPrivileges(ast_storage);
auto required_privileges = query::GetRequiredPrivileges(ast);
auto frontend_time = frontend_timer.Elapsed();
// Try to get a cached plan. Note that this local shared_ptr might be the only
@ -78,7 +81,8 @@ Interpreter::Results Interpreter::operator()(
utils::Timer planning_timer;
if (!plan) {
plan = plan_cache_access
.insert(stripped.hash(), AstToPlan(std::move(ast_storage), &ctx))
.insert(stripped.hash(),
AstToPlan(ast, std::move(ast_storage), &ctx))
.first->second;
}
auto planning_time = planning_timer.Elapsed();
@ -114,16 +118,17 @@ Interpreter::Results Interpreter::operator()(
}
std::shared_ptr<Interpreter::CachedPlan> Interpreter::AstToPlan(
AstStorage ast_storage, Context *ctx) {
Query *query, AstStorage ast_storage, Context *ctx) {
SymbolGenerator symbol_generator(ctx->symbol_table_);
ast_storage.query()->Accept(symbol_generator);
query->Accept(symbol_generator);
return std::make_shared<CachedPlan>(
MakeLogicalPlan(std::move(ast_storage), ctx));
MakeLogicalPlan(query, std::move(ast_storage), ctx));
}
AstStorage Interpreter::QueryToAst(const StrippedQuery &stripped,
const ParsingContext &context,
database::GraphDbAccessor *db_accessor) {
Query *Interpreter::QueryToAst(const StrippedQuery &stripped,
const ParsingContext &context,
AstStorage *ast_storage,
database::GraphDbAccessor *db_accessor) {
if (!context.is_query_cached) {
// stripped query -> AST
auto parser = [&] {
@ -134,9 +139,9 @@ AstStorage Interpreter::QueryToAst(const StrippedQuery &stripped,
}();
auto low_level_tree = parser->tree();
// AST -> high level tree
frontend::CypherMainVisitor visitor(context, db_accessor);
frontend::CypherMainVisitor visitor(context, ast_storage, db_accessor);
visitor.visit(low_level_tree);
return std::move(visitor.storage());
return visitor.query();
}
auto ast_cache_accessor = ast_cache_.access();
auto ast_it = ast_cache_accessor.find(stripped.hash());
@ -160,16 +165,16 @@ AstStorage Interpreter::QueryToAst(const StrippedQuery &stripped,
}();
auto low_level_tree = parser->tree();
// AST -> high level tree
frontend::CypherMainVisitor visitor(context, db_accessor);
CachedQuery cached_query;
frontend::CypherMainVisitor visitor(context, &cached_query.ast_storage,
db_accessor);
visitor.visit(low_level_tree);
cached_query.query = visitor.query();
// Cache it.
ast_it =
ast_cache_accessor.insert(stripped.hash(), std::move(visitor.storage()))
.first;
ast_it = ast_cache_accessor.insert(stripped.hash(), std::move(cached_query))
.first;
}
AstStorage new_ast;
ast_it->second.query()->Clone(new_ast);
return new_ast;
return ast_it->second.query->Clone(*ast_storage);
}
class SingleNodeLogicalPlan final : public LogicalPlan {
@ -194,10 +199,10 @@ class SingleNodeLogicalPlan final : public LogicalPlan {
};
std::unique_ptr<LogicalPlan> Interpreter::MakeLogicalPlan(
AstStorage ast_storage, Context *context) {
Query *query, AstStorage ast_storage, Context *context) {
auto vertex_counts = plan::MakeVertexCountCache(context->db_accessor_);
auto planning_context = plan::MakePlanningContext(
ast_storage, context->symbol_table_, vertex_counts);
ast_storage, context->symbol_table_, query, vertex_counts);
std::unique_ptr<plan::LogicalOperator> root;
double cost;
std::tie(root, cost) = plan::MakeLogicalPlan(

View File

@ -57,6 +57,11 @@ class Interpreter {
utils::Timer cache_timer_;
};
struct CachedQuery {
AstStorage ast_storage;
Query *query;
};
using PlanCacheT = ConcurrentMap<HashType, std::shared_ptr<CachedPlan>>;
public:
@ -180,10 +185,11 @@ class Interpreter {
// high level tree -> logical plan
// AstStorage and SymbolTable may be modified during planning. The created
// LogicalPlan must take ownership of AstStorage and SymbolTable.
virtual std::unique_ptr<LogicalPlan> MakeLogicalPlan(AstStorage, Context *);
virtual std::unique_ptr<LogicalPlan> MakeLogicalPlan(Query *, AstStorage,
Context *);
private:
ConcurrentMap<HashType, AstStorage> ast_cache_;
ConcurrentMap<HashType, CachedQuery> ast_cache_;
PlanCacheT plan_cache_;
// Antlr has singleton instance that is shared between threads. It is
// protected by locks inside of antlr. Unfortunately, they are not protected
@ -194,11 +200,12 @@ class Interpreter {
utils::SpinLock antlr_lock_;
// high level tree -> CachedPlan
std::shared_ptr<CachedPlan> AstToPlan(AstStorage ast_storage, Context *ctx);
std::shared_ptr<CachedPlan> AstToPlan(Query *query, AstStorage ast_storage,
Context *ctx);
// stripped query -> high level tree
AstStorage QueryToAst(const StrippedQuery &stripped,
const ParsingContext &context,
database::GraphDbAccessor *db_accessor);
Query *QueryToAst(const StrippedQuery &stripped,
const ParsingContext &context, AstStorage *ast_storage,
database::GraphDbAccessor *db_accessor);
};
} // namespace query

View File

@ -1,5 +1,6 @@
/// @file
/// This file is an entry point for invoking various planners via the following API:
/// This file is an entry point for invoking various planners via the following
/// API:
/// * `MakeLogicalPlanForSingleQuery`
/// * `MakeLogicalPlan`
@ -51,8 +52,8 @@ auto MakeLogicalPlanForSingleQuery(
template <class TPlanningContext>
auto MakeLogicalPlan(TPlanningContext &context, const Parameters &parameters,
bool use_variable_planner) {
auto query_parts =
CollectQueryParts(context.symbol_table, context.ast_storage);
auto query_parts = CollectQueryParts(context.symbol_table,
context.ast_storage, context.query);
auto &vertex_counts = context.db;
double total_cost = 0;
std::unique_ptr<LogicalOperator> last_op;
@ -99,7 +100,7 @@ auto MakeLogicalPlan(TPlanningContext &context, const Parameters &parameters,
prev_op, prev_op->OutputSymbols(context.symbol_table));
}
if (context.ast_storage.query()->explain_) {
if (context.query->explain_) {
last_op = std::make_unique<Explain>(
std::move(last_op),
context.symbol_table.CreateSymbol("QUERY PLAN", false),

View File

@ -443,8 +443,7 @@ void Filters::AnalyzeAndStoreFilter(Expression *expr,
// Converts a Query to multiple QueryParts. In the process new Ast nodes may be
// created, e.g. filter expressions.
std::vector<SingleQueryPart> CollectSingleQueryParts(
SymbolTable &symbol_table, AstStorage &storage,
SingleQuery *single_query) {
SymbolTable &symbol_table, AstStorage &storage, SingleQuery *single_query) {
std::vector<SingleQueryPart> query_parts(1);
auto *query_part = &query_parts.back();
for (auto &clause : single_query->clauses_) {
@ -477,8 +476,8 @@ std::vector<SingleQueryPart> CollectSingleQueryParts(
return query_parts;
}
QueryParts CollectQueryParts(SymbolTable &symbol_table, AstStorage &storage) {
auto *query = storage.query();
QueryParts CollectQueryParts(SymbolTable &symbol_table, AstStorage &storage,
Query *query) {
std::vector<QueryPart> query_parts;
auto *single_query = query->single_query_;

View File

@ -290,6 +290,6 @@ struct QueryParts {
/// and do some other preprocessing in order to generate multiple @c QueryPart
/// structures. @c AstStorage and @c SymbolTable may be used to create new
/// AST nodes.
QueryParts CollectQueryParts(SymbolTable &, AstStorage &);
QueryParts CollectQueryParts(SymbolTable &, AstStorage &, Query *);
} // namespace query::plan

View File

@ -21,9 +21,10 @@ struct PlanningContext {
///
/// Newly created AST nodes may be added to reference existing symbols.
SymbolTable &symbol_table;
/// @brief The storage is used to traverse the AST as well as create new nodes
/// for use in operators.
/// @brief The storage is used to create new AST nodes for use in operators.
AstStorage &ast_storage;
/// @brief Query to be planned
Query *query;
/// @brief TDbAccessor, which may be used to get some information from the
/// database to generate better plans. The accessor is required only to live
/// long enough for the plan generation to finish.
@ -39,8 +40,8 @@ struct PlanningContext {
template <class TDbAccessor>
auto MakePlanningContext(AstStorage &ast_storage, SymbolTable &symbol_table,
const TDbAccessor &db) {
return PlanningContext<TDbAccessor>{symbol_table, ast_storage, db};
Query *query, const TDbAccessor &db) {
return PlanningContext<TDbAccessor>{symbol_table, ast_storage, query, db};
}
// Contextual information used for generating match operators.

View File

@ -10,7 +10,9 @@
#include "query/plan/vertex_count_cache.hpp"
// Add chained MATCH (node1) -- (node2), MATCH (node2) -- (node3) ... clauses.
static void AddChainedMatches(int num_matches, query::AstStorage &storage) {
static query::Query *AddChainedMatches(int num_matches,
query::AstStorage &storage) {
auto *query = storage.Create<query::Query>();
for (int i = 0; i < num_matches; ++i) {
auto *match = storage.Create<query::Match>();
auto *pattern = storage.Create<query::Pattern>();
@ -26,8 +28,9 @@ static void AddChainedMatches(int num_matches, query::AstStorage &storage) {
pattern->atoms_.emplace_back(storage.Create<query::NodeAtom>(
storage.Create<query::Identifier>("node" + std::to_string(i))));
single_query->clauses_.emplace_back(match);
storage.query()->single_query_ = single_query;
query->single_query_ = single_query;
}
return query;
}
static void BM_PlanChainedMatches(benchmark::State &state) {
@ -37,13 +40,15 @@ static void BM_PlanChainedMatches(benchmark::State &state) {
state.PauseTiming();
query::AstStorage storage;
int num_matches = state.range(0);
AddChainedMatches(num_matches, storage);
auto *query = AddChainedMatches(num_matches, storage);
query::SymbolTable symbol_table;
query::SymbolGenerator symbol_generator(symbol_table);
storage.query()->Accept(symbol_generator);
auto ctx = query::plan::MakePlanningContext(storage, symbol_table, *dba);
query->Accept(symbol_generator);
auto ctx =
query::plan::MakePlanningContext(storage, symbol_table, query, *dba);
state.ResumeTiming();
auto query_parts = query::plan::CollectQueryParts(symbol_table, storage);
auto query_parts =
query::plan::CollectQueryParts(symbol_table, storage, query);
if (query_parts.query_parts.size() == 0) {
std::exit(EXIT_FAILURE);
}
@ -62,10 +67,11 @@ BENCHMARK(BM_PlanChainedMatches)
->Range(50, 400)
->Unit(benchmark::kMillisecond);
static void AddIndexedMatches(
static query::Query *AddIndexedMatches(
int num_matches, storage::Label label,
const std::pair<std::string, storage::Property> &property,
query::AstStorage &storage) {
auto *query = storage.Create<query::Query>();
for (int i = 0; i < num_matches; ++i) {
auto *match = storage.Create<query::Match>();
auto *pattern = storage.Create<query::Pattern>();
@ -79,8 +85,9 @@ static void AddIndexedMatches(
node->properties_[property] = storage.Create<query::PrimitiveLiteral>(i);
pattern->atoms_.emplace_back(node);
single_query->clauses_.emplace_back(match);
storage.query()->single_query_ = single_query;
query->single_query_ = single_query;
}
return query;
}
static auto CreateIndexedVertices(int index_count, int vertex_count,
@ -112,14 +119,16 @@ static void BM_PlanAndEstimateIndexedMatching(benchmark::State &state) {
while (state.KeepRunning()) {
state.PauseTiming();
query::AstStorage storage;
AddIndexedMatches(index_count, label, std::make_pair("prop", prop),
storage);
auto *query = AddIndexedMatches(index_count, label,
std::make_pair("prop", prop), storage);
query::SymbolTable symbol_table;
query::SymbolGenerator symbol_generator(symbol_table);
storage.query()->Accept(symbol_generator);
query->Accept(symbol_generator);
state.ResumeTiming();
auto ctx = query::plan::MakePlanningContext(storage, symbol_table, *dba);
auto query_parts = query::plan::CollectQueryParts(symbol_table, storage);
auto ctx =
query::plan::MakePlanningContext(storage, symbol_table, query, *dba);
auto query_parts =
query::plan::CollectQueryParts(symbol_table, storage, query);
if (query_parts.query_parts.size() == 0) {
std::exit(EXIT_FAILURE);
}
@ -146,15 +155,16 @@ static void BM_PlanAndEstimateIndexedMatchingWithCachedCounts(
while (state.KeepRunning()) {
state.PauseTiming();
query::AstStorage storage;
AddIndexedMatches(index_count, label, std::make_pair("prop", prop),
storage);
auto *query = AddIndexedMatches(index_count, label,
std::make_pair("prop", prop), storage);
query::SymbolTable symbol_table;
query::SymbolGenerator symbol_generator(symbol_table);
storage.query()->Accept(symbol_generator);
query->Accept(symbol_generator);
state.ResumeTiming();
auto ctx =
query::plan::MakePlanningContext(storage, symbol_table, vertex_counts);
auto query_parts = query::plan::CollectQueryParts(symbol_table, storage);
auto ctx = query::plan::MakePlanningContext(storage, symbol_table, query,
vertex_counts);
auto query_parts =
query::plan::CollectQueryParts(symbol_table, storage, query);
if (query_parts.query_parts.size() == 0) {
std::exit(EXIT_FAILURE);
}

View File

@ -432,33 +432,34 @@ void ExaminePlans(
}
}
query::AstStorage MakeAst(const std::string &query,
database::GraphDbAccessor &dba) {
query::Query *MakeAst(const std::string &query, query::AstStorage *storage,
database::GraphDbAccessor &dba) {
query::ParsingContext parsing_context;
parsing_context.is_query_cached = false;
// query -> AST
auto parser = std::make_unique<query::frontend::opencypher::Parser>(query);
// AST -> high level tree
query::frontend::CypherMainVisitor visitor(parsing_context, &dba);
query::frontend::CypherMainVisitor visitor(parsing_context, storage, &dba);
visitor.visit(parser->tree());
return std::move(visitor.storage());
return visitor.query();
}
query::SymbolTable MakeSymbolTable(const query::AstStorage &ast) {
query::SymbolTable MakeSymbolTable(query::Query *query) {
query::SymbolTable symbol_table;
query::SymbolGenerator symbol_generator(symbol_table);
ast.query()->Accept(symbol_generator);
query->Accept(symbol_generator);
return symbol_table;
}
// Returns a list of pairs (plan, estimated cost), sorted in the ascending
// order by cost.
auto MakeLogicalPlans(query::AstStorage &ast, query::SymbolTable &symbol_table,
auto MakeLogicalPlans(query::Query *query, query::AstStorage &ast,
query::SymbolTable &symbol_table,
InteractiveDbAccessor &dba) {
auto query_parts = query::plan::CollectQueryParts(symbol_table, ast);
auto query_parts = query::plan::CollectQueryParts(symbol_table, ast, query);
std::vector<std::pair<std::unique_ptr<query::plan::LogicalOperator>, double>>
plans_with_cost;
auto ctx = query::plan::MakePlanningContext(ast, symbol_table, dba);
auto ctx = query::plan::MakePlanningContext(ast, symbol_table, query, dba);
if (query_parts.query_parts.size() <= 0) {
std::cerr << "Failed to extract query parts" << std::endl;
std::exit(EXIT_FAILURE);
@ -499,10 +500,11 @@ void RunInteractivePlanning(database::GraphDbAccessor *dba) {
if (!line || *line == "quit") break;
if (line->empty()) continue;
try {
auto ast = MakeAst(*line, *dba);
auto symbol_table = MakeSymbolTable(ast);
query::AstStorage ast;
auto *query = MakeAst(*line, &ast, *dba);
auto symbol_table = MakeSymbolTable(query);
planning_timer.Start();
auto plans = MakeLogicalPlans(ast, symbol_table, interactive_db);
auto plans = MakeLogicalPlans(query, ast, symbol_table, interactive_db);
auto planning_time = planning_timer.Elapsed();
std::cout
<< "Planning took "

View File

@ -82,17 +82,14 @@ class Base {
// This generator uses ast constructed by parsing the query.
class AstGenerator : public Base {
public:
explicit AstGenerator(const std::string &query)
: Base(query),
parser_(query),
visitor_(context_, db_accessor_.get()),
query_([&]() {
visitor_.visit(parser_.tree());
return visitor_.query();
}()) {}
explicit AstGenerator(const std::string &query) : Base(query) {
::frontend::opencypher::Parser parser(query);
CypherMainVisitor visitor(context_, &ast_storage_, db_accessor_.get());
visitor.visit(parser.tree());
query_ = visitor.query();
}
::frontend::opencypher::Parser parser_;
CypherMainVisitor visitor_;
AstStorage ast_storage_;
Query *query_;
};
@ -103,7 +100,7 @@ class OriginalAfterCloningAstGenerator : public AstGenerator {
explicit OriginalAfterCloningAstGenerator(const std::string &query)
: AstGenerator(query) {
AstStorage storage;
visitor_.query()->Clone(storage);
query_->Clone(storage);
}
};
@ -112,15 +109,15 @@ class OriginalAfterCloningAstGenerator : public AstGenerator {
// any data from original ast.
class ClonedAstGenerator : public Base {
public:
explicit ClonedAstGenerator(const std::string &query)
: Base(query), query_([&]() {
::frontend::opencypher::Parser parser(query);
CypherMainVisitor visitor(context_, db_accessor_.get());
visitor.visit(parser.tree());
return visitor.query()->Clone(storage);
}()) {}
explicit ClonedAstGenerator(const std::string &query) : Base(query) {
::frontend::opencypher::Parser parser(query);
AstStorage tmp_storage;
CypherMainVisitor visitor(context_, &tmp_storage, db_accessor_.get());
visitor.visit(parser.tree());
query_ = visitor.query()->Clone(ast_storage_);
}
AstStorage storage;
AstStorage ast_storage_;
Query *query_;
};
@ -128,22 +125,18 @@ class ClonedAstGenerator : public Base {
// the same way it is done in ast cacheing in interpreter.
class CachedAstGenerator : public Base {
public:
explicit CachedAstGenerator(const std::string &query)
: Base(query),
storage_([&]() {
context_.is_query_cached = true;
StrippedQuery stripped(query_string_);
parameters_ = stripped.literals();
::frontend::opencypher::Parser parser(stripped.query());
CypherMainVisitor visitor(context_, db_accessor_.get());
visitor.visit(parser.tree());
AstStorage new_ast;
visitor.storage().query()->Clone(new_ast);
return new_ast;
}()),
query_(storage_.query()) {}
explicit CachedAstGenerator(const std::string &query) : Base(query) {
context_.is_query_cached = true;
StrippedQuery stripped(query_string_);
parameters_ = stripped.literals();
::frontend::opencypher::Parser parser(stripped.query());
AstStorage tmp_storage;
CypherMainVisitor visitor(context_, &tmp_storage, db_accessor_.get());
visitor.visit(parser.tree());
query_ = visitor.query()->Clone(ast_storage_);
}
AstStorage storage_;
AstStorage ast_storage_;
Query *query_;
};

File diff suppressed because it is too large Load Diff

View File

@ -247,21 +247,18 @@ auto GetCypherUnion(CypherUnion *cypher_union, SingleQuery *single_query) {
}
auto GetQuery(AstStorage &storage, SingleQuery *single_query) {
storage.query()->single_query_ = single_query;
return storage.query();
}
auto GetQuery(AstStorage &storage, SingleQuery *single_query,
CypherUnion *cypher_union) {
storage.query()->cypher_unions_.emplace_back(cypher_union);
return GetQuery(storage, single_query);
auto *query = storage.Create<Query>();
query->single_query_ = single_query;
return query;
}
template <class... T>
auto GetQuery(AstStorage &storage, SingleQuery *single_query,
CypherUnion *cypher_union, T *... cypher_unions) {
storage.query()->cypher_unions_.emplace_back(cypher_union);
return GetQuery(storage, single_query, cypher_unions...);
T *... cypher_unions) {
auto *query = storage.Create<Query>();
query->single_query_ = single_query;
query->cypher_unions_ = std::vector<CypherUnion *>{cypher_unions...};
return query;
}
// Helper functions for constructing RETURN and WITH clauses.

File diff suppressed because it is too large Load Diff

View File

@ -540,9 +540,10 @@ auto MakeSymbolTable(query::Query &query) {
template <class TPlanner, class TDbAccessor>
TPlanner MakePlanner(const TDbAccessor &dba, AstStorage &storage,
SymbolTable &symbol_table) {
auto planning_context = MakePlanningContext(storage, symbol_table, dba);
auto query_parts = CollectQueryParts(symbol_table, storage);
SymbolTable &symbol_table, Query *query) {
auto planning_context =
MakePlanningContext(storage, symbol_table, query, dba);
auto query_parts = CollectQueryParts(symbol_table, storage, query);
auto single_query_parts = query_parts.query_parts.at(0).single_query_parts;
return TPlanner(single_query_parts, planning_context);
}

View File

@ -24,83 +24,89 @@ class TestPrivilegeExtractor : public ::testing::Test {
};
TEST_F(TestPrivilegeExtractor, CreateNode) {
QUERY(SINGLE_QUERY(CREATE(PATTERN(NODE("n")))));
EXPECT_THAT(GetRequiredPrivileges(storage),
auto *query = QUERY(SINGLE_QUERY(CREATE(PATTERN(NODE("n")))));
EXPECT_THAT(GetRequiredPrivileges(query),
UnorderedElementsAre(AuthQuery::Privilege::CREATE));
}
TEST_F(TestPrivilegeExtractor, MatchNodeDelete) {
QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), DELETE(IDENT("n"))));
EXPECT_THAT(GetRequiredPrivileges(storage),
auto *query =
QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), DELETE(IDENT("n"))));
EXPECT_THAT(GetRequiredPrivileges(query),
UnorderedElementsAre(AuthQuery::Privilege::MATCH,
AuthQuery::Privilege::DELETE));
}
TEST_F(TestPrivilegeExtractor, MatchNodeReturn) {
QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN("n")));
EXPECT_THAT(GetRequiredPrivileges(storage),
auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN("n")));
EXPECT_THAT(GetRequiredPrivileges(query),
UnorderedElementsAre(AuthQuery::Privilege::MATCH));
}
TEST_F(TestPrivilegeExtractor, MatchCreateExpand) {
QUERY(SINGLE_QUERY(
auto *query = QUERY(SINGLE_QUERY(
MATCH(PATTERN(NODE("n"))),
CREATE(PATTERN(NODE("n"),
EDGE("r", EdgeAtom::Direction::OUT, {EDGE_TYPE}),
NODE("m")))));
EXPECT_THAT(GetRequiredPrivileges(storage),
EXPECT_THAT(GetRequiredPrivileges(query),
UnorderedElementsAre(AuthQuery::Privilege::MATCH,
AuthQuery::Privilege::CREATE));
}
TEST_F(TestPrivilegeExtractor, MatchNodeSetLabels) {
QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), SET("n", {LABEL_0, LABEL_1})));
EXPECT_THAT(GetRequiredPrivileges(storage),
auto *query = QUERY(
SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), SET("n", {LABEL_0, LABEL_1})));
EXPECT_THAT(GetRequiredPrivileges(query),
UnorderedElementsAre(AuthQuery::Privilege::MATCH,
AuthQuery::Privilege::SET));
}
TEST_F(TestPrivilegeExtractor, MatchNodeSetProperty) {
QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))),
SET(PROPERTY_LOOKUP("n", {"prop", PROP_0}), LITERAL(42))));
EXPECT_THAT(GetRequiredPrivileges(storage),
auto *query = QUERY(
SINGLE_QUERY(MATCH(PATTERN(NODE("n"))),
SET(PROPERTY_LOOKUP("n", {"prop", PROP_0}), LITERAL(42))));
EXPECT_THAT(GetRequiredPrivileges(query),
UnorderedElementsAre(AuthQuery::Privilege::MATCH,
AuthQuery::Privilege::SET));
}
TEST_F(TestPrivilegeExtractor, MatchNodeSetProperties) {
QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), SET("n", LIST())));
EXPECT_THAT(GetRequiredPrivileges(storage),
auto *query =
QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), SET("n", LIST())));
EXPECT_THAT(GetRequiredPrivileges(query),
UnorderedElementsAre(AuthQuery::Privilege::MATCH,
AuthQuery::Privilege::SET));
}
TEST_F(TestPrivilegeExtractor, MatchNodeRemoveLabels) {
QUERY(
auto *query = QUERY(
SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), REMOVE("n", {LABEL_0, LABEL_1})));
EXPECT_THAT(GetRequiredPrivileges(storage),
EXPECT_THAT(GetRequiredPrivileges(query),
UnorderedElementsAre(AuthQuery::Privilege::MATCH,
AuthQuery::Privilege::REMOVE));
}
TEST_F(TestPrivilegeExtractor, MatchNodeRemoveProperty) {
QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))),
REMOVE(PROPERTY_LOOKUP("n", {"prop", PROP_0}))));
EXPECT_THAT(GetRequiredPrivileges(storage),
auto *query =
QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))),
REMOVE(PROPERTY_LOOKUP("n", {"prop", PROP_0}))));
EXPECT_THAT(GetRequiredPrivileges(query),
UnorderedElementsAre(AuthQuery::Privilege::MATCH,
AuthQuery::Privilege::REMOVE));
}
TEST_F(TestPrivilegeExtractor, CreateIndex) {
QUERY(SINGLE_QUERY(CREATE_INDEX_ON(LABEL_0, PROP_0)));
EXPECT_THAT(GetRequiredPrivileges(storage),
auto *query = QUERY(SINGLE_QUERY(CREATE_INDEX_ON(LABEL_0, PROP_0)));
EXPECT_THAT(GetRequiredPrivileges(query),
UnorderedElementsAre(AuthQuery::Privilege::INDEX));
}
TEST_F(TestPrivilegeExtractor, AuthQuery) {
QUERY(SINGLE_QUERY(AUTH_QUERY(AuthQuery::Action::CREATE_ROLE, "", "role", "",
nullptr, std::vector<AuthQuery::Privilege>{})));
EXPECT_THAT(GetRequiredPrivileges(storage),
auto *query = QUERY(
SINGLE_QUERY(AUTH_QUERY(AuthQuery::Action::CREATE_ROLE, "", "role", "",
nullptr, std::vector<AuthQuery::Privilege>{})));
EXPECT_THAT(GetRequiredPrivileges(query),
UnorderedElementsAre(AuthQuery::Privilege::AUTH));
}
@ -121,8 +127,8 @@ TEST_F(TestPrivilegeExtractor, StreamQuery) {
STOP_ALL_STREAMS};
for (auto *stream_clause : stream_clauses) {
QUERY(SINGLE_QUERY(stream_clause));
EXPECT_THAT(GetRequiredPrivileges(storage),
auto *query = QUERY(SINGLE_QUERY(stream_clause));
EXPECT_THAT(GetRequiredPrivileges(query),
UnorderedElementsAre(AuthQuery::Privilege::STREAM));
}
}

View File

@ -63,12 +63,13 @@ void AssertRows(const std::vector<std::vector<TypedValue>> &datum,
};
void CheckPlansProduce(
size_t expected_plan_count, AstStorage &storage,
size_t expected_plan_count, query::Query *query, AstStorage &storage,
database::GraphDbAccessor &dba,
std::function<void(const std::vector<std::vector<TypedValue>> &)> check) {
auto symbol_table = MakeSymbolTable(*storage.query());
auto planning_context = MakePlanningContext(storage, symbol_table, dba);
auto query_parts = CollectQueryParts(symbol_table, storage);
auto symbol_table = MakeSymbolTable(*query);
auto planning_context =
MakePlanningContext(storage, symbol_table, query, dba);
auto query_parts = CollectQueryParts(symbol_table, storage, query);
EXPECT_TRUE(query_parts.query_parts.size() > 0);
auto single_query_parts = query_parts.query_parts.at(0).single_query_parts;
auto plans = MakeLogicalPlanForSingleQuery<VariableStartPlanner>(
@ -92,11 +93,11 @@ TEST(TestVariableStartPlanner, MatchReturn) {
dba->AdvanceCommand();
// Test MATCH (n) -[r]-> (m) RETURN n
AstStorage storage;
QUERY(SINGLE_QUERY(
auto *query = QUERY(SINGLE_QUERY(
MATCH(PATTERN(NODE("n"), EDGE("r", Direction::OUT), NODE("m"))),
RETURN("n")));
// We have 2 nodes `n` and `m` from which we could start, so expect 2 plans.
CheckPlansProduce(2, storage, *dba, [&](const auto &results) {
CheckPlansProduce(2, query, storage, *dba, [&](const auto &results) {
// We expect to produce only a single (v1) node.
AssertRows(results, {{v1}});
});
@ -115,12 +116,12 @@ TEST(TestVariableStartPlanner, MatchTripletPatternReturn) {
{
// Test `MATCH (n) -[r]-> (m) -[e]-> (l) RETURN n`
AstStorage storage;
QUERY(SINGLE_QUERY(
auto *query = QUERY(SINGLE_QUERY(
MATCH(PATTERN(NODE("n"), EDGE("r", Direction::OUT), NODE("m"),
EDGE("e", Direction::OUT), NODE("l"))),
RETURN("n")));
// We have 3 nodes: `n`, `m` and `l` from which we could start.
CheckPlansProduce(3, storage, *dba, [&](const auto &results) {
CheckPlansProduce(3, query, storage, *dba, [&](const auto &results) {
// We expect to produce only a single (v1) node.
AssertRows(results, {{v1}});
});
@ -128,11 +129,11 @@ TEST(TestVariableStartPlanner, MatchTripletPatternReturn) {
{
// Equivalent to `MATCH (n) -[r]-> (m), (m) -[e]-> (l) RETURN n`.
AstStorage storage;
QUERY(SINGLE_QUERY(
auto *query = QUERY(SINGLE_QUERY(
MATCH(PATTERN(NODE("n"), EDGE("r", Direction::OUT), NODE("m")),
PATTERN(NODE("m"), EDGE("e", Direction::OUT), NODE("l"))),
RETURN("n")));
CheckPlansProduce(3, storage, *dba, [&](const auto &results) {
CheckPlansProduce(3, query, storage, *dba, [&](const auto &results) {
AssertRows(results, {{v1}});
});
}
@ -150,13 +151,13 @@ TEST(TestVariableStartPlanner, MatchOptionalMatchReturn) {
dba->AdvanceCommand();
// Test MATCH (n) -[r]-> (m) OPTIONAL MATCH (m) -[e]-> (l) RETURN n, l
AstStorage storage;
QUERY(SINGLE_QUERY(
auto *query = QUERY(SINGLE_QUERY(
MATCH(PATTERN(NODE("n"), EDGE("r", Direction::OUT), NODE("m"))),
OPTIONAL_MATCH(PATTERN(NODE("m"), EDGE("e", Direction::OUT), NODE("l"))),
RETURN("n", "l")));
// We have 2 nodes `n` and `m` from which we could start the MATCH, and 2
// nodes for OPTIONAL MATCH. This should produce 2 * 2 plans.
CheckPlansProduce(4, storage, *dba, [&](const auto &results) {
CheckPlansProduce(4, query, storage, *dba, [&](const auto &results) {
// We expect to produce 2 rows:
// * (v1), (v3)
// * (v2), null
@ -176,14 +177,14 @@ TEST(TestVariableStartPlanner, MatchOptionalMatchMergeReturn) {
// Test MATCH (n) -[r]-> (m) OPTIONAL MATCH (m) -[e]-> (l)
// MERGE (u) -[q:r]-> (v) RETURN n, m, l, u, v
AstStorage storage;
QUERY(SINGLE_QUERY(
auto *query = QUERY(SINGLE_QUERY(
MATCH(PATTERN(NODE("n"), EDGE("r", Direction::OUT), NODE("m"))),
OPTIONAL_MATCH(PATTERN(NODE("m"), EDGE("e", Direction::OUT), NODE("l"))),
MERGE(PATTERN(NODE("u"), EDGE("q", Direction::OUT, {r_type}), NODE("v"))),
RETURN("n", "m", "l", "u", "v")));
// Since MATCH, OPTIONAL MATCH and MERGE each have 2 nodes from which we can
// start, we generate 2 * 2 * 2 plans.
CheckPlansProduce(8, storage, *dba, [&](const auto &results) {
CheckPlansProduce(8, query, storage, *dba, [&](const auto &results) {
// We expect to produce a single row: (v1), (v2), null, (v1), (v2)
AssertRows(results, {{v1, v2, TypedValue::Null, v1, v2}});
});
@ -199,14 +200,14 @@ TEST(TestVariableStartPlanner, MatchWithMatchReturn) {
dba->AdvanceCommand();
// Test MATCH (n) -[r]-> (m) WITH n MATCH (m) -[r]-> (l) RETURN n, m, l
AstStorage storage;
QUERY(SINGLE_QUERY(
auto *query = QUERY(SINGLE_QUERY(
MATCH(PATTERN(NODE("n"), EDGE("r", Direction::OUT), NODE("m"))),
WITH("n"),
MATCH(PATTERN(NODE("m"), EDGE("r", Direction::OUT), NODE("l"))),
RETURN("n", "m", "l")));
// We can start from 2 nodes in each match. Since WITH separates query parts,
// we expect to get 2 plans for each, which totals 2 * 2.
CheckPlansProduce(4, storage, *dba, [&](const auto &results) {
CheckPlansProduce(4, query, storage, *dba, [&](const auto &results) {
// We expect to produce a single row: (v1), (v1), (v2)
AssertRows(results, {{v1, v1, v2}});
});
@ -225,12 +226,13 @@ TEST(TestVariableStartPlanner, MatchVariableExpand) {
// Test MATCH (n) -[r*]-> (m) RETURN r
AstStorage storage;
auto edge = EDGE_VARIABLE("r", Direction::OUT);
QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), edge, NODE("m"))), RETURN("r")));
auto *query = QUERY(
SINGLE_QUERY(MATCH(PATTERN(NODE("n"), edge, NODE("m"))), RETURN("r")));
// We expect to get a single column with the following rows:
TypedValue r1_list(std::vector<TypedValue>{r1}); // [r1]
TypedValue r2_list(std::vector<TypedValue>{r2}); // [r2]
TypedValue r1_r2_list(std::vector<TypedValue>{r1, r2}); // [r1, r2]
CheckPlansProduce(2, storage, *dba, [&](const auto &results) {
CheckPlansProduce(2, query, storage, *dba, [&](const auto &results) {
AssertRows(results, {{r1_list}, {r2_list}, {r1_r2_list}});
});
}
@ -254,11 +256,12 @@ TEST(TestVariableStartPlanner, MatchVariableExpandReferenceNode) {
AstStorage storage;
auto edge = EDGE_VARIABLE("r", Direction::OUT);
edge->upper_bound_ = PROPERTY_LOOKUP("n", id);
QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), edge, NODE("m"))), RETURN("r")));
auto *query = QUERY(
SINGLE_QUERY(MATCH(PATTERN(NODE("n"), edge, NODE("m"))), RETURN("r")));
// We expect to get a single column with the following rows:
TypedValue r1_list(std::vector<TypedValue>{r1}); // [r1] (v1 -[*..1]-> v2)
TypedValue r2_list(std::vector<TypedValue>{r2}); // [r2] (v2 -[*..2]-> v3)
CheckPlansProduce(2, storage, dba, [&](const auto &results) {
CheckPlansProduce(2, query, storage, dba, [&](const auto &results) {
AssertRows(results, {{r1_list}, {r2_list}});
});
}
@ -280,11 +283,12 @@ TEST(TestVariableStartPlanner, MatchVariableExpandBoth) {
auto edge = EDGE_VARIABLE("r", Direction::BOTH);
auto node_n = NODE("n");
node_n->properties_[std::make_pair("id", id)] = LITERAL(1);
QUERY(SINGLE_QUERY(MATCH(PATTERN(node_n, edge, NODE("m"))), RETURN("r")));
auto *query =
QUERY(SINGLE_QUERY(MATCH(PATTERN(node_n, edge, NODE("m"))), RETURN("r")));
// We expect to get a single column with the following rows:
TypedValue r1_list(std::vector<TypedValue>{r1}); // [r1]
TypedValue r1_r2_list(std::vector<TypedValue>{r1, r2}); // [r1, r2]
CheckPlansProduce(2, storage, *dba, [&](const auto &results) {
CheckPlansProduce(2, query, storage, *dba, [&](const auto &results) {
AssertRows(results, {{r1_list}, {r1_r2_list}});
});
}
@ -313,10 +317,11 @@ TEST(TestVariableStartPlanner, MatchBfs) {
bfs->filter_lambda_.inner_node = IDENT("n");
bfs->filter_lambda_.expression = NEQ(PROPERTY_LOOKUP("n", id), LITERAL(3));
bfs->upper_bound_ = LITERAL(10);
QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), bfs, NODE("m"))), RETURN("r")));
auto *query = QUERY(
SINGLE_QUERY(MATCH(PATTERN(NODE("n"), bfs, NODE("m"))), RETURN("r")));
// We expect to get a single column with the following rows:
TypedValue r1_list(std::vector<TypedValue>{r1}); // [r1]
CheckPlansProduce(2, storage, dba, [&](const auto &results) {
CheckPlansProduce(2, query, storage, dba, [&](const auto &results) {
AssertRows(results, {{r1_list}});
});
}