Implement extract function

Reviewers: teon.banek, buda

Reviewed By: teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1455
This commit is contained in:
Marin Tomic 2018-06-28 16:58:33 +02:00
parent c9b75cbb45
commit 86a00b00fa
14 changed files with 326 additions and 60 deletions

View File

@ -38,6 +38,7 @@ struct Expression {
all @10 :All;
single @11 :Single;
parameterLookup @12 :ParameterLookup;
extract @13 :Extract;
}
}
@ -285,6 +286,12 @@ struct Reduce {
expression @4 :Tree;
}
struct Extract {
identifier @0 :Tree;
list @1 :Tree;
expression @2 :Tree;
}
struct All {
identifier @0 :Tree;
listExpression @1 :Tree;

View File

@ -22,9 +22,7 @@ namespace query {
// safe, use a regular top-level function.
void *const AstStorage::kHelperId = (void *)CloneReturnBody;
AstStorage::AstStorage() {
storage_.emplace_back(new Query(next_uid_++));
}
AstStorage::AstStorage() { storage_.emplace_back(new Query(next_uid_++)); }
Query *AstStorage::query() const {
return dynamic_cast<Query *>(storage_[0].get());
@ -48,7 +46,7 @@ ReturnBody CloneReturnBody(AstStorage &storage, const ReturnBody &body) {
// Capnproto serialization.
Tree *AstStorage::Load(const capnp::Tree::Reader &tree,
std::vector<int> *loaded_uids) {
std::vector<int> *loaded_uids) {
auto uid = tree.getUid();
// Check if element already deserialized and if yes, return existing
@ -213,6 +211,10 @@ Expression *Expression::Construct(const capnp::Expression::Reader &reader,
auto single_reader = reader.getSingle();
return Single::Construct(single_reader, storage);
}
case capnp::Expression::EXTRACT: {
auto extract_reader = reader.getExtract();
return Extract::Construct(extract_reader, storage);
}
}
}
@ -360,8 +362,7 @@ void BinaryOperator::Save(capnp::BinaryOperator::Builder *builder,
}
void BinaryOperator::Load(const capnp::Tree::Reader &reader,
AstStorage *storage,
std::vector<int> *loaded_uids) {
AstStorage *storage, std::vector<int> *loaded_uids) {
Expression::Load(reader, storage, loaded_uids);
auto bop_reader = reader.getExpression().getBinaryOperator();
if (bop_reader.hasExpression1()) {
@ -700,8 +701,7 @@ void UnaryOperator::Save(capnp::UnaryOperator::Builder *builder,
}
}
void UnaryOperator::Load(const capnp::Tree::Reader &reader,
AstStorage *storage,
void UnaryOperator::Load(const capnp::Tree::Reader &reader, AstStorage *storage,
std::vector<int> *loaded_uids) {
Expression::Load(reader, storage, loaded_uids);
if (reader.hasExpression()) {
@ -924,8 +924,8 @@ void Function::Save(capnp::Function::Builder *builder,
}
}
void Function::Load(const capnp::Tree::Reader &base_reader,
AstStorage *storage, std::vector<int> *loaded_uids) {
void Function::Load(const capnp::Tree::Reader &base_reader, AstStorage *storage,
std::vector<int> *loaded_uids) {
Expression::Load(base_reader, storage, loaded_uids);
auto reader = base_reader.getExpression().getFunction();
function_name_ = reader.getFunctionName().cStr();
@ -1043,8 +1043,7 @@ void PropertyLookup::Save(capnp::PropertyLookup::Builder *builder,
}
void PropertyLookup::Load(const capnp::Tree::Reader &base_reader,
AstStorage *storage,
std::vector<int> *loaded_uids) {
AstStorage *storage, std::vector<int> *loaded_uids) {
Expression::Load(base_reader, storage, loaded_uids);
auto reader = base_reader.getExpression().getPropertyLookup();
if (reader.hasExpression()) {
@ -1084,8 +1083,8 @@ void Reduce::Save(capnp::Reduce::Builder *builder,
expression_->Save(&expr_builder, saved_uids);
}
void Reduce::Load(const capnp::Tree::Reader &base_reader,
AstStorage *storage, std::vector<int> *loaded_uids) {
void Reduce::Load(const capnp::Tree::Reader &base_reader, AstStorage *storage,
std::vector<int> *loaded_uids) {
Expression::Load(base_reader, storage, loaded_uids);
auto reader = base_reader.getExpression().getReduce();
const auto acc_reader = reader.getAccumulator();
@ -1109,6 +1108,43 @@ Reduce *Reduce::Construct(const capnp::Reduce::Reader &reader,
return storage->Create<Reduce>(nullptr, nullptr, nullptr, nullptr, nullptr);
}
// Extract
void Extract::Save(capnp::Expression::Builder *expr_builder,
std::vector<int> *saved_uids) {
Expression::Save(expr_builder, saved_uids);
auto builder = expr_builder->initExtract();
Save(&builder, saved_uids);
}
void Extract::Save(capnp::Extract::Builder *builder,
std::vector<int> *saved_uids) {
auto id_builder = builder->initIdentifier();
identifier_->Save(&id_builder, saved_uids);
auto list_builder = builder->initList();
list_->Save(&list_builder, saved_uids);
auto expr_builder = builder->initExpression();
expression_->Save(&expr_builder, saved_uids);
}
void Extract::Load(const capnp::Tree::Reader &base_reader, AstStorage *storage,
std::vector<int> *loaded_uids) {
Expression::Load(base_reader, storage, loaded_uids);
auto reader = base_reader.getExpression().getExtract();
const auto id_reader = reader.getIdentifier();
identifier_ =
dynamic_cast<Identifier *>(storage->Load(id_reader, loaded_uids));
const auto list_reader = reader.getList();
list_ = dynamic_cast<Expression *>(storage->Load(list_reader, loaded_uids));
const auto expr_reader = reader.getExpression();
expression_ =
dynamic_cast<Expression *>(storage->Load(expr_reader, loaded_uids));
}
Extract *Extract::Construct(const capnp::Extract::Reader &reader,
AstStorage *storage) {
return storage->Create<Extract>(nullptr, nullptr, nullptr);
}
// Single
void Single::Save(capnp::Expression::Builder *expr_builder,
std::vector<int> *saved_uids) {
@ -1127,8 +1163,8 @@ void Single::Save(capnp::Single::Builder *builder,
list_expression_->Save(&expr_builder, saved_uids);
}
void Single::Load(const capnp::Tree::Reader &base_reader,
AstStorage *storage, std::vector<int> *loaded_uids) {
void Single::Load(const capnp::Tree::Reader &base_reader, AstStorage *storage,
std::vector<int> *loaded_uids) {
Expression::Load(base_reader, storage, loaded_uids);
auto reader = base_reader.getExpression().getSingle();
const auto id_reader = reader.getIdentifier();
@ -1165,8 +1201,8 @@ void Where::Save(capnp::Where::Builder *builder, std::vector<int> *saved_uids) {
}
}
void Where::Load(const capnp::Tree::Reader &base_reader,
AstStorage *storage, std::vector<int> *loaded_uids) {
void Where::Load(const capnp::Tree::Reader &base_reader, AstStorage *storage,
std::vector<int> *loaded_uids) {
Tree::Load(base_reader, storage, loaded_uids);
auto reader = base_reader.getWhere();
if (reader.hasExpression()) {
@ -1278,8 +1314,8 @@ void Create::Save(capnp::Create::Builder *builder,
}
}
void Create::Load(const capnp::Tree::Reader &base_reader,
AstStorage *storage, std::vector<int> *loaded_uids) {
void Create::Load(const capnp::Tree::Reader &base_reader, AstStorage *storage,
std::vector<int> *loaded_uids) {
Clause::Load(base_reader, storage, loaded_uids);
auto reader = base_reader.getClause().getCreate();
for (const auto pattern_reader : reader.getPatterns()) {
@ -1339,8 +1375,8 @@ void Delete::Save(capnp::Delete::Builder *builder,
builder->setDetach(detach_);
}
void Delete::Load(const capnp::Tree::Reader &base_reader,
AstStorage *storage, std::vector<int> *loaded_uids) {
void Delete::Load(const capnp::Tree::Reader &base_reader, AstStorage *storage,
std::vector<int> *loaded_uids) {
Clause::Load(base_reader, storage, loaded_uids);
auto reader = base_reader.getClause().getDelete();
for (const auto tree_reader : reader.getExpressions()) {
@ -1378,8 +1414,8 @@ void Match::Save(capnp::Match::Builder *builder, std::vector<int> *saved_uids) {
builder->setOptional(optional_);
}
void Match::Load(const capnp::Tree::Reader &base_reader,
AstStorage *storage, std::vector<int> *loaded_uids) {
void Match::Load(const capnp::Tree::Reader &base_reader, AstStorage *storage,
std::vector<int> *loaded_uids) {
Clause::Load(base_reader, storage, loaded_uids);
auto reader = base_reader.getClause().getMatch();
for (const auto tree_reader : reader.getPatterns()) {
@ -1427,8 +1463,8 @@ void Merge::Save(capnp::Merge::Builder *builder, std::vector<int> *saved_uids) {
}
}
void Merge::Load(const capnp::Tree::Reader &base_reader,
AstStorage *storage, std::vector<int> *loaded_uids) {
void Merge::Load(const capnp::Tree::Reader &base_reader, AstStorage *storage,
std::vector<int> *loaded_uids) {
Clause::Load(base_reader, storage, loaded_uids);
auto reader = base_reader.getClause().getMerge();
for (const auto tree_reader : reader.getOnMatch()) {
@ -1473,8 +1509,7 @@ void RemoveLabels::Save(capnp::RemoveLabels::Builder *builder,
}
void RemoveLabels::Load(const capnp::Tree::Reader &base_reader,
AstStorage *storage,
std::vector<int> *loaded_uids) {
AstStorage *storage, std::vector<int> *loaded_uids) {
Clause::Load(base_reader, storage, loaded_uids);
auto reader = base_reader.getClause().getRemoveLabels();
if (reader.hasIdentifier()) {
@ -1511,8 +1546,7 @@ void RemoveProperty::Save(capnp::RemoveProperty::Builder *builder,
}
void RemoveProperty::Load(const capnp::Tree::Reader &base_reader,
AstStorage *storage,
std::vector<int> *loaded_uids) {
AstStorage *storage, std::vector<int> *loaded_uids) {
Clause::Load(base_reader, storage, loaded_uids);
auto reader = base_reader.getClause().getRemoveProperty();
if (reader.hasPropertyLookup()) {
@ -1608,8 +1642,8 @@ void LoadReturnBody(capnp::ReturnBody::Reader &rb_reader, ReturnBody &body,
}
}
void Return::Load(const capnp::Tree::Reader &base_reader,
AstStorage *storage, std::vector<int> *loaded_uids) {
void Return::Load(const capnp::Tree::Reader &base_reader, AstStorage *storage,
std::vector<int> *loaded_uids) {
Clause::Load(base_reader, storage, loaded_uids);
auto reader = base_reader.getClause().getReturn();
auto rb_reader = reader.getReturnBody();
@ -1726,8 +1760,7 @@ void SetProperties::Save(capnp::SetProperties::Builder *builder,
}
void SetProperties::Load(const capnp::Tree::Reader &base_reader,
AstStorage *storage,
std::vector<int> *loaded_uids) {
AstStorage *storage, std::vector<int> *loaded_uids) {
Clause::Load(base_reader, storage, loaded_uids);
auto reader = base_reader.getClause().getSetProperties();
if (reader.hasIdentifier()) {
@ -1764,8 +1797,8 @@ void Unwind::Save(capnp::Unwind::Builder *builder,
}
}
void Unwind::Load(const capnp::Tree::Reader &base_reader,
AstStorage *storage, std::vector<int> *loaded_uids) {
void Unwind::Load(const capnp::Tree::Reader &base_reader, AstStorage *storage,
std::vector<int> *loaded_uids) {
Clause::Load(base_reader, storage, loaded_uids);
auto reader = base_reader.getClause().getUnwind();
if (reader.hasNamedExpression()) {
@ -1809,8 +1842,7 @@ void With::Load(const capnp::Tree::Reader &base_reader, AstStorage *storage,
LoadReturnBody(rb_reader, body_, storage, loaded_uids);
}
With *With::Construct(const capnp::With::Reader &reader,
AstStorage *storage) {
With *With::Construct(const capnp::With::Reader &reader, AstStorage *storage) {
return storage->Create<With>();
}
@ -1866,8 +1898,8 @@ void DropUser::Save(capnp::DropUser::Builder *builder,
utils::SaveVector(usernames_, &usernames_builder);
}
void DropUser::Load(const capnp::Tree::Reader &base_reader,
AstStorage *storage, std::vector<int> *loaded_uids) {
void DropUser::Load(const capnp::Tree::Reader &base_reader, AstStorage *storage,
std::vector<int> *loaded_uids) {
Clause::Load(base_reader, storage, loaded_uids);
auto reader = base_reader.getClause().getDropUser();
usernames_.clear();
@ -1950,8 +1982,7 @@ void NamedExpression::Save(capnp::NamedExpression::Builder *builder,
}
void NamedExpression::Load(const capnp::Tree::Reader &base_reader,
AstStorage *storage,
std::vector<int> *loaded_uids) {
AstStorage *storage, std::vector<int> *loaded_uids) {
Tree::Load(base_reader, storage, loaded_uids);
auto reader = base_reader.getNamedExpression();
name_ = reader.getName().cStr();
@ -1994,8 +2025,8 @@ void Pattern::Save(capnp::Pattern::Builder *builder,
}
}
void Pattern::Load(const capnp::Tree::Reader &base_reader,
AstStorage *storage, std::vector<int> *loaded_uids) {
void Pattern::Load(const capnp::Tree::Reader &base_reader, AstStorage *storage,
std::vector<int> *loaded_uids) {
Tree::Load(base_reader, storage, loaded_uids);
auto reader = base_reader.getPattern();
if (reader.hasIdentifier()) {
@ -2048,8 +2079,8 @@ PatternAtom *PatternAtom::Construct(const capnp::PatternAtom::Reader &reader,
}
}
void PatternAtom::Load(const capnp::Tree::Reader &reader,
AstStorage *storage, std::vector<int> *loaded_uids) {
void PatternAtom::Load(const capnp::Tree::Reader &reader, AstStorage *storage,
std::vector<int> *loaded_uids) {
Tree::Load(reader, storage, loaded_uids);
auto pa_reader = reader.getPatternAtom();
if (pa_reader.hasIdentifier()) {
@ -2089,8 +2120,8 @@ void NodeAtom::Save(capnp::NodeAtom::Builder *builder,
}
}
void NodeAtom::Load(const capnp::Tree::Reader &base_reader,
AstStorage *storage, std::vector<int> *loaded_uids) {
void NodeAtom::Load(const capnp::Tree::Reader &base_reader, AstStorage *storage,
std::vector<int> *loaded_uids) {
PatternAtom::Load(base_reader, storage, loaded_uids);
auto reader = base_reader.getPatternAtom().getNodeAtom();
for (auto entry_reader : reader.getProperties()) {
@ -2231,8 +2262,8 @@ void LoadLambda(capnp::EdgeAtom::Lambda::Reader &reader,
}
}
void EdgeAtom::Load(const capnp::Tree::Reader &base_reader,
AstStorage *storage, std::vector<int> *loaded_uids) {
void EdgeAtom::Load(const capnp::Tree::Reader &base_reader, AstStorage *storage,
std::vector<int> *loaded_uids) {
PatternAtom::Load(base_reader, storage, loaded_uids);
auto reader = base_reader.getPatternAtom().getEdgeAtom();
switch (reader.getType()) {
@ -2418,6 +2449,7 @@ BOOST_CLASS_EXPORT_IMPLEMENT(query::LabelsTest);
BOOST_CLASS_EXPORT_IMPLEMENT(query::Aggregation);
BOOST_CLASS_EXPORT_IMPLEMENT(query::Function);
BOOST_CLASS_EXPORT_IMPLEMENT(query::Reduce);
BOOST_CLASS_EXPORT_IMPLEMENT(query::Extract);
BOOST_CLASS_EXPORT_IMPLEMENT(query::All);
BOOST_CLASS_EXPORT_IMPLEMENT(query::Single);
BOOST_CLASS_EXPORT_IMPLEMENT(query::ParameterLookup);

View File

@ -1707,6 +1707,79 @@ class Reduce : public Expression {
const unsigned int);
};
class Extract : public Expression {
friend class AstStorage;
public:
DEFVISITABLE(TreeVisitor<TypedValue>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
identifier_->Accept(visitor) && list_->Accept(visitor) &&
expression_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
Extract *Clone(AstStorage &storage) const override {
return storage.Create<Extract>(identifier_->Clone(storage),
list_->Clone(storage),
expression_->Clone(storage));
}
static Extract *Construct(const capnp::Extract::Reader &reader,
AstStorage *storage);
using Expression::Save;
// None of these should be nullptr after construction.
/// Identifier for the list element.
Identifier *identifier_ = nullptr;
/// Expression which produces a list which will be extracted.
Expression *list_ = nullptr;
/// Expression which produces the new value for list element.
Expression *expression_ = nullptr;
protected:
Extract(int uid, Identifier *identifier, Expression *list,
Expression *expression)
: Expression(uid),
identifier_(identifier),
list_(list),
expression_(expression) {}
void Save(capnp::Expression::Builder *builder,
std::vector<int> *saved_uids) override;
virtual void Save(capnp::Extract::Builder *builder,
std::vector<int> *saved_uids);
void Load(const capnp::Tree::Reader &tree_reader, AstStorage *storage,
std::vector<int> *loaded_uids) override;
private:
friend class boost::serialization::access;
BOOST_SERIALIZATION_SPLIT_MEMBER();
template <class TArchive>
void save(TArchive &ar, const unsigned int) const {
ar << boost::serialization::base_object<Expression>(*this);
SavePointer(ar, identifier_);
SavePointer(ar, list_);
SavePointer(ar, expression_);
}
template <class TArchive>
void load(TArchive &ar, const unsigned int) {
ar >> boost::serialization::base_object<Expression>(*this);
LoadPointer(ar, identifier_);
LoadPointer(ar, list_);
LoadPointer(ar, expression_);
}
template <class TArchive>
friend void boost::serialization::load_construct_data(TArchive &, Extract *,
const unsigned int);
};
// TODO: Think about representing All and Any as Reduce.
class All : public Expression {
friend class AstStorage;
@ -3641,6 +3714,7 @@ LOAD_AND_CONSTRUCT(query::Aggregation, 0, nullptr, nullptr,
query::Aggregation::Op::COUNT);
LOAD_AND_CONSTRUCT(query::Reduce, 0, nullptr, nullptr, nullptr, nullptr,
nullptr);
LOAD_AND_CONSTRUCT(query::Extract, 0, nullptr, nullptr, nullptr);
LOAD_AND_CONSTRUCT(query::All, 0, nullptr, nullptr, nullptr);
LOAD_AND_CONSTRUCT(query::Single, 0, nullptr, nullptr, nullptr);
LOAD_AND_CONSTRUCT(query::ParameterLookup, 0);
@ -3704,6 +3778,7 @@ BOOST_CLASS_EXPORT_KEY(query::LabelsTest);
BOOST_CLASS_EXPORT_KEY(query::Aggregation);
BOOST_CLASS_EXPORT_KEY(query::Function);
BOOST_CLASS_EXPORT_KEY(query::Reduce);
BOOST_CLASS_EXPORT_KEY(query::Extract);
BOOST_CLASS_EXPORT_KEY(query::All);
BOOST_CLASS_EXPORT_KEY(query::Single);
BOOST_CLASS_EXPORT_KEY(query::ParameterLookup);

View File

@ -15,6 +15,7 @@ class LabelsTest;
class Aggregation;
class Function;
class Reduce;
class Extract;
class All;
class Single;
class ParameterLookup;
@ -71,9 +72,9 @@ using TreeCompositeVisitor = ::utils::CompositeVisitor<
GreaterEqualOperator, InListOperator, ListMapIndexingOperator,
ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator,
IsNullOperator, ListLiteral, MapLiteral, PropertyLookup, LabelsTest,
Aggregation, Function, Reduce, All, Single, Create, Match, Return, With,
Pattern, NodeAtom, EdgeAtom, Delete, Where, SetProperty, SetProperties,
SetLabels, RemoveProperty, RemoveLabels, Merge, Unwind>;
Aggregation, Function, Reduce, Extract, All, Single, Create, Match, Return,
With, Pattern, NodeAtom, EdgeAtom, Delete, Where, SetProperty,
SetProperties, SetLabels, RemoveProperty, RemoveLabels, Merge, Unwind>;
using TreeLeafVisitor =
::utils::LeafVisitor<Identifier, PrimitiveLiteral, ParameterLookup,
@ -97,9 +98,10 @@ using TreeVisitor = ::utils::Visitor<
LessEqualOperator, GreaterEqualOperator, InListOperator,
ListMapIndexingOperator, ListSlicingOperator, IfOperator, UnaryPlusOperator,
UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral, PropertyLookup,
LabelsTest, Aggregation, Function, Reduce, All, Single, ParameterLookup,
Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete, Where,
SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, Merge,
Unwind, Identifier, PrimitiveLiteral, CreateIndex, ModifyUser, DropUser>;
LabelsTest, Aggregation, Function, Reduce, Extract, All, Single,
ParameterLookup, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom,
Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty,
RemoveLabels, Merge, Unwind, Identifier, PrimitiveLiteral, CreateIndex,
ModifyUser, DropUser>;
} // namespace query

View File

@ -991,6 +991,17 @@ antlrcpp::Any CypherMainVisitor::visitAtom(CypherParser::AtomContext *ctx) {
storage_.Create<Reduce>(accumulator, initializer, ident, list, expr));
} else if (ctx->caseExpression()) {
return static_cast<Expression *>(ctx->caseExpression()->accept(this));
} else if (ctx->extractExpression()) {
auto *ident = storage_.Create<Identifier>(ctx->extractExpression()
->idInColl()
->variable()
->accept(this)
.as<std::string>());
Expression *list =
ctx->extractExpression()->idInColl()->expression()->accept(this);
Expression *expr = ctx->extractExpression()->expression()->accept(this);
return static_cast<Expression *>(
storage_.Create<Extract>(ident, list, expr));
}
// TODO: Implement this. We don't support comprehensions, filtering... at
// the moment.

View File

@ -188,7 +188,7 @@ atom : literal
| listComprehension
| patternComprehension
| ( FILTER SP? '(' SP? filterExpression SP? ')' )
| ( EXTRACT SP? '(' SP? filterExpression SP? ( SP? '|' expression )? ')' )
| ( EXTRACT SP? '(' SP? extractExpression SP? ')' )
| ( REDUCE SP? '(' SP? reduceExpression SP? ')' )
| ( ALL SP? '(' SP? filterExpression SP? ')' )
| ( ANY SP? '(' SP? filterExpression SP? ')' )
@ -231,6 +231,8 @@ filterExpression : idInColl ( SP? where )? ;
reduceExpression : accumulator=variable SP? '=' SP? initial=expression SP? ',' SP? idInColl SP? '|' SP? expression ;
extractExpression : idInColl SP? '|' SP? expression ;
idInColl : variable SP IN SP expression ;
functionInvocation : functionName SP? '(' SP? ( DISTINCT SP? )? ( expression SP? ( ',' SP? expression SP? )* )? ')' ;

View File

@ -351,6 +351,11 @@ bool SymbolGenerator::PreVisit(Reduce &reduce) {
return false;
}
bool SymbolGenerator::PreVisit(Extract &extract) {
extract.list_->Accept(*this);
VisitWithIdentifiers(*extract.expression_, {extract.identifier_});
return false;
}
// Pattern and its subparts.

View File

@ -61,6 +61,7 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
bool PreVisit(All &) override;
bool PreVisit(Single &) override;
bool PreVisit(Reduce &) override;
bool PreVisit(Extract &) override;
// Pattern and its subparts.
bool PreVisit(Pattern &) override;

View File

@ -372,6 +372,31 @@ class ExpressionEvaluator : public TreeVisitor<TypedValue> {
return accumulator;
}
TypedValue Visit(Extract &extract) override {
auto list_value = extract.list_->Accept(*this);
if (list_value.IsNull()) {
return TypedValue::Null;
}
if (list_value.type() != TypedValue::Type::List) {
throw QueryRuntimeException("'EXTRACT' expected a list, but got {}",
list_value.type());
}
const auto &list = list_value.Value<std::vector<TypedValue>>();
const auto &element_symbol =
context_->symbol_table_.at(*extract.identifier_);
std::vector<TypedValue> result;
result.reserve(list.size());
for (const auto &element : list) {
if (element.IsNull()) {
result.push_back(TypedValue::Null);
} else {
frame_[element_symbol] = element;
result.emplace_back(extract.expression_->Accept(*this));
}
}
return result;
}
TypedValue Visit(All &all) override {
auto list_value = all.list_expression_->Accept(*this);
if (list_value.IsNull()) {

View File

@ -241,6 +241,21 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
return true;
}
bool PostVisit(Extract &extract) override {
// Remove the symbol bound by extract, because we are only interested
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*extract.identifier_));
DCHECK(has_aggregation_.size() >= 3U)
<< "Expected 3 has_aggregation_ flags for EXTRACT arguments";
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
bool Visit(Identifier &ident) override {
const auto &symbol = symbol_table_.at(ident);
if (!utils::Contains(output_symbols_, symbol)) {

View File

@ -1638,6 +1638,25 @@ TYPED_TEST(CypherMainVisitorTest, ReturnReduce) {
EXPECT_TRUE(add);
}
TYPED_TEST(CypherMainVisitorTest, ReturnExtract) {
TypeParam ast_generator("RETURN extract(x IN [1,2,3] | sum + x)");
auto *query = ast_generator.query_;
ASSERT_TRUE(query->single_query_);
auto *single_query = query->single_query_;
ASSERT_EQ(single_query->clauses_.size(), 1U);
auto *ret = dynamic_cast<Return *>(single_query->clauses_[0]);
ASSERT_TRUE(ret);
ASSERT_EQ(ret->body_.named_expressions.size(), 1U);
auto *extract =
dynamic_cast<Extract *>(ret->body_.named_expressions[0]->expression_);
ASSERT_TRUE(extract);
EXPECT_EQ(extract->identifier_->name_, "x");
auto *list_literal = dynamic_cast<ListLiteral *>(extract->list_);
EXPECT_TRUE(list_literal);
auto *add = dynamic_cast<AdditionOperator *>(extract->expression_);
EXPECT_TRUE(add);
}
TYPED_TEST(CypherMainVisitorTest, MatchBfsReturn) {
TypeParam ast_generator(
"MATCH (n) -[r:type1|type2 *bfs..10 (e, n|e.prop = 42)]-> (m) RETURN r");

View File

@ -287,8 +287,8 @@ void FillReturnBody(AstStorage &, ReturnBody &body, OrderBy order_by,
body.order_by = order_by.expressions;
body.limit = limit.expression;
}
void FillReturnBody(AstStorage &, ReturnBody &body, OrderBy order_by,
Skip skip, Limit limit = Limit{}) {
void FillReturnBody(AstStorage &, ReturnBody &body, OrderBy order_by, Skip skip,
Limit limit = Limit{}) {
body.order_by = order_by.expressions;
body.skip = skip.expression;
body.limit = limit.expression;
@ -570,6 +570,9 @@ auto GetMerge(AstStorage &storage, Pattern *pattern, OnMatch on_match,
storage.Create<query::Reduce>( \
storage.Create<query::Identifier>(accumulator), initializer, \
storage.Create<query::Identifier>(variable), list, expr)
#define EXTRACT(variable, list, expr) \
storage.Create<query::Extract>(storage.Create<query::Identifier>(variable), \
list, expr)
#define CREATE_USER(username, password) \
storage.Create<query::ModifyUser>((username), LITERAL(password), true)
#define ALTER_USER(username, password) \

View File

@ -1285,6 +1285,48 @@ TEST(ExpressionEvaluator, FunctionReduce) {
EXPECT_EQ(value.Value<int64_t>(), 3);
}
TEST(ExpressionEvaluator, FunctionExtract) {
AstStorage storage;
auto *ident_x = IDENT("x");
auto *extract =
EXTRACT("x", LIST(LITERAL(1), LITERAL(2), LITERAL(TypedValue::Null)),
ADD(ident_x, LITERAL(1)));
NoContextExpressionEvaluator eval;
const auto x_sym = eval.ctx.symbol_table_.CreateSymbol("x", true);
eval.ctx.symbol_table_[*extract->identifier_] = x_sym;
eval.ctx.symbol_table_[*ident_x] = x_sym;
auto value = extract->Accept(eval.eval);
EXPECT_EQ(value.type(), TypedValue::Type::List);
auto result = value.ValueList();
EXPECT_EQ(result[0].ValueInt(), 2);
EXPECT_EQ(result[1].ValueInt(), 3);
EXPECT_TRUE(result[2].IsNull());
}
TEST(ExpressionEvaluator, FunctionExtractNull) {
AstStorage storage;
auto *ident_x = IDENT("x");
auto *extract =
EXTRACT("x", LITERAL(TypedValue::Null), ADD(ident_x, LITERAL(1)));
NoContextExpressionEvaluator eval;
const auto x_sym = eval.ctx.symbol_table_.CreateSymbol("x", true);
eval.ctx.symbol_table_[*extract->identifier_] = x_sym;
eval.ctx.symbol_table_[*ident_x] = x_sym;
auto value = extract->Accept(eval.eval);
EXPECT_TRUE(value.IsNull());
}
TEST(ExpressionEvaluator, FunctionExtractExceptions) {
AstStorage storage;
auto *ident_x = IDENT("x");
auto *extract = EXTRACT("x", LITERAL("bla"), ADD(ident_x, LITERAL(1)));
NoContextExpressionEvaluator eval;
const auto x_sym = eval.ctx.symbol_table_.CreateSymbol("x", true);
eval.ctx.symbol_table_[*extract->identifier_] = x_sym;
eval.ctx.symbol_table_[*ident_x] = x_sym;
EXPECT_THROW(extract->Accept(eval.eval), QueryRuntimeException);
}
TEST(ExpressionEvaluator, FunctionAssert) {
// Invalid calls.
ASSERT_THROW(EvaluateFunction("ASSERT", {}), QueryRuntimeException);

View File

@ -849,6 +849,33 @@ TEST_F(TestSymbolGenerator, WithReturnReduce) {
EXPECT_NE(symbol_table.at(*reduce->accumulator_), symbol_table.at(*ret_as_y));
}
TEST_F(TestSymbolGenerator, WithReturnExtract) {
// Test WITH [1, 2, 3] AS x RETURN extract(x IN x | x + 1) AS x, x AS y
auto *with_as_x = AS("x");
auto *list_x = IDENT("x");
auto *expr_x = IDENT("x");
auto *extract = EXTRACT("x", LIST(list_x), ADD(expr_x, LITERAL(1)));
auto *ret_as_x = AS("x");
auto *ret_x = IDENT("x");
auto *ret_as_y = AS("y");
auto query = QUERY(
SINGLE_QUERY(WITH(LIST(LITERAL(1), LITERAL(2), LITERAL(3)), with_as_x),
RETURN(extract, ret_as_x, ret_x, ret_as_y)));
query->Accept(symbol_generator);
// Symbols for `WITH .. AS x`, `EXTRACT(x ...)`, `EXTRACT(...) AS x` and
// `AS y`.
EXPECT_EQ(symbol_table.max_position(), 4);
// Check `WITH .. AS x` is the same as `... IN x` and `RETURN ... x AS y`
EXPECT_EQ(symbol_table.at(*with_as_x), symbol_table.at(*list_x));
EXPECT_EQ(symbol_table.at(*with_as_x), symbol_table.at(*ret_x));
EXPECT_NE(symbol_table.at(*with_as_x),
symbol_table.at(*extract->identifier_));
EXPECT_NE(symbol_table.at(*with_as_x), symbol_table.at(*ret_as_x));
// Check `EXTRACT(x ...)` is only equal to `x + 1`
EXPECT_EQ(symbol_table.at(*extract->identifier_), symbol_table.at(*expr_x));
EXPECT_NE(symbol_table.at(*extract->identifier_), symbol_table.at(*ret_as_x));
}
TEST_F(TestSymbolGenerator, MatchBfsReturn) {
// Test MATCH (n) -[r *bfs..n.prop] (r, n | r.prop)]-> (m) RETURN r AS r
auto prop = dba.Property("prop");