This commit is contained in:
DavIvek 2024-02-29 15:57:32 +01:00
parent 147a0e545a
commit 92e6c323c8
3 changed files with 18 additions and 7 deletions

View File

@ -1280,6 +1280,17 @@ class LabelsTest : public memgraph::query::Expression {
protected: protected:
LabelsTest(Expression *expression, const std::vector<LabelIx> &labels) : expression_(expression), labels_(labels) {} LabelsTest(Expression *expression, const std::vector<LabelIx> &labels) : expression_(expression), labels_(labels) {}
LabelsTest(Expression *expression, const std::vector<std::variant<LabelIx, Expression *>> &labels)
: expression_(expression) {
labels_.reserve(labels.size());
for (auto &label : labels) {
if (std::holds_alternative<LabelIx>(label)) {
labels_.push_back(std::get<LabelIx>(label));
} else {
throw SemanticException("You can't use expressions in labels test.");
}
}
}
private: private:
friend class AstStorage; friend class AstStorage;

View File

@ -2488,7 +2488,7 @@ antlrcpp::Any CypherMainVisitor::visitListIndexingOrSlicing(MemgraphCypher::List
antlrcpp::Any CypherMainVisitor::visitExpression2a(MemgraphCypher::Expression2aContext *ctx) { antlrcpp::Any CypherMainVisitor::visitExpression2a(MemgraphCypher::Expression2aContext *ctx) {
auto *expression = std::any_cast<Expression *>(ctx->expression2b()->accept(this)); auto *expression = std::any_cast<Expression *>(ctx->expression2b()->accept(this));
if (ctx->nodeLabels()) { if (ctx->nodeLabels()) {
auto labels = std::any_cast<std::vector<LabelIx>>(ctx->nodeLabels()->accept(this)); auto labels = std::any_cast<std::vector<std::variant<LabelIx, Expression *>>>(ctx->nodeLabels()->accept(this));
expression = storage_->Create<LabelsTest>(expression, labels); expression = storage_->Create<LabelsTest>(expression, labels);
} }
return expression; return expression;

View File

@ -380,8 +380,6 @@ bool CreateExpand::CreateExpandCursor::Pull(Frame &frame, ExecutionContext &cont
if (!input_cursor_->Pull(frame, context)) return false; if (!input_cursor_->Pull(frame, context)) return false;
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
storage::View::NEW); storage::View::NEW);
#ifdef MG_ENTERPRISE
std::vector<storage::LabelId> labels; std::vector<storage::LabelId> labels;
for (auto label : self_.node_info_.labels) { for (auto label : self_.node_info_.labels) {
if (const auto *label_atom = std::get_if<storage::LabelId>(&label)) { if (const auto *label_atom = std::get_if<storage::LabelId>(&label)) {
@ -391,6 +389,8 @@ bool CreateExpand::CreateExpandCursor::Pull(Frame &frame, ExecutionContext &cont
context.db_accessor->NameToLabel(std::get<Expression *>(label)->Accept(evaluator).ValueString())); context.db_accessor->NameToLabel(std::get<Expression *>(label)->Accept(evaluator).ValueString()));
} }
} }
#ifdef MG_ENTERPRISE
if (license::global_license_checker.IsEnterpriseValidFast()) { if (license::global_license_checker.IsEnterpriseValidFast()) {
const auto fine_grained_permission = self_.existing_node_ const auto fine_grained_permission = self_.existing_node_
? memgraph::query::AuthQuery::FineGrainedPrivilege::UPDATE ? memgraph::query::AuthQuery::FineGrainedPrivilege::UPDATE
@ -3148,9 +3148,9 @@ SetLabels::SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input
SetLabels::SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, SetLabels::SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<storage::LabelId> &labels) const std::vector<storage::LabelId> &labels)
: input_(input), input_symbol_(std::move(input_symbol)) { : input_(input), input_symbol_(std::move(input_symbol)) {
this->labels_.reserve(labels.size()); labels_.reserve(labels.size());
for (const auto &label : labels) { for (const auto &label : labels) {
this->labels_.emplace_back(label); labels_.emplace_back(label);
} }
} }
@ -3328,9 +3328,9 @@ RemoveLabels::RemoveLabels(const std::shared_ptr<LogicalOperator> &input, Symbol
RemoveLabels::RemoveLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, RemoveLabels::RemoveLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<storage::LabelId> &labels) const std::vector<storage::LabelId> &labels)
: input_(input), input_symbol_(std::move(input_symbol)) { : input_(input), input_symbol_(std::move(input_symbol)) {
this->labels_.reserve(labels.size()); labels_.reserve(labels.size());
for (const auto &label : labels) { for (const auto &label : labels) {
this->labels_.push_back(label); labels_.push_back(label);
} }
} }