From 1c51ce77ef3572ff6431e4c68a4cd77649102140 Mon Sep 17 00:00:00 2001
From: Mislav Bradac <mislav.bradac@memgraph.io>
Date: Wed, 26 Apr 2017 16:29:57 +0200
Subject: [PATCH] Even more awesome functions

Reviewers: florijan, teon.banek

Reviewed By: teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D319
---
 .../interpret/awesome_memgraph_functions.cpp  | 123 +++++++++++-------
 tests/unit/query_expression_evaluator.cpp     |  63 ++++++---
 2 files changed, 125 insertions(+), 61 deletions(-)

diff --git a/src/query/interpret/awesome_memgraph_functions.cpp b/src/query/interpret/awesome_memgraph_functions.cpp
index 672872160..5be9f0b56 100644
--- a/src/query/interpret/awesome_memgraph_functions.cpp
+++ b/src/query/interpret/awesome_memgraph_functions.cpp
@@ -30,7 +30,8 @@ namespace {
 // return same time. We need to store query start time somwhere.
 // TODO: Implement rest of the list functions.
 // TODO: Implement rand
-// TODO: Implement logarithmic, trigonometric, string and spatial functions
+// TODO: Implement degrees, haversin, radians
+// TODO: Implement string and spatial functions
 
 TypedValue Coalesce(const std::vector<TypedValue> &args, GraphDbAccessor &) {
   if (args.size() == 0U) {
@@ -316,54 +317,61 @@ TypedValue Abs(const std::vector<TypedValue> &args, GraphDbAccessor &) {
   }
 }
 
-TypedValue Ceil(const std::vector<TypedValue> &args, GraphDbAccessor &) {
-  if (args.size() != 1U) {
-    throw QueryRuntimeException("ceil requires one argument");
+#define WRAP_CMATH_FLOAT_FUNCTION(name, lowercased_name)                      \
+  TypedValue name(const std::vector<TypedValue> &args, GraphDbAccessor &) {   \
+    if (args.size() != 1U) {                                                  \
+      throw QueryRuntimeException(#lowercased_name " requires one argument"); \
+    }                                                                         \
+    switch (args[0].type()) {                                                 \
+      case TypedValue::Type::Null:                                            \
+        return TypedValue::Null;                                              \
+      case TypedValue::Type::Int:                                             \
+        return lowercased_name(args[0].Value<int64_t>());                     \
+      case TypedValue::Type::Double:                                          \
+        return lowercased_name(args[0].Value<double>());                      \
+      default:                                                                \
+        throw QueryRuntimeException(#lowercased_name                          \
+                                    " called with incompatible type");        \
+    }                                                                         \
   }
-  switch (args[0].type()) {
-    case TypedValue::Type::Null:
-      return TypedValue::Null;
-    case TypedValue::Type::Int:
-      return ceil(args[0].Value<int64_t>());
-    case TypedValue::Type::Double:
-      return ceil(args[0].Value<double>());
-    default:
-      throw QueryRuntimeException("ceil called with incompatible type");
-  }
-}
-
-TypedValue Floor(const std::vector<TypedValue> &args, GraphDbAccessor &) {
-  if (args.size() != 1U) {
-    throw QueryRuntimeException("floor requires one argument");
-  }
-  switch (args[0].type()) {
-    case TypedValue::Type::Null:
-      return TypedValue::Null;
-    case TypedValue::Type::Int:
-      return floor(args[0].Value<int64_t>());
-    case TypedValue::Type::Double:
-      return floor(args[0].Value<double>());
-    default:
-      throw QueryRuntimeException("floor called with incompatible type");
-  }
-}
 
+WRAP_CMATH_FLOAT_FUNCTION(Ceil, ceil)
+WRAP_CMATH_FLOAT_FUNCTION(Floor, floor)
 // We are not completely compatible with neoj4 in this function because,
 // neo4j rounds -0.5, -1.5, -2.5... to 0, -1, -2...
-TypedValue Round(const std::vector<TypedValue> &args, GraphDbAccessor &) {
-  if (args.size() != 1U) {
-    throw QueryRuntimeException("round requires one argument");
-  }
-  switch (args[0].type()) {
-    case TypedValue::Type::Null:
-      return TypedValue::Null;
-    case TypedValue::Type::Int:
-      return round(args[0].Value<int64_t>());
-    case TypedValue::Type::Double:
-      return round(args[0].Value<double>());
-    default:
-      throw QueryRuntimeException("round called with incompatible type");
+WRAP_CMATH_FLOAT_FUNCTION(Round, round)
+WRAP_CMATH_FLOAT_FUNCTION(Exp, exp)
+WRAP_CMATH_FLOAT_FUNCTION(Log, log)
+WRAP_CMATH_FLOAT_FUNCTION(Log10, log10)
+WRAP_CMATH_FLOAT_FUNCTION(Sqrt, sqrt)
+WRAP_CMATH_FLOAT_FUNCTION(Acos, acos)
+WRAP_CMATH_FLOAT_FUNCTION(Asin, asin)
+WRAP_CMATH_FLOAT_FUNCTION(Atan, atan)
+WRAP_CMATH_FLOAT_FUNCTION(Cos, cos)
+WRAP_CMATH_FLOAT_FUNCTION(Sin, sin)
+WRAP_CMATH_FLOAT_FUNCTION(Tan, tan)
+
+#undef WRAP_CMATH_FLOAT_FUNCTION
+
+TypedValue Atan2(const std::vector<TypedValue> &args, GraphDbAccessor &) {
+  if (args.size() != 2U) {
+    throw QueryRuntimeException("atan2 requires two arguments");
   }
+  if (args[0].type() == TypedValue::Type::Null) return TypedValue::Null;
+  if (args[1].type() == TypedValue::Type::Null) return TypedValue::Null;
+  auto to_double = [](const TypedValue &t) -> double {
+    switch (t.type()) {
+      case TypedValue::Type::Int:
+        return t.Value<int64_t>();
+      case TypedValue::Type::Double:
+        return t.Value<double>();
+      default:
+        throw QueryRuntimeException("atan2 called with incompatible types");
+    }
+  };
+  double y = to_double(args[0]);
+  double x = to_double(args[1]);
+  return atan2(y, x);
 }
 
 TypedValue Sign(const std::vector<TypedValue> &args, GraphDbAccessor &) {
@@ -382,6 +390,20 @@ TypedValue Sign(const std::vector<TypedValue> &args, GraphDbAccessor &) {
       throw QueryRuntimeException("sign called with incompatible type");
   }
 }
+
+TypedValue E(const std::vector<TypedValue> &args, GraphDbAccessor &) {
+  if (args.size() != 0U) {
+    throw QueryRuntimeException("e shouldn't be called with arguments");
+  }
+  return M_E;
+}
+
+TypedValue Pi(const std::vector<TypedValue> &args, GraphDbAccessor &) {
+  if (args.size() != 0U) {
+    throw QueryRuntimeException("pi shouldn't be called with arguments");
+  }
+  return M_PI;
+}
 }
 
 std::function<TypedValue(const std::vector<TypedValue> &, GraphDbAccessor &)>
@@ -404,7 +426,20 @@ NameToFunction(const std::string &function_name) {
   if (function_name == "CEIL") return Ceil;
   if (function_name == "FLOOR") return Floor;
   if (function_name == "ROUND") return Round;
+  if (function_name == "EXP") return Exp;
+  if (function_name == "LOG") return Log;
+  if (function_name == "LOG10") return Log10;
+  if (function_name == "SQRT") return Sqrt;
+  if (function_name == "ACOS") return Acos;
+  if (function_name == "ASIN") return Asin;
+  if (function_name == "ATAN") return Atan;
+  if (function_name == "ATAN2") return Atan2;
+  if (function_name == "COS") return Cos;
+  if (function_name == "SIN") return Sin;
+  if (function_name == "TAN") return Tan;
   if (function_name == "SIGN") return Sign;
+  if (function_name == "E") return E;
+  if (function_name == "PI") return Pi;
   return nullptr;
 }
 }
diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp
index 7e5b4784c..e9369b26c 100644
--- a/tests/unit/query_expression_evaluator.cpp
+++ b/tests/unit/query_expression_evaluator.cpp
@@ -1,3 +1,4 @@
+#include <cmath>
 #include <iterator>
 #include <memory>
 #include <vector>
@@ -523,26 +524,23 @@ TEST(ExpressionEvaluator, FunctionAbs) {
   ASSERT_THROW(EvaluateFunction("ABS", {true}), QueryRuntimeException);
 }
 
-TEST(ExpressionEvaluator, FunctionCeil) {
-  ASSERT_THROW(EvaluateFunction("CEIL", {}), QueryRuntimeException);
-  ASSERT_EQ(EvaluateFunction("CEIL", {TypedValue::Null}).type(),
+// Test if log works. If it does then all functions wrapped with
+// WRAP_CMATH_FLOAT_FUNCTION macro should work and are not gonna be tested for
+// correctnes..
+TEST(ExpressionEvaluator, FunctionLog) {
+  ASSERT_THROW(EvaluateFunction("LOG", {}), QueryRuntimeException);
+  ASSERT_EQ(EvaluateFunction("LOG", {TypedValue::Null}).type(),
             TypedValue::Type::Null);
-  ASSERT_EQ(EvaluateFunction("CEIL", {-2}).Value<double>(), -2);
-  ASSERT_EQ(EvaluateFunction("CEIL", {-2.5}).Value<double>(), -2);
-  ASSERT_EQ(EvaluateFunction("CEIL", {2.5}).Value<double>(), 3);
-  ASSERT_THROW(EvaluateFunction("CEIL", {true}), QueryRuntimeException);
-}
-
-TEST(ExpressionEvaluator, FunctionFloor) {
-  ASSERT_THROW(EvaluateFunction("FLOOR", {}), QueryRuntimeException);
-  ASSERT_EQ(EvaluateFunction("FLOOR", {TypedValue::Null}).type(),
-            TypedValue::Type::Null);
-  ASSERT_EQ(EvaluateFunction("FLOOR", {-2}).Value<double>(), -2);
-  ASSERT_EQ(EvaluateFunction("FLOOR", {-2.5}).Value<double>(), -3);
-  ASSERT_EQ(EvaluateFunction("FLOOR", {2.5}).Value<double>(), 2);
-  ASSERT_THROW(EvaluateFunction("FLOOR", {true}), QueryRuntimeException);
+  ASSERT_DOUBLE_EQ(EvaluateFunction("LOG", {2}).Value<double>(), log(2));
+  ASSERT_DOUBLE_EQ(EvaluateFunction("LOG", {1.5}).Value<double>(), log(1.5));
+  // Not portable, but should work on most platforms.
+  ASSERT_TRUE(std::isnan(EvaluateFunction("LOG", {-1.5}).Value<double>()));
+  ASSERT_THROW(EvaluateFunction("LOG", {true}), QueryRuntimeException);
 }
 
+// Function Round wraps round from cmath and will work if FunctionLog test
+// passes. This test is used to show behavior of round since it differs from
+// neo4j's round.
 TEST(ExpressionEvaluator, FunctionRound) {
   ASSERT_THROW(EvaluateFunction("ROUND", {}), QueryRuntimeException);
   ASSERT_EQ(EvaluateFunction("ROUND", {TypedValue::Null}).type(),
@@ -557,6 +555,27 @@ TEST(ExpressionEvaluator, FunctionRound) {
   ASSERT_THROW(EvaluateFunction("ROUND", {true}), QueryRuntimeException);
 }
 
+// Check if wrapped functions are callable (check if everything was spelled
+// correctly...). Wrapper correctnes is checked in FunctionLog test.
+TEST(ExpressionEvaluator, FunctionWrappedMathFunctions) {
+  for (auto function_name :
+       {"FLOOR", "CEIL", "ROUND", "EXP", "LOG", "LOG10", "SQRT", "ACOS", "ASIN",
+        "ATAN", "COS", "SIN", "TAN"}) {
+    EvaluateFunction(function_name, {0.5});
+  }
+}
+
+TEST(ExpressionEvaluator, FunctionAtan2) {
+  ASSERT_THROW(EvaluateFunction("ATAN2", {}), QueryRuntimeException);
+  ASSERT_EQ(EvaluateFunction("ATAN2", {TypedValue::Null, 1}).type(),
+            TypedValue::Type::Null);
+  ASSERT_EQ(EvaluateFunction("ATAN2", {1, TypedValue::Null}).type(),
+            TypedValue::Type::Null);
+  ASSERT_DOUBLE_EQ(EvaluateFunction("ATAN2", {2, -1.0}).Value<double>(),
+                   atan2(2, -1));
+  ASSERT_THROW(EvaluateFunction("ATAN2", {3.0, true}), QueryRuntimeException);
+}
+
 TEST(ExpressionEvaluator, FunctionSign) {
   ASSERT_THROW(EvaluateFunction("SIGN", {}), QueryRuntimeException);
   ASSERT_EQ(EvaluateFunction("SIGN", {TypedValue::Null}).type(),
@@ -567,4 +586,14 @@ TEST(ExpressionEvaluator, FunctionSign) {
   ASSERT_EQ(EvaluateFunction("SIGN", {2.5}).Value<int64_t>(), 1);
   ASSERT_THROW(EvaluateFunction("SIGN", {true}), QueryRuntimeException);
 }
+
+TEST(ExpressionEvaluator, FunctionE) {
+  ASSERT_THROW(EvaluateFunction("E", {1}), QueryRuntimeException);
+  ASSERT_DOUBLE_EQ(EvaluateFunction("E", {}).Value<double>(), M_E);
+}
+
+TEST(ExpressionEvaluator, FunctionPi) {
+  ASSERT_THROW(EvaluateFunction("PI", {1}), QueryRuntimeException);
+  ASSERT_DOUBLE_EQ(EvaluateFunction("PI", {}).Value<double>(), M_PI);
+}
 }