diff --git a/src/query/frontend/ast/ast.lcp b/src/query/frontend/ast/ast.lcp index 577d15113..a1f3c8ef5 100644 --- a/src/query/frontend/ast/ast.lcp +++ b/src/query/frontend/ast/ast.lcp @@ -301,6 +301,7 @@ cpp<# (lcp:define-class expression (tree "::utils::Visitable<HierarchicalTreeVisitor>" "::utils::Visitable<ExpressionVisitor<TypedValue>>" + "::utils::Visitable<ExpressionVisitor<TypedValue*>>" "::utils::Visitable<ExpressionVisitor<void>>") () (:abstractp t) @@ -308,6 +309,7 @@ cpp<# #>cpp using utils::Visitable<HierarchicalTreeVisitor>::Accept; using utils::Visitable<ExpressionVisitor<TypedValue>>::Accept; + using utils::Visitable<ExpressionVisitor<TypedValue*>>::Accept; using utils::Visitable<ExpressionVisitor<void>>::Accept; Expression() = default; @@ -407,6 +409,7 @@ cpp<# (:public #>cpp DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); bool Accept(HierarchicalTreeVisitor &visitor) override { if (visitor.PreVisit(*this)) { @@ -438,6 +441,7 @@ cpp<# (:public #>cpp DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); bool Accept(HierarchicalTreeVisitor &visitor) override { if (visitor.PreVisit(*this)) { @@ -485,6 +489,7 @@ cpp<# } DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); bool Accept(HierarchicalTreeVisitor &visitor) override { if (visitor.PreVisit(*this)) { @@ -538,6 +543,7 @@ cpp<# ListSlicingOperator() = default; DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); bool Accept(HierarchicalTreeVisitor &visitor) override { if (visitor.PreVisit(*this)) { @@ -581,6 +587,7 @@ cpp<# IfOperator() = default; DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); bool Accept(HierarchicalTreeVisitor &visitor) override { if (visitor.PreVisit(*this)) { @@ -628,6 +635,7 @@ cpp<# PrimitiveLiteral() = default; DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); DEFVISITABLE(HierarchicalTreeVisitor); cpp<#) @@ -656,6 +664,7 @@ cpp<# ListLiteral() = default; DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); bool Accept(HierarchicalTreeVisitor &visitor) override { if (visitor.PreVisit(*this)) { @@ -688,6 +697,7 @@ cpp<# MapLiteral() = default; DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); bool Accept(HierarchicalTreeVisitor &visitor) override { if (visitor.PreVisit(*this)) { @@ -720,6 +730,7 @@ cpp<# Identifier() = default; DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); DEFVISITABLE(HierarchicalTreeVisitor); @@ -757,6 +768,7 @@ cpp<# PropertyLookup() = default; DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); bool Accept(HierarchicalTreeVisitor &visitor) override { if (visitor.PreVisit(*this)) { @@ -797,6 +809,7 @@ cpp<# LabelsTest() = default; DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); bool Accept(HierarchicalTreeVisitor &visitor) override { if (visitor.PreVisit(*this)) { @@ -837,6 +850,7 @@ cpp<# Function() = default; DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); bool Accept(HierarchicalTreeVisitor &visitor) override { if (visitor.PreVisit(*this)) { @@ -892,6 +906,7 @@ cpp<# Reduce() = default; DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); bool Accept(HierarchicalTreeVisitor &visitor) override { if (visitor.PreVisit(*this)) { @@ -930,6 +945,7 @@ cpp<# Coalesce() = default; DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); bool Accept(HierarchicalTreeVisitor &visitor) override { if (visitor.PreVisit(*this)) { @@ -969,6 +985,7 @@ cpp<# Extract() = default; DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); bool Accept(HierarchicalTreeVisitor &visitor) override { if (visitor.PreVisit(*this)) { @@ -1005,6 +1022,7 @@ cpp<# All() = default; DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); bool Accept(HierarchicalTreeVisitor &visitor) override { if (visitor.PreVisit(*this)) { @@ -1046,6 +1064,7 @@ cpp<# Single() = default; DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); bool Accept(HierarchicalTreeVisitor &visitor) override { if (visitor.PreVisit(*this)) { @@ -1087,6 +1106,7 @@ cpp<# Any() = default; DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); bool Accept(HierarchicalTreeVisitor &visitor) override { if (visitor.PreVisit(*this)) { @@ -1128,6 +1148,7 @@ cpp<# None() = default; DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); bool Accept(HierarchicalTreeVisitor &visitor) override { if (visitor.PreVisit(*this)) { @@ -1159,6 +1180,7 @@ cpp<# ParameterLookup() = default; DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); DEFVISITABLE(HierarchicalTreeVisitor); cpp<#) @@ -1186,6 +1208,7 @@ cpp<# RegexMatch() = default; DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); bool Accept(HierarchicalTreeVisitor &visitor) override { if (visitor.PreVisit(*this)) { @@ -1205,6 +1228,7 @@ cpp<# (lcp:define-class named-expression (tree "::utils::Visitable<HierarchicalTreeVisitor>" "::utils::Visitable<ExpressionVisitor<TypedValue>>" + "::utils::Visitable<ExpressionVisitor<TypedValue*>>" "::utils::Visitable<ExpressionVisitor<void>>") ((name "std::string" :scope :public) (expression "Expression *" :initval "nullptr" :scope :public @@ -1223,6 +1247,7 @@ cpp<# NamedExpression() = default; DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue*>); DEFVISITABLE(ExpressionVisitor<void>); bool Accept(HierarchicalTreeVisitor &visitor) override { if (visitor.PreVisit(*this)) { diff --git a/src/query/interpret/eval.hpp b/src/query/interpret/eval.hpp index 1ea20846d..74df223a5 100644 --- a/src/query/interpret/eval.hpp +++ b/src/query/interpret/eval.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -31,6 +31,71 @@ namespace memgraph::query { +class ReferenceExpressionEvaluator : public ExpressionVisitor<TypedValue *> { + public: + ReferenceExpressionEvaluator(Frame *frame, const SymbolTable *symbol_table, const EvaluationContext *ctx) + : frame_(frame), symbol_table_(symbol_table), ctx_(ctx) {} + + using ExpressionVisitor<TypedValue *>::Visit; + + utils::MemoryResource *GetMemoryResource() const { return ctx_->memory; } + +#define UNSUCCESSFUL_VISIT(expr_name) \ + TypedValue *Visit(expr_name &expr) override { return nullptr; } + + TypedValue *Visit(Identifier &ident) override { return &frame_->at(symbol_table_->at(ident)); } + + UNSUCCESSFUL_VISIT(NamedExpression); + UNSUCCESSFUL_VISIT(OrOperator); + UNSUCCESSFUL_VISIT(XorOperator); + UNSUCCESSFUL_VISIT(AdditionOperator); + UNSUCCESSFUL_VISIT(SubtractionOperator); + UNSUCCESSFUL_VISIT(MultiplicationOperator); + UNSUCCESSFUL_VISIT(DivisionOperator); + UNSUCCESSFUL_VISIT(ModOperator); + UNSUCCESSFUL_VISIT(NotEqualOperator); + UNSUCCESSFUL_VISIT(EqualOperator); + UNSUCCESSFUL_VISIT(LessOperator); + UNSUCCESSFUL_VISIT(GreaterOperator); + UNSUCCESSFUL_VISIT(LessEqualOperator); + UNSUCCESSFUL_VISIT(GreaterEqualOperator); + + UNSUCCESSFUL_VISIT(NotOperator); + UNSUCCESSFUL_VISIT(UnaryPlusOperator); + UNSUCCESSFUL_VISIT(UnaryMinusOperator); + + UNSUCCESSFUL_VISIT(AndOperator); + UNSUCCESSFUL_VISIT(IfOperator); + UNSUCCESSFUL_VISIT(InListOperator); + + UNSUCCESSFUL_VISIT(SubscriptOperator); + + UNSUCCESSFUL_VISIT(ListSlicingOperator); + UNSUCCESSFUL_VISIT(IsNullOperator); + UNSUCCESSFUL_VISIT(PropertyLookup); + UNSUCCESSFUL_VISIT(LabelsTest); + + UNSUCCESSFUL_VISIT(PrimitiveLiteral); + UNSUCCESSFUL_VISIT(ListLiteral); + UNSUCCESSFUL_VISIT(MapLiteral); + UNSUCCESSFUL_VISIT(Aggregation); + UNSUCCESSFUL_VISIT(Coalesce); + UNSUCCESSFUL_VISIT(Function); + UNSUCCESSFUL_VISIT(Reduce); + UNSUCCESSFUL_VISIT(Extract); + UNSUCCESSFUL_VISIT(All); + UNSUCCESSFUL_VISIT(Single); + UNSUCCESSFUL_VISIT(Any); + UNSUCCESSFUL_VISIT(None); + UNSUCCESSFUL_VISIT(ParameterLookup); + UNSUCCESSFUL_VISIT(RegexMatch); + + private: + Frame *frame_; + const SymbolTable *symbol_table_; + const EvaluationContext *ctx_; +}; + class ExpressionEvaluator : public ExpressionVisitor<TypedValue> { public: ExpressionEvaluator(Frame *frame, const SymbolTable &symbol_table, const EvaluationContext &ctx, DbAccessor *dba, @@ -159,50 +224,53 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> { } TypedValue Visit(SubscriptOperator &list_indexing) override { - auto lhs = list_indexing.expression1_->Accept(*this); + ReferenceExpressionEvaluator referenceExpressionEvaluator(frame_, symbol_table_, ctx_); + + TypedValue *lhs_ptr = list_indexing.expression1_->Accept(referenceExpressionEvaluator); + TypedValue lhs; + const auto referenced = nullptr != lhs_ptr; + if (!referenced) { + lhs = list_indexing.expression1_->Accept(*this); + lhs_ptr = &lhs; + } auto index = list_indexing.expression2_->Accept(*this); - if (!lhs.IsList() && !lhs.IsMap() && !lhs.IsVertex() && !lhs.IsEdge() && !lhs.IsNull()) + if (!lhs_ptr->IsList() && !lhs_ptr->IsMap() && !lhs_ptr->IsVertex() && !lhs_ptr->IsEdge() && !lhs_ptr->IsNull()) throw QueryRuntimeException( "Expected a list, a map, a node or an edge to index with '[]', got " "{}.", - lhs.type()); - if (lhs.IsNull() || index.IsNull()) return TypedValue(ctx_->memory); - if (lhs.IsList()) { + lhs_ptr->type()); + if (lhs_ptr->IsNull() || index.IsNull()) return TypedValue(ctx_->memory); + if (lhs_ptr->IsList()) { if (!index.IsInt()) throw QueryRuntimeException("Expected an integer as a list index, got {}.", index.type()); auto index_int = index.ValueInt(); - // NOTE: Take non-const reference to list, so that we can move out the - // indexed element as the result. - auto &list = lhs.ValueList(); + auto &list = lhs_ptr->ValueList(); if (index_int < 0) { index_int += static_cast<int64_t>(list.size()); } if (index_int >= static_cast<int64_t>(list.size()) || index_int < 0) return TypedValue(ctx_->memory); - // NOTE: Explicit move is needed, so that we return the move constructed - // value and preserve the correct MemoryResource. - return std::move(list[index_int]); + return referenced ? TypedValue(list[index_int], ctx_->memory) + : TypedValue(std::move(list[index_int]), ctx_->memory); } - if (lhs.IsMap()) { + if (lhs_ptr->IsMap()) { if (!index.IsString()) throw QueryRuntimeException("Expected a string as a map index, got {}.", index.type()); // NOTE: Take non-const reference to map, so that we can move out the // looked-up element as the result. - auto &map = lhs.ValueMap(); + auto &map = lhs_ptr->ValueMap(); auto found = map.find(index.ValueString()); if (found == map.end()) return TypedValue(ctx_->memory); - // NOTE: Explicit move is needed, so that we return the move constructed - // value and preserve the correct MemoryResource. - return std::move(found->second); + return referenced ? TypedValue(found->second, ctx_->memory) : TypedValue(std::move(found->second), ctx_->memory); } - if (lhs.IsVertex()) { + if (lhs_ptr->IsVertex()) { if (!index.IsString()) throw QueryRuntimeException("Expected a string as a property name, got {}.", index.type()); - return TypedValue(GetProperty(lhs.ValueVertex(), index.ValueString()), ctx_->memory); + return {GetProperty(lhs_ptr->ValueVertex(), index.ValueString()), ctx_->memory}; } - if (lhs.IsEdge()) { + if (lhs_ptr->IsEdge()) { if (!index.IsString()) throw QueryRuntimeException("Expected a string as a property name, got {}.", index.type()); - return TypedValue(GetProperty(lhs.ValueEdge(), index.ValueString()), ctx_->memory); - } + return {GetProperty(lhs_ptr->ValueEdge(), index.ValueString()), ctx_->memory}; + }; // lhs is Null return TypedValue(ctx_->memory); @@ -258,7 +326,15 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> { } TypedValue Visit(PropertyLookup &property_lookup) override { - auto expression_result = property_lookup.expression_->Accept(*this); + ReferenceExpressionEvaluator referenceExpressionEvaluator(frame_, symbol_table_, ctx_); + + TypedValue *expression_result_ptr = property_lookup.expression_->Accept(referenceExpressionEvaluator); + TypedValue expression_result; + + if (nullptr == expression_result_ptr) { + expression_result = property_lookup.expression_->Accept(*this); + expression_result_ptr = &expression_result; + } auto maybe_date = [this](const auto &date, const auto &prop_name) -> std::optional<TypedValue> { if (prop_name == "year") { return TypedValue(date.year, ctx_->memory); @@ -332,42 +408,38 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> { } return std::nullopt; }; - switch (expression_result.type()) { + switch (expression_result_ptr->type()) { case TypedValue::Type::Null: return TypedValue(ctx_->memory); case TypedValue::Type::Vertex: - return TypedValue(GetProperty(expression_result.ValueVertex(), property_lookup.property_), ctx_->memory); + return TypedValue(GetProperty(expression_result_ptr->ValueVertex(), property_lookup.property_), ctx_->memory); case TypedValue::Type::Edge: - return TypedValue(GetProperty(expression_result.ValueEdge(), property_lookup.property_), ctx_->memory); + return TypedValue(GetProperty(expression_result_ptr->ValueEdge(), property_lookup.property_), ctx_->memory); case TypedValue::Type::Map: { - // NOTE: Take non-const reference to map, so that we can move out the - // looked-up element as the result. - auto &map = expression_result.ValueMap(); + auto &map = expression_result_ptr->ValueMap(); auto found = map.find(property_lookup.property_.name.c_str()); if (found == map.end()) return TypedValue(ctx_->memory); - // NOTE: Explicit move is needed, so that we return the move constructed - // value and preserve the correct MemoryResource. - return std::move(found->second); + return TypedValue(found->second, ctx_->memory); } case TypedValue::Type::Duration: { const auto &prop_name = property_lookup.property_.name; - const auto &dur = expression_result.ValueDuration(); + const auto &dur = expression_result_ptr->ValueDuration(); if (auto dur_field = maybe_duration(dur, prop_name); dur_field) { - return std::move(*dur_field); + return TypedValue(*dur_field, ctx_->memory); } throw QueryRuntimeException("Invalid property name {} for Duration", prop_name); } case TypedValue::Type::Date: { const auto &prop_name = property_lookup.property_.name; - const auto &date = expression_result.ValueDate(); + const auto &date = expression_result_ptr->ValueDate(); if (auto date_field = maybe_date(date, prop_name); date_field) { - return std::move(*date_field); + return TypedValue(*date_field, ctx_->memory); } throw QueryRuntimeException("Invalid property name {} for Date", prop_name); } case TypedValue::Type::LocalTime: { const auto &prop_name = property_lookup.property_.name; - const auto < = expression_result.ValueLocalTime(); + const auto < = expression_result_ptr->ValueLocalTime(); if (auto lt_field = maybe_local_time(lt, prop_name); lt_field) { return std::move(*lt_field); } @@ -375,20 +447,20 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> { } case TypedValue::Type::LocalDateTime: { const auto &prop_name = property_lookup.property_.name; - const auto &ldt = expression_result.ValueLocalDateTime(); + const auto &ldt = expression_result_ptr->ValueLocalDateTime(); if (auto date_field = maybe_date(ldt.date, prop_name); date_field) { return std::move(*date_field); } if (auto lt_field = maybe_local_time(ldt.local_time, prop_name); lt_field) { - return std::move(*lt_field); + return TypedValue(*lt_field, ctx_->memory); } throw QueryRuntimeException("Invalid property name {} for LocalDateTime", prop_name); } case TypedValue::Type::Graph: { const auto &prop_name = property_lookup.property_.name; - const auto &graph = expression_result.ValueGraph(); + const auto &graph = expression_result_ptr->ValueGraph(); if (auto graph_field = maybe_graph(graph, prop_name); graph_field) { - return std::move(*graph_field); + return TypedValue(*graph_field, ctx_->memory); } throw QueryRuntimeException("Invalid property name {} for Graph", prop_name); } diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp index 6e61e07d6..f1ae0d652 100644 --- a/tests/unit/query_expression_evaluator.cpp +++ b/tests/unit/query_expression_evaluator.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -26,6 +26,7 @@ #include "query/interpret/eval.hpp" #include "query/interpret/frame.hpp" #include "query/path.hpp" +#include "query/typed_value.hpp" #include "storage/v2/storage.hpp" #include "utils/exceptions.hpp" #include "utils/string.hpp" @@ -426,6 +427,47 @@ TEST_F(ExpressionEvaluatorTest, VertexAndEdgeIndexing) { } } +TEST_F(ExpressionEvaluatorTest, TypedValueListIndexing) { + auto list_vector = memgraph::utils::pmr::vector<TypedValue>(ctx.memory); + list_vector.emplace_back("string1"); + list_vector.emplace_back(TypedValue("string2")); + + auto *identifier = storage.Create<Identifier>("n"); + auto node_symbol = symbol_table.CreateSymbol("n", true); + identifier->MapTo(node_symbol); + frame[node_symbol] = TypedValue(list_vector, ctx.memory); + + { + // Legal indexing. + auto *op = storage.Create<SubscriptOperator>(identifier, storage.Create<PrimitiveLiteral>(0)); + auto value = Eval(op); + EXPECT_EQ(value.ValueString(), "string1"); + } + { + // Out of bounds indexing + auto *op = storage.Create<SubscriptOperator>(identifier, storage.Create<PrimitiveLiteral>(3)); + auto value = Eval(op); + EXPECT_TRUE(value.IsNull()); + } + { + // Out of bounds indexing with negative bound. + auto *op = storage.Create<SubscriptOperator>(identifier, storage.Create<PrimitiveLiteral>(-100)); + auto value = Eval(op); + EXPECT_TRUE(value.IsNull()); + } + { + // Legal indexing with negative index. + auto *op = storage.Create<SubscriptOperator>(identifier, storage.Create<PrimitiveLiteral>(-2)); + auto value = Eval(op); + EXPECT_EQ(value.ValueString(), "string1"); + } + { + // Indexing with incompatible type. + auto *op = storage.Create<SubscriptOperator>(identifier, storage.Create<PrimitiveLiteral>("bla")); + EXPECT_THROW(Eval(op), QueryRuntimeException); + } +} + TEST_F(ExpressionEvaluatorTest, ListSlicingOperator) { auto *list_literal = storage.Create<ListLiteral>( std::vector<Expression *>{storage.Create<PrimitiveLiteral>(1), storage.Create<PrimitiveLiteral>(2),