Remove the Coalesce function

Reviewers: teon.banek, mtomic

Reviewed By: teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1676
This commit is contained in:
Lovro Lugovic 2018-10-25 14:51:17 +02:00
parent d87664d5db
commit 06ae2ffecc
12 changed files with 136 additions and 76 deletions

View File

@ -882,6 +882,46 @@ cpp<#
cpp<#)
(:serialize :capnp))
(lcp:define-class coalesce (expression)
((expressions "std::vector<Expression *>"
:scope :public
:capnp-type "List(Tree)"
:capnp-save (save-ast-vector "Expression *")
:capnp-load (load-ast-vector "Expression *")
:documentation "A list of expressions to evaluate. None of the expressions should be nullptr."))
(:public
#>cpp
Coalesce() = default;
DEFVISITABLE(TreeVisitor<TypedValue>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
for (auto *expr : expressions_) {
if (!expr->Accept(visitor)) break;
}
}
return visitor.PostVisit(*this);
}
Coalesce *Clone(AstStorage &storage) const override {
std::vector<Expression *> expressions;
expressions.reserve(expressions_.size());
for (const auto &expr : expressions_) {
expressions.emplace_back(expr->Clone(storage));
}
return storage.Create<Coalesce>(std::move(expressions));
}
cpp<#
)
(:private
#>cpp
Coalesce(int uid, const std::vector<Expression *> &expressions)
: Expression(uid), expressions_(expressions) {}
friend class AstStorage;
cpp<#)
(:serialize :capnp))
(lcp:define-class extract (expression)
((identifier "Identifier *" :initval "nullptr" :scope :public
:capnp-type "Tree" :capnp-init nil

View File

@ -15,6 +15,7 @@ class LabelsTest;
class Aggregation;
class Function;
class Reduce;
class Coalesce;
class Extract;
class All;
class Single;
@ -73,9 +74,10 @@ using TreeCompositeVisitor = ::utils::CompositeVisitor<
LessEqualOperator, GreaterEqualOperator, InListOperator, SubscriptOperator,
ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator,
IsNullOperator, ListLiteral, MapLiteral, PropertyLookup, LabelsTest,
Aggregation, Function, Reduce, Extract, All, Single, Create, Match, Return,
With, Pattern, NodeAtom, EdgeAtom, Delete, Where, SetProperty,
SetProperties, SetLabels, RemoveProperty, RemoveLabels, Merge, Unwind>;
Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Create,
Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete, Where,
SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, Merge,
Unwind>;
using TreeLeafVisitor =
::utils::LeafVisitor<Identifier, PrimitiveLiteral, ParameterLookup,
@ -100,9 +102,9 @@ using TreeVisitor = ::utils::Visitor<
InListOperator, SubscriptOperator, ListSlicingOperator, IfOperator,
UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral,
MapLiteral, PropertyLookup, LabelsTest, Aggregation, Function, Reduce,
Extract, All, Single, ParameterLookup, Create, Match, Return, With, Pattern,
NodeAtom, EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels,
RemoveProperty, RemoveLabels, Merge, Unwind, Identifier, PrimitiveLiteral,
IndexQuery, AuthQuery, StreamQuery>;
Coalesce, Extract, All, Single, ParameterLookup, Create, Match, Return,
With, Pattern, NodeAtom, EdgeAtom, Delete, Where, SetProperty,
SetProperties, SetLabels, RemoveProperty, RemoveLabels, Merge, Unwind,
Identifier, PrimitiveLiteral, IndexQuery, AuthQuery, StreamQuery>;
} // namespace query

View File

@ -1318,6 +1318,13 @@ antlrcpp::Any CypherMainVisitor::visitAtom(MemgraphCypher::AtomContext *ctx) {
return static_cast<Expression *>(storage_->Create<Identifier>(variable));
} else if (ctx->functionInvocation()) {
return static_cast<Expression *>(ctx->functionInvocation()->accept(this));
} else if (ctx->COALESCE()) {
std::vector<Expression *> exprs;
for (auto *expr_context : ctx->expression()) {
exprs.emplace_back(expr_context->accept(this).as<Expression *>());
}
return static_cast<Expression *>(
storage_->Create<Coalesce>(std::move(exprs)));
} else if (ctx->COUNT()) {
// Here we handle COUNT(*). COUNT(expression) is handled in
// visitFunctionInvocation with other aggregations. This is visible in

View File

@ -194,6 +194,7 @@ atom : literal
| ( FILTER '(' filterExpression ')' )
| ( EXTRACT '(' extractExpression ')' )
| ( REDUCE '(' reduceExpression ')' )
| ( COALESCE '(' expression ( ',' expression )* ')' )
| ( ALL '(' filterExpression ')' )
| ( ANY '(' filterExpression ')' )
| ( NONE '(' filterExpression ')' )

View File

@ -77,6 +77,7 @@ ASCENDING : A S C E N D I N G ;
BFS : B F S ;
BY : B Y ;
CASE : C A S E ;
COALESCE : C O A L E S C E ;
CONTAINS : C O N T A I N S ;
COUNT : C O U N T ;
CREATE : C R E A T E ;

View File

@ -79,18 +79,18 @@ class Trie {
const int kBitsetSize = 65536;
const trie::Trie kKeywords = {
"union", "all", "optional", "match", "unwind", "as",
"merge", "on", "create", "set", "detach", "delete",
"remove", "with", "distinct", "return", "order", "by",
"skip", "limit", "ascending", "asc", "descending", "desc",
"where", "or", "xor", "and", "not", "in",
"starts", "ends", "contains", "is", "null", "case",
"when", "then", "else", "end", "count", "filter",
"extract", "any", "none", "single", "true", "false",
"reduce", "user", "password", "alter", "drop", "stream",
"streams", "load", "data", "kafka", "transform", "batch",
"interval", "show", "start", "stop", "size", "topic",
"test", "unique", "explain"};
"union", "all", "optional", "match", "unwind", "as",
"merge", "on", "create", "set", "detach", "delete",
"remove", "with", "distinct", "return", "order", "by",
"skip", "limit", "ascending", "asc", "descending", "desc",
"where", "or", "xor", "and", "not", "in",
"starts", "ends", "contains", "is", "null", "case",
"when", "then", "else", "end", "count", "filter",
"extract", "any", "none", "single", "true", "false",
"reduce", "coalesce", "user", "password", "alter", "drop",
"stream", "streams", "load", "data", "kafka", "transform",
"batch", "interval", "show", "start", "stop", "size",
"topic", "test", "unique", "explain"};
// Unicode codepoints that are allowed at the start of the unescaped name.
const std::bitset<kBitsetSize> kUnescapedNameAllowedStarts(std::string(

View File

@ -32,21 +32,6 @@ namespace {
// TODO: Implement degrees, haversin, radians
// TODO: Implement spatial functions
TypedValue Coalesce(TypedValue *args, int64_t nargs, const EvaluationContext &,
database::GraphDbAccessor *) {
// TODO: Perhaps this function should be done by the evaluator itself, so as
// to avoid evaluating all the arguments.
if (nargs == 0) {
throw QueryRuntimeException("'coalesce' requires at least one argument.");
}
for (int64_t i = 0; i < nargs; ++i) {
if (args[i].type() != TypedValue::Type::Null) {
return args[i];
}
}
return TypedValue::Null;
}
TypedValue EndNode(TypedValue *args, int64_t nargs, const EvaluationContext &,
database::GraphDbAccessor *) {
if (nargs != 1) {
@ -887,7 +872,6 @@ std::function<TypedValue(TypedValue *, int64_t, const EvaluationContext &,
database::GraphDbAccessor *)>
NameToFunction(const std::string &function_name) {
// Scalar functions
if (function_name == "COALESCE") return Coalesce;
if (function_name == "DEGREE") return Degree;
if (function_name == "INDEGREE") return InDegree;
if (function_name == "OUTDEGREE") return OutDegree;

View File

@ -13,7 +13,6 @@ namespace {
const char kStartsWith[] = "STARTSWITH";
const char kEndsWith[] = "ENDSWITH";
const char kContains[] = "CONTAINS";
const char kCoalesce[] = "COALESCE";
} // namespace
/// Return the function implementation with the given name.

View File

@ -367,25 +367,24 @@ class ExpressionEvaluator : public TreeVisitor<TypedValue> {
return value;
}
TypedValue Visit(Function &function) override {
// Handle COALESCE specially -- evaluate the arguments in order until one of
// them produces a non-null value.
if (function.function_name_ == kCoalesce) {
if (function.arguments_.size() == 0) {
throw QueryRuntimeException(
"'coalesce' requires at least one argument.");
}
TypedValue Visit(Coalesce &coalesce) override {
auto &exprs = coalesce.expressions_;
for (int64_t i = 0; i < function.arguments_.size(); ++i) {
TypedValue val = function.arguments_[i]->Accept(*this);
if (val.type() != TypedValue::Type::Null) {
return val;
}
}
return TypedValue::Null;
if (exprs.size() == 0) {
throw QueryRuntimeException("'coalesce' requires at least one argument.");
}
for (int64_t i = 0; i < exprs.size(); ++i) {
TypedValue val = exprs[i]->Accept(*this);
if (!val.IsNull()) {
return val;
}
}
return TypedValue::Null;
}
TypedValue Visit(Function &function) override {
// Stack allocate evaluated arguments when there's a small number of them.
if (function.arguments_.size() <= 8) {
TypedValue arguments[8];

View File

@ -182,7 +182,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
public:
bool PostVisit(ListLiteral &list_literal) override {
DCHECK(list_literal.elements_.size() <= has_aggregation_.size())
<< "Expected has_aggregation_ flags as much as there are list "
<< "Expected as many has_aggregation_ flags as there are list"
"elements.";
PostVisitCollectionLiteral(list_literal, [](auto it) { return *it; });
return true;
@ -241,6 +241,19 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
return true;
}
bool PostVisit(Coalesce &coalesce) override {
CHECK(has_aggregation_.size() >= coalesce.expressions_.size())
<< "Expected >= " << has_aggregation_.size()
<< "has_aggregation_ flags for COALESCE arguments";
bool has_aggr = false;
for (int i = 0; i < coalesce.expressions_.size(); ++i) {
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
bool PostVisit(Extract &extract) override {
// Remove the symbol bound by extract, because we are only interested
// in free (unbound) symbols.
@ -310,7 +323,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
bool PostVisit(Function &function) override {
DCHECK(function.arguments_.size() <= has_aggregation_.size())
<< "Expected has_aggregation_ flags as much as there are "
<< "Expected as many has_aggregation_ flags as there are"
"function arguments.";
bool has_aggr = false;
auto it = has_aggregation_.end();

View File

@ -602,6 +602,8 @@ auto GetMerge(AstStorage &storage, Pattern *pattern, OnMatch on_match,
storage.Create<query::Reduce>( \
storage.Create<query::Identifier>(accumulator), initializer, \
storage.Create<query::Identifier>(variable), list, expr)
#define COALESCE(...) \
storage.Create<query::Coalesce>(std::vector<query::Expression *>{__VA_ARGS__})
#define EXTRACT(variable, list, expr) \
storage.Create<query::Extract>(storage.Create<query::Identifier>(variable), \
list, expr)

View File

@ -812,6 +812,40 @@ TEST_F(ExpressionEvaluatorTest, FunctionExtractExceptions) {
EXPECT_THROW(extract->Accept(eval), QueryRuntimeException);
}
TEST_F(ExpressionEvaluatorTest, Coalesce) {
// coalesce()
EXPECT_THROW(COALESCE()->Accept(eval), QueryRuntimeException);
// coalesce(null, null)
EXPECT_TRUE(COALESCE(LITERAL(TypedValue::Null), LITERAL(TypedValue::Null))
->Accept(eval)
.IsNull());
// coalesce(null, 2, 3)
EXPECT_EQ(COALESCE(LITERAL(TypedValue::Null), LITERAL(2), LITERAL(3))
->Accept(eval)
.ValueInt(),
2);
// coalesce(null, 2, assert(false), 3)
EXPECT_EQ(COALESCE(LITERAL(TypedValue::Null), LITERAL(2),
FN("ASSERT", LITERAL(false)), LITERAL(3))
->Accept(eval)
.ValueInt(),
2);
// (null, assert(false))
EXPECT_THROW(COALESCE(LITERAL(TypedValue::Null), FN("ASSERT", LITERAL(false)))
->Accept(eval),
QueryRuntimeException);
// coalesce([null, null])
EXPECT_FALSE(COALESCE(LITERAL(TypedValue(std::vector<TypedValue>{
TypedValue::Null, TypedValue::Null})))
->Accept(eval)
.IsNull());
}
class ExpressionEvaluatorPropertyLookup : public ExpressionEvaluatorTest {
protected:
std::pair<std::string, storage::Property> prop_age =
@ -891,28 +925,6 @@ class FunctionTest : public ExpressionEvaluatorTest {
}
};
TEST_F(FunctionTest, Coalesce) {
ASSERT_THROW(EvaluateFunction("COALESCE", {}), QueryRuntimeException);
ASSERT_TRUE(EvaluateFunction("COALESCE", {TypedValue::Null, TypedValue::Null})
.IsNull());
ASSERT_EQ(EvaluateFunction("COALESCE", {TypedValue::Null, 2, 3}).ValueInt(),
2);
// (null, 2, assert(false), 3)
auto expressions1 = ExpressionsFromTypedValues({TypedValue::Null, 2, 3});
expressions1.insert(
expressions1.begin() + 2,
storage.Create<Function>("ASSERT", ExpressionsFromTypedValues({false})));
ASSERT_EQ(EvaluateFunctionWithExprs("COALESCE", expressions1).ValueInt(), 2);
// (null, assert(false))
auto expressions2 = ExpressionsFromTypedValues({TypedValue::Null});
expressions2.push_back(
storage.Create<Function>("ASSERT", ExpressionsFromTypedValues({false})));
ASSERT_THROW(EvaluateFunctionWithExprs("COALESCE", expressions2),
QueryRuntimeException);
}
TEST_F(FunctionTest, EndNode) {
ASSERT_THROW(EvaluateFunction("ENDNODE", {}), QueryRuntimeException);
ASSERT_TRUE(EvaluateFunction("ENDNODE", {TypedValue::Null}).IsNull());