From 6a547839e183e49bfb95e987e682e9dcdbe2abe9 Mon Sep 17 00:00:00 2001
From: Mislav Bradac <mislav.bradac@memgraph.io>
Date: Mon, 27 Mar 2017 10:04:37 +0200
Subject: [PATCH] Where and delete from antlr to highlevel ast

Reviewers: teon.banek

Reviewed By: teon.banek

Differential Revision: https://phabricator.memgraph.io/D181
---
 src/query/frontend/ast/ast.hpp                | 36 +++++++++++++-
 src/query/frontend/ast/ast_visitor.hpp        |  4 +-
 .../frontend/ast/cypher_main_visitor.cpp      | 27 ++++++++++-
 .../frontend/ast/cypher_main_visitor.hpp      | 11 +++++
 tests/unit/cypher_main_visitor.cpp            | 47 +++++++++++++++++++
 5 files changed, 122 insertions(+), 3 deletions(-)

diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp
index d02fd399f..c02d6cfa1 100644
--- a/src/query/frontend/ast/ast.hpp
+++ b/src/query/frontend/ast/ast.hpp
@@ -496,7 +496,6 @@ class Match : public Clause {
   friend class AstTreeStorage;
 
  public:
-  std::vector<Pattern *> patterns_;
   void Accept(TreeVisitorBase &visitor) override {
     visitor.Visit(*this);
     for (auto &pattern : patterns_) {
@@ -504,6 +503,8 @@ class Match : public Clause {
     }
     visitor.PostVisit(*this);
   }
+  std::vector<Pattern *> patterns_;
+  Where *where_ = nullptr;
 
  protected:
   Match(int uid) : Clause(uid) {}
@@ -526,6 +527,39 @@ class Return : public Clause {
   Return(int uid) : Clause(uid) {}
 };
 
+class Delete : public Clause {
+  friend class AstTreeStorage;
+
+ public:
+  void Accept(TreeVisitorBase &visitor) override {
+    visitor.Visit(*this);
+    for (auto &expr : expressions_) {
+      expr->Accept(visitor);
+    }
+    visitor.PostVisit(*this);
+  }
+  std::vector<Expression *> expressions_;
+  bool detach_ = false;
+
+ protected:
+  Delete(int uid) : Clause(uid) {}
+};
+
+class Where : public Tree {
+  friend class AstTreeStorage;
+
+ public:
+  void Accept(TreeVisitorBase &visitor) override {
+    visitor.Visit(*this);
+    expression_->Accept(visitor);
+    visitor.PostVisit(*this);
+  }
+  Expression *expression_ = nullptr;
+
+ protected:
+  Where(int uid) : Tree(uid) {}
+};
+
 // It would be better to call this AstTree, but we already have a class Tree,
 // which could be renamed to Node or AstTreeNode, but we also have a class
 // called NodeAtom...
diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp
index 73d3a855c..a5f5ef2fa 100644
--- a/src/query/frontend/ast/ast_visitor.hpp
+++ b/src/query/frontend/ast/ast_visitor.hpp
@@ -33,6 +33,8 @@ class LessOperator;
 class GreaterOperator;
 class LessEqualOperator;
 class GreaterEqualOperator;
+class Delete;
+class Where;
 
 using TreeVisitorBase = ::utils::Visitor<
     Query, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator,
@@ -40,5 +42,5 @@ using TreeVisitorBase = ::utils::Visitor<
     DivisionOperator, ModOperator, NotEqualOperator, EqualOperator,
     LessOperator, GreaterOperator, LessEqualOperator, GreaterEqualOperator,
     UnaryPlusOperator, UnaryMinusOperator, Identifier, Literal, PropertyLookup,
-    Create, Match, Return, Pattern, NodeAtom, EdgeAtom>;
+    Create, Match, Return, Pattern, NodeAtom, EdgeAtom, Delete, Where>;
 }
diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp
index e811c5929..abe2fab8f 100644
--- a/src/query/frontend/ast/cypher_main_visitor.cpp
+++ b/src/query/frontend/ast/cypher_main_visitor.cpp
@@ -50,6 +50,10 @@ antlrcpp::Any CypherMainVisitor::visitClause(CypherParser::ClauseContext *ctx) {
   if (ctx->create()) {
     return static_cast<Clause *>(ctx->create()->accept(this).as<Create *>());
   }
+  if (ctx->cypherDelete()) {
+    return static_cast<Clause *>(
+        ctx->cypherDelete()->accept(this).as<Delete *>());
+  }
   // TODO: implement other clauses.
   throw NotYetImplemented();
   return 0;
@@ -58,10 +62,13 @@ antlrcpp::Any CypherMainVisitor::visitClause(CypherParser::ClauseContext *ctx) {
 antlrcpp::Any CypherMainVisitor::visitCypherMatch(
     CypherParser::CypherMatchContext *ctx) {
   auto *match = storage_.Create<Match>();
-  if (ctx->OPTIONAL() || ctx->where()) {
+  if (ctx->OPTIONAL()) {
     // TODO: implement other clauses.
     throw NotYetImplemented();
   }
+  if (ctx->where()) {
+    match->where_ = ctx->where()->accept(this);
+  }
   match->patterns_ = ctx->pattern()->accept(this).as<std::vector<Pattern *>>();
   return match;
 }
@@ -665,5 +672,23 @@ antlrcpp::Any CypherMainVisitor::visitBooleanLiteral(
     throw std::exception();
   }
 }
+
+antlrcpp::Any CypherMainVisitor::visitCypherDelete(
+    CypherParser::CypherDeleteContext *ctx) {
+  auto *del = storage_.Create<Delete>();
+  if (ctx->DETACH()) {
+    del->detach_ = true;
+  }
+  for (auto *expression : ctx->expression()) {
+    del->expressions_.push_back(expression->accept(this));
+  }
+  return del;
+}
+
+antlrcpp::Any CypherMainVisitor::visitWhere(CypherParser::WhereContext *ctx) {
+  auto *where = storage_.Create<Where>();
+  where->expression_ = ctx->expression()->accept(this);
+  return where;
+}
 }
 }
diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp
index fd39d1a84..c73cc3776 100644
--- a/src/query/frontend/ast/cypher_main_visitor.hpp
+++ b/src/query/frontend/ast/cypher_main_visitor.hpp
@@ -400,6 +400,17 @@ class CypherMainVisitor : public antlropencypher::CypherBaseVisitor {
   antlrcpp::Any visitDoubleLiteral(
       CypherParser::DoubleLiteralContext *ctx) override;
 
+  /**
+   * @return Delete*
+   */
+  antlrcpp::Any visitCypherDelete(
+      CypherParser::CypherDeleteContext *ctx) override;
+
+  /**
+   * @return Where*
+   */
+  antlrcpp::Any visitWhere(CypherParser::WhereContext *ctx) override;
+
  public:
   Query *query() { return query_; }
   const static std::string kAnonPrefix;
diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp
index 138f127f7..3e5bd40a6 100644
--- a/tests/unit/cypher_main_visitor.cpp
+++ b/tests/unit/cypher_main_visitor.cpp
@@ -367,6 +367,7 @@ TEST(CypherMainVisitorTest, NodePattern) {
   ASSERT_EQ(query->clauses_.size(), 1U);
   auto *match = dynamic_cast<Match *>(query->clauses_[0]);
   ASSERT_TRUE(match);
+  ASSERT_FALSE(match->where_);
   ASSERT_EQ(match->patterns_.size(), 1U);
   ASSERT_TRUE(match->patterns_[0]);
   ASSERT_EQ(match->patterns_[0]->atoms_.size(), 1U);
@@ -396,6 +397,7 @@ TEST(CypherMainVisitorTest, NodePatternIdentifier) {
   AstGenerator ast_generator("MATCH (var)");
   auto *query = ast_generator.query_;
   auto *match = dynamic_cast<Match *>(query->clauses_[0]);
+  ASSERT_FALSE(match->where_);
   auto node = dynamic_cast<NodeAtom *>(match->patterns_[0]->atoms_[0]);
   ASSERT_TRUE(node->identifier_);
   ASSERT_EQ(node->identifier_->name_, "var");
@@ -407,6 +409,7 @@ TEST(CypherMainVisitorTest, RelationshipPatternNoDetails) {
   AstGenerator ast_generator("MATCH ()--()");
   auto *query = ast_generator.query_;
   auto *match = dynamic_cast<Match *>(query->clauses_[0]);
+  ASSERT_FALSE(match->where_);
   ASSERT_EQ(match->patterns_.size(), 1U);
   ASSERT_TRUE(match->patterns_[0]);
   ASSERT_EQ(match->patterns_[0]->atoms_.size(), 3U);
@@ -427,6 +430,7 @@ TEST(CypherMainVisitorTest, PatternPartBraces) {
   AstGenerator ast_generator("MATCH ((()--()))");
   auto *query = ast_generator.query_;
   auto *match = dynamic_cast<Match *>(query->clauses_[0]);
+  ASSERT_FALSE(match->where_);
   ASSERT_EQ(match->patterns_.size(), 1U);
   ASSERT_TRUE(match->patterns_[0]);
   ASSERT_EQ(match->patterns_[0]->atoms_.size(), 3U);
@@ -446,6 +450,7 @@ TEST(CypherMainVisitorTest, RelationshipPatternDetails) {
   AstGenerator ast_generator("MATCH ()<-[:type1|type2 {a : 5, b : 10}]-()");
   auto *query = ast_generator.query_;
   auto *match = dynamic_cast<Match *>(query->clauses_[0]);
+  ASSERT_FALSE(match->where_);
   auto *edge = dynamic_cast<EdgeAtom *>(match->patterns_[0]->atoms_[1]);
   ASSERT_EQ(edge->direction_, EdgeAtom::Direction::LEFT);
   ASSERT_THAT(
@@ -469,6 +474,7 @@ TEST(CypherMainVisitorTest, RelationshipPatternVariable) {
   AstGenerator ast_generator("MATCH ()-[var]->()");
   auto *query = ast_generator.query_;
   auto *match = dynamic_cast<Match *>(query->clauses_[0]);
+  ASSERT_FALSE(match->where_);
   auto *edge = dynamic_cast<EdgeAtom *>(match->patterns_[0]->atoms_[1]);
   ASSERT_EQ(edge->direction_, EdgeAtom::Direction::RIGHT);
   ASSERT_TRUE(edge->identifier_);
@@ -568,4 +574,45 @@ TEST(CypherMainVisitorTest, Create) {
   ASSERT_TRUE(node->identifier_);
   ASSERT_EQ(node->identifier_->name_, "n");
 }
+
+TEST(CypherMainVisitorTest, Delete) {
+  AstGenerator ast_generator("DELETE n, m");
+  auto *query = ast_generator.query_;
+  ASSERT_EQ(query->clauses_.size(), 1U);
+  auto *del = dynamic_cast<Delete *>(query->clauses_[0]);
+  ASSERT_TRUE(del);
+  ASSERT_FALSE(del->detach_);
+  ASSERT_EQ(del->expressions_.size(), 2U);
+  auto *identifier1 = dynamic_cast<Identifier *>(del->expressions_[0]);
+  ASSERT_TRUE(identifier1);
+  ASSERT_EQ(identifier1->name_, "n");
+  auto *identifier2 = dynamic_cast<Identifier *>(del->expressions_[1]);
+  ASSERT_TRUE(identifier2);
+  ASSERT_EQ(identifier2->name_, "m");
+}
+
+TEST(CypherMainVisitorTest, DeleteDetach) {
+  AstGenerator ast_generator("DETACH DELETE n");
+  auto *query = ast_generator.query_;
+  ASSERT_EQ(query->clauses_.size(), 1U);
+  auto *del = dynamic_cast<Delete *>(query->clauses_[0]);
+  ASSERT_TRUE(del);
+  ASSERT_TRUE(del->detach_);
+  ASSERT_EQ(del->expressions_.size(), 1U);
+  auto *identifier1 = dynamic_cast<Identifier *>(del->expressions_[0]);
+  ASSERT_TRUE(identifier1);
+  ASSERT_EQ(identifier1->name_, "n");
+}
+
+TEST(Visitor, MatchWhere) {
+  AstGenerator ast_generator("MATCH (n) WHERE n");
+  auto *query = ast_generator.query_;
+  ASSERT_EQ(query->clauses_.size(), 1U);
+  auto *match = dynamic_cast<Match *>(query->clauses_[0]);
+  ASSERT_TRUE(match);
+  ASSERT_TRUE(match->where_);
+  auto *identifier = dynamic_cast<Identifier *>(match->where_->expression_);
+  ASSERT_TRUE(identifier);
+  ASSERT_EQ(identifier->name_, "n");
+}
 }