diff --git a/src/query/frontend/ast/ast.cpp b/src/query/frontend/ast/ast.cpp index c5e4c84c4..6a9f05bad 100644 --- a/src/query/frontend/ast/ast.cpp +++ b/src/query/frontend/ast/ast.cpp @@ -293,4 +293,7 @@ constexpr utils::TypeInfo query::ShowDatabasesQuery::kType{utils::TypeId::AST_SH constexpr utils::TypeInfo query::EdgeImportModeQuery::kType{utils::TypeId::AST_EDGE_IMPORT_MODE_QUERY, "EdgeImportModeQuery", &query::Query::kType}; +constexpr utils::TypeInfo query::PatternComprehension::kType{utils::TypeId::AST_PATTERN_COMPREHENSION, + "PatternComprehension", &query::Expression::kType}; + } // namespace memgraph diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index 59860d5b0..b5e058491 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -3520,6 +3520,65 @@ class Exists : public memgraph::query::Expression { friend class AstStorage; }; +class PatternComprehension : public memgraph::query::Expression { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + PatternComprehension() = default; + + DEFVISITABLE(ExpressionVisitor); + DEFVISITABLE(ExpressionVisitor); + DEFVISITABLE(ExpressionVisitor); + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + if (variable_) { + variable_->Accept(visitor); + } + pattern_->Accept(visitor); + if (filter_) { + filter_->Accept(visitor); + } + resultExpr_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + PatternComprehension *MapTo(const Symbol &symbol) { + symbol_pos_ = symbol.position(); + return this; + } + + // The variable name. + Identifier *variable_{nullptr}; + // The pattern to match. + Pattern *pattern_{nullptr}; + // Optional WHERE clause for filtering. + Where *filter_{nullptr}; + // The projection expression. + Expression *resultExpr_{nullptr}; + + /// Symbol table position of the symbol this Aggregation is mapped to. + int32_t symbol_pos_{-1}; + + PatternComprehension *Clone(AstStorage *storage) const override { + PatternComprehension *object = storage->Create(); + object->pattern_ = pattern_ ? pattern_->Clone(storage) : nullptr; + object->filter_ = filter_ ? filter_->Clone(storage) : nullptr; + object->resultExpr_ = resultExpr_ ? resultExpr_->Clone(storage) : nullptr; + + object->symbol_pos_ = symbol_pos_; + return object; + } + + protected: + PatternComprehension(Identifier *variable, Pattern *pattern) : variable_(variable), pattern_(pattern) {} + + private: + friend class AstStorage; +}; + class CallSubquery : public memgraph::query::Clause { public: static const utils::TypeInfo kType; diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index 793c15a95..ff1586fe4 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -107,6 +107,7 @@ class Exists; class MultiDatabaseQuery; class ShowDatabasesQuery; class EdgeImportModeQuery; +class PatternComprehension; using TreeCompositeVisitor = utils::CompositeVisitor< SingleQuery, CypherUnion, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, AdditionOperator, @@ -116,7 +117,7 @@ using TreeCompositeVisitor = utils::CompositeVisitor< MapProjectionLiteral, PropertyLookup, AllPropertiesLookup, LabelsTest, Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Any, None, CallProcedure, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, Merge, Unwind, RegexMatch, LoadCsv, - Foreach, Exists, CallSubquery, CypherQuery>; + Foreach, Exists, CallSubquery, CypherQuery, PatternComprehension>; using TreeLeafVisitor = utils::LeafVisitor; @@ -137,7 +138,7 @@ class ExpressionVisitor ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral, MapProjectionLiteral, PropertyLookup, AllPropertiesLookup, LabelsTest, Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Any, None, - ParameterLookup, Identifier, PrimitiveLiteral, RegexMatch, Exists> {}; + ParameterLookup, Identifier, PrimitiveLiteral, RegexMatch, Exists, PatternComprehension> {}; template class QueryVisitor diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index b62c9e301..2d93fd757 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -1969,6 +1969,18 @@ antlrcpp::Any CypherMainVisitor::visitPatternElement(MemgraphCypher::PatternElem return pattern; } +antlrcpp::Any CypherMainVisitor::visitRelationshipsPattern(MemgraphCypher::RelationshipsPatternContext *ctx) { + auto *pattern = storage_->Create(); + pattern->atoms_.push_back(std::any_cast(ctx->nodePattern()->accept(this))); + for (auto *pattern_element_chain : ctx->patternElementChain()) { + auto element = std::any_cast>(pattern_element_chain->accept(this)); + pattern->atoms_.push_back(element.first); + pattern->atoms_.push_back(element.second); + } + anonymous_identifiers.push_back(&pattern->identifier_); + return pattern; +} + antlrcpp::Any CypherMainVisitor::visitPatternElementChain(MemgraphCypher::PatternElementChainContext *ctx) { return std::pair(std::any_cast(ctx->relationshipPattern()->accept(this)), std::any_cast(ctx->nodePattern()->accept(this))); @@ -2463,6 +2475,8 @@ antlrcpp::Any CypherMainVisitor::visitAtom(MemgraphCypher::AtomContext *ctx) { return static_cast(storage_->Create(ident, list, expr)); } else if (ctx->existsExpression()) { return std::any_cast(ctx->existsExpression()->accept(this)); + } else if (ctx->patternComprehension()) { + return std::any_cast(ctx->patternComprehension()->accept(this)); } // TODO: Implement this. We don't support comprehensions, filtering... at @@ -2523,6 +2537,19 @@ antlrcpp::Any CypherMainVisitor::visitExistsExpression(MemgraphCypher::ExistsExp return static_cast(exists); } +antlrcpp::Any CypherMainVisitor::visitPatternComprehension(MemgraphCypher::PatternComprehensionContext *ctx) { + auto *comprehension = storage_->Create(); + if (ctx->variable()) { + comprehension->variable_ = storage_->Create(std::any_cast(ctx->variable()->accept(this))); + } + comprehension->pattern_ = std::any_cast(ctx->relationshipsPattern()->accept(this)); + if (ctx->where()) { + comprehension->filter_ = std::any_cast(ctx->where()->accept(this)); + } + comprehension->resultExpr_ = std::any_cast(ctx->expression()->accept(this)); + return static_cast(comprehension); +} + antlrcpp::Any CypherMainVisitor::visitParenthesizedExpression(MemgraphCypher::ParenthesizedExpressionContext *ctx) { return std::any_cast(ctx->expression()->accept(this)); } diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index 7f75b0050..1aa887ad7 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -678,6 +678,11 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { */ antlrcpp::Any visitPatternElement(MemgraphCypher::PatternElementContext *ctx) override; + /** + * @return Pattern* + */ + antlrcpp::Any visitRelationshipsPattern(MemgraphCypher::RelationshipsPatternContext *ctx) override; + /** * @return vector> */ @@ -843,6 +848,11 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { */ antlrcpp::Any visitExistsExpression(MemgraphCypher::ExistsExpressionContext *ctx) override; + /** + * @return pattern comprehension (Expression) + */ + antlrcpp::Any visitPatternComprehension(MemgraphCypher::PatternComprehensionContext *ctx) override; + /** * @return Expression* */ diff --git a/src/query/frontend/ast/pretty_print.cpp b/src/query/frontend/ast/pretty_print.cpp index ef45afd7d..61bd23797 100644 --- a/src/query/frontend/ast/pretty_print.cpp +++ b/src/query/frontend/ast/pretty_print.cpp @@ -73,6 +73,7 @@ class ExpressionPrettyPrinter : public ExpressionVisitor { void Visit(ParameterLookup &op) override; void Visit(NamedExpression &op) override; void Visit(RegexMatch &op) override; + void Visit(PatternComprehension &op) override; private: std::ostream *out_; @@ -323,6 +324,10 @@ void ExpressionPrettyPrinter::Visit(NamedExpression &op) { void ExpressionPrettyPrinter::Visit(RegexMatch &op) { PrintOperator(out_, "=~", op.string_expr_, op.regex_); } +void ExpressionPrettyPrinter::Visit(PatternComprehension &op) { + PrintOperator(out_, "Pattern Comprehension", op.variable_, op.pattern_, op.filter_, op.resultExpr_); +} + } // namespace void PrintExpression(Expression *expr, std::ostream *out) { diff --git a/src/query/frontend/opencypher/grammar/Cypher.g4 b/src/query/frontend/opencypher/grammar/Cypher.g4 index bb435b85d..f4830ccef 100644 --- a/src/query/frontend/opencypher/grammar/Cypher.g4 +++ b/src/query/frontend/opencypher/grammar/Cypher.g4 @@ -296,7 +296,7 @@ functionName : symbolicName ( '.' symbolicName )* ; listComprehension : '[' filterExpression ( '|' expression )? ']' ; -patternComprehension : '[' ( variable '=' )? relationshipsPattern ( WHERE expression )? '|' expression ']' ; +patternComprehension : '[' ( variable '=' )? relationshipsPattern ( where )? '|' resultExpr=expression ']' ; propertyLookup : '.' ( propertyKeyName ) ; diff --git a/src/query/frontend/semantic/symbol_generator.hpp b/src/query/frontend/semantic/symbol_generator.hpp index 207bbddbd..f9e6468f6 100644 --- a/src/query/frontend/semantic/symbol_generator.hpp +++ b/src/query/frontend/semantic/symbol_generator.hpp @@ -249,6 +249,7 @@ class PropertyLookupEvaluationModeVisitor : public ExpressionVisitor { void Visit(ParameterLookup &op) override{}; void Visit(NamedExpression &op) override { op.expression_->Accept(*this); }; void Visit(RegexMatch &op) override{}; + void Visit(PatternComprehension &op) override{}; void Visit(PropertyLookup & /*property_lookup*/) override; diff --git a/src/query/frontend/semantic/symbol_table.hpp b/src/query/frontend/semantic/symbol_table.hpp index 0b521356c..cf462c437 100644 --- a/src/query/frontend/semantic/symbol_table.hpp +++ b/src/query/frontend/semantic/symbol_table.hpp @@ -52,6 +52,9 @@ class SymbolTable final { const Symbol &at(const NamedExpression &nexpr) const { return table_.at(nexpr.symbol_pos_); } const Symbol &at(const Aggregation &aggr) const { return table_.at(aggr.symbol_pos_); } const Symbol &at(const Exists &exists) const { return table_.at(exists.symbol_pos_); } + const Symbol &at(const PatternComprehension &pattern_comprehension) const { + return table_.at(pattern_comprehension.symbol_pos_); + } // TODO: Remove these since members are public int32_t max_position() const { return static_cast(table_.size()); } diff --git a/src/query/interpret/eval.hpp b/src/query/interpret/eval.hpp index f4f3126cd..017dc9101 100644 --- a/src/query/interpret/eval.hpp +++ b/src/query/interpret/eval.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -101,6 +101,7 @@ class ReferenceExpressionEvaluator : public ExpressionVisitor { UNSUCCESSFUL_VISIT(ParameterLookup); UNSUCCESSFUL_VISIT(RegexMatch); UNSUCCESSFUL_VISIT(Exists); + UNSUCCESSFUL_VISIT(PatternComprehension); #undef UNSUCCESSFUL_VISIT @@ -170,6 +171,7 @@ class PrimitiveLiteralExpressionEvaluator : public ExpressionVisitor INVALID_VISIT(Identifier) INVALID_VISIT(RegexMatch) INVALID_VISIT(Exists) + INVALID_VISIT(PatternComprehension) #undef INVALID_VISIT private: @@ -1090,6 +1092,10 @@ class ExpressionEvaluator : public ExpressionVisitor { } } + TypedValue Visit(PatternComprehension & /*pattern_comprehension*/) override { + throw utils::NotYetImplemented("Expression evaluator can not handle pattern comprehension."); + } + private: template std::map GetAllProperties(const TRecordAccessor &record_accessor) { diff --git a/src/query/plan/preprocess.hpp b/src/query/plan/preprocess.hpp index 8e1955907..322da545a 100644 --- a/src/query/plan/preprocess.hpp +++ b/src/query/plan/preprocess.hpp @@ -230,6 +230,7 @@ class PatternFilterVisitor : public ExpressionVisitor { void Visit(ParameterLookup &op) override{}; void Visit(NamedExpression &op) override{}; void Visit(RegexMatch &op) override{}; + void Visit(PatternComprehension &op) override{}; std::vector getMatchings() { return matchings_; } diff --git a/src/query/plan/rule_based_planner.cpp b/src/query/plan/rule_based_planner.cpp index f3d0c1487..bf5e66158 100644 --- a/src/query/plan/rule_based_planner.cpp +++ b/src/query/plan/rule_based_planner.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -373,12 +373,12 @@ class ReturnBodyContext : public HierarchicalTreeVisitor { return true; } - bool Visit(ParameterLookup &) override { + bool Visit(ParameterLookup & /*unused*/) override { has_aggregation_.emplace_back(false); return true; } - bool PostVisit(RegexMatch ®ex_match) override { + bool PostVisit(RegexMatch & /*unused*/) override { MG_ASSERT(has_aggregation_.size() >= 2U, "Expected 2 has_aggregation_ flags for RegexMatch arguments"); bool has_aggr = has_aggregation_.back(); has_aggregation_.pop_back(); @@ -386,6 +386,10 @@ class ReturnBodyContext : public HierarchicalTreeVisitor { return true; } + bool PostVisit(PatternComprehension & /*unused*/) override { + throw utils::NotYetImplemented("Planner can not handle pattern comprehension."); + } + // Creates NamedExpression with an Identifier for each user declared symbol. // This should be used when body.all_identifiers is true, to generate // expressions for Produce operator. diff --git a/src/utils/typeinfo.hpp b/src/utils/typeinfo.hpp index 682b5ac55..944d35fab 100644 --- a/src/utils/typeinfo.hpp +++ b/src/utils/typeinfo.hpp @@ -190,6 +190,7 @@ enum class TypeId : uint64_t { AST_MULTI_DATABASE_QUERY, AST_SHOW_DATABASES, AST_EDGE_IMPORT_MODE_QUERY, + AST_PATTERN_COMPREHENSION, // Symbol SYMBOL, }; diff --git a/tests/gql_behave/tests/memgraph_V1/features/list_operations.feature b/tests/gql_behave/tests/memgraph_V1/features/list_operations.feature index bfe6b6225..8c5538d6b 100644 --- a/tests/gql_behave/tests/memgraph_V1/features/list_operations.feature +++ b/tests/gql_behave/tests/memgraph_V1/features/list_operations.feature @@ -279,3 +279,15 @@ Feature: List operators Then the result should be: | o | | (:Node {Status: 'This is the status'}) | + + Scenario: Simple list pattern comprehension + Given graph "graph_keanu" + When executing query: + """ + MATCH (keanu:Person {name: 'Keanu Reeves'}) + RETURN [(keanu)-->(b:Movie) WHERE b.title CONTAINS 'Matrix' | b.released] AS years + """ + Then an error should be raised +# Then the result should be: +# | years | +# | [2021,2003,2003,1999] | diff --git a/tests/gql_behave/tests/memgraph_V1/graphs/graph_keanu.cypher b/tests/gql_behave/tests/memgraph_V1/graphs/graph_keanu.cypher new file mode 100644 index 000000000..a7a72aced --- /dev/null +++ b/tests/gql_behave/tests/memgraph_V1/graphs/graph_keanu.cypher @@ -0,0 +1,16 @@ +CREATE + (keanu:Person {name: 'Keanu Reeves'}), + (johnnyMnemonic:Movie {title: 'Johnny Mnemonic', released: 1995}), + (theMatrixRevolutions:Movie {title: 'The Matrix Revolutions', released: 2003}), + (theMatrixReloaded:Movie {title: 'The Matrix Reloaded', released: 2003}), + (theReplacements:Movie {title: 'The Replacements', released: 2000}), + (theMatrix:Movie {title: 'The Matrix', released: 1999}), + (theDevilsAdvocate:Movie {title: 'The Devils Advocate', released: 1997}), + (theMatrixResurrections:Movie {title: 'The Matrix Resurrections', released: 2021}), + (keanu)-[:ACTED_IN]->(johnnyMnemonic), + (keanu)-[:ACTED_IN]->(theMatrixRevolutions), + (keanu)-[:ACTED_IN]->(theMatrixReloaded), + (keanu)-[:ACTED_IN]->(theReplacements), + (keanu)-[:ACTED_IN]->(theMatrix), + (keanu)-[:ACTED_IN]->(theDevilsAdvocate), + (keanu)-[:ACTED_IN]->(theMatrixResurrections);