Supstitute named expressions in ast cache

Reviewers: buda, teon.banek

Reviewed By: buda

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D510
This commit is contained in:
Mislav Bradac 2017-06-26 15:42:13 +02:00
parent f89ef14823
commit 3119ae5343
12 changed files with 261 additions and 45 deletions

View File

@ -34,4 +34,4 @@
--recover-on-startup=false
# use ast caching
--ast-cache=false
--ast-cache=true

View File

@ -27,8 +27,8 @@ std::string ParseStringLiteral(const std::string &s) {
int j = i + 1;
const int kShortUnicodeLength = 4;
const int kLongUnicodeLength = 8;
while (j < (int)s.size() - 1 && j < i + kLongUnicodeLength + 1 &&
isxdigit(s[j])) {
while (j < static_cast<int>(s.size()) - 1 &&
j < i + kLongUnicodeLength + 1 && isxdigit(s[j])) {
++j;
}
if (j - i == kLongUnicodeLength + 1) {

View File

@ -2,6 +2,7 @@
#include <map>
#include <memory>
#include <unordered_map>
#include <vector>
#include "database/graph_db.hpp"
@ -772,17 +773,28 @@ class NamedExpression : public Tree {
}
NamedExpression *Clone(AstTreeStorage &storage) const override {
return storage.Create<NamedExpression>(name_, expression_->Clone(storage));
return storage.Create<NamedExpression>(name_, expression_->Clone(storage),
token_position_);
}
std::string name_;
Expression *expression_ = nullptr;
// This field contains token position of first token in named expression
// used to create name_. If NamedExpression object is not created from
// query or it is aliased leave this value at -1.
int token_position_ = -1;
protected:
NamedExpression(int uid) : Tree(uid) {}
NamedExpression(int uid, const std::string &name) : Tree(uid), name_(name) {}
NamedExpression(int uid, const std::string &name, Expression *expression)
: Tree(uid), name_(name), expression_(expression) {}
NamedExpression(int uid, const std::string &name, Expression *expression,
int token_position)
: Tree(uid),
name_(name),
expression_(expression),
token_position_(token_position) {}
};
class PatternAtom : public Tree {
@ -1386,11 +1398,13 @@ class CachedAst {
public:
CachedAst(AstTreeStorage storage) : storage_(std::move(storage)) {}
/// Create new storage by plugging literals on its positions.
AstTreeStorage Plug(const Parameters &literals) {
/// Create new storage by plugging literals and named expessions on theirs
/// positions.
AstTreeStorage Plug(const Parameters &literals,
const std::unordered_map<int, std::string> &named_exprs) {
AstTreeStorage new_ast;
storage_.query()->Clone(new_ast);
LiteralsPlugger plugger(literals);
LiteralsPlugger plugger(literals, named_exprs);
new_ast.query()->Accept(plugger);
return new_ast;
}
@ -1403,26 +1417,47 @@ class CachedAst {
using HierarchicalTreeVisitor::Visit;
using HierarchicalTreeVisitor::PostVisit;
LiteralsPlugger(const Parameters &parameters) : parameters_(parameters) {}
LiteralsPlugger(const Parameters &literals,
const std::unordered_map<int, std::string> &named_exprs)
: literals_(literals), named_exprs_(named_exprs) {}
bool Visit(PrimitiveLiteral &literal) override {
// TODO: If literal is a part of NamedExpression then we need to change
// text in NamedExpression, otherwise wrong header will be returned.
if (!literal.value_.IsNull()) {
permanent_assert(literal.token_position_ != -1,
"Use AstPlugLiteralsVisitor only on ast created by "
"parsing queries");
literal.value_ = parameters_.AtTokenPosition(literal.token_position_);
literal.value_ = literals_.AtTokenPosition(literal.token_position_);
}
return true;
}
bool PreVisit(NamedExpression &named_expr) override {
// We care only about aliased named expressions in return.
if (!in_return_ || named_expr.token_position_ == -1) return true;
permanent_assert(
named_exprs_.count(named_expr.token_position_),
"There is no named expression string for needed position");
named_expr.name_ = named_exprs_.at(named_expr.token_position_);
return true;
}
bool Visit(Identifier &) override { return true; }
private:
const Parameters &parameters_;
};
bool PreVisit(Return &) override {
in_return_ = true;
return true;
}
bool PostVisit(Return &) override {
in_return_ = false;
return true;
}
private:
const Parameters &literals_;
const std::unordered_map<int, std::string> &named_exprs_;
bool in_return_ = false;
};
AstTreeStorage storage_;
};

View File

@ -211,8 +211,9 @@ antlrcpp::Any CypherMainVisitor::visitReturnItem(
if (in_with_ && !dynamic_cast<Identifier *>(named_expr->expression_)) {
throw SemanticException("Only variables can be non aliased in with");
}
// TODO: Should we get this by text or some escaping is needed?
named_expr->name_ = std::string(ctx->getText());
named_expr->token_position_ =
ctx->expression()->getStart()->getTokenIndex();
}
return named_expr;
}

View File

@ -121,7 +121,7 @@ properties : mapLiteral
| parameter
;
relationshipTypes : ':' SP? relTypeName ( SP? '|' ':'? SP? relTypeName )* ;
relationshipTypes : ':' SP? relTypeName ( SP? '|' SP? ':'? SP? relTypeName )* ;
nodeLabels : nodeLabel ( SP? nodeLabel )* ;

View File

@ -68,19 +68,24 @@ StrippedQuery::StrippedQuery(const std::string &query) : original_(query) {
token_strings.push_back(new_value);
};
// Convert tokens to strings, perform lowercasing and filtering.
for (const auto &token : tokens) {
// For every token in original query remember token index in stripped query.
std::vector<int> position_mapping(tokens.size(), -1);
// Convert tokens to strings, perform lowercasing and filtering, store
// literals and nonaliased named expressions in return.
for (int i = 0; i < static_cast<int>(tokens.size()); ++i) {
const auto &token = tokens[i];
// Position is calculated in query after stripping and whitespace
// normalisation, not before. There will be twice as much tokens before this
// one because space tokens will be inserted between every one.
// normalisation, not before. There will be twice as much tokens before
// this one because space tokens will be inserted between every one.
int token_index = token_strings.size() * 2;
switch (token.first) {
case Token::UNMATCHED:
debug_assert(false, "Shouldn't happen");
case Token::KEYWORD: {
auto s = utils::ToLowerCase(token.second);
// We don't strip NULL, since it can appear in special expressions like
// IS NULL and IS NOT NULL, but we strip true and false keywords.
// We don't strip NULL, since it can appear in special expressions
// like IS NULL and IS NOT NULL, but we strip true and false keywords.
if (s == "true") {
replace_stripped(token_index, true, kStrippedBooleanToken);
} else if (s == "false") {
@ -109,19 +114,100 @@ StrippedQuery::StrippedQuery(const std::string &query) : original_(query) {
token_strings.push_back(token.second);
break;
}
if (token.first != Token::SPACE) {
position_mapping[i] = token_index;
}
}
query_ = utils::Join(token_strings, " ");
hash_ = fnv(query_);
// Store nonaliased named expressions in returns in named_exprs_.
auto it = std::find_if(tokens.begin(), tokens.end(),
[](const std::pair<Token, std::string> &a) {
return utils::ToLowerCase(a.second) == "return";
});
// There is no RETURN so there is nothing to do here.
if (it == tokens.end()) return;
// Skip RETURN;
++it;
// Now we need to parse cypherReturn production from opencypher grammar.
// Skip leading whitespaces and DISTINCT statemant if there is one.
while (it != tokens.end() && it->first == Token::SPACE) {
++it;
}
if (it != tokens.end() && utils::ToLowerCase(it->second) == "distinct") {
++it;
}
// We assume there is only one return statement and that return statement is
// the last one. Otherwise, query is invalid and either antlr parser or
// cypher_main_visitor will report an error.
// TODO: we shouldn't rely on the fact that those checks will be done
// after this step. We should do them here.
while (it < tokens.end()) {
// Disregard leading whitespace
while (it != tokens.end() && it->first == Token::SPACE) {
++it;
}
// There is only whitespace, nothing to do...
if (it == tokens.end()) break;
bool has_as = false;
auto last_non_space = it;
auto jt = it;
// We should track number of opened braces and parantheses so that we can
// recognize if comma is a named expression separator or part of the
// list literal / function call.
int num_open_braces = 0;
int num_open_parantheses = 0;
for (; jt != tokens.end() &&
(jt->second != "," || num_open_braces || num_open_parantheses) &&
utils::ToLowerCase(jt->second) != "order" &&
utils::ToLowerCase(jt->second) != "skip" &&
utils::ToLowerCase(jt->second) != "limit";
++jt) {
if (jt->second == "(") {
++num_open_parantheses;
} else if (jt->second == ")") {
--num_open_parantheses;
} else if (jt->second == "[") {
++num_open_braces;
} else if (jt->second == "]") {
--num_open_braces;
}
has_as |= utils::ToLowerCase(jt->second) == "as";
if (jt->first != Token::SPACE) {
last_non_space = jt;
}
}
if (!has_as) {
// Named expression is not aliased. Save string disregarding leading and
// trailing whitespaces.
std::string s;
for (auto kt = it; kt != last_non_space + 1; ++kt) {
s += kt->second;
}
named_exprs_[position_mapping[it - tokens.begin()]] = s;
}
if (jt != tokens.end() && jt->second == ",") {
// There are more named expressions.
it = jt + 1;
} else {
// We hit ORDER, SKIP or LIMIT -> we are done.
break;
}
}
}
std::string StrippedQuery::GetFirstUtf8Symbol(const char *_s) const {
// According to
// https://stackoverflow.com/questions/16260033/reinterpret-cast-between-char-and-stduint8-t-safe
// this checks if casting from const char * to uint8_t is undefined behaviour.
static_assert(
std::is_same<std::uint8_t, unsigned char>::value,
"This library requires std::uint8_t to be implemented as unsigned char.");
static_assert(std::is_same<std::uint8_t, unsigned char>::value,
"This library requires std::uint8_t to be implemented as "
"unsigned char.");
const uint8_t *s = reinterpret_cast<const uint8_t *>(_s);
if ((*s >> 7) == 0x00) return std::string(_s, _s + 1);
if ((*s >> 5) == 0x06) {
@ -150,8 +236,8 @@ std::string StrippedQuery::GetFirstUtf8Symbol(const char *_s) const {
// From here until end of file there are functions that calculate matches for
// every possible token. Functions are more or less compatible with Cypher.g4
// grammar. Unfortunately, they contain a lof of special cases and shouldn't be
// changed without good reasons.
// grammar. Unfortunately, they contain a lof of special cases and shouldn't
// be changed without good reasons.
//
// Here be dragons, do not touch!
// ____ __
@ -192,14 +278,32 @@ int StrippedQuery::MatchSpecial(int start) const {
int StrippedQuery::MatchString(int start) const {
if (original_[start] != '"' && original_[start] != '\'') return 0;
char start_char = original_[start];
bool escaped = false;
for (auto *p = original_.data() + start + 1; *p; ++p) {
if (escaped) {
escaped = false;
} else if (!escaped) {
if (*p == start_char) return p - (original_.data() + start) + 1;
if (*p == '\\') {
escaped = true;
if (*p == start_char) return p - (original_.data() + start) + 1;
if (*p == '\\') {
++p;
if (*p == '\\' || *p == '\'' || *p == '"' || *p == 'B' || *p == 'b' ||
*p == 'F' || *p == 'f' || *p == 'N' || *p == 'n' || *p == 'R' ||
*p == 'r' || *p == 'T' || *p == 't') {
// Allowed escaped characters.
continue;
} else if (*p == 'U' || *p == 'u') {
int cnt = 0;
auto *r = p + 1;
while (isxdigit(*r) && cnt < 8) {
++cnt;
++r;
}
if (!*r) return 0;
if (cnt < 4) return 0;
if (cnt >= 4 && cnt < 8) {
p += 4;
}
if (cnt >= 8) {
p += 8;
}
} else {
return 0;
}
}
}
@ -209,8 +313,7 @@ int StrippedQuery::MatchString(int start) const {
int StrippedQuery::MatchDecimalInt(int start) const {
if (original_[start] == '0') return 1;
int i = start;
while (i < static_cast<int>(original_.size()) && '0' <= original_[i] &&
original_[i] <= '9') {
while (i < static_cast<int>(original_.size()) && isdigit(original_[i])) {
++i;
}
return i - start;
@ -232,10 +335,7 @@ int StrippedQuery::MatchHexadecimalInt(int start) const {
if (start + 1 >= static_cast<int>(original_.size())) return 0;
if (original_[start + 1] != 'x') return 0;
int i = start + 2;
while (i < static_cast<int>(original_.size()) &&
(('0' <= original_[i] && original_[i] <= '9') ||
('a' <= original_[i] && original_[i] <= 'f') ||
('A' <= original_[i] && original_[i] <= 'F'))) {
while (i < static_cast<int>(original_.size()) && isxdigit(original_[i])) {
++i;
}
if (i == start + 2) return 0;

View File

@ -1,5 +1,8 @@
#pragma once
#include <string>
#include <unordered_map>
#include "query/parameters.hpp"
#include "query/typed_value.hpp"
#include "utils/assert.hpp"
@ -46,6 +49,7 @@ class StrippedQuery {
const std::string &query() const { return query_; }
auto &literals() const { return literals_; }
auto &named_expressions() const { return named_exprs_; }
HashType hash() const { return hash_; }
private:
@ -72,6 +76,10 @@ class StrippedQuery {
// Token positions of stripped out literals mapped to their values.
Parameters literals_;
// Token positions of nonaliased named expressions in return statement mapped
// to theirs original/unstripped string.
std::unordered_map<int, std::string> named_exprs_;
// Hash based on the stripped query.
HashType hash_;
};

View File

@ -3,7 +3,7 @@
// TODO: Remove this flag. Ast caching can be disabled by setting this flag to
// false, this is useful for recerating antlr crashes in highly concurrent test.
// Once antlr bugs are fixed, or real test is written this flag can be removed.
DEFINE_bool(ast_cache, false, "Use ast caching.");
DEFINE_bool(ast_cache, true, "Use ast caching.");
DEFINE_bool(query_cost_planner, true,
"Use the cost estimator to generate plans for queries.");

View File

@ -67,7 +67,7 @@ class Interpreter {
CachedAst(std::move(visitor.storage())))
.first;
}
return it->second.Plug(stripped.literals());
return it->second.Plug(stripped.literals(), stripped.named_expressions());
}();
clock_t frontend_end_time = clock();

View File

@ -99,7 +99,7 @@ class CachedAstGenerator {
CypherMainVisitor visitor(context_);
visitor.visit(parser.tree());
CachedAst cached(std::move(visitor.storage()));
return cached.Plug(stripped.literals());
return cached.Plug(stripped.literals(), stripped.named_expressions());
}()),
query_(storage_.query()) {}

View File

@ -19,6 +19,8 @@ TEST(QueryEngine, AstCache) {
ResultStreamFaker stream;
auto dba = dbms.active();
engine.Run("RETURN 2 + 3", *dba, stream);
ASSERT_EQ(stream.GetHeader().size(), 1U);
EXPECT_EQ(stream.GetHeader()[0], "2 + 3");
ASSERT_EQ(stream.GetResults().size(), 1U);
ASSERT_EQ(stream.GetResults()[0].size(), 1U);
ASSERT_EQ(stream.GetResults()[0][0].Value<int64_t>(), 5);
@ -63,7 +65,18 @@ TEST(QueryEngine, AstCache) {
// Cached ast, same literals, different whitespaces.
ResultStreamFaker stream;
auto dba = dbms.active();
engine.Run("RETURN 10.5+1", *dba, stream);
engine.Run("RETURN 10.5 + 1", *dba, stream);
ASSERT_EQ(stream.GetResults().size(), 1U);
ASSERT_EQ(stream.GetResults()[0].size(), 1U);
ASSERT_EQ(stream.GetResults()[0][0].Value<double>(), 11.5);
}
{
// Cached ast, same literals, different named header.
ResultStreamFaker stream;
auto dba = dbms.active();
engine.Run("RETURN 10.5+1", *dba, stream);
ASSERT_EQ(stream.GetHeader().size(), 1U);
EXPECT_EQ(stream.GetHeader()[0], "10.5+1");
ASSERT_EQ(stream.GetResults().size(), 1U);
ASSERT_EQ(stream.GetResults()[0].size(), 1U);
ASSERT_EQ(stream.GetResults()[0][0].Value<double>(), 11.5);

View File

@ -3,8 +3,9 @@
// Created by Florijan Stamenkovic on 07.03.17.
//
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "query/exceptions.hpp"
#include "query/frontend/stripped.hpp"
#include "query/typed_value.hpp"
@ -12,6 +13,9 @@ using namespace query;
namespace {
using testing::Pair;
using testing::UnorderedElementsAre;
void EXPECT_PROP_TRUE(const TypedValue& a) {
EXPECT_TRUE(a.type() == TypedValue::Type::Bool && a.Value<bool>());
}
@ -114,6 +118,18 @@ TEST(QueryStripper, StringLiteral3) {
EXPECT_EQ(stripped.query(), "return " + kStrippedStringToken);
}
TEST(QueryStripper, StringLiteral4) {
StrippedQuery stripped("RETURN '\\u1Aa4'");
EXPECT_EQ(stripped.literals().size(), 1);
EXPECT_EQ(stripped.literals().At(0).second.Value<std::string>(), u8"\u1Aa4");
EXPECT_EQ(stripped.query(), "return " + kStrippedStringToken);
}
TEST(QueryStripper, StringLiteralIllegalEscapedSequence) {
EXPECT_THROW(StrippedQuery("RETURN 'so\\x'"), LexingException);
EXPECT_THROW(StrippedQuery("RETURN 'so\\uabc'"), LexingException);
}
TEST(QueryStripper, TrueLiteral) {
StrippedQuery stripped("RETURN trUE");
EXPECT_EQ(stripped.literals().size(), 1);
@ -222,4 +238,47 @@ TEST(QueryStripper, OtherTokens) {
EXPECT_EQ(stripped.literals().size(), 0);
EXPECT_EQ(stripped.query(), "+ += .. .");
}
TEST(QueryStripper, NamedExpression) {
StrippedQuery stripped("RETURN 2 + 3");
EXPECT_THAT(stripped.named_expressions(),
UnorderedElementsAre(Pair(2, "2 + 3")));
}
TEST(QueryStripper, AliasedNamedExpression) {
StrippedQuery stripped("RETURN 2 + 3 AS x");
EXPECT_THAT(stripped.named_expressions(), UnorderedElementsAre());
}
TEST(QueryStripper, MultipleNamedExpressions) {
StrippedQuery stripped("RETURN 2 + 3, x as s, x, n.x");
EXPECT_THAT(
stripped.named_expressions(),
UnorderedElementsAre(Pair(2, "2 + 3"), Pair(18, "x"), Pair(22, "n.x")));
}
TEST(QueryStripper, ReturnOrderBy) {
StrippedQuery stripped("RETURN 2 + 3 ORDER BY n.x, x");
EXPECT_THAT(stripped.named_expressions(),
UnorderedElementsAre(Pair(2, "2 + 3")));
}
TEST(QueryStripper, ReturnSkip) {
StrippedQuery stripped("RETURN 2 + 3 SKIP 10");
EXPECT_THAT(stripped.named_expressions(),
UnorderedElementsAre(Pair(2, "2 + 3")));
}
TEST(QueryStripper, ReturnLimit) {
StrippedQuery stripped("RETURN 2 + 3 LIMIT 12");
EXPECT_THAT(stripped.named_expressions(),
UnorderedElementsAre(Pair(2, "2 + 3")));
}
TEST(QueryStripper, ReturnListsAndFunctionCalls) {
StrippedQuery stripped("RETURN [1,2,[3, 4] , 5], f(1, 2), 3");
EXPECT_THAT(stripped.named_expressions(),
UnorderedElementsAre(Pair(2, "[1,2,[3, 4] , 5]"),
Pair(30, "f(1, 2)"), Pair(44, "3")));
}
}