Fix bug in SymbolTable

Summary:
This is a bugfix for D1836. It made `SymbolTable` return references to vector
elements, which then get invalidated and weird stuff happens.

This made a `DCHECK` in `rule_based_planner.hpp` trigger, and it was noticed by
@ipaljak 2 months later. All `DCHECK`s in `rule_based_planner.hpp` are now
changed to `CHECK`s.

Also, hash function for `Symbol` was wrong, because it also took
`user_declared` field into consideration, and `==` operator doesn't do that.

Reviewers: ipaljak, teon.banek, mferencevic, msantl

Reviewed By: msantl

Subscribers: pullbot, ipaljak

Differential Revision: https://phabricator.memgraph.io/D1938
This commit is contained in:
Marin Tomic 2019-04-02 16:13:41 +02:00
parent 3436094df6
commit dc231fe4e7
8 changed files with 55 additions and 45 deletions

View File

@ -66,7 +66,6 @@ struct hash<query::Symbol> {
size_t prime = 265443599u;
size_t hash = std::hash<int>{}(symbol.position());
hash ^= prime * std::hash<std::string>{}(symbol.name());
hash ^= prime * std::hash<bool>{}(symbol.user_declared());
hash ^= prime * std::hash<int>{}(static_cast<int>(symbol.type()));
return hash;
}

View File

@ -16,9 +16,11 @@ class SymbolTable final {
int32_t token_position = -1) {
CHECK(table_.size() <= std::numeric_limits<int32_t>::max())
<< "SymbolTable size doesn't fit into 32-bit integer!";
int32_t position = static_cast<int32_t>(table_.size());
table_.emplace_back(name, position, user_declared, type, token_position);
return table_.back();
auto got = table_.emplace(position_, Symbol(name, position_, user_declared,
type, token_position));
CHECK(got.second) << "Duplicate symbol ID!";
position_++;
return got.first->second;
}
const Symbol &at(const Identifier &ident) const {
@ -36,7 +38,8 @@ class SymbolTable final {
const auto &table() const { return table_; }
std::vector<Symbol> table_;
int32_t position_{0};
std::map<int32_t, Symbol> table_;
};
} // namespace query

View File

@ -77,7 +77,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
if (where) {
where->Accept(*this);
}
DCHECK(aggregations_.empty())
CHECK(aggregations_.empty())
<< "Unexpected aggregations in ORDER BY or WHERE";
}
}
@ -121,7 +121,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
public:
bool PostVisit(ListLiteral &list_literal) override {
DCHECK(list_literal.elements_.size() <= has_aggregation_.size())
CHECK(list_literal.elements_.size() <= has_aggregation_.size())
<< "Expected as many has_aggregation_ flags as there are list"
"elements.";
PostVisitCollectionLiteral(list_literal, [](auto it) { return *it; });
@ -129,7 +129,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
}
bool PostVisit(MapLiteral &map_literal) override {
DCHECK(map_literal.elements_.size() <= has_aggregation_.size())
CHECK(map_literal.elements_.size() <= has_aggregation_.size())
<< "Expected has_aggregation_ flags as much as there are map elements.";
PostVisitCollectionLiteral(map_literal, [](auto it) { return it->second; });
return true;
@ -139,7 +139,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
// Remove the symbol which is bound by all, because we are only interested
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*all.identifier_));
DCHECK(has_aggregation_.size() >= 3U)
CHECK(has_aggregation_.size() >= 3U)
<< "Expected 3 has_aggregation_ flags for ALL arguments";
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
@ -154,7 +154,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
// Remove the symbol which is bound by single, because we are only
// interested in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*single.identifier_));
DCHECK(has_aggregation_.size() >= 3U)
CHECK(has_aggregation_.size() >= 3U)
<< "Expected 3 has_aggregation_ flags for SINGLE arguments";
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
@ -170,7 +170,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*reduce.accumulator_));
used_symbols_.erase(symbol_table_.at(*reduce.identifier_));
DCHECK(has_aggregation_.size() >= 5U)
CHECK(has_aggregation_.size() >= 5U)
<< "Expected 5 has_aggregation_ flags for REDUCE arguments";
bool has_aggr = false;
for (int i = 0; i < 5; ++i) {
@ -198,7 +198,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
// Remove the symbol bound by extract, because we are only interested
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*extract.identifier_));
DCHECK(has_aggregation_.size() >= 3U)
CHECK(has_aggregation_.size() >= 3U)
<< "Expected 3 has_aggregation_ flags for EXTRACT arguments";
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
@ -257,12 +257,12 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
has_aggregation_.emplace_back(has_aggr);
// TODO: Once we allow aggregations here, insert appropriate stuff in
// group_by.
DCHECK(!has_aggr) << "Currently aggregations in CASE are not allowed";
CHECK(!has_aggr) << "Currently aggregations in CASE are not allowed";
return false;
}
bool PostVisit(Function &function) override {
DCHECK(function.arguments_.size() <= has_aggregation_.size())
CHECK(function.arguments_.size() <= has_aggregation_.size())
<< "Expected as many has_aggregation_ flags as there are"
"function arguments.";
bool has_aggr = false;
@ -278,7 +278,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
#define VISIT_BINARY_OPERATOR(BinaryOperator) \
bool PostVisit(BinaryOperator &op) override { \
DCHECK(has_aggregation_.size() >= 2U) \
CHECK(has_aggregation_.size() >= 2U) \
<< "Expected at least 2 has_aggregation_ flags."; \
/* has_aggregation_ stack is reversed, last result is from the 2nd */ \
/* expression. */ \
@ -336,7 +336,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
}
bool PostVisit(NamedExpression &named_expr) override {
DCHECK(has_aggregation_.size() == 1U)
CHECK(has_aggregation_.size() == 1U)
<< "Expected to reduce has_aggregation_ to single boolean.";
if (!has_aggregation_.back()) {
group_by_.emplace_back(named_expr.expression_);
@ -354,9 +354,9 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
// This should be used when body.all_identifiers is true, to generate
// expressions for Produce operator.
void ExpandUserSymbols() {
DCHECK(named_expressions_.empty())
CHECK(named_expressions_.empty())
<< "ExpandUserSymbols should be first to fill named_expressions_";
DCHECK(output_symbols_.empty())
CHECK(output_symbols_.empty())
<< "ExpandUserSymbols should be first to fill output_symbols_";
for (const auto &symbol : bound_symbols_) {
if (!symbol.user_declared()) {

View File

@ -94,20 +94,20 @@ template <typename T>
auto ReducePattern(
Pattern &pattern, std::function<T(NodeAtom *)> base,
std::function<T(T, NodeAtom *, EdgeAtom *, NodeAtom *)> collect) {
DCHECK(!pattern.atoms_.empty()) << "Missing atoms in pattern";
CHECK(!pattern.atoms_.empty()) << "Missing atoms in pattern";
auto atoms_it = pattern.atoms_.begin();
auto current_node = utils::Downcast<NodeAtom>(*atoms_it++);
DCHECK(current_node) << "First pattern atom is not a node";
CHECK(current_node) << "First pattern atom is not a node";
auto last_res = base(current_node);
// Remaining atoms need to follow sequentially as (EdgeAtom, NodeAtom)*
while (atoms_it != pattern.atoms_.end()) {
auto edge = utils::Downcast<EdgeAtom>(*atoms_it++);
DCHECK(edge) << "Expected an edge atom in pattern.";
DCHECK(atoms_it != pattern.atoms_.end())
CHECK(edge) << "Expected an edge atom in pattern.";
CHECK(atoms_it != pattern.atoms_.end())
<< "Edge atom should not end the pattern.";
auto prev_node = current_node;
current_node = utils::Downcast<NodeAtom>(*atoms_it++);
DCHECK(current_node) << "Expected a node atom in pattern.";
CHECK(current_node) << "Expected a node atom in pattern.";
last_res = collect(std::move(last_res), prev_node, edge, current_node);
}
return last_res;
@ -179,7 +179,7 @@ class RuleBasedPlanner {
}
int merge_id = 0;
for (auto *clause : query_part.remaining_clauses) {
DCHECK(!utils::IsSubtype(*clause, Match::kType))
CHECK(!utils::IsSubtype(*clause, Match::kType))
<< "Unexpected Match in remaining clauses";
if (auto *ret = utils::Downcast<Return>(clause)) {
input_op = impl::GenReturn(
@ -407,7 +407,7 @@ class RuleBasedPlanner {
symbol_table.at(*expansion.node2->identifier_);
auto existing_node = utils::Contains(bound_symbols, node_symbol);
const auto &edge_symbol = symbol_table.at(*edge->identifier_);
DCHECK(!utils::Contains(bound_symbols, edge_symbol))
CHECK(!utils::Contains(bound_symbols, edge_symbol))
<< "Existing edges are not supported";
std::vector<storage::EdgeType> edge_types;
edge_types.reserve(edge->edge_types_.size());
@ -439,7 +439,7 @@ class RuleBasedPlanner {
bound_symbols.insert(filter_lambda.inner_edge_symbol).second;
bool inner_node_bound =
bound_symbols.insert(filter_lambda.inner_node_symbol).second;
DCHECK(inner_edge_bound && inner_node_bound)
CHECK(inner_edge_bound && inner_node_bound)
<< "An inner edge and node can't be bound from before";
}
// Join regular filters with lambda filter expression, so that they
@ -518,14 +518,14 @@ class RuleBasedPlanner {
storage);
}
}
DCHECK(named_paths.empty()) << "Expected to generate all named paths";
CHECK(named_paths.empty()) << "Expected to generate all named paths";
// We bound all named path symbols, so just add them to new_symbols.
for (const auto &named_path : matching.named_paths) {
DCHECK(utils::Contains(bound_symbols, named_path.first))
CHECK(utils::Contains(bound_symbols, named_path.first))
<< "Expected generated named path to have bound symbol";
match_context.new_symbols.emplace_back(named_path.first);
}
DCHECK(filters.empty()) << "Expected to generate all filters";
CHECK(filters.empty()) << "Expected to generate all filters";
return last_op;
}
@ -544,12 +544,12 @@ class RuleBasedPlanner {
for (auto &set : merge.on_create_) {
on_create = HandleWriteClause(set, on_create, *context_->symbol_table,
context_->bound_symbols);
DCHECK(on_create) << "Expected SET in MERGE ... ON CREATE";
CHECK(on_create) << "Expected SET in MERGE ... ON CREATE";
}
for (auto &set : merge.on_match_) {
on_match = HandleWriteClause(set, on_match, *context_->symbol_table,
context_->bound_symbols);
DCHECK(on_match) << "Expected SET in MERGE ... ON MATCH";
CHECK(on_match) << "Expected SET in MERGE ... ON MATCH";
}
return std::make_unique<plan::Merge>(
std::move(input_op), std::move(on_match), std::move(on_create));

View File

@ -43,5 +43,5 @@ struct TypedValue {
}
struct SymbolTable {
table @0 :List(Sem.Symbol);
table @0 :Utils.Map(Utils.BoxInt32, Sem.Symbol);
}

View File

@ -6,8 +6,8 @@
#include "query/frontend/semantic/symbol_table.hpp"
#include "query/serialization.capnp.h"
#include "query/typed_value.hpp"
#include "storage/distributed/rpc/serialization.hpp"
#include "rpc/serialization.hpp"
#include "storage/distributed/rpc/serialization.hpp"
namespace distributed {
class DataManager;
@ -32,19 +32,27 @@ void Load(TypedValueVectorCompare *comparator,
inline void Save(const SymbolTable &symbol_table,
capnp::SymbolTable::Builder *builder) {
auto list_builder = builder->initTable(symbol_table.table().size());
utils::SaveVector<capnp::Symbol, Symbol>(
symbol_table.table(), &list_builder,
[](auto *builder, const auto &symbol) { Save(symbol, builder); });
auto table_builder = builder->initTable();
utils::SaveMap<utils::capnp::BoxInt32, capnp::Symbol,
std::map<int32_t, Symbol>>(
symbol_table.table(), &table_builder,
[](auto *builder, const auto &entry) {
auto key_builder = builder->initKey();
key_builder.setValue(entry.first);
auto value_builder = builder->initValue();
Save(entry.second, &value_builder);
});
}
inline void Load(SymbolTable *symbol_table,
const capnp::SymbolTable::Reader &reader) {
utils::LoadVector<capnp::Symbol, Symbol>(
utils::LoadMap<utils::capnp::BoxInt32, capnp::Symbol,
std::map<int32_t, Symbol>>(
&symbol_table->table_, reader.getTable(), [](const auto &reader) {
Symbol val;
Load(&val, reader);
return val;
std::pair<int32_t, Symbol> entry;
entry.first = reader.getKey().getValue();
Load(&entry.second, reader.getValue());
return entry;
});
}

View File

@ -757,7 +757,7 @@ TEST_F(QueryPlanExpandVariable, NamedPath) {
2, e, "m", GraphView::OLD);
auto find_symbol = [this](const std::string &name) {
for (const auto &sym : symbol_table.table())
if (sym.name() == name) return sym;
if (sym.second.name() == name) return sym.second;
throw std::runtime_error("Symbol not found");
};

View File

@ -985,10 +985,10 @@ TEST_F(TestSymbolGenerator, MatchVariableLambdaSymbols) {
EXPECT_EQ(symbol_table.max_position(), 7);
// All symbols except `AS res` are anonymously generated.
for (const auto &symbol : symbol_table.table()) {
if (symbol.name() == "res") {
EXPECT_TRUE(symbol.user_declared());
if (symbol.second.name() == "res") {
EXPECT_TRUE(symbol.second.user_declared());
} else {
EXPECT_FALSE(symbol.user_declared());
EXPECT_FALSE(symbol.second.user_declared());
}
}
}