Add classic and hierarchical visiting

Summary:
Merge utils/visitor directory into single file.
Rename Visitor to HierarchicalVisitor.
Add regular Visitor.
Split HierarchicalVisitor into LeafVisitor and CompositeVisitor.
Add more documentation on visitor pattern.
Make PostVisit and Visit return bool.

Reviewers: florijan, mislav.bradac, buda

Reviewed By: mislav.bradac

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D364
This commit is contained in:
Teon Banek 2017-05-16 09:16:46 +02:00
parent 9640633dd1
commit 15d7226515
12 changed files with 697 additions and 498 deletions

View File

@ -9,13 +9,12 @@
#include "query/frontend/ast/ast_visitor.hpp"
#include "query/typed_value.hpp"
#include "utils/assert.hpp"
#include "utils/visitor/visitable.hpp"
namespace query {
class AstTreeStorage;
class Tree : public ::utils::Visitable<TreeVisitorBase> {
class Tree : public ::utils::Visitable<HierarchicalTreeVisitor> {
friend class AstTreeStorage;
public:
@ -64,13 +63,11 @@ class OrOperator : public BinaryOperator {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression1_->Accept(visitor);
expression2_->Accept(visitor);
visitor.PostVisit(*this);
expression1_->Accept(visitor) && expression2_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
protected:
@ -81,13 +78,11 @@ class XorOperator : public BinaryOperator {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression1_->Accept(visitor);
expression2_->Accept(visitor);
visitor.PostVisit(*this);
expression1_->Accept(visitor) && expression2_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
protected:
@ -98,13 +93,11 @@ class AndOperator : public BinaryOperator {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression1_->Accept(visitor);
expression2_->Accept(visitor);
visitor.PostVisit(*this);
expression1_->Accept(visitor) && expression2_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
protected:
@ -119,13 +112,11 @@ class FilterAndOperator : public BinaryOperator {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression1_->Accept(visitor);
expression2_->Accept(visitor);
visitor.PostVisit(*this);
expression1_->Accept(visitor) && expression2_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
protected:
@ -136,13 +127,11 @@ class AdditionOperator : public BinaryOperator {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression1_->Accept(visitor);
expression2_->Accept(visitor);
visitor.PostVisit(*this);
expression1_->Accept(visitor) && expression2_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
protected:
@ -153,13 +142,11 @@ class SubtractionOperator : public BinaryOperator {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression1_->Accept(visitor);
expression2_->Accept(visitor);
visitor.PostVisit(*this);
expression1_->Accept(visitor) && expression2_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
protected:
@ -170,13 +157,11 @@ class MultiplicationOperator : public BinaryOperator {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression1_->Accept(visitor);
expression2_->Accept(visitor);
visitor.PostVisit(*this);
expression1_->Accept(visitor) && expression2_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
protected:
@ -187,13 +172,11 @@ class DivisionOperator : public BinaryOperator {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression1_->Accept(visitor);
expression2_->Accept(visitor);
visitor.PostVisit(*this);
expression1_->Accept(visitor) && expression2_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
protected:
@ -204,13 +187,11 @@ class ModOperator : public BinaryOperator {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression1_->Accept(visitor);
expression2_->Accept(visitor);
visitor.PostVisit(*this);
expression1_->Accept(visitor) && expression2_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
protected:
@ -221,13 +202,11 @@ class NotEqualOperator : public BinaryOperator {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression1_->Accept(visitor);
expression2_->Accept(visitor);
visitor.PostVisit(*this);
expression1_->Accept(visitor) && expression2_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
protected:
@ -238,13 +217,11 @@ class EqualOperator : public BinaryOperator {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression1_->Accept(visitor);
expression2_->Accept(visitor);
visitor.PostVisit(*this);
expression1_->Accept(visitor) && expression2_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
protected:
@ -255,13 +232,11 @@ class LessOperator : public BinaryOperator {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression1_->Accept(visitor);
expression2_->Accept(visitor);
visitor.PostVisit(*this);
expression1_->Accept(visitor) && expression2_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
protected:
@ -272,13 +247,11 @@ class GreaterOperator : public BinaryOperator {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression1_->Accept(visitor);
expression2_->Accept(visitor);
visitor.PostVisit(*this);
expression1_->Accept(visitor) && expression2_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
protected:
@ -289,13 +262,11 @@ class LessEqualOperator : public BinaryOperator {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression1_->Accept(visitor);
expression2_->Accept(visitor);
visitor.PostVisit(*this);
expression1_->Accept(visitor) && expression2_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
protected:
@ -306,13 +277,11 @@ class GreaterEqualOperator : public BinaryOperator {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression1_->Accept(visitor);
expression2_->Accept(visitor);
visitor.PostVisit(*this);
expression1_->Accept(visitor) && expression2_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
protected:
@ -323,13 +292,11 @@ class InListOperator : public BinaryOperator {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression1_->Accept(visitor);
expression2_->Accept(visitor);
visitor.PostVisit(*this);
expression1_->Accept(visitor) && expression2_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
protected:
@ -340,13 +307,11 @@ class ListIndexingOperator : public BinaryOperator {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression1_->Accept(visitor);
expression2_->Accept(visitor);
visitor.PostVisit(*this);
expression1_->Accept(visitor) && expression2_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
protected:
@ -357,18 +322,17 @@ class ListSlicingOperator : public Expression {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
list_->Accept(visitor);
if (lower_bound_) {
lower_bound_->Accept(visitor);
bool cont = list_->Accept(visitor);
if (cont && lower_bound_) {
cont = lower_bound_->Accept(visitor);
}
if (upper_bound_) {
if (cont && upper_bound_) {
upper_bound_->Accept(visitor);
}
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
Expression *list_;
@ -388,12 +352,11 @@ class NotOperator : public UnaryOperator {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression_->Accept(visitor);
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
protected:
@ -404,12 +367,11 @@ class UnaryPlusOperator : public UnaryOperator {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression_->Accept(visitor);
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
protected:
@ -420,12 +382,11 @@ class UnaryMinusOperator : public UnaryOperator {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression_->Accept(visitor);
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
protected:
@ -436,10 +397,11 @@ class IsNullOperator : public UnaryOperator {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
visitor.Visit(*this);
expression_->Accept(visitor);
visitor.PostVisit(*this);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
expression_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
protected:
@ -457,8 +419,8 @@ class PrimitiveLiteral : public BaseLiteral {
friend class AstTreeStorage;
public:
DEFVISITABLE(HierarchicalTreeVisitor);
TypedValue value_;
DEFVISITABLE(TreeVisitorBase);
protected:
PrimitiveLiteral(int uid) : BaseLiteral(uid) {}
@ -471,12 +433,12 @@ class ListLiteral : public BaseLiteral {
public:
const std::vector<Expression *> elements_;
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
for (auto expr_ptr : elements_) expr_ptr->Accept(visitor);
visitor.PostVisit(*this);
for (auto expr_ptr : elements_)
if (!expr_ptr->Accept(visitor)) break;
}
return visitor.PostVisit(*this);
}
protected:
@ -489,7 +451,7 @@ class Identifier : public Expression {
friend class AstTreeStorage;
public:
DEFVISITABLE(TreeVisitorBase);
DEFVISITABLE(HierarchicalTreeVisitor);
std::string name_;
bool user_declared_ = true;
@ -503,12 +465,11 @@ class PropertyLookup : public Expression {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression_->Accept(visitor);
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
Expression *expression_ = nullptr;
@ -531,12 +492,11 @@ class LabelsTest : public Expression {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression_->Accept(visitor);
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
Expression *expression_ = nullptr;
@ -552,12 +512,11 @@ class EdgeTypeTest : public Expression {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression_->Accept(visitor);
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
Expression *expression_ = nullptr;
@ -576,14 +535,13 @@ class Function : public Expression {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
for (auto *argument : arguments_) {
argument->Accept(visitor);
if (!argument->Accept(visitor)) break;
}
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
std::function<TypedValue(const std::vector<TypedValue> &, GraphDbAccessor &)>
@ -609,14 +567,13 @@ class Aggregation : public UnaryOperator {
static const constexpr char *const kSum = "SUM";
static const constexpr char *const kAvg = "AVG";
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
if (expression_) {
expression_->Accept(visitor);
}
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
Op op_;
@ -633,12 +590,11 @@ class NamedExpression : public Tree {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression_->Accept(visitor);
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
std::string name_;
@ -667,12 +623,11 @@ class NodeAtom : public PatternAtom {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
identifier_->Accept(visitor);
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
std::vector<GraphDbTypes::Label> labels_;
@ -692,12 +647,11 @@ class EdgeAtom : public PatternAtom {
// necessarily go from left to right
enum class Direction { LEFT, RIGHT, BOTH };
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
identifier_->Accept(visitor);
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
Direction direction_ = Direction::BOTH;
@ -722,14 +676,13 @@ class Pattern : public Tree {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
for (auto &part : atoms_) {
part->Accept(visitor);
if (!part->Accept(visitor)) break;
}
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
Identifier *identifier_ = nullptr;
std::vector<PatternAtom *> atoms_;
@ -742,14 +695,13 @@ class Query : public Tree {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
for (auto &clause : clauses_) {
clause->Accept(visitor);
if (!clause->Accept(visitor)) break;
}
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
std::vector<Clause *> clauses_;
@ -763,14 +715,13 @@ class Create : public Clause {
public:
Create(int uid) : Clause(uid) {}
std::vector<Pattern *> patterns_;
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
for (auto &pattern : patterns_) {
pattern->Accept(visitor);
if (!pattern->Accept(visitor)) break;
}
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
};
@ -778,12 +729,11 @@ class Where : public Tree {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression_->Accept(visitor);
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
Expression *expression_ = nullptr;
@ -796,17 +746,20 @@ class Match : public Clause {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
bool cont = true;
for (auto &pattern : patterns_) {
pattern->Accept(visitor);
if (!pattern->Accept(visitor)) {
cont = false;
break;
}
}
if (where_) {
if (cont && where_) {
where_->Accept(visitor);
}
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
std::vector<Pattern *> patterns_;
Where *where_ = nullptr;
@ -842,19 +795,27 @@ class Return : public Clause {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
bool cont = true;
for (auto &expr : body_.named_expressions) {
expr->Accept(visitor);
if (!expr->Accept(visitor)) {
cont = false;
break;
}
}
for (auto &order_by : body_.order_by) {
order_by.second->Accept(visitor);
if (cont) {
for (auto &order_by : body_.order_by) {
if (!order_by.second->Accept(visitor)) {
cont = false;
break;
}
}
}
if (body_.skip) body_.skip->Accept(visitor);
if (body_.limit) body_.limit->Accept(visitor);
visitor.PostVisit(*this);
if (cont && body_.skip) cont = body_.skip->Accept(visitor);
if (cont && body_.limit) cont = body_.limit->Accept(visitor);
}
return visitor.PostVisit(*this);
}
ReturnBody body_;
@ -867,20 +828,28 @@ class With : public Clause {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
bool cont = true;
for (auto &expr : body_.named_expressions) {
expr->Accept(visitor);
if (!expr->Accept(visitor)) {
cont = false;
break;
}
}
for (auto &order_by : body_.order_by) {
order_by.second->Accept(visitor);
if (cont) {
for (auto &order_by : body_.order_by) {
if (!order_by.second->Accept(visitor)) {
cont = false;
break;
}
}
}
if (where_) where_->Accept(visitor);
if (body_.skip) body_.skip->Accept(visitor);
if (body_.limit) body_.limit->Accept(visitor);
visitor.PostVisit(*this);
if (cont && where_) cont = where_->Accept(visitor);
if (cont && body_.skip) cont = body_.skip->Accept(visitor);
if (cont && body_.limit) cont = body_.limit->Accept(visitor);
}
return visitor.PostVisit(*this);
}
ReturnBody body_;
@ -894,14 +863,13 @@ class Delete : public Clause {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
for (auto &expr : expressions_) {
expr->Accept(visitor);
if (!expr->Accept(visitor)) break;
}
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
std::vector<Expression *> expressions_;
bool detach_ = false;
@ -914,13 +882,11 @@ class SetProperty : public Clause {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
property_lookup_->Accept(visitor);
expression_->Accept(visitor);
visitor.PostVisit(*this);
property_lookup_->Accept(visitor) && expression_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
PropertyLookup *property_lookup_ = nullptr;
Expression *expression_ = nullptr;
@ -937,13 +903,11 @@ class SetProperties : public Clause {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
identifier_->Accept(visitor);
expression_->Accept(visitor);
visitor.PostVisit(*this);
identifier_->Accept(visitor) && expression_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
Identifier *identifier_ = nullptr;
Expression *expression_ = nullptr;
@ -963,12 +927,11 @@ class SetLabels : public Clause {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
identifier_->Accept(visitor);
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
Identifier *identifier_ = nullptr;
std::vector<GraphDbTypes::Label> labels_;
@ -984,12 +947,11 @@ class RemoveProperty : public Clause {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
property_lookup_->Accept(visitor);
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
PropertyLookup *property_lookup_ = nullptr;
@ -1003,12 +965,11 @@ class RemoveLabels : public Clause {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
identifier_->Accept(visitor);
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
Identifier *identifier_ = nullptr;
std::vector<GraphDbTypes::Label> labels_;
@ -1024,18 +985,27 @@ class Merge : public Clause {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
pattern_->Accept(visitor);
for (auto &set : on_match_) {
set->Accept(visitor);
bool cont = pattern_->Accept(visitor);
if (cont) {
for (auto &set : on_match_) {
if (!set->Accept(visitor)) {
cont = false;
break;
}
}
}
for (auto &set : on_create_) {
set->Accept(visitor);
if (cont) {
for (auto &set : on_create_) {
if (!set->Accept(visitor)) {
cont = false;
break;
}
}
}
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
Pattern *pattern_ = nullptr;
@ -1050,12 +1020,11 @@ class Unwind : public Clause {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
named_expression_->Accept(visitor);
visitor.PostVisit(*this);
}
return visitor.PostVisit(*this);
}
NamedExpression *const named_expression_ = nullptr;

View File

@ -1,10 +1,10 @@
#pragma once
#include "utils/visitor/visitor.hpp"
#include "utils/visitor.hpp"
namespace query {
// Forward declares for TreeVisitorBase
// Forward declares for Tree visitors.
class Query;
class NamedExpression;
class Identifier;
@ -54,16 +54,27 @@ class RemoveLabels;
class Merge;
class Unwind;
using TreeVisitorBase = ::utils::Visitor<
using TreeCompositeVisitor = ::utils::CompositeVisitor<
Query, NamedExpression, OrOperator, XorOperator, AndOperator,
FilterAndOperator, NotOperator, AdditionOperator, SubtractionOperator,
MultiplicationOperator, DivisionOperator, ModOperator, NotEqualOperator,
EqualOperator, LessOperator, GreaterOperator, LessEqualOperator,
GreaterEqualOperator, InListOperator, ListIndexingOperator,
ListSlicingOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator,
Identifier, PrimitiveLiteral, ListLiteral, PropertyLookup, LabelsTest,
EdgeTypeTest, Aggregation, Function, Create, Match, Return, With, Pattern,
NodeAtom, EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels,
RemoveProperty, RemoveLabels, Merge, Unwind>;
ListLiteral, PropertyLookup, LabelsTest, EdgeTypeTest, Aggregation,
Function, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete,
Where, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels,
Merge, Unwind>;
using TreeLeafVisitor = ::utils::LeafVisitor<Identifier, PrimitiveLiteral>;
class HierarchicalTreeVisitor : public TreeCompositeVisitor,
public TreeLeafVisitor {
public:
using TreeCompositeVisitor::PreVisit;
using TreeCompositeVisitor::PostVisit;
using typename TreeLeafVisitor::ReturnType;
using TreeLeafVisitor::Visit;
};
} // namespace query

View File

@ -109,8 +109,14 @@ void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) {
// Clauses
void SymbolGenerator::Visit(Create &create) { scope_.in_create = true; }
void SymbolGenerator::PostVisit(Create &create) { scope_.in_create = false; }
bool SymbolGenerator::PreVisit(Create &create) {
scope_.in_create = true;
return true;
}
bool SymbolGenerator::PostVisit(Create &create) {
scope_.in_create = false;
return true;
}
bool SymbolGenerator::PreVisit(Return &ret) {
scope_.in_return = true;
@ -126,22 +132,38 @@ bool SymbolGenerator::PreVisit(With &with) {
return false; // We handled the traversal ourselves.
}
void SymbolGenerator::Visit(Where &) { scope_.in_where = true; }
void SymbolGenerator::PostVisit(Where &) { scope_.in_where = false; }
bool SymbolGenerator::PreVisit(Where &) {
scope_.in_where = true;
return true;
}
bool SymbolGenerator::PostVisit(Where &) {
scope_.in_where = false;
return true;
}
void SymbolGenerator::Visit(Merge &) { scope_.in_merge = true; }
void SymbolGenerator::PostVisit(Merge &) { scope_.in_merge = false; }
bool SymbolGenerator::PreVisit(Merge &) {
scope_.in_merge = true;
return true;
}
bool SymbolGenerator::PostVisit(Merge &) {
scope_.in_merge = false;
return true;
}
void SymbolGenerator::PostVisit(Unwind &unwind) {
bool SymbolGenerator::PostVisit(Unwind &unwind) {
const auto &name = unwind.named_expression_->name_;
if (HasSymbol(name)) {
throw RedeclareVariableError(name);
}
symbol_table_[*unwind.named_expression_] = CreateSymbol(name, true);
return true;
}
void SymbolGenerator::Visit(Match &) { scope_.in_match = true; }
void SymbolGenerator::PostVisit(Match &) {
bool SymbolGenerator::PreVisit(Match &) {
scope_.in_match = true;
return true;
}
bool SymbolGenerator::PostVisit(Match &) {
scope_.in_match = false;
// Check variables in property maps after visiting Match, so that they can
// reference symbols out of bind order.
@ -150,11 +172,12 @@ void SymbolGenerator::PostVisit(Match &) {
symbol_table_[*ident] = scope_.symbols[ident->name_];
}
scope_.identifiers_in_property_maps.clear();
return true;
}
// Expressions
void SymbolGenerator::Visit(Identifier &ident) {
SymbolGenerator::ReturnType SymbolGenerator::Visit(Identifier &ident) {
if (scope_.in_skip || scope_.in_limit) {
throw SemanticException("Variables are not allowed in {}",
scope_.in_skip ? "SKIP" : "LIMIT");
@ -193,9 +216,10 @@ void SymbolGenerator::Visit(Identifier &ident) {
symbol = scope_.symbols[ident.name_];
}
symbol_table_[ident] = symbol;
return true;
}
void SymbolGenerator::Visit(Aggregation &aggr) {
bool SymbolGenerator::PreVisit(Aggregation &aggr) {
// Check if the aggregation can be used in this context. This check should
// probably move to a separate phase, which checks if the query is well
// formed.
@ -215,29 +239,33 @@ void SymbolGenerator::Visit(Aggregation &aggr) {
symbol_table_.CreateSymbol("", false, Symbol::Type::Number);
scope_.in_aggregation = true;
scope_.has_aggregation = true;
return true;
}
void SymbolGenerator::PostVisit(Aggregation &aggr) {
bool SymbolGenerator::PostVisit(Aggregation &aggr) {
scope_.in_aggregation = false;
return true;
}
// Pattern and its subparts.
void SymbolGenerator::Visit(Pattern &pattern) {
bool SymbolGenerator::PreVisit(Pattern &pattern) {
scope_.in_pattern = true;
if ((scope_.in_create || scope_.in_merge) && pattern.atoms_.size() == 1U) {
debug_assert(dynamic_cast<NodeAtom *>(pattern.atoms_[0]),
"Expected a single NodeAtom in Pattern");
scope_.in_create_node = true;
}
return true;
}
void SymbolGenerator::PostVisit(Pattern &pattern) {
bool SymbolGenerator::PostVisit(Pattern &pattern) {
scope_.in_pattern = false;
scope_.in_create_node = false;
return true;
}
void SymbolGenerator::Visit(NodeAtom &node_atom) {
bool SymbolGenerator::PreVisit(NodeAtom &node_atom) {
scope_.in_node_atom = true;
bool props_or_labels =
!node_atom.properties_.empty() || !node_atom.labels_.empty();
@ -252,13 +280,15 @@ void SymbolGenerator::Visit(NodeAtom &node_atom) {
kv.second->Accept(*this);
}
scope_.in_property_map = false;
return true;
}
void SymbolGenerator::PostVisit(NodeAtom &node_atom) {
bool SymbolGenerator::PostVisit(NodeAtom &node_atom) {
scope_.in_node_atom = false;
return true;
}
void SymbolGenerator::Visit(EdgeAtom &edge_atom) {
bool SymbolGenerator::PreVisit(EdgeAtom &edge_atom) {
scope_.in_edge_atom = true;
if (scope_.in_create || scope_.in_merge) {
scope_.in_create_edge = true;
@ -274,11 +304,13 @@ void SymbolGenerator::Visit(EdgeAtom &edge_atom) {
"when creating an edge");
}
}
return true;
}
void SymbolGenerator::PostVisit(EdgeAtom &edge_atom) {
bool SymbolGenerator::PostVisit(EdgeAtom &edge_atom) {
scope_.in_edge_atom = false;
scope_.in_create_edge = false;
return true;
}
bool SymbolGenerator::HasSymbol(const std::string &name) {

View File

@ -16,39 +16,41 @@ namespace query {
/// During the process of symbol generation, simple semantic checks are
/// performed. Such as, redeclaring a variable or conflicting expectations of
/// variable types.
class SymbolGenerator : public TreeVisitorBase {
class SymbolGenerator : public HierarchicalTreeVisitor {
public:
SymbolGenerator(SymbolTable &symbol_table) : symbol_table_(symbol_table) {}
using TreeVisitorBase::PreVisit;
using TreeVisitorBase::Visit;
using TreeVisitorBase::PostVisit;
using HierarchicalTreeVisitor::PreVisit;
using typename HierarchicalTreeVisitor::ReturnType;
using HierarchicalTreeVisitor::Visit;
using HierarchicalTreeVisitor::PostVisit;
// Clauses
void Visit(Create &) override;
void PostVisit(Create &) override;
bool PreVisit(Create &) override;
bool PostVisit(Create &) override;
bool PreVisit(Return &) override;
bool PreVisit(With &) override;
void Visit(Where &) override;
void PostVisit(Where &) override;
void Visit(Merge &) override;
void PostVisit(Merge &) override;
void PostVisit(Unwind &) override;
void Visit(Match &) override;
void PostVisit(Match &) override;
bool PreVisit(Where &) override;
bool PostVisit(Where &) override;
bool PreVisit(Merge &) override;
bool PostVisit(Merge &) override;
bool PostVisit(Unwind &) override;
bool PreVisit(Match &) override;
bool PostVisit(Match &) override;
// Expressions
void Visit(Identifier &) override;
void Visit(Aggregation &) override;
void PostVisit(Aggregation &) override;
ReturnType Visit(Identifier &) override;
ReturnType Visit(PrimitiveLiteral &) override { return true; }
bool PreVisit(Aggregation &) override;
bool PostVisit(Aggregation &) override;
// Pattern and its subparts.
void Visit(Pattern &) override;
void PostVisit(Pattern &) override;
void Visit(NodeAtom &) override;
void PostVisit(NodeAtom &) override;
void Visit(EdgeAtom &) override;
void PostVisit(EdgeAtom &) override;
bool PreVisit(Pattern &) override;
bool PostVisit(Pattern &) override;
bool PreVisit(NodeAtom &) override;
bool PostVisit(NodeAtom &) override;
bool PreVisit(EdgeAtom &) override;
bool PostVisit(EdgeAtom &) override;
private:
// Scope stores the state of where we are when visiting the AST and a map of

View File

@ -16,7 +16,7 @@
namespace query {
class ExpressionEvaluator : public TreeVisitorBase {
class ExpressionEvaluator : public HierarchicalTreeVisitor {
public:
ExpressionEvaluator(Frame &frame, const SymbolTable &symbol_table,
GraphDbAccessor &db_accessor,
@ -39,32 +39,37 @@ class ExpressionEvaluator : public TreeVisitorBase {
return last;
}
using TreeVisitorBase::PreVisit;
using TreeVisitorBase::Visit;
using TreeVisitorBase::PostVisit;
using HierarchicalTreeVisitor::PreVisit;
using typename HierarchicalTreeVisitor::ReturnType;
using HierarchicalTreeVisitor::Visit;
using HierarchicalTreeVisitor::PostVisit;
void PostVisit(NamedExpression &named_expression) override {
bool PostVisit(NamedExpression &named_expression) override {
auto symbol = symbol_table_.at(named_expression);
frame_[symbol] = PopBack();
return true;
}
void Visit(Identifier &ident) override {
ReturnType Visit(Identifier &ident) override {
auto value = frame_[symbol_table_.at(ident)];
SwitchAccessors(value);
result_stack_.emplace_back(std::move(value));
return true;
}
#define BINARY_OPERATOR_VISITOR(OP_NODE, CPP_OP) \
void PostVisit(OP_NODE &) override { \
bool PostVisit(OP_NODE &) override { \
auto expression2 = PopBack(); \
auto expression1 = PopBack(); \
result_stack_.push_back(expression1 CPP_OP expression2); \
return true; \
}
#define UNARY_OPERATOR_VISITOR(OP_NODE, CPP_OP) \
void PostVisit(OP_NODE &) override { \
bool PostVisit(OP_NODE &) override { \
auto expression = PopBack(); \
result_stack_.push_back(CPP_OP expression); \
return true; \
}
BINARY_OPERATOR_VISITOR(OrOperator, ||);
@ -103,12 +108,12 @@ class ExpressionEvaluator : public TreeVisitorBase {
return false;
}
void PostVisit(InListOperator &) override {
bool PostVisit(InListOperator &) override {
auto _list = PopBack();
auto literal = PopBack();
if (_list.IsNull()) {
result_stack_.emplace_back(TypedValue::Null);
return;
return true;
}
// Exceptions have higher priority than returning null.
// We need to convert list to its type before checking if literal is null,
@ -116,7 +121,7 @@ class ExpressionEvaluator : public TreeVisitorBase {
auto list = _list.Value<std::vector<TypedValue>>();
if (literal.IsNull()) {
result_stack_.emplace_back(TypedValue::Null);
return;
return true;
}
auto has_null = false;
for (const auto &element : list) {
@ -125,7 +130,7 @@ class ExpressionEvaluator : public TreeVisitorBase {
has_null = true;
} else if (result.Value<bool>()) {
result_stack_.emplace_back(true);
return;
return true;
}
}
if (has_null) {
@ -133,9 +138,10 @@ class ExpressionEvaluator : public TreeVisitorBase {
} else {
result_stack_.emplace_back(false);
}
return true;
}
void PostVisit(ListIndexingOperator &) override {
bool PostVisit(ListIndexingOperator &) override {
// TODO: implement this for maps
auto _index = PopBack();
if (_index.type() != TypedValue::Type::Int &&
@ -150,7 +156,7 @@ class ExpressionEvaluator : public TreeVisitorBase {
if (_index.type() == TypedValue::Type::Null ||
_list.type() == TypedValue::Type::Null) {
result_stack_.emplace_back(TypedValue::Null);
return;
return true;
}
auto index = _index.Value<int64_t>();
const auto &list = _list.Value<std::vector<TypedValue>>();
@ -159,12 +165,13 @@ class ExpressionEvaluator : public TreeVisitorBase {
}
if (index >= static_cast<int64_t>(list.size()) || index < 0) {
result_stack_.emplace_back(TypedValue::Null);
return;
return true;
}
result_stack_.emplace_back(list[index]);
return true;
}
void PostVisit(ListSlicingOperator &op) override {
bool PostVisit(ListSlicingOperator &op) override {
// If some type is null we can't return null, because throwing exception on
// illegal type has higher priority.
auto is_null = false;
@ -193,7 +200,7 @@ class ExpressionEvaluator : public TreeVisitorBase {
if (is_null) {
result_stack_.emplace_back(TypedValue::Null);
return;
return true;
}
const auto &list = _list.Value<std::vector<TypedValue>>();
auto normalise_bound = [&](int64_t bound) {
@ -207,18 +214,20 @@ class ExpressionEvaluator : public TreeVisitorBase {
auto upper_bound = normalise_bound(_upper_bound.Value<int64_t>());
if (upper_bound <= lower_bound) {
result_stack_.emplace_back(std::vector<TypedValue>());
return;
return true;
}
result_stack_.emplace_back(std::vector<TypedValue>(
list.begin() + lower_bound, list.begin() + upper_bound));
return true;
}
void PostVisit(IsNullOperator &) override {
bool PostVisit(IsNullOperator &) override {
auto expression = PopBack();
result_stack_.push_back(TypedValue(expression.IsNull()));
return true;
}
void PostVisit(PropertyLookup &property_lookup) override {
bool PostVisit(PropertyLookup &property_lookup) override {
auto expression_result = PopBack();
switch (expression_result.type()) {
case TypedValue::Type::Null:
@ -244,9 +253,10 @@ class ExpressionEvaluator : public TreeVisitorBase {
throw TypedValueException(
"Expected Node, Edge or Map for property lookup");
}
return true;
}
void PostVisit(LabelsTest &labels_test) override {
bool PostVisit(LabelsTest &labels_test) override {
auto expression_result = PopBack();
switch (expression_result.type()) {
case TypedValue::Type::Null:
@ -257,7 +267,7 @@ class ExpressionEvaluator : public TreeVisitorBase {
for (const auto label : labels_test.labels_) {
if (!vertex.has_label(label)) {
result_stack_.emplace_back(false);
return;
return true;
}
}
result_stack_.emplace_back(true);
@ -266,9 +276,10 @@ class ExpressionEvaluator : public TreeVisitorBase {
default:
throw TypedValueException("Expected Node in labels test");
}
return true;
}
void PostVisit(EdgeTypeTest &edge_type_test) override {
bool PostVisit(EdgeTypeTest &edge_type_test) override {
auto expression_result = PopBack();
switch (expression_result.type()) {
case TypedValue::Type::Null:
@ -280,7 +291,7 @@ class ExpressionEvaluator : public TreeVisitorBase {
for (const auto edge_type : edge_type_test.edge_types_) {
if (edge_type == real_edge_type) {
result_stack_.emplace_back(true);
return;
return true;
}
}
result_stack_.emplace_back(false);
@ -289,21 +300,24 @@ class ExpressionEvaluator : public TreeVisitorBase {
default:
throw TypedValueException("Expected Edge in edge type test");
}
return true;
}
void Visit(PrimitiveLiteral &literal) override {
ReturnType Visit(PrimitiveLiteral &literal) override {
// TODO: no need to evaluate constants, we can write it to frame in one
// of the previous phases.
result_stack_.push_back(literal.value_);
return true;
}
void PostVisit(ListLiteral &literal) override {
bool PostVisit(ListLiteral &literal) override {
std::vector<TypedValue> result;
result.reserve(literal.elements_.size());
for (size_t i = 0; i < literal.elements_.size(); i++)
result.emplace_back(PopBack());
std::reverse(result.begin(), result.end());
result_stack_.emplace_back(std::move(result));
return true;
}
bool PreVisit(Aggregation &aggregation) override {
@ -316,13 +330,14 @@ class ExpressionEvaluator : public TreeVisitorBase {
return false;
}
void PostVisit(Function &function) override {
bool PostVisit(Function &function) override {
std::vector<TypedValue> arguments;
for (int i = 0; i < static_cast<int>(function.arguments_.size()); ++i) {
arguments.push_back(PopBack());
}
reverse(arguments.begin(), arguments.end());
result_stack_.emplace_back(function.function_(arguments, db_accessor_));
return true;
}
private:

View File

@ -8,27 +8,16 @@
// macro for the default implementation of LogicalOperator::Accept
// that accepts the visitor and visits it's input_ operator
#define ACCEPT_WITH_INPUT(class_name) \
void class_name::Accept(LogicalOperatorVisitor &visitor) { \
if (visitor.PreVisit(*this)) { \
visitor.Visit(*this); \
input_->Accept(visitor); \
visitor.PostVisit(*this); \
} \
#define ACCEPT_WITH_INPUT(class_name) \
bool class_name::Accept(HierarchicalLogicalOperatorVisitor &visitor) { \
if (visitor.PreVisit(*this)) { \
input_->Accept(visitor); \
} \
return visitor.PostVisit(*this); \
}
namespace query::plan {
void Once::Accept(LogicalOperatorVisitor &visitor) {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
visitor.PostVisit(*this);
}
}
std::unique_ptr<Cursor> Once::MakeCursor(GraphDbAccessor &) {
return std::make_unique<OnceCursor>();
}
bool Once::OnceCursor::Pull(Frame &, const SymbolTable &) {
if (!did_pull_) {
did_pull_ = true;
@ -37,6 +26,10 @@ bool Once::OnceCursor::Pull(Frame &, const SymbolTable &) {
return false;
}
std::unique_ptr<Cursor> Once::MakeCursor(GraphDbAccessor &) {
return std::make_unique<OnceCursor>();
}
void Once::OnceCursor::Reset() { did_pull_ = false; }
CreateNode::CreateNode(const NodeAtom *node_atom,
@ -1316,14 +1309,12 @@ Merge::Merge(const std::shared_ptr<LogicalOperator> input,
merge_match_(merge_match),
merge_create_(merge_create) {}
void Merge::Accept(LogicalOperatorVisitor &visitor) {
bool Merge::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
input_->Accept(visitor);
merge_match_->Accept(visitor);
merge_create_->Accept(visitor);
visitor.PostVisit(*this);
input_->Accept(visitor) && merge_match_->Accept(visitor) &&
merge_create_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
std::unique_ptr<Cursor> Merge::MakeCursor(GraphDbAccessor &db) {
@ -1384,13 +1375,11 @@ Optional::Optional(const std::shared_ptr<LogicalOperator> &input,
optional_(optional),
optional_symbols_(optional_symbols) {}
void Optional::Accept(LogicalOperatorVisitor &visitor) {
bool Optional::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
input_->Accept(visitor);
optional_->Accept(visitor);
visitor.PostVisit(*this);
input_->Accept(visitor) && optional_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
std::unique_ptr<Cursor> Optional::MakeCursor(GraphDbAccessor &db) {

View File

@ -14,8 +14,7 @@
#include "query/common.hpp"
#include "query/frontend/semantic/symbol_table.hpp"
#include "utils/hashing/fnv.hpp"
#include "utils/visitor/visitable.hpp"
#include "utils/visitor/visitor.hpp"
#include "utils/visitor.hpp"
namespace query {
@ -76,21 +75,37 @@ class Optional;
class Unwind;
class Distinct;
/** @brief Base class for visitors of @c LogicalOperator class hierarchy. */
using LogicalOperatorVisitor = ::utils::Visitor<
using LogicalOperatorCompositeVisitor = ::utils::CompositeVisitor<
Once, CreateNode, CreateExpand, ScanAll, Expand, Filter, Produce, Delete,
SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels,
ExpandUniquenessFilter<VertexAccessor>,
ExpandUniquenessFilter<EdgeAccessor>, Accumulate, AdvanceCommand, Aggregate,
Skip, Limit, OrderBy, Merge, Optional, Unwind, Distinct>;
using LogicalOperatorLeafVisitor = ::utils::LeafVisitor<Once>;
/**
* @brief Base class for hierarhical visitors of @c LogicalOperator class
* hierarchy.
*/
class HierarchicalLogicalOperatorVisitor
: public LogicalOperatorCompositeVisitor,
public LogicalOperatorLeafVisitor {
public:
using LogicalOperatorCompositeVisitor::PreVisit;
using LogicalOperatorCompositeVisitor::PostVisit;
using typename LogicalOperatorLeafVisitor::ReturnType;
using LogicalOperatorLeafVisitor::Visit;
};
/** @brief Base class for logical operators.
*
* Each operator describes an operation, which is to be performed on the
* database. Operators are iterated over using a @c Cursor. Various operators
* can serve as inputs to others and thus a sequence of operations is formed.
*/
class LogicalOperator : public ::utils::Visitable<LogicalOperatorVisitor> {
class LogicalOperator
: public ::utils::Visitable<HierarchicalLogicalOperatorVisitor> {
public:
/** @brief Constructs a @c Cursor which is used to run this operator.
*
@ -121,7 +136,7 @@ class LogicalOperator : public ::utils::Visitable<LogicalOperatorVisitor> {
*/
class Once : public LogicalOperator {
public:
void Accept(LogicalOperatorVisitor &visitor) override;
DEFVISITABLE(HierarchicalLogicalOperatorVisitor);
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
private:
@ -155,7 +170,7 @@ class CreateNode : public LogicalOperator {
*/
CreateNode(const NodeAtom *node_atom,
const std::shared_ptr<LogicalOperator> &input);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
private:
@ -210,7 +225,7 @@ class CreateExpand : public LogicalOperator {
CreateExpand(const NodeAtom *node_atom, const EdgeAtom *edge_atom,
const std::shared_ptr<LogicalOperator> &input,
Symbol input_symbol, bool existing_node);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
private:
@ -276,7 +291,7 @@ class ScanAll : public LogicalOperator {
ScanAll(const NodeAtom *node_atom,
const std::shared_ptr<LogicalOperator> &input,
GraphView graph_view = GraphView::OLD);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
private:
@ -347,7 +362,7 @@ class Expand : public LogicalOperator {
const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
bool existing_node, bool existing_edge,
GraphView graph_view = GraphView::AS_IS);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
private:
@ -434,7 +449,7 @@ class Filter : public LogicalOperator {
public:
Filter(const std::shared_ptr<LogicalOperator> &input_,
Expression *expression_);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
private:
@ -469,7 +484,7 @@ class Produce : public LogicalOperator {
public:
Produce(const std::shared_ptr<LogicalOperator> &input,
const std::vector<NamedExpression *> named_expressions);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
std::vector<Symbol> OutputSymbols(const SymbolTable &) override;
const std::vector<NamedExpression *> &named_expressions();
@ -501,7 +516,7 @@ class Delete : public LogicalOperator {
public:
Delete(const std::shared_ptr<LogicalOperator> &input_,
const std::vector<Expression *> &expressions, bool detach_);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
private:
@ -535,7 +550,7 @@ class SetProperty : public LogicalOperator {
public:
SetProperty(const std::shared_ptr<LogicalOperator> &input,
PropertyLookup *lhs, Expression *rhs);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
private:
@ -579,7 +594,7 @@ class SetProperties : public LogicalOperator {
SetProperties(const std::shared_ptr<LogicalOperator> &input,
Symbol input_symbol, Expression *rhs, Op op);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
private:
@ -619,7 +634,7 @@ class SetLabels : public LogicalOperator {
public:
SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<GraphDbTypes::Label> &labels);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
private:
@ -647,7 +662,7 @@ class RemoveProperty : public LogicalOperator {
public:
RemoveProperty(const std::shared_ptr<LogicalOperator> &input,
PropertyLookup *lhs);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
private:
@ -678,7 +693,7 @@ class RemoveLabels : public LogicalOperator {
RemoveLabels(const std::shared_ptr<LogicalOperator> &input,
Symbol input_symbol,
const std::vector<GraphDbTypes::Label> &labels);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
private:
@ -725,7 +740,7 @@ class ExpandUniquenessFilter : public LogicalOperator {
ExpandUniquenessFilter(const std::shared_ptr<LogicalOperator> &input,
Symbol expand_symbol,
const std::vector<Symbol> &previous_symbols);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
private:
@ -777,7 +792,7 @@ class Accumulate : public LogicalOperator {
public:
Accumulate(const std::shared_ptr<LogicalOperator> &input,
const std::vector<Symbol> &symbols, bool advance_command = false);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
const auto &symbols() const { return symbols_; };
@ -837,7 +852,7 @@ class Aggregate : public LogicalOperator {
const std::vector<Element> &aggregations,
const std::vector<Expression *> &group_by,
const std::vector<Symbol> &remember);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
const auto &aggregations() const { return aggregations_; }
@ -944,7 +959,7 @@ class Aggregate : public LogicalOperator {
class Skip : public LogicalOperator {
public:
Skip(const std::shared_ptr<LogicalOperator> &input, Expression *expression);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
std::vector<Symbol> OutputSymbols(const SymbolTable &) override;
@ -987,7 +1002,7 @@ class Skip : public LogicalOperator {
class Limit : public LogicalOperator {
public:
Limit(const std::shared_ptr<LogicalOperator> &input, Expression *expression);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
std::vector<Symbol> OutputSymbols(const SymbolTable &) override;
@ -1028,7 +1043,7 @@ class OrderBy : public LogicalOperator {
OrderBy(const std::shared_ptr<LogicalOperator> &input,
const std::vector<std::pair<Ordering, Expression *>> &order_by,
const std::vector<Symbol> &output_symbols);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
std::vector<Symbol> OutputSymbols(const SymbolTable &) override;
@ -1102,7 +1117,7 @@ class Merge : public LogicalOperator {
Merge(const std::shared_ptr<LogicalOperator> input,
const std::shared_ptr<LogicalOperator> merge_match,
const std::shared_ptr<LogicalOperator> merge_create);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
auto input() const { return input_; }
@ -1147,7 +1162,7 @@ class Optional : public LogicalOperator {
Optional(const std::shared_ptr<LogicalOperator> &input,
const std::shared_ptr<LogicalOperator> &optional,
const std::vector<Symbol> &optional_symbols);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
auto input() const { return input_; }
@ -1188,7 +1203,7 @@ class Unwind : public LogicalOperator {
public:
Unwind(const std::shared_ptr<LogicalOperator> &input,
Expression *input_expression_, Symbol output_symbol);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
private:
@ -1225,7 +1240,7 @@ class Distinct : public LogicalOperator {
public:
Distinct(const std::shared_ptr<LogicalOperator> &input,
const std::vector<Symbol> &value_symbols);
void Accept(LogicalOperatorVisitor &visitor) override;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
std::vector<Symbol> OutputSymbols(const SymbolTable &) override;

View File

@ -103,16 +103,23 @@ auto GenCreate(Create &create, LogicalOperator *input_op,
}
// Collects symbols from identifiers found in visited AST nodes.
class UsedSymbolsCollector : public TreeVisitorBase {
class UsedSymbolsCollector : public HierarchicalTreeVisitor {
public:
UsedSymbolsCollector(const SymbolTable &symbol_table)
: symbol_table_(symbol_table) {}
using TreeVisitorBase::Visit;
void Visit(Identifier &ident) override {
using HierarchicalTreeVisitor::PreVisit;
using HierarchicalTreeVisitor::PostVisit;
using typename HierarchicalTreeVisitor::ReturnType;
using HierarchicalTreeVisitor::Visit;
ReturnType Visit(Identifier &ident) override {
symbols_.insert(symbol_table_.at(ident));
return true;
}
ReturnType Visit(PrimitiveLiteral &) override { return true; }
std::unordered_set<Symbol> symbols_;
const SymbolTable &symbol_table_;
};
@ -367,7 +374,7 @@ auto GenMatches(std::vector<Match *> &matches, LogicalOperator *input_op,
//
// In addition to the above, we collect information on used symbols,
// aggregations and expressions used for group by.
class ReturnBodyContext : public TreeVisitorBase {
class ReturnBodyContext : public HierarchicalTreeVisitor {
public:
ReturnBodyContext(const ReturnBody &body, SymbolTable &symbol_table,
const std::unordered_set<Symbol> &bound_symbols,
@ -408,17 +415,21 @@ class ReturnBodyContext : public TreeVisitorBase {
}
}
using TreeVisitorBase::PreVisit;
using TreeVisitorBase::Visit;
using TreeVisitorBase::PostVisit;
using HierarchicalTreeVisitor::PreVisit;
using HierarchicalTreeVisitor::Visit;
using HierarchicalTreeVisitor::PostVisit;
void Visit(PrimitiveLiteral &) override {
bool Visit(PrimitiveLiteral &) override {
has_aggregation_.emplace_back(false);
return true;
}
void Visit(ListLiteral &) override { has_aggregation_.emplace_back(false); }
bool PreVisit(ListLiteral &) override {
has_aggregation_.emplace_back(false);
return true;
}
void Visit(Identifier &ident) override {
bool Visit(Identifier &ident) override {
const auto &symbol = symbol_table_.at(ident);
if (std::find(output_symbols_.begin(), output_symbols_.end(), symbol) ==
output_symbols_.end()) {
@ -427,10 +438,11 @@ class ReturnBodyContext : public TreeVisitorBase {
used_symbols_.insert(symbol);
}
has_aggregation_.emplace_back(false);
return true;
}
#define VISIT_BINARY_OPERATOR(BinaryOperator) \
void PostVisit(BinaryOperator &op) override { \
bool PostVisit(BinaryOperator &op) override { \
/* has_aggregation_ stack is reversed, last result is from the 2nd \
* expression. */ \
bool aggr2 = has_aggregation_.back(); \
@ -445,6 +457,7 @@ class ReturnBodyContext : public TreeVisitorBase {
} \
/* Propagate that this whole expression may contain an aggregation. */ \
has_aggregation_.emplace_back(has_aggr); \
return true; \
}
VISIT_BINARY_OPERATOR(OrOperator)
@ -464,7 +477,7 @@ class ReturnBodyContext : public TreeVisitorBase {
#undef VISIT_BINARY_OPERATOR
void PostVisit(Aggregation &aggr) override {
bool PostVisit(Aggregation &aggr) override {
// Aggregation contains a virtual symbol, where the result will be stored.
const auto &symbol = symbol_table_.at(aggr);
aggregations_.emplace_back(aggr.expression_, aggr.op_, symbol);
@ -477,13 +490,15 @@ class ReturnBodyContext : public TreeVisitorBase {
// Possible optimization is to skip remembering symbols inside aggregation.
// If and when implementing this, don't forget that Accumulate needs *all*
// the symbols, including those inside aggregation.
return true;
}
void PostVisit(NamedExpression &named_expr) override {
bool PostVisit(NamedExpression &named_expr) override {
if (!has_aggregation_.back()) {
group_by_.emplace_back(named_expr.expression_);
}
has_aggregation_.pop_back();
return true;
}
// Creates NamedExpression with an Identifier for each user declared symbol.

261
src/utils/visitor.hpp Normal file
View File

@ -0,0 +1,261 @@
/// @file visitor.hpp
///
/// @brief This file contains the generic implementation of visitor pattern.
///
/// There are 2 approaches to the pattern:
///
/// * classic visitor pattern using @c Accept and @c Visit methods, and
/// * hierarchical visitor which also uses @c PreVisit and @c PostVisit
/// methods.
///
/// Classic Visitor
/// ===============
///
/// Explanation on the classic visitor pattern can be found from many
/// sources, but here is the link to hopefully most easily accessible
/// information: https://en.wikipedia.org/wiki/Visitor_pattern
///
/// The idea behind the generic implementation of classic visitor pattern is to
/// allow returning any type via @c Accept and @c Visit methods. Traversing the
/// class hierarchy is relegated to the visitor classes. Therefore, visitor
/// should call @c Accept on children when visiting their parents. To implement
/// such a visitor refer to @c Visitor and @c Visitable classes.
///
/// Hierarchical Visitor
/// ====================
///
/// Unlike the classic visitor, the intent of this design is to allow the
/// visited structure itself to control the traversal. This way the internal
/// children structure of classes can remain private. On the other hand,
/// visitors may want to differentiate visiting composite types from leaf types.
/// Composite types are those which contain visitable children, unlike the leaf
/// nodes. Differentiation is accomplished by providing @c PreVisit and @c
/// PostVisit methods, which should be called inside @c Accept of composite
/// types. Regular @c Visit is only called inside @c Accept of leaf types.
/// To implement such a visitor refer to @c CompositeVisitor, @c LeafVisitor and
/// @c Visitable classes.
///
/// Implementation of hierarchical visiting is modelled after:
/// http://wiki.c2.com/?HierarchicalVisitorPattern
#pragma once
namespace utils {
// Don't use anonymous namespace, because each translation unit will then get a
// unique type. This may cause errors if one wants to check the type.
namespace detail {
template <typename R, class... T>
class VisitorBase;
template <typename R, class Head, class... Tail>
class VisitorBase<R, Head, Tail...> : public VisitorBase<R, Tail...> {
public:
using typename VisitorBase<R, Tail...>::ReturnType;
using VisitorBase<R, Tail...>::Visit;
virtual ReturnType Visit(Head &) = 0;
};
template <typename R, class T>
class VisitorBase<R, T> {
public:
/// @brief ReturnType of the @c Visit method.
using ReturnType = R;
virtual ~VisitorBase() = default;
/// @brief Visit an instance of @c T.
virtual ReturnType Visit(T &) = 0;
};
template <class... T>
class CompositeVisitorBase;
template <class Head, class... Tail>
class CompositeVisitorBase<Head, Tail...>
: public CompositeVisitorBase<Tail...> {
public:
virtual bool PreVisit(Head &) { return true; }
virtual bool PostVisit(Head &) { return true; }
using CompositeVisitorBase<Tail...>::PreVisit;
using CompositeVisitorBase<Tail...>::PostVisit;
};
template <class T>
class CompositeVisitorBase<T> {
public:
/// @brief Start visiting an instance of *composite* type @c TVisitable.
///
/// This function should be used to control whether the visitor should be sent
/// further down the tree of classes. It is only called at the start of
/// @c Accept method of a composite type. The default implementation returns
/// true, which means that the visiting should continue.
///
/// @return bool indicating whether to continue visiting.
virtual bool PreVisit(T &) { return true; }
/// @brief Finish visiting an instance of *composite* type @c TVisitable.
///
/// This function should be used to control whether the visitor should be sent
/// to the siblings of currently visited instance. It is called at the end of
/// @c Accept method in a composite type. The default implementation returns
/// true, which means that visiting should continue.
///
/// @return bool indicating whether to continue visiting.
virtual bool PostVisit(T &) { return true; }
};
} // namespace detail
/// @brief Inherit from this class if you want to visit TVisitable types.
///
/// This visitor is the standard implementation of visitor pattern, where the
/// traversal should be done in the visitor implementation itself. This is
/// different from @c CompositeVisitor, where the traversal is handled by
/// visited classes. Therefore, this visitor contains only the @c Visit method,
/// which has a generic @c ReturnType.
///
/// Example usage:
/// @code
/// // Typedef for convenience or to establish a base class of visitors.
/// typedef Visitor<TypedValue, Identifier, AddOp> ExpressionVisitorBase;
/// class ExpressionVisitor : public ExpressionVisitorBase {
/// public:
/// using ExpressionVisitorBase::Visit;
///
/// TypedValue Visit(Identifier &ident) override {
/// // Visiting Identifier returns the value of it from execution frame.
/// return frame_[ident];
/// }
/// TypedValue Visit(AddOp &add_op) override {
/// // Visiting '+' sums the evaluation of both sides.
/// auto res1 = add_op.expression1_->Accept(*this);
/// auto res2 = add_op.expression2_->Accept(*this);
/// return res1 + res2;
/// }
/// };
/// @endcode
///
/// @sa Visitable
/// @sa CompositeVisitor
/// @sa detail::VisitorBase<R, T>::Visit
template <typename TReturn, class... TVisitable>
class Visitor : public detail::VisitorBase<TReturn, TVisitable...> {
public:
using typename detail::VisitorBase<TReturn, TVisitable...>::ReturnType;
using detail::VisitorBase<TReturn, TVisitable...>::Visit;
};
/// @brief Inherit from this class if you want to visit *leaf* TVisitable types.
///
/// This visitor is meant for hierarchical visiting of classes, where the
/// traversal is done by the visited classes themselves. It should be paired
/// with @c CompositeVisitor.
///
/// The @c Visit method should return true, if the visitor wishes to continue
/// traversing the sibling leaf classes.
template <class... TVisitable>
using LeafVisitor = Visitor<bool, TVisitable...>;
/// @brief Inherit from this class if you want to visit *composite* TVisitable
/// types.
///
/// This visitor is meant for hierarchical visiting of classes, where the
/// traversal is done by the visited classes themselves. Therefore, this visitor
/// contains @c PreVisit and @c PostVisit methods which are only called when
/// entering and leaving *composite* classes. It should be paired with @c
/// LeafVisitor. The standard @c Visit method is called only on *leaf* classes,
/// which do not have any visitable children in them. If you wish to use the
/// regular visitor pattern, refer to @c Visitor.
///
/// Example usage:
/// @code
///
/// class ExpressionVisitor
/// : public CompositeVisitor<AddOp>, // AddOp is a composite type
/// public LeafVisitor<Identifier> { // Identifier is a leaf type
/// public:
/// using CompositeVisitor::PreVisit;
/// using CompositeVisitor::PostVisit;
/// using LeafVisitor::Visit;
///
/// bool PreVisit(AddOp &add_op) override {
/// // Custom implementation for *composite* AddOp expression.
/// }
///
/// void Visit(Identifier &identifier) override {
/// // Custom implementation for *leaf* Identifier expression.
/// }
/// };
/// @endcode
///
/// @sa Visitable
/// @sa LeafVisitor
/// @sa Visitor
/// @sa detail::CompositeVisitorBase<T>::PreVisit
/// @sa detail::CompositeVisitorBase<T>::PostVisit
template <class... TVisitable>
class CompositeVisitor : public detail::CompositeVisitorBase<TVisitable...> {
public:
using detail::CompositeVisitorBase<TVisitable...>::PreVisit;
using detail::CompositeVisitorBase<TVisitable...>::PostVisit;
};
/// @brief Inherit from this class to allow visiting from TVisitor class.
///
/// Example usage with @c CompositeVisitor:
/// @code
/// class Expression : public Visitable<ExpressionVisitor> { ... };
///
/// class Identifier : public ExpressionVisitor {
/// public:
/// // Use default Accept implementation, since this is a *leaf* type.
/// DEFVISITABLE(ExpressionVisitor)
/// ....
/// };
///
/// class AddOp : public Expression {
/// public:
/// // Implement custom Accept, since this is a *composite* type.
/// bool Accept(ExpressionVisitor &visitor) override {
/// if (visitor.PreVisit(*this)) {
/// // Send visitor to children. Accept returns bool, which when false
/// // should stop the traversal to siblings.
/// expression1_->Accept(*this) && expression2_->Accept(*this);
/// }
/// return visitor.PostVisit(*this);
/// }
/// ...
///
/// private:
/// Expression *expression1_;
/// Expression *expression1_;
/// ...
/// };
/// @endcode
///
/// @sa DEFVISITABLE
/// @sa Visitor
/// @sa LeafVisitor
/// @sa CompositeVisitor
template <class TVisitor>
class Visitable {
public:
virtual ~Visitable() = default;
/// @brief Accept the @c TVisitor instance and call its @c Visit method.
virtual typename TVisitor::ReturnType Accept(TVisitor &) = 0;
/// Default implementation for @c utils::Visitable::Accept, which works for
/// visitors of @c TVisitor type. This should be used to implement regular
/// @c utils::Visitor, as well as for *leaf* types when accepting
/// @c utils::CompositeVisitor.
///
/// @sa utils::Visitable
#define DEFVISITABLE(TVisitor) \
TVisitor::ReturnType Accept(TVisitor &visitor) override { \
return visitor.Visit(*this); \
}
};
} // namespace utils

View File

@ -1,49 +0,0 @@
/// @file visitable.hpp
#pragma once
namespace utils {
/// Inherit from this class to allow visiting from TVisitor class.
/// Example usage:
///
/// class Expression : public Visitable<ExpressionVisitor> {
/// };
/// class Identifier : public ExpressionVisitor {
/// public:
/// DEFVISITABLE(ExpressionVisitor) // Use default Accept implementation
/// ....
/// };
/// class Literal : public Expression {
/// public:
/// void Accept(ExpressionVisitor &visitor) override {
/// // Implement custom Accept.
/// if (visitor.PreVisit(*this)) {
/// visitor.Visit(*this);
/// ... // e.g. send visitor to children
/// visitor.PostVisit(*this);
/// }
/// }
/// };
///
/// @sa DEFVISITABLE
/// @sa Visitor
template <class TVisitor>
class Visitable {
public:
virtual ~Visitable() = default;
virtual void Accept(TVisitor &) = 0;
/// Default implementation for @c Accept, which works for visitors of
/// @c TVisitor type.
/// @sa utils::Visitable
#define DEFVISITABLE(TVisitor) \
void Accept(TVisitor &visitor) override { \
if (visitor.PreVisit(*this)) { \
visitor.Visit(*this); \
visitor.PostVisit(*this); \
} \
}
};
} // namespace utils

View File

@ -1,71 +0,0 @@
#pragma once
namespace utils {
// Don't use anonymous namespace, because each translation unit will then get a
// unique type. This may cause errors if one wants to check the type.
namespace detail {
template <typename T>
class VisitorBase {
public:
virtual ~VisitorBase() = default;
virtual bool PreVisit(T &) { return true; }
virtual void Visit(T &) {}
virtual void PostVisit(T &) {}
};
template <typename... T>
class RecursiveVisitorBase;
template <typename Head, typename... Tail>
class RecursiveVisitorBase<Head, Tail...>
: public VisitorBase<Head>, public RecursiveVisitorBase<Tail...> {
public:
using VisitorBase<Head>::PreVisit;
using VisitorBase<Head>::Visit;
using VisitorBase<Head>::PostVisit;
using RecursiveVisitorBase<Tail...>::PreVisit;
using RecursiveVisitorBase<Tail...>::Visit;
using RecursiveVisitorBase<Tail...>::PostVisit;
};
template <typename T>
class RecursiveVisitorBase<T> : public VisitorBase<T> {
public:
using VisitorBase<T>::PreVisit;
using VisitorBase<T>::Visit;
using VisitorBase<T>::PostVisit;
};
} // namespace detail
/// @brief Inherit from this class if you want to visit TVisitable types.
///
/// Example usage:
///
/// // Typedef for convenience or to establish a base class of visitors.
/// typedef Visitor<Identifier, Literal> ExpressionVisitorBase;
/// class ExpressionVisitor : public ExpressionVisitorBase {
/// public:
/// using ExpressionVisitorBase::PreVisit;
/// using ExpressionVisitorBase::Visit;
/// using ExpressionVisitorBase::PostVisit;
///
/// void Visit(Identifier &identifier) override {
/// // Custom implementation of visiting Identifier.
/// }
/// };
///
/// @sa Visitable
template <typename... TVisitable>
class Visitor : public detail::RecursiveVisitorBase<TVisitable...> {
public:
using detail::RecursiveVisitorBase<TVisitable...>::PreVisit;
using detail::RecursiveVisitorBase<TVisitable...>::Visit;
using detail::RecursiveVisitorBase<TVisitable...>::PostVisit;
};
} // namespace utils

View File

@ -29,37 +29,41 @@ class BaseOpChecker {
virtual void CheckOp(LogicalOperator &, const SymbolTable &) = 0;
};
class PlanChecker : public LogicalOperatorVisitor {
class PlanChecker : public HierarchicalLogicalOperatorVisitor {
public:
using LogicalOperatorVisitor::PreVisit;
using LogicalOperatorVisitor::Visit;
using LogicalOperatorVisitor::PostVisit;
using HierarchicalLogicalOperatorVisitor::PreVisit;
using HierarchicalLogicalOperatorVisitor::Visit;
using HierarchicalLogicalOperatorVisitor::PostVisit;
PlanChecker(const std::list<BaseOpChecker *> &checkers,
const SymbolTable &symbol_table)
: checkers_(checkers), symbol_table_(symbol_table) {}
void Visit(CreateNode &op) override { CheckOp(op); }
void Visit(CreateExpand &op) override { CheckOp(op); }
void Visit(Delete &op) override { CheckOp(op); }
void Visit(ScanAll &op) override { CheckOp(op); }
void Visit(Expand &op) override { CheckOp(op); }
void Visit(Filter &op) override { CheckOp(op); }
void Visit(Produce &op) override { CheckOp(op); }
void Visit(SetProperty &op) override { CheckOp(op); }
void Visit(SetProperties &op) override { CheckOp(op); }
void Visit(SetLabels &op) override { CheckOp(op); }
void Visit(RemoveProperty &op) override { CheckOp(op); }
void Visit(RemoveLabels &op) override { CheckOp(op); }
void Visit(ExpandUniquenessFilter<VertexAccessor> &op) override {
CheckOp(op);
#define PRE_VISIT(TOp) \
bool PreVisit(TOp &op) override { \
CheckOp(op); \
return true; \
}
void Visit(ExpandUniquenessFilter<EdgeAccessor> &op) override { CheckOp(op); }
void Visit(Accumulate &op) override { CheckOp(op); }
void Visit(Aggregate &op) override { CheckOp(op); }
void Visit(Skip &op) override { CheckOp(op); }
void Visit(Limit &op) override { CheckOp(op); }
void Visit(OrderBy &op) override { CheckOp(op); }
PRE_VISIT(CreateNode);
PRE_VISIT(CreateExpand);
PRE_VISIT(Delete);
PRE_VISIT(ScanAll);
PRE_VISIT(Expand);
PRE_VISIT(Filter);
PRE_VISIT(Produce);
PRE_VISIT(SetProperty);
PRE_VISIT(SetProperties);
PRE_VISIT(SetLabels);
PRE_VISIT(RemoveProperty);
PRE_VISIT(RemoveLabels);
PRE_VISIT(ExpandUniquenessFilter<VertexAccessor>);
PRE_VISIT(ExpandUniquenessFilter<EdgeAccessor>);
PRE_VISIT(Accumulate);
PRE_VISIT(Aggregate);
PRE_VISIT(Skip);
PRE_VISIT(Limit);
PRE_VISIT(OrderBy);
bool PreVisit(Merge &op) override {
CheckOp(op);
op.input()->Accept(*this);
@ -70,8 +74,14 @@ class PlanChecker : public LogicalOperatorVisitor {
op.input()->Accept(*this);
return false;
}
void Visit(Unwind &op) override { CheckOp(op); }
void Visit(Distinct &op) override { CheckOp(op); }
PRE_VISIT(Unwind);
PRE_VISIT(Distinct);
bool Visit(Once &op) override {
// Ignore checking Once, it is implicitly at the end.
return true;
}
#undef PRE_VISIT
std::list<BaseOpChecker *> checkers_;