add support for set and remove labels clauses

This commit is contained in:
DavIvek 2024-02-29 14:41:04 +01:00
parent 003d08c24a
commit 147a0e545a
6 changed files with 113 additions and 23 deletions

View File

@ -2632,20 +2632,30 @@ class SetLabels : public memgraph::query::Clause {
}
memgraph::query::Identifier *identifier_{nullptr};
std::vector<memgraph::query::LabelIx> labels_;
std::vector<std::variant<memgraph::query::LabelIx, memgraph::query::Expression *>> labels_;
SetLabels *Clone(AstStorage *storage) const override {
SetLabels *object = storage->Create<SetLabels>();
object->identifier_ = identifier_ ? identifier_->Clone(storage) : nullptr;
object->labels_.resize(labels_.size());
for (auto i = 0; i < object->labels_.size(); ++i) {
object->labels_[i] = storage->GetLabelIx(labels_[i].name);
if (const auto *label = std::get_if<LabelIx>(&labels_[i])) {
object->labels_[i] = storage->GetLabelIx(label->name);
} else {
object->labels_[i] = std::get<Expression *>(labels_[i])->Clone(storage);
}
}
return object;
}
protected:
SetLabels(Identifier *identifier, const std::vector<LabelIx> &labels) : identifier_(identifier), labels_(labels) {}
SetLabels(Identifier *identifier, const std::vector<std::variant<LabelIx, Expression *>> &labels)
: identifier_(identifier), labels_(labels) {}
SetLabels(Identifier *identifier, std::vector<LabelIx> labels) : identifier_(identifier) {
for (auto &label : labels) {
labels_.emplace_back(label);
}
}
private:
friend class AstStorage;
@ -2695,20 +2705,30 @@ class RemoveLabels : public memgraph::query::Clause {
}
memgraph::query::Identifier *identifier_{nullptr};
std::vector<memgraph::query::LabelIx> labels_;
std::vector<std::variant<memgraph::query::LabelIx, memgraph::query::Expression *>> labels_;
RemoveLabels *Clone(AstStorage *storage) const override {
RemoveLabels *object = storage->Create<RemoveLabels>();
object->identifier_ = identifier_ ? identifier_->Clone(storage) : nullptr;
object->labels_.resize(labels_.size());
for (auto i = 0; i < object->labels_.size(); ++i) {
object->labels_[i] = storage->GetLabelIx(labels_[i].name);
if (const auto *label = std::get_if<LabelIx>(&labels_[i])) {
object->labels_[i] = storage->GetLabelIx(label->name);
} else {
object->labels_[i] = std::get<Expression *>(labels_[i])->Clone(storage);
}
}
return object;
}
protected:
RemoveLabels(Identifier *identifier, const std::vector<LabelIx> &labels) : identifier_(identifier), labels_(labels) {}
RemoveLabels(Identifier *identifier, const std::vector<std::variant<LabelIx, Expression *>> &labels)
: identifier_(identifier), labels_(labels) {}
RemoveLabels(Identifier *identifier, std::vector<LabelIx> labels) : identifier_(identifier) {
for (auto &label : labels) {
labels_.emplace_back(label);
}
}
private:
friend class AstStorage;

View File

@ -2814,7 +2814,8 @@ antlrcpp::Any CypherMainVisitor::visitSetItem(MemgraphCypher::SetItemContext *ct
// SetLabels
auto *set_labels = storage_->Create<SetLabels>();
set_labels->identifier_ = storage_->Create<Identifier>(std::any_cast<std::string>(ctx->variable()->accept(this)));
set_labels->labels_ = std::any_cast<std::vector<LabelIx>>(ctx->nodeLabels()->accept(this));
set_labels->labels_ =
std::any_cast<std::vector<std::variant<LabelIx, Expression *>>>(ctx->nodeLabels()->accept(this));
return static_cast<Clause *>(set_labels);
}
@ -2837,7 +2838,8 @@ antlrcpp::Any CypherMainVisitor::visitRemoveItem(MemgraphCypher::RemoveItemConte
// RemoveLabels
auto *remove_labels = storage_->Create<RemoveLabels>();
remove_labels->identifier_ = storage_->Create<Identifier>(std::any_cast<std::string>(ctx->variable()->accept(this)));
remove_labels->labels_ = std::any_cast<std::vector<LabelIx>>(ctx->nodeLabels()->accept(this));
remove_labels->labels_ =
std::any_cast<std::vector<std::variant<LabelIx, Expression *>>>(ctx->nodeLabels()->accept(this));
return static_cast<Clause *>(remove_labels);
}

View File

@ -3142,9 +3142,18 @@ void SetProperties::SetPropertiesCursor::Shutdown() { input_cursor_->Shutdown();
void SetProperties::SetPropertiesCursor::Reset() { input_cursor_->Reset(); }
SetLabels::SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<storage::LabelId> &labels)
const std::vector<std::variant<storage::LabelId, query::Expression *>> &labels)
: input_(input), input_symbol_(std::move(input_symbol)), labels_(labels) {}
SetLabels::SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<storage::LabelId> &labels)
: input_(input), input_symbol_(std::move(input_symbol)) {
this->labels_.reserve(labels.size());
for (const auto &label : labels) {
this->labels_.emplace_back(label);
}
}
ACCEPT_WITH_INPUT(SetLabels)
UniqueCursorPtr SetLabels::MakeCursor(utils::MemoryResource *mem) const {
@ -3163,10 +3172,21 @@ SetLabels::SetLabelsCursor::SetLabelsCursor(const SetLabels &self, utils::Memory
bool SetLabels::SetLabelsCursor::Pull(Frame &frame, ExecutionContext &context) {
OOMExceptionEnabler oom_exception;
SCOPED_PROFILE_OP("SetLabels");
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
storage::View::NEW);
std::vector<storage::LabelId> labels;
for (const auto &label : self_.labels_) {
if (std::holds_alternative<storage::LabelId>(label)) {
labels.push_back(std::get<storage::LabelId>(label));
} else {
labels.push_back(
context.db_accessor->NameToLabel(std::get<query::Expression *>(label)->Accept(evaluator).ValueString()));
}
}
#ifdef MG_ENTERPRISE
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
!context.auth_checker->Has(self_.labels_, memgraph::query::AuthQuery::FineGrainedPrivilege::CREATE_DELETE)) {
!context.auth_checker->Has(labels, memgraph::query::AuthQuery::FineGrainedPrivilege::CREATE_DELETE)) {
throw QueryRuntimeException("Couldn't set label due to not having enough permission!");
}
#endif
@ -3187,7 +3207,7 @@ bool SetLabels::SetLabelsCursor::Pull(Frame &frame, ExecutionContext &context) {
}
#endif
for (auto label : self_.labels_) {
for (auto label : labels) {
auto maybe_value = vertex.AddLabel(label);
if (maybe_value.HasError()) {
switch (maybe_value.GetError()) {
@ -3302,9 +3322,18 @@ void RemoveProperty::RemovePropertyCursor::Shutdown() { input_cursor_->Shutdown(
void RemoveProperty::RemovePropertyCursor::Reset() { input_cursor_->Reset(); }
RemoveLabels::RemoveLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<storage::LabelId> &labels)
const std::vector<std::variant<storage::LabelId, query::Expression *>> &labels)
: input_(input), input_symbol_(std::move(input_symbol)), labels_(labels) {}
RemoveLabels::RemoveLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<storage::LabelId> &labels)
: input_(input), input_symbol_(std::move(input_symbol)) {
this->labels_.reserve(labels.size());
for (const auto &label : labels) {
this->labels_.push_back(label);
}
}
ACCEPT_WITH_INPUT(RemoveLabels)
UniqueCursorPtr RemoveLabels::MakeCursor(utils::MemoryResource *mem) const {
@ -3323,10 +3352,21 @@ RemoveLabels::RemoveLabelsCursor::RemoveLabelsCursor(const RemoveLabels &self, u
bool RemoveLabels::RemoveLabelsCursor::Pull(Frame &frame, ExecutionContext &context) {
OOMExceptionEnabler oom_exception;
SCOPED_PROFILE_OP("RemoveLabels");
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
storage::View::NEW);
std::vector<storage::LabelId> labels;
for (const auto &label : self_.labels_) {
if (std::holds_alternative<storage::LabelId>(label)) {
labels.push_back(std::get<storage::LabelId>(label));
} else {
labels.push_back(
context.db_accessor->NameToLabel(std::get<query::Expression *>(label)->Accept(evaluator).ValueString()));
}
}
#ifdef MG_ENTERPRISE
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
!context.auth_checker->Has(self_.labels_, memgraph::query::AuthQuery::FineGrainedPrivilege::CREATE_DELETE)) {
!context.auth_checker->Has(labels, memgraph::query::AuthQuery::FineGrainedPrivilege::CREATE_DELETE)) {
throw QueryRuntimeException("Couldn't remove label due to not having enough permission!");
}
#endif
@ -3347,7 +3387,7 @@ bool RemoveLabels::RemoveLabelsCursor::Pull(Frame &frame, ExecutionContext &cont
}
#endif
for (auto label : self_.labels_) {
for (auto label : labels) {
auto maybe_value = vertex.RemoveLabel(label);
if (maybe_value.HasError()) {
switch (maybe_value.GetError()) {

View File

@ -1442,6 +1442,8 @@ class SetLabels : public memgraph::query::plan::LogicalOperator {
SetLabels() = default;
SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<std::variant<storage::LabelId, query::Expression *>> &labels);
SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<storage::LabelId> &labels);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
@ -1454,7 +1456,7 @@ class SetLabels : public memgraph::query::plan::LogicalOperator {
std::shared_ptr<memgraph::query::plan::LogicalOperator> input_;
Symbol input_symbol_;
std::vector<storage::LabelId> labels_;
std::vector<std::variant<storage::LabelId, query::Expression *>> labels_;
std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override {
auto object = std::make_unique<SetLabels>();
@ -1531,6 +1533,8 @@ class RemoveLabels : public memgraph::query::plan::LogicalOperator {
RemoveLabels() = default;
RemoveLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<std::variant<storage::LabelId, query::Expression *>> &labels);
RemoveLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<storage::LabelId> &labels);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
@ -1543,7 +1547,7 @@ class RemoveLabels : public memgraph::query::plan::LogicalOperator {
std::shared_ptr<memgraph::query::plan::LogicalOperator> input_;
Symbol input_symbol_;
std::vector<storage::LabelId> labels_;
std::vector<std::variant<storage::LabelId, query::Expression *>> labels_;
std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override {
auto object = std::make_unique<RemoveLabels>();

View File

@ -634,8 +634,15 @@ bool PlanToJsonVisitor::PreVisit(SetLabels &op) {
json self;
self["name"] = "SetLabels";
self["input_symbol"] = ToJson(op.input_symbol_);
self["labels"] = ToJson(op.labels_, *dba_);
std::vector<storage::LabelId> labels;
for (auto label : op.labels_) {
if (const auto *label_node = std::get_if<Expression *>(&label)) {
labels = {};
break;
}
labels.push_back(std::get<storage::LabelId>(label));
}
self["labels"] = ToJson(labels, *dba_);
op.input_->Accept(*this);
self["input"] = PopOutput();
@ -660,7 +667,16 @@ bool PlanToJsonVisitor::PreVisit(RemoveLabels &op) {
json self;
self["name"] = "RemoveLabels";
self["input_symbol"] = ToJson(op.input_symbol_);
self["labels"] = ToJson(op.labels_, *dba_);
// not a solution, have to fix it
std::vector<storage::LabelId> labels;
for (auto label : op.labels_) {
if (const auto *label_node = std::get_if<Expression *>(&label)) {
labels = {};
break;
}
labels.push_back(std::get<storage::LabelId>(label));
}
self["labels"] = ToJson(labels, *dba_);
op.input_->Accept(*this);
self["input"] = PopOutput();

View File

@ -414,10 +414,14 @@ class RuleBasedPlanner {
return std::make_unique<plan::SetProperties>(std::move(input_op), input_symbol, set->expression_, op);
} else if (auto *set = utils::Downcast<query::SetLabels>(clause)) {
const auto &input_symbol = symbol_table.at(*set->identifier_);
std::vector<storage::LabelId> labels;
std::vector<std::variant<storage::LabelId, query::Expression *>> labels;
labels.reserve(set->labels_.size());
for (const auto &label : set->labels_) {
labels.push_back(GetLabel(label));
if (const auto *label_atom = std::get_if<LabelIx>(&label)) {
labels.emplace_back(GetLabel(*label_atom));
} else {
labels.emplace_back(std::get<query::Expression *>(label));
}
}
return std::make_unique<plan::SetLabels>(std::move(input_op), input_symbol, labels);
} else if (auto *rem = utils::Downcast<query::RemoveProperty>(clause)) {
@ -425,10 +429,14 @@ class RuleBasedPlanner {
rem->property_lookup_);
} else if (auto *rem = utils::Downcast<query::RemoveLabels>(clause)) {
const auto &input_symbol = symbol_table.at(*rem->identifier_);
std::vector<storage::LabelId> labels;
std::vector<std::variant<storage::LabelId, query::Expression *>> labels;
labels.reserve(rem->labels_.size());
for (const auto &label : rem->labels_) {
labels.push_back(GetLabel(label));
if (const auto *label_atom = std::get_if<LabelIx>(&label)) {
labels.emplace_back(GetLabel(*label_atom));
} else {
labels.emplace_back(std::get<query::Expression *>(label));
}
}
return std::make_unique<plan::RemoveLabels>(std::move(input_op), input_symbol, labels);
}