diff --git a/config/testing.conf b/config/testing.conf index 5fef4f818..6d25f9383 100644 --- a/config/testing.conf +++ b/config/testing.conf @@ -34,4 +34,4 @@ --recover-on-startup=false # use ast caching ---ast-cache=false +--ast-cache=true diff --git a/src/query/common.cpp b/src/query/common.cpp index 2ebee4732..718d7fd18 100644 --- a/src/query/common.cpp +++ b/src/query/common.cpp @@ -27,8 +27,8 @@ std::string ParseStringLiteral(const std::string &s) { int j = i + 1; const int kShortUnicodeLength = 4; const int kLongUnicodeLength = 8; - while (j < (int)s.size() - 1 && j < i + kLongUnicodeLength + 1 && - isxdigit(s[j])) { + while (j < static_cast<int>(s.size()) - 1 && + j < i + kLongUnicodeLength + 1 && isxdigit(s[j])) { ++j; } if (j - i == kLongUnicodeLength + 1) { diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index dec4310a2..19d05f534 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -2,6 +2,7 @@ #include <map> #include <memory> +#include <unordered_map> #include <vector> #include "database/graph_db.hpp" @@ -772,17 +773,28 @@ class NamedExpression : public Tree { } NamedExpression *Clone(AstTreeStorage &storage) const override { - return storage.Create<NamedExpression>(name_, expression_->Clone(storage)); + return storage.Create<NamedExpression>(name_, expression_->Clone(storage), + token_position_); } std::string name_; Expression *expression_ = nullptr; + // This field contains token position of first token in named expression + // used to create name_. If NamedExpression object is not created from + // query or it is aliased leave this value at -1. + int token_position_ = -1; protected: NamedExpression(int uid) : Tree(uid) {} NamedExpression(int uid, const std::string &name) : Tree(uid), name_(name) {} NamedExpression(int uid, const std::string &name, Expression *expression) : Tree(uid), name_(name), expression_(expression) {} + NamedExpression(int uid, const std::string &name, Expression *expression, + int token_position) + : Tree(uid), + name_(name), + expression_(expression), + token_position_(token_position) {} }; class PatternAtom : public Tree { @@ -1386,11 +1398,13 @@ class CachedAst { public: CachedAst(AstTreeStorage storage) : storage_(std::move(storage)) {} - /// Create new storage by plugging literals on its positions. - AstTreeStorage Plug(const Parameters &literals) { + /// Create new storage by plugging literals and named expessions on theirs + /// positions. + AstTreeStorage Plug(const Parameters &literals, + const std::unordered_map<int, std::string> &named_exprs) { AstTreeStorage new_ast; storage_.query()->Clone(new_ast); - LiteralsPlugger plugger(literals); + LiteralsPlugger plugger(literals, named_exprs); new_ast.query()->Accept(plugger); return new_ast; } @@ -1403,26 +1417,47 @@ class CachedAst { using HierarchicalTreeVisitor::Visit; using HierarchicalTreeVisitor::PostVisit; - LiteralsPlugger(const Parameters ¶meters) : parameters_(parameters) {} + LiteralsPlugger(const Parameters &literals, + const std::unordered_map<int, std::string> &named_exprs) + : literals_(literals), named_exprs_(named_exprs) {} bool Visit(PrimitiveLiteral &literal) override { - // TODO: If literal is a part of NamedExpression then we need to change - // text in NamedExpression, otherwise wrong header will be returned. if (!literal.value_.IsNull()) { permanent_assert(literal.token_position_ != -1, "Use AstPlugLiteralsVisitor only on ast created by " "parsing queries"); - literal.value_ = parameters_.AtTokenPosition(literal.token_position_); + literal.value_ = literals_.AtTokenPosition(literal.token_position_); } return true; } + bool PreVisit(NamedExpression &named_expr) override { + // We care only about aliased named expressions in return. + if (!in_return_ || named_expr.token_position_ == -1) return true; + permanent_assert( + named_exprs_.count(named_expr.token_position_), + "There is no named expression string for needed position"); + named_expr.name_ = named_exprs_.at(named_expr.token_position_); + return true; + } + bool Visit(Identifier &) override { return true; } - private: - const Parameters ¶meters_; - }; + bool PreVisit(Return &) override { + in_return_ = true; + return true; + } + bool PostVisit(Return &) override { + in_return_ = false; + return true; + } + + private: + const Parameters &literals_; + const std::unordered_map<int, std::string> &named_exprs_; + bool in_return_ = false; + }; AstTreeStorage storage_; }; diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index a6961ecbb..a7bb13d7d 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -211,8 +211,9 @@ antlrcpp::Any CypherMainVisitor::visitReturnItem( if (in_with_ && !dynamic_cast<Identifier *>(named_expr->expression_)) { throw SemanticException("Only variables can be non aliased in with"); } - // TODO: Should we get this by text or some escaping is needed? named_expr->name_ = std::string(ctx->getText()); + named_expr->token_position_ = + ctx->expression()->getStart()->getTokenIndex(); } return named_expr; } diff --git a/src/query/frontend/opencypher/grammar/Cypher.g4 b/src/query/frontend/opencypher/grammar/Cypher.g4 index 5abb9d1d3..e169f9ad1 100644 --- a/src/query/frontend/opencypher/grammar/Cypher.g4 +++ b/src/query/frontend/opencypher/grammar/Cypher.g4 @@ -121,7 +121,7 @@ properties : mapLiteral | parameter ; -relationshipTypes : ':' SP? relTypeName ( SP? '|' ':'? SP? relTypeName )* ; +relationshipTypes : ':' SP? relTypeName ( SP? '|' SP? ':'? SP? relTypeName )* ; nodeLabels : nodeLabel ( SP? nodeLabel )* ; diff --git a/src/query/frontend/stripped.cpp b/src/query/frontend/stripped.cpp index 7943a66e1..14e2fe28c 100644 --- a/src/query/frontend/stripped.cpp +++ b/src/query/frontend/stripped.cpp @@ -68,19 +68,24 @@ StrippedQuery::StrippedQuery(const std::string &query) : original_(query) { token_strings.push_back(new_value); }; - // Convert tokens to strings, perform lowercasing and filtering. - for (const auto &token : tokens) { + // For every token in original query remember token index in stripped query. + std::vector<int> position_mapping(tokens.size(), -1); + + // Convert tokens to strings, perform lowercasing and filtering, store + // literals and nonaliased named expressions in return. + for (int i = 0; i < static_cast<int>(tokens.size()); ++i) { + const auto &token = tokens[i]; // Position is calculated in query after stripping and whitespace - // normalisation, not before. There will be twice as much tokens before this - // one because space tokens will be inserted between every one. + // normalisation, not before. There will be twice as much tokens before + // this one because space tokens will be inserted between every one. int token_index = token_strings.size() * 2; switch (token.first) { case Token::UNMATCHED: debug_assert(false, "Shouldn't happen"); case Token::KEYWORD: { auto s = utils::ToLowerCase(token.second); - // We don't strip NULL, since it can appear in special expressions like - // IS NULL and IS NOT NULL, but we strip true and false keywords. + // We don't strip NULL, since it can appear in special expressions + // like IS NULL and IS NOT NULL, but we strip true and false keywords. if (s == "true") { replace_stripped(token_index, true, kStrippedBooleanToken); } else if (s == "false") { @@ -109,19 +114,100 @@ StrippedQuery::StrippedQuery(const std::string &query) : original_(query) { token_strings.push_back(token.second); break; } + + if (token.first != Token::SPACE) { + position_mapping[i] = token_index; + } } query_ = utils::Join(token_strings, " "); hash_ = fnv(query_); + + // Store nonaliased named expressions in returns in named_exprs_. + auto it = std::find_if(tokens.begin(), tokens.end(), + [](const std::pair<Token, std::string> &a) { + return utils::ToLowerCase(a.second) == "return"; + }); + // There is no RETURN so there is nothing to do here. + if (it == tokens.end()) return; + // Skip RETURN; + ++it; + + // Now we need to parse cypherReturn production from opencypher grammar. + // Skip leading whitespaces and DISTINCT statemant if there is one. + while (it != tokens.end() && it->first == Token::SPACE) { + ++it; + } + if (it != tokens.end() && utils::ToLowerCase(it->second) == "distinct") { + ++it; + } + + // We assume there is only one return statement and that return statement is + // the last one. Otherwise, query is invalid and either antlr parser or + // cypher_main_visitor will report an error. + // TODO: we shouldn't rely on the fact that those checks will be done + // after this step. We should do them here. + while (it < tokens.end()) { + // Disregard leading whitespace + while (it != tokens.end() && it->first == Token::SPACE) { + ++it; + } + // There is only whitespace, nothing to do... + if (it == tokens.end()) break; + bool has_as = false; + auto last_non_space = it; + auto jt = it; + // We should track number of opened braces and parantheses so that we can + // recognize if comma is a named expression separator or part of the + // list literal / function call. + int num_open_braces = 0; + int num_open_parantheses = 0; + for (; jt != tokens.end() && + (jt->second != "," || num_open_braces || num_open_parantheses) && + utils::ToLowerCase(jt->second) != "order" && + utils::ToLowerCase(jt->second) != "skip" && + utils::ToLowerCase(jt->second) != "limit"; + ++jt) { + if (jt->second == "(") { + ++num_open_parantheses; + } else if (jt->second == ")") { + --num_open_parantheses; + } else if (jt->second == "[") { + ++num_open_braces; + } else if (jt->second == "]") { + --num_open_braces; + } + has_as |= utils::ToLowerCase(jt->second) == "as"; + if (jt->first != Token::SPACE) { + last_non_space = jt; + } + } + if (!has_as) { + // Named expression is not aliased. Save string disregarding leading and + // trailing whitespaces. + std::string s; + for (auto kt = it; kt != last_non_space + 1; ++kt) { + s += kt->second; + } + named_exprs_[position_mapping[it - tokens.begin()]] = s; + } + if (jt != tokens.end() && jt->second == ",") { + // There are more named expressions. + it = jt + 1; + } else { + // We hit ORDER, SKIP or LIMIT -> we are done. + break; + } + } } std::string StrippedQuery::GetFirstUtf8Symbol(const char *_s) const { // According to // https://stackoverflow.com/questions/16260033/reinterpret-cast-between-char-and-stduint8-t-safe // this checks if casting from const char * to uint8_t is undefined behaviour. - static_assert( - std::is_same<std::uint8_t, unsigned char>::value, - "This library requires std::uint8_t to be implemented as unsigned char."); + static_assert(std::is_same<std::uint8_t, unsigned char>::value, + "This library requires std::uint8_t to be implemented as " + "unsigned char."); const uint8_t *s = reinterpret_cast<const uint8_t *>(_s); if ((*s >> 7) == 0x00) return std::string(_s, _s + 1); if ((*s >> 5) == 0x06) { @@ -150,8 +236,8 @@ std::string StrippedQuery::GetFirstUtf8Symbol(const char *_s) const { // From here until end of file there are functions that calculate matches for // every possible token. Functions are more or less compatible with Cypher.g4 -// grammar. Unfortunately, they contain a lof of special cases and shouldn't be -// changed without good reasons. +// grammar. Unfortunately, they contain a lof of special cases and shouldn't +// be changed without good reasons. // // Here be dragons, do not touch! // ____ __ @@ -192,14 +278,32 @@ int StrippedQuery::MatchSpecial(int start) const { int StrippedQuery::MatchString(int start) const { if (original_[start] != '"' && original_[start] != '\'') return 0; char start_char = original_[start]; - bool escaped = false; for (auto *p = original_.data() + start + 1; *p; ++p) { - if (escaped) { - escaped = false; - } else if (!escaped) { - if (*p == start_char) return p - (original_.data() + start) + 1; - if (*p == '\\') { - escaped = true; + if (*p == start_char) return p - (original_.data() + start) + 1; + if (*p == '\\') { + ++p; + if (*p == '\\' || *p == '\'' || *p == '"' || *p == 'B' || *p == 'b' || + *p == 'F' || *p == 'f' || *p == 'N' || *p == 'n' || *p == 'R' || + *p == 'r' || *p == 'T' || *p == 't') { + // Allowed escaped characters. + continue; + } else if (*p == 'U' || *p == 'u') { + int cnt = 0; + auto *r = p + 1; + while (isxdigit(*r) && cnt < 8) { + ++cnt; + ++r; + } + if (!*r) return 0; + if (cnt < 4) return 0; + if (cnt >= 4 && cnt < 8) { + p += 4; + } + if (cnt >= 8) { + p += 8; + } + } else { + return 0; } } } @@ -209,8 +313,7 @@ int StrippedQuery::MatchString(int start) const { int StrippedQuery::MatchDecimalInt(int start) const { if (original_[start] == '0') return 1; int i = start; - while (i < static_cast<int>(original_.size()) && '0' <= original_[i] && - original_[i] <= '9') { + while (i < static_cast<int>(original_.size()) && isdigit(original_[i])) { ++i; } return i - start; @@ -232,10 +335,7 @@ int StrippedQuery::MatchHexadecimalInt(int start) const { if (start + 1 >= static_cast<int>(original_.size())) return 0; if (original_[start + 1] != 'x') return 0; int i = start + 2; - while (i < static_cast<int>(original_.size()) && - (('0' <= original_[i] && original_[i] <= '9') || - ('a' <= original_[i] && original_[i] <= 'f') || - ('A' <= original_[i] && original_[i] <= 'F'))) { + while (i < static_cast<int>(original_.size()) && isxdigit(original_[i])) { ++i; } if (i == start + 2) return 0; diff --git a/src/query/frontend/stripped.hpp b/src/query/frontend/stripped.hpp index 1abbe0542..9e1a3d397 100644 --- a/src/query/frontend/stripped.hpp +++ b/src/query/frontend/stripped.hpp @@ -1,5 +1,8 @@ #pragma once +#include <string> +#include <unordered_map> + #include "query/parameters.hpp" #include "query/typed_value.hpp" #include "utils/assert.hpp" @@ -46,6 +49,7 @@ class StrippedQuery { const std::string &query() const { return query_; } auto &literals() const { return literals_; } + auto &named_expressions() const { return named_exprs_; } HashType hash() const { return hash_; } private: @@ -72,6 +76,10 @@ class StrippedQuery { // Token positions of stripped out literals mapped to their values. Parameters literals_; + // Token positions of nonaliased named expressions in return statement mapped + // to theirs original/unstripped string. + std::unordered_map<int, std::string> named_exprs_; + // Hash based on the stripped query. HashType hash_; }; diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 1b16ac9a9..5be604d14 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -3,7 +3,7 @@ // TODO: Remove this flag. Ast caching can be disabled by setting this flag to // false, this is useful for recerating antlr crashes in highly concurrent test. // Once antlr bugs are fixed, or real test is written this flag can be removed. -DEFINE_bool(ast_cache, false, "Use ast caching."); +DEFINE_bool(ast_cache, true, "Use ast caching."); DEFINE_bool(query_cost_planner, true, "Use the cost estimator to generate plans for queries."); diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index 1de96cd46..9aef7e9ce 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -67,7 +67,7 @@ class Interpreter { CachedAst(std::move(visitor.storage()))) .first; } - return it->second.Plug(stripped.literals()); + return it->second.Plug(stripped.literals(), stripped.named_expressions()); }(); clock_t frontend_end_time = clock(); diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index fbed0196a..1a6d29287 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -99,7 +99,7 @@ class CachedAstGenerator { CypherMainVisitor visitor(context_); visitor.visit(parser.tree()); CachedAst cached(std::move(visitor.storage())); - return cached.Plug(stripped.literals()); + return cached.Plug(stripped.literals(), stripped.named_expressions()); }()), query_(storage_.query()) {} diff --git a/tests/unit/query_engine.cpp b/tests/unit/query_engine.cpp index b6372fb94..8930359f8 100644 --- a/tests/unit/query_engine.cpp +++ b/tests/unit/query_engine.cpp @@ -19,6 +19,8 @@ TEST(QueryEngine, AstCache) { ResultStreamFaker stream; auto dba = dbms.active(); engine.Run("RETURN 2 + 3", *dba, stream); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "2 + 3"); ASSERT_EQ(stream.GetResults().size(), 1U); ASSERT_EQ(stream.GetResults()[0].size(), 1U); ASSERT_EQ(stream.GetResults()[0][0].Value<int64_t>(), 5); @@ -63,7 +65,18 @@ TEST(QueryEngine, AstCache) { // Cached ast, same literals, different whitespaces. ResultStreamFaker stream; auto dba = dbms.active(); - engine.Run("RETURN 10.5+1", *dba, stream); + engine.Run("RETURN 10.5 + 1", *dba, stream); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].Value<double>(), 11.5); + } + { + // Cached ast, same literals, different named header. + ResultStreamFaker stream; + auto dba = dbms.active(); + engine.Run("RETURN 10.5+1", *dba, stream); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "10.5+1"); ASSERT_EQ(stream.GetResults().size(), 1U); ASSERT_EQ(stream.GetResults()[0].size(), 1U); ASSERT_EQ(stream.GetResults()[0][0].Value<double>(), 11.5); diff --git a/tests/unit/stripped.cpp b/tests/unit/stripped.cpp index 2124fd14b..b13e76de3 100644 --- a/tests/unit/stripped.cpp +++ b/tests/unit/stripped.cpp @@ -3,8 +3,9 @@ // Created by Florijan Stamenkovic on 07.03.17. // +#include "gmock/gmock.h" #include "gtest/gtest.h" - +#include "query/exceptions.hpp" #include "query/frontend/stripped.hpp" #include "query/typed_value.hpp" @@ -12,6 +13,9 @@ using namespace query; namespace { +using testing::Pair; +using testing::UnorderedElementsAre; + void EXPECT_PROP_TRUE(const TypedValue& a) { EXPECT_TRUE(a.type() == TypedValue::Type::Bool && a.Value<bool>()); } @@ -114,6 +118,18 @@ TEST(QueryStripper, StringLiteral3) { EXPECT_EQ(stripped.query(), "return " + kStrippedStringToken); } +TEST(QueryStripper, StringLiteral4) { + StrippedQuery stripped("RETURN '\\u1Aa4'"); + EXPECT_EQ(stripped.literals().size(), 1); + EXPECT_EQ(stripped.literals().At(0).second.Value<std::string>(), u8"\u1Aa4"); + EXPECT_EQ(stripped.query(), "return " + kStrippedStringToken); +} + +TEST(QueryStripper, StringLiteralIllegalEscapedSequence) { + EXPECT_THROW(StrippedQuery("RETURN 'so\\x'"), LexingException); + EXPECT_THROW(StrippedQuery("RETURN 'so\\uabc'"), LexingException); +} + TEST(QueryStripper, TrueLiteral) { StrippedQuery stripped("RETURN trUE"); EXPECT_EQ(stripped.literals().size(), 1); @@ -222,4 +238,47 @@ TEST(QueryStripper, OtherTokens) { EXPECT_EQ(stripped.literals().size(), 0); EXPECT_EQ(stripped.query(), "+ += .. ."); } + +TEST(QueryStripper, NamedExpression) { + StrippedQuery stripped("RETURN 2 + 3"); + EXPECT_THAT(stripped.named_expressions(), + UnorderedElementsAre(Pair(2, "2 + 3"))); +} + +TEST(QueryStripper, AliasedNamedExpression) { + StrippedQuery stripped("RETURN 2 + 3 AS x"); + EXPECT_THAT(stripped.named_expressions(), UnorderedElementsAre()); +} + +TEST(QueryStripper, MultipleNamedExpressions) { + StrippedQuery stripped("RETURN 2 + 3, x as s, x, n.x"); + EXPECT_THAT( + stripped.named_expressions(), + UnorderedElementsAre(Pair(2, "2 + 3"), Pair(18, "x"), Pair(22, "n.x"))); +} + +TEST(QueryStripper, ReturnOrderBy) { + StrippedQuery stripped("RETURN 2 + 3 ORDER BY n.x, x"); + EXPECT_THAT(stripped.named_expressions(), + UnorderedElementsAre(Pair(2, "2 + 3"))); +} + +TEST(QueryStripper, ReturnSkip) { + StrippedQuery stripped("RETURN 2 + 3 SKIP 10"); + EXPECT_THAT(stripped.named_expressions(), + UnorderedElementsAre(Pair(2, "2 + 3"))); +} + +TEST(QueryStripper, ReturnLimit) { + StrippedQuery stripped("RETURN 2 + 3 LIMIT 12"); + EXPECT_THAT(stripped.named_expressions(), + UnorderedElementsAre(Pair(2, "2 + 3"))); +} + +TEST(QueryStripper, ReturnListsAndFunctionCalls) { + StrippedQuery stripped("RETURN [1,2,[3, 4] , 5], f(1, 2), 3"); + EXPECT_THAT(stripped.named_expressions(), + UnorderedElementsAre(Pair(2, "[1,2,[3, 4] , 5]"), + Pair(30, "f(1, 2)"), Pair(44, "3"))); +} }