Improve Visit performance (#774)

This commit is contained in:
Antonio Filipovic 2023-02-17 13:09:25 +01:00 committed by GitHub
parent bbce21e78f
commit 862a1afdf1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 181 additions and 42 deletions

View File

@ -301,6 +301,7 @@ cpp<#
(lcp:define-class expression (tree "::utils::Visitable<HierarchicalTreeVisitor>"
"::utils::Visitable<ExpressionVisitor<TypedValue>>"
"::utils::Visitable<ExpressionVisitor<TypedValue*>>"
"::utils::Visitable<ExpressionVisitor<void>>")
()
(:abstractp t)
@ -308,6 +309,7 @@ cpp<#
#>cpp
using utils::Visitable<HierarchicalTreeVisitor>::Accept;
using utils::Visitable<ExpressionVisitor<TypedValue>>::Accept;
using utils::Visitable<ExpressionVisitor<TypedValue*>>::Accept;
using utils::Visitable<ExpressionVisitor<void>>::Accept;
Expression() = default;
@ -407,6 +409,7 @@ cpp<#
(:public
#>cpp
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
@ -438,6 +441,7 @@ cpp<#
(:public
#>cpp
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
@ -485,6 +489,7 @@ cpp<#
}
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
@ -538,6 +543,7 @@ cpp<#
ListSlicingOperator() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
@ -581,6 +587,7 @@ cpp<#
IfOperator() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
@ -628,6 +635,7 @@ cpp<#
PrimitiveLiteral() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
DEFVISITABLE(HierarchicalTreeVisitor);
cpp<#)
@ -656,6 +664,7 @@ cpp<#
ListLiteral() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
@ -688,6 +697,7 @@ cpp<#
MapLiteral() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
@ -720,6 +730,7 @@ cpp<#
Identifier() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
DEFVISITABLE(HierarchicalTreeVisitor);
@ -757,6 +768,7 @@ cpp<#
PropertyLookup() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
@ -797,6 +809,7 @@ cpp<#
LabelsTest() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
@ -837,6 +850,7 @@ cpp<#
Function() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
@ -892,6 +906,7 @@ cpp<#
Reduce() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
@ -930,6 +945,7 @@ cpp<#
Coalesce() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
@ -969,6 +985,7 @@ cpp<#
Extract() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
@ -1005,6 +1022,7 @@ cpp<#
All() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
@ -1046,6 +1064,7 @@ cpp<#
Single() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
@ -1087,6 +1106,7 @@ cpp<#
Any() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
@ -1128,6 +1148,7 @@ cpp<#
None() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
@ -1159,6 +1180,7 @@ cpp<#
ParameterLookup() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
DEFVISITABLE(HierarchicalTreeVisitor);
cpp<#)
@ -1186,6 +1208,7 @@ cpp<#
RegexMatch() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
@ -1205,6 +1228,7 @@ cpp<#
(lcp:define-class named-expression (tree "::utils::Visitable<HierarchicalTreeVisitor>"
"::utils::Visitable<ExpressionVisitor<TypedValue>>"
"::utils::Visitable<ExpressionVisitor<TypedValue*>>"
"::utils::Visitable<ExpressionVisitor<void>>")
((name "std::string" :scope :public)
(expression "Expression *" :initval "nullptr" :scope :public
@ -1223,6 +1247,7 @@ cpp<#
NamedExpression() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<TypedValue*>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -31,6 +31,71 @@
namespace memgraph::query {
class ReferenceExpressionEvaluator : public ExpressionVisitor<TypedValue *> {
public:
ReferenceExpressionEvaluator(Frame *frame, const SymbolTable *symbol_table, const EvaluationContext *ctx)
: frame_(frame), symbol_table_(symbol_table), ctx_(ctx) {}
using ExpressionVisitor<TypedValue *>::Visit;
utils::MemoryResource *GetMemoryResource() const { return ctx_->memory; }
#define UNSUCCESSFUL_VISIT(expr_name) \
TypedValue *Visit(expr_name &expr) override { return nullptr; }
TypedValue *Visit(Identifier &ident) override { return &frame_->at(symbol_table_->at(ident)); }
UNSUCCESSFUL_VISIT(NamedExpression);
UNSUCCESSFUL_VISIT(OrOperator);
UNSUCCESSFUL_VISIT(XorOperator);
UNSUCCESSFUL_VISIT(AdditionOperator);
UNSUCCESSFUL_VISIT(SubtractionOperator);
UNSUCCESSFUL_VISIT(MultiplicationOperator);
UNSUCCESSFUL_VISIT(DivisionOperator);
UNSUCCESSFUL_VISIT(ModOperator);
UNSUCCESSFUL_VISIT(NotEqualOperator);
UNSUCCESSFUL_VISIT(EqualOperator);
UNSUCCESSFUL_VISIT(LessOperator);
UNSUCCESSFUL_VISIT(GreaterOperator);
UNSUCCESSFUL_VISIT(LessEqualOperator);
UNSUCCESSFUL_VISIT(GreaterEqualOperator);
UNSUCCESSFUL_VISIT(NotOperator);
UNSUCCESSFUL_VISIT(UnaryPlusOperator);
UNSUCCESSFUL_VISIT(UnaryMinusOperator);
UNSUCCESSFUL_VISIT(AndOperator);
UNSUCCESSFUL_VISIT(IfOperator);
UNSUCCESSFUL_VISIT(InListOperator);
UNSUCCESSFUL_VISIT(SubscriptOperator);
UNSUCCESSFUL_VISIT(ListSlicingOperator);
UNSUCCESSFUL_VISIT(IsNullOperator);
UNSUCCESSFUL_VISIT(PropertyLookup);
UNSUCCESSFUL_VISIT(LabelsTest);
UNSUCCESSFUL_VISIT(PrimitiveLiteral);
UNSUCCESSFUL_VISIT(ListLiteral);
UNSUCCESSFUL_VISIT(MapLiteral);
UNSUCCESSFUL_VISIT(Aggregation);
UNSUCCESSFUL_VISIT(Coalesce);
UNSUCCESSFUL_VISIT(Function);
UNSUCCESSFUL_VISIT(Reduce);
UNSUCCESSFUL_VISIT(Extract);
UNSUCCESSFUL_VISIT(All);
UNSUCCESSFUL_VISIT(Single);
UNSUCCESSFUL_VISIT(Any);
UNSUCCESSFUL_VISIT(None);
UNSUCCESSFUL_VISIT(ParameterLookup);
UNSUCCESSFUL_VISIT(RegexMatch);
private:
Frame *frame_;
const SymbolTable *symbol_table_;
const EvaluationContext *ctx_;
};
class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
public:
ExpressionEvaluator(Frame *frame, const SymbolTable &symbol_table, const EvaluationContext &ctx, DbAccessor *dba,
@ -159,50 +224,53 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
}
TypedValue Visit(SubscriptOperator &list_indexing) override {
auto lhs = list_indexing.expression1_->Accept(*this);
ReferenceExpressionEvaluator referenceExpressionEvaluator(frame_, symbol_table_, ctx_);
TypedValue *lhs_ptr = list_indexing.expression1_->Accept(referenceExpressionEvaluator);
TypedValue lhs;
const auto referenced = nullptr != lhs_ptr;
if (!referenced) {
lhs = list_indexing.expression1_->Accept(*this);
lhs_ptr = &lhs;
}
auto index = list_indexing.expression2_->Accept(*this);
if (!lhs.IsList() && !lhs.IsMap() && !lhs.IsVertex() && !lhs.IsEdge() && !lhs.IsNull())
if (!lhs_ptr->IsList() && !lhs_ptr->IsMap() && !lhs_ptr->IsVertex() && !lhs_ptr->IsEdge() && !lhs_ptr->IsNull())
throw QueryRuntimeException(
"Expected a list, a map, a node or an edge to index with '[]', got "
"{}.",
lhs.type());
if (lhs.IsNull() || index.IsNull()) return TypedValue(ctx_->memory);
if (lhs.IsList()) {
lhs_ptr->type());
if (lhs_ptr->IsNull() || index.IsNull()) return TypedValue(ctx_->memory);
if (lhs_ptr->IsList()) {
if (!index.IsInt()) throw QueryRuntimeException("Expected an integer as a list index, got {}.", index.type());
auto index_int = index.ValueInt();
// NOTE: Take non-const reference to list, so that we can move out the
// indexed element as the result.
auto &list = lhs.ValueList();
auto &list = lhs_ptr->ValueList();
if (index_int < 0) {
index_int += static_cast<int64_t>(list.size());
}
if (index_int >= static_cast<int64_t>(list.size()) || index_int < 0) return TypedValue(ctx_->memory);
// NOTE: Explicit move is needed, so that we return the move constructed
// value and preserve the correct MemoryResource.
return std::move(list[index_int]);
return referenced ? TypedValue(list[index_int], ctx_->memory)
: TypedValue(std::move(list[index_int]), ctx_->memory);
}
if (lhs.IsMap()) {
if (lhs_ptr->IsMap()) {
if (!index.IsString()) throw QueryRuntimeException("Expected a string as a map index, got {}.", index.type());
// NOTE: Take non-const reference to map, so that we can move out the
// looked-up element as the result.
auto &map = lhs.ValueMap();
auto &map = lhs_ptr->ValueMap();
auto found = map.find(index.ValueString());
if (found == map.end()) return TypedValue(ctx_->memory);
// NOTE: Explicit move is needed, so that we return the move constructed
// value and preserve the correct MemoryResource.
return std::move(found->second);
return referenced ? TypedValue(found->second, ctx_->memory) : TypedValue(std::move(found->second), ctx_->memory);
}
if (lhs.IsVertex()) {
if (lhs_ptr->IsVertex()) {
if (!index.IsString()) throw QueryRuntimeException("Expected a string as a property name, got {}.", index.type());
return TypedValue(GetProperty(lhs.ValueVertex(), index.ValueString()), ctx_->memory);
return {GetProperty(lhs_ptr->ValueVertex(), index.ValueString()), ctx_->memory};
}
if (lhs.IsEdge()) {
if (lhs_ptr->IsEdge()) {
if (!index.IsString()) throw QueryRuntimeException("Expected a string as a property name, got {}.", index.type());
return TypedValue(GetProperty(lhs.ValueEdge(), index.ValueString()), ctx_->memory);
}
return {GetProperty(lhs_ptr->ValueEdge(), index.ValueString()), ctx_->memory};
};
// lhs is Null
return TypedValue(ctx_->memory);
@ -258,7 +326,15 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
}
TypedValue Visit(PropertyLookup &property_lookup) override {
auto expression_result = property_lookup.expression_->Accept(*this);
ReferenceExpressionEvaluator referenceExpressionEvaluator(frame_, symbol_table_, ctx_);
TypedValue *expression_result_ptr = property_lookup.expression_->Accept(referenceExpressionEvaluator);
TypedValue expression_result;
if (nullptr == expression_result_ptr) {
expression_result = property_lookup.expression_->Accept(*this);
expression_result_ptr = &expression_result;
}
auto maybe_date = [this](const auto &date, const auto &prop_name) -> std::optional<TypedValue> {
if (prop_name == "year") {
return TypedValue(date.year, ctx_->memory);
@ -332,42 +408,38 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
}
return std::nullopt;
};
switch (expression_result.type()) {
switch (expression_result_ptr->type()) {
case TypedValue::Type::Null:
return TypedValue(ctx_->memory);
case TypedValue::Type::Vertex:
return TypedValue(GetProperty(expression_result.ValueVertex(), property_lookup.property_), ctx_->memory);
return TypedValue(GetProperty(expression_result_ptr->ValueVertex(), property_lookup.property_), ctx_->memory);
case TypedValue::Type::Edge:
return TypedValue(GetProperty(expression_result.ValueEdge(), property_lookup.property_), ctx_->memory);
return TypedValue(GetProperty(expression_result_ptr->ValueEdge(), property_lookup.property_), ctx_->memory);
case TypedValue::Type::Map: {
// NOTE: Take non-const reference to map, so that we can move out the
// looked-up element as the result.
auto &map = expression_result.ValueMap();
auto &map = expression_result_ptr->ValueMap();
auto found = map.find(property_lookup.property_.name.c_str());
if (found == map.end()) return TypedValue(ctx_->memory);
// NOTE: Explicit move is needed, so that we return the move constructed
// value and preserve the correct MemoryResource.
return std::move(found->second);
return TypedValue(found->second, ctx_->memory);
}
case TypedValue::Type::Duration: {
const auto &prop_name = property_lookup.property_.name;
const auto &dur = expression_result.ValueDuration();
const auto &dur = expression_result_ptr->ValueDuration();
if (auto dur_field = maybe_duration(dur, prop_name); dur_field) {
return std::move(*dur_field);
return TypedValue(*dur_field, ctx_->memory);
}
throw QueryRuntimeException("Invalid property name {} for Duration", prop_name);
}
case TypedValue::Type::Date: {
const auto &prop_name = property_lookup.property_.name;
const auto &date = expression_result.ValueDate();
const auto &date = expression_result_ptr->ValueDate();
if (auto date_field = maybe_date(date, prop_name); date_field) {
return std::move(*date_field);
return TypedValue(*date_field, ctx_->memory);
}
throw QueryRuntimeException("Invalid property name {} for Date", prop_name);
}
case TypedValue::Type::LocalTime: {
const auto &prop_name = property_lookup.property_.name;
const auto &lt = expression_result.ValueLocalTime();
const auto &lt = expression_result_ptr->ValueLocalTime();
if (auto lt_field = maybe_local_time(lt, prop_name); lt_field) {
return std::move(*lt_field);
}
@ -375,20 +447,20 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
}
case TypedValue::Type::LocalDateTime: {
const auto &prop_name = property_lookup.property_.name;
const auto &ldt = expression_result.ValueLocalDateTime();
const auto &ldt = expression_result_ptr->ValueLocalDateTime();
if (auto date_field = maybe_date(ldt.date, prop_name); date_field) {
return std::move(*date_field);
}
if (auto lt_field = maybe_local_time(ldt.local_time, prop_name); lt_field) {
return std::move(*lt_field);
return TypedValue(*lt_field, ctx_->memory);
}
throw QueryRuntimeException("Invalid property name {} for LocalDateTime", prop_name);
}
case TypedValue::Type::Graph: {
const auto &prop_name = property_lookup.property_.name;
const auto &graph = expression_result.ValueGraph();
const auto &graph = expression_result_ptr->ValueGraph();
if (auto graph_field = maybe_graph(graph, prop_name); graph_field) {
return std::move(*graph_field);
return TypedValue(*graph_field, ctx_->memory);
}
throw QueryRuntimeException("Invalid property name {} for Graph", prop_name);
}

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -26,6 +26,7 @@
#include "query/interpret/eval.hpp"
#include "query/interpret/frame.hpp"
#include "query/path.hpp"
#include "query/typed_value.hpp"
#include "storage/v2/storage.hpp"
#include "utils/exceptions.hpp"
#include "utils/string.hpp"
@ -426,6 +427,47 @@ TEST_F(ExpressionEvaluatorTest, VertexAndEdgeIndexing) {
}
}
TEST_F(ExpressionEvaluatorTest, TypedValueListIndexing) {
auto list_vector = memgraph::utils::pmr::vector<TypedValue>(ctx.memory);
list_vector.emplace_back("string1");
list_vector.emplace_back(TypedValue("string2"));
auto *identifier = storage.Create<Identifier>("n");
auto node_symbol = symbol_table.CreateSymbol("n", true);
identifier->MapTo(node_symbol);
frame[node_symbol] = TypedValue(list_vector, ctx.memory);
{
// Legal indexing.
auto *op = storage.Create<SubscriptOperator>(identifier, storage.Create<PrimitiveLiteral>(0));
auto value = Eval(op);
EXPECT_EQ(value.ValueString(), "string1");
}
{
// Out of bounds indexing
auto *op = storage.Create<SubscriptOperator>(identifier, storage.Create<PrimitiveLiteral>(3));
auto value = Eval(op);
EXPECT_TRUE(value.IsNull());
}
{
// Out of bounds indexing with negative bound.
auto *op = storage.Create<SubscriptOperator>(identifier, storage.Create<PrimitiveLiteral>(-100));
auto value = Eval(op);
EXPECT_TRUE(value.IsNull());
}
{
// Legal indexing with negative index.
auto *op = storage.Create<SubscriptOperator>(identifier, storage.Create<PrimitiveLiteral>(-2));
auto value = Eval(op);
EXPECT_EQ(value.ValueString(), "string1");
}
{
// Indexing with incompatible type.
auto *op = storage.Create<SubscriptOperator>(identifier, storage.Create<PrimitiveLiteral>("bla"));
EXPECT_THROW(Eval(op), QueryRuntimeException);
}
}
TEST_F(ExpressionEvaluatorTest, ListSlicingOperator) {
auto *list_literal = storage.Create<ListLiteral>(
std::vector<Expression *>{storage.Create<PrimitiveLiteral>(1), storage.Create<PrimitiveLiteral>(2),