Fix bug in SymbolTable
Summary: This is a bugfix for D1836. It made `SymbolTable` return references to vector elements, which then get invalidated and weird stuff happens. This made a `DCHECK` in `rule_based_planner.hpp` trigger, and it was noticed by @ipaljak 2 months later. All `DCHECK`s in `rule_based_planner.hpp` are now changed to `CHECK`s. Also, hash function for `Symbol` was wrong, because it also took `user_declared` field into consideration, and `==` operator doesn't do that. Reviewers: ipaljak, teon.banek, mferencevic, msantl Reviewed By: msantl Subscribers: pullbot, ipaljak Differential Revision: https://phabricator.memgraph.io/D1938
This commit is contained in:
parent
3436094df6
commit
dc231fe4e7
@ -66,7 +66,6 @@ struct hash<query::Symbol> {
|
||||
size_t prime = 265443599u;
|
||||
size_t hash = std::hash<int>{}(symbol.position());
|
||||
hash ^= prime * std::hash<std::string>{}(symbol.name());
|
||||
hash ^= prime * std::hash<bool>{}(symbol.user_declared());
|
||||
hash ^= prime * std::hash<int>{}(static_cast<int>(symbol.type()));
|
||||
return hash;
|
||||
}
|
||||
|
@ -16,9 +16,11 @@ class SymbolTable final {
|
||||
int32_t token_position = -1) {
|
||||
CHECK(table_.size() <= std::numeric_limits<int32_t>::max())
|
||||
<< "SymbolTable size doesn't fit into 32-bit integer!";
|
||||
int32_t position = static_cast<int32_t>(table_.size());
|
||||
table_.emplace_back(name, position, user_declared, type, token_position);
|
||||
return table_.back();
|
||||
auto got = table_.emplace(position_, Symbol(name, position_, user_declared,
|
||||
type, token_position));
|
||||
CHECK(got.second) << "Duplicate symbol ID!";
|
||||
position_++;
|
||||
return got.first->second;
|
||||
}
|
||||
|
||||
const Symbol &at(const Identifier &ident) const {
|
||||
@ -36,7 +38,8 @@ class SymbolTable final {
|
||||
|
||||
const auto &table() const { return table_; }
|
||||
|
||||
std::vector<Symbol> table_;
|
||||
int32_t position_{0};
|
||||
std::map<int32_t, Symbol> table_;
|
||||
};
|
||||
|
||||
} // namespace query
|
||||
|
@ -77,7 +77,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
if (where) {
|
||||
where->Accept(*this);
|
||||
}
|
||||
DCHECK(aggregations_.empty())
|
||||
CHECK(aggregations_.empty())
|
||||
<< "Unexpected aggregations in ORDER BY or WHERE";
|
||||
}
|
||||
}
|
||||
@ -121,7 +121,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
|
||||
public:
|
||||
bool PostVisit(ListLiteral &list_literal) override {
|
||||
DCHECK(list_literal.elements_.size() <= has_aggregation_.size())
|
||||
CHECK(list_literal.elements_.size() <= has_aggregation_.size())
|
||||
<< "Expected as many has_aggregation_ flags as there are list"
|
||||
"elements.";
|
||||
PostVisitCollectionLiteral(list_literal, [](auto it) { return *it; });
|
||||
@ -129,7 +129,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
}
|
||||
|
||||
bool PostVisit(MapLiteral &map_literal) override {
|
||||
DCHECK(map_literal.elements_.size() <= has_aggregation_.size())
|
||||
CHECK(map_literal.elements_.size() <= has_aggregation_.size())
|
||||
<< "Expected has_aggregation_ flags as much as there are map elements.";
|
||||
PostVisitCollectionLiteral(map_literal, [](auto it) { return it->second; });
|
||||
return true;
|
||||
@ -139,7 +139,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
// Remove the symbol which is bound by all, because we are only interested
|
||||
// in free (unbound) symbols.
|
||||
used_symbols_.erase(symbol_table_.at(*all.identifier_));
|
||||
DCHECK(has_aggregation_.size() >= 3U)
|
||||
CHECK(has_aggregation_.size() >= 3U)
|
||||
<< "Expected 3 has_aggregation_ flags for ALL arguments";
|
||||
bool has_aggr = false;
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
@ -154,7 +154,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
// Remove the symbol which is bound by single, because we are only
|
||||
// interested in free (unbound) symbols.
|
||||
used_symbols_.erase(symbol_table_.at(*single.identifier_));
|
||||
DCHECK(has_aggregation_.size() >= 3U)
|
||||
CHECK(has_aggregation_.size() >= 3U)
|
||||
<< "Expected 3 has_aggregation_ flags for SINGLE arguments";
|
||||
bool has_aggr = false;
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
@ -170,7 +170,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
// in free (unbound) symbols.
|
||||
used_symbols_.erase(symbol_table_.at(*reduce.accumulator_));
|
||||
used_symbols_.erase(symbol_table_.at(*reduce.identifier_));
|
||||
DCHECK(has_aggregation_.size() >= 5U)
|
||||
CHECK(has_aggregation_.size() >= 5U)
|
||||
<< "Expected 5 has_aggregation_ flags for REDUCE arguments";
|
||||
bool has_aggr = false;
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
@ -198,7 +198,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
// Remove the symbol bound by extract, because we are only interested
|
||||
// in free (unbound) symbols.
|
||||
used_symbols_.erase(symbol_table_.at(*extract.identifier_));
|
||||
DCHECK(has_aggregation_.size() >= 3U)
|
||||
CHECK(has_aggregation_.size() >= 3U)
|
||||
<< "Expected 3 has_aggregation_ flags for EXTRACT arguments";
|
||||
bool has_aggr = false;
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
@ -257,12 +257,12 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
has_aggregation_.emplace_back(has_aggr);
|
||||
// TODO: Once we allow aggregations here, insert appropriate stuff in
|
||||
// group_by.
|
||||
DCHECK(!has_aggr) << "Currently aggregations in CASE are not allowed";
|
||||
CHECK(!has_aggr) << "Currently aggregations in CASE are not allowed";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PostVisit(Function &function) override {
|
||||
DCHECK(function.arguments_.size() <= has_aggregation_.size())
|
||||
CHECK(function.arguments_.size() <= has_aggregation_.size())
|
||||
<< "Expected as many has_aggregation_ flags as there are"
|
||||
"function arguments.";
|
||||
bool has_aggr = false;
|
||||
@ -278,7 +278,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
|
||||
#define VISIT_BINARY_OPERATOR(BinaryOperator) \
|
||||
bool PostVisit(BinaryOperator &op) override { \
|
||||
DCHECK(has_aggregation_.size() >= 2U) \
|
||||
CHECK(has_aggregation_.size() >= 2U) \
|
||||
<< "Expected at least 2 has_aggregation_ flags."; \
|
||||
/* has_aggregation_ stack is reversed, last result is from the 2nd */ \
|
||||
/* expression. */ \
|
||||
@ -336,7 +336,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
}
|
||||
|
||||
bool PostVisit(NamedExpression &named_expr) override {
|
||||
DCHECK(has_aggregation_.size() == 1U)
|
||||
CHECK(has_aggregation_.size() == 1U)
|
||||
<< "Expected to reduce has_aggregation_ to single boolean.";
|
||||
if (!has_aggregation_.back()) {
|
||||
group_by_.emplace_back(named_expr.expression_);
|
||||
@ -354,9 +354,9 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
// This should be used when body.all_identifiers is true, to generate
|
||||
// expressions for Produce operator.
|
||||
void ExpandUserSymbols() {
|
||||
DCHECK(named_expressions_.empty())
|
||||
CHECK(named_expressions_.empty())
|
||||
<< "ExpandUserSymbols should be first to fill named_expressions_";
|
||||
DCHECK(output_symbols_.empty())
|
||||
CHECK(output_symbols_.empty())
|
||||
<< "ExpandUserSymbols should be first to fill output_symbols_";
|
||||
for (const auto &symbol : bound_symbols_) {
|
||||
if (!symbol.user_declared()) {
|
||||
|
@ -94,20 +94,20 @@ template <typename T>
|
||||
auto ReducePattern(
|
||||
Pattern &pattern, std::function<T(NodeAtom *)> base,
|
||||
std::function<T(T, NodeAtom *, EdgeAtom *, NodeAtom *)> collect) {
|
||||
DCHECK(!pattern.atoms_.empty()) << "Missing atoms in pattern";
|
||||
CHECK(!pattern.atoms_.empty()) << "Missing atoms in pattern";
|
||||
auto atoms_it = pattern.atoms_.begin();
|
||||
auto current_node = utils::Downcast<NodeAtom>(*atoms_it++);
|
||||
DCHECK(current_node) << "First pattern atom is not a node";
|
||||
CHECK(current_node) << "First pattern atom is not a node";
|
||||
auto last_res = base(current_node);
|
||||
// Remaining atoms need to follow sequentially as (EdgeAtom, NodeAtom)*
|
||||
while (atoms_it != pattern.atoms_.end()) {
|
||||
auto edge = utils::Downcast<EdgeAtom>(*atoms_it++);
|
||||
DCHECK(edge) << "Expected an edge atom in pattern.";
|
||||
DCHECK(atoms_it != pattern.atoms_.end())
|
||||
CHECK(edge) << "Expected an edge atom in pattern.";
|
||||
CHECK(atoms_it != pattern.atoms_.end())
|
||||
<< "Edge atom should not end the pattern.";
|
||||
auto prev_node = current_node;
|
||||
current_node = utils::Downcast<NodeAtom>(*atoms_it++);
|
||||
DCHECK(current_node) << "Expected a node atom in pattern.";
|
||||
CHECK(current_node) << "Expected a node atom in pattern.";
|
||||
last_res = collect(std::move(last_res), prev_node, edge, current_node);
|
||||
}
|
||||
return last_res;
|
||||
@ -179,7 +179,7 @@ class RuleBasedPlanner {
|
||||
}
|
||||
int merge_id = 0;
|
||||
for (auto *clause : query_part.remaining_clauses) {
|
||||
DCHECK(!utils::IsSubtype(*clause, Match::kType))
|
||||
CHECK(!utils::IsSubtype(*clause, Match::kType))
|
||||
<< "Unexpected Match in remaining clauses";
|
||||
if (auto *ret = utils::Downcast<Return>(clause)) {
|
||||
input_op = impl::GenReturn(
|
||||
@ -407,7 +407,7 @@ class RuleBasedPlanner {
|
||||
symbol_table.at(*expansion.node2->identifier_);
|
||||
auto existing_node = utils::Contains(bound_symbols, node_symbol);
|
||||
const auto &edge_symbol = symbol_table.at(*edge->identifier_);
|
||||
DCHECK(!utils::Contains(bound_symbols, edge_symbol))
|
||||
CHECK(!utils::Contains(bound_symbols, edge_symbol))
|
||||
<< "Existing edges are not supported";
|
||||
std::vector<storage::EdgeType> edge_types;
|
||||
edge_types.reserve(edge->edge_types_.size());
|
||||
@ -439,7 +439,7 @@ class RuleBasedPlanner {
|
||||
bound_symbols.insert(filter_lambda.inner_edge_symbol).second;
|
||||
bool inner_node_bound =
|
||||
bound_symbols.insert(filter_lambda.inner_node_symbol).second;
|
||||
DCHECK(inner_edge_bound && inner_node_bound)
|
||||
CHECK(inner_edge_bound && inner_node_bound)
|
||||
<< "An inner edge and node can't be bound from before";
|
||||
}
|
||||
// Join regular filters with lambda filter expression, so that they
|
||||
@ -518,14 +518,14 @@ class RuleBasedPlanner {
|
||||
storage);
|
||||
}
|
||||
}
|
||||
DCHECK(named_paths.empty()) << "Expected to generate all named paths";
|
||||
CHECK(named_paths.empty()) << "Expected to generate all named paths";
|
||||
// We bound all named path symbols, so just add them to new_symbols.
|
||||
for (const auto &named_path : matching.named_paths) {
|
||||
DCHECK(utils::Contains(bound_symbols, named_path.first))
|
||||
CHECK(utils::Contains(bound_symbols, named_path.first))
|
||||
<< "Expected generated named path to have bound symbol";
|
||||
match_context.new_symbols.emplace_back(named_path.first);
|
||||
}
|
||||
DCHECK(filters.empty()) << "Expected to generate all filters";
|
||||
CHECK(filters.empty()) << "Expected to generate all filters";
|
||||
return last_op;
|
||||
}
|
||||
|
||||
@ -544,12 +544,12 @@ class RuleBasedPlanner {
|
||||
for (auto &set : merge.on_create_) {
|
||||
on_create = HandleWriteClause(set, on_create, *context_->symbol_table,
|
||||
context_->bound_symbols);
|
||||
DCHECK(on_create) << "Expected SET in MERGE ... ON CREATE";
|
||||
CHECK(on_create) << "Expected SET in MERGE ... ON CREATE";
|
||||
}
|
||||
for (auto &set : merge.on_match_) {
|
||||
on_match = HandleWriteClause(set, on_match, *context_->symbol_table,
|
||||
context_->bound_symbols);
|
||||
DCHECK(on_match) << "Expected SET in MERGE ... ON MATCH";
|
||||
CHECK(on_match) << "Expected SET in MERGE ... ON MATCH";
|
||||
}
|
||||
return std::make_unique<plan::Merge>(
|
||||
std::move(input_op), std::move(on_match), std::move(on_create));
|
||||
|
@ -43,5 +43,5 @@ struct TypedValue {
|
||||
}
|
||||
|
||||
struct SymbolTable {
|
||||
table @0 :List(Sem.Symbol);
|
||||
table @0 :Utils.Map(Utils.BoxInt32, Sem.Symbol);
|
||||
}
|
||||
|
@ -6,8 +6,8 @@
|
||||
#include "query/frontend/semantic/symbol_table.hpp"
|
||||
#include "query/serialization.capnp.h"
|
||||
#include "query/typed_value.hpp"
|
||||
#include "storage/distributed/rpc/serialization.hpp"
|
||||
#include "rpc/serialization.hpp"
|
||||
#include "storage/distributed/rpc/serialization.hpp"
|
||||
|
||||
namespace distributed {
|
||||
class DataManager;
|
||||
@ -32,19 +32,27 @@ void Load(TypedValueVectorCompare *comparator,
|
||||
|
||||
inline void Save(const SymbolTable &symbol_table,
|
||||
capnp::SymbolTable::Builder *builder) {
|
||||
auto list_builder = builder->initTable(symbol_table.table().size());
|
||||
utils::SaveVector<capnp::Symbol, Symbol>(
|
||||
symbol_table.table(), &list_builder,
|
||||
[](auto *builder, const auto &symbol) { Save(symbol, builder); });
|
||||
auto table_builder = builder->initTable();
|
||||
utils::SaveMap<utils::capnp::BoxInt32, capnp::Symbol,
|
||||
std::map<int32_t, Symbol>>(
|
||||
symbol_table.table(), &table_builder,
|
||||
[](auto *builder, const auto &entry) {
|
||||
auto key_builder = builder->initKey();
|
||||
key_builder.setValue(entry.first);
|
||||
auto value_builder = builder->initValue();
|
||||
Save(entry.second, &value_builder);
|
||||
});
|
||||
}
|
||||
|
||||
inline void Load(SymbolTable *symbol_table,
|
||||
const capnp::SymbolTable::Reader &reader) {
|
||||
utils::LoadVector<capnp::Symbol, Symbol>(
|
||||
utils::LoadMap<utils::capnp::BoxInt32, capnp::Symbol,
|
||||
std::map<int32_t, Symbol>>(
|
||||
&symbol_table->table_, reader.getTable(), [](const auto &reader) {
|
||||
Symbol val;
|
||||
Load(&val, reader);
|
||||
return val;
|
||||
std::pair<int32_t, Symbol> entry;
|
||||
entry.first = reader.getKey().getValue();
|
||||
Load(&entry.second, reader.getValue());
|
||||
return entry;
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -757,7 +757,7 @@ TEST_F(QueryPlanExpandVariable, NamedPath) {
|
||||
2, e, "m", GraphView::OLD);
|
||||
auto find_symbol = [this](const std::string &name) {
|
||||
for (const auto &sym : symbol_table.table())
|
||||
if (sym.name() == name) return sym;
|
||||
if (sym.second.name() == name) return sym.second;
|
||||
throw std::runtime_error("Symbol not found");
|
||||
};
|
||||
|
||||
|
@ -985,10 +985,10 @@ TEST_F(TestSymbolGenerator, MatchVariableLambdaSymbols) {
|
||||
EXPECT_EQ(symbol_table.max_position(), 7);
|
||||
// All symbols except `AS res` are anonymously generated.
|
||||
for (const auto &symbol : symbol_table.table()) {
|
||||
if (symbol.name() == "res") {
|
||||
EXPECT_TRUE(symbol.user_declared());
|
||||
if (symbol.second.name() == "res") {
|
||||
EXPECT_TRUE(symbol.second.user_declared());
|
||||
} else {
|
||||
EXPECT_FALSE(symbol.user_declared());
|
||||
EXPECT_FALSE(symbol.second.user_declared());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user