Add List Pattern Comprehension grammar. (#1588)

This commit is contained in:
Aidar Samerkhanov 2024-01-11 17:20:21 +03:00 committed by GitHub
parent 31f15b3651
commit 2e4d27c59a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 156 additions and 7 deletions

View File

@ -293,4 +293,7 @@ constexpr utils::TypeInfo query::ShowDatabasesQuery::kType{utils::TypeId::AST_SH
constexpr utils::TypeInfo query::EdgeImportModeQuery::kType{utils::TypeId::AST_EDGE_IMPORT_MODE_QUERY,
"EdgeImportModeQuery", &query::Query::kType};
constexpr utils::TypeInfo query::PatternComprehension::kType{utils::TypeId::AST_PATTERN_COMPREHENSION,
"PatternComprehension", &query::Expression::kType};
} // namespace memgraph

View File

@ -3520,6 +3520,65 @@ class Exists : public memgraph::query::Expression {
friend class AstStorage;
};
class PatternComprehension : public memgraph::query::Expression {
public:
static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; }
PatternComprehension() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue *>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
if (variable_) {
variable_->Accept(visitor);
}
pattern_->Accept(visitor);
if (filter_) {
filter_->Accept(visitor);
}
resultExpr_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
PatternComprehension *MapTo(const Symbol &symbol) {
symbol_pos_ = symbol.position();
return this;
}
// The variable name.
Identifier *variable_{nullptr};
// The pattern to match.
Pattern *pattern_{nullptr};
// Optional WHERE clause for filtering.
Where *filter_{nullptr};
// The projection expression.
Expression *resultExpr_{nullptr};
/// Symbol table position of the symbol this Aggregation is mapped to.
int32_t symbol_pos_{-1};
PatternComprehension *Clone(AstStorage *storage) const override {
PatternComprehension *object = storage->Create<PatternComprehension>();
object->pattern_ = pattern_ ? pattern_->Clone(storage) : nullptr;
object->filter_ = filter_ ? filter_->Clone(storage) : nullptr;
object->resultExpr_ = resultExpr_ ? resultExpr_->Clone(storage) : nullptr;
object->symbol_pos_ = symbol_pos_;
return object;
}
protected:
PatternComprehension(Identifier *variable, Pattern *pattern) : variable_(variable), pattern_(pattern) {}
private:
friend class AstStorage;
};
class CallSubquery : public memgraph::query::Clause {
public:
static const utils::TypeInfo kType;

View File

@ -107,6 +107,7 @@ class Exists;
class MultiDatabaseQuery;
class ShowDatabasesQuery;
class EdgeImportModeQuery;
class PatternComprehension;
using TreeCompositeVisitor = utils::CompositeVisitor<
SingleQuery, CypherUnion, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, AdditionOperator,
@ -116,7 +117,7 @@ using TreeCompositeVisitor = utils::CompositeVisitor<
MapProjectionLiteral, PropertyLookup, AllPropertiesLookup, LabelsTest, Aggregation, Function, Reduce, Coalesce,
Extract, All, Single, Any, None, CallProcedure, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete,
Where, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, Merge, Unwind, RegexMatch, LoadCsv,
Foreach, Exists, CallSubquery, CypherQuery>;
Foreach, Exists, CallSubquery, CypherQuery, PatternComprehension>;
using TreeLeafVisitor = utils::LeafVisitor<Identifier, PrimitiveLiteral, ParameterLookup>;
@ -137,7 +138,7 @@ class ExpressionVisitor
ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator,
ListLiteral, MapLiteral, MapProjectionLiteral, PropertyLookup, AllPropertiesLookup,
LabelsTest, Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Any, None,
ParameterLookup, Identifier, PrimitiveLiteral, RegexMatch, Exists> {};
ParameterLookup, Identifier, PrimitiveLiteral, RegexMatch, Exists, PatternComprehension> {};
template <class TResult>
class QueryVisitor

View File

@ -1969,6 +1969,18 @@ antlrcpp::Any CypherMainVisitor::visitPatternElement(MemgraphCypher::PatternElem
return pattern;
}
antlrcpp::Any CypherMainVisitor::visitRelationshipsPattern(MemgraphCypher::RelationshipsPatternContext *ctx) {
auto *pattern = storage_->Create<Pattern>();
pattern->atoms_.push_back(std::any_cast<NodeAtom *>(ctx->nodePattern()->accept(this)));
for (auto *pattern_element_chain : ctx->patternElementChain()) {
auto element = std::any_cast<std::pair<PatternAtom *, PatternAtom *>>(pattern_element_chain->accept(this));
pattern->atoms_.push_back(element.first);
pattern->atoms_.push_back(element.second);
}
anonymous_identifiers.push_back(&pattern->identifier_);
return pattern;
}
antlrcpp::Any CypherMainVisitor::visitPatternElementChain(MemgraphCypher::PatternElementChainContext *ctx) {
return std::pair<PatternAtom *, PatternAtom *>(std::any_cast<EdgeAtom *>(ctx->relationshipPattern()->accept(this)),
std::any_cast<NodeAtom *>(ctx->nodePattern()->accept(this)));
@ -2463,6 +2475,8 @@ antlrcpp::Any CypherMainVisitor::visitAtom(MemgraphCypher::AtomContext *ctx) {
return static_cast<Expression *>(storage_->Create<Extract>(ident, list, expr));
} else if (ctx->existsExpression()) {
return std::any_cast<Expression *>(ctx->existsExpression()->accept(this));
} else if (ctx->patternComprehension()) {
return std::any_cast<Expression *>(ctx->patternComprehension()->accept(this));
}
// TODO: Implement this. We don't support comprehensions, filtering... at
@ -2523,6 +2537,19 @@ antlrcpp::Any CypherMainVisitor::visitExistsExpression(MemgraphCypher::ExistsExp
return static_cast<Expression *>(exists);
}
antlrcpp::Any CypherMainVisitor::visitPatternComprehension(MemgraphCypher::PatternComprehensionContext *ctx) {
auto *comprehension = storage_->Create<PatternComprehension>();
if (ctx->variable()) {
comprehension->variable_ = storage_->Create<Identifier>(std::any_cast<std::string>(ctx->variable()->accept(this)));
}
comprehension->pattern_ = std::any_cast<Pattern *>(ctx->relationshipsPattern()->accept(this));
if (ctx->where()) {
comprehension->filter_ = std::any_cast<Where *>(ctx->where()->accept(this));
}
comprehension->resultExpr_ = std::any_cast<Expression *>(ctx->expression()->accept(this));
return static_cast<Expression *>(comprehension);
}
antlrcpp::Any CypherMainVisitor::visitParenthesizedExpression(MemgraphCypher::ParenthesizedExpressionContext *ctx) {
return std::any_cast<Expression *>(ctx->expression()->accept(this));
}

View File

@ -678,6 +678,11 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
*/
antlrcpp::Any visitPatternElement(MemgraphCypher::PatternElementContext *ctx) override;
/**
* @return Pattern*
*/
antlrcpp::Any visitRelationshipsPattern(MemgraphCypher::RelationshipsPatternContext *ctx) override;
/**
* @return vector<pair<EdgeAtom*, NodeAtom*>>
*/
@ -843,6 +848,11 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
*/
antlrcpp::Any visitExistsExpression(MemgraphCypher::ExistsExpressionContext *ctx) override;
/**
* @return pattern comprehension (Expression)
*/
antlrcpp::Any visitPatternComprehension(MemgraphCypher::PatternComprehensionContext *ctx) override;
/**
* @return Expression*
*/

View File

@ -73,6 +73,7 @@ class ExpressionPrettyPrinter : public ExpressionVisitor<void> {
void Visit(ParameterLookup &op) override;
void Visit(NamedExpression &op) override;
void Visit(RegexMatch &op) override;
void Visit(PatternComprehension &op) override;
private:
std::ostream *out_;
@ -323,6 +324,10 @@ void ExpressionPrettyPrinter::Visit(NamedExpression &op) {
void ExpressionPrettyPrinter::Visit(RegexMatch &op) { PrintOperator(out_, "=~", op.string_expr_, op.regex_); }
void ExpressionPrettyPrinter::Visit(PatternComprehension &op) {
PrintOperator(out_, "Pattern Comprehension", op.variable_, op.pattern_, op.filter_, op.resultExpr_);
}
} // namespace
void PrintExpression(Expression *expr, std::ostream *out) {

View File

@ -296,7 +296,7 @@ functionName : symbolicName ( '.' symbolicName )* ;
listComprehension : '[' filterExpression ( '|' expression )? ']' ;
patternComprehension : '[' ( variable '=' )? relationshipsPattern ( WHERE expression )? '|' expression ']' ;
patternComprehension : '[' ( variable '=' )? relationshipsPattern ( where )? '|' resultExpr=expression ']' ;
propertyLookup : '.' ( propertyKeyName ) ;

View File

@ -249,6 +249,7 @@ class PropertyLookupEvaluationModeVisitor : public ExpressionVisitor<void> {
void Visit(ParameterLookup &op) override{};
void Visit(NamedExpression &op) override { op.expression_->Accept(*this); };
void Visit(RegexMatch &op) override{};
void Visit(PatternComprehension &op) override{};
void Visit(PropertyLookup & /*property_lookup*/) override;

View File

@ -52,6 +52,9 @@ class SymbolTable final {
const Symbol &at(const NamedExpression &nexpr) const { return table_.at(nexpr.symbol_pos_); }
const Symbol &at(const Aggregation &aggr) const { return table_.at(aggr.symbol_pos_); }
const Symbol &at(const Exists &exists) const { return table_.at(exists.symbol_pos_); }
const Symbol &at(const PatternComprehension &pattern_comprehension) const {
return table_.at(pattern_comprehension.symbol_pos_);
}
// TODO: Remove these since members are public
int32_t max_position() const { return static_cast<int32_t>(table_.size()); }

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -101,6 +101,7 @@ class ReferenceExpressionEvaluator : public ExpressionVisitor<TypedValue *> {
UNSUCCESSFUL_VISIT(ParameterLookup);
UNSUCCESSFUL_VISIT(RegexMatch);
UNSUCCESSFUL_VISIT(Exists);
UNSUCCESSFUL_VISIT(PatternComprehension);
#undef UNSUCCESSFUL_VISIT
@ -170,6 +171,7 @@ class PrimitiveLiteralExpressionEvaluator : public ExpressionVisitor<TypedValue>
INVALID_VISIT(Identifier)
INVALID_VISIT(RegexMatch)
INVALID_VISIT(Exists)
INVALID_VISIT(PatternComprehension)
#undef INVALID_VISIT
private:
@ -1090,6 +1092,10 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
}
}
TypedValue Visit(PatternComprehension & /*pattern_comprehension*/) override {
throw utils::NotYetImplemented("Expression evaluator can not handle pattern comprehension.");
}
private:
template <class TRecordAccessor>
std::map<storage::PropertyId, storage::PropertyValue> GetAllProperties(const TRecordAccessor &record_accessor) {

View File

@ -230,6 +230,7 @@ class PatternFilterVisitor : public ExpressionVisitor<void> {
void Visit(ParameterLookup &op) override{};
void Visit(NamedExpression &op) override{};
void Visit(RegexMatch &op) override{};
void Visit(PatternComprehension &op) override{};
std::vector<FilterMatching> getMatchings() { return matchings_; }

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -373,12 +373,12 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
return true;
}
bool Visit(ParameterLookup &) override {
bool Visit(ParameterLookup & /*unused*/) override {
has_aggregation_.emplace_back(false);
return true;
}
bool PostVisit(RegexMatch &regex_match) override {
bool PostVisit(RegexMatch & /*unused*/) override {
MG_ASSERT(has_aggregation_.size() >= 2U, "Expected 2 has_aggregation_ flags for RegexMatch arguments");
bool has_aggr = has_aggregation_.back();
has_aggregation_.pop_back();
@ -386,6 +386,10 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
return true;
}
bool PostVisit(PatternComprehension & /*unused*/) override {
throw utils::NotYetImplemented("Planner can not handle pattern comprehension.");
}
// Creates NamedExpression with an Identifier for each user declared symbol.
// This should be used when body.all_identifiers is true, to generate
// expressions for Produce operator.

View File

@ -190,6 +190,7 @@ enum class TypeId : uint64_t {
AST_MULTI_DATABASE_QUERY,
AST_SHOW_DATABASES,
AST_EDGE_IMPORT_MODE_QUERY,
AST_PATTERN_COMPREHENSION,
// Symbol
SYMBOL,
};

View File

@ -279,3 +279,15 @@ Feature: List operators
Then the result should be:
| o |
| (:Node {Status: 'This is the status'}) |
Scenario: Simple list pattern comprehension
Given graph "graph_keanu"
When executing query:
"""
MATCH (keanu:Person {name: 'Keanu Reeves'})
RETURN [(keanu)-->(b:Movie) WHERE b.title CONTAINS 'Matrix' | b.released] AS years
"""
Then an error should be raised
# Then the result should be:
# | years |
# | [2021,2003,2003,1999] |

View File

@ -0,0 +1,16 @@
CREATE
(keanu:Person {name: 'Keanu Reeves'}),
(johnnyMnemonic:Movie {title: 'Johnny Mnemonic', released: 1995}),
(theMatrixRevolutions:Movie {title: 'The Matrix Revolutions', released: 2003}),
(theMatrixReloaded:Movie {title: 'The Matrix Reloaded', released: 2003}),
(theReplacements:Movie {title: 'The Replacements', released: 2000}),
(theMatrix:Movie {title: 'The Matrix', released: 1999}),
(theDevilsAdvocate:Movie {title: 'The Devils Advocate', released: 1997}),
(theMatrixResurrections:Movie {title: 'The Matrix Resurrections', released: 2021}),
(keanu)-[:ACTED_IN]->(johnnyMnemonic),
(keanu)-[:ACTED_IN]->(theMatrixRevolutions),
(keanu)-[:ACTED_IN]->(theMatrixReloaded),
(keanu)-[:ACTED_IN]->(theReplacements),
(keanu)-[:ACTED_IN]->(theMatrix),
(keanu)-[:ACTED_IN]->(theDevilsAdvocate),
(keanu)-[:ACTED_IN]->(theMatrixResurrections);