Serialize Ast classes

Summary:
Although the first solution used cereal, the final implementation uses
boost. Since the cereal is still used in the codebase, compilation has
been modified to support multithreaded cereal.

In addition to serializing Ast classes, the following also needed to be
serialized:

  * GraphDbTypes
  * Symbol
  * TypedValue

TypedValue is treated specially, by inlining the serialization code in
the Ast class, concretely PrimitiveLiteral.

Another special case was the Function Ast class, which now stores a
function name which is resolved to a concrete std::function on
construction.

Tests have been added for serialized Ast in
tests/unit/cypher_main_visitor

Reviewers: mferencevic, mislav.bradac, florijan

Reviewed By: mislav.bradac

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1067
This commit is contained in:
Teon Banek 2017-12-20 11:24:48 +01:00
parent 3ae45e0d19
commit 5f7837d613
13 changed files with 1591 additions and 43 deletions

View File

@ -128,6 +128,9 @@ if (USE_READLINE)
endif() endif()
endif() endif()
set(Boost_USE_STATIC_LIBS ON)
find_package(Boost REQUIRED COMPONENTS serialization)
set(libs_dir ${CMAKE_SOURCE_DIR}/libs) set(libs_dir ${CMAKE_SOURCE_DIR}/libs)
add_subdirectory(libs EXCLUDE_FROM_ALL) add_subdirectory(libs EXCLUDE_FROM_ALL)

4
init
View File

@ -1,10 +1,14 @@
#!/bin/bash -e #!/bin/bash -e
# TODO: Consider putting boost library in libs/setup.sh, since the license
# allows source modification and static compilation. Unfortunately, it is quite
# a pain to set up the boost build process.
required_pkgs=(git arcanist # source code control required_pkgs=(git arcanist # source code control
cmake clang-3.8 llvm-3.8 pkg-config # build system cmake clang-3.8 llvm-3.8 pkg-config # build system
curl wget # for downloading libs curl wget # for downloading libs
uuid-dev default-jre-headless # required by antlr uuid-dev default-jre-headless # required by antlr
libreadline-dev # for memgraph console libreadline-dev # for memgraph console
libboost-serialization-dev
python3 python-virtualenv python3-pip # for qa, macro_benchmark and stress tests python3 python-virtualenv python3-pip # for qa, macro_benchmark and stress tests
) )

View File

@ -167,3 +167,6 @@ import_header_library(json ${CMAKE_CURRENT_SOURCE_DIR})
# Setup cereal # Setup cereal
import_header_library(cereal "${CMAKE_CURRENT_SOURCE_DIR}/cereal/include") import_header_library(cereal "${CMAKE_CURRENT_SOURCE_DIR}/cereal/include")
# Make cereal multithreaded by passing -DCEREAL_THREAD_SAFE=1 (note that -D is omitted below).
set_property(TARGET cereal PROPERTY
INTERFACE_COMPILE_DEFINITIONS CEREAL_THREAD_SAFE=1)

View File

@ -54,7 +54,7 @@ set(memgraph_src_files
# memgraph_lib depend on these libraries # memgraph_lib depend on these libraries
set(MEMGRAPH_ALL_LIBS stdc++fs Threads::Threads fmt cppitertools cereal set(MEMGRAPH_ALL_LIBS stdc++fs Threads::Threads fmt cppitertools cereal
antlr_opencypher_parser_lib dl glog gflags) antlr_opencypher_parser_lib dl glog gflags Boost::serialization)
if (USE_LTALLOC) if (USE_LTALLOC)
list(APPEND MEMGRAPH_ALL_LIBS ltalloc) list(APPEND MEMGRAPH_ALL_LIBS ltalloc)

View File

@ -2,6 +2,9 @@
#include <string> #include <string>
#include "boost/serialization/base_object.hpp"
#include "cereal/types/base_class.hpp"
#include "utils/total_ordering.hpp" #include "utils/total_ordering.hpp"
namespace GraphDbTypes { namespace GraphDbTypes {
@ -33,18 +36,67 @@ class Common : TotalOrdering<TSpecificType> {
private: private:
StorageT storage_{0}; StorageT storage_{0};
friend class boost::serialization::access;
template <class TArchive>
void serialize(TArchive &ar, const unsigned int) {
ar & storage_;
}
}; };
class Label : public Common<Label> { class Label : public Common<Label> {
using Common::Common; using Common::Common;
friend class boost::serialization::access;
template <class TArchive>
void serialize(TArchive &ar, const unsigned int) {
ar & boost::serialization::base_object<Common<Label>>(*this);
}
public:
/** Required for cereal serialization. */
template <class Archive>
void serialize(Archive &archive) {
archive(cereal::base_class<Common<Label>>(this));
}
}; };
class EdgeType : public Common<EdgeType> { class EdgeType : public Common<EdgeType> {
using Common::Common; using Common::Common;
friend class boost::serialization::access;
template <class TArchive>
void serialize(TArchive &ar, const unsigned int) {
ar & boost::serialization::base_object<Common<EdgeType>>(*this);
}
public:
/** Required for cereal serialization. */
template <class Archive>
void serialize(Archive &archive) {
archive(cereal::base_class<Common<EdgeType>>(this));
}
}; };
class Property : public Common<Property> { class Property : public Common<Property> {
using Common::Common; using Common::Common;
friend class boost::serialization::access;
template <class TArchive>
void serialize(TArchive &ar, const unsigned int) {
ar & boost::serialization::base_object<Common<Property>>(*this);
}
public:
/** Required for cereal serialization. */
template <class Archive>
void serialize(Archive &archive) {
archive(cereal::base_class<Common<Property>>(this));
}
}; };
}; // namespace GraphDbTypes }; // namespace GraphDbTypes

View File

@ -10,6 +10,14 @@ Query *AstTreeStorage::query() const {
return dynamic_cast<Query *>(storage_[0].get()); return dynamic_cast<Query *>(storage_[0].get());
} }
int AstTreeStorage::MaximumStorageUid() const {
int max_uid = -1;
for (const auto &tree : storage_) {
max_uid = std::max(max_uid, tree->uid());
}
return max_uid;
}
ReturnBody CloneReturnBody(AstTreeStorage &storage, const ReturnBody &body) { ReturnBody CloneReturnBody(AstTreeStorage &storage, const ReturnBody &body) {
ReturnBody new_body; ReturnBody new_body;
new_body.distinct = body.distinct; new_body.distinct = body.distinct;
@ -26,3 +34,133 @@ ReturnBody CloneReturnBody(AstTreeStorage &storage, const ReturnBody &body) {
} }
} // namespace query } // namespace query
#define LOAD_AND_CONSTRUCT(DerivedClass, ...) \
template <class TArchive> \
void load_construct_data(TArchive &ar, DerivedClass *cls, \
const unsigned int) { \
::new (cls) DerivedClass(__VA_ARGS__); \
}
// All of the serialization cruft follows
namespace boost::serialization {
LOAD_AND_CONSTRUCT(query::Where, 0);
LOAD_AND_CONSTRUCT(query::OrOperator, 0);
LOAD_AND_CONSTRUCT(query::XorOperator, 0);
LOAD_AND_CONSTRUCT(query::AndOperator, 0);
LOAD_AND_CONSTRUCT(query::AdditionOperator, 0);
LOAD_AND_CONSTRUCT(query::SubtractionOperator, 0);
LOAD_AND_CONSTRUCT(query::MultiplicationOperator, 0);
LOAD_AND_CONSTRUCT(query::DivisionOperator, 0);
LOAD_AND_CONSTRUCT(query::ModOperator, 0);
LOAD_AND_CONSTRUCT(query::NotEqualOperator, 0);
LOAD_AND_CONSTRUCT(query::EqualOperator, 0);
LOAD_AND_CONSTRUCT(query::LessOperator, 0);
LOAD_AND_CONSTRUCT(query::GreaterOperator, 0);
LOAD_AND_CONSTRUCT(query::LessEqualOperator, 0);
LOAD_AND_CONSTRUCT(query::GreaterEqualOperator, 0);
LOAD_AND_CONSTRUCT(query::InListOperator, 0);
LOAD_AND_CONSTRUCT(query::ListMapIndexingOperator, 0);
LOAD_AND_CONSTRUCT(query::ListSlicingOperator, 0, nullptr, nullptr, nullptr);
LOAD_AND_CONSTRUCT(query::IfOperator, 0, nullptr, nullptr, nullptr);
LOAD_AND_CONSTRUCT(query::NotOperator, 0);
LOAD_AND_CONSTRUCT(query::UnaryPlusOperator, 0);
LOAD_AND_CONSTRUCT(query::UnaryMinusOperator, 0);
LOAD_AND_CONSTRUCT(query::IsNullOperator, 0);
LOAD_AND_CONSTRUCT(query::PrimitiveLiteral, 0);
LOAD_AND_CONSTRUCT(query::ListLiteral, 0);
LOAD_AND_CONSTRUCT(query::MapLiteral, 0);
LOAD_AND_CONSTRUCT(query::Identifier, 0, "");
LOAD_AND_CONSTRUCT(query::PropertyLookup, 0, nullptr, "",
GraphDbTypes::Property());
LOAD_AND_CONSTRUCT(query::LabelsTest, 0, nullptr,
std::vector<GraphDbTypes::Label>());
LOAD_AND_CONSTRUCT(query::Function, 0);
LOAD_AND_CONSTRUCT(query::Aggregation, 0, nullptr, nullptr,
query::Aggregation::Op::COUNT);
LOAD_AND_CONSTRUCT(query::All, 0, nullptr, nullptr, nullptr);
LOAD_AND_CONSTRUCT(query::ParameterLookup, 0);
LOAD_AND_CONSTRUCT(query::NamedExpression, 0);
LOAD_AND_CONSTRUCT(query::NodeAtom, 0);
LOAD_AND_CONSTRUCT(query::EdgeAtom, 0);
LOAD_AND_CONSTRUCT(query::Pattern, 0);
LOAD_AND_CONSTRUCT(query::SingleQuery, 0);
LOAD_AND_CONSTRUCT(query::CypherUnion, 0);
LOAD_AND_CONSTRUCT(query::Query, 0);
LOAD_AND_CONSTRUCT(query::Create, 0);
LOAD_AND_CONSTRUCT(query::Match, 0);
LOAD_AND_CONSTRUCT(query::Return, 0);
LOAD_AND_CONSTRUCT(query::With, 0);
LOAD_AND_CONSTRUCT(query::Delete, 0);
LOAD_AND_CONSTRUCT(query::SetProperty, 0);
LOAD_AND_CONSTRUCT(query::SetProperties, 0);
LOAD_AND_CONSTRUCT(query::SetLabels, 0);
LOAD_AND_CONSTRUCT(query::RemoveProperty, 0);
LOAD_AND_CONSTRUCT(query::RemoveLabels, 0);
LOAD_AND_CONSTRUCT(query::Merge, 0);
LOAD_AND_CONSTRUCT(query::Unwind, 0);
LOAD_AND_CONSTRUCT(query::CreateIndex, 0);
} // namespace boost::serialization
#undef LOAD_AND_CONSTRUCT
// Include archives before registering most derived types.
#include "boost/archive/binary_iarchive.hpp"
#include "boost/archive/binary_oarchive.hpp"
#include "boost/archive/text_iarchive.hpp"
#include "boost/archive/text_oarchive.hpp"
BOOST_CLASS_EXPORT_IMPLEMENT(query::Query);
BOOST_CLASS_EXPORT_IMPLEMENT(query::SingleQuery);
BOOST_CLASS_EXPORT_IMPLEMENT(query::CypherUnion);
BOOST_CLASS_EXPORT_IMPLEMENT(query::NamedExpression);
BOOST_CLASS_EXPORT_IMPLEMENT(query::OrOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::XorOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::AndOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::NotOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::AdditionOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::SubtractionOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::MultiplicationOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::DivisionOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::ModOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::NotEqualOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::EqualOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::LessOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::GreaterOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::LessEqualOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::GreaterEqualOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::InListOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::ListMapIndexingOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::ListSlicingOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::IfOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::UnaryPlusOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::UnaryMinusOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::IsNullOperator);
BOOST_CLASS_EXPORT_IMPLEMENT(query::ListLiteral);
BOOST_CLASS_EXPORT_IMPLEMENT(query::MapLiteral);
BOOST_CLASS_EXPORT_IMPLEMENT(query::PropertyLookup);
BOOST_CLASS_EXPORT_IMPLEMENT(query::LabelsTest);
BOOST_CLASS_EXPORT_IMPLEMENT(query::Aggregation);
BOOST_CLASS_EXPORT_IMPLEMENT(query::Function);
BOOST_CLASS_EXPORT_IMPLEMENT(query::All);
BOOST_CLASS_EXPORT_IMPLEMENT(query::ParameterLookup);
BOOST_CLASS_EXPORT_IMPLEMENT(query::Create);
BOOST_CLASS_EXPORT_IMPLEMENT(query::Match);
BOOST_CLASS_EXPORT_IMPLEMENT(query::Return);
BOOST_CLASS_EXPORT_IMPLEMENT(query::With);
BOOST_CLASS_EXPORT_IMPLEMENT(query::Pattern);
BOOST_CLASS_EXPORT_IMPLEMENT(query::NodeAtom);
BOOST_CLASS_EXPORT_IMPLEMENT(query::EdgeAtom);
BOOST_CLASS_EXPORT_IMPLEMENT(query::Delete);
BOOST_CLASS_EXPORT_IMPLEMENT(query::Where);
BOOST_CLASS_EXPORT_IMPLEMENT(query::SetProperty);
BOOST_CLASS_EXPORT_IMPLEMENT(query::SetProperties);
BOOST_CLASS_EXPORT_IMPLEMENT(query::SetLabels);
BOOST_CLASS_EXPORT_IMPLEMENT(query::RemoveProperty);
BOOST_CLASS_EXPORT_IMPLEMENT(query::RemoveLabels);
BOOST_CLASS_EXPORT_IMPLEMENT(query::Merge);
BOOST_CLASS_EXPORT_IMPLEMENT(query::Unwind);
BOOST_CLASS_EXPORT_IMPLEMENT(query::Identifier);
BOOST_CLASS_EXPORT_IMPLEMENT(query::PrimitiveLiteral);
BOOST_CLASS_EXPORT_IMPLEMENT(query::CreateIndex);

File diff suppressed because it is too large Load Diff

View File

@ -755,27 +755,24 @@ antlrcpp::Any CypherMainVisitor::visitExpression3a(
expression = static_cast<Expression *>(storage_.Create<InListOperator>( expression = static_cast<Expression *>(storage_.Create<InListOperator>(
expression, op->expression3b()->accept(this))); expression, op->expression3b()->accept(this)));
} else { } else {
std::function<TypedValue(const std::vector<TypedValue> &, std::string function_name;
GraphDbAccessor &)>
f;
if (op->STARTS() && op->WITH()) { if (op->STARTS() && op->WITH()) {
f = NameToFunction(kStartsWith); function_name = kStartsWith;
} else if (op->ENDS() && op->WITH()) { } else if (op->ENDS() && op->WITH()) {
f = NameToFunction(kEndsWith); function_name = kEndsWith;
} else if (op->CONTAINS()) { } else if (op->CONTAINS()) {
f = NameToFunction(kContains); function_name = kContains;
} else { } else {
throw utils::NotYetImplemented("function '{}'", op->getText()); throw utils::NotYetImplemented("function '{}'", op->getText());
} }
auto expression2 = op->expression3b()->accept(this); auto expression2 = op->expression3b()->accept(this);
std::vector<Expression *> args = {expression, expression2}; std::vector<Expression *> args = {expression, expression2};
expression = expression = static_cast<Expression *>(
static_cast<Expression *>(storage_.Create<Function>(f, args)); storage_.Create<Function>(function_name, args));
} }
} }
return expression; return expression;
} }
antlrcpp::Any CypherMainVisitor::visitStringAndNullOperators( antlrcpp::Any CypherMainVisitor::visitStringAndNullOperators(
CypherParser::StringAndNullOperatorsContext *) { CypherParser::StringAndNullOperatorsContext *) {
DLOG(FATAL) << "Should never be called. See documentation in hpp."; DLOG(FATAL) << "Should never be called. See documentation in hpp.";
@ -989,7 +986,7 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(
if (!function) if (!function)
throw SemanticException("Function '{}' doesn't exist.", function_name); throw SemanticException("Function '{}' doesn't exist.", function_name);
return static_cast<Expression *>( return static_cast<Expression *>(
storage_.Create<Function>(function, expressions)); storage_.Create<Function>(function_name, expressions));
} }
antlrcpp::Any CypherMainVisitor::visitFunctionName( antlrcpp::Any CypherMainVisitor::visitFunctionName(

View File

@ -43,6 +43,17 @@ class Symbol {
bool user_declared_ = true; bool user_declared_ = true;
Type type_ = Type::Any; Type type_ = Type::Any;
int token_position_ = -1; int token_position_ = -1;
friend class boost::serialization::access;
template <class TArchive>
void serialize(TArchive &ar, const unsigned int) {
ar & name_;
ar & position_;
ar & user_declared_;
ar & type_;
ar & token_position_;
}
}; };
} // namespace query } // namespace query

View File

@ -350,7 +350,7 @@ class ExpressionEvaluator : public TreeVisitor<TypedValue> {
for (const auto &argument : function.arguments_) { for (const auto &argument : function.arguments_) {
arguments.emplace_back(argument->Accept(*this)); arguments.emplace_back(argument->Accept(*this));
} }
return function.function_(arguments, db_accessor_); return function.function()(arguments, db_accessor_);
} }
TypedValue Visit(All &all) override { TypedValue Visit(All &all) override {

View File

@ -6,8 +6,11 @@
#include <vector> #include <vector>
#include "antlr4-runtime.h" #include "antlr4-runtime.h"
#include "boost/archive/text_iarchive.hpp"
#include "boost/archive/text_oarchive.hpp"
#include "gmock/gmock.h" #include "gmock/gmock.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "query/context.hpp" #include "query/context.hpp"
#include "query/frontend/ast/ast.hpp" #include "query/frontend/ast/ast.hpp"
#include "query/frontend/ast/cypher_main_visitor.hpp" #include "query/frontend/ast/cypher_main_visitor.hpp"
@ -107,11 +110,39 @@ class CachedAstGenerator : public Base {
Query *query_; Query *query_;
}; };
// This generator serializes the parsed ast and uses the deserialized one.
class SerializedAstGenerator : public Base {
public:
SerializedAstGenerator(const std::string &query)
: Base(query),
storage_([&]() {
::frontend::opencypher::Parser parser(query);
CypherMainVisitor visitor(context_);
visitor.visit(parser.tree());
std::stringstream stream;
{
boost::archive::text_oarchive out_archive(stream);
out_archive << *visitor.query();
}
AstTreeStorage new_ast;
{
boost::archive::text_iarchive in_archive(stream);
new_ast.Load(in_archive);
}
return new_ast;
}()),
query_(storage_.query()) {}
AstTreeStorage storage_;
Query *query_;
};
template <typename T> template <typename T>
class CypherMainVisitorTest : public ::testing::Test {}; class CypherMainVisitorTest : public ::testing::Test {};
typedef ::testing::Types<AstGenerator, OriginalAfterCloningAstGenerator, typedef ::testing::Types<AstGenerator, OriginalAfterCloningAstGenerator,
ClonedAstGenerator, CachedAstGenerator> ClonedAstGenerator, CachedAstGenerator,
SerializedAstGenerator>
AstGeneratorTypes; AstGeneratorTypes;
TYPED_TEST_CASE(CypherMainVisitorTest, AstGeneratorTypes); TYPED_TEST_CASE(CypherMainVisitorTest, AstGeneratorTypes);
@ -712,7 +743,7 @@ TYPED_TEST(CypherMainVisitorTest, Function) {
auto *function = dynamic_cast<Function *>( auto *function = dynamic_cast<Function *>(
return_clause->body_.named_expressions[0]->expression_); return_clause->body_.named_expressions[0]->expression_);
ASSERT_TRUE(function); ASSERT_TRUE(function);
ASSERT_TRUE(function->function_); ASSERT_TRUE(function->function());
} }
TYPED_TEST(CypherMainVisitorTest, StringLiteralDoubleQuotes) { TYPED_TEST(CypherMainVisitorTest, StringLiteralDoubleQuotes) {

View File

@ -579,9 +579,9 @@ auto GetMerge(AstTreeStorage &storage, Pattern *pattern, OnMatch on_match,
#define AND(expr1, expr2) storage.Create<query::AndOperator>((expr1), (expr2)) #define AND(expr1, expr2) storage.Create<query::AndOperator>((expr1), (expr2))
#define OR(expr1, expr2) storage.Create<query::OrOperator>((expr1), (expr2)) #define OR(expr1, expr2) storage.Create<query::OrOperator>((expr1), (expr2))
// Function call // Function call
#define FN(function_name, ...) \ #define FN(function_name, ...) \
storage.Create<query::Function>( \ storage.Create<query::Function>( \
query::NameToFunction(utils::ToUpperCase(function_name)), \ utils::ToUpperCase(function_name), \
std::vector<query::Expression *>{__VA_ARGS__}) std::vector<query::Expression *>{__VA_ARGS__})
// List slicing // List slicing
#define SLICE(list, lower_bound, upper_bound) \ #define SLICE(list, lower_bound, upper_bound) \

View File

@ -49,8 +49,7 @@ TypedValue EvaluateFunction(const std::string &function_name,
for (const auto &arg : args) { for (const auto &arg : args) {
expressions.push_back(storage.Create<PrimitiveLiteral>(arg)); expressions.push_back(storage.Create<PrimitiveLiteral>(arg));
} }
auto *op = auto *op = storage.Create<Function>(function_name, expressions);
storage.Create<Function>(NameToFunction(function_name), expressions);
return op->Accept(eval); return op->Accept(eval);
} }