Support multiple types for weighted shortest path (#278)
This commit is contained in:
parent
10196f3d7d
commit
1a78c3695d
@ -49,6 +49,7 @@
|
||||
#include "utils/pmr/vector.hpp"
|
||||
#include "utils/readable_size.hpp"
|
||||
#include "utils/string.hpp"
|
||||
#include "utils/temporal.hpp"
|
||||
|
||||
// macro for the default implementation of LogicalOperator::Accept
|
||||
// that accepts the visitor and visits it's input_ operator
|
||||
@ -1445,7 +1446,7 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
|
||||
// satisfy the "where" condition. if so, places them in the priority
|
||||
// queue.
|
||||
auto expand_pair = [this, &evaluator, &frame, &create_state](const EdgeAccessor &edge, const VertexAccessor &vertex,
|
||||
double weight, int64_t depth) {
|
||||
const TypedValue &total_weight, int64_t depth) {
|
||||
auto *memory = evaluator.GetMemoryResource();
|
||||
if (self_.filter_lambda_.expression) {
|
||||
frame[self_.filter_lambda_.inner_edge_symbol] = edge;
|
||||
@ -1457,27 +1458,48 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
|
||||
frame[self_.weight_lambda_->inner_edge_symbol] = edge;
|
||||
frame[self_.weight_lambda_->inner_node_symbol] = vertex;
|
||||
|
||||
TypedValue typed_weight = self_.weight_lambda_->expression->Accept(evaluator);
|
||||
TypedValue current_weight = self_.weight_lambda_->expression->Accept(evaluator);
|
||||
|
||||
if (!typed_weight.IsNumeric()) {
|
||||
throw QueryRuntimeException("Calculated weight must be numeric, got {}.", typed_weight.type());
|
||||
if (!current_weight.IsNumeric() && !current_weight.IsDuration()) {
|
||||
throw QueryRuntimeException("Calculated weight must be numeric or a Duration, got {}.", current_weight.type());
|
||||
}
|
||||
if ((typed_weight < TypedValue(0, memory)).ValueBool()) {
|
||||
|
||||
const auto is_valid_numeric = [&] {
|
||||
return current_weight.IsNumeric() && (current_weight >= TypedValue(0, memory)).ValueBool();
|
||||
};
|
||||
|
||||
const auto is_valid_duration = [&] {
|
||||
return current_weight.IsDuration() && (current_weight >= TypedValue(utils::Duration(0), memory)).ValueBool();
|
||||
};
|
||||
|
||||
if (!is_valid_numeric() && !is_valid_duration()) {
|
||||
throw QueryRuntimeException("Calculated weight must be non-negative!");
|
||||
}
|
||||
|
||||
auto next_state = create_state(vertex, depth);
|
||||
auto next_weight = TypedValue(weight, memory) + typed_weight;
|
||||
auto found_it = total_cost_.find(next_state);
|
||||
if (found_it != total_cost_.end() && found_it->second.ValueDouble() <= next_weight.ValueDouble()) return;
|
||||
|
||||
pq_.push({next_weight.ValueDouble(), depth + 1, vertex, edge});
|
||||
TypedValue next_weight = std::invoke([&] {
|
||||
if (total_weight.IsNull()) {
|
||||
return current_weight;
|
||||
}
|
||||
|
||||
ValidateWeightTypes(current_weight, total_weight);
|
||||
|
||||
return TypedValue(current_weight, memory) + total_weight;
|
||||
});
|
||||
|
||||
auto found_it = total_cost_.find(next_state);
|
||||
if (found_it != total_cost_.end() && (found_it->second.IsNull() || (found_it->second <= next_weight).ValueBool()))
|
||||
return;
|
||||
|
||||
pq_.push({next_weight, depth + 1, vertex, edge});
|
||||
};
|
||||
|
||||
// Populates the priority queue structure with expansions
|
||||
// from the given vertex. skips expansions that don't satisfy
|
||||
// the "where" condition.
|
||||
auto expand_from_vertex = [this, &expand_pair](const VertexAccessor &vertex, double weight, int64_t depth) {
|
||||
auto expand_from_vertex = [this, &expand_pair](const VertexAccessor &vertex, const TypedValue &weight,
|
||||
int64_t depth) {
|
||||
if (self_.common_.direction != EdgeAtom::Direction::IN) {
|
||||
auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::View::OLD, self_.common_.edge_types));
|
||||
for (const auto &edge : out_edges) {
|
||||
@ -1522,7 +1544,7 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
|
||||
total_cost_.clear();
|
||||
yielded_vertices_.clear();
|
||||
|
||||
pq_.push({0.0, 0, vertex, std::nullopt});
|
||||
pq_.push({TypedValue(), 0, vertex, std::nullopt});
|
||||
// We are adding the starting vertex to the set of yielded vertices
|
||||
// because we don't want to yield paths that end with the starting
|
||||
// vertex.
|
||||
@ -1622,17 +1644,38 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
|
||||
// Keeps track of vertices for which we yielded a path already.
|
||||
utils::pmr::unordered_set<VertexAccessor> yielded_vertices_;
|
||||
|
||||
static void ValidateWeightTypes(const TypedValue &lhs, const TypedValue &rhs) {
|
||||
if (!((lhs.IsNumeric() && lhs.IsNumeric()) || (rhs.IsDuration() && rhs.IsDuration()))) {
|
||||
throw QueryRuntimeException(utils::MessageWithLink(
|
||||
"All weights should be of the same type, either numeric or a Duration. Please update the weight "
|
||||
"expression or the filter expression.",
|
||||
"https://memgr.ph/wsp"));
|
||||
}
|
||||
}
|
||||
|
||||
// Priority queue comparator. Keep lowest weight on top of the queue.
|
||||
class PriorityQueueComparator {
|
||||
public:
|
||||
bool operator()(const std::tuple<double, int64_t, VertexAccessor, std::optional<EdgeAccessor>> &lhs,
|
||||
const std::tuple<double, int64_t, VertexAccessor, std::optional<EdgeAccessor>> &rhs) {
|
||||
return std::get<0>(lhs) > std::get<0>(rhs);
|
||||
bool operator()(const std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>> &lhs,
|
||||
const std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>> &rhs) {
|
||||
const auto &lhs_weight = std::get<0>(lhs);
|
||||
const auto &rhs_weight = std::get<0>(rhs);
|
||||
// Null defines minimum value for all types
|
||||
if (lhs_weight.IsNull()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (rhs_weight.IsNull()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
ValidateWeightTypes(lhs_weight, rhs_weight);
|
||||
return (lhs_weight > rhs_weight).ValueBool();
|
||||
}
|
||||
};
|
||||
|
||||
std::priority_queue<std::tuple<double, int64_t, VertexAccessor, std::optional<EdgeAccessor>>,
|
||||
utils::pmr::vector<std::tuple<double, int64_t, VertexAccessor, std::optional<EdgeAccessor>>>,
|
||||
std::priority_queue<std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>>,
|
||||
utils::pmr::vector<std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>>>,
|
||||
PriorityQueueComparator>
|
||||
pq_;
|
||||
|
||||
|
@ -40,10 +40,10 @@ Feature: Weighted Shortest Path
|
||||
MATCH (n {a:'0'})-[le *wShortest 10 (e, n | e.w ) w]->(m) RETURN m.a, size(le) as s, w
|
||||
"""
|
||||
Then the result should be:
|
||||
| m.a | s | w |
|
||||
| '1' | 1 | 1.0 |
|
||||
| '2' | 2 | 3.0 |
|
||||
| '3' | 1 | 4.0 |
|
||||
| m.a | s | w |
|
||||
| '1' | 1 | 1 |
|
||||
| '2' | 2 | 3 |
|
||||
| '3' | 1 | 4 |
|
||||
|
||||
Scenario: Test match wShortest single edge type filtered
|
||||
Given an empty graph
|
||||
@ -116,4 +116,42 @@ Feature: Weighted Shortest Path
|
||||
"""
|
||||
Then an error should be raised
|
||||
|
||||
Scenario: Test match wShortest weight duration
|
||||
Given an empty graph
|
||||
And having executed:
|
||||
"""
|
||||
CREATE (n {a:'0'})-[:r {w: DURATION('PT1S')}]->({a:'1'})-[:r {w: DURATION('PT2S')}]->({a:'2'}), (n)-[:r {w: DURATION('PT4S')}]->({a:'3'})
|
||||
"""
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (n {a:'0'})-[le *wShortest 10 (e, n | e.w ) w]->(m) RETURN m.a, size(le) as s, w
|
||||
"""
|
||||
Then the result should be:
|
||||
| m.a | s | w |
|
||||
| '1' | 1 | PT1S |
|
||||
| '2' | 2 | PT3S |
|
||||
| '3' | 1 | PT4S |
|
||||
|
||||
Scenario: Test match wShortest weight negative duration
|
||||
Given an empty graph
|
||||
And having executed:
|
||||
"""
|
||||
CREATE (n {a:'0'})-[:r {w: DURATION({seconds: -1})}]->({a:'1'})-[:r {w: DURATION('PT2S')}]->({a:'2'}), (n)-[:r {w: DURATION('PT4S')}]->({a:'3'})
|
||||
"""
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (n {a:'0'})-[le *wShortest 10 (e, n | e.w ) w]->(m) RETURN m.a, size(le) as s, w
|
||||
"""
|
||||
Then an error should be raised
|
||||
|
||||
Scenario: Test match wShortest weight mixed numeric and duration as weights
|
||||
Given an empty graph
|
||||
And having executed:
|
||||
"""
|
||||
CREATE (n {a:'0'})-[:r {w: 2}]->({a:'1'})-[:r {w: DURATION('PT2S')}]->({a:'2'}), (n)-[:r {w: DURATION('PT4S')}]->({a:'3'})
|
||||
"""
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (n {a:'0'})-[le *wShortest 10 (e, n | e.w ) w]->(m) RETURN m.a, size(le) as s, w
|
||||
"""
|
||||
Then an error should be raised
|
||||
|
@ -26,5 +26,5 @@ Feature: Queries related to all Stackoverflow questions related to WSP
|
||||
"""
|
||||
Then the result should be:
|
||||
| hops | total_weight |
|
||||
|'1 -> 3 -> 4'| 17.0 |
|
||||
|'1 -> 3 -> 4'| 17 |
|
||||
|
||||
|
@ -1,3 +1,14 @@
|
||||
// Copyright 2021 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
|
||||
@ -410,38 +421,58 @@ TEST_F(InterpreterTest, Bfs) {
|
||||
|
||||
// Test shortest path end to end.
|
||||
TEST_F(InterpreterTest, ShortestPath) {
|
||||
Interpret(
|
||||
"CREATE (n:A {x: 1}), (m:B {x: 2}), (l:C {x: 1}), (n)-[:r1 {w: 1 "
|
||||
"}]->(m)-[:r2 {w: 2}]->(l), (n)-[:r3 {w: 4}]->(l)");
|
||||
const auto test_shortest_path = [this](const bool use_duration) {
|
||||
const auto get_weight = [use_duration](const auto value) {
|
||||
return fmt::format(use_duration ? "DURATION('PT{}S')" : "{}", value);
|
||||
};
|
||||
|
||||
auto stream = Interpret("MATCH (n)-[e *wshortest 5 (e, n | e.w) ]->(m) return e");
|
||||
Interpret(
|
||||
fmt::format("CREATE (n:A {{x: 1}}), (m:B {{x: 2}}), (l:C {{x: 1}}), (n)-[:r1 {{w: {} "
|
||||
"}}]->(m)-[:r2 {{w: {}}}]->(l), (n)-[:r3 {{w: {}}}]->(l)",
|
||||
get_weight(1), get_weight(2), get_weight(4)));
|
||||
|
||||
ASSERT_EQ(stream.GetHeader().size(), 1U);
|
||||
EXPECT_EQ(stream.GetHeader()[0], "e");
|
||||
ASSERT_EQ(stream.GetResults().size(), 3U);
|
||||
auto stream = Interpret("MATCH (n)-[e *wshortest 5 (e, n | e.w) ]->(m) return e");
|
||||
|
||||
auto dba = db_.Access();
|
||||
std::vector<std::vector<std::string>> expected_results{{"r1"}, {"r2"}, {"r1", "r2"}};
|
||||
ASSERT_EQ(stream.GetHeader().size(), 1U);
|
||||
EXPECT_EQ(stream.GetHeader()[0], "e");
|
||||
ASSERT_EQ(stream.GetResults().size(), 3U);
|
||||
|
||||
for (const auto &result : stream.GetResults()) {
|
||||
const auto &edges = ToEdgeList(result[0]);
|
||||
auto dba = db_.Access();
|
||||
std::vector<std::vector<std::string>> expected_results{{"r1"}, {"r2"}, {"r1", "r2"}};
|
||||
|
||||
std::vector<std::string> datum;
|
||||
datum.reserve(edges.size());
|
||||
for (const auto &result : stream.GetResults()) {
|
||||
const auto &edges = ToEdgeList(result[0]);
|
||||
|
||||
for (const auto &edge : edges) {
|
||||
datum.push_back(edge.type);
|
||||
}
|
||||
std::vector<std::string> datum;
|
||||
datum.reserve(edges.size());
|
||||
|
||||
bool any_match = false;
|
||||
for (const auto &expected : expected_results) {
|
||||
if (expected == datum) {
|
||||
any_match = true;
|
||||
break;
|
||||
for (const auto &edge : edges) {
|
||||
datum.push_back(edge.type);
|
||||
}
|
||||
|
||||
bool any_match = false;
|
||||
for (const auto &expected : expected_results) {
|
||||
if (expected == datum) {
|
||||
any_match = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_TRUE(any_match);
|
||||
}
|
||||
|
||||
EXPECT_TRUE(any_match);
|
||||
Interpret("MATCH (n) DETACH DELETE n");
|
||||
};
|
||||
|
||||
constexpr bool kUseNumeric{false};
|
||||
constexpr bool kUseDuration{true};
|
||||
{
|
||||
SCOPED_TRACE("Test with numeric values");
|
||||
test_shortest_path(kUseNumeric);
|
||||
}
|
||||
{
|
||||
SCOPED_TRACE("Test with Duration values");
|
||||
test_shortest_path(kUseDuration);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user