Support multiple types for weighted shortest path (#278)

This commit is contained in:
Antonio Andelic 2021-10-19 14:39:23 +02:00 committed by GitHub
parent 10196f3d7d
commit 1a78c3695d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 155 additions and 43 deletions

View File

@ -49,6 +49,7 @@
#include "utils/pmr/vector.hpp"
#include "utils/readable_size.hpp"
#include "utils/string.hpp"
#include "utils/temporal.hpp"
// macro for the default implementation of LogicalOperator::Accept
// that accepts the visitor and visits it's input_ operator
@ -1445,7 +1446,7 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
// satisfy the "where" condition. if so, places them in the priority
// queue.
auto expand_pair = [this, &evaluator, &frame, &create_state](const EdgeAccessor &edge, const VertexAccessor &vertex,
double weight, int64_t depth) {
const TypedValue &total_weight, int64_t depth) {
auto *memory = evaluator.GetMemoryResource();
if (self_.filter_lambda_.expression) {
frame[self_.filter_lambda_.inner_edge_symbol] = edge;
@ -1457,27 +1458,48 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
frame[self_.weight_lambda_->inner_edge_symbol] = edge;
frame[self_.weight_lambda_->inner_node_symbol] = vertex;
TypedValue typed_weight = self_.weight_lambda_->expression->Accept(evaluator);
TypedValue current_weight = self_.weight_lambda_->expression->Accept(evaluator);
if (!typed_weight.IsNumeric()) {
throw QueryRuntimeException("Calculated weight must be numeric, got {}.", typed_weight.type());
if (!current_weight.IsNumeric() && !current_weight.IsDuration()) {
throw QueryRuntimeException("Calculated weight must be numeric or a Duration, got {}.", current_weight.type());
}
if ((typed_weight < TypedValue(0, memory)).ValueBool()) {
const auto is_valid_numeric = [&] {
return current_weight.IsNumeric() && (current_weight >= TypedValue(0, memory)).ValueBool();
};
const auto is_valid_duration = [&] {
return current_weight.IsDuration() && (current_weight >= TypedValue(utils::Duration(0), memory)).ValueBool();
};
if (!is_valid_numeric() && !is_valid_duration()) {
throw QueryRuntimeException("Calculated weight must be non-negative!");
}
auto next_state = create_state(vertex, depth);
auto next_weight = TypedValue(weight, memory) + typed_weight;
auto found_it = total_cost_.find(next_state);
if (found_it != total_cost_.end() && found_it->second.ValueDouble() <= next_weight.ValueDouble()) return;
pq_.push({next_weight.ValueDouble(), depth + 1, vertex, edge});
TypedValue next_weight = std::invoke([&] {
if (total_weight.IsNull()) {
return current_weight;
}
ValidateWeightTypes(current_weight, total_weight);
return TypedValue(current_weight, memory) + total_weight;
});
auto found_it = total_cost_.find(next_state);
if (found_it != total_cost_.end() && (found_it->second.IsNull() || (found_it->second <= next_weight).ValueBool()))
return;
pq_.push({next_weight, depth + 1, vertex, edge});
};
// Populates the priority queue structure with expansions
// from the given vertex. skips expansions that don't satisfy
// the "where" condition.
auto expand_from_vertex = [this, &expand_pair](const VertexAccessor &vertex, double weight, int64_t depth) {
auto expand_from_vertex = [this, &expand_pair](const VertexAccessor &vertex, const TypedValue &weight,
int64_t depth) {
if (self_.common_.direction != EdgeAtom::Direction::IN) {
auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::View::OLD, self_.common_.edge_types));
for (const auto &edge : out_edges) {
@ -1522,7 +1544,7 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
total_cost_.clear();
yielded_vertices_.clear();
pq_.push({0.0, 0, vertex, std::nullopt});
pq_.push({TypedValue(), 0, vertex, std::nullopt});
// We are adding the starting vertex to the set of yielded vertices
// because we don't want to yield paths that end with the starting
// vertex.
@ -1622,17 +1644,38 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
// Keeps track of vertices for which we yielded a path already.
utils::pmr::unordered_set<VertexAccessor> yielded_vertices_;
static void ValidateWeightTypes(const TypedValue &lhs, const TypedValue &rhs) {
if (!((lhs.IsNumeric() && lhs.IsNumeric()) || (rhs.IsDuration() && rhs.IsDuration()))) {
throw QueryRuntimeException(utils::MessageWithLink(
"All weights should be of the same type, either numeric or a Duration. Please update the weight "
"expression or the filter expression.",
"https://memgr.ph/wsp"));
}
}
// Priority queue comparator. Keep lowest weight on top of the queue.
class PriorityQueueComparator {
public:
bool operator()(const std::tuple<double, int64_t, VertexAccessor, std::optional<EdgeAccessor>> &lhs,
const std::tuple<double, int64_t, VertexAccessor, std::optional<EdgeAccessor>> &rhs) {
return std::get<0>(lhs) > std::get<0>(rhs);
bool operator()(const std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>> &lhs,
const std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>> &rhs) {
const auto &lhs_weight = std::get<0>(lhs);
const auto &rhs_weight = std::get<0>(rhs);
// Null defines minimum value for all types
if (lhs_weight.IsNull()) {
return false;
}
if (rhs_weight.IsNull()) {
return true;
}
ValidateWeightTypes(lhs_weight, rhs_weight);
return (lhs_weight > rhs_weight).ValueBool();
}
};
std::priority_queue<std::tuple<double, int64_t, VertexAccessor, std::optional<EdgeAccessor>>,
utils::pmr::vector<std::tuple<double, int64_t, VertexAccessor, std::optional<EdgeAccessor>>>,
std::priority_queue<std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>>,
utils::pmr::vector<std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>>>,
PriorityQueueComparator>
pq_;

View File

@ -40,10 +40,10 @@ Feature: Weighted Shortest Path
MATCH (n {a:'0'})-[le *wShortest 10 (e, n | e.w ) w]->(m) RETURN m.a, size(le) as s, w
"""
Then the result should be:
| m.a | s | w |
| '1' | 1 | 1.0 |
| '2' | 2 | 3.0 |
| '3' | 1 | 4.0 |
| m.a | s | w |
| '1' | 1 | 1 |
| '2' | 2 | 3 |
| '3' | 1 | 4 |
Scenario: Test match wShortest single edge type filtered
Given an empty graph
@ -116,4 +116,42 @@ Feature: Weighted Shortest Path
"""
Then an error should be raised
Scenario: Test match wShortest weight duration
Given an empty graph
And having executed:
"""
CREATE (n {a:'0'})-[:r {w: DURATION('PT1S')}]->({a:'1'})-[:r {w: DURATION('PT2S')}]->({a:'2'}), (n)-[:r {w: DURATION('PT4S')}]->({a:'3'})
"""
When executing query:
"""
MATCH (n {a:'0'})-[le *wShortest 10 (e, n | e.w ) w]->(m) RETURN m.a, size(le) as s, w
"""
Then the result should be:
| m.a | s | w |
| '1' | 1 | PT1S |
| '2' | 2 | PT3S |
| '3' | 1 | PT4S |
Scenario: Test match wShortest weight negative duration
Given an empty graph
And having executed:
"""
CREATE (n {a:'0'})-[:r {w: DURATION({seconds: -1})}]->({a:'1'})-[:r {w: DURATION('PT2S')}]->({a:'2'}), (n)-[:r {w: DURATION('PT4S')}]->({a:'3'})
"""
When executing query:
"""
MATCH (n {a:'0'})-[le *wShortest 10 (e, n | e.w ) w]->(m) RETURN m.a, size(le) as s, w
"""
Then an error should be raised
Scenario: Test match wShortest weight mixed numeric and duration as weights
Given an empty graph
And having executed:
"""
CREATE (n {a:'0'})-[:r {w: 2}]->({a:'1'})-[:r {w: DURATION('PT2S')}]->({a:'2'}), (n)-[:r {w: DURATION('PT4S')}]->({a:'3'})
"""
When executing query:
"""
MATCH (n {a:'0'})-[le *wShortest 10 (e, n | e.w ) w]->(m) RETURN m.a, size(le) as s, w
"""
Then an error should be raised

View File

@ -26,5 +26,5 @@ Feature: Queries related to all Stackoverflow questions related to WSP
"""
Then the result should be:
| hops | total_weight |
|'1 -> 3 -> 4'| 17.0 |
|'1 -> 3 -> 4'| 17 |

View File

@ -1,3 +1,14 @@
// Copyright 2021 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <cstdlib>
#include <filesystem>
@ -410,38 +421,58 @@ TEST_F(InterpreterTest, Bfs) {
// Test shortest path end to end.
TEST_F(InterpreterTest, ShortestPath) {
Interpret(
"CREATE (n:A {x: 1}), (m:B {x: 2}), (l:C {x: 1}), (n)-[:r1 {w: 1 "
"}]->(m)-[:r2 {w: 2}]->(l), (n)-[:r3 {w: 4}]->(l)");
const auto test_shortest_path = [this](const bool use_duration) {
const auto get_weight = [use_duration](const auto value) {
return fmt::format(use_duration ? "DURATION('PT{}S')" : "{}", value);
};
auto stream = Interpret("MATCH (n)-[e *wshortest 5 (e, n | e.w) ]->(m) return e");
Interpret(
fmt::format("CREATE (n:A {{x: 1}}), (m:B {{x: 2}}), (l:C {{x: 1}}), (n)-[:r1 {{w: {} "
"}}]->(m)-[:r2 {{w: {}}}]->(l), (n)-[:r3 {{w: {}}}]->(l)",
get_weight(1), get_weight(2), get_weight(4)));
ASSERT_EQ(stream.GetHeader().size(), 1U);
EXPECT_EQ(stream.GetHeader()[0], "e");
ASSERT_EQ(stream.GetResults().size(), 3U);
auto stream = Interpret("MATCH (n)-[e *wshortest 5 (e, n | e.w) ]->(m) return e");
auto dba = db_.Access();
std::vector<std::vector<std::string>> expected_results{{"r1"}, {"r2"}, {"r1", "r2"}};
ASSERT_EQ(stream.GetHeader().size(), 1U);
EXPECT_EQ(stream.GetHeader()[0], "e");
ASSERT_EQ(stream.GetResults().size(), 3U);
for (const auto &result : stream.GetResults()) {
const auto &edges = ToEdgeList(result[0]);
auto dba = db_.Access();
std::vector<std::vector<std::string>> expected_results{{"r1"}, {"r2"}, {"r1", "r2"}};
std::vector<std::string> datum;
datum.reserve(edges.size());
for (const auto &result : stream.GetResults()) {
const auto &edges = ToEdgeList(result[0]);
for (const auto &edge : edges) {
datum.push_back(edge.type);
}
std::vector<std::string> datum;
datum.reserve(edges.size());
bool any_match = false;
for (const auto &expected : expected_results) {
if (expected == datum) {
any_match = true;
break;
for (const auto &edge : edges) {
datum.push_back(edge.type);
}
bool any_match = false;
for (const auto &expected : expected_results) {
if (expected == datum) {
any_match = true;
break;
}
}
EXPECT_TRUE(any_match);
}
EXPECT_TRUE(any_match);
Interpret("MATCH (n) DETACH DELETE n");
};
constexpr bool kUseNumeric{false};
constexpr bool kUseDuration{true};
{
SCOPED_TRACE("Test with numeric values");
test_shortest_path(kUseNumeric);
}
{
SCOPED_TRACE("Test with Duration values");
test_shortest_path(kUseDuration);
}
}