diff --git a/src/auth/models.cpp b/src/auth/models.cpp index 54bf24da6..32b2ca291 100644 --- a/src/auth/models.cpp +++ b/src/auth/models.cpp @@ -37,7 +37,7 @@ const std::vector<Permission> kPermissionsAll = { Permission::CONSTRAINT, Permission::DUMP, Permission::AUTH, Permission::REPLICATION, Permission::DURABILITY, Permission::READ_FILE, Permission::FREE_MEMORY, Permission::TRIGGER, Permission::CONFIG, Permission::STREAM, Permission::MODULE_READ, Permission::MODULE_WRITE, - Permission::WEBSOCKET}; + Permission::WEBSOCKET, Permission::SCHEMA}; } // namespace std::string PermissionToString(Permission permission) { @@ -84,6 +84,8 @@ std::string PermissionToString(Permission permission) { return "MODULE_WRITE"; case Permission::WEBSOCKET: return "WEBSOCKET"; + case Permission::SCHEMA: + return "SCHEMA"; } } diff --git a/src/auth/models.hpp b/src/auth/models.hpp index 0f01c0a39..00c26464b 100644 --- a/src/auth/models.hpp +++ b/src/auth/models.hpp @@ -38,7 +38,8 @@ enum class Permission : uint64_t { STREAM = 1U << 17U, MODULE_READ = 1U << 18U, MODULE_WRITE = 1U << 19U, - WEBSOCKET = 1U << 20U + WEBSOCKET = 1U << 20U, + SCHEMA = 1U << 21U }; // clang-format on diff --git a/src/common/types.hpp b/src/common/types.hpp new file mode 100644 index 000000000..09a0aecf5 --- /dev/null +++ b/src/common/types.hpp @@ -0,0 +1,19 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <cstdint> + +namespace memgraph::common { +enum class SchemaType : uint8_t { BOOL, INT, STRING, DATE, LOCALTIME, LOCALDATETIME, DURATION }; + +} // namespace memgraph::common diff --git a/src/glue/v2/auth.cpp b/src/glue/v2/auth.cpp new file mode 100644 index 000000000..c550f39a0 --- /dev/null +++ b/src/glue/v2/auth.cpp @@ -0,0 +1,64 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "glue/v2/auth.hpp" + +namespace memgraph::glue::v2 { + +auth::Permission PrivilegeToPermission(query::v2::AuthQuery::Privilege privilege) { + switch (privilege) { + case query::v2::AuthQuery::Privilege::MATCH: + return auth::Permission::MATCH; + case query::v2::AuthQuery::Privilege::CREATE: + return auth::Permission::CREATE; + case query::v2::AuthQuery::Privilege::MERGE: + return auth::Permission::MERGE; + case query::v2::AuthQuery::Privilege::DELETE: + return auth::Permission::DELETE; + case query::v2::AuthQuery::Privilege::SET: + return auth::Permission::SET; + case query::v2::AuthQuery::Privilege::REMOVE: + return auth::Permission::REMOVE; + case query::v2::AuthQuery::Privilege::INDEX: + return auth::Permission::INDEX; + case query::v2::AuthQuery::Privilege::STATS: + return auth::Permission::STATS; + case query::v2::AuthQuery::Privilege::CONSTRAINT: + return auth::Permission::CONSTRAINT; + case query::v2::AuthQuery::Privilege::DUMP: + return auth::Permission::DUMP; + case query::v2::AuthQuery::Privilege::REPLICATION: + return auth::Permission::REPLICATION; + case query::v2::AuthQuery::Privilege::DURABILITY: + return auth::Permission::DURABILITY; + case query::v2::AuthQuery::Privilege::READ_FILE: + return auth::Permission::READ_FILE; + case query::v2::AuthQuery::Privilege::FREE_MEMORY: + return auth::Permission::FREE_MEMORY; + case query::v2::AuthQuery::Privilege::TRIGGER: + return auth::Permission::TRIGGER; + case query::v2::AuthQuery::Privilege::CONFIG: + return auth::Permission::CONFIG; + case query::v2::AuthQuery::Privilege::AUTH: + return auth::Permission::AUTH; + case query::v2::AuthQuery::Privilege::STREAM: + return auth::Permission::STREAM; + case query::v2::AuthQuery::Privilege::MODULE_READ: + return auth::Permission::MODULE_READ; + case query::v2::AuthQuery::Privilege::MODULE_WRITE: + return auth::Permission::MODULE_WRITE; + case query::v2::AuthQuery::Privilege::WEBSOCKET: + return auth::Permission::WEBSOCKET; + case query::v2::AuthQuery::Privilege::SCHEMA: + return auth::Permission::SCHEMA; + } +} +} // namespace memgraph::glue::v2 diff --git a/src/glue/v2/auth.hpp b/src/glue/v2/auth.hpp new file mode 100644 index 000000000..3410f6d57 --- /dev/null +++ b/src/glue/v2/auth.hpp @@ -0,0 +1,23 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "auth/models.hpp" +#include "query/v2/frontend/ast/ast.hpp" + +namespace memgraph::glue::v2 { + +/** + * This function converts query::AuthQuery::Privilege to its corresponding + * auth::Permission. + */ +auth::Permission PrivilegeToPermission(query::v2::AuthQuery::Privilege privilege); + +} // namespace memgraph::glue::v2 diff --git a/src/glue/v2/communication.cpp b/src/glue/v2/communication.cpp new file mode 100644 index 000000000..6d99a9a92 --- /dev/null +++ b/src/glue/v2/communication.cpp @@ -0,0 +1,275 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "glue/v2/communication.hpp" + +#include <map> +#include <string> +#include <vector> + +#include "storage/v3/edge_accessor.hpp" +#include "storage/v3/storage.hpp" +#include "storage/v3/vertex_accessor.hpp" +#include "utils/temporal.hpp" + +using memgraph::communication::bolt::Value; + +namespace memgraph::glue::v2 { + +query::v2::TypedValue ToTypedValue(const Value &value) { + switch (value.type()) { + case Value::Type::Null: + return {}; + case Value::Type::Bool: + return query::v2::TypedValue(value.ValueBool()); + case Value::Type::Int: + return query::v2::TypedValue(value.ValueInt()); + case Value::Type::Double: + return query::v2::TypedValue(value.ValueDouble()); + case Value::Type::String: + return query::v2::TypedValue(value.ValueString()); + case Value::Type::List: { + std::vector<query::v2::TypedValue> list; + list.reserve(value.ValueList().size()); + for (const auto &v : value.ValueList()) list.push_back(ToTypedValue(v)); + return query::v2::TypedValue(std::move(list)); + } + case Value::Type::Map: { + std::map<std::string, query::v2::TypedValue> map; + for (const auto &kv : value.ValueMap()) map.emplace(kv.first, ToTypedValue(kv.second)); + return query::v2::TypedValue(std::move(map)); + } + case Value::Type::Vertex: + case Value::Type::Edge: + case Value::Type::UnboundedEdge: + case Value::Type::Path: + throw communication::bolt::ValueException("Unsupported conversion from Value to TypedValue"); + case Value::Type::Date: + return query::v2::TypedValue(value.ValueDate()); + case Value::Type::LocalTime: + return query::v2::TypedValue(value.ValueLocalTime()); + case Value::Type::LocalDateTime: + return query::v2::TypedValue(value.ValueLocalDateTime()); + case Value::Type::Duration: + return query::v2::TypedValue(value.ValueDuration()); + } +} + +storage::v3::Result<communication::bolt::Vertex> ToBoltVertex(const query::v2::VertexAccessor &vertex, + const storage::v3::Storage &db, storage::v3::View view) { + return ToBoltVertex(vertex.impl_, db, view); +} + +storage::v3::Result<communication::bolt::Edge> ToBoltEdge(const query::v2::EdgeAccessor &edge, + const storage::v3::Storage &db, storage::v3::View view) { + return ToBoltEdge(edge.impl_, db, view); +} + +storage::v3::Result<Value> ToBoltValue(const query::v2::TypedValue &value, const storage::v3::Storage &db, + storage::v3::View view) { + switch (value.type()) { + case query::v2::TypedValue::Type::Null: + return Value(); + case query::v2::TypedValue::Type::Bool: + return Value(value.ValueBool()); + case query::v2::TypedValue::Type::Int: + return Value(value.ValueInt()); + case query::v2::TypedValue::Type::Double: + return Value(value.ValueDouble()); + case query::v2::TypedValue::Type::String: + return Value(std::string(value.ValueString())); + case query::v2::TypedValue::Type::List: { + std::vector<Value> values; + values.reserve(value.ValueList().size()); + for (const auto &v : value.ValueList()) { + auto maybe_value = ToBoltValue(v, db, view); + if (maybe_value.HasError()) return maybe_value.GetError(); + values.emplace_back(std::move(*maybe_value)); + } + return Value(std::move(values)); + } + case query::v2::TypedValue::Type::Map: { + std::map<std::string, Value> map; + for (const auto &kv : value.ValueMap()) { + auto maybe_value = ToBoltValue(kv.second, db, view); + if (maybe_value.HasError()) return maybe_value.GetError(); + map.emplace(kv.first, std::move(*maybe_value)); + } + return Value(std::move(map)); + } + case query::v2::TypedValue::Type::Vertex: { + auto maybe_vertex = ToBoltVertex(value.ValueVertex(), db, view); + if (maybe_vertex.HasError()) return maybe_vertex.GetError(); + return Value(std::move(*maybe_vertex)); + } + case query::v2::TypedValue::Type::Edge: { + auto maybe_edge = ToBoltEdge(value.ValueEdge(), db, view); + if (maybe_edge.HasError()) return maybe_edge.GetError(); + return Value(std::move(*maybe_edge)); + } + case query::v2::TypedValue::Type::Path: { + auto maybe_path = ToBoltPath(value.ValuePath(), db, view); + if (maybe_path.HasError()) return maybe_path.GetError(); + return Value(std::move(*maybe_path)); + } + case query::v2::TypedValue::Type::Date: + return Value(value.ValueDate()); + case query::v2::TypedValue::Type::LocalTime: + return Value(value.ValueLocalTime()); + case query::v2::TypedValue::Type::LocalDateTime: + return Value(value.ValueLocalDateTime()); + case query::v2::TypedValue::Type::Duration: + return Value(value.ValueDuration()); + } +} + +storage::v3::Result<communication::bolt::Vertex> ToBoltVertex(const storage::v3::VertexAccessor &vertex, + const storage::v3::Storage &db, storage::v3::View view) { + auto id = communication::bolt::Id::FromUint(vertex.Gid().AsUint()); + auto maybe_labels = vertex.Labels(view); + if (maybe_labels.HasError()) return maybe_labels.GetError(); + std::vector<std::string> labels; + labels.reserve(maybe_labels->size()); + for (const auto &label : *maybe_labels) { + labels.push_back(db.LabelToName(label)); + } + auto maybe_properties = vertex.Properties(view); + if (maybe_properties.HasError()) return maybe_properties.GetError(); + std::map<std::string, Value> properties; + for (const auto &prop : *maybe_properties) { + properties[db.PropertyToName(prop.first)] = ToBoltValue(prop.second); + } + return communication::bolt::Vertex{id, labels, properties}; +} + +storage::v3::Result<communication::bolt::Edge> ToBoltEdge(const storage::v3::EdgeAccessor &edge, + const storage::v3::Storage &db, storage::v3::View view) { + auto id = communication::bolt::Id::FromUint(edge.Gid().AsUint()); + auto from = communication::bolt::Id::FromUint(edge.FromVertex().Gid().AsUint()); + auto to = communication::bolt::Id::FromUint(edge.ToVertex().Gid().AsUint()); + const auto &type = db.EdgeTypeToName(edge.EdgeType()); + auto maybe_properties = edge.Properties(view); + if (maybe_properties.HasError()) return maybe_properties.GetError(); + std::map<std::string, Value> properties; + for (const auto &prop : *maybe_properties) { + properties[db.PropertyToName(prop.first)] = ToBoltValue(prop.second); + } + return communication::bolt::Edge{id, from, to, type, properties}; +} + +storage::v3::Result<communication::bolt::Path> ToBoltPath(const query::v2::Path &path, const storage::v3::Storage &db, + storage::v3::View view) { + std::vector<communication::bolt::Vertex> vertices; + vertices.reserve(path.vertices().size()); + for (const auto &v : path.vertices()) { + auto maybe_vertex = ToBoltVertex(v, db, view); + if (maybe_vertex.HasError()) return maybe_vertex.GetError(); + vertices.emplace_back(std::move(*maybe_vertex)); + } + std::vector<communication::bolt::Edge> edges; + edges.reserve(path.edges().size()); + for (const auto &e : path.edges()) { + auto maybe_edge = ToBoltEdge(e, db, view); + if (maybe_edge.HasError()) return maybe_edge.GetError(); + edges.emplace_back(std::move(*maybe_edge)); + } + return communication::bolt::Path(vertices, edges); +} + +storage::v3::PropertyValue ToPropertyValue(const Value &value) { + switch (value.type()) { + case Value::Type::Null: + return {}; + case Value::Type::Bool: + return storage::v3::PropertyValue(value.ValueBool()); + case Value::Type::Int: + return storage::v3::PropertyValue(value.ValueInt()); + case Value::Type::Double: + return storage::v3::PropertyValue(value.ValueDouble()); + case Value::Type::String: + return storage::v3::PropertyValue(value.ValueString()); + case Value::Type::List: { + std::vector<storage::v3::PropertyValue> vec; + vec.reserve(value.ValueList().size()); + for (const auto &value : value.ValueList()) vec.emplace_back(ToPropertyValue(value)); + return storage::v3::PropertyValue(std::move(vec)); + } + case Value::Type::Map: { + std::map<std::string, storage::v3::PropertyValue> map; + for (const auto &kv : value.ValueMap()) map.emplace(kv.first, ToPropertyValue(kv.second)); + return storage::v3::PropertyValue(std::move(map)); + } + case Value::Type::Vertex: + case Value::Type::Edge: + case Value::Type::UnboundedEdge: + case Value::Type::Path: + throw communication::bolt::ValueException("Unsupported conversion from Value to PropertyValue"); + case Value::Type::Date: + return storage::v3::PropertyValue( + storage::v3::TemporalData(storage::v3::TemporalType::Date, value.ValueDate().MicrosecondsSinceEpoch())); + case Value::Type::LocalTime: + return storage::v3::PropertyValue(storage::v3::TemporalData(storage::v3::TemporalType::LocalTime, + value.ValueLocalTime().MicrosecondsSinceEpoch())); + case Value::Type::LocalDateTime: + return storage::v3::PropertyValue(storage::v3::TemporalData(storage::v3::TemporalType::LocalDateTime, + value.ValueLocalDateTime().MicrosecondsSinceEpoch())); + case Value::Type::Duration: + return storage::v3::PropertyValue( + storage::v3::TemporalData(storage::v3::TemporalType::Duration, value.ValueDuration().microseconds)); + } +} + +Value ToBoltValue(const storage::v3::PropertyValue &value) { + switch (value.type()) { + case storage::v3::PropertyValue::Type::Null: + return {}; + case storage::v3::PropertyValue::Type::Bool: + return {value.ValueBool()}; + case storage::v3::PropertyValue::Type::Int: + return {value.ValueInt()}; + break; + case storage::v3::PropertyValue::Type::Double: + return {value.ValueDouble()}; + case storage::v3::PropertyValue::Type::String: + return {value.ValueString()}; + case storage::v3::PropertyValue::Type::List: { + const auto &values = value.ValueList(); + std::vector<Value> vec; + vec.reserve(values.size()); + for (const auto &v : values) { + vec.push_back(ToBoltValue(v)); + } + return {std::move(vec)}; + } + case storage::v3::PropertyValue::Type::Map: { + const auto &map = value.ValueMap(); + std::map<std::string, Value> dv_map; + for (const auto &kv : map) { + dv_map.emplace(kv.first, ToBoltValue(kv.second)); + } + return {std::move(dv_map)}; + } + case storage::v3::PropertyValue::Type::TemporalData: + const auto &type = value.ValueTemporalData(); + switch (type.type) { + case storage::v3::TemporalType::Date: + return {utils::Date(type.microseconds)}; + case storage::v3::TemporalType::LocalTime: + return {utils::LocalTime(type.microseconds)}; + case storage::v3::TemporalType::LocalDateTime: + return {utils::LocalDateTime(type.microseconds)}; + case storage::v3::TemporalType::Duration: + return {utils::Duration(type.microseconds)}; + } + } +} + +} // namespace memgraph::glue::v2 diff --git a/src/glue/v2/communication.hpp b/src/glue/v2/communication.hpp new file mode 100644 index 000000000..13bf96fca --- /dev/null +++ b/src/glue/v2/communication.hpp @@ -0,0 +1,68 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file Conversion functions between Value and other memgraph types. +#pragma once + +#include "communication/bolt/v1/value.hpp" +#include "query/v2/typed_value.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/result.hpp" +#include "storage/v3/view.hpp" + +namespace memgraph::storage::v3 { +class EdgeAccessor; +class Storage; +class VertexAccessor; +} // namespace memgraph::storage::v3 + +namespace memgraph::glue::v2 { + +/// @param storage::v3::VertexAccessor for converting to +/// communication::bolt::Vertex. +/// @param storage::v3::Storage for getting label and property names. +/// @param storage::v3::View for deciding which vertex attributes are visible. +/// +/// @throw std::bad_alloc +storage::v3::Result<communication::bolt::Vertex> ToBoltVertex(const storage::v3::VertexAccessor &vertex, + const storage::v3::Storage &db, storage::v3::View view); + +/// @param storage::v3::EdgeAccessor for converting to communication::bolt::Edge. +/// @param storage::v3::Storage for getting edge type and property names. +/// @param storage::v3::View for deciding which edge attributes are visible. +/// +/// @throw std::bad_alloc +storage::v3::Result<communication::bolt::Edge> ToBoltEdge(const storage::v3::EdgeAccessor &edge, + const storage::v3::Storage &db, storage::v3::View view); + +/// @param query::v2::Path for converting to communication::bolt::Path. +/// @param storage::v3::Storage for ToBoltVertex and ToBoltEdge. +/// @param storage::v3::View for ToBoltVertex and ToBoltEdge. +/// +/// @throw std::bad_alloc +storage::v3::Result<communication::bolt::Path> ToBoltPath(const query::v2::Path &path, const storage::v3::Storage &db, + storage::v3::View view); + +/// @param query::v2::TypedValue for converting to communication::bolt::Value. +/// @param storage::v3::Storage for ToBoltVertex and ToBoltEdge. +/// @param storage::v3::View for ToBoltVertex and ToBoltEdge. +/// +/// @throw std::bad_alloc +storage::v3::Result<communication::bolt::Value> ToBoltValue(const query::v2::TypedValue &value, + const storage::v3::Storage &db, storage::v3::View view); + +query::v2::TypedValue ToTypedValue(const communication::bolt::Value &value); + +communication::bolt::Value ToBoltValue(const storage::v3::PropertyValue &value); + +storage::v3::PropertyValue ToPropertyValue(const communication::bolt::Value &value); + +} // namespace memgraph::glue::v2 diff --git a/src/io/address.hpp b/src/io/address.hpp index ad2c92b57..94a231e07 100644 --- a/src/io/address.hpp +++ b/src/io/address.hpp @@ -37,22 +37,19 @@ struct Address { return ret; } - bool operator==(const Address &other) const { - return (unique_id == other.unique_id) && (last_known_ip == other.last_known_ip) && - (last_known_port == other.last_known_port); - } + friend bool operator==(const Address &lhs, const Address &rhs) = default; /// unique_id is most dominant for ordering, then last_known_ip, then last_known_port - bool operator<(const Address &other) const { - if (unique_id != other.unique_id) { - return unique_id < other.unique_id; + friend bool operator<(const Address &lhs, const Address &rhs) { + if (lhs.unique_id != rhs.unique_id) { + return lhs.unique_id < rhs.unique_id; } - if (last_known_ip != other.last_known_ip) { - return last_known_ip < other.last_known_ip; + if (lhs.last_known_ip != rhs.last_known_ip) { + return lhs.last_known_ip < rhs.last_known_ip; } - return last_known_port < other.last_known_port; + return lhs.last_known_port < rhs.last_known_port; } std::string ToString() const { diff --git a/src/io/future.hpp b/src/io/future.hpp index 63f3989de..7b9a4461c 100644 --- a/src/io/future.hpp +++ b/src/io/future.hpp @@ -77,7 +77,7 @@ class Shared { // so we have to get out of its way to avoid // a cyclical deadlock. lock.unlock(); - simulator_progressed = (simulator_notifier_)(); + simulator_progressed = std::invoke(simulator_notifier_); lock.lock(); if (item_) { // item may have been filled while we @@ -97,7 +97,7 @@ class Shared { return Take(); } - bool IsReady() { + bool IsReady() const { std::unique_lock<std::mutex> lock(mu_); return item_; } @@ -125,7 +125,7 @@ class Shared { cv_.notify_all(); } - bool IsAwaited() { + bool IsAwaited() const { std::unique_lock<std::mutex> lock(mu_); return waiting_; } diff --git a/src/io/rsm/raft.hpp b/src/io/rsm/raft.hpp index d74b8b69d..dc1a1e9b3 100644 --- a/src/io/rsm/raft.hpp +++ b/src/io/rsm/raft.hpp @@ -9,10 +9,8 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. -// TODO(tyler) buffer out-of-order Append buffers to reassemble more quickly +// TODO(tyler) buffer out-of-order Append buffers on the Followers to reassemble more quickly // TODO(tyler) handle granular batch sizes based on simple flow control -// TODO(tyler) add "application" test that asserts that all state machines apply the same items in-order -// TODO(tyler) add proper token-based deterministic scheduling #pragma once @@ -21,6 +19,7 @@ #include <map> #include <set> #include <thread> +#include <unordered_map> #include <vector> #include "io/simulator/simulator.hpp" @@ -31,14 +30,9 @@ namespace memgraph::io::rsm { using memgraph::io::Address; using memgraph::io::Duration; using memgraph::io::Io; -using memgraph::io::ResponseEnvelope; using memgraph::io::ResponseFuture; using memgraph::io::ResponseResult; using memgraph::io::Time; -using memgraph::io::simulator::Simulator; -using memgraph::io::simulator::SimulatorConfig; -using memgraph::io::simulator::SimulatorStats; -using memgraph::io::simulator::SimulatorTransport; using Term = uint64_t; using LogIndex = uint64_t; @@ -77,7 +71,14 @@ struct ReadResponse { std::optional<Address> retry_leader; }; -// TODO(tyler) add docs +/// AppendRequest is a raft-level message that the Leader +/// periodically broadcasts to all Follower peers. This +/// serves three main roles: +/// 1. acts as a heartbeat from the Leader to the Follower +/// 2. replicates new data that the Leader has received to the Follower +/// 3. informs Follower peers when the commit index has increased, +/// signalling that it is now safe to apply log items to the +/// replicated state machine template <typename WriteRequest> struct AppendRequest { Term term = 0; @@ -115,7 +116,7 @@ struct CommonState { Term term = 0; std::vector<std::pair<Term, WriteRequest>> log; LogIndex committed_log_size = 0; - LogIndex last_applied = 0; + LogIndex applied_size = 0; }; struct FollowerTracker { @@ -124,17 +125,17 @@ struct FollowerTracker { }; struct PendingClientRequest { - LogIndex log_index; RequestId request_id; Address address; + Time received_at; }; struct Leader { std::map<Address, FollowerTracker> followers; - std::deque<PendingClientRequest> pending_client_requests; + std::unordered_map<LogIndex, PendingClientRequest> pending_client_requests; Time last_broadcast = Time::min(); - void Print() { std::cout << "\tLeader \t"; } + static void Print() { std::cout << "\tLeader \t"; } }; struct Candidate { @@ -142,48 +143,36 @@ struct Candidate { Time election_began = Time::min(); std::set<Address> outstanding_votes; - void Print() { std::cout << "\tCandidate\t"; } + static void Print() { std::cout << "\tCandidate\t"; } }; struct Follower { Time last_received_append_entries_timestamp; Address leader_address; - void Print() { std::cout << "\tFollower \t"; } + static void Print() { std::cout << "\tFollower \t"; } }; using Role = std::variant<Candidate, Leader, Follower>; /* -TODO make concept that expresses the fact that any RSM must -be able to have a specific type applied to it, returning -another interesting result type. - -all ReplicatedState classes should have an apply method +all ReplicatedState classes should have an Apply method that returns our WriteResponseValue: -ReadResponse read(ReadOperation); -WriteResponseValue ReplicatedState::apply(WriteRequest); +ReadResponse Read(ReadOperation); +WriteResponseValue ReplicatedState::Apply(WriteRequest); for examples: if the state is uint64_t, and WriteRequest is `struct PlusOne {};`, and WriteResponseValue is also uint64_t (the new value), then -each call to state.apply(PlusOne{}) will return the new value +each call to state.Apply(PlusOne{}) will return the new value after incrementing it. 0, 1, 2, 3... and this will be sent back to the client that requested the mutation. In practice, these mutations will usually be predicated on some previous value, so that they are idempotent, functioning similarly to a CAS operation. - -template<typename Write, typename T, typename WriteResponse> -concept Rsm = requires(T t, Write w) -{ - { t.read(r) } -> std::same_as<ReadResponse>; - { t.apply(w) } -> std::same_as<WriteResponseValue>; -}; */ - template <typename WriteOperation, typename ReadOperation, typename ReplicatedState, typename WriteResponseValue, typename ReadResponseValue> concept Rsm = requires(ReplicatedState state, WriteOperation w, ReadOperation r) { @@ -201,7 +190,7 @@ concept Rsm = requires(ReplicatedState state, WriteOperation w, ReadOperation r) /// identical order across all replicas after an WriteOperation reaches consensus. /// ReadOperation the type of operations that do not require consensus before executing directly /// on a const ReplicatedState & -/// ReadResponseValue the return value of calling ReplicatedState::read(ReadOperation), which is executed directly +/// ReadResponseValue the return value of calling ReplicatedState::Read(ReadOperation), which is executed directly /// without going through consensus first template <typename IoImpl, typename ReplicatedState, typename WriteOperation, typename WriteResponseValue, typename ReadOperation, typename ReadResponseValue> @@ -255,41 +244,36 @@ class Raft { indices.push_back(f.confirmed_contiguous_index); Log("at port ", addr.last_known_port, " has confirmed contiguous index of: ", f.confirmed_contiguous_index); } - std::ranges::sort(indices, std::ranges::greater()); - // assuming reverse sort (using std::ranges::greater) - size_t new_committed_log_size = indices[(indices.size() / 2)]; - // TODO(tyler / gabor) for each index between the old - // index and the new one, apply that log's WriteOperation - // to our replicated_state_, and use the specific return - // value of the ReplicatedState::apply method (WriteResponseValue) - // to respondto the requester. - // - // this will completely replace the while loop below + // reverse sort from highest to lowest (using std::ranges::greater) + std::ranges::sort(indices, std::ranges::greater()); + + size_t new_committed_log_size = indices[(indices.size() / 2)]; state_.committed_log_size = new_committed_log_size; - Log("committed_log_size is now ", state_.committed_log_size); + // For each index between the old index and the new one (inclusive), + // Apply that log's WriteOperation to our replicated_state_, + // and use the specific return value of the ReplicatedState::Apply + // method (WriteResponseValue) to respond to the requester. + for (; state_.applied_size < state_.committed_log_size; state_.applied_size++) { + const LogIndex apply_index = state_.applied_size; + const auto &write_request = state_.log[apply_index].second; + WriteResponseValue write_return = replicated_state_.Apply(write_request); + + if (leader.pending_client_requests.contains(apply_index)) { + PendingClientRequest client_request = std::move(leader.pending_client_requests.at(apply_index)); - while (!leader.pending_client_requests.empty()) { - const auto &front = leader.pending_client_requests.front(); - if (front.log_index <= state_.committed_log_size) { - const auto &write_request = state_.log[front.log_index].second; - WriteResponseValue write_return = replicated_state_.apply(write_request); WriteResponse<WriteResponseValue> resp; resp.success = true; resp.write_return = write_return; - // Log("responding SUCCESS to client"); - // WriteResponse rr{ - // .success = true, - // .retry_leader = std::nullopt, - // }; - io_.Send(front.address, front.request_id, std::move(resp)); - leader.pending_client_requests.pop_front(); - } else { - break; + + io_.Send(client_request.address, client_request.request_id, std::move(resp)); + leader.pending_client_requests.erase(apply_index); } } + + Log("committed_log_size is now ", state_.committed_log_size); } // Raft paper - 5.1 @@ -477,7 +461,7 @@ class Raft { BroadcastAppendEntries(leader.followers); leader.last_broadcast = now; } - // TODO(tyler) TimeOutOldClientRequests(); + return std::nullopt; } @@ -493,7 +477,6 @@ class Raft { /// message that has been received. ///////////////////////////////////////////////////////////// - // Don't we need more stuff in this variant? void Handle(std::variant<ReadRequest<ReadOperation>, AppendRequest<WriteOperation>, AppendResponse, WriteRequest<WriteOperation>, VoteRequest, VoteResponse> &&message_variant, RequestId request_id, Address from_address) { @@ -575,14 +558,14 @@ class Raft { .next_index = committed_log_size, .confirmed_contiguous_index = committed_log_size, }; - followers.insert({address, std::move(follower)}); + followers.insert({address, follower}); } for (const auto &address : candidate.outstanding_votes) { FollowerTracker follower{ .next_index = state_.log.size(), .confirmed_contiguous_index = 0, }; - followers.insert({address, std::move(follower)}); + followers.insert({address, follower}); } Log("becoming Leader at term ", state_.term); @@ -591,7 +574,7 @@ class Raft { return Leader{ .followers = std::move(followers), - .pending_client_requests = std::deque<PendingClientRequest>(), + .pending_client_requests = std::unordered_map<LogIndex, PendingClientRequest>(), }; } @@ -599,7 +582,7 @@ class Raft { } template <typename AllRoles> - std::optional<Role> Handle(AllRoles &, VoteResponse &&res, RequestId request_id, Address from_address) { + std::optional<Role> Handle(AllRoles &, VoteResponse &&, RequestId, Address) { Log("non-Candidate received VoteResponse"); return std::nullopt; } @@ -668,7 +651,7 @@ class Raft { Log("req.last_log_term differs from our leader term at that slot, expected: ", LastLogTerm(), " but got ", req.last_log_term); } else { - // happy path - apply log + // happy path - Apply log Log("applying batch of entries to log of size ", req.entries.size()); MG_ASSERT(req.last_log_index >= state_.committed_log_size, @@ -683,6 +666,11 @@ class Raft { state_.committed_log_size = std::min(req.leader_commit, LastLogIndex()); + for (; state_.applied_size < state_.committed_log_size; state_.applied_size++) { + const auto &write_request = state_.log[state_.applied_size].second; + replicated_state_.Apply(write_request); + } + res.success = true; } @@ -691,7 +679,7 @@ class Raft { return std::nullopt; } - std::optional<Role> Handle(Leader &leader, AppendResponse &&res, RequestId request_id, Address from_address) { + std::optional<Role> Handle(Leader &leader, AppendResponse &&res, RequestId, Address from_address) { if (res.term != state_.term) { } else if (!leader.followers.contains(from_address)) { Log("received AppendResponse from unknown Follower"); @@ -714,7 +702,7 @@ class Raft { } template <typename AllRoles> - std::optional<Role> Handle(AllRoles &, AppendResponse &&res, RequestId request_id, Address from_address) { + std::optional<Role> Handle(AllRoles &, AppendResponse &&, RequestId, Address) { // we used to be the leader, and are getting old delayed responses return std::nullopt; } @@ -724,13 +712,11 @@ class Raft { ///////////////////////////////////////////////////////////// // Leaders are able to immediately respond to the requester (with a ReadResponseValue) applied to the ReplicatedState - std::optional<Role> Handle(Leader &leader, ReadRequest<ReadOperation> &&req, RequestId request_id, - Address from_address) { - // TODO(tyler / gabor) implement - + std::optional<Role> Handle(Leader &, ReadRequest<ReadOperation> &&req, RequestId request_id, Address from_address) { + Log("handling ReadOperation"); ReadOperation read_operation = req.operation; - ReadResponseValue read_return = replicated_state_.read(read_operation); + ReadResponseValue read_return = replicated_state_.Read(read_operation); ReadResponse<ReadResponseValue> resp{ .success = true, @@ -744,10 +730,7 @@ class Raft { } // Candidates should respond with a failure, similar to the Candidate + WriteRequest failure below - std::optional<Role> Handle(Candidate &, ReadRequest<ReadOperation> &&req, RequestId request_id, - Address from_address) { - // TODO(tyler / gabor) implement - + std::optional<Role> Handle(Candidate &, ReadRequest<ReadOperation> &&, RequestId request_id, Address from_address) { Log("received ReadOperation - not redirecting because no Leader is known"); auto res = ReadResponse<ReadResponseValue>{}; @@ -760,11 +743,9 @@ class Raft { return std::nullopt; } - // Leaders should respond with a redirection, similar to the Follower + WriteRequest response below - std::optional<Role> Handle(Follower &follower, ReadRequest<ReadOperation> &&req, RequestId request_id, + // Followers should respond with a redirection, similar to the Follower + WriteRequest response below + std::optional<Role> Handle(Follower &follower, ReadRequest<ReadOperation> &&, RequestId request_id, Address from_address) { - // TODO(tyler / gabor) implement - auto res = ReadResponse<ReadResponseValue>{}; res.success = false; @@ -781,7 +762,7 @@ class Raft { // server. If the client’s first choice is not the leader, that // server will reject the client’s request and supply information // about the most recent leader it has heard from. - std::optional<Role> Handle(Follower &follower, WriteRequest<WriteOperation> &&req, RequestId request_id, + std::optional<Role> Handle(Follower &follower, WriteRequest<WriteOperation> &&, RequestId request_id, Address from_address) { auto res = WriteResponse<WriteResponseValue>{}; @@ -794,8 +775,7 @@ class Raft { return std::nullopt; } - std::optional<Role> Handle(Candidate &, WriteRequest<WriteOperation> &&req, RequestId request_id, - Address from_address) { + std::optional<Role> Handle(Candidate &, WriteRequest<WriteOperation> &&, RequestId request_id, Address from_address) { Log("received WriteRequest - not redirecting because no Leader is known"); auto res = WriteResponse<WriteResponseValue>{}; @@ -811,22 +791,23 @@ class Raft { // only leaders actually handle replication requests from clients std::optional<Role> Handle(Leader &leader, WriteRequest<WriteOperation> &&req, RequestId request_id, Address from_address) { - Log("received WriteRequest"); + Log("handling WriteRequest"); // we are the leader. add item to log and send Append to peers state_.log.emplace_back(std::pair(state_.term, std::move(req.operation))); + LogIndex log_index = state_.log.size() - 1; + PendingClientRequest pcr{ - .log_index = state_.log.size() - 1, .request_id = request_id, .address = from_address, + .received_at = io_.Now(), }; - leader.pending_client_requests.push_back(pcr); + leader.pending_client_requests.emplace(log_index, pcr); BroadcastAppendEntries(leader.followers); - // TODO(tyler) add message to pending requests buffer, reply asynchronously return std::nullopt; } }; diff --git a/src/io/simulator/simulator_handle.hpp b/src/io/simulator/simulator_handle.hpp index 668d2389b..6abaa129d 100644 --- a/src/io/simulator/simulator_handle.hpp +++ b/src/io/simulator/simulator_handle.hpp @@ -43,16 +43,16 @@ struct PromiseKey { Address replier_address; public: - bool operator<(const PromiseKey &other) const { - if (requester_address != other.requester_address) { - return requester_address < other.requester_address; + friend bool operator<(const PromiseKey &lhs, const PromiseKey &rhs) { + if (lhs.requester_address != rhs.requester_address) { + return lhs.requester_address < rhs.requester_address; } - if (request_id != other.request_id) { - return request_id < other.request_id; + if (lhs.request_id != rhs.request_id) { + return lhs.request_id < rhs.request_id; } - return replier_address < other.replier_address; + return lhs.replier_address < rhs.replier_address; } }; diff --git a/src/query/db_accessor.hpp b/src/query/db_accessor.hpp index f325e1282..f8d400e54 100644 --- a/src/query/db_accessor.hpp +++ b/src/query/db_accessor.hpp @@ -51,7 +51,6 @@ class EdgeAccessor final { public: storage::EdgeAccessor impl_; - public: explicit EdgeAccessor(storage::EdgeAccessor impl) : impl_(std::move(impl)) {} bool IsVisible(storage::View view) const { return impl_.IsVisible(view); } @@ -97,7 +96,6 @@ class VertexAccessor final { static EdgeAccessor MakeEdgeAccessor(const storage::EdgeAccessor impl) { return EdgeAccessor(impl); } - public: explicit VertexAccessor(storage::VertexAccessor impl) : impl_(impl) {} bool IsVisible(storage::View view) const { return impl_.IsVisible(view); } diff --git a/src/query/frontend/stripped_lexer_constants.hpp b/src/query/frontend/stripped_lexer_constants.hpp index 42b7b4aeb..784692b53 100644 --- a/src/query/frontend/stripped_lexer_constants.hpp +++ b/src/query/frontend/stripped_lexer_constants.hpp @@ -204,7 +204,7 @@ const trie::Trie kKeywords = {"union", "pulsar", "service_url", "version", - "websocket" + "websocket", "foreach"}; // Unicode codepoints that are allowed at the start of the unescaped name. diff --git a/src/query/metadata.cpp b/src/query/metadata.cpp index f4e8512fd..fa80c61f5 100644 --- a/src/query/metadata.cpp +++ b/src/query/metadata.cpp @@ -114,4 +114,4 @@ std::string ExecutionStatsKeyToString(const ExecutionStats::Key key) { } } -} // namespace memgraph::query \ No newline at end of file +} // namespace memgraph::query diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index a4ae9da66..72117f3ad 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -368,13 +368,12 @@ VertexAccessor &CreateExpand::CreateExpandCursor::OtherVertex(Frame &frame, Exec TypedValue &dest_node_value = frame[self_.node_info_.symbol]; ExpectType(self_.node_info_.symbol, dest_node_value, TypedValue::Type::Vertex); return dest_node_value.ValueVertex(); - } else { - auto &created_vertex = CreateLocalVertex(self_.node_info_, &frame, context); - if (context.trigger_context_collector) { - context.trigger_context_collector->RegisterCreatedObject(created_vertex); - } - return created_vertex; } + auto &created_vertex = CreateLocalVertex(self_.node_info_, &frame, context); + if (context.trigger_context_collector) { + context.trigger_context_collector->RegisterCreatedObject(created_vertex); + } + return created_vertex; } template <class TVerticesFun> diff --git a/src/query/v2/common.hpp b/src/query/v2/common.hpp index e79ca996c..ee8f72ddf 100644 --- a/src/query/v2/common.hpp +++ b/src/query/v2/common.hpp @@ -16,6 +16,7 @@ #include <cstdint> #include <string> #include <string_view> +#include <type_traits> #include "query/v2/db_accessor.hpp" #include "query/v2/exceptions.hpp" @@ -24,8 +25,12 @@ #include "query/v2/typed_value.hpp" #include "storage/v3/id_types.hpp" #include "storage/v3/property_value.hpp" +#include "storage/v3/result.hpp" +#include "storage/v3/schema_validator.hpp" #include "storage/v3/view.hpp" +#include "utils/exceptions.hpp" #include "utils/logging.hpp" +#include "utils/variant_helpers.hpp" namespace memgraph::query::v2 { @@ -81,27 +86,79 @@ concept AccessorWithSetProperty = requires(T accessor, const storage::v3::Proper { accessor.SetProperty(key, new_value) } -> std::same_as<storage::v3::Result<storage::v3::PropertyValue>>; }; +inline void HandleSchemaViolation(const storage::v3::SchemaViolation &schema_violation, const DbAccessor &dba) { + switch (schema_violation.status) { + case storage::v3::SchemaViolation::ValidationStatus::VERTEX_HAS_NO_PRIMARY_PROPERTY: { + throw SchemaViolationException( + fmt::format("Primary key {} not defined on label :{}", + storage::v3::SchemaTypeToString(schema_violation.violated_schema_property->type), + dba.LabelToName(schema_violation.label))); + } + case storage::v3::SchemaViolation::ValidationStatus::NO_SCHEMA_DEFINED_FOR_LABEL: { + throw SchemaViolationException( + fmt::format("Label :{} is not a primary label", dba.LabelToName(schema_violation.label))); + } + case storage::v3::SchemaViolation::ValidationStatus::VERTEX_PROPERTY_WRONG_TYPE: { + throw SchemaViolationException( + fmt::format("Wrong type of property {} in schema :{}, should be of type {}", + *schema_violation.violated_property_value, dba.LabelToName(schema_violation.label), + storage::v3::SchemaTypeToString(schema_violation.violated_schema_property->type))); + } + case storage::v3::SchemaViolation::ValidationStatus::VERTEX_UPDATE_PRIMARY_KEY: { + throw SchemaViolationException(fmt::format("Updating of primary key {} on schema :{} not supported", + *schema_violation.violated_property_value, + dba.LabelToName(schema_violation.label))); + } + case storage::v3::SchemaViolation::ValidationStatus::VERTEX_MODIFY_PRIMARY_LABEL: { + throw SchemaViolationException(fmt::format("Cannot add or remove label :{} since it is a primary label", + dba.LabelToName(schema_violation.label))); + } + case storage::v3::SchemaViolation::ValidationStatus::VERTEX_SECONDARY_LABEL_IS_PRIMARY: { + throw SchemaViolationException( + fmt::format("Cannot create vertex with secondary label :{}", dba.LabelToName(schema_violation.label))); + } + } +} + +inline void HandleErrorOnPropertyUpdate(const storage::v3::Error error) { + switch (error) { + case storage::v3::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to set properties on a deleted object."); + case storage::v3::Error::PROPERTIES_DISABLED: + throw QueryRuntimeException("Can't set property because properties on edges are disabled."); + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when setting a property."); + } +} + /// Set a property `value` mapped with given `key` on a `record`. /// /// @throw QueryRuntimeException if value cannot be set as a property value template <AccessorWithSetProperty T> -storage::v3::PropertyValue PropsSetChecked(T *record, const storage::v3::PropertyId &key, const TypedValue &value) { +storage::v3::PropertyValue PropsSetChecked(T *record, const DbAccessor &dba, const storage::v3::PropertyId &key, + const TypedValue &value) { try { - auto maybe_old_value = record->SetProperty(key, storage::v3::PropertyValue(value)); - if (maybe_old_value.HasError()) { - switch (maybe_old_value.GetError()) { - case storage::v3::Error::SERIALIZATION_ERROR: - throw TransactionSerializationException(); - case storage::v3::Error::DELETED_OBJECT: - throw QueryRuntimeException("Trying to set properties on a deleted object."); - case storage::v3::Error::PROPERTIES_DISABLED: - throw QueryRuntimeException("Can't set property because properties on edges are disabled."); - case storage::v3::Error::VERTEX_HAS_EDGES: - case storage::v3::Error::NONEXISTENT_OBJECT: - throw QueryRuntimeException("Unexpected error when setting a property."); + if constexpr (std::is_same_v<T, VertexAccessor>) { + const auto maybe_old_value = record->SetPropertyAndValidate(key, storage::v3::PropertyValue(value)); + if (maybe_old_value.HasError()) { + std::visit(utils::Overloaded{[](const storage::v3::Error error) { HandleErrorOnPropertyUpdate(error); }, + [&dba](const storage::v3::SchemaViolation &schema_violation) { + HandleSchemaViolation(schema_violation, dba); + }}, + maybe_old_value.GetError()); } + return std::move(*maybe_old_value); + } else { + // No validation on edge properties + const auto maybe_old_value = record->SetProperty(key, storage::v3::PropertyValue(value)); + if (maybe_old_value.HasError()) { + HandleErrorOnPropertyUpdate(maybe_old_value.GetError()); + } + return std::move(*maybe_old_value); } - return std::move(*maybe_old_value); } catch (const TypedValueException &) { throw QueryRuntimeException("'{}' cannot be used as a property value.", value.type()); } diff --git a/src/query/v2/db_accessor.hpp b/src/query/v2/db_accessor.hpp index 90ea6d431..a09bcffcb 100644 --- a/src/query/v2/db_accessor.hpp +++ b/src/query/v2/db_accessor.hpp @@ -12,6 +12,7 @@ #pragma once #include <optional> +#include <vector> #include <cppitertools/filter.hpp> #include <cppitertools/imap.hpp> @@ -23,7 +24,7 @@ /////////////////////////////////////////////////////////// // Our communication layer and query engine don't mix -// very well on Centos because OpenSSL version avaialable +// very well on Centos because OpenSSL version available // on Centos 7 include libkrb5 which has brilliant macros // called TRUE and FALSE. For more detailed explanation go // to memgraph.cpp. @@ -34,6 +35,8 @@ // simply undefine those macros as we're sure that libkrb5 // won't and can't be used anywhere in the query engine. #include "storage/v3/storage.hpp" +#include "utils/logging.hpp" +#include "utils/result.hpp" #undef FALSE #undef TRUE @@ -51,7 +54,6 @@ class EdgeAccessor final { public: storage::v3::EdgeAccessor impl_; - public: explicit EdgeAccessor(storage::v3::EdgeAccessor impl) : impl_(std::move(impl)) {} bool IsVisible(storage::v3::View view) const { return impl_.IsVisible(view); } @@ -99,17 +101,26 @@ class VertexAccessor final { static EdgeAccessor MakeEdgeAccessor(const storage::v3::EdgeAccessor impl) { return EdgeAccessor(impl); } - public: explicit VertexAccessor(storage::v3::VertexAccessor impl) : impl_(impl) {} bool IsVisible(storage::v3::View view) const { return impl_.IsVisible(view); } auto Labels(storage::v3::View view) const { return impl_.Labels(view); } + auto PrimaryLabel(storage::v3::View view) const { return impl_.PrimaryLabel(view); } + storage::v3::Result<bool> AddLabel(storage::v3::LabelId label) { return impl_.AddLabel(label); } + storage::v3::ResultSchema<bool> AddLabelAndValidate(storage::v3::LabelId label) { + return impl_.AddLabelAndValidate(label); + } + storage::v3::Result<bool> RemoveLabel(storage::v3::LabelId label) { return impl_.RemoveLabel(label); } + storage::v3::ResultSchema<bool> RemoveLabelAndValidate(storage::v3::LabelId label) { + return impl_.RemoveLabelAndValidate(label); + } + storage::v3::Result<bool> HasLabel(storage::v3::View view, storage::v3::LabelId label) const { return impl_.HasLabel(label, view); } @@ -126,8 +137,13 @@ class VertexAccessor final { return impl_.SetProperty(key, value); } - storage::v3::Result<storage::v3::PropertyValue> RemoveProperty(storage::v3::PropertyId key) { - return SetProperty(key, storage::v3::PropertyValue()); + storage::v3::ResultSchema<storage::v3::PropertyValue> SetPropertyAndValidate( + storage::v3::PropertyId key, const storage::v3::PropertyValue &value) { + return impl_.SetPropertyAndValidate(key, value); + } + + storage::v3::ResultSchema<storage::v3::PropertyValue> RemovePropertyAndValidate(storage::v3::PropertyId key) { + return SetPropertyAndValidate(key, storage::v3::PropertyValue{}); } storage::v3::Result<std::map<storage::v3::PropertyId, storage::v3::PropertyValue>> ClearProperties() { @@ -254,7 +270,18 @@ class DbAccessor final { return VerticesIterable(accessor_->Vertices(label, property, lower, upper, view)); } - VertexAccessor InsertVertex() { return VertexAccessor(accessor_->CreateVertex()); } + // TODO Remove when query modules have been fixed + [[deprecated]] VertexAccessor InsertVertex() { return VertexAccessor(accessor_->CreateVertex()); } + + storage::v3::ResultSchema<VertexAccessor> InsertVertexAndValidate( + const storage::v3::LabelId primary_label, const std::vector<storage::v3::LabelId> &labels, + const std::vector<std::pair<storage::v3::PropertyId, storage::v3::PropertyValue>> &properties) { + auto maybe_vertex_acc = accessor_->CreateVertexAndValidate(primary_label, labels, properties); + if (maybe_vertex_acc.HasError()) { + return {std::move(maybe_vertex_acc.GetError())}; + } + return VertexAccessor{maybe_vertex_acc.GetValue()}; + } storage::v3::Result<EdgeAccessor> InsertEdge(VertexAccessor *from, VertexAccessor *to, const storage::v3::EdgeTypeId &edge_type) { @@ -312,7 +339,7 @@ class DbAccessor final { return std::optional<VertexAccessor>{}; } - return std::make_optional<VertexAccessor>(*value); + return {std::make_optional<VertexAccessor>(*value)}; } storage::v3::PropertyId NameToProperty(const std::string_view name) { return accessor_->NameToProperty(name); } @@ -361,6 +388,10 @@ class DbAccessor final { storage::v3::IndicesInfo ListAllIndices() const { return accessor_->ListAllIndices(); } storage::v3::ConstraintsInfo ListAllConstraints() const { return accessor_->ListAllConstraints(); } + + const storage::v3::SchemaValidator &GetSchemaValidator() const { return accessor_->GetSchemaValidator(); } + + storage::v3::SchemasInfo ListAllSchemas() const { return accessor_->ListAllSchemas(); } }; } // namespace memgraph::query::v2 diff --git a/src/query/v2/exceptions.hpp b/src/query/v2/exceptions.hpp index e0802a6cc..959672eae 100644 --- a/src/query/v2/exceptions.hpp +++ b/src/query/v2/exceptions.hpp @@ -224,4 +224,12 @@ class VersionInfoInMulticommandTxException : public QueryException { : QueryException("Version info query not allowed in multicommand transactions.") {} }; +/** + * An exception for an illegal operation that violates schema + */ +class SchemaViolationException : public QueryRuntimeException { + public: + using QueryRuntimeException::QueryRuntimeException; +}; + } // namespace memgraph::query::v2 diff --git a/src/query/v2/frontend/ast/ast.lcp b/src/query/v2/frontend/ast/ast.lcp index b858ab71f..023be58f2 100644 --- a/src/query/v2/frontend/ast/ast.lcp +++ b/src/query/v2/frontend/ast/ast.lcp @@ -134,6 +134,15 @@ cpp<# } cpp<#)) + +(defun clone-schema-property-vector (source dest) + #>cpp + ${dest}.reserve(${source}.size()); + for (const auto &[property_ix, property_type]: ${source}) { + ${dest}.emplace_back(storage->GetPropertyIx(property_ix.name), property_type); + } + cpp<#) + ;; The following index structs serve as a decoupling point of AST from ;; concrete database types. All the names are collected in AstStorage, and can ;; be indexed through these instances. This means that we can create a vector @@ -2256,7 +2265,7 @@ cpp<# (lcp:define-enum privilege (create delete match merge set remove index stats auth constraint dump replication durability read_file free_memory trigger config stream module_read module_write - websocket) + websocket schema) (:serialize)) #>cpp AuthQuery() = default; @@ -2298,7 +2307,7 @@ const std::vector<AuthQuery::Privilege> kPrivilegesAll = { AuthQuery::Privilege::FREE_MEMORY, AuthQuery::Privilege::TRIGGER, AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM, AuthQuery::Privilege::MODULE_READ, AuthQuery::Privilege::MODULE_WRITE, - AuthQuery::Privilege::WEBSOCKET}; + AuthQuery::Privilege::WEBSOCKET, AuthQuery::Privilege::SCHEMA}; cpp<# (lcp:define-class info-query (query) @@ -2671,6 +2680,39 @@ cpp<# (:serialize (:slk)) (:clone)) +(lcp:define-class schema-query (query) + ((action "Action" :scope :public) + (label "LabelIx" :scope :public + :slk-load (lambda (member) + #>cpp + slk::Load(&self->${member}, reader, storage); + cpp<#) + :clone (lambda (source dest) + #>cpp + ${dest} = storage->GetLabelIx(${source}.name); + cpp<#)) + (schema_type_map "std::vector<std::pair<PropertyIx, common::SchemaType>>" + :slk-save #'slk-save-property-map + :slk-load #'slk-load-property-map + :clone #'clone-schema-property-vector + :scope :public)) + + (:public + (lcp:define-enum action + (create-schema drop-schema show-schema show-schemas) + (:serialize)) + #>cpp + SchemaQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + (lcp:pop-namespace) ;; namespace v2 (lcp:pop-namespace) ;; namespace query (lcp:pop-namespace) ;; namespace memgraph diff --git a/src/query/v2/frontend/ast/ast_visitor.hpp b/src/query/v2/frontend/ast/ast_visitor.hpp index 3cd7f9074..77c25cffb 100644 --- a/src/query/v2/frontend/ast/ast_visitor.hpp +++ b/src/query/v2/frontend/ast/ast_visitor.hpp @@ -94,6 +94,7 @@ class StreamQuery; class SettingQuery; class VersionQuery; class Foreach; +class SchemaQuery; using TreeCompositeVisitor = utils::CompositeVisitor< SingleQuery, CypherUnion, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, AdditionOperator, @@ -125,9 +126,9 @@ class ExpressionVisitor None, ParameterLookup, Identifier, PrimitiveLiteral, RegexMatch> {}; template <class TResult> -class QueryVisitor - : public utils::Visitor<TResult, CypherQuery, ExplainQuery, ProfileQuery, IndexQuery, AuthQuery, InfoQuery, - ConstraintQuery, DumpQuery, ReplicationQuery, LockPathQuery, FreeMemoryQuery, TriggerQuery, - IsolationLevelQuery, CreateSnapshotQuery, StreamQuery, SettingQuery, VersionQuery> {}; +class QueryVisitor : public utils::Visitor<TResult, CypherQuery, ExplainQuery, ProfileQuery, IndexQuery, AuthQuery, + InfoQuery, ConstraintQuery, DumpQuery, ReplicationQuery, LockPathQuery, + FreeMemoryQuery, TriggerQuery, IsolationLevelQuery, CreateSnapshotQuery, + StreamQuery, SettingQuery, VersionQuery, SchemaQuery> {}; } // namespace memgraph::query::v2 diff --git a/src/query/v2/frontend/ast/cypher_main_visitor.cpp b/src/query/v2/frontend/ast/cypher_main_visitor.cpp index 976e3abfc..5c8304b5b 100644 --- a/src/query/v2/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/v2/frontend/ast/cypher_main_visitor.cpp @@ -17,6 +17,7 @@ #include <cstring> #include <iterator> #include <limits> +#include <ranges> #include <string> #include <tuple> #include <type_traits> @@ -27,6 +28,7 @@ #include <boost/preprocessor/cat.hpp> +#include "common/types.hpp" #include "query/v2/exceptions.hpp" #include "query/v2/frontend/ast/ast.hpp" #include "query/v2/frontend/ast/ast_visitor.hpp" @@ -275,18 +277,7 @@ antlrcpp::Any CypherMainVisitor::visitRegisterReplica(MemgraphCypher::RegisterRe replication_query->replica_name_ = std::any_cast<std::string>(ctx->replicaName()->symbolicName()->accept(this)); if (ctx->SYNC()) { replication_query->sync_mode_ = memgraph::query::v2::ReplicationQuery::SyncMode::SYNC; - if (ctx->WITH() && ctx->TIMEOUT()) { - if (ctx->timeout->numberLiteral()) { - // we accept both double and integer literals - replication_query->timeout_ = std::any_cast<Expression *>(ctx->timeout->accept(this)); - } else { - throw SemanticException("Timeout should be a integer or double literal!"); - } - } } else if (ctx->ASYNC()) { - if (ctx->WITH() && ctx->TIMEOUT()) { - throw SyntaxException("Timeout can be set only for the SYNC replication mode!"); - } replication_query->sync_mode_ = memgraph::query::v2::ReplicationQuery::SyncMode::ASYNC; } @@ -1358,6 +1349,7 @@ antlrcpp::Any CypherMainVisitor::visitPrivilege(MemgraphCypher::PrivilegeContext if (ctx->MODULE_READ()) return AuthQuery::Privilege::MODULE_READ; if (ctx->MODULE_WRITE()) return AuthQuery::Privilege::MODULE_WRITE; if (ctx->WEBSOCKET()) return AuthQuery::Privilege::WEBSOCKET; + if (ctx->SCHEMA()) return AuthQuery::Privilege::SCHEMA; LOG_FATAL("Should not get here - unknown privilege!"); } @@ -2364,6 +2356,93 @@ antlrcpp::Any CypherMainVisitor::visitForeach(MemgraphCypher::ForeachContext *ct return for_each; } +antlrcpp::Any CypherMainVisitor::visitSchemaQuery(MemgraphCypher::SchemaQueryContext *ctx) { + MG_ASSERT(ctx->children.size() == 1, "SchemaQuery should have exactly one child!"); + auto *schema_query = std::any_cast<SchemaQuery *>(ctx->children[0]->accept(this)); + query_ = schema_query; + return schema_query; +} + +antlrcpp::Any CypherMainVisitor::visitShowSchema(MemgraphCypher::ShowSchemaContext *ctx) { + auto *schema_query = storage_->Create<SchemaQuery>(); + schema_query->action_ = SchemaQuery::Action::SHOW_SCHEMA; + schema_query->label_ = AddLabel(std::any_cast<std::string>(ctx->labelName()->accept(this))); + query_ = schema_query; + return schema_query; +} + +antlrcpp::Any CypherMainVisitor::visitShowSchemas(MemgraphCypher::ShowSchemasContext * /*ctx*/) { + auto *schema_query = storage_->Create<SchemaQuery>(); + schema_query->action_ = SchemaQuery::Action::SHOW_SCHEMAS; + query_ = schema_query; + return schema_query; +} + +antlrcpp::Any CypherMainVisitor::visitPropertyType(MemgraphCypher::PropertyTypeContext *ctx) { + MG_ASSERT(ctx->symbolicName()); + const auto property_type = utils::ToLowerCase(std::any_cast<std::string>(ctx->symbolicName()->accept(this))); + if (property_type == "bool") { + return common::SchemaType::BOOL; + } + if (property_type == "string") { + return common::SchemaType::STRING; + } + if (property_type == "integer") { + return common::SchemaType::INT; + } + if (property_type == "date") { + return common::SchemaType::DATE; + } + if (property_type == "duration") { + return common::SchemaType::DURATION; + } + if (property_type == "localdatetime") { + return common::SchemaType::LOCALDATETIME; + } + if (property_type == "localtime") { + return common::SchemaType::LOCALTIME; + } + throw SyntaxException("Property type must be one of the supported types!"); +} + +/** + * @return Schema* + */ +antlrcpp::Any CypherMainVisitor::visitSchemaPropertyMap(MemgraphCypher::SchemaPropertyMapContext *ctx) { + std::vector<std::pair<PropertyIx, common::SchemaType>> schema_property_map; + for (auto *property_key_pair : ctx->propertyKeyTypePair()) { + auto key = std::any_cast<PropertyIx>(property_key_pair->propertyKeyName()->accept(this)); + auto type = std::any_cast<common::SchemaType>(property_key_pair->propertyType()->accept(this)); + if (std::ranges::find_if(schema_property_map, [&key](const auto &elem) { return elem.first == key; }) != + schema_property_map.end()) { + throw SemanticException("Same property name can't appear twice in a schema map."); + } + schema_property_map.emplace_back(key, type); + } + return schema_property_map; +} + +antlrcpp::Any CypherMainVisitor::visitCreateSchema(MemgraphCypher::CreateSchemaContext *ctx) { + auto *schema_query = storage_->Create<SchemaQuery>(); + schema_query->action_ = SchemaQuery::Action::CREATE_SCHEMA; + schema_query->label_ = AddLabel(std::any_cast<std::string>(ctx->labelName()->accept(this))); + schema_query->schema_type_map_ = + std::any_cast<std::vector<std::pair<PropertyIx, common::SchemaType>>>(ctx->schemaPropertyMap()->accept(this)); + query_ = schema_query; + return schema_query; +} + +/** + * @return Schema* + */ +antlrcpp::Any CypherMainVisitor::visitDropSchema(MemgraphCypher::DropSchemaContext *ctx) { + auto *schema_query = storage_->Create<SchemaQuery>(); + schema_query->action_ = SchemaQuery::Action::DROP_SCHEMA; + schema_query->label_ = AddLabel(std::any_cast<std::string>(ctx->labelName()->accept(this))); + query_ = schema_query; + return schema_query; +} + LabelIx CypherMainVisitor::AddLabel(const std::string &name) { return storage_->GetLabelIx(name); } PropertyIx CypherMainVisitor::AddProperty(const std::string &name) { return storage_->GetPropertyIx(name); } diff --git a/src/query/v2/frontend/ast/cypher_main_visitor.hpp b/src/query/v2/frontend/ast/cypher_main_visitor.hpp index 767c8ce65..0052cd279 100644 --- a/src/query/v2/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/v2/frontend/ast/cypher_main_visitor.hpp @@ -849,6 +849,41 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { */ antlrcpp::Any visitForeach(MemgraphCypher::ForeachContext *ctx) override; + /** + * @return Schema* + */ + antlrcpp::Any visitPropertyType(MemgraphCypher::PropertyTypeContext *ctx) override; + + /** + * @return Schema* + */ + antlrcpp::Any visitSchemaPropertyMap(MemgraphCypher::SchemaPropertyMapContext *ctx) override; + + /** + * @return Schema* + */ + antlrcpp::Any visitSchemaQuery(MemgraphCypher::SchemaQueryContext *ctx) override; + + /** + * @return Schema* + */ + antlrcpp::Any visitShowSchema(MemgraphCypher::ShowSchemaContext *ctx) override; + + /** + * @return Schema* + */ + antlrcpp::Any visitShowSchemas(MemgraphCypher::ShowSchemasContext *ctx) override; + + /** + * @return Schema* + */ + antlrcpp::Any visitCreateSchema(MemgraphCypher::CreateSchemaContext *ctx) override; + + /** + * @return Schema* + */ + antlrcpp::Any visitDropSchema(MemgraphCypher::DropSchemaContext *ctx) override; + public: Query *query() { return query_; } const static std::string kAnonPrefix; diff --git a/src/query/v2/frontend/opencypher/grammar/MemgraphCypher.g4 b/src/query/v2/frontend/opencypher/grammar/MemgraphCypher.g4 index b412a474a..956320bbf 100644 --- a/src/query/v2/frontend/opencypher/grammar/MemgraphCypher.g4 +++ b/src/query/v2/frontend/opencypher/grammar/MemgraphCypher.g4 @@ -46,10 +46,10 @@ memgraphCypherKeyword : cypherKeyword | DROP | DUMP | EXECUTE - | FOR - | FOREACH | FREE | FROM + | FOR + | FOREACH | GLOBAL | GRANT | HEADER @@ -76,6 +76,8 @@ memgraphCypherKeyword : cypherKeyword | ROLE | ROLES | QUOTE + | SCHEMA + | SCHEMAS | SESSION | SETTING | SETTINGS @@ -122,6 +124,7 @@ query : cypherQuery | streamQuery | settingQuery | versionQuery + | schemaQuery ; authQuery : createRole @@ -192,6 +195,12 @@ settingQuery : setSetting | showSettings ; +schemaQuery : showSchema + | showSchemas + | createSchema + | dropSchema + ; + loadCsv : LOAD CSV FROM csvFile ( WITH | NO ) HEADER ( IGNORE BAD ) ? ( DELIMITER delimiter ) ? @@ -254,6 +263,7 @@ privilege : CREATE | MODULE_READ | MODULE_WRITE | WEBSOCKET + | SCHEMA ; privilegeList : privilege ( ',' privilege )* ; @@ -276,7 +286,6 @@ replicaName : symbolicName ; socketAddress : literal ; registerReplica : REGISTER REPLICA replicaName ( SYNC | ASYNC ) - ( WITH TIMEOUT timeout=literal ) ? TO socketAddress ; dropReplica : DROP REPLICA replicaName ; @@ -374,3 +383,17 @@ showSetting : SHOW DATABASE SETTING settingName ; showSettings : SHOW DATABASE SETTINGS ; versionQuery : SHOW VERSION ; + +showSchema : SHOW SCHEMA ON ':' labelName ; + +showSchemas : SHOW SCHEMAS ; + +propertyType : symbolicName ; + +propertyKeyTypePair : propertyKeyName propertyType ; + +schemaPropertyMap : '(' propertyKeyTypePair ( ',' propertyKeyTypePair )* ')' ; + +createSchema : CREATE SCHEMA ON ':' labelName schemaPropertyMap ; + +dropSchema : DROP SCHEMA ON ':' labelName ; diff --git a/src/query/v2/frontend/opencypher/grammar/MemgraphCypherLexer.g4 b/src/query/v2/frontend/opencypher/grammar/MemgraphCypherLexer.g4 index 55e5d53a2..869141033 100644 --- a/src/query/v2/frontend/opencypher/grammar/MemgraphCypherLexer.g4 +++ b/src/query/v2/frontend/opencypher/grammar/MemgraphCypherLexer.g4 @@ -89,6 +89,8 @@ REVOKE : R E V O K E ; ROLE : R O L E ; ROLES : R O L E S ; QUOTE : Q U O T E ; +SCHEMA : S C H E M A ; +SCHEMAS : S C H E M A S ; SERVICE_URL : S E R V I C E UNDERSCORE U R L ; SESSION : S E S S I O N ; SETTING : S E T T I N G ; diff --git a/src/query/v2/frontend/semantic/required_privileges.cpp b/src/query/v2/frontend/semantic/required_privileges.cpp index 0790529cf..df160fac1 100644 --- a/src/query/v2/frontend/semantic/required_privileges.cpp +++ b/src/query/v2/frontend/semantic/required_privileges.cpp @@ -80,6 +80,8 @@ class PrivilegeExtractor : public QueryVisitor<void>, public HierarchicalTreeVis void Visit(VersionQuery & /*version_query*/) override { AddPrivilege(AuthQuery::Privilege::STATS); } + void Visit(SchemaQuery & /*schema_query*/) override { AddPrivilege(AuthQuery::Privilege::SCHEMA); } + bool PreVisit(Create & /*unused*/) override { AddPrivilege(AuthQuery::Privilege::CREATE); return false; diff --git a/src/query/v2/frontend/stripped_lexer_constants.hpp b/src/query/v2/frontend/stripped_lexer_constants.hpp index 4e52fbdc4..df52066fc 100644 --- a/src/query/v2/frontend/stripped_lexer_constants.hpp +++ b/src/query/v2/frontend/stripped_lexer_constants.hpp @@ -204,8 +204,9 @@ const trie::Trie kKeywords = {"union", "pulsar", "service_url", "version", - "websocket" - "foreach"}; + "websocket", + "foreach", + "schema"}; // Unicode codepoints that are allowed at the start of the unescaped name. const std::bitset<kBitsetSize> kUnescapedNameAllowedStarts( diff --git a/src/query/v2/interpreter.cpp b/src/query/v2/interpreter.cpp index 73f3e00be..c65b5e971 100644 --- a/src/query/v2/interpreter.cpp +++ b/src/query/v2/interpreter.cpp @@ -877,6 +877,102 @@ Callback HandleSettingQuery(SettingQuery *setting_query, const Parameters ¶m } } +Callback HandleSchemaQuery(SchemaQuery *schema_query, InterpreterContext *interpreter_context, + std::vector<Notification> *notifications) { + Callback callback; + switch (schema_query->action_) { + case SchemaQuery::Action::SHOW_SCHEMAS: { + callback.header = {"label", "primary_key"}; + callback.fn = [interpreter_context]() { + auto *db = interpreter_context->db; + auto schemas_info = db->ListAllSchemas(); + std::vector<std::vector<TypedValue>> results; + results.reserve(schemas_info.schemas.size()); + + for (const auto &[label_id, schema_types] : schemas_info.schemas) { + std::vector<TypedValue> schema_info_row; + schema_info_row.reserve(3); + + schema_info_row.emplace_back(db->LabelToName(label_id)); + std::vector<std::string> primary_key_properties; + primary_key_properties.reserve(schema_types.size()); + std::transform(schema_types.begin(), schema_types.end(), std::back_inserter(primary_key_properties), + [&db](const auto &schema_type) { + return db->PropertyToName(schema_type.property_id) + + "::" + storage::v3::SchemaTypeToString(schema_type.type); + }); + + schema_info_row.emplace_back(utils::Join(primary_key_properties, ", ")); + results.push_back(std::move(schema_info_row)); + } + return results; + }; + return callback; + } + case SchemaQuery::Action::SHOW_SCHEMA: { + callback.header = {"property_name", "property_type"}; + callback.fn = [interpreter_context, primary_label = schema_query->label_]() { + auto *db = interpreter_context->db; + const auto label = db->NameToLabel(primary_label.name); + const auto *schema = db->GetSchema(label); + std::vector<std::vector<TypedValue>> results; + if (schema) { + for (const auto &schema_property : schema->second) { + std::vector<TypedValue> schema_info_row; + schema_info_row.reserve(2); + schema_info_row.emplace_back(db->PropertyToName(schema_property.property_id)); + schema_info_row.emplace_back(storage::v3::SchemaTypeToString(schema_property.type)); + results.push_back(std::move(schema_info_row)); + } + return results; + } + throw QueryException(fmt::format("Schema on label :{} not found!", primary_label.name)); + }; + return callback; + } + case SchemaQuery::Action::CREATE_SCHEMA: { + auto schema_type_map = schema_query->schema_type_map_; + if (schema_query->schema_type_map_.empty()) { + throw SyntaxException("One or more types have to be defined in schema definition."); + } + callback.fn = [interpreter_context, primary_label = schema_query->label_, + schema_type_map = std::move(schema_type_map)]() { + auto *db = interpreter_context->db; + const auto label = db->NameToLabel(primary_label.name); + std::vector<storage::v3::SchemaProperty> schemas_types; + schemas_types.reserve(schema_type_map.size()); + for (const auto &schema_type : schema_type_map) { + auto property_id = db->NameToProperty(schema_type.first.name); + schemas_types.push_back({property_id, schema_type.second}); + } + if (!db->CreateSchema(label, schemas_types)) { + throw QueryException(fmt::format("Schema on label :{} already exists!", primary_label.name)); + } + return std::vector<std::vector<TypedValue>>{}; + }; + notifications->emplace_back(SeverityLevel::INFO, NotificationCode::CREATE_SCHEMA, + fmt::format("Create schema on label :{}", schema_query->label_.name)); + return callback; + } + case SchemaQuery::Action::DROP_SCHEMA: { + callback.fn = [interpreter_context, primary_label = schema_query->label_]() { + auto *db = interpreter_context->db; + const auto label = db->NameToLabel(primary_label.name); + + if (!db->DropSchema(label)) { + throw QueryException(fmt::format("Schema on label :{} does not exist!", primary_label.name)); + } + + return std::vector<std::vector<TypedValue>>{}; + }; + notifications->emplace_back(SeverityLevel::INFO, NotificationCode::DROP_SCHEMA, + fmt::format("Dropped schema on label :{}", schema_query->label_.name)); + return callback; + } + } + return callback; +} + // Struct for lazy pulling from a vector struct PullPlanVector { explicit PullPlanVector(std::vector<std::vector<TypedValue>> values) : values_(std::move(values)) {} @@ -2072,6 +2168,32 @@ PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_ RWType::NONE}; } +PreparedQuery PrepareSchemaQuery(ParsedQuery parsed_query, bool in_explicit_transaction, + InterpreterContext *interpreter_context, std::vector<Notification> *notifications) { + if (in_explicit_transaction) { + throw ConstraintInMulticommandTxException(); + } + auto *schema_query = utils::Downcast<SchemaQuery>(parsed_query.query); + MG_ASSERT(schema_query); + auto callback = HandleSchemaQuery(schema_query, interpreter_context, notifications); + + return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), + [handler = std::move(callback.fn), action = QueryHandlerResult::NOTHING, + pull_plan = std::shared_ptr<PullPlanVector>(nullptr)]( + AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> { + if (!pull_plan) { + auto results = handler(); + pull_plan = std::make_shared<PullPlanVector>(std::move(results)); + } + + if (pull_plan->Pull(stream, n)) { + return action; + } + return std::nullopt; + }, + RWType::NONE}; +} + void Interpreter::BeginTransaction() { const auto prepared_query = PrepareTransactionQuery("BEGIN"); prepared_query.query_handler(nullptr, {}); @@ -2205,6 +2327,9 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, prepared_query = PrepareSettingQuery(std::move(parsed_query), in_explicit_transaction_, &*execution_db_accessor_); } else if (utils::Downcast<VersionQuery>(parsed_query.query)) { prepared_query = PrepareVersionQuery(std::move(parsed_query), in_explicit_transaction_); + } else if (utils::Downcast<SchemaQuery>(parsed_query.query)) { + prepared_query = PrepareSchemaQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_, + &query_execution->notifications); } else { LOG_FATAL("Should not get here -- unknown query type!"); } diff --git a/src/query/v2/metadata.cpp b/src/query/v2/metadata.cpp index fe7461e79..f8a14d4a0 100644 --- a/src/query/v2/metadata.cpp +++ b/src/query/v2/metadata.cpp @@ -38,6 +38,8 @@ constexpr std::string_view GetCodeString(const NotificationCode code) { return "CreateIndex"sv; case NotificationCode::CREATE_STREAM: return "CreateStream"sv; + case NotificationCode::CREATE_SCHEMA: + return "CreateSchema"sv; case NotificationCode::CHECK_STREAM: return "CheckStream"sv; case NotificationCode::CREATE_TRIGGER: @@ -48,6 +50,8 @@ constexpr std::string_view GetCodeString(const NotificationCode code) { return "DropReplica"sv; case NotificationCode::DROP_INDEX: return "DropIndex"sv; + case NotificationCode::DROP_SCHEMA: + return "DropSchema"sv; case NotificationCode::DROP_STREAM: return "DropStream"sv; case NotificationCode::DROP_TRIGGER: diff --git a/src/query/v2/metadata.hpp b/src/query/v2/metadata.hpp index ffc621d64..c5211b1c1 100644 --- a/src/query/v2/metadata.hpp +++ b/src/query/v2/metadata.hpp @@ -26,12 +26,14 @@ enum class SeverityLevel : uint8_t { INFO, WARNING }; enum class NotificationCode : uint8_t { CREATE_CONSTRAINT, CREATE_INDEX, + CREATE_SCHEMA, CHECK_STREAM, CREATE_STREAM, CREATE_TRIGGER, DROP_CONSTRAINT, DROP_INDEX, DROP_REPLICA, + DROP_SCHEMA, DROP_STREAM, DROP_TRIGGER, EXISTANT_INDEX, diff --git a/src/query/v2/plan/operator.cpp b/src/query/v2/plan/operator.cpp index 4dd0bf693..4e3005d75 100644 --- a/src/query/v2/plan/operator.cpp +++ b/src/query/v2/plan/operator.cpp @@ -52,6 +52,7 @@ #include "utils/readable_size.hpp" #include "utils/string.hpp" #include "utils/temporal.hpp" +#include "utils/variant_helpers.hpp" // macro for the default implementation of LogicalOperator::Accept // that accepts the visitor and visits it's input_ operator @@ -174,45 +175,58 @@ CreateNode::CreateNode(const std::shared_ptr<LogicalOperator> &input, const Node // Creates a vertex on this GraphDb. Returns a reference to vertex placed on the // frame. -VertexAccessor &CreateLocalVertex(const NodeCreationInfo &node_info, Frame *frame, ExecutionContext &context) { +VertexAccessor &CreateLocalVertexAtomically(const NodeCreationInfo &node_info, Frame *frame, + ExecutionContext &context) { auto &dba = *context.db_accessor; - auto new_node = dba.InsertVertex(); - context.execution_stats[ExecutionStats::Key::CREATED_NODES] += 1; - for (auto label : node_info.labels) { - auto maybe_error = new_node.AddLabel(label); - if (maybe_error.HasError()) { - switch (maybe_error.GetError()) { - case storage::v3::Error::SERIALIZATION_ERROR: - throw TransactionSerializationException(); - case storage::v3::Error::DELETED_OBJECT: - throw QueryRuntimeException("Trying to set a label on a deleted node."); - case storage::v3::Error::VERTEX_HAS_EDGES: - case storage::v3::Error::PROPERTIES_DISABLED: - case storage::v3::Error::NONEXISTENT_OBJECT: - throw QueryRuntimeException("Unexpected error when setting a label."); - } - } - context.execution_stats[ExecutionStats::Key::CREATED_LABELS] += 1; - } // Evaluator should use the latest accessors, as modified in this query, when // setting properties on new nodes. ExpressionEvaluator evaluator(frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::v3::View::NEW); - // TODO: PropsSetChecked allocates a PropertyValue, make it use context.memory - // when we update PropertyValue with custom allocator. + + std::vector<std::pair<storage::v3::PropertyId, storage::v3::PropertyValue>> properties; if (const auto *node_info_properties = std::get_if<PropertiesMapList>(&node_info.properties)) { + properties.reserve(node_info_properties->size()); for (const auto &[key, value_expression] : *node_info_properties) { - PropsSetChecked(&new_node, key, value_expression->Accept(evaluator)); + properties.emplace_back(key, storage::v3::PropertyValue(value_expression->Accept(evaluator))); } } else { - auto property_map = evaluator.Visit(*std::get<ParameterLookup *>(node_info.properties)); - for (const auto &[key, value] : property_map.ValueMap()) { + auto property_map = evaluator.Visit(*std::get<ParameterLookup *>(node_info.properties)).ValueMap(); + properties.reserve(property_map.size()); + + for (const auto &[key, value] : property_map) { auto property_id = dba.NameToProperty(key); - PropsSetChecked(&new_node, property_id, value); + properties.emplace_back(property_id, value); } } - (*frame)[node_info.symbol] = new_node; + if (node_info.labels.empty()) { + throw QueryRuntimeException("Primary label must be defined!"); + } + const auto primary_label = node_info.labels[0]; + std::vector<storage::v3::LabelId> secondary_labels(node_info.labels.begin() + 1, node_info.labels.end()); + auto maybe_new_node = dba.InsertVertexAndValidate(primary_label, secondary_labels, properties); + if (maybe_new_node.HasError()) { + std::visit(utils::Overloaded{[&dba](const storage::v3::SchemaViolation &schema_violation) { + HandleSchemaViolation(schema_violation, dba); + }, + [](const storage::v3::Error error) { + switch (error) { + case storage::v3::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to set a label on a deleted node."); + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + case storage::v3::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when setting a label."); + } + }}, + maybe_new_node.GetError()); + } + + context.execution_stats[ExecutionStats::Key::CREATED_NODES] += 1; + + (*frame)[node_info.symbol] = *maybe_new_node; return (*frame)[node_info.symbol].ValueVertex(); } @@ -237,7 +251,7 @@ bool CreateNode::CreateNodeCursor::Pull(Frame &frame, ExecutionContext &context) SCOPED_PROFILE_OP("CreateNode"); if (input_cursor_->Pull(frame, context)) { - auto created_vertex = CreateLocalVertex(self_.node_info_, &frame, context); + auto created_vertex = CreateLocalVertexAtomically(self_.node_info_, &frame, context); if (context.trigger_context_collector) { context.trigger_context_collector->RegisterCreatedObject(created_vertex); } @@ -286,13 +300,13 @@ EdgeAccessor CreateEdge(const EdgeCreationInfo &edge_info, DbAccessor *dba, Vert auto &edge = *maybe_edge; if (const auto *properties = std::get_if<PropertiesMapList>(&edge_info.properties)) { for (const auto &[key, value_expression] : *properties) { - PropsSetChecked(&edge, key, value_expression->Accept(*evaluator)); + PropsSetChecked(&edge, *dba, key, value_expression->Accept(*evaluator)); } } else { auto property_map = evaluator->Visit(*std::get<ParameterLookup *>(edge_info.properties)); for (const auto &[key, value] : property_map.ValueMap()) { auto property_id = dba->NameToProperty(key); - PropsSetChecked(&edge, property_id, value); + PropsSetChecked(&edge, *dba, property_id, value); } } @@ -368,13 +382,12 @@ VertexAccessor &CreateExpand::CreateExpandCursor::OtherVertex(Frame &frame, Exec TypedValue &dest_node_value = frame[self_.node_info_.symbol]; ExpectType(self_.node_info_.symbol, dest_node_value, TypedValue::Type::Vertex); return dest_node_value.ValueVertex(); - } else { - auto &created_vertex = CreateLocalVertex(self_.node_info_, &frame, context); - if (context.trigger_context_collector) { - context.trigger_context_collector->RegisterCreatedObject(created_vertex); - } - return created_vertex; } + auto &created_vertex = CreateLocalVertexAtomically(self_.node_info_, &frame, context); + if (context.trigger_context_collector) { + context.trigger_context_collector->RegisterCreatedObject(created_vertex); + } + return created_vertex; } template <class TVerticesFun> @@ -2050,7 +2063,7 @@ bool SetProperty::SetPropertyCursor::Pull(Frame &frame, ExecutionContext &contex switch (lhs.type()) { case TypedValue::Type::Vertex: { - auto old_value = PropsSetChecked(&lhs.ValueVertex(), self_.property_, rhs); + auto old_value = PropsSetChecked(&lhs.ValueVertex(), *context.db_accessor, self_.property_, rhs); context.execution_stats[ExecutionStats::Key::UPDATED_PROPERTIES] += 1; if (context.trigger_context_collector) { // rhs cannot be moved because it was created with the allocator that is only valid during current pull @@ -2060,7 +2073,7 @@ bool SetProperty::SetPropertyCursor::Pull(Frame &frame, ExecutionContext &contex break; } case TypedValue::Type::Edge: { - auto old_value = PropsSetChecked(&lhs.ValueEdge(), self_.property_, rhs); + auto old_value = PropsSetChecked(&lhs.ValueEdge(), *context.db_accessor, self_.property_, rhs); context.execution_stats[ExecutionStats::Key::UPDATED_PROPERTIES] += 1; if (context.trigger_context_collector) { // rhs cannot be moved because it was created with the allocator that is only valid during current pull @@ -2216,7 +2229,7 @@ void SetPropertiesOnRecord(TRecordAccessor *record, const TypedValue &rhs, SetPr case TypedValue::Type::Map: { for (const auto &kv : rhs.ValueMap()) { auto key = context->db_accessor->NameToProperty(kv.first); - auto old_value = PropsSetChecked(record, key, kv.second); + auto old_value = PropsSetChecked(record, *context->db_accessor, key, kv.second); if (should_register_change) { register_set_property(std::move(old_value), key, kv.second); } @@ -2300,22 +2313,31 @@ bool SetLabels::SetLabelsCursor::Pull(Frame &frame, ExecutionContext &context) { // Skip setting labels on Null (can occur in optional match). if (vertex_value.IsNull()) return true; ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); + + auto &dba = *context.db_accessor; auto &vertex = vertex_value.ValueVertex(); - for (auto label : self_.labels_) { - auto maybe_value = vertex.AddLabel(label); + for (const auto label : self_.labels_) { + auto maybe_value = vertex.AddLabelAndValidate(label); if (maybe_value.HasError()) { - switch (maybe_value.GetError()) { - case storage::v3::Error::SERIALIZATION_ERROR: - throw TransactionSerializationException(); - case storage::v3::Error::DELETED_OBJECT: - throw QueryRuntimeException("Trying to set a label on a deleted node."); - case storage::v3::Error::VERTEX_HAS_EDGES: - case storage::v3::Error::PROPERTIES_DISABLED: - case storage::v3::Error::NONEXISTENT_OBJECT: - throw QueryRuntimeException("Unexpected error when setting a label."); - } + std::visit(utils::Overloaded{[](const storage::v3::Error error) { + switch (error) { + case storage::v3::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to set a label on a deleted node."); + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + case storage::v3::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when setting a label."); + } + }, + [&dba](const storage::v3::SchemaViolation schema_violation) { + HandleSchemaViolation(schema_violation, dba); + }}, + maybe_value.GetError()); } + context.execution_stats[ExecutionStats::Key::CREATED_LABELS]++; if (context.trigger_context_collector && *maybe_value) { context.trigger_context_collector->RegisterSetVertexLabel(vertex, label); } @@ -2358,26 +2380,11 @@ bool RemoveProperty::RemovePropertyCursor::Pull(Frame &frame, ExecutionContext & TypedValue lhs = self_.lhs_->expression_->Accept(evaluator); auto remove_prop = [property = self_.property_, &context](auto *record) { - auto maybe_old_value = record->RemoveProperty(property); - if (maybe_old_value.HasError()) { - switch (maybe_old_value.GetError()) { - case storage::v3::Error::DELETED_OBJECT: - throw QueryRuntimeException("Trying to remove a property on a deleted graph element."); - case storage::v3::Error::SERIALIZATION_ERROR: - throw TransactionSerializationException(); - case storage::v3::Error::PROPERTIES_DISABLED: - throw QueryRuntimeException( - "Can't remove property because properties on edges are " - "disabled."); - case storage::v3::Error::VERTEX_HAS_EDGES: - case storage::v3::Error::NONEXISTENT_OBJECT: - throw QueryRuntimeException("Unexpected error when removing property."); - } - } + auto old_value = PropsSetChecked(record, *context.db_accessor, property, TypedValue{}); if (context.trigger_context_collector) { context.trigger_context_collector->RegisterRemovedObjectProperty(*record, property, - TypedValue(std::move(*maybe_old_value))); + TypedValue(std::move(old_value))); } }; @@ -2431,18 +2438,25 @@ bool RemoveLabels::RemoveLabelsCursor::Pull(Frame &frame, ExecutionContext &cont ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); auto &vertex = vertex_value.ValueVertex(); for (auto label : self_.labels_) { - auto maybe_value = vertex.RemoveLabel(label); + auto maybe_value = vertex.RemoveLabelAndValidate(label); if (maybe_value.HasError()) { - switch (maybe_value.GetError()) { - case storage::v3::Error::SERIALIZATION_ERROR: - throw TransactionSerializationException(); - case storage::v3::Error::DELETED_OBJECT: - throw QueryRuntimeException("Trying to remove labels from a deleted node."); - case storage::v3::Error::VERTEX_HAS_EDGES: - case storage::v3::Error::PROPERTIES_DISABLED: - case storage::v3::Error::NONEXISTENT_OBJECT: - throw QueryRuntimeException("Unexpected error when removing labels from a node."); - } + std::visit( + utils::Overloaded{[](const storage::v3::Error error) { + switch (error) { + case storage::v3::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to remove labels from a deleted node."); + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + case storage::v3::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when removing labels from a node."); + } + }, + [&context](const storage::v3::SchemaViolation &schema_violation) { + HandleSchemaViolation(schema_violation, *context.db_accessor); + }}, + maybe_value.GetError()); } context.execution_stats[ExecutionStats::Key::DELETED_LABELS] += 1; diff --git a/src/storage/v2/indices.hpp b/src/storage/v2/indices.hpp index eed22e8b5..336b70ead 100644 --- a/src/storage/v2/indices.hpp +++ b/src/storage/v2/indices.hpp @@ -111,7 +111,7 @@ class LabelIndex { Iterable Vertices(LabelId label, View view, Transaction *transaction) { auto it = index_.find(label); MG_ASSERT(it != index_.end(), "Index for label {} doesn't exist", label.AsUint()); - return Iterable(it->second.access(), label, view, transaction, indices_, constraints_, config_); + return {it->second.access(), label, view, transaction, indices_, constraints_, config_}; } int64_t ApproximateVertexCount(LabelId label) { @@ -216,8 +216,8 @@ class LabelPropertyIndex { auto it = index_.find({label, property}); MG_ASSERT(it != index_.end(), "Index for label {} and property {} doesn't exist", label.AsUint(), property.AsUint()); - return Iterable(it->second.access(), label, property, lower_bound, upper_bound, view, transaction, indices_, - constraints_, config_); + return {it->second.access(), label, property, lower_bound, upper_bound, view, + transaction, indices_, constraints_, config_}; } int64_t ApproximateVertexCount(LabelId label, PropertyId property) const { diff --git a/src/storage/v2/storage.cpp b/src/storage/v2/storage.cpp index cbc41da86..cee74574d 100644 --- a/src/storage/v2/storage.cpp +++ b/src/storage/v2/storage.cpp @@ -489,12 +489,12 @@ VertexAccessor Storage::Accessor::CreateVertex() { OOMExceptionEnabler oom_exception; auto gid = storage_->vertex_id_.fetch_add(1, std::memory_order_acq_rel); auto acc = storage_->vertices_.access(); - auto delta = CreateDeleteObjectDelta(&transaction_); + auto *delta = CreateDeleteObjectDelta(&transaction_); auto [it, inserted] = acc.insert(Vertex{storage::Gid::FromUint(gid), delta}); MG_ASSERT(inserted, "The vertex must be inserted here!"); MG_ASSERT(it != acc.end(), "Invalid Vertex accessor!"); delta->prev.Set(&*it); - return VertexAccessor(&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_); + return {&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_}; } VertexAccessor Storage::Accessor::CreateVertex(storage::Gid gid) { @@ -508,12 +508,12 @@ VertexAccessor Storage::Accessor::CreateVertex(storage::Gid gid) { storage_->vertex_id_.store(std::max(storage_->vertex_id_.load(std::memory_order_acquire), gid.AsUint() + 1), std::memory_order_release); auto acc = storage_->vertices_.access(); - auto delta = CreateDeleteObjectDelta(&transaction_); + auto *delta = CreateDeleteObjectDelta(&transaction_); auto [it, inserted] = acc.insert(Vertex{gid, delta}); MG_ASSERT(inserted, "The vertex must be inserted here!"); MG_ASSERT(it != acc.end(), "Invalid Vertex accessor!"); delta->prev.Set(&*it); - return VertexAccessor(&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_); + return {&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_}; } std::optional<VertexAccessor> Storage::Accessor::FindVertex(Gid gid, View view) { diff --git a/src/storage/v3/CMakeLists.txt b/src/storage/v3/CMakeLists.txt index e9322b45a..9bc8e1f10 100644 --- a/src/storage/v3/CMakeLists.txt +++ b/src/storage/v3/CMakeLists.txt @@ -10,6 +10,8 @@ set(storage_v3_src_files indices.cpp property_store.cpp vertex_accessor.cpp + schemas.cpp + schema_validator.cpp storage.cpp) # #### Replication ##### diff --git a/src/storage/v3/constraints.cpp b/src/storage/v3/constraints.cpp index e04fa4070..a00b28aa1 100644 --- a/src/storage/v3/constraints.cpp +++ b/src/storage/v3/constraints.cpp @@ -16,6 +16,7 @@ #include <map> #include "storage/v3/mvcc.hpp" +#include "storage/v3/vertex.hpp" #include "utils/logging.hpp" namespace memgraph::storage::v3 { @@ -59,7 +60,7 @@ bool LastCommittedVersionHasLabelProperty(const Vertex &vertex, LabelId label, c std::lock_guard<utils::SpinLock> guard(vertex.lock); delta = vertex.delta; deleted = vertex.deleted; - has_label = utils::Contains(vertex.labels, label); + has_label = VertexHasLabel(vertex, label); size_t i = 0; for (const auto &property : properties) { @@ -142,7 +143,7 @@ bool AnyVersionHasLabelProperty(const Vertex &vertex, LabelId label, const std:: Delta *delta{nullptr}; { std::lock_guard<utils::SpinLock> guard(vertex.lock); - has_label = utils::Contains(vertex.labels, label); + has_label = VertexHasLabel(vertex, label); deleted = vertex.deleted; delta = vertex.delta; @@ -267,7 +268,7 @@ bool UniqueConstraints::Entry::operator==(const std::vector<PropertyValue> &rhs) void UniqueConstraints::UpdateBeforeCommit(const Vertex *vertex, const Transaction &tx) { for (auto &[label_props, storage] : constraints_) { - if (!utils::Contains(vertex->labels, label_props.first)) { + if (!VertexHasLabel(*vertex, label_props.first)) { continue; } auto values = ExtractPropertyValues(*vertex, label_props.second); @@ -301,7 +302,7 @@ utils::BasicResult<ConstraintViolation, UniqueConstraints::CreationStatus> Uniqu auto acc = constraint->second.access(); for (const Vertex &vertex : vertices) { - if (vertex.deleted || !utils::Contains(vertex.labels, label)) { + if (vertex.deleted || !VertexHasLabel(vertex, label)) { continue; } auto values = ExtractPropertyValues(vertex, properties); @@ -352,7 +353,7 @@ std::optional<ConstraintViolation> UniqueConstraints::Validate(const Vertex &ver for (const auto &[label_props, storage] : constraints_) { const auto &label = label_props.first; const auto &properties = label_props.second; - if (!utils::Contains(vertex.labels, label)) { + if (!VertexHasLabel(vertex, label)) { continue; } diff --git a/src/storage/v3/constraints.hpp b/src/storage/v3/constraints.hpp index 3c0715c4d..25d329ad9 100644 --- a/src/storage/v3/constraints.hpp +++ b/src/storage/v3/constraints.hpp @@ -158,7 +158,7 @@ inline utils::BasicResult<ConstraintViolation, bool> CreateExistenceConstraint( return false; } for (const auto &vertex : vertices) { - if (!vertex.deleted && utils::Contains(vertex.labels, label) && !vertex.properties.HasProperty(property)) { + if (!vertex.deleted && VertexHasLabel(vertex, label) && !vertex.properties.HasProperty(property)) { return ConstraintViolation{ConstraintViolation::Type::EXISTENCE, label, std::set<PropertyId>{property}}; } } @@ -184,7 +184,7 @@ inline bool DropExistenceConstraint(Constraints *constraints, LabelId label, Pro [[nodiscard]] inline std::optional<ConstraintViolation> ValidateExistenceConstraints(const Vertex &vertex, const Constraints &constraints) { for (const auto &[label, property] : constraints.existence_constraints) { - if (!vertex.deleted && utils::Contains(vertex.labels, label) && !vertex.properties.HasProperty(property)) { + if (!vertex.deleted && VertexHasLabel(vertex, label) && !vertex.properties.HasProperty(property)) { return ConstraintViolation{ConstraintViolation::Type::EXISTENCE, label, std::set<PropertyId>{property}}; } } diff --git a/src/storage/v3/durability/snapshot.cpp b/src/storage/v3/durability/snapshot.cpp index 48b4184ad..1942778e2 100644 --- a/src/storage/v3/durability/snapshot.cpp +++ b/src/storage/v3/durability/snapshot.cpp @@ -628,8 +628,9 @@ RecoveredSnapshot LoadSnapshot(const std::filesystem::path &path, utils::SkipLis void CreateSnapshot(Transaction *transaction, const std::filesystem::path &snapshot_directory, const std::filesystem::path &wal_directory, uint64_t snapshot_retention_count, utils::SkipList<Vertex> *vertices, utils::SkipList<Edge> *edges, NameIdMapper *name_id_mapper, - Indices *indices, Constraints *constraints, Config::Items items, const std::string &uuid, - const std::string_view epoch_id, const std::deque<std::pair<std::string, uint64_t>> &epoch_history, + Indices *indices, Constraints *constraints, Config::Items items, + const SchemaValidator &schema_validator, const std::string &uuid, const std::string_view epoch_id, + const std::deque<std::pair<std::string, uint64_t>> &epoch_history, utils::FileRetainer *file_retainer) { // Ensure that the storage directory exists. utils::EnsureDirOrDie(snapshot_directory); @@ -713,8 +714,9 @@ void CreateSnapshot(Transaction *transaction, const std::filesystem::path &snaps // type and invalid from/to pointers because we don't know them here, // but that isn't an issue because we won't use that part of the API // here. - auto ea = - EdgeAccessor{edge_ref, EdgeTypeId::FromUint(0UL), nullptr, nullptr, transaction, indices, constraints, items}; + // TODO(jbajic) Fix snapshot with new schema rules + auto ea = EdgeAccessor{edge_ref, EdgeTypeId::FromUint(0UL), nullptr, nullptr, transaction, indices, constraints, + items, schema_validator}; // Get edge data. auto maybe_props = ea.Properties(View::OLD); @@ -742,7 +744,7 @@ void CreateSnapshot(Transaction *transaction, const std::filesystem::path &snaps auto acc = vertices->access(); for (auto &vertex : acc) { // The visibility check is implemented for vertices so we use it here. - auto va = VertexAccessor::Create(&vertex, transaction, indices, constraints, items, View::OLD); + auto va = VertexAccessor::Create(&vertex, transaction, indices, constraints, items, schema_validator, View::OLD); if (!va) continue; // Get vertex data. diff --git a/src/storage/v3/durability/snapshot.hpp b/src/storage/v3/durability/snapshot.hpp index 785ca7ed2..28c0edf08 100644 --- a/src/storage/v3/durability/snapshot.hpp +++ b/src/storage/v3/durability/snapshot.hpp @@ -21,6 +21,7 @@ #include "storage/v3/edge.hpp" #include "storage/v3/indices.hpp" #include "storage/v3/name_id_mapper.hpp" +#include "storage/v3/schema_validator.hpp" #include "storage/v3/transaction.hpp" #include "storage/v3/vertex.hpp" #include "utils/file_locker.hpp" @@ -68,8 +69,9 @@ RecoveredSnapshot LoadSnapshot(const std::filesystem::path &path, utils::SkipLis void CreateSnapshot(Transaction *transaction, const std::filesystem::path &snapshot_directory, const std::filesystem::path &wal_directory, uint64_t snapshot_retention_count, utils::SkipList<Vertex> *vertices, utils::SkipList<Edge> *edges, NameIdMapper *name_id_mapper, - Indices *indices, Constraints *constraints, Config::Items items, const std::string &uuid, - std::string_view epoch_id, const std::deque<std::pair<std::string, uint64_t>> &epoch_history, + Indices *indices, Constraints *constraints, Config::Items items, + const SchemaValidator &schema_validator, const std::string &uuid, std::string_view epoch_id, + const std::deque<std::pair<std::string, uint64_t>> &epoch_history, utils::FileRetainer *file_retainer); } // namespace memgraph::storage::v3::durability diff --git a/src/storage/v3/edge_accessor.cpp b/src/storage/v3/edge_accessor.cpp index 2a1294d8e..abb5597e5 100644 --- a/src/storage/v3/edge_accessor.cpp +++ b/src/storage/v3/edge_accessor.cpp @@ -15,6 +15,7 @@ #include "storage/v3/mvcc.hpp" #include "storage/v3/property_value.hpp" +#include "storage/v3/schema_validator.hpp" #include "storage/v3/vertex_accessor.hpp" #include "utils/memory_tracker.hpp" @@ -54,11 +55,11 @@ bool EdgeAccessor::IsVisible(const View view) const { } VertexAccessor EdgeAccessor::FromVertex() const { - return VertexAccessor{from_vertex_, transaction_, indices_, constraints_, config_}; + return {from_vertex_, transaction_, indices_, constraints_, config_, *schema_validator_}; } VertexAccessor EdgeAccessor::ToVertex() const { - return VertexAccessor{to_vertex_, transaction_, indices_, constraints_, config_}; + return {to_vertex_, transaction_, indices_, constraints_, config_, *schema_validator_}; } Result<PropertyValue> EdgeAccessor::SetProperty(PropertyId property, const PropertyValue &value) { diff --git a/src/storage/v3/edge_accessor.hpp b/src/storage/v3/edge_accessor.hpp index 60497c80b..cf8b658d8 100644 --- a/src/storage/v3/edge_accessor.hpp +++ b/src/storage/v3/edge_accessor.hpp @@ -18,6 +18,7 @@ #include "storage/v3/config.hpp" #include "storage/v3/result.hpp" +#include "storage/v3/schema_validator.hpp" #include "storage/v3/transaction.hpp" #include "storage/v3/view.hpp" @@ -34,7 +35,8 @@ class EdgeAccessor final { public: EdgeAccessor(EdgeRef edge, EdgeTypeId edge_type, Vertex *from_vertex, Vertex *to_vertex, Transaction *transaction, - Indices *indices, Constraints *constraints, Config::Items config, bool for_deleted = false) + Indices *indices, Constraints *constraints, Config::Items config, + const SchemaValidator &schema_validator, bool for_deleted = false) : edge_(edge), edge_type_(edge_type), from_vertex_(from_vertex), @@ -43,6 +45,7 @@ class EdgeAccessor final { indices_(indices), constraints_(constraints), config_(config), + schema_validator_{&schema_validator}, for_deleted_(for_deleted) {} /// @return true if the object is visible from the current transaction @@ -91,6 +94,7 @@ class EdgeAccessor final { Indices *indices_; Constraints *constraints_; Config::Items config_; + const SchemaValidator *schema_validator_; // if the accessor was created for a deleted edge. // Accessor behaves differently for some methods based on this diff --git a/src/storage/v3/indices.cpp b/src/storage/v3/indices.cpp index 340484228..aab1c0fe6 100644 --- a/src/storage/v3/indices.cpp +++ b/src/storage/v3/indices.cpp @@ -10,10 +10,12 @@ // licenses/APL.txt. #include "indices.hpp" + #include <limits> #include "storage/v3/mvcc.hpp" #include "storage/v3/property_value.hpp" +#include "storage/v3/schema_validator.hpp" #include "utils/bound.hpp" #include "utils/logging.hpp" #include "utils/memory_tracker.hpp" @@ -327,7 +329,7 @@ void LabelIndex::RemoveObsoleteEntries(uint64_t oldest_active_start_timestamp) { LabelIndex::Iterable::Iterator::Iterator(Iterable *self, utils::SkipList<Entry>::Iterator index_iterator) : self_(self), index_iterator_(index_iterator), - current_vertex_accessor_(nullptr, nullptr, nullptr, nullptr, self_->config_), + current_vertex_accessor_(nullptr, nullptr, nullptr, nullptr, self_->config_, *self_->schema_validator_), current_vertex_(nullptr) { AdvanceUntilValid(); } @@ -345,8 +347,8 @@ void LabelIndex::Iterable::Iterator::AdvanceUntilValid() { } if (CurrentVersionHasLabel(*index_iterator_->vertex, self_->label_, self_->transaction_, self_->view_)) { current_vertex_ = index_iterator_->vertex; - current_vertex_accessor_ = - VertexAccessor{current_vertex_, self_->transaction_, self_->indices_, self_->constraints_, self_->config_}; + current_vertex_accessor_ = VertexAccessor{current_vertex_, self_->transaction_, self_->indices_, + self_->constraints_, self_->config_, *self_->schema_validator_}; break; } } @@ -354,14 +356,15 @@ void LabelIndex::Iterable::Iterator::AdvanceUntilValid() { LabelIndex::Iterable::Iterable(utils::SkipList<Entry>::Accessor index_accessor, LabelId label, View view, Transaction *transaction, Indices *indices, Constraints *constraints, - Config::Items config) + Config::Items config, const SchemaValidator &schema_validator) : index_accessor_(std::move(index_accessor)), label_(label), view_(view), transaction_(transaction), indices_(indices), constraints_(constraints), - config_(config) {} + config_(config), + schema_validator_(&schema_validator) {} void LabelIndex::RunGC() { for (auto &index_entry : index_) { @@ -478,7 +481,7 @@ void LabelPropertyIndex::RemoveObsoleteEntries(uint64_t oldest_active_start_time LabelPropertyIndex::Iterable::Iterator::Iterator(Iterable *self, utils::SkipList<Entry>::Iterator index_iterator) : self_(self), index_iterator_(index_iterator), - current_vertex_accessor_(nullptr, nullptr, nullptr, nullptr, self_->config_), + current_vertex_accessor_(nullptr, nullptr, nullptr, nullptr, self_->config_, *self_->schema_validator_), current_vertex_(nullptr) { AdvanceUntilValid(); } @@ -517,8 +520,8 @@ void LabelPropertyIndex::Iterable::Iterator::AdvanceUntilValid() { if (CurrentVersionHasLabelProperty(*index_iterator_->vertex, self_->label_, self_->property_, index_iterator_->value, self_->transaction_, self_->view_)) { current_vertex_ = index_iterator_->vertex; - current_vertex_accessor_ = - VertexAccessor(current_vertex_, self_->transaction_, self_->indices_, self_->constraints_, self_->config_); + current_vertex_accessor_ = VertexAccessor(current_vertex_, self_->transaction_, self_->indices_, + self_->constraints_, self_->config_, *self_->schema_validator_); break; } } @@ -541,7 +544,7 @@ LabelPropertyIndex::Iterable::Iterable(utils::SkipList<Entry>::Accessor index_ac const std::optional<utils::Bound<PropertyValue>> &lower_bound, const std::optional<utils::Bound<PropertyValue>> &upper_bound, View view, Transaction *transaction, Indices *indices, Constraints *constraints, - Config::Items config) + Config::Items config, const SchemaValidator &schema_validator) : index_accessor_(std::move(index_accessor)), label_(label), property_(property), @@ -551,7 +554,8 @@ LabelPropertyIndex::Iterable::Iterable(utils::SkipList<Entry>::Accessor index_ac transaction_(transaction), indices_(indices), constraints_(constraints), - config_(config) { + config_(config), + schema_validator_(&schema_validator) { // We have to fix the bounds that the user provided to us. If the user // provided only one bound we should make sure that only values of that type // are returned by the iterator. We ensure this by supplying either an diff --git a/src/storage/v3/indices.hpp b/src/storage/v3/indices.hpp index fe9e83b88..cdc5df630 100644 --- a/src/storage/v3/indices.hpp +++ b/src/storage/v3/indices.hpp @@ -11,12 +11,14 @@ #pragma once +#include <cstdint> #include <optional> #include <tuple> #include <utility> #include "storage/v3/config.hpp" #include "storage/v3/property_value.hpp" +#include "storage/v3/schema_validator.hpp" #include "storage/v3/transaction.hpp" #include "storage/v3/vertex_accessor.hpp" #include "utils/bound.hpp" @@ -51,8 +53,8 @@ class LabelIndex { }; public: - LabelIndex(Indices *indices, Constraints *constraints, Config::Items config) - : indices_(indices), constraints_(constraints), config_(config) {} + LabelIndex(Indices *indices, Constraints *constraints, Config::Items config, const SchemaValidator &schema_validator) + : indices_(indices), constraints_(constraints), config_(config), schema_validator_{&schema_validator} {} /// @throw std::bad_alloc void UpdateOnAddLabel(LabelId label, Vertex *vertex, const Transaction &tx); @@ -72,7 +74,7 @@ class LabelIndex { class Iterable { public: Iterable(utils::SkipList<Entry>::Accessor index_accessor, LabelId label, View view, Transaction *transaction, - Indices *indices, Constraints *constraints, Config::Items config); + Indices *indices, Constraints *constraints, Config::Items config, const SchemaValidator &schema_validator); class Iterator { public: @@ -105,13 +107,14 @@ class LabelIndex { Indices *indices_; Constraints *constraints_; Config::Items config_; + const SchemaValidator *schema_validator_; }; /// Returns an self with vertices visible from the given transaction. Iterable Vertices(LabelId label, View view, Transaction *transaction) { auto it = index_.find(label); MG_ASSERT(it != index_.end(), "Index for label {} doesn't exist", label.AsUint()); - return {it->second.access(), label, view, transaction, indices_, constraints_, config_}; + return {it->second.access(), label, view, transaction, indices_, constraints_, config_, *schema_validator_}; } int64_t ApproximateVertexCount(LabelId label) { @@ -129,6 +132,7 @@ class LabelIndex { Indices *indices_; Constraints *constraints_; Config::Items config_; + const SchemaValidator *schema_validator_; }; class LabelPropertyIndex { @@ -146,8 +150,9 @@ class LabelPropertyIndex { }; public: - LabelPropertyIndex(Indices *indices, Constraints *constraints, Config::Items config) - : indices_(indices), constraints_(constraints), config_(config) {} + LabelPropertyIndex(Indices *indices, Constraints *constraints, Config::Items config, + const SchemaValidator &schema_validator) + : indices_(indices), constraints_(constraints), config_(config), schema_validator_{&schema_validator} {} /// @throw std::bad_alloc void UpdateOnAddLabel(LabelId label, Vertex *vertex, const Transaction &tx); @@ -171,7 +176,7 @@ class LabelPropertyIndex { Iterable(utils::SkipList<Entry>::Accessor index_accessor, LabelId label, PropertyId property, const std::optional<utils::Bound<PropertyValue>> &lower_bound, const std::optional<utils::Bound<PropertyValue>> &upper_bound, View view, Transaction *transaction, - Indices *indices, Constraints *constraints, Config::Items config); + Indices *indices, Constraints *constraints, Config::Items config, const SchemaValidator &schema_validator); class Iterator { public: @@ -208,16 +213,17 @@ class LabelPropertyIndex { Indices *indices_; Constraints *constraints_; Config::Items config_; + const SchemaValidator *schema_validator_; }; Iterable Vertices(LabelId label, PropertyId property, const std::optional<utils::Bound<PropertyValue>> &lower_bound, - const std::optional<utils::Bound<PropertyValue>> &upper_bound, View view, - Transaction *transaction) { + const std::optional<utils::Bound<PropertyValue>> &upper_bound, View view, Transaction *transaction, + const SchemaValidator &schema_validator_) { auto it = index_.find({label, property}); MG_ASSERT(it != index_.end(), "Index for label {} and property {} doesn't exist", label.AsUint(), property.AsUint()); - return {it->second.access(), label, property, lower_bound, upper_bound, view, - transaction, indices_, constraints_, config_}; + return {it->second.access(), label, property, lower_bound, upper_bound, view, + transaction, indices_, constraints_, config_, schema_validator_}; } int64_t ApproximateVertexCount(LabelId label, PropertyId property) const { @@ -246,11 +252,13 @@ class LabelPropertyIndex { Indices *indices_; Constraints *constraints_; Config::Items config_; + const SchemaValidator *schema_validator_; }; struct Indices { - Indices(Constraints *constraints, Config::Items config) - : label_index(this, constraints, config), label_property_index(this, constraints, config) {} + Indices(Constraints *constraints, Config::Items config, const SchemaValidator &schema_validator) + : label_index(this, constraints, config, schema_validator), + label_property_index(this, constraints, config, schema_validator) {} // Disable copy and move because members hold pointer to `this`. Indices(const Indices &) = delete; diff --git a/src/storage/v3/replication/replication_server.cpp b/src/storage/v3/replication/replication_server.cpp index 0568598e1..4b9b38d77 100644 --- a/src/storage/v3/replication/replication_server.cpp +++ b/src/storage/v3/replication/replication_server.cpp @@ -10,6 +10,7 @@ // licenses/APL.txt. #include "storage/v3/replication/replication_server.hpp" + #include <atomic> #include <filesystem> @@ -162,9 +163,10 @@ void Storage::ReplicationServer::SnapshotHandler(slk::Reader *req_reader, slk::B storage_->edges_.clear(); storage_->constraints_ = Constraints(); - storage_->indices_.label_index = LabelIndex(&storage_->indices_, &storage_->constraints_, storage_->config_.items); - storage_->indices_.label_property_index = - LabelPropertyIndex(&storage_->indices_, &storage_->constraints_, storage_->config_.items); + storage_->indices_.label_index = + LabelIndex(&storage_->indices_, &storage_->constraints_, storage_->config_.items, storage_->schema_validator_); + storage_->indices_.label_property_index = LabelPropertyIndex(&storage_->indices_, &storage_->constraints_, + storage_->config_.items, storage_->schema_validator_); try { spdlog::debug("Loading snapshot"); auto recovered_snapshot = durability::LoadSnapshot(*maybe_snapshot_path, &storage_->vertices_, &storage_->edges_, @@ -461,7 +463,8 @@ uint64_t Storage::ReplicationServer::ReadAndApplyDelta(durability::BaseDecoder * &transaction->transaction_, &storage_->indices_, &storage_->constraints_, - storage_->config_.items}; + storage_->config_.items, + storage_->schema_validator_}; auto ret = ea.SetProperty(transaction->NameToProperty(delta.vertex_edge_set_property.property), delta.vertex_edge_set_property.value); diff --git a/src/storage/v3/schema_validator.cpp b/src/storage/v3/schema_validator.cpp new file mode 100644 index 000000000..4aa466f7f --- /dev/null +++ b/src/storage/v3/schema_validator.cpp @@ -0,0 +1,106 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "storage/v3/schema_validator.hpp" + +#include <bits/ranges_algo.h> +#include <cstddef> +#include <ranges> + +#include "storage/v3/schemas.hpp" + +namespace memgraph::storage::v3 { + +bool operator==(const SchemaViolation &lhs, const SchemaViolation &rhs) { + return lhs.status == rhs.status && lhs.label == rhs.label && + lhs.violated_schema_property == rhs.violated_schema_property && + lhs.violated_property_value == rhs.violated_property_value; +} + +SchemaViolation::SchemaViolation(ValidationStatus status, LabelId label) : status{status}, label{label} {} + +SchemaViolation::SchemaViolation(ValidationStatus status, LabelId label, SchemaProperty violated_schema_property) + : status{status}, label{label}, violated_schema_property{violated_schema_property} {} + +SchemaViolation::SchemaViolation(ValidationStatus status, LabelId label, SchemaProperty violated_schema_property, + PropertyValue violated_property_value) + : status{status}, + label{label}, + violated_schema_property{violated_schema_property}, + violated_property_value{violated_property_value} {} + +SchemaValidator::SchemaValidator(Schemas &schemas) : schemas_{schemas} {} + +[[nodiscard]] std::optional<SchemaViolation> SchemaValidator::ValidateVertexCreate( + LabelId primary_label, const std::vector<LabelId> &labels, + const std::vector<std::pair<PropertyId, PropertyValue>> &properties) const { + // Schema on primary label + const auto *schema = schemas_.GetSchema(primary_label); + if (schema == nullptr) { + return SchemaViolation(SchemaViolation::ValidationStatus::NO_SCHEMA_DEFINED_FOR_LABEL, primary_label); + } + + // Is there another primary label among secondary labels + for (const auto &secondary_label : labels) { + if (schemas_.GetSchema(secondary_label)) { + return SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_SECONDARY_LABEL_IS_PRIMARY, secondary_label); + } + } + + // Check only properties defined by schema + for (const auto &schema_type : schema->second) { + // Check schema property existence + auto property_pair = std::ranges::find_if( + properties, [schema_property_id = schema_type.property_id](const auto &property_type_value) { + return property_type_value.first == schema_property_id; + }); + if (property_pair == properties.end()) { + return SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_HAS_NO_PRIMARY_PROPERTY, primary_label, + schema_type); + } + + // Check schema property type + if (auto property_schema_type = PropertyTypeToSchemaType(property_pair->second); + property_schema_type && *property_schema_type != schema_type.type) { + return SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_PROPERTY_WRONG_TYPE, primary_label, schema_type, + property_pair->second); + } + } + + return std::nullopt; +} + +[[nodiscard]] std::optional<SchemaViolation> SchemaValidator::ValidatePropertyUpdate( + const LabelId primary_label, const PropertyId property_id) const { + // Verify existence of schema on primary label + const auto *schema = schemas_.GetSchema(primary_label); + MG_ASSERT(schema, "Cannot validate against non existing schema!"); + + // Verify that updating property is not part of schema + if (const auto schema_property = std::ranges::find_if( + schema->second, + [property_id](const auto &schema_property) { return property_id == schema_property.property_id; }); + schema_property != schema->second.end()) { + return SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_UPDATE_PRIMARY_KEY, primary_label, + *schema_property); + } + return std::nullopt; +} + +[[nodiscard]] std::optional<SchemaViolation> SchemaValidator::ValidateLabelUpdate(const LabelId label) const { + const auto *schema = schemas_.GetSchema(label); + if (schema) { + return SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_MODIFY_PRIMARY_LABEL, label); + } + return std::nullopt; +} + +} // namespace memgraph::storage::v3 diff --git a/src/storage/v3/schema_validator.hpp b/src/storage/v3/schema_validator.hpp new file mode 100644 index 000000000..a2da4609c --- /dev/null +++ b/src/storage/v3/schema_validator.hpp @@ -0,0 +1,69 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <optional> +#include <variant> + +#include "storage/v3/id_types.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/result.hpp" +#include "storage/v3/schemas.hpp" + +namespace memgraph::storage::v3 { + +struct SchemaViolation { + enum class ValidationStatus : uint8_t { + VERTEX_HAS_NO_PRIMARY_PROPERTY, + NO_SCHEMA_DEFINED_FOR_LABEL, + VERTEX_PROPERTY_WRONG_TYPE, + VERTEX_UPDATE_PRIMARY_KEY, + VERTEX_MODIFY_PRIMARY_LABEL, + VERTEX_SECONDARY_LABEL_IS_PRIMARY, + }; + + SchemaViolation(ValidationStatus status, LabelId label); + + SchemaViolation(ValidationStatus status, LabelId label, SchemaProperty violated_schema_property); + + SchemaViolation(ValidationStatus status, LabelId label, SchemaProperty violated_schema_property, + PropertyValue violated_property_value); + + friend bool operator==(const SchemaViolation &lhs, const SchemaViolation &rhs); + + ValidationStatus status; + LabelId label; + std::optional<SchemaProperty> violated_schema_property; + std::optional<PropertyValue> violated_property_value; +}; + +class SchemaValidator { + public: + explicit SchemaValidator(Schemas &schemas); + + [[nodiscard]] std::optional<SchemaViolation> ValidateVertexCreate( + LabelId primary_label, const std::vector<LabelId> &labels, + const std::vector<std::pair<PropertyId, PropertyValue>> &properties) const; + + [[nodiscard]] std::optional<SchemaViolation> ValidatePropertyUpdate(LabelId primary_label, + PropertyId property_id) const; + + [[nodiscard]] std::optional<SchemaViolation> ValidateLabelUpdate(LabelId label) const; + + private: + Schemas &schemas_; +}; + +template <typename TValue> +using ResultSchema = utils::BasicResult<std::variant<SchemaViolation, Error>, TValue>; + +} // namespace memgraph::storage::v3 diff --git a/src/storage/v3/schemas.cpp b/src/storage/v3/schemas.cpp new file mode 100644 index 000000000..2f89c80c0 --- /dev/null +++ b/src/storage/v3/schemas.cpp @@ -0,0 +1,112 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "storage/v3/schemas.hpp" + +#include <unordered_map> +#include <vector> + +#include "storage/v3/property_value.hpp" + +namespace memgraph::storage::v3 { + +bool operator==(const SchemaProperty &lhs, const SchemaProperty &rhs) { + return lhs.property_id == rhs.property_id && lhs.type == rhs.type; +} + +Schemas::SchemasList Schemas::ListSchemas() const { + Schemas::SchemasList ret; + ret.reserve(schemas_.size()); + std::transform(schemas_.begin(), schemas_.end(), std::back_inserter(ret), + [](const auto &schema_property_type) { return schema_property_type; }); + return ret; +} + +const Schemas::Schema *Schemas::GetSchema(const LabelId primary_label) const { + if (auto schema_map = schemas_.find(primary_label); schema_map != schemas_.end()) { + return &*schema_map; + } + return nullptr; +} + +bool Schemas::CreateSchema(const LabelId primary_label, const std::vector<SchemaProperty> &schemas_types) { + if (schemas_.contains(primary_label)) { + return false; + } + schemas_.emplace(primary_label, schemas_types); + return true; +} + +bool Schemas::DropSchema(const LabelId primary_label) { return schemas_.erase(primary_label); } + +std::optional<common::SchemaType> PropertyTypeToSchemaType(const PropertyValue &property_value) { + switch (property_value.type()) { + case PropertyValue::Type::Bool: { + return common::SchemaType::BOOL; + } + case PropertyValue::Type::Int: { + return common::SchemaType::INT; + } + case PropertyValue::Type::String: { + return common::SchemaType::STRING; + } + case PropertyValue::Type::TemporalData: { + switch (property_value.ValueTemporalData().type) { + case TemporalType::Date: { + return common::SchemaType::DATE; + } + case TemporalType::LocalDateTime: { + return common::SchemaType::LOCALDATETIME; + } + case TemporalType::LocalTime: { + return common::SchemaType::LOCALTIME; + } + case TemporalType::Duration: { + return common::SchemaType::DURATION; + } + } + } + case PropertyValue::Type::Double: + case PropertyValue::Type::Null: + case PropertyValue::Type::Map: + case PropertyValue::Type::List: { + return std::nullopt; + } + } +} + +std::string SchemaTypeToString(const common::SchemaType type) { + switch (type) { + case common::SchemaType::BOOL: { + return "Bool"; + } + case common::SchemaType::INT: { + return "Integer"; + } + case common::SchemaType::STRING: { + return "String"; + } + case common::SchemaType::DATE: { + return "Date"; + } + case common::SchemaType::LOCALTIME: { + return "LocalTime"; + } + case common::SchemaType::LOCALDATETIME: { + return "LocalDateTime"; + } + case common::SchemaType::DURATION: { + return "Duration"; + } + } +} + +} // namespace memgraph::storage::v3 diff --git a/src/storage/v3/schemas.hpp b/src/storage/v3/schemas.hpp new file mode 100644 index 000000000..157ee7a35 --- /dev/null +++ b/src/storage/v3/schemas.hpp @@ -0,0 +1,70 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <memory> +#include <optional> +#include <unordered_map> +#include <utility> +#include <vector> + +#include "common/types.hpp" +#include "storage/v3/id_types.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/temporal.hpp" +#include "utils/result.hpp" + +namespace memgraph::storage::v3 { + +struct SchemaProperty { + PropertyId property_id; + common::SchemaType type; + + friend bool operator==(const SchemaProperty &lhs, const SchemaProperty &rhs); +}; + +/// Structure that represents a collection of schemas +/// Schema can be mapped under only one label => primary label +class Schemas { + public: + using SchemasMap = std::unordered_map<LabelId, std::vector<SchemaProperty>>; + using Schema = SchemasMap::value_type; + using SchemasList = std::vector<Schema>; + + Schemas() = default; + Schemas(const Schemas &) = delete; + Schemas(Schemas &&) = delete; + Schemas &operator=(const Schemas &) = delete; + Schemas &operator=(Schemas &&) = delete; + ~Schemas() = default; + + [[nodiscard]] SchemasList ListSchemas() const; + + [[nodiscard]] const Schema *GetSchema(LabelId primary_label) const; + + // Returns true if it was successfully created or false if the schema + // already exists + [[nodiscard]] bool CreateSchema(LabelId label, const std::vector<SchemaProperty> &schemas_types); + + // Returns true if it was successfully dropped or false if the schema + // does not exist + [[nodiscard]] bool DropSchema(LabelId label); + + private: + SchemasMap schemas_; +}; + +std::optional<common::SchemaType> PropertyTypeToSchemaType(const PropertyValue &property_value); + +std::string SchemaTypeToString(common::SchemaType type); + +} // namespace memgraph::storage::v3 diff --git a/src/storage/v3/storage.cpp b/src/storage/v3/storage.cpp index 5ab334893..28fd250fb 100644 --- a/src/storage/v3/storage.cpp +++ b/src/storage/v3/storage.cpp @@ -15,9 +15,11 @@ #include <atomic> #include <memory> #include <mutex> +#include <optional> #include <variant> #include <gflags/gflags.h> +#include <spdlog/spdlog.h> #include "io/network/endpoint.hpp" #include "storage/v3/constraints.hpp" @@ -27,6 +29,7 @@ #include "storage/v3/durability/snapshot.hpp" #include "storage/v3/durability/wal.hpp" #include "storage/v3/edge_accessor.hpp" +#include "storage/v3/id_types.hpp" #include "storage/v3/indices.hpp" #include "storage/v3/mvcc.hpp" #include "storage/v3/replication/config.hpp" @@ -35,10 +38,12 @@ #include "storage/v3/replication/rpc.hpp" #include "storage/v3/transaction.hpp" #include "storage/v3/vertex_accessor.hpp" +#include "utils/exceptions.hpp" #include "utils/file.hpp" #include "utils/logging.hpp" #include "utils/memory_tracker.hpp" #include "utils/message.hpp" +#include "utils/result.hpp" #include "utils/rw_lock.hpp" #include "utils/spin_lock.hpp" #include "utils/stat.hpp" @@ -54,9 +59,9 @@ inline constexpr uint16_t kEpochHistoryRetention = 1000; auto AdvanceToVisibleVertex(utils::SkipList<Vertex>::Iterator it, utils::SkipList<Vertex>::Iterator end, std::optional<VertexAccessor> *vertex, Transaction *tx, View view, Indices *indices, - Constraints *constraints, Config::Items config) { + Constraints *constraints, Config::Items config, const SchemaValidator &schema_validator) { while (it != end) { - *vertex = VertexAccessor::Create(&*it, tx, indices, constraints, config, view); + *vertex = VertexAccessor::Create(&*it, tx, indices, constraints, config, schema_validator, view); if (!*vertex) { ++it; continue; @@ -69,14 +74,14 @@ auto AdvanceToVisibleVertex(utils::SkipList<Vertex>::Iterator it, utils::SkipLis AllVerticesIterable::Iterator::Iterator(AllVerticesIterable *self, utils::SkipList<Vertex>::Iterator it) : self_(self), it_(AdvanceToVisibleVertex(it, self->vertices_accessor_.end(), &self->vertex_, self->transaction_, self->view_, - self->indices_, self_->constraints_, self->config_)) {} + self->indices_, self_->constraints_, self->config_, *self->schema_validator_)) {} VertexAccessor AllVerticesIterable::Iterator::operator*() const { return *self_->vertex_; } AllVerticesIterable::Iterator &AllVerticesIterable::Iterator::operator++() { ++it_; it_ = AdvanceToVisibleVertex(it_, self_->vertices_accessor_.end(), &self_->vertex_, self_->transaction_, self_->view_, - self_->indices_, self_->constraints_, self_->config_); + self_->indices_, self_->constraints_, self_->config_, *self_->schema_validator_); return *this; } @@ -300,7 +305,8 @@ bool VerticesIterable::Iterator::operator==(const Iterator &other) const { } Storage::Storage(Config config) - : indices_(&constraints_, config.items), + : schema_validator_(schemas_), + indices_(&constraints_, config.items, schema_validator_), isolation_level_(config.transaction.isolation_level), config_(config), snapshot_directory_(config_.durability.storage_directory / durability::kSnapshotDirectory), @@ -462,7 +468,8 @@ Storage::Accessor::~Accessor() { FinalizeTransaction(); } -VertexAccessor Storage::Accessor::CreateVertex() { +// TODO Remove when import csv is fixed +[[deprecated]] VertexAccessor Storage::Accessor::CreateVertex() { OOMExceptionEnabler oom_exception; auto gid = storage_->vertex_id_.fetch_add(1, std::memory_order_acq_rel); auto acc = storage_->vertices_.access(); @@ -470,10 +477,12 @@ VertexAccessor Storage::Accessor::CreateVertex() { auto [it, inserted] = acc.insert(Vertex{Gid::FromUint(gid), delta}); MG_ASSERT(inserted, "The vertex must be inserted here!"); MG_ASSERT(it != acc.end(), "Invalid Vertex accessor!"); + delta->prev.Set(&*it); - return {&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_}; + return {&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_, storage_->schema_validator_}; } +// TODO Remove when replication is fixed VertexAccessor Storage::Accessor::CreateVertex(Gid gid) { OOMExceptionEnabler oom_exception; // NOTE: When we update the next `vertex_id_` here we perform a RMW @@ -486,18 +495,53 @@ VertexAccessor Storage::Accessor::CreateVertex(Gid gid) { std::memory_order_release); auto acc = storage_->vertices_.access(); auto *delta = CreateDeleteObjectDelta(&transaction_); - auto [it, inserted] = acc.insert(Vertex{gid, delta}); + auto [it, inserted] = acc.insert(Vertex{gid}); MG_ASSERT(inserted, "The vertex must be inserted here!"); MG_ASSERT(it != acc.end(), "Invalid Vertex accessor!"); delta->prev.Set(&*it); - return {&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_}; + return {&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_, storage_->schema_validator_}; +} + +ResultSchema<VertexAccessor> Storage::Accessor::CreateVertexAndValidate( + LabelId primary_label, const std::vector<LabelId> &labels, + const std::vector<std::pair<PropertyId, PropertyValue>> &properties) { + auto maybe_schema_violation = GetSchemaValidator().ValidateVertexCreate(primary_label, labels, properties); + if (maybe_schema_violation) { + return {std::move(*maybe_schema_violation)}; + } + OOMExceptionEnabler oom_exception; + auto gid = storage_->vertex_id_.fetch_add(1, std::memory_order_acq_rel); + auto acc = storage_->vertices_.access(); + auto *delta = CreateDeleteObjectDelta(&transaction_); + auto [it, inserted] = acc.insert(Vertex{Gid::FromUint(gid), delta, primary_label}); + MG_ASSERT(inserted, "The vertex must be inserted here!"); + MG_ASSERT(it != acc.end(), "Invalid Vertex accessor!"); + delta->prev.Set(&*it); + + auto va = VertexAccessor{ + &*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_, storage_->schema_validator_}; + for (const auto label : labels) { + const auto maybe_error = va.AddLabel(label); + if (maybe_error.HasError()) { + return {maybe_error.GetError()}; + } + } + // Set properties + for (auto [property_id, property_value] : properties) { + const auto maybe_error = va.SetProperty(property_id, property_value); + if (maybe_error.HasError()) { + return {maybe_error.GetError()}; + } + } + return va; } std::optional<VertexAccessor> Storage::Accessor::FindVertex(Gid gid, View view) { auto acc = storage_->vertices_.access(); auto it = acc.find(gid); if (it == acc.end()) return std::nullopt; - return VertexAccessor::Create(&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_, view); + return VertexAccessor::Create(&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_, + storage_->schema_validator_, view); } Result<std::optional<VertexAccessor>> Storage::Accessor::DeleteVertex(VertexAccessor *vertex) { @@ -520,7 +564,7 @@ Result<std::optional<VertexAccessor>> Storage::Accessor::DeleteVertex(VertexAcce vertex_ptr->deleted = true; return std::make_optional<VertexAccessor>(vertex_ptr, &transaction_, &storage_->indices_, &storage_->constraints_, - config_, true); + config_, storage_->schema_validator_, true); } Result<std::optional<std::pair<VertexAccessor, std::vector<EdgeAccessor>>>> Storage::Accessor::DetachDeleteVertex( @@ -550,7 +594,7 @@ Result<std::optional<std::pair<VertexAccessor, std::vector<EdgeAccessor>>>> Stor for (const auto &item : in_edges) { auto [edge_type, from_vertex, edge] = item; EdgeAccessor e(edge, edge_type, from_vertex, vertex_ptr, &transaction_, &storage_->indices_, - &storage_->constraints_, config_); + &storage_->constraints_, config_, storage_->schema_validator_); auto ret = DeleteEdge(&e); if (ret.HasError()) { MG_ASSERT(ret.GetError() == Error::SERIALIZATION_ERROR, "Invalid database state!"); @@ -564,7 +608,7 @@ Result<std::optional<std::pair<VertexAccessor, std::vector<EdgeAccessor>>>> Stor for (const auto &item : out_edges) { auto [edge_type, to_vertex, edge] = item; EdgeAccessor e(edge, edge_type, vertex_ptr, to_vertex, &transaction_, &storage_->indices_, &storage_->constraints_, - config_); + config_, storage_->schema_validator_); auto ret = DeleteEdge(&e); if (ret.HasError()) { MG_ASSERT(ret.GetError() == Error::SERIALIZATION_ERROR, "Invalid database state!"); @@ -590,7 +634,8 @@ Result<std::optional<std::pair<VertexAccessor, std::vector<EdgeAccessor>>>> Stor vertex_ptr->deleted = true; return std::make_optional<ReturnType>( - VertexAccessor{vertex_ptr, &transaction_, &storage_->indices_, &storage_->constraints_, config_, true}, + VertexAccessor{vertex_ptr, &transaction_, &storage_->indices_, &storage_->constraints_, config_, + storage_->schema_validator_, true}, std::move(deleted_edges)); } @@ -650,7 +695,7 @@ Result<EdgeAccessor> Storage::Accessor::CreateEdge(VertexAccessor *from, VertexA storage_->edge_count_.fetch_add(1, std::memory_order_acq_rel); return EdgeAccessor(edge, edge_type, from_vertex, to_vertex, &transaction_, &storage_->indices_, - &storage_->constraints_, config_); + &storage_->constraints_, config_, storage_->schema_validator_); } Result<EdgeAccessor> Storage::Accessor::CreateEdge(VertexAccessor *from, VertexAccessor *to, EdgeTypeId edge_type, @@ -718,7 +763,7 @@ Result<EdgeAccessor> Storage::Accessor::CreateEdge(VertexAccessor *from, VertexA storage_->edge_count_.fetch_add(1, std::memory_order_acq_rel); return EdgeAccessor(edge, edge_type, from_vertex, to_vertex, &transaction_, &storage_->indices_, - &storage_->constraints_, config_); + &storage_->constraints_, config_, storage_->schema_validator_); } Result<std::optional<EdgeAccessor>> Storage::Accessor::DeleteEdge(EdgeAccessor *edge) { @@ -802,7 +847,8 @@ Result<std::optional<EdgeAccessor>> Storage::Accessor::DeleteEdge(EdgeAccessor * storage_->edge_count_.fetch_add(-1, std::memory_order_acq_rel); return std::make_optional<EdgeAccessor>(edge_ref, edge_type, from_vertex, to_vertex, &transaction_, - &storage_->indices_, &storage_->constraints_, config_, true); + &storage_->indices_, &storage_->constraints_, config_, + storage_->schema_validator_, true); } const std::string &Storage::Accessor::LabelToName(LabelId label) const { return storage_->LabelToName(label); } @@ -815,11 +861,11 @@ const std::string &Storage::Accessor::EdgeTypeToName(EdgeTypeId edge_type) const return storage_->EdgeTypeToName(edge_type); } -LabelId Storage::Accessor::NameToLabel(const std::string_view &name) { return storage_->NameToLabel(name); } +LabelId Storage::Accessor::NameToLabel(const std::string_view name) { return storage_->NameToLabel(name); } -PropertyId Storage::Accessor::NameToProperty(const std::string_view &name) { return storage_->NameToProperty(name); } +PropertyId Storage::Accessor::NameToProperty(const std::string_view name) { return storage_->NameToProperty(name); } -EdgeTypeId Storage::Accessor::NameToEdgeType(const std::string_view &name) { return storage_->NameToEdgeType(name); } +EdgeTypeId Storage::Accessor::NameToEdgeType(const std::string_view name) { return storage_->NameToEdgeType(name); } void Storage::Accessor::AdvanceCommand() { ++transaction_.command_id; } @@ -846,11 +892,11 @@ utils::BasicResult<ConstraintViolation, void> Storage::Accessor::Commit( auto validation_result = ValidateExistenceConstraints(*prev.vertex, storage_->constraints_); if (validation_result) { Abort(); - return *validation_result; + return {*validation_result}; } } - // Result of validating the vertex against unqiue constraints. It has to be + // Result of validating the vertex against unique constraints. It has to be // declared outside of the critical section scope because its value is // tested for Abort call which has to be done out of the scope. std::optional<ConstraintViolation> unique_constraint_violation; @@ -931,7 +977,7 @@ utils::BasicResult<ConstraintViolation, void> Storage::Accessor::Commit( if (unique_constraint_violation) { Abort(); - return *unique_constraint_violation; + return {*unique_constraint_violation}; } } is_transaction_active_ = false; @@ -1124,13 +1170,13 @@ const std::string &Storage::EdgeTypeToName(EdgeTypeId edge_type) const { return name_id_mapper_.IdToName(edge_type.AsUint()); } -LabelId Storage::NameToLabel(const std::string_view &name) { return LabelId::FromUint(name_id_mapper_.NameToId(name)); } +LabelId Storage::NameToLabel(const std::string_view name) { return LabelId::FromUint(name_id_mapper_.NameToId(name)); } -PropertyId Storage::NameToProperty(const std::string_view &name) { +PropertyId Storage::NameToProperty(const std::string_view name) { return PropertyId::FromUint(name_id_mapper_.NameToId(name)); } -EdgeTypeId Storage::NameToEdgeType(const std::string_view &name) { +EdgeTypeId Storage::NameToEdgeType(const std::string_view name) { return EdgeTypeId::FromUint(name_id_mapper_.NameToId(name)); } @@ -1232,11 +1278,29 @@ UniqueConstraints::DeletionStatus Storage::DropUniqueConstraint( return UniqueConstraints::DeletionStatus::SUCCESS; } +const SchemaValidator &Storage::Accessor::GetSchemaValidator() const { return storage_->schema_validator_; } + ConstraintsInfo Storage::ListAllConstraints() const { std::shared_lock<utils::RWLock> storage_guard_(main_lock_); return {ListExistenceConstraints(constraints_), constraints_.unique_constraints.ListConstraints()}; } +SchemasInfo Storage::ListAllSchemas() const { + std::shared_lock<utils::RWLock> storage_guard_(main_lock_); + return {schemas_.ListSchemas()}; +} + +const Schemas::Schema *Storage::GetSchema(const LabelId primary_label) const { + std::shared_lock<utils::RWLock> storage_guard_(main_lock_); + return schemas_.GetSchema(primary_label); +} + +bool Storage::CreateSchema(const LabelId primary_label, const std::vector<SchemaProperty> &schemas_types) { + return schemas_.CreateSchema(primary_label, schemas_types); +} + +bool Storage::DropSchema(const LabelId primary_label) { return schemas_.DropSchema(primary_label); } + StorageInfo Storage::GetInfo() const { auto vertex_count = vertices_.size(); auto edge_count = edge_count_.load(std::memory_order_acquire); @@ -1253,21 +1317,22 @@ VerticesIterable Storage::Accessor::Vertices(LabelId label, View view) { } VerticesIterable Storage::Accessor::Vertices(LabelId label, PropertyId property, View view) { - return VerticesIterable(storage_->indices_.label_property_index.Vertices(label, property, std::nullopt, std::nullopt, - view, &transaction_)); + return VerticesIterable(storage_->indices_.label_property_index.Vertices( + label, property, std::nullopt, std::nullopt, view, &transaction_, storage_->schema_validator_)); } VerticesIterable Storage::Accessor::Vertices(LabelId label, PropertyId property, const PropertyValue &value, View view) { return VerticesIterable(storage_->indices_.label_property_index.Vertices( - label, property, utils::MakeBoundInclusive(value), utils::MakeBoundInclusive(value), view, &transaction_)); + label, property, utils::MakeBoundInclusive(value), utils::MakeBoundInclusive(value), view, &transaction_, + storage_->schema_validator_)); } VerticesIterable Storage::Accessor::Vertices(LabelId label, PropertyId property, const std::optional<utils::Bound<PropertyValue>> &lower_bound, const std::optional<utils::Bound<PropertyValue>> &upper_bound, View view) { - return VerticesIterable( - storage_->indices_.label_property_index.Vertices(label, property, lower_bound, upper_bound, view, &transaction_)); + return VerticesIterable(storage_->indices_.label_property_index.Vertices( + label, property, lower_bound, upper_bound, view, &transaction_, storage_->schema_validator_)); } Transaction Storage::CreateTransaction(IsolationLevel isolation_level) { @@ -1795,8 +1860,8 @@ utils::BasicResult<Storage::CreateSnapshotError> Storage::CreateSnapshot() { // Create snapshot. durability::CreateSnapshot(&transaction, snapshot_directory_, wal_directory_, config_.durability.snapshot_retention_count, &vertices_, &edges_, &name_id_mapper_, - &indices_, &constraints_, config_.items, uuid_, epoch_id_, epoch_history_, - &file_retainer_); + &indices_, &constraints_, config_.items, schema_validator_, uuid_, epoch_id_, + epoch_history_, &file_retainer_); // Finalize snapshot transaction. commit_log_->MarkFinished(transaction.start_timestamp); diff --git a/src/storage/v3/storage.hpp b/src/storage/v3/storage.hpp index 40fe708b1..93a6687ce 100644 --- a/src/storage/v3/storage.hpp +++ b/src/storage/v3/storage.hpp @@ -16,8 +16,10 @@ #include <optional> #include <shared_mutex> #include <variant> +#include <vector> #include "io/network/endpoint.hpp" +#include "kvstore/kvstore.hpp" #include "storage/v3/commit_log.hpp" #include "storage/v3/config.hpp" #include "storage/v3/constraints.hpp" @@ -25,14 +27,19 @@ #include "storage/v3/durability/wal.hpp" #include "storage/v3/edge.hpp" #include "storage/v3/edge_accessor.hpp" +#include "storage/v3/id_types.hpp" #include "storage/v3/indices.hpp" #include "storage/v3/isolation_level.hpp" #include "storage/v3/mvcc.hpp" #include "storage/v3/name_id_mapper.hpp" +#include "storage/v3/property_value.hpp" #include "storage/v3/result.hpp" +#include "storage/v3/schema_validator.hpp" +#include "storage/v3/schemas.hpp" #include "storage/v3/transaction.hpp" #include "storage/v3/vertex.hpp" #include "storage/v3/vertex_accessor.hpp" +#include "utils/exceptions.hpp" #include "utils/file_locker.hpp" #include "utils/on_scope_exit.hpp" #include "utils/rw_lock.hpp" @@ -66,6 +73,7 @@ class AllVerticesIterable final { Indices *indices_; Constraints *constraints_; Config::Items config_; + const SchemaValidator *schema_validator_; std::optional<VertexAccessor> vertex_; public: @@ -86,13 +94,15 @@ class AllVerticesIterable final { }; AllVerticesIterable(utils::SkipList<Vertex>::Accessor vertices_accessor, Transaction *transaction, View view, - Indices *indices, Constraints *constraints, Config::Items config) + Indices *indices, Constraints *constraints, Config::Items config, + SchemaValidator *schema_validator) : vertices_accessor_(std::move(vertices_accessor)), transaction_(transaction), view_(view), indices_(indices), constraints_(constraints), - config_(config) {} + config_(config), + schema_validator_(schema_validator) {} Iterator begin() { return {this, vertices_accessor_.begin()}; } Iterator end() { return {this, vertices_accessor_.end()}; } @@ -173,6 +183,11 @@ struct ConstraintsInfo { std::vector<std::pair<LabelId, std::set<PropertyId>>> unique; }; +/// Structure used to return information about existing schemas in the storage +struct SchemasInfo { + Schemas::SchemasList schemas; +}; + /// Structure used to return information about the storage. struct StorageInfo { uint64_t vertex_count; @@ -190,10 +205,6 @@ class Storage final { /// @throw std::bad_alloc explicit Storage(Config config = Config()); - Storage(const Storage &) = delete; - Storage(Storage &&) = delete; - Storage &operator=(const Storage &) = delete; - Storage &operator=(Storage &&) = delete; ~Storage(); class Accessor final { @@ -213,15 +224,21 @@ class Storage final { ~Accessor(); - /// @throw std::bad_alloc VertexAccessor CreateVertex(); + VertexAccessor CreateVertex(Gid gid); + + /// @throw std::bad_alloc + ResultSchema<VertexAccessor> CreateVertexAndValidate( + LabelId primary_label, const std::vector<LabelId> &labels, + const std::vector<std::pair<PropertyId, PropertyValue>> &properties); + std::optional<VertexAccessor> FindVertex(Gid gid, View view); VerticesIterable Vertices(View view) { return VerticesIterable(AllVerticesIterable(storage_->vertices_.access(), &transaction_, view, - &storage_->indices_, &storage_->constraints_, - storage_->config_.items)); + &storage_->indices_, &storage_->constraints_, storage_->config_.items, + &storage_->schema_validator_)); } VerticesIterable Vertices(LabelId label, View view); @@ -287,13 +304,13 @@ class Storage final { const std::string &EdgeTypeToName(EdgeTypeId edge_type) const; /// @throw std::bad_alloc if unable to insert a new mapping - LabelId NameToLabel(const std::string_view &name); + LabelId NameToLabel(std::string_view name); /// @throw std::bad_alloc if unable to insert a new mapping - PropertyId NameToProperty(const std::string_view &name); + PropertyId NameToProperty(std::string_view name); /// @throw std::bad_alloc if unable to insert a new mapping - EdgeTypeId NameToEdgeType(const std::string_view &name); + EdgeTypeId NameToEdgeType(std::string_view name); bool LabelIndexExists(LabelId label) const { return storage_->indices_.label_index.IndexExists(label); } @@ -310,6 +327,10 @@ class Storage final { storage_->constraints_.unique_constraints.ListConstraints()}; } + const SchemaValidator &GetSchemaValidator() const; + + SchemasInfo ListAllSchemas() const { return {storage_->schemas_.ListSchemas()}; } + void AdvanceCommand(); /// Commit returns `ConstraintViolation` if the changes made by this @@ -325,7 +346,7 @@ class Storage final { private: /// @throw std::bad_alloc - VertexAccessor CreateVertex(Gid gid); + VertexAccessor CreateVertex(Gid gid, LabelId primary_label); /// @throw std::bad_alloc Result<EdgeAccessor> CreateEdge(VertexAccessor *from, VertexAccessor *to, EdgeTypeId edge_type, Gid gid); @@ -347,13 +368,13 @@ class Storage final { const std::string &EdgeTypeToName(EdgeTypeId edge_type) const; /// @throw std::bad_alloc if unable to insert a new mapping - LabelId NameToLabel(const std::string_view &name); + LabelId NameToLabel(std::string_view name); /// @throw std::bad_alloc if unable to insert a new mapping - PropertyId NameToProperty(const std::string_view &name); + PropertyId NameToProperty(std::string_view name); /// @throw std::bad_alloc if unable to insert a new mapping - EdgeTypeId NameToEdgeType(const std::string_view &name); + EdgeTypeId NameToEdgeType(std::string_view name); /// @throw std::bad_alloc bool CreateIndex(LabelId label, std::optional<uint64_t> desired_commit_timestamp = {}); @@ -368,7 +389,7 @@ class Storage final { IndicesInfo ListAllIndices() const; /// Creates an existence constraint. Returns true if the constraint was - /// successfuly added, false if it already exists and a `ConstraintViolation` + /// successfully added, false if it already exists and a `ConstraintViolation` /// if there is an existing vertex violating the constraint. /// /// @throw std::bad_alloc @@ -406,6 +427,14 @@ class Storage final { ConstraintsInfo ListAllConstraints() const; + SchemasInfo ListAllSchemas() const; + + const Schemas::Schema *GetSchema(LabelId primary_label) const; + + bool CreateSchema(LabelId primary_label, const std::vector<SchemaProperty> &schemas_types); + + bool DropSchema(LabelId primary_label); + StorageInfo GetInfo() const; bool LockPath(); @@ -415,7 +444,12 @@ class Storage final { bool SetMainReplicationRole(); - enum class RegisterReplicaError : uint8_t { NAME_EXISTS, END_POINT_EXISTS, CONNECTION_FAILED }; + enum class RegisterReplicaError : uint8_t { + NAME_EXISTS, + END_POINT_EXISTS, + CONNECTION_FAILED, + COULD_NOT_BE_PERSISTED + }; /// @pre The instance should have a MAIN role /// @pre Timeout can only be set for SYNC replication @@ -493,8 +527,10 @@ class Storage final { NameIdMapper name_id_mapper_; + SchemaValidator schema_validator_; Constraints constraints_; Indices indices_; + Schemas schemas_; // Transaction engine utils::SpinLock engine_lock_; diff --git a/src/storage/v3/vertex.hpp b/src/storage/v3/vertex.hpp index 3bf1bbb49..618d23849 100644 --- a/src/storage/v3/vertex.hpp +++ b/src/storage/v3/vertex.hpp @@ -19,18 +19,39 @@ #include "storage/v3/edge_ref.hpp" #include "storage/v3/id_types.hpp" #include "storage/v3/property_store.hpp" +#include "utils/algorithm.hpp" #include "utils/spin_lock.hpp" namespace memgraph::storage::v3 { struct Vertex { + Vertex(Gid gid, Delta *delta, LabelId primary_label) + : gid(gid), primary_label{primary_label}, deleted(false), delta(delta) { + MG_ASSERT(delta == nullptr || delta->action == Delta::Action::DELETE_OBJECT, + "Vertex must be created with an initial DELETE_OBJECT delta!"); + } + + // TODO remove this when import replication is solved + Vertex(Gid gid, LabelId primary_label) : gid(gid), primary_label{primary_label}, deleted(false) { + MG_ASSERT(delta == nullptr || delta->action == Delta::Action::DELETE_OBJECT, + "Vertex must be created with an initial DELETE_OBJECT delta!"); + } + + // TODO remove this when import csv is solved Vertex(Gid gid, Delta *delta) : gid(gid), deleted(false), delta(delta) { MG_ASSERT(delta == nullptr || delta->action == Delta::Action::DELETE_OBJECT, "Vertex must be created with an initial DELETE_OBJECT delta!"); } + // TODO remove this when import replication is solved + explicit Vertex(Gid gid) : gid(gid), deleted(false) { + MG_ASSERT(delta == nullptr || delta->action == Delta::Action::DELETE_OBJECT, + "Vertex must be created with an initial DELETE_OBJECT delta!"); + } + Gid gid; + LabelId primary_label; std::vector<LabelId> labels; PropertyStore properties; @@ -52,4 +73,8 @@ inline bool operator<(const Vertex &first, const Vertex &second) { return first. inline bool operator==(const Vertex &first, const Gid &second) { return first.gid == second; } inline bool operator<(const Vertex &first, const Gid &second) { return first.gid < second; } +inline bool VertexHasLabel(const Vertex &vertex, const LabelId label) { + return vertex.primary_label == label || utils::Contains(vertex.labels, label); +} + } // namespace memgraph::storage::v3 diff --git a/src/storage/v3/vertex_accessor.cpp b/src/storage/v3/vertex_accessor.cpp index f9f60fa5e..867537d47 100644 --- a/src/storage/v3/vertex_accessor.cpp +++ b/src/storage/v3/vertex_accessor.cpp @@ -18,6 +18,8 @@ #include "storage/v3/indices.hpp" #include "storage/v3/mvcc.hpp" #include "storage/v3/property_value.hpp" +#include "storage/v3/schema_validator.hpp" +#include "storage/v3/vertex.hpp" #include "utils/logging.hpp" #include "utils/memory_tracker.hpp" @@ -61,12 +63,13 @@ std::pair<bool, bool> IsVisible(Vertex *vertex, Transaction *transaction, View v } // namespace detail std::optional<VertexAccessor> VertexAccessor::Create(Vertex *vertex, Transaction *transaction, Indices *indices, - Constraints *constraints, Config::Items config, View view) { + Constraints *constraints, Config::Items config, + const SchemaValidator &schema_validator, View view) { if (const auto [exists, deleted] = detail::IsVisible(vertex, transaction, view); !exists || deleted) { return std::nullopt; } - return VertexAccessor{vertex, transaction, indices, constraints, config}; + return VertexAccessor{vertex, transaction, indices, constraints, config, schema_validator}; } bool VertexAccessor::IsVisible(View view) const { @@ -93,6 +96,28 @@ Result<bool> VertexAccessor::AddLabel(LabelId label) { return true; } +ResultSchema<bool> VertexAccessor::AddLabelAndValidate(LabelId label) { + if (const auto maybe_violation_error = vertex_validator_.ValidateAddLabel(label); maybe_violation_error) { + return {*maybe_violation_error}; + } + utils::MemoryTracker::OutOfMemoryExceptionEnabler oom_exception; + std::lock_guard<utils::SpinLock> guard(vertex_->lock); + + if (!PrepareForWrite(transaction_, vertex_)) return {Error::SERIALIZATION_ERROR}; + + if (vertex_->deleted) return {Error::DELETED_OBJECT}; + + if (std::find(vertex_->labels.begin(), vertex_->labels.end(), label) != vertex_->labels.end()) return false; + + CreateAndLinkDelta(transaction_, vertex_, Delta::RemoveLabelTag(), label); + + vertex_->labels.push_back(label); + + UpdateOnAddLabel(indices_, label, vertex_, *transaction_); + + return true; +} + Result<bool> VertexAccessor::RemoveLabel(LabelId label) { std::lock_guard<utils::SpinLock> guard(vertex_->lock); @@ -110,6 +135,26 @@ Result<bool> VertexAccessor::RemoveLabel(LabelId label) { return true; } +ResultSchema<bool> VertexAccessor::RemoveLabelAndValidate(LabelId label) { + if (const auto maybe_violation_error = vertex_validator_.ValidateRemoveLabel(label); maybe_violation_error) { + return {*maybe_violation_error}; + } + std::lock_guard<utils::SpinLock> guard(vertex_->lock); + + if (!PrepareForWrite(transaction_, vertex_)) return {Error::SERIALIZATION_ERROR}; + + if (vertex_->deleted) return {Error::DELETED_OBJECT}; + + auto it = std::find(vertex_->labels.begin(), vertex_->labels.end(), label); + if (it == vertex_->labels.end()) return false; + + CreateAndLinkDelta(transaction_, vertex_, Delta::AddLabelTag(), label); + + std::swap(*it, *vertex_->labels.rbegin()); + vertex_->labels.pop_back(); + return true; +} + Result<bool> VertexAccessor::HasLabel(LabelId label, View view) const { bool exists = true; bool deleted = false; @@ -118,7 +163,7 @@ Result<bool> VertexAccessor::HasLabel(LabelId label, View view) const { { std::lock_guard<utils::SpinLock> guard(vertex_->lock); deleted = vertex_->deleted; - has_label = std::find(vertex_->labels.begin(), vertex_->labels.end(), label) != vertex_->labels.end(); + has_label = VertexHasLabel(*vertex_, label); delta = vertex_->delta; } ApplyDeltasForRead(transaction_, delta, view, [&exists, &deleted, &has_label, label](const Delta &delta) { @@ -158,6 +203,40 @@ Result<bool> VertexAccessor::HasLabel(LabelId label, View view) const { return has_label; } +Result<LabelId> VertexAccessor::PrimaryLabel(const View view) const { + bool exists = true; + bool deleted = false; + Delta *delta = nullptr; + { + std::lock_guard<utils::SpinLock> guard(vertex_->lock); + deleted = vertex_->deleted; + delta = vertex_->delta; + } + ApplyDeltasForRead(transaction_, delta, view, [&exists, &deleted](const Delta &delta) { + switch (delta.action) { + case Delta::Action::DELETE_OBJECT: { + exists = false; + break; + } + case Delta::Action::RECREATE_OBJECT: { + deleted = false; + break; + } + case Delta::Action::ADD_LABEL: + case Delta::Action::REMOVE_LABEL: + case Delta::Action::SET_PROPERTY: + case Delta::Action::ADD_IN_EDGE: + case Delta::Action::ADD_OUT_EDGE: + case Delta::Action::REMOVE_IN_EDGE: + case Delta::Action::REMOVE_OUT_EDGE: + break; + } + }); + if (!exists) return Error::NONEXISTENT_OBJECT; + if (!for_deleted_ && deleted) return Error::DELETED_OBJECT; + return vertex_->primary_label; +} + Result<std::vector<LabelId>> VertexAccessor::Labels(View view) const { bool exists = true; bool deleted = false; @@ -230,6 +309,36 @@ Result<PropertyValue> VertexAccessor::SetProperty(PropertyId property, const Pro return std::move(current_value); } +ResultSchema<PropertyValue> VertexAccessor::SetPropertyAndValidate(PropertyId property, const PropertyValue &value) { + if (auto maybe_violation_error = vertex_validator_.ValidatePropertyUpdate(property); maybe_violation_error) { + return {*maybe_violation_error}; + } + utils::MemoryTracker::OutOfMemoryExceptionEnabler oom_exception; + std::lock_guard<utils::SpinLock> guard(vertex_->lock); + + if (!PrepareForWrite(transaction_, vertex_)) { + return {Error::SERIALIZATION_ERROR}; + } + + if (vertex_->deleted) { + return {Error::DELETED_OBJECT}; + } + + auto current_value = vertex_->properties.GetProperty(property); + // We could skip setting the value if the previous one is the same to the new + // one. This would save some memory as a delta would not be created as well as + // avoid copying the value. The reason we are not doing that is because the + // current code always follows the logical pattern of "create a delta" and + // "modify in-place". Additionally, the created delta will make other + // transactions get a SERIALIZATION_ERROR. + CreateAndLinkDelta(transaction_, vertex_, Delta::SetPropertyTag(), property, current_value); + vertex_->properties.SetProperty(property, value); + + UpdateOnSetProperty(indices_, property, value, vertex_, *transaction_); + + return std::move(current_value); +} + Result<std::map<PropertyId, PropertyValue>> VertexAccessor::ClearProperties() { std::lock_guard<utils::SpinLock> guard(vertex_->lock); @@ -414,7 +523,8 @@ Result<std::vector<EdgeAccessor>> VertexAccessor::InEdges(View view, const std:: ret.reserve(in_edges.size()); for (const auto &item : in_edges) { const auto &[edge_type, from_vertex, edge] = item; - ret.emplace_back(edge, edge_type, from_vertex, vertex_, transaction_, indices_, constraints_, config_); + ret.emplace_back(edge, edge_type, from_vertex, vertex_, transaction_, indices_, constraints_, config_, + *vertex_validator_.schema_validator); } return std::move(ret); } @@ -494,7 +604,8 @@ Result<std::vector<EdgeAccessor>> VertexAccessor::OutEdges(View view, const std: ret.reserve(out_edges.size()); for (const auto &item : out_edges) { const auto &[edge_type, to_vertex, edge] = item; - ret.emplace_back(edge, edge_type, vertex_, to_vertex, transaction_, indices_, constraints_, config_); + ret.emplace_back(edge, edge_type, vertex_, to_vertex, transaction_, indices_, constraints_, config_, + *vertex_validator_.schema_validator); } return std::move(ret); } @@ -575,4 +686,21 @@ Result<size_t> VertexAccessor::OutDegree(View view) const { return degree; } +VertexAccessor::VertexValidator::VertexValidator(const SchemaValidator &schema_validator, const Vertex *vertex) + : schema_validator{&schema_validator}, vertex_{vertex} {} + +[[nodiscard]] std::optional<SchemaViolation> VertexAccessor::VertexValidator::ValidatePropertyUpdate( + PropertyId property_id) const { + MG_ASSERT(vertex_ != nullptr, "Cannot validate vertex which is nullptr"); + return schema_validator->ValidatePropertyUpdate(vertex_->primary_label, property_id); +}; + +[[nodiscard]] std::optional<SchemaViolation> VertexAccessor::VertexValidator::ValidateAddLabel(LabelId label) const { + return schema_validator->ValidateLabelUpdate(label); +} + +[[nodiscard]] std::optional<SchemaViolation> VertexAccessor::VertexValidator::ValidateRemoveLabel(LabelId label) const { + return schema_validator->ValidateLabelUpdate(label); +} + } // namespace memgraph::storage::v3 diff --git a/src/storage/v3/vertex_accessor.hpp b/src/storage/v3/vertex_accessor.hpp index 88a2828c8..6cc7d4ff3 100644 --- a/src/storage/v3/vertex_accessor.hpp +++ b/src/storage/v3/vertex_accessor.hpp @@ -13,6 +13,8 @@ #include <optional> +#include "storage/v3/id_types.hpp" +#include "storage/v3/schema_validator.hpp" #include "storage/v3/vertex.hpp" #include "storage/v3/config.hpp" @@ -29,20 +31,39 @@ struct Constraints; class VertexAccessor final { private: + struct VertexValidator { + // TODO(jbajic) Beware since vertex is pointer it will be accessed even as nullptr + explicit VertexValidator(const SchemaValidator &schema_validator, const Vertex *vertex); + + [[nodiscard]] std::optional<SchemaViolation> ValidatePropertyUpdate(PropertyId property_id) const; + + [[nodiscard]] std::optional<SchemaViolation> ValidateAddLabel(LabelId label) const; + + [[nodiscard]] std::optional<SchemaViolation> ValidateRemoveLabel(LabelId label) const; + + const SchemaValidator *schema_validator; + + private: + const Vertex *vertex_; + }; friend class Storage; public: + // Be careful when using VertexAccessor since it can be instantiated with + // nullptr values VertexAccessor(Vertex *vertex, Transaction *transaction, Indices *indices, Constraints *constraints, - Config::Items config, bool for_deleted = false) + Config::Items config, const SchemaValidator &schema_validator, bool for_deleted = false) : vertex_(vertex), transaction_(transaction), indices_(indices), constraints_(constraints), config_(config), + vertex_validator_{schema_validator, vertex}, for_deleted_(for_deleted) {} static std::optional<VertexAccessor> Create(Vertex *vertex, Transaction *transaction, Indices *indices, - Constraints *constraints, Config::Items config, View view); + Constraints *constraints, Config::Items config, + const SchemaValidator &schema_validator, View view); /// @return true if the object is visible from the current transaction bool IsVisible(View view) const; @@ -52,11 +73,23 @@ class VertexAccessor final { /// @throw std::bad_alloc Result<bool> AddLabel(LabelId label); + /// Add a label and return `true` if insertion took place. + /// `false` is returned if the label already existed, or SchemaViolation + /// if adding the label has violated one of the schema constraints. + /// @throw std::bad_alloc + ResultSchema<bool> AddLabelAndValidate(LabelId label); + /// Remove a label and return `true` if deletion took place. /// `false` is returned if the vertex did not have a label already. /// @throw std::bad_alloc Result<bool> RemoveLabel(LabelId label); + /// Remove a label and return `true` if deletion took place. + /// `false` is returned if the vertex did not have a label already. or SchemaViolation + /// if adding the label has violated one of the schema constraints. + /// @throw std::bad_alloc + ResultSchema<bool> RemoveLabelAndValidate(LabelId label); + Result<bool> HasLabel(LabelId label, View view) const; /// @throw std::bad_alloc @@ -64,10 +97,16 @@ class VertexAccessor final { /// std::vector::max_size(). Result<std::vector<LabelId>> Labels(View view) const; + Result<LabelId> PrimaryLabel(View view) const; + /// Set a property value and return the old value. /// @throw std::bad_alloc Result<PropertyValue> SetProperty(PropertyId property, const PropertyValue &value); + /// Set a property value and return the old value or error. + /// @throw std::bad_alloc + ResultSchema<PropertyValue> SetPropertyAndValidate(PropertyId property, const PropertyValue &value); + /// Remove all properties and return the values of the removed properties. /// @throw std::bad_alloc Result<std::map<PropertyId, PropertyValue>> ClearProperties(); @@ -96,6 +135,8 @@ class VertexAccessor final { Gid Gid() const noexcept { return vertex_->gid; } + const SchemaValidator *GetSchemaValidator() const; + bool operator==(const VertexAccessor &other) const noexcept { return vertex_ == other.vertex_ && transaction_ == other.transaction_; } @@ -107,6 +148,7 @@ class VertexAccessor final { Indices *indices_; Constraints *constraints_; Config::Items config_; + VertexValidator vertex_validator_; // if the accessor was created for a deleted vertex. // Accessor behaves differently for some methods based on this diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index d8af2c201..9de4860ef 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -84,9 +84,9 @@ target_link_libraries(${test_prefix}plan_pretty_print mg-query) add_unit_test(query_cost_estimator.cpp) target_link_libraries(${test_prefix}query_cost_estimator mg-query) -add_unit_test(query_dump.cpp ${CMAKE_SOURCE_DIR}/src/glue/communication.cpp) -target_link_libraries(${test_prefix}query_dump mg-communication mg-query) - +# TODO Fix later on +# add_unit_test(query_dump.cpp ${CMAKE_SOURCE_DIR}/src/glue/communication.cpp) +# target_link_libraries(${test_prefix}query_dump mg-communication mg-query) add_unit_test(query_expression_evaluator.cpp) target_link_libraries(${test_prefix}query_expression_evaluator mg-query) @@ -318,6 +318,41 @@ target_link_libraries(${test_prefix}storage_v2_replication mg-storage-v2 fmt) add_unit_test(storage_v2_isolation_level.cpp) target_link_libraries(${test_prefix}storage_v2_isolation_level mg-storage-v2) +# Test mg-storage-v3 +add_unit_test(storage_v3.cpp) +target_link_libraries(${test_prefix}storage_v3 mg-storage-v3) + +add_unit_test(storage_v3_schema.cpp) +target_link_libraries(${test_prefix}storage_v3_schema mg-storage-v3) + +# Test mg-query-v3 +add_unit_test(query_v2_interpreter.cpp ${CMAKE_SOURCE_DIR}/src/glue/v2/communication.cpp) +target_link_libraries(${test_prefix}query_v2_interpreter mg-storage-v3 mg-query-v2 mg-communication) + +add_unit_test(query_v2_query_plan_accumulate_aggregate.cpp) +target_link_libraries(${test_prefix}query_v2_query_plan_accumulate_aggregate mg-query-v2) + +add_unit_test(query_v2_query_plan_create_set_remove_delete.cpp) +target_link_libraries(${test_prefix}query_v2_query_plan_create_set_remove_delete mg-query-v2) + +add_unit_test(query_v2_query_plan_bag_semantics.cpp) +target_link_libraries(${test_prefix}query_v2_query_plan_bag_semantics mg-query-v2) + +add_unit_test(query_v2_query_plan_edge_cases.cpp ${CMAKE_SOURCE_DIR}/src/glue/v2/communication.cpp) +target_link_libraries(${test_prefix}query_v2_query_plan_edge_cases mg-communication mg-query-v2) + +add_unit_test(query_v2_query_plan_v2_create_set_remove_delete.cpp) +target_link_libraries(${test_prefix}query_v2_query_plan_v2_create_set_remove_delete mg-query-v2) + +add_unit_test(query_v2_query_plan_match_filter_return.cpp) +target_link_libraries(${test_prefix}query_v2_query_plan_match_filter_return mg-query-v2) + +add_unit_test(query_v2_cypher_main_visitor.cpp) +target_link_libraries(${test_prefix}query_v2_cypher_main_visitor mg-query-v2) + +add_unit_test(query_v2_query_required_privileges.cpp) +target_link_libraries(${test_prefix}query_v2_query_required_privileges mg-query-v2) + add_unit_test(replication_persistence_helper.cpp) target_link_libraries(${test_prefix}replication_persistence_helper mg-storage-v2) @@ -362,10 +397,6 @@ find_package(Boost REQUIRED) add_unit_test(websocket.cpp) target_link_libraries(${test_prefix}websocket mg-communication Boost::headers) -# Test storage-v3 -add_unit_test(storage_v3.cpp) -target_link_libraries(${test_prefix}storage_v3 mg-storage-v3) - # Test future add_unit_test(future.cpp) -target_link_libraries(${test_prefix}future mg-io) +target_link_libraries(${test_prefix}future mg-io) \ No newline at end of file diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index b0dd41a72..903cacf12 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -531,9 +531,9 @@ auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vec memgraph::query::test_common::OnCreate { \ std::vector<memgraph::query::Clause *> { __VA_ARGS__ } \ } -#define CREATE_INDEX_ON(label, property) \ +#define CREATE_INDEX_ON(label, property) \ storage.Create<memgraph::query::IndexQuery>(memgraph::query::IndexQuery::Action::CREATE, (label), \ - std::vector<memgraph::query::PropertyIx>{(property)}) + std::vector<memgraph::query::PropertyIx>{(property)}) #define QUERY(...) memgraph::query::test_common::GetQuery(storage, __VA_ARGS__) #define SINGLE_QUERY(...) memgraph::query::test_common::GetSingleQuery(storage.Create<SingleQuery>(), __VA_ARGS__) #define UNION(...) memgraph::query::test_common::GetCypherUnion(storage.Create<CypherUnion>(true), __VA_ARGS__) diff --git a/tests/unit/query_plan_bag_semantics.cpp b/tests/unit/query_plan_bag_semantics.cpp index f0b0916f4..d4cbaecf5 100644 --- a/tests/unit/query_plan_bag_semantics.cpp +++ b/tests/unit/query_plan_bag_semantics.cpp @@ -9,11 +9,6 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. -// -// Copyright 2017 Memgraph -// Created by Florijan Stamenkovic on 14.03.17. -// - #include <algorithm> #include <iterator> #include <memory> diff --git a/tests/unit/query_plan_common.hpp b/tests/unit/query_plan_common.hpp index 037875754..48667607d 100644 --- a/tests/unit/query_plan_common.hpp +++ b/tests/unit/query_plan_common.hpp @@ -99,6 +99,16 @@ ScanAllTuple MakeScanAll(AstStorage &storage, SymbolTable &symbol_table, const s return ScanAllTuple{node, logical_op, symbol}; } +ScanAllTuple MakeScanAllNew(AstStorage &storage, SymbolTable &symbol_table, const std::string &identifier, + std::shared_ptr<LogicalOperator> input = {nullptr}, + memgraph::storage::View view = memgraph::storage::View::OLD) { + auto *node = NODE(identifier, "label"); + auto symbol = symbol_table.CreateSymbol(identifier, true); + node->identifier_->MapTo(symbol); + auto logical_op = std::make_shared<ScanAll>(input, symbol, view); + return ScanAllTuple{node, logical_op, symbol}; +} + /** * Creates and returns a tuple of stuff for a scan-all starting * from the node with the given name and label. diff --git a/tests/unit/query_v2_cypher_main_visitor.cpp b/tests/unit/query_v2_cypher_main_visitor.cpp new file mode 100644 index 000000000..308f9a049 --- /dev/null +++ b/tests/unit/query_v2_cypher_main_visitor.cpp @@ -0,0 +1,4326 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <algorithm> +#include <climits> +#include <limits> +#include <optional> +#include <string> +#include <unordered_map> +#include <variant> +#include <vector> + +////////////////////////////////////////////////////// +// "json.hpp" should always come before "antrl4-runtime.h" +// "json.hpp" uses libc's EOF macro while +// "antrl4-runtime.h" contains a static variable of the +// same name, EOF. +// This hides the definition of the macro which causes +// the compilation to fail. +#include <json/json.hpp> +////////////////////////////////////////////////////// +#include <antlr4-runtime.h> +#include <gmock/gmock-matchers.h> +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include "common/types.hpp" +#include "query/v2/exceptions.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/frontend/ast/cypher_main_visitor.hpp" +#include "query/v2/frontend/opencypher/parser.hpp" +#include "query/v2/frontend/stripped.hpp" +#include "query/v2/procedure/cypher_types.hpp" +#include "query/v2/procedure/mg_procedure_impl.hpp" +#include "query/v2/procedure/module.hpp" +#include "query/v2/typed_value.hpp" + +#include "utils/string.hpp" +#include "utils/variant_helpers.hpp" + +using namespace memgraph::query::v2; +using namespace memgraph::query::v2::frontend; +using memgraph::query::v2::TypedValue; +using testing::ElementsAre; +using testing::Pair; +using testing::UnorderedElementsAre; + +// Base class for all test types +class Base { + public: + ParsingContext context_; + Parameters parameters_; + + virtual ~Base() {} + + virtual Query *ParseQuery(const std::string &query_string) = 0; + + virtual PropertyIx Prop(const std::string &prop_name) = 0; + + virtual LabelIx Label(const std::string &label_name) = 0; + + virtual EdgeTypeIx EdgeType(const std::string &edge_type_name) = 0; + + TypedValue LiteralValue(Expression *expression) { + if (context_.is_query_cached) { + auto *param_lookup = dynamic_cast<ParameterLookup *>(expression); + return TypedValue(parameters_.AtTokenPosition(param_lookup->token_position_)); + } else { + auto *literal = dynamic_cast<PrimitiveLiteral *>(expression); + return TypedValue(literal->value_); + } + } + + TypedValue GetLiteral(Expression *expression, const bool use_parameter_lookup, + const std::optional<int> &token_position = std::nullopt) const { + if (use_parameter_lookup) { + auto *param_lookup = dynamic_cast<ParameterLookup *>(expression); + if (param_lookup == nullptr) { + ADD_FAILURE(); + return {}; + } + if (token_position) { + EXPECT_EQ(param_lookup->token_position_, *token_position); + } + return TypedValue(parameters_.AtTokenPosition(param_lookup->token_position_)); + } + + auto *literal = dynamic_cast<PrimitiveLiteral *>(expression); + if (literal == nullptr) { + ADD_FAILURE(); + return {}; + } + if (token_position) { + EXPECT_EQ(literal->token_position_, *token_position); + } + return TypedValue(literal->value_); + } + + template <class TValue> + void CheckLiteral(Expression *expression, const TValue &expected, + const std::optional<int> &token_position = std::nullopt) const { + TypedValue expected_tv(expected); + const auto use_parameter_lookup = !expected_tv.IsNull() && context_.is_query_cached; + TypedValue value = GetLiteral(expression, use_parameter_lookup, token_position); + EXPECT_TRUE(TypedValue::BoolEqual{}(value, expected_tv)); + } +}; + +// This generator uses ast constructed by parsing the query. +class AstGenerator : public Base { + public: + Query *ParseQuery(const std::string &query_string) override { + ::frontend::opencypher::Parser parser(query_string); + CypherMainVisitor visitor(context_, &ast_storage_); + visitor.visit(parser.tree()); + return visitor.query(); + } + + PropertyIx Prop(const std::string &prop_name) override { return ast_storage_.GetPropertyIx(prop_name); } + + LabelIx Label(const std::string &name) override { return ast_storage_.GetLabelIx(name); } + + EdgeTypeIx EdgeType(const std::string &name) override { return ast_storage_.GetEdgeTypeIx(name); } + + AstStorage ast_storage_; +}; + +// This clones ast, but uses original one. This done just to ensure that cloning +// doesn't change original. +class OriginalAfterCloningAstGenerator : public AstGenerator { + public: + Query *ParseQuery(const std::string &query_string) override { + auto *original_query = AstGenerator::ParseQuery(query_string); + AstStorage storage; + original_query->Clone(&storage); + return original_query; + } +}; + +// This generator clones parsed ast and uses that one. +// Original ast is cleared after cloning to ensure that cloned ast doesn't reuse +// any data from original ast. +class ClonedAstGenerator : public Base { + public: + Query *ParseQuery(const std::string &query_string) override { + ::frontend::opencypher::Parser parser(query_string); + AstStorage tmp_storage; + { + // Add a label, property and edge type into temporary storage so + // indices have to change in cloned AST. + tmp_storage.GetLabelIx("jkfdklajfkdalsj"); + tmp_storage.GetPropertyIx("fdjakfjdklfjdaslk"); + tmp_storage.GetEdgeTypeIx("fdjkalfjdlkajfdkla"); + } + CypherMainVisitor visitor(context_, &tmp_storage); + visitor.visit(parser.tree()); + return visitor.query()->Clone(&ast_storage_); + } + + PropertyIx Prop(const std::string &prop_name) override { return ast_storage_.GetPropertyIx(prop_name); } + + LabelIx Label(const std::string &name) override { return ast_storage_.GetLabelIx(name); } + + EdgeTypeIx EdgeType(const std::string &name) override { return ast_storage_.GetEdgeTypeIx(name); } + + AstStorage ast_storage_; +}; + +// This generator strips ast, clones it and then plugs stripped out literals in +// the same way it is done in ast cacheing in interpreter. +class CachedAstGenerator : public Base { + public: + Query *ParseQuery(const std::string &query_string) override { + context_.is_query_cached = true; + StrippedQuery stripped(query_string); + parameters_ = stripped.literals(); + ::frontend::opencypher::Parser parser(stripped.query()); + AstStorage tmp_storage; + CypherMainVisitor visitor(context_, &tmp_storage); + visitor.visit(parser.tree()); + return visitor.query()->Clone(&ast_storage_); + } + + PropertyIx Prop(const std::string &prop_name) override { return ast_storage_.GetPropertyIx(prop_name); } + + LabelIx Label(const std::string &name) override { return ast_storage_.GetLabelIx(name); } + + EdgeTypeIx EdgeType(const std::string &name) override { return ast_storage_.GetEdgeTypeIx(name); } + + AstStorage ast_storage_; +}; + +class MockModule : public procedure::Module { + public: + MockModule(){}; + ~MockModule() override{}; + MockModule(const MockModule &) = delete; + MockModule(MockModule &&) = delete; + MockModule &operator=(const MockModule &) = delete; + MockModule &operator=(MockModule &&) = delete; + + bool Close() override { return true; }; + + const std::map<std::string, mgp_proc, std::less<>> *Procedures() const override { return &procedures; } + + const std::map<std::string, mgp_trans, std::less<>> *Transformations() const override { return &transformations; } + + const std::map<std::string, mgp_func, std::less<>> *Functions() const override { return &functions; } + + std::optional<std::filesystem::path> Path() const override { return std::nullopt; }; + + std::map<std::string, mgp_proc, std::less<>> procedures{}; + std::map<std::string, mgp_trans, std::less<>> transformations{}; + std::map<std::string, mgp_func, std::less<>> functions{}; +}; + +void DummyProcCallback(mgp_list * /*args*/, mgp_graph * /*graph*/, mgp_result * /*result*/, mgp_memory * /*memory*/){}; +void DummyFuncCallback(mgp_list * /*args*/, mgp_func_context * /*func_ctx*/, mgp_func_result * /*result*/, + mgp_memory * /*memory*/){}; + +enum class ProcedureType { WRITE, READ }; + +std::string ToString(const ProcedureType type) { return type == ProcedureType::WRITE ? "write" : "read"; } + +class CypherMainVisitorTest : public ::testing::TestWithParam<std::shared_ptr<Base>> { + public: + void SetUp() override { + { + auto mock_module_owner = std::make_unique<MockModule>(); + mock_module = mock_module_owner.get(); + procedure::gModuleRegistry.RegisterModule("mock_module", std::move(mock_module_owner)); + } + { + auto mock_module_with_dots_in_name_owner = std::make_unique<MockModule>(); + mock_module_with_dots_in_name = mock_module_with_dots_in_name_owner.get(); + procedure::gModuleRegistry.RegisterModule("mock_module.with.dots.in.name", + std::move(mock_module_with_dots_in_name_owner)); + } + } + + void TearDown() override { + // To release any_type + procedure::gModuleRegistry.UnloadAllModules(); + } + + static void AddProc(MockModule &module, const char *name, const std::vector<std::string_view> &args, + const std::vector<std::string_view> &results, const ProcedureType type) { + memgraph::utils::MemoryResource *memory = memgraph::utils::NewDeleteResource(); + const bool is_write = type == ProcedureType::WRITE; + mgp_proc proc(name, DummyProcCallback, memory, {.is_write = is_write}); + for (const auto arg : args) { + proc.args.emplace_back(memgraph::utils::pmr::string{arg, memory}, &any_type); + } + for (const auto result : results) { + proc.results.emplace(memgraph::utils::pmr::string{result, memory}, std::make_pair(&any_type, false)); + } + module.procedures.emplace(name, std::move(proc)); + } + + static void AddFunc(MockModule &module, const char *name, const std::vector<std::string_view> &args) { + memgraph::utils::MemoryResource *memory = memgraph::utils::NewDeleteResource(); + mgp_func func(name, DummyFuncCallback, memory); + for (const auto arg : args) { + func.args.emplace_back(memgraph::utils::pmr::string{arg, memory}, &any_type); + } + module.functions.emplace(name, std::move(func)); + } + + std::string CreateProcByType(const ProcedureType type, const std::vector<std::string_view> &args) { + const auto proc_name = std::string{"proc_"} + ToString(type); + SCOPED_TRACE(proc_name); + AddProc(*mock_module, proc_name.c_str(), {}, args, type); + return std::string{"mock_module."} + proc_name; + } + + static const procedure::AnyType any_type; + MockModule *mock_module{nullptr}; + MockModule *mock_module_with_dots_in_name{nullptr}; +}; + +const procedure::AnyType CypherMainVisitorTest::any_type{}; + +std::shared_ptr<Base> gAstGeneratorTypes[] = { + std::make_shared<AstGenerator>(), + std::make_shared<OriginalAfterCloningAstGenerator>(), + std::make_shared<ClonedAstGenerator>(), + std::make_shared<CachedAstGenerator>(), +}; + +INSTANTIATE_TEST_CASE_P(AstGeneratorTypes, CypherMainVisitorTest, ::testing::ValuesIn(gAstGeneratorTypes)); + +// NOTE: The above used to use *Typed Tests* functionality of gtest library. +// Unfortunately, the compilation time of this test increased to full 2 minutes! +// Although using Typed Tests is the recommended way to achieve what we want, we +// are (ab)using *Value-Parameterized Tests* functionality instead. This cuts +// down the compilation time to about 20 seconds. The original code is here for +// future reference in case someone gets the idea to change to *appropriate* +// Typed Tests mechanism and ruin the compilation times. +// +// typedef ::testing::Types<AstGenerator, OriginalAfterCloningAstGenerator, +// ClonedAstGenerator, CachedAstGenerator> +// AstGeneratorTypes; +// +// TYPED_TEST_CASE(CypherMainVisitorTest, AstGeneratorTypes); + +TEST_P(CypherMainVisitorTest, SyntaxException) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("CREATE ()-[*1....2]-()"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, SyntaxExceptionOnTrailingText) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 2 + 2 mirko"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, PropertyLookup) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN n.x")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *property_lookup = dynamic_cast<PropertyLookup *>(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(property_lookup->expression_); + auto identifier = dynamic_cast<Identifier *>(property_lookup->expression_); + ASSERT_TRUE(identifier); + ASSERT_EQ(identifier->name_, "n"); + ASSERT_EQ(property_lookup->property_, ast_generator.Prop("x")); +} + +TEST_P(CypherMainVisitorTest, LabelsTest) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN n:x:y")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *labels_test = dynamic_cast<LabelsTest *>(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(labels_test->expression_); + auto identifier = dynamic_cast<Identifier *>(labels_test->expression_); + ASSERT_TRUE(identifier); + ASSERT_EQ(identifier->name_, "n"); + ASSERT_THAT(labels_test->labels_, ElementsAre(ast_generator.Label("x"), ast_generator.Label("y"))); +} + +TEST_P(CypherMainVisitorTest, EscapedLabel) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN n:`l-$\"'ab``e````l`")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *labels_test = dynamic_cast<LabelsTest *>(return_clause->body_.named_expressions[0]->expression_); + auto identifier = dynamic_cast<Identifier *>(labels_test->expression_); + ASSERT_EQ(identifier->name_, "n"); + ASSERT_THAT(labels_test->labels_, ElementsAre(ast_generator.Label("l-$\"'ab`e``l"))); +} + +TEST_P(CypherMainVisitorTest, KeywordLabel) { + for (const auto &label : {"DeLete", "UsER"}) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(fmt::format("RETURN n:{}", label))); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *labels_test = dynamic_cast<LabelsTest *>(return_clause->body_.named_expressions[0]->expression_); + auto identifier = dynamic_cast<Identifier *>(labels_test->expression_); + ASSERT_EQ(identifier->name_, "n"); + ASSERT_THAT(labels_test->labels_, ElementsAre(ast_generator.Label(label))); + } +} + +TEST_P(CypherMainVisitorTest, HexLetterLabel) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN n:a")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *labels_test = dynamic_cast<LabelsTest *>(return_clause->body_.named_expressions[0]->expression_); + auto identifier = dynamic_cast<Identifier *>(labels_test->expression_); + EXPECT_EQ(identifier->name_, "n"); + ASSERT_THAT(labels_test->labels_, ElementsAre(ast_generator.Label("a"))); +} + +TEST_P(CypherMainVisitorTest, ReturnNoDistinctNoBagSemantics) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN x")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_FALSE(return_clause->body_.all_identifiers); + ASSERT_EQ(return_clause->body_.order_by.size(), 0U); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 1U); + ASSERT_FALSE(return_clause->body_.limit); + ASSERT_FALSE(return_clause->body_.skip); + ASSERT_FALSE(return_clause->body_.distinct); +} + +TEST_P(CypherMainVisitorTest, ReturnDistinct) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN DISTINCT x")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_TRUE(return_clause->body_.distinct); +} + +TEST_P(CypherMainVisitorTest, ReturnLimit) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN x LIMIT 5")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_TRUE(return_clause->body_.limit); + ast_generator.CheckLiteral(return_clause->body_.limit, 5); +} + +TEST_P(CypherMainVisitorTest, ReturnSkip) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN x SKIP 5")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_TRUE(return_clause->body_.skip); + ast_generator.CheckLiteral(return_clause->body_.skip, 5); +} + +TEST_P(CypherMainVisitorTest, ReturnOrderBy) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN x, y, z ORDER BY z ASC, x, y DESC")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_EQ(return_clause->body_.order_by.size(), 3U); + std::vector<std::pair<Ordering, std::string>> ordering; + for (const auto &sort_item : return_clause->body_.order_by) { + auto *identifier = dynamic_cast<Identifier *>(sort_item.expression); + ordering.emplace_back(sort_item.ordering, identifier->name_); + } + ASSERT_THAT(ordering, + UnorderedElementsAre(Pair(Ordering::ASC, "z"), Pair(Ordering::ASC, "x"), Pair(Ordering::DESC, "y"))); +} + +TEST_P(CypherMainVisitorTest, ReturnNamedIdentifier) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN var AS var5")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_FALSE(return_clause->body_.all_identifiers); + auto *named_expr = return_clause->body_.named_expressions[0]; + ASSERT_EQ(named_expr->name_, "var5"); + auto *identifier = dynamic_cast<Identifier *>(named_expr->expression_); + ASSERT_EQ(identifier->name_, "var"); +} + +TEST_P(CypherMainVisitorTest, ReturnAsterisk) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN *")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_TRUE(return_clause->body_.all_identifiers); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 0U); +} + +TEST_P(CypherMainVisitorTest, IntegerLiteral) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN 42")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, 42, 1); +} + +TEST_P(CypherMainVisitorTest, IntegerLiteralTooLarge) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 10000000000000000000000000"), SemanticException); +} + +TEST_P(CypherMainVisitorTest, BooleanLiteralTrue) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN TrUe")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, true, 1); +} + +TEST_P(CypherMainVisitorTest, BooleanLiteralFalse) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN faLSE")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, false, 1); +} + +TEST_P(CypherMainVisitorTest, NullLiteral) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN nULl")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, TypedValue(), 1); +} + +TEST_P(CypherMainVisitorTest, ParenthesizedExpression) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN (2)")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, 2); +} + +TEST_P(CypherMainVisitorTest, OrOperator) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN true Or false oR n")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *or_operator2 = dynamic_cast<OrOperator *>(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(or_operator2); + auto *or_operator1 = dynamic_cast<OrOperator *>(or_operator2->expression1_); + ASSERT_TRUE(or_operator1); + ast_generator.CheckLiteral(or_operator1->expression1_, true); + ast_generator.CheckLiteral(or_operator1->expression2_, false); + auto *operand3 = dynamic_cast<Identifier *>(or_operator2->expression2_); + ASSERT_TRUE(operand3); + ASSERT_EQ(operand3->name_, "n"); +} + +TEST_P(CypherMainVisitorTest, XorOperator) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN true xOr false")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *xor_operator = dynamic_cast<XorOperator *>(return_clause->body_.named_expressions[0]->expression_); + ast_generator.CheckLiteral(xor_operator->expression1_, true); + ast_generator.CheckLiteral(xor_operator->expression2_, false); +} + +TEST_P(CypherMainVisitorTest, AndOperator) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN true and false")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *and_operator = dynamic_cast<AndOperator *>(return_clause->body_.named_expressions[0]->expression_); + ast_generator.CheckLiteral(and_operator->expression1_, true); + ast_generator.CheckLiteral(and_operator->expression2_, false); +} + +TEST_P(CypherMainVisitorTest, AdditionSubtractionOperators) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN 1 - 2 + 3")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *addition_operator = dynamic_cast<AdditionOperator *>(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(addition_operator); + auto *subtraction_operator = dynamic_cast<SubtractionOperator *>(addition_operator->expression1_); + ASSERT_TRUE(subtraction_operator); + ast_generator.CheckLiteral(subtraction_operator->expression1_, 1); + ast_generator.CheckLiteral(subtraction_operator->expression2_, 2); + ast_generator.CheckLiteral(addition_operator->expression2_, 3); +} + +TEST_P(CypherMainVisitorTest, MulitplicationOperator) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN 2 * 3")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *mult_operator = dynamic_cast<MultiplicationOperator *>(return_clause->body_.named_expressions[0]->expression_); + ast_generator.CheckLiteral(mult_operator->expression1_, 2); + ast_generator.CheckLiteral(mult_operator->expression2_, 3); +} + +TEST_P(CypherMainVisitorTest, DivisionOperator) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN 2 / 3")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *div_operator = dynamic_cast<DivisionOperator *>(return_clause->body_.named_expressions[0]->expression_); + ast_generator.CheckLiteral(div_operator->expression1_, 2); + ast_generator.CheckLiteral(div_operator->expression2_, 3); +} + +TEST_P(CypherMainVisitorTest, ModOperator) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN 2 % 3")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *mod_operator = dynamic_cast<ModOperator *>(return_clause->body_.named_expressions[0]->expression_); + ast_generator.CheckLiteral(mod_operator->expression1_, 2); + ast_generator.CheckLiteral(mod_operator->expression2_, 3); +} + +#define CHECK_COMPARISON(TYPE, VALUE1, VALUE2) \ + do { \ + auto *and_operator = dynamic_cast<AndOperator *>(_operator); \ + ASSERT_TRUE(and_operator); \ + _operator = and_operator->expression1_; \ + auto *cmp_operator = dynamic_cast<TYPE *>(and_operator->expression2_); \ + ASSERT_TRUE(cmp_operator); \ + ast_generator.CheckLiteral(cmp_operator->expression1_, VALUE1); \ + ast_generator.CheckLiteral(cmp_operator->expression2_, VALUE2); \ + } while (0) + +TEST_P(CypherMainVisitorTest, ComparisonOperators) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN 2 = 3 != 4 <> 5 < 6 > 7 <= 8 >= 9")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + Expression *_operator = return_clause->body_.named_expressions[0]->expression_; + CHECK_COMPARISON(GreaterEqualOperator, 8, 9); + CHECK_COMPARISON(LessEqualOperator, 7, 8); + CHECK_COMPARISON(GreaterOperator, 6, 7); + CHECK_COMPARISON(LessOperator, 5, 6); + CHECK_COMPARISON(NotEqualOperator, 4, 5); + CHECK_COMPARISON(NotEqualOperator, 3, 4); + auto *cmp_operator = dynamic_cast<EqualOperator *>(_operator); + ASSERT_TRUE(cmp_operator); + ast_generator.CheckLiteral(cmp_operator->expression1_, 2); + ast_generator.CheckLiteral(cmp_operator->expression2_, 3); +} + +#undef CHECK_COMPARISON + +TEST_P(CypherMainVisitorTest, ListIndexing) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN [1,2,3] [ 2 ]")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *list_index_op = dynamic_cast<SubscriptOperator *>(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(list_index_op); + auto *list = dynamic_cast<ListLiteral *>(list_index_op->expression1_); + EXPECT_TRUE(list); + ast_generator.CheckLiteral(list_index_op->expression2_, 2); +} + +TEST_P(CypherMainVisitorTest, ListSlicingOperatorNoBounds) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("RETURN [1,2,3] [ .. ]"), SemanticException); +} + +TEST_P(CypherMainVisitorTest, ListSlicingOperator) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN [1,2,3] [ .. 2 ]")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *list_slicing_op = dynamic_cast<ListSlicingOperator *>(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(list_slicing_op); + auto *list = dynamic_cast<ListLiteral *>(list_slicing_op->list_); + EXPECT_TRUE(list); + EXPECT_FALSE(list_slicing_op->lower_bound_); + ast_generator.CheckLiteral(list_slicing_op->upper_bound_, 2); +} + +TEST_P(CypherMainVisitorTest, InListOperator) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN 5 IN [1,2]")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *in_list_operator = dynamic_cast<InListOperator *>(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(in_list_operator); + ast_generator.CheckLiteral(in_list_operator->expression1_, 5); + auto *list = dynamic_cast<ListLiteral *>(in_list_operator->expression2_); + ASSERT_TRUE(list); +} + +TEST_P(CypherMainVisitorTest, InWithListIndexing) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN 1 IN [[1,2]][0]")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *in_list_operator = dynamic_cast<InListOperator *>(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(in_list_operator); + ast_generator.CheckLiteral(in_list_operator->expression1_, 1); + auto *list_indexing = dynamic_cast<SubscriptOperator *>(in_list_operator->expression2_); + ASSERT_TRUE(list_indexing); + auto *list = dynamic_cast<ListLiteral *>(list_indexing->expression1_); + EXPECT_TRUE(list); + ast_generator.CheckLiteral(list_indexing->expression2_, 0); +} + +TEST_P(CypherMainVisitorTest, CaseGenericForm) { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN CASE WHEN n < 10 THEN 1 WHEN n > 10 THEN 2 END")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *if_operator = dynamic_cast<IfOperator *>(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(if_operator); + auto *condition = dynamic_cast<LessOperator *>(if_operator->condition_); + ASSERT_TRUE(condition); + ast_generator.CheckLiteral(if_operator->then_expression_, 1); + + auto *if_operator2 = dynamic_cast<IfOperator *>(if_operator->else_expression_); + ASSERT_TRUE(if_operator2); + auto *condition2 = dynamic_cast<GreaterOperator *>(if_operator2->condition_); + ASSERT_TRUE(condition2); + ast_generator.CheckLiteral(if_operator2->then_expression_, 2); + ast_generator.CheckLiteral(if_operator2->else_expression_, TypedValue()); +} + +TEST_P(CypherMainVisitorTest, CaseGenericFormElse) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN CASE WHEN n < 10 THEN 1 ELSE 2 END")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *if_operator = dynamic_cast<IfOperator *>(return_clause->body_.named_expressions[0]->expression_); + auto *condition = dynamic_cast<LessOperator *>(if_operator->condition_); + ASSERT_TRUE(condition); + ast_generator.CheckLiteral(if_operator->then_expression_, 1); + ast_generator.CheckLiteral(if_operator->else_expression_, 2); +} + +TEST_P(CypherMainVisitorTest, CaseSimpleForm) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN CASE 5 WHEN 10 THEN 1 END")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *if_operator = dynamic_cast<IfOperator *>(return_clause->body_.named_expressions[0]->expression_); + auto *condition = dynamic_cast<EqualOperator *>(if_operator->condition_); + ASSERT_TRUE(condition); + ast_generator.CheckLiteral(condition->expression1_, 5); + ast_generator.CheckLiteral(condition->expression2_, 10); + ast_generator.CheckLiteral(if_operator->then_expression_, 1); + ast_generator.CheckLiteral(if_operator->else_expression_, TypedValue()); +} + +TEST_P(CypherMainVisitorTest, IsNull) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN 2 iS NulL")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *is_type_operator = dynamic_cast<IsNullOperator *>(return_clause->body_.named_expressions[0]->expression_); + ast_generator.CheckLiteral(is_type_operator->expression_, 2); +} + +TEST_P(CypherMainVisitorTest, IsNotNull) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN 2 iS nOT NulL")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *not_operator = dynamic_cast<NotOperator *>(return_clause->body_.named_expressions[0]->expression_); + auto *is_type_operator = dynamic_cast<IsNullOperator *>(not_operator->expression_); + ast_generator.CheckLiteral(is_type_operator->expression_, 2); +} + +TEST_P(CypherMainVisitorTest, NotOperator) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN not true")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *not_operator = dynamic_cast<NotOperator *>(return_clause->body_.named_expressions[0]->expression_); + ast_generator.CheckLiteral(not_operator->expression_, true); +} + +TEST_P(CypherMainVisitorTest, UnaryMinusPlusOperators) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN -+5")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *unary_minus_operator = + dynamic_cast<UnaryMinusOperator *>(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(unary_minus_operator); + auto *unary_plus_operator = dynamic_cast<UnaryPlusOperator *>(unary_minus_operator->expression_); + ASSERT_TRUE(unary_plus_operator); + ast_generator.CheckLiteral(unary_plus_operator->expression_, 5); +} + +TEST_P(CypherMainVisitorTest, Aggregation) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>( + ast_generator.ParseQuery("RETURN COUNT(a), MIN(b), MAX(c), SUM(d), AVG(e), COLLECT(f), COUNT(*)")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 7U); + Aggregation::Op ops[] = {Aggregation::Op::COUNT, Aggregation::Op::MIN, Aggregation::Op::MAX, + Aggregation::Op::SUM, Aggregation::Op::AVG, Aggregation::Op::COLLECT_LIST}; + std::string ids[] = {"a", "b", "c", "d", "e", "f"}; + for (int i = 0; i < 6; ++i) { + auto *aggregation = dynamic_cast<Aggregation *>(return_clause->body_.named_expressions[i]->expression_); + ASSERT_TRUE(aggregation); + ASSERT_EQ(aggregation->op_, ops[i]); + auto *identifier = dynamic_cast<Identifier *>(aggregation->expression1_); + ASSERT_TRUE(identifier); + ASSERT_EQ(identifier->name_, ids[i]); + } + auto *aggregation = dynamic_cast<Aggregation *>(return_clause->body_.named_expressions[6]->expression_); + ASSERT_TRUE(aggregation); + ASSERT_EQ(aggregation->op_, Aggregation::Op::COUNT); + ASSERT_FALSE(aggregation->expression1_); +} + +TEST_P(CypherMainVisitorTest, UndefinedFunction) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("RETURN " + "IHopeWeWillNeverHaveAwesomeMemgraphProcedureWithS" + "uchALongAndAwesomeNameSinceThisTestWouldFail(1)"), + SemanticException); +} + +TEST_P(CypherMainVisitorTest, MissingFunction) { + AddFunc(*mock_module, "get", {}); + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("RETURN missing_function.get()"), SemanticException); +} + +TEST_P(CypherMainVisitorTest, Function) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN abs(n, 2)")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 1); + auto *function = dynamic_cast<Function *>(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(function); + ASSERT_TRUE(function->function_); +} + +TEST_P(CypherMainVisitorTest, MagicFunction) { + AddFunc(*mock_module, "get", {}); + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN mock_module.get()")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 1); + auto *function = dynamic_cast<Function *>(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(function); + ASSERT_TRUE(function->function_); +} + +TEST_P(CypherMainVisitorTest, StringLiteralDoubleQuotes) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN \"mi'rko\"")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, "mi'rko", 1); +} + +TEST_P(CypherMainVisitorTest, StringLiteralSingleQuotes) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN 'mi\"rko'")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, "mi\"rko", 1); +} + +TEST_P(CypherMainVisitorTest, StringLiteralEscapedChars) { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN '\\\\\\'\\\"\\b\\B\\f\\F\\n\\N\\r\\R\\t\\T'")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, "\\'\"\b\b\f\f\n\n\r\r\t\t", 1); +} + +TEST_P(CypherMainVisitorTest, StringLiteralEscapedUtf16) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN '\\u221daaa\\u221daaa'")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, + "\xE2\x88\x9D" + "aaa" + "\xE2\x88\x9D" + "aaa", + 1); // u8"\u221daaa\u221daaa" +} + +TEST_P(CypherMainVisitorTest, StringLiteralEscapedUtf16Error) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("RETURN '\\U221daaa'"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, StringLiteralEscapedUtf32) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN '\\U0001F600aaaa\\U0001F600aaaaaaaa'")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, + "\xF0\x9F\x98\x80" + "aaaa" + "\xF0\x9F\x98\x80" + "aaaaaaaa", + 1); // u8"\U0001F600aaaa\U0001F600aaaaaaaa" +} + +TEST_P(CypherMainVisitorTest, DoubleLiteral) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN 3.5")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, 3.5, 1); +} + +TEST_P(CypherMainVisitorTest, DoubleLiteralExponent) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN 5e-1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, 0.5, 1); +} + +TEST_P(CypherMainVisitorTest, ListLiteral) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN [3, [], 'johhny']")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *list_literal = dynamic_cast<ListLiteral *>(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(list_literal); + ASSERT_EQ(3, list_literal->elements_.size()); + ast_generator.CheckLiteral(list_literal->elements_[0], 3); + auto *elem_1 = dynamic_cast<ListLiteral *>(list_literal->elements_[1]); + ASSERT_TRUE(elem_1); + EXPECT_EQ(0, elem_1->elements_.size()); + ast_generator.CheckLiteral(list_literal->elements_[2], "johhny"); +} + +TEST_P(CypherMainVisitorTest, MapLiteral) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN {a: 1, b: 'bla', c: [1, {a: 42}]}")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + auto *map_literal = dynamic_cast<MapLiteral *>(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(map_literal); + ASSERT_EQ(3, map_literal->elements_.size()); + ast_generator.CheckLiteral(map_literal->elements_[ast_generator.Prop("a")], 1); + ast_generator.CheckLiteral(map_literal->elements_[ast_generator.Prop("b")], "bla"); + auto *elem_2 = dynamic_cast<ListLiteral *>(map_literal->elements_[ast_generator.Prop("c")]); + ASSERT_TRUE(elem_2); + EXPECT_EQ(2, elem_2->elements_.size()); + auto *elem_2_1 = dynamic_cast<MapLiteral *>(elem_2->elements_[1]); + ASSERT_TRUE(elem_2_1); + EXPECT_EQ(1, elem_2_1->elements_.size()); +} + +TEST_P(CypherMainVisitorTest, NodePattern) { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MATCH (:label1:label2:label3 {a : 5, b : 10}) RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *match = dynamic_cast<Match *>(single_query->clauses_[0]); + ASSERT_TRUE(match); + EXPECT_FALSE(match->optional_); + EXPECT_FALSE(match->where_); + ASSERT_EQ(match->patterns_.size(), 1U); + ASSERT_TRUE(match->patterns_[0]); + ASSERT_EQ(match->patterns_[0]->atoms_.size(), 1U); + auto node = dynamic_cast<NodeAtom *>(match->patterns_[0]->atoms_[0]); + ASSERT_TRUE(node); + ASSERT_TRUE(node->identifier_); + EXPECT_EQ(node->identifier_->name_, CypherMainVisitor::kAnonPrefix + std::to_string(1)); + EXPECT_FALSE(node->identifier_->user_declared_); + EXPECT_THAT(node->labels_, UnorderedElementsAre(ast_generator.Label("label1"), ast_generator.Label("label2"), + ast_generator.Label("label3"))); + std::unordered_map<PropertyIx, int64_t> properties; + for (auto x : std::get<0>(node->properties_)) { + TypedValue value = ast_generator.LiteralValue(x.second); + ASSERT_TRUE(value.type() == TypedValue::Type::Int); + properties[x.first] = value.ValueInt(); + } + EXPECT_THAT(properties, UnorderedElementsAre(Pair(ast_generator.Prop("a"), 5), Pair(ast_generator.Prop("b"), 10))); +} + +TEST_P(CypherMainVisitorTest, PropertyMapSameKeyAppearsTwice) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("MATCH ({a : 1, a : 2})"), SemanticException); +} + +TEST_P(CypherMainVisitorTest, NodePatternIdentifier) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MATCH (var) RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast<Match *>(single_query->clauses_[0]); + ASSERT_TRUE(match); + EXPECT_FALSE(match->optional_); + EXPECT_FALSE(match->where_); + auto node = dynamic_cast<NodeAtom *>(match->patterns_[0]->atoms_[0]); + ASSERT_TRUE(node); + ASSERT_TRUE(node->identifier_); + EXPECT_EQ(node->identifier_->name_, "var"); + EXPECT_TRUE(node->identifier_->user_declared_); + EXPECT_THAT(node->labels_, UnorderedElementsAre()); + EXPECT_THAT(std::get<0>(node->properties_), UnorderedElementsAre()); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternNoDetails) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MATCH ()--() RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast<Match *>(single_query->clauses_[0]); + ASSERT_TRUE(match); + EXPECT_FALSE(match->optional_); + EXPECT_FALSE(match->where_); + ASSERT_EQ(match->patterns_.size(), 1U); + ASSERT_TRUE(match->patterns_[0]); + ASSERT_EQ(match->patterns_[0]->atoms_.size(), 3U); + auto *node1 = dynamic_cast<NodeAtom *>(match->patterns_[0]->atoms_[0]); + ASSERT_TRUE(node1); + auto *edge = dynamic_cast<EdgeAtom *>(match->patterns_[0]->atoms_[1]); + ASSERT_TRUE(edge); + auto *node2 = dynamic_cast<NodeAtom *>(match->patterns_[0]->atoms_[2]); + ASSERT_TRUE(node2); + ASSERT_TRUE(node1->identifier_); + ASSERT_TRUE(edge->identifier_); + ASSERT_TRUE(node2->identifier_); + EXPECT_THAT( + std::vector<std::string>({node1->identifier_->name_, edge->identifier_->name_, node2->identifier_->name_}), + UnorderedElementsAre(CypherMainVisitor::kAnonPrefix + std::to_string(1), + CypherMainVisitor::kAnonPrefix + std::to_string(2), + CypherMainVisitor::kAnonPrefix + std::to_string(3))); + EXPECT_FALSE(node1->identifier_->user_declared_); + EXPECT_FALSE(edge->identifier_->user_declared_); + EXPECT_FALSE(node2->identifier_->user_declared_); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::BOTH); +} + +// PatternPart in braces. +TEST_P(CypherMainVisitorTest, PatternPartBraces) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MATCH ((()--())) RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast<Match *>(single_query->clauses_[0]); + ASSERT_TRUE(match); + EXPECT_FALSE(match->where_); + ASSERT_EQ(match->patterns_.size(), 1U); + ASSERT_TRUE(match->patterns_[0]); + ASSERT_EQ(match->patterns_[0]->atoms_.size(), 3U); + auto *node1 = dynamic_cast<NodeAtom *>(match->patterns_[0]->atoms_[0]); + ASSERT_TRUE(node1); + auto *edge = dynamic_cast<EdgeAtom *>(match->patterns_[0]->atoms_[1]); + ASSERT_TRUE(edge); + auto *node2 = dynamic_cast<NodeAtom *>(match->patterns_[0]->atoms_[2]); + ASSERT_TRUE(node2); + ASSERT_TRUE(node1->identifier_); + ASSERT_TRUE(edge->identifier_); + ASSERT_TRUE(node2->identifier_); + EXPECT_THAT( + std::vector<std::string>({node1->identifier_->name_, edge->identifier_->name_, node2->identifier_->name_}), + UnorderedElementsAre(CypherMainVisitor::kAnonPrefix + std::to_string(1), + CypherMainVisitor::kAnonPrefix + std::to_string(2), + CypherMainVisitor::kAnonPrefix + std::to_string(3))); + EXPECT_FALSE(node1->identifier_->user_declared_); + EXPECT_FALSE(edge->identifier_->user_declared_); + EXPECT_FALSE(node2->identifier_->user_declared_); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::BOTH); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternDetails) { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MATCH ()<-[:type1|type2 {a : 5, b : 10}]-() RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast<Match *>(single_query->clauses_[0]); + ASSERT_TRUE(match); + EXPECT_FALSE(match->optional_); + EXPECT_FALSE(match->where_); + auto *edge = dynamic_cast<EdgeAtom *>(match->patterns_[0]->atoms_[1]); + ASSERT_TRUE(edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::IN); + EXPECT_THAT(edge->edge_types_, + UnorderedElementsAre(ast_generator.EdgeType("type1"), ast_generator.EdgeType("type2"))); + std::unordered_map<PropertyIx, int64_t> properties; + for (auto x : std::get<0>(edge->properties_)) { + TypedValue value = ast_generator.LiteralValue(x.second); + ASSERT_TRUE(value.type() == TypedValue::Type::Int); + properties[x.first] = value.ValueInt(); + } + EXPECT_THAT(properties, UnorderedElementsAre(Pair(ast_generator.Prop("a"), 5), Pair(ast_generator.Prop("b"), 10))); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternVariable) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MATCH ()-[var]->() RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast<Match *>(single_query->clauses_[0]); + ASSERT_TRUE(match); + EXPECT_FALSE(match->optional_); + EXPECT_FALSE(match->where_); + auto *edge = dynamic_cast<EdgeAtom *>(match->patterns_[0]->atoms_[1]); + ASSERT_TRUE(edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::OUT); + ASSERT_TRUE(edge->identifier_); + EXPECT_THAT(edge->identifier_->name_, "var"); + EXPECT_TRUE(edge->identifier_->user_declared_); +} + +// Assert that match has a single pattern with a single edge atom and store it +// in edge parameter. +void AssertMatchSingleEdgeAtom(Match *match, EdgeAtom *&edge) { + ASSERT_TRUE(match); + ASSERT_EQ(match->patterns_.size(), 1U); + ASSERT_EQ(match->patterns_[0]->atoms_.size(), 3U); + edge = dynamic_cast<EdgeAtom *>(match->patterns_[0]->atoms_[1]); + ASSERT_TRUE(edge); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternUnbounded) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MATCH ()-[r*]->() RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast<Match *>(single_query->clauses_[0]); + EdgeAtom *edge = nullptr; + AssertMatchSingleEdgeAtom(match, edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::OUT); + EXPECT_EQ(edge->type_, EdgeAtom::Type::DEPTH_FIRST); + EXPECT_EQ(edge->lower_bound_, nullptr); + EXPECT_EQ(edge->upper_bound_, nullptr); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternLowerBounded) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MATCH ()-[r*42..]->() RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast<Match *>(single_query->clauses_[0]); + EdgeAtom *edge = nullptr; + AssertMatchSingleEdgeAtom(match, edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::OUT); + EXPECT_EQ(edge->type_, EdgeAtom::Type::DEPTH_FIRST); + ast_generator.CheckLiteral(edge->lower_bound_, 42); + EXPECT_EQ(edge->upper_bound_, nullptr); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternUpperBounded) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MATCH ()-[r*..42]->() RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast<Match *>(single_query->clauses_[0]); + EdgeAtom *edge = nullptr; + AssertMatchSingleEdgeAtom(match, edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::OUT); + EXPECT_EQ(edge->type_, EdgeAtom::Type::DEPTH_FIRST); + EXPECT_EQ(edge->lower_bound_, nullptr); + ast_generator.CheckLiteral(edge->upper_bound_, 42); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternLowerUpperBounded) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MATCH ()-[r*24..42]->() RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast<Match *>(single_query->clauses_[0]); + EdgeAtom *edge = nullptr; + AssertMatchSingleEdgeAtom(match, edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::OUT); + EXPECT_EQ(edge->type_, EdgeAtom::Type::DEPTH_FIRST); + ast_generator.CheckLiteral(edge->lower_bound_, 24); + ast_generator.CheckLiteral(edge->upper_bound_, 42); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternFixedRange) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MATCH ()-[r*42]->() RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast<Match *>(single_query->clauses_[0]); + EdgeAtom *edge = nullptr; + AssertMatchSingleEdgeAtom(match, edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::OUT); + EXPECT_EQ(edge->type_, EdgeAtom::Type::DEPTH_FIRST); + ast_generator.CheckLiteral(edge->lower_bound_, 42); + ast_generator.CheckLiteral(edge->upper_bound_, 42); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternFloatingUpperBound) { + // [r*1...2] should be parsed as [r*1..0.2] + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MATCH ()-[r*1...2]->() RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast<Match *>(single_query->clauses_[0]); + EdgeAtom *edge = nullptr; + AssertMatchSingleEdgeAtom(match, edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::OUT); + EXPECT_EQ(edge->type_, EdgeAtom::Type::DEPTH_FIRST); + ast_generator.CheckLiteral(edge->lower_bound_, 1); + ast_generator.CheckLiteral(edge->upper_bound_, 0.2); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternUnboundedWithProperty) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MATCH ()-[r* {prop: 42}]->() RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast<Match *>(single_query->clauses_[0]); + EdgeAtom *edge = nullptr; + AssertMatchSingleEdgeAtom(match, edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::OUT); + EXPECT_EQ(edge->type_, EdgeAtom::Type::DEPTH_FIRST); + EXPECT_EQ(edge->lower_bound_, nullptr); + EXPECT_EQ(edge->upper_bound_, nullptr); + ast_generator.CheckLiteral(std::get<0>(edge->properties_)[ast_generator.Prop("prop")], 42); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternDotsUnboundedWithEdgeTypeProperty) { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MATCH ()-[r:edge_type*..{prop: 42}]->() RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast<Match *>(single_query->clauses_[0]); + EdgeAtom *edge = nullptr; + AssertMatchSingleEdgeAtom(match, edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::OUT); + EXPECT_EQ(edge->type_, EdgeAtom::Type::DEPTH_FIRST); + EXPECT_EQ(edge->lower_bound_, nullptr); + EXPECT_EQ(edge->upper_bound_, nullptr); + ast_generator.CheckLiteral(std::get<0>(edge->properties_)[ast_generator.Prop("prop")], 42); + ASSERT_EQ(edge->edge_types_.size(), 1U); + auto edge_type = ast_generator.EdgeType("edge_type"); + EXPECT_EQ(edge->edge_types_[0], edge_type); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternUpperBoundedWithProperty) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MATCH ()-[r*..2{prop: 42}]->() RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast<Match *>(single_query->clauses_[0]); + EdgeAtom *edge = nullptr; + AssertMatchSingleEdgeAtom(match, edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::OUT); + EXPECT_EQ(edge->type_, EdgeAtom::Type::DEPTH_FIRST); + EXPECT_EQ(edge->lower_bound_, nullptr); + ast_generator.CheckLiteral(edge->upper_bound_, 2); + ast_generator.CheckLiteral(std::get<0>(edge->properties_)[ast_generator.Prop("prop")], 42); +} + +// TODO maybe uncomment +// // PatternPart with variable. +// TEST_P(CypherMainVisitorTest, PatternPartVariable) { +// ParserTables parser("CREATE var=()--()"); +// ASSERT_EQ(parser.identifiers_map_.size(), 1U); +// ASSERT_EQ(parser.pattern_parts_.size(), 1U); +// ASSERT_EQ(parser.relationships_.size(), 1U); +// ASSERT_EQ(parser.nodes_.size(), 2U); +// ASSERT_EQ(parser.pattern_parts_.begin()->second.nodes.size(), 2U); +// ASSERT_EQ(parser.pattern_parts_.begin()->second.relationships.size(), 1U); +// ASSERT_NE(parser.identifiers_map_.find("var"), +// parser.identifiers_map_.end()); +// auto output_identifier = parser.identifiers_map_["var"]; +// ASSERT_NE(parser.pattern_parts_.find(output_identifier), +// parser.pattern_parts_.end()); +// } + +TEST_P(CypherMainVisitorTest, ReturnUnanemdIdentifier) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN var")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_TRUE(return_clause); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 1U); + auto *named_expr = return_clause->body_.named_expressions[0]; + ASSERT_TRUE(named_expr); + ASSERT_EQ(named_expr->name_, "var"); + auto *identifier = dynamic_cast<Identifier *>(named_expr->expression_); + ASSERT_TRUE(identifier); + ASSERT_EQ(identifier->name_, "var"); + ASSERT_TRUE(identifier->user_declared_); +} + +TEST_P(CypherMainVisitorTest, Create) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CREATE (n)")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *create = dynamic_cast<Create *>(single_query->clauses_[0]); + ASSERT_TRUE(create); + ASSERT_EQ(create->patterns_.size(), 1U); + ASSERT_TRUE(create->patterns_[0]); + ASSERT_EQ(create->patterns_[0]->atoms_.size(), 1U); + auto node = dynamic_cast<NodeAtom *>(create->patterns_[0]->atoms_[0]); + ASSERT_TRUE(node); + ASSERT_TRUE(node->identifier_); + ASSERT_EQ(node->identifier_->name_, "n"); +} + +TEST_P(CypherMainVisitorTest, Delete) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("DELETE n, m")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *del = dynamic_cast<Delete *>(single_query->clauses_[0]); + ASSERT_TRUE(del); + ASSERT_FALSE(del->detach_); + ASSERT_EQ(del->expressions_.size(), 2U); + auto *identifier1 = dynamic_cast<Identifier *>(del->expressions_[0]); + ASSERT_TRUE(identifier1); + ASSERT_EQ(identifier1->name_, "n"); + auto *identifier2 = dynamic_cast<Identifier *>(del->expressions_[1]); + ASSERT_TRUE(identifier2); + ASSERT_EQ(identifier2->name_, "m"); +} + +TEST_P(CypherMainVisitorTest, DeleteDetach) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("DETACH DELETE n")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *del = dynamic_cast<Delete *>(single_query->clauses_[0]); + ASSERT_TRUE(del); + ASSERT_TRUE(del->detach_); + ASSERT_EQ(del->expressions_.size(), 1U); + auto *identifier1 = dynamic_cast<Identifier *>(del->expressions_[0]); + ASSERT_TRUE(identifier1); + ASSERT_EQ(identifier1->name_, "n"); +} + +TEST_P(CypherMainVisitorTest, OptionalMatchWhere) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("OPTIONAL MATCH (n) WHERE m RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *match = dynamic_cast<Match *>(single_query->clauses_[0]); + ASSERT_TRUE(match); + EXPECT_TRUE(match->optional_); + ASSERT_TRUE(match->where_); + auto *identifier = dynamic_cast<Identifier *>(match->where_->expression_); + ASSERT_TRUE(identifier); + ASSERT_EQ(identifier->name_, "m"); +} + +TEST_P(CypherMainVisitorTest, Set) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("SET a.x = b, c = d, e += f, g : h : i ")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 4U); + + { + auto *set_property = dynamic_cast<SetProperty *>(single_query->clauses_[0]); + ASSERT_TRUE(set_property); + ASSERT_TRUE(set_property->property_lookup_); + auto *identifier1 = dynamic_cast<Identifier *>(set_property->property_lookup_->expression_); + ASSERT_TRUE(identifier1); + ASSERT_EQ(identifier1->name_, "a"); + ASSERT_EQ(set_property->property_lookup_->property_, ast_generator.Prop("x")); + auto *identifier2 = dynamic_cast<Identifier *>(set_property->expression_); + ASSERT_EQ(identifier2->name_, "b"); + } + + { + auto *set_properties_assignment = dynamic_cast<SetProperties *>(single_query->clauses_[1]); + ASSERT_TRUE(set_properties_assignment); + ASSERT_FALSE(set_properties_assignment->update_); + ASSERT_TRUE(set_properties_assignment->identifier_); + ASSERT_EQ(set_properties_assignment->identifier_->name_, "c"); + auto *identifier = dynamic_cast<Identifier *>(set_properties_assignment->expression_); + ASSERT_EQ(identifier->name_, "d"); + } + + { + auto *set_properties_update = dynamic_cast<SetProperties *>(single_query->clauses_[2]); + ASSERT_TRUE(set_properties_update); + ASSERT_TRUE(set_properties_update->update_); + ASSERT_TRUE(set_properties_update->identifier_); + ASSERT_EQ(set_properties_update->identifier_->name_, "e"); + auto *identifier = dynamic_cast<Identifier *>(set_properties_update->expression_); + ASSERT_EQ(identifier->name_, "f"); + } + + { + auto *set_labels = dynamic_cast<SetLabels *>(single_query->clauses_[3]); + ASSERT_TRUE(set_labels); + ASSERT_TRUE(set_labels->identifier_); + ASSERT_EQ(set_labels->identifier_->name_, "g"); + ASSERT_THAT(set_labels->labels_, UnorderedElementsAre(ast_generator.Label("h"), ast_generator.Label("i"))); + } +} + +TEST_P(CypherMainVisitorTest, Remove) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("REMOVE a.x, g : h : i")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + { + auto *remove_property = dynamic_cast<RemoveProperty *>(single_query->clauses_[0]); + ASSERT_TRUE(remove_property); + ASSERT_TRUE(remove_property->property_lookup_); + auto *identifier1 = dynamic_cast<Identifier *>(remove_property->property_lookup_->expression_); + ASSERT_TRUE(identifier1); + ASSERT_EQ(identifier1->name_, "a"); + ASSERT_EQ(remove_property->property_lookup_->property_, ast_generator.Prop("x")); + } + { + auto *remove_labels = dynamic_cast<RemoveLabels *>(single_query->clauses_[1]); + ASSERT_TRUE(remove_labels); + ASSERT_TRUE(remove_labels->identifier_); + ASSERT_EQ(remove_labels->identifier_->name_, "g"); + ASSERT_THAT(remove_labels->labels_, UnorderedElementsAre(ast_generator.Label("h"), ast_generator.Label("i"))); + } +} + +TEST_P(CypherMainVisitorTest, With) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("WITH n AS m RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *with = dynamic_cast<With *>(single_query->clauses_[0]); + ASSERT_TRUE(with); + ASSERT_FALSE(with->body_.distinct); + ASSERT_FALSE(with->body_.limit); + ASSERT_FALSE(with->body_.skip); + ASSERT_EQ(with->body_.order_by.size(), 0U); + ASSERT_FALSE(with->where_); + ASSERT_EQ(with->body_.named_expressions.size(), 1U); + auto *named_expr = with->body_.named_expressions[0]; + ASSERT_EQ(named_expr->name_, "m"); + auto *identifier = dynamic_cast<Identifier *>(named_expr->expression_); + ASSERT_EQ(identifier->name_, "n"); +} + +TEST_P(CypherMainVisitorTest, WithNonAliasedExpression) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("WITH n.x RETURN 1"), SemanticException); +} + +TEST_P(CypherMainVisitorTest, WithNonAliasedVariable) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("WITH n RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *with = dynamic_cast<With *>(single_query->clauses_[0]); + ASSERT_TRUE(with); + ASSERT_EQ(with->body_.named_expressions.size(), 1U); + auto *named_expr = with->body_.named_expressions[0]; + ASSERT_EQ(named_expr->name_, "n"); + auto *identifier = dynamic_cast<Identifier *>(named_expr->expression_); + ASSERT_EQ(identifier->name_, "n"); +} + +TEST_P(CypherMainVisitorTest, WithDistinct) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("WITH DISTINCT n AS m RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *with = dynamic_cast<With *>(single_query->clauses_[0]); + ASSERT_TRUE(with->body_.distinct); + ASSERT_FALSE(with->where_); + ASSERT_EQ(with->body_.named_expressions.size(), 1U); + auto *named_expr = with->body_.named_expressions[0]; + ASSERT_EQ(named_expr->name_, "m"); + auto *identifier = dynamic_cast<Identifier *>(named_expr->expression_); + ASSERT_EQ(identifier->name_, "n"); +} + +TEST_P(CypherMainVisitorTest, WithBag) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("WITH n as m ORDER BY m SKIP 1 LIMIT 2 RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *with = dynamic_cast<With *>(single_query->clauses_[0]); + ASSERT_FALSE(with->body_.distinct); + ASSERT_FALSE(with->where_); + ASSERT_EQ(with->body_.named_expressions.size(), 1U); + // No need to check contents of body. That is checked in RETURN clause tests. + ASSERT_EQ(with->body_.order_by.size(), 1U); + ASSERT_TRUE(with->body_.limit); + ASSERT_TRUE(with->body_.skip); +} + +TEST_P(CypherMainVisitorTest, WithWhere) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("WITH n AS m WHERE k RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *with = dynamic_cast<With *>(single_query->clauses_[0]); + ASSERT_TRUE(with); + ASSERT_TRUE(with->where_); + auto *identifier = dynamic_cast<Identifier *>(with->where_->expression_); + ASSERT_TRUE(identifier); + ASSERT_EQ(identifier->name_, "k"); + ASSERT_EQ(with->body_.named_expressions.size(), 1U); + auto *named_expr = with->body_.named_expressions[0]; + ASSERT_EQ(named_expr->name_, "m"); + auto *identifier2 = dynamic_cast<Identifier *>(named_expr->expression_); + ASSERT_EQ(identifier2->name_, "n"); +} + +TEST_P(CypherMainVisitorTest, WithAnonymousVariableCapture) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("WITH 5 as anon1 MATCH () return *")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 3U); + auto *match = dynamic_cast<Match *>(single_query->clauses_[1]); + ASSERT_TRUE(match); + ASSERT_EQ(match->patterns_.size(), 1U); + auto *pattern = match->patterns_[0]; + ASSERT_TRUE(pattern); + ASSERT_EQ(pattern->atoms_.size(), 1U); + auto *atom = dynamic_cast<NodeAtom *>(pattern->atoms_[0]); + ASSERT_TRUE(atom); + ASSERT_NE("anon1", atom->identifier_->name_); +} + +TEST_P(CypherMainVisitorTest, ClausesOrdering) { + // Obviously some of the ridiculous combinations don't fail here, but they + // will fail in semantic analysis or they make perfect sense as a part of + // bigger query. + auto &ast_generator = *GetParam(); + ast_generator.ParseQuery("RETURN 1"); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 1 RETURN 1"), SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 1 MATCH (n) RETURN n"), SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 1 DELETE n"), SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 1 MERGE (n)"), SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 1 WITH n AS m RETURN 1"), SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 1 AS n UNWIND n AS x RETURN x"), SemanticException); + + ASSERT_THROW(ast_generator.ParseQuery("OPTIONAL MATCH (n) MATCH (m) RETURN n, m"), SemanticException); + ast_generator.ParseQuery("OPTIONAL MATCH (n) WITH n MATCH (m) RETURN n, m"); + ast_generator.ParseQuery("OPTIONAL MATCH (n) OPTIONAL MATCH (m) RETURN n, m"); + ast_generator.ParseQuery("MATCH (n) OPTIONAL MATCH (m) RETURN n, m"); + + ast_generator.ParseQuery("CREATE (n)"); + ASSERT_THROW(ast_generator.ParseQuery("SET n:x MATCH (n) RETURN n"), SemanticException); + ast_generator.ParseQuery("REMOVE n.x SET n.x = 1"); + ast_generator.ParseQuery("REMOVE n:L RETURN n"); + ast_generator.ParseQuery("SET n.x = 1 WITH n AS m RETURN m"); + + ASSERT_THROW(ast_generator.ParseQuery("MATCH (n)"), SemanticException); + ast_generator.ParseQuery("MATCH (n) MATCH (n) RETURN n"); + ast_generator.ParseQuery("MATCH (n) SET n = m"); + ast_generator.ParseQuery("MATCH (n) RETURN n"); + ast_generator.ParseQuery("MATCH (n) WITH n AS m RETURN m"); + + ASSERT_THROW(ast_generator.ParseQuery("WITH 1 AS n"), SemanticException); + ast_generator.ParseQuery("WITH 1 AS n WITH n AS m RETURN m"); + ast_generator.ParseQuery("WITH 1 AS n RETURN n"); + ast_generator.ParseQuery("WITH 1 AS n SET n += m"); + ast_generator.ParseQuery("WITH 1 AS n MATCH (n) RETURN n"); + + ASSERT_THROW(ast_generator.ParseQuery("UNWIND [1,2,3] AS x"), SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("CREATE (n) UNWIND [1,2,3] AS x RETURN x"), SemanticException); + ast_generator.ParseQuery("UNWIND [1,2,3] AS x CREATE (n) RETURN x"); + ast_generator.ParseQuery("CREATE (n) WITH n UNWIND [1,2,3] AS x RETURN x"); +} + +TEST_P(CypherMainVisitorTest, Merge) { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MERGE (a) -[:r]- (b) ON MATCH SET a.x = b.x " + "ON CREATE SET b :label ON MATCH SET b = a")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *merge = dynamic_cast<Merge *>(single_query->clauses_[0]); + ASSERT_TRUE(merge); + EXPECT_TRUE(dynamic_cast<Pattern *>(merge->pattern_)); + ASSERT_EQ(merge->on_match_.size(), 2U); + EXPECT_TRUE(dynamic_cast<SetProperty *>(merge->on_match_[0])); + EXPECT_TRUE(dynamic_cast<SetProperties *>(merge->on_match_[1])); + ASSERT_EQ(merge->on_create_.size(), 1U); + EXPECT_TRUE(dynamic_cast<SetLabels *>(merge->on_create_[0])); +} + +TEST_P(CypherMainVisitorTest, Unwind) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("UNWIND [1,2,3] AS elem RETURN elem")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *unwind = dynamic_cast<Unwind *>(single_query->clauses_[0]); + ASSERT_TRUE(unwind); + auto *ret = dynamic_cast<Return *>(single_query->clauses_[1]); + EXPECT_TRUE(ret); + ASSERT_TRUE(unwind->named_expression_); + EXPECT_EQ(unwind->named_expression_->name_, "elem"); + auto *expr = unwind->named_expression_->expression_; + ASSERT_TRUE(expr); + ASSERT_TRUE(dynamic_cast<ListLiteral *>(expr)); +} + +TEST_P(CypherMainVisitorTest, UnwindWithoutAsError) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("UNWIND [1,2,3] RETURN 42"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, CreateIndex) { + auto &ast_generator = *GetParam(); + auto *index_query = dynamic_cast<IndexQuery *>(ast_generator.ParseQuery("Create InDeX oN :mirko(slavko)")); + ASSERT_TRUE(index_query); + EXPECT_EQ(index_query->action_, IndexQuery::Action::CREATE); + EXPECT_EQ(index_query->label_, ast_generator.Label("mirko")); + std::vector<PropertyIx> expected_properties{ast_generator.Prop("slavko")}; + EXPECT_EQ(index_query->properties_, expected_properties); +} + +TEST_P(CypherMainVisitorTest, DropIndex) { + auto &ast_generator = *GetParam(); + auto *index_query = dynamic_cast<IndexQuery *>(ast_generator.ParseQuery("dRoP InDeX oN :mirko(slavko)")); + ASSERT_TRUE(index_query); + EXPECT_EQ(index_query->action_, IndexQuery::Action::DROP); + EXPECT_EQ(index_query->label_, ast_generator.Label("mirko")); + std::vector<PropertyIx> expected_properties{ast_generator.Prop("slavko")}; + EXPECT_EQ(index_query->properties_, expected_properties); +} + +TEST_P(CypherMainVisitorTest, DropIndexWithoutProperties) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("dRoP InDeX oN :mirko()"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, DropIndexWithMultipleProperties) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("dRoP InDeX oN :mirko(slavko, pero)"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, ReturnAll) { + { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("RETURN all(x in [1,2,3])"), SyntaxException); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN all(x IN [1,2,3] WHERE x = 2)")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *ret = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_TRUE(ret); + ASSERT_EQ(ret->body_.named_expressions.size(), 1U); + auto *all = dynamic_cast<All *>(ret->body_.named_expressions[0]->expression_); + ASSERT_TRUE(all); + EXPECT_EQ(all->identifier_->name_, "x"); + auto *list_literal = dynamic_cast<ListLiteral *>(all->list_expression_); + EXPECT_TRUE(list_literal); + auto *eq = dynamic_cast<EqualOperator *>(all->where_->expression_); + EXPECT_TRUE(eq); + } +} + +TEST_P(CypherMainVisitorTest, ReturnSingle) { + { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("RETURN single(x in [1,2,3])"), SyntaxException); + } + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN single(x IN [1,2,3] WHERE x = 2)")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *ret = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_TRUE(ret); + ASSERT_EQ(ret->body_.named_expressions.size(), 1U); + auto *single = dynamic_cast<Single *>(ret->body_.named_expressions[0]->expression_); + ASSERT_TRUE(single); + EXPECT_EQ(single->identifier_->name_, "x"); + auto *list_literal = dynamic_cast<ListLiteral *>(single->list_expression_); + EXPECT_TRUE(list_literal); + auto *eq = dynamic_cast<EqualOperator *>(single->where_->expression_); + EXPECT_TRUE(eq); +} + +TEST_P(CypherMainVisitorTest, ReturnReduce) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN reduce(sum = 0, x IN [1,2,3] | sum + x)")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *ret = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_TRUE(ret); + ASSERT_EQ(ret->body_.named_expressions.size(), 1U); + auto *reduce = dynamic_cast<Reduce *>(ret->body_.named_expressions[0]->expression_); + ASSERT_TRUE(reduce); + EXPECT_EQ(reduce->accumulator_->name_, "sum"); + ast_generator.CheckLiteral(reduce->initializer_, 0); + EXPECT_EQ(reduce->identifier_->name_, "x"); + auto *list_literal = dynamic_cast<ListLiteral *>(reduce->list_); + EXPECT_TRUE(list_literal); + auto *add = dynamic_cast<AdditionOperator *>(reduce->expression_); + EXPECT_TRUE(add); +} + +TEST_P(CypherMainVisitorTest, ReturnExtract) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN extract(x IN [1,2,3] | sum + x)")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *ret = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_TRUE(ret); + ASSERT_EQ(ret->body_.named_expressions.size(), 1U); + auto *extract = dynamic_cast<Extract *>(ret->body_.named_expressions[0]->expression_); + ASSERT_TRUE(extract); + EXPECT_EQ(extract->identifier_->name_, "x"); + auto *list_literal = dynamic_cast<ListLiteral *>(extract->list_); + EXPECT_TRUE(list_literal); + auto *add = dynamic_cast<AdditionOperator *>(extract->expression_); + EXPECT_TRUE(add); +} + +TEST_P(CypherMainVisitorTest, MatchBfsReturn) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>( + ast_generator.ParseQuery("MATCH (n) -[r:type1|type2 *bfs..10 (e, n|e.prop = 42)]-> (m) RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *match = dynamic_cast<Match *>(single_query->clauses_[0]); + ASSERT_TRUE(match); + ASSERT_EQ(match->patterns_.size(), 1U); + ASSERT_EQ(match->patterns_[0]->atoms_.size(), 3U); + auto *bfs = dynamic_cast<EdgeAtom *>(match->patterns_[0]->atoms_[1]); + ASSERT_TRUE(bfs); + EXPECT_TRUE(bfs->IsVariable()); + EXPECT_EQ(bfs->direction_, EdgeAtom::Direction::OUT); + EXPECT_THAT(bfs->edge_types_, UnorderedElementsAre(ast_generator.EdgeType("type1"), ast_generator.EdgeType("type2"))); + EXPECT_EQ(bfs->identifier_->name_, "r"); + EXPECT_EQ(bfs->filter_lambda_.inner_edge->name_, "e"); + EXPECT_TRUE(bfs->filter_lambda_.inner_edge->user_declared_); + EXPECT_EQ(bfs->filter_lambda_.inner_node->name_, "n"); + EXPECT_TRUE(bfs->filter_lambda_.inner_node->user_declared_); + ast_generator.CheckLiteral(bfs->upper_bound_, 10); + auto *eq = dynamic_cast<EqualOperator *>(bfs->filter_lambda_.expression); + ASSERT_TRUE(eq); +} + +TEST_P(CypherMainVisitorTest, MatchVariableLambdaSymbols) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MATCH () -[*]- () RETURN *")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *match = dynamic_cast<Match *>(single_query->clauses_[0]); + ASSERT_TRUE(match); + ASSERT_EQ(match->patterns_.size(), 1U); + ASSERT_EQ(match->patterns_[0]->atoms_.size(), 3U); + auto *var_expand = dynamic_cast<EdgeAtom *>(match->patterns_[0]->atoms_[1]); + ASSERT_TRUE(var_expand); + ASSERT_TRUE(var_expand->IsVariable()); + EXPECT_FALSE(var_expand->filter_lambda_.inner_edge->user_declared_); + EXPECT_FALSE(var_expand->filter_lambda_.inner_node->user_declared_); +} + +TEST_P(CypherMainVisitorTest, MatchWShortestReturn) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>( + ast_generator.ParseQuery("MATCH ()-[r:type1|type2 *wShortest 10 (we, wn | 42) total_weight " + "(e, n | true)]->() RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *match = dynamic_cast<Match *>(single_query->clauses_[0]); + ASSERT_TRUE(match); + ASSERT_EQ(match->patterns_.size(), 1U); + ASSERT_EQ(match->patterns_[0]->atoms_.size(), 3U); + auto *shortest = dynamic_cast<EdgeAtom *>(match->patterns_[0]->atoms_[1]); + ASSERT_TRUE(shortest); + EXPECT_TRUE(shortest->IsVariable()); + EXPECT_EQ(shortest->type_, EdgeAtom::Type::WEIGHTED_SHORTEST_PATH); + EXPECT_EQ(shortest->direction_, EdgeAtom::Direction::OUT); + EXPECT_THAT(shortest->edge_types_, + UnorderedElementsAre(ast_generator.EdgeType("type1"), ast_generator.EdgeType("type2"))); + ast_generator.CheckLiteral(shortest->upper_bound_, 10); + EXPECT_FALSE(shortest->lower_bound_); + EXPECT_EQ(shortest->identifier_->name_, "r"); + EXPECT_EQ(shortest->filter_lambda_.inner_edge->name_, "e"); + EXPECT_TRUE(shortest->filter_lambda_.inner_edge->user_declared_); + EXPECT_EQ(shortest->filter_lambda_.inner_node->name_, "n"); + EXPECT_TRUE(shortest->filter_lambda_.inner_node->user_declared_); + ast_generator.CheckLiteral(shortest->filter_lambda_.expression, true); + EXPECT_EQ(shortest->weight_lambda_.inner_edge->name_, "we"); + EXPECT_TRUE(shortest->weight_lambda_.inner_edge->user_declared_); + EXPECT_EQ(shortest->weight_lambda_.inner_node->name_, "wn"); + EXPECT_TRUE(shortest->weight_lambda_.inner_node->user_declared_); + ast_generator.CheckLiteral(shortest->weight_lambda_.expression, 42); + ASSERT_TRUE(shortest->total_weight_); + EXPECT_EQ(shortest->total_weight_->name_, "total_weight"); + EXPECT_TRUE(shortest->total_weight_->user_declared_); +} + +TEST_P(CypherMainVisitorTest, MatchWShortestNoFilterReturn) { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MATCH ()-[r:type1|type2 *wShortest 10 (we, wn | 42)]->() " + "RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *match = dynamic_cast<Match *>(single_query->clauses_[0]); + ASSERT_TRUE(match); + ASSERT_EQ(match->patterns_.size(), 1U); + ASSERT_EQ(match->patterns_[0]->atoms_.size(), 3U); + auto *shortest = dynamic_cast<EdgeAtom *>(match->patterns_[0]->atoms_[1]); + ASSERT_TRUE(shortest); + EXPECT_TRUE(shortest->IsVariable()); + EXPECT_EQ(shortest->type_, EdgeAtom::Type::WEIGHTED_SHORTEST_PATH); + EXPECT_EQ(shortest->direction_, EdgeAtom::Direction::OUT); + EXPECT_THAT(shortest->edge_types_, + UnorderedElementsAre(ast_generator.EdgeType("type1"), ast_generator.EdgeType("type2"))); + ast_generator.CheckLiteral(shortest->upper_bound_, 10); + EXPECT_FALSE(shortest->lower_bound_); + EXPECT_EQ(shortest->identifier_->name_, "r"); + EXPECT_FALSE(shortest->filter_lambda_.expression); + EXPECT_FALSE(shortest->filter_lambda_.inner_edge->user_declared_); + EXPECT_FALSE(shortest->filter_lambda_.inner_node->user_declared_); + EXPECT_EQ(shortest->weight_lambda_.inner_edge->name_, "we"); + EXPECT_TRUE(shortest->weight_lambda_.inner_edge->user_declared_); + EXPECT_EQ(shortest->weight_lambda_.inner_node->name_, "wn"); + EXPECT_TRUE(shortest->weight_lambda_.inner_node->user_declared_); + ast_generator.CheckLiteral(shortest->weight_lambda_.expression, 42); + ASSERT_TRUE(shortest->total_weight_); + EXPECT_FALSE(shortest->total_weight_->user_declared_); +} + +TEST_P(CypherMainVisitorTest, SemanticExceptionOnWShortestLowerBound) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("MATCH ()-[r *wShortest 10.. (e, n | 42)]-() RETURN r"), SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("MATCH ()-[r *wShortest 10..20 (e, n | 42)]-() RETURN r"), SemanticException); +} + +TEST_P(CypherMainVisitorTest, SemanticExceptionOnWShortestWithoutLambda) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("MATCH ()-[r *wShortest]-() RETURN r"), SemanticException); +} + +TEST_P(CypherMainVisitorTest, SemanticExceptionOnUnionTypeMix) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 5 as X UNION ALL RETURN 6 AS X UNION RETURN 7 AS X"), + SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 5 as X UNION RETURN 6 AS X UNION ALL RETURN 7 AS X"), + SemanticException); +} + +TEST_P(CypherMainVisitorTest, Union) { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN 5 AS X, 6 AS Y UNION RETURN 6 AS X, 5 AS Y")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_FALSE(return_clause->body_.all_identifiers); + ASSERT_EQ(return_clause->body_.order_by.size(), 0U); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 2U); + ASSERT_FALSE(return_clause->body_.limit); + ASSERT_FALSE(return_clause->body_.skip); + ASSERT_FALSE(return_clause->body_.distinct); + + ASSERT_EQ(query->cypher_unions_.size(), 1); + auto *cypher_union = query->cypher_unions_.at(0); + ASSERT_TRUE(cypher_union); + ASSERT_TRUE(cypher_union->distinct_); + ASSERT_TRUE(single_query = cypher_union->single_query_); + ASSERT_EQ(single_query->clauses_.size(), 1U); + return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_FALSE(return_clause->body_.all_identifiers); + ASSERT_EQ(return_clause->body_.order_by.size(), 0U); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 2U); + ASSERT_FALSE(return_clause->body_.limit); + ASSERT_FALSE(return_clause->body_.skip); + ASSERT_FALSE(return_clause->body_.distinct); +} + +TEST_P(CypherMainVisitorTest, UnionAll) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>( + ast_generator.ParseQuery("RETURN 5 AS X UNION ALL RETURN 6 AS X UNION ALL RETURN 7 AS X")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_FALSE(return_clause->body_.all_identifiers); + ASSERT_EQ(return_clause->body_.order_by.size(), 0U); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 1U); + ASSERT_FALSE(return_clause->body_.limit); + ASSERT_FALSE(return_clause->body_.skip); + ASSERT_FALSE(return_clause->body_.distinct); + + ASSERT_EQ(query->cypher_unions_.size(), 2); + + auto *cypher_union = query->cypher_unions_.at(0); + ASSERT_TRUE(cypher_union); + ASSERT_FALSE(cypher_union->distinct_); + ASSERT_TRUE(single_query = cypher_union->single_query_); + ASSERT_EQ(single_query->clauses_.size(), 1U); + return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_FALSE(return_clause->body_.all_identifiers); + ASSERT_EQ(return_clause->body_.order_by.size(), 0U); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 1U); + ASSERT_FALSE(return_clause->body_.limit); + ASSERT_FALSE(return_clause->body_.skip); + ASSERT_FALSE(return_clause->body_.distinct); + + cypher_union = query->cypher_unions_.at(1); + ASSERT_TRUE(cypher_union); + ASSERT_FALSE(cypher_union->distinct_); + ASSERT_TRUE(single_query = cypher_union->single_query_); + ASSERT_EQ(single_query->clauses_.size(), 1U); + return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_FALSE(return_clause->body_.all_identifiers); + ASSERT_EQ(return_clause->body_.order_by.size(), 0U); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 1U); + ASSERT_FALSE(return_clause->body_.limit); + ASSERT_FALSE(return_clause->body_.skip); + ASSERT_FALSE(return_clause->body_.distinct); +} + +void check_auth_query(Base *ast_generator, std::string input, AuthQuery::Action action, std::string user, + std::string role, std::string user_or_role, std::optional<TypedValue> password, + std::vector<AuthQuery::Privilege> privileges) { + auto *auth_query = dynamic_cast<AuthQuery *>(ast_generator->ParseQuery(input)); + ASSERT_TRUE(auth_query); + EXPECT_EQ(auth_query->action_, action); + EXPECT_EQ(auth_query->user_, user); + EXPECT_EQ(auth_query->role_, role); + EXPECT_EQ(auth_query->user_or_role_, user_or_role); + ASSERT_EQ(static_cast<bool>(auth_query->password_), static_cast<bool>(password)); + if (password) { + ast_generator->CheckLiteral(auth_query->password_, *password); + } + EXPECT_EQ(auth_query->privileges_, privileges); +} + +TEST_P(CypherMainVisitorTest, UserOrRoleName) { + auto &ast_generator = *GetParam(); + check_auth_query(&ast_generator, "CREATE ROLE `user`", AuthQuery::Action::CREATE_ROLE, "", "user", "", {}, {}); + check_auth_query(&ast_generator, "CREATE ROLE us___er", AuthQuery::Action::CREATE_ROLE, "", "us___er", "", {}, {}); + check_auth_query(&ast_generator, "CREATE ROLE `us+er`", AuthQuery::Action::CREATE_ROLE, "", "us+er", "", {}, {}); + check_auth_query(&ast_generator, "CREATE ROLE `us|er`", AuthQuery::Action::CREATE_ROLE, "", "us|er", "", {}, {}); + check_auth_query(&ast_generator, "CREATE ROLE `us er`", AuthQuery::Action::CREATE_ROLE, "", "us er", "", {}, {}); +} + +TEST_P(CypherMainVisitorTest, CreateRole) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("CREATE ROLE"), SyntaxException); + check_auth_query(&ast_generator, "CREATE ROLE rola", AuthQuery::Action::CREATE_ROLE, "", "rola", "", {}, {}); + ASSERT_THROW(ast_generator.ParseQuery("CREATE ROLE lagano rolamo"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, DropRole) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("DROP ROLE"), SyntaxException); + check_auth_query(&ast_generator, "DROP ROLE rola", AuthQuery::Action::DROP_ROLE, "", "rola", "", {}, {}); + ASSERT_THROW(ast_generator.ParseQuery("DROP ROLE lagano rolamo"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, ShowRoles) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("SHOW ROLES ROLES"), SyntaxException); + check_auth_query(&ast_generator, "SHOW ROLES", AuthQuery::Action::SHOW_ROLES, "", "", "", {}, {}); +} + +TEST_P(CypherMainVisitorTest, CreateUser) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("CREATE USER"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CREATE USER 123"), SyntaxException); + check_auth_query(&ast_generator, "CREATE USER user", AuthQuery::Action::CREATE_USER, "user", "", "", {}, {}); + check_auth_query(&ast_generator, "CREATE USER user IDENTIFIED BY 'password'", AuthQuery::Action::CREATE_USER, "user", + "", "", TypedValue("password"), {}); + check_auth_query(&ast_generator, "CREATE USER user IDENTIFIED BY ''", AuthQuery::Action::CREATE_USER, "user", "", "", + TypedValue(""), {}); + check_auth_query(&ast_generator, "CREATE USER user IDENTIFIED BY null", AuthQuery::Action::CREATE_USER, "user", "", + "", TypedValue(), {}); + ASSERT_THROW(ast_generator.ParseQuery("CRATE USER user IDENTIFIED BY password"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CREATE USER user IDENTIFIED BY 5"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CREATE USER user IDENTIFIED BY "), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, SetPassword) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("SET PASSWORD FOR"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("SET PASSWORD FOR user "), SyntaxException); + check_auth_query(&ast_generator, "SET PASSWORD FOR user TO null", AuthQuery::Action::SET_PASSWORD, "user", "", "", + TypedValue(), {}); + check_auth_query(&ast_generator, "SET PASSWORD FOR user TO 'password'", AuthQuery::Action::SET_PASSWORD, "user", "", + "", TypedValue("password"), {}); + ASSERT_THROW(ast_generator.ParseQuery("SET PASSWORD FOR user To 5"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, DropUser) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("DROP USER"), SyntaxException); + check_auth_query(&ast_generator, "DROP USER user", AuthQuery::Action::DROP_USER, "user", "", "", {}, {}); + ASSERT_THROW(ast_generator.ParseQuery("DROP USER lagano rolamo"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, ShowUsers) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("SHOW USERS ROLES"), SyntaxException); + check_auth_query(&ast_generator, "SHOW USERS", AuthQuery::Action::SHOW_USERS, "", "", "", {}, {}); +} + +TEST_P(CypherMainVisitorTest, SetRole) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("SET ROLE"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("SET ROLE user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("SET ROLE FOR user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("SET ROLE FOR user TO"), SyntaxException); + check_auth_query(&ast_generator, "SET ROLE FOR user TO role", AuthQuery::Action::SET_ROLE, "user", "role", "", {}, + {}); + check_auth_query(&ast_generator, "SET ROLE FOR user TO null", AuthQuery::Action::SET_ROLE, "user", "null", "", {}, + {}); +} + +TEST_P(CypherMainVisitorTest, ClearRole) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("CLEAR ROLE"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CLEAR ROLE user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CLEAR ROLE FOR user TO"), SyntaxException); + check_auth_query(&ast_generator, "CLEAR ROLE FOR user", AuthQuery::Action::CLEAR_ROLE, "user", "", "", {}, {}); +} + +TEST_P(CypherMainVisitorTest, GrantPrivilege) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("GRANT"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("GRANT TO user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("GRANT BLABLA TO user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("GRANT MATCH, TO user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("GRANT MATCH, BLABLA TO user"), SyntaxException); + check_auth_query(&ast_generator, "GRANT MATCH TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MATCH}); + check_auth_query(&ast_generator, "GRANT MATCH, AUTH TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MATCH, AuthQuery::Privilege::AUTH}); + // Verify that all privileges are correctly visited. + check_auth_query(&ast_generator, "GRANT CREATE TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::CREATE}); + check_auth_query(&ast_generator, "GRANT DELETE TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::DELETE}); + check_auth_query(&ast_generator, "GRANT MERGE TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MERGE}); + check_auth_query(&ast_generator, "GRANT SET TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::SET}); + check_auth_query(&ast_generator, "GRANT REMOVE TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::REMOVE}); + check_auth_query(&ast_generator, "GRANT INDEX TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::INDEX}); + check_auth_query(&ast_generator, "GRANT STATS TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::STATS}); + check_auth_query(&ast_generator, "GRANT AUTH TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::AUTH}); + check_auth_query(&ast_generator, "GRANT CONSTRAINT TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::CONSTRAINT}); + check_auth_query(&ast_generator, "GRANT DUMP TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::DUMP}); + check_auth_query(&ast_generator, "GRANT REPLICATION TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::REPLICATION}); + check_auth_query(&ast_generator, "GRANT DURABILITY TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::DURABILITY}); + check_auth_query(&ast_generator, "GRANT READ_FILE TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::READ_FILE}); + check_auth_query(&ast_generator, "GRANT FREE_MEMORY TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::FREE_MEMORY}); + check_auth_query(&ast_generator, "GRANT TRIGGER TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::TRIGGER}); + check_auth_query(&ast_generator, "GRANT CONFIG TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::CONFIG}); + check_auth_query(&ast_generator, "GRANT STREAM TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::STREAM}); + check_auth_query(&ast_generator, "GRANT WEBSOCKET TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::WEBSOCKET}); + check_auth_query(&ast_generator, "GRANT MODULE_READ TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MODULE_READ}); + check_auth_query(&ast_generator, "GRANT MODULE_WRITE TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MODULE_WRITE}); + check_auth_query(&ast_generator, "GRANT SCHEMA TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::SCHEMA}); +} + +TEST_P(CypherMainVisitorTest, DenyPrivilege) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("DENY"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("DENY TO user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("DENY BLABLA TO user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("DENY MATCH, TO user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("DENY MATCH, BLABLA TO user"), SyntaxException); + check_auth_query(&ast_generator, "DENY MATCH TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MATCH}); + check_auth_query(&ast_generator, "DENY MATCH, AUTH TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MATCH, AuthQuery::Privilege::AUTH}); + // Verify that all privileges are correctly visited. + check_auth_query(&ast_generator, "DENY CREATE TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::CREATE}); + check_auth_query(&ast_generator, "DENY DELETE TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::DELETE}); + check_auth_query(&ast_generator, "DENY MERGE TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MERGE}); + check_auth_query(&ast_generator, "DENY SET TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::SET}); + check_auth_query(&ast_generator, "DENY REMOVE TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::REMOVE}); + check_auth_query(&ast_generator, "DENY INDEX TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::INDEX}); + check_auth_query(&ast_generator, "DENY STATS TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::STATS}); + check_auth_query(&ast_generator, "DENY AUTH TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::AUTH}); + check_auth_query(&ast_generator, "DENY CONSTRAINT TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::CONSTRAINT}); + check_auth_query(&ast_generator, "DENY DUMP TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::DUMP}); + check_auth_query(&ast_generator, "DENY WEBSOCKET TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::WEBSOCKET}); + check_auth_query(&ast_generator, "DENY MODULE_READ TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MODULE_READ}); + check_auth_query(&ast_generator, "DENY MODULE_WRITE TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MODULE_WRITE}); + check_auth_query(&ast_generator, "DENY SCHEMA TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::SCHEMA}); +} + +TEST_P(CypherMainVisitorTest, RevokePrivilege) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("REVOKE"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("REVOKE FROM user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("REVOKE BLABLA FROM user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("REVOKE MATCH, FROM user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("REVOKE MATCH, BLABLA FROM user"), SyntaxException); + check_auth_query(&ast_generator, "REVOKE MATCH FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MATCH}); + check_auth_query(&ast_generator, "REVOKE MATCH, AUTH FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", + {}, {AuthQuery::Privilege::MATCH, AuthQuery::Privilege::AUTH}); + check_auth_query(&ast_generator, "REVOKE ALL PRIVILEGES FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", + "user", {}, kPrivilegesAll); + // Verify that all privileges are correctly visited. + check_auth_query(&ast_generator, "REVOKE CREATE FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::CREATE}); + check_auth_query(&ast_generator, "REVOKE DELETE FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::DELETE}); + check_auth_query(&ast_generator, "REVOKE MERGE FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MERGE}); + check_auth_query(&ast_generator, "REVOKE SET FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::SET}); + check_auth_query(&ast_generator, "REVOKE REMOVE FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::REMOVE}); + check_auth_query(&ast_generator, "REVOKE INDEX FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::INDEX}); + check_auth_query(&ast_generator, "REVOKE STATS FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::STATS}); + check_auth_query(&ast_generator, "REVOKE AUTH FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::AUTH}); + check_auth_query(&ast_generator, "REVOKE CONSTRAINT FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", + {}, {AuthQuery::Privilege::CONSTRAINT}); + check_auth_query(&ast_generator, "REVOKE DUMP FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::DUMP}); + check_auth_query(&ast_generator, "REVOKE WEBSOCKET FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", + {}, {AuthQuery::Privilege::WEBSOCKET}); + check_auth_query(&ast_generator, "REVOKE MODULE_READ FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", + {}, {AuthQuery::Privilege::MODULE_READ}); + check_auth_query(&ast_generator, "REVOKE MODULE_WRITE FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", + {}, {AuthQuery::Privilege::MODULE_WRITE}); + check_auth_query(&ast_generator, "REVOKE SCHEMA FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::SCHEMA}); +} + +TEST_P(CypherMainVisitorTest, ShowPrivileges) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("SHOW PRIVILEGES FOR"), SyntaxException); + check_auth_query(&ast_generator, "SHOW PRIVILEGES FOR user", AuthQuery::Action::SHOW_PRIVILEGES, "", "", "user", {}, + {}); + ASSERT_THROW(ast_generator.ParseQuery("SHOW PRIVILEGES FOR user1, user2"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, ShowRoleForUser) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("SHOW ROLE FOR "), SyntaxException); + check_auth_query(&ast_generator, "SHOW ROLE FOR user", AuthQuery::Action::SHOW_ROLE_FOR_USER, "user", "", "", {}, {}); + ASSERT_THROW(ast_generator.ParseQuery("SHOW ROLE FOR user1, user2"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, ShowUsersForRole) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("SHOW USERS FOR "), SyntaxException); + check_auth_query(&ast_generator, "SHOW USERS FOR role", AuthQuery::Action::SHOW_USERS_FOR_ROLE, "", "role", "", {}, + {}); + ASSERT_THROW(ast_generator.ParseQuery("SHOW USERS FOR role1, role2"), SyntaxException); +} + +void check_replication_query(Base *ast_generator, const ReplicationQuery *query, const std::string name, + const std::optional<TypedValue> socket_address, const ReplicationQuery::SyncMode sync_mode, + const std::optional<TypedValue> port = {}) { + EXPECT_EQ(query->replica_name_, name); + EXPECT_EQ(query->sync_mode_, sync_mode); + ASSERT_EQ(static_cast<bool>(query->socket_address_), static_cast<bool>(socket_address)); + if (socket_address) { + ast_generator->CheckLiteral(query->socket_address_, *socket_address); + } + ASSERT_EQ(static_cast<bool>(query->port_), static_cast<bool>(port)); + if (port) { + ast_generator->CheckLiteral(query->port_, *port); + } +} + +TEST_P(CypherMainVisitorTest, TestShowReplicationMode) { + auto &ast_generator = *GetParam(); + const std::string raw_query = "SHOW REPLICATION ROLE"; + auto *parsed_query = dynamic_cast<ReplicationQuery *>(ast_generator.ParseQuery(raw_query)); + EXPECT_EQ(parsed_query->action_, ReplicationQuery::Action::SHOW_REPLICATION_ROLE); +} + +TEST_P(CypherMainVisitorTest, TestShowReplicasQuery) { + auto &ast_generator = *GetParam(); + const std::string raw_query = "SHOW REPLICAS"; + auto *parsed_query = dynamic_cast<ReplicationQuery *>(ast_generator.ParseQuery(raw_query)); + EXPECT_EQ(parsed_query->action_, ReplicationQuery::Action::SHOW_REPLICAS); +} + +TEST_P(CypherMainVisitorTest, TestSetReplicationMode) { + auto &ast_generator = *GetParam(); + + { + const std::string query = "SET REPLICATION ROLE"; + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = "SET REPLICATION ROLE TO BUTTERY"; + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = "SET REPLICATION ROLE TO MAIN"; + auto *parsed_query = dynamic_cast<ReplicationQuery *>(ast_generator.ParseQuery(query)); + EXPECT_EQ(parsed_query->action_, ReplicationQuery::Action::SET_REPLICATION_ROLE); + EXPECT_EQ(parsed_query->role_, ReplicationQuery::ReplicationRole::MAIN); + } + + { + const std::string query = "SET REPLICATION ROLE TO MAIN WITH PORT 10000"; + ASSERT_THROW(ast_generator.ParseQuery(query), SemanticException); + } + + { + const std::string query = "SET REPLICATION ROLE TO REPLICA WITH PORT 10000"; + auto *parsed_query = dynamic_cast<ReplicationQuery *>(ast_generator.ParseQuery(query)); + EXPECT_EQ(parsed_query->action_, ReplicationQuery::Action::SET_REPLICATION_ROLE); + EXPECT_EQ(parsed_query->role_, ReplicationQuery::ReplicationRole::REPLICA); + ast_generator.CheckLiteral(parsed_query->port_, TypedValue(10000)); + } +} + +TEST_P(CypherMainVisitorTest, TestRegisterReplicationQuery) { + auto &ast_generator = *GetParam(); + + const std::string faulty_query = "REGISTER REPLICA TO"; + ASSERT_THROW(ast_generator.ParseQuery(faulty_query), SyntaxException); + + const std::string faulty_query_with_timeout = R"(REGISTER REPLICA replica1 SYNC WITH TIMEOUT 1.0 TO "127.0.0.1")"; + ASSERT_THROW(ast_generator.ParseQuery(faulty_query_with_timeout), SyntaxException); + + const std::string correct_query = R"(REGISTER REPLICA replica1 SYNC TO "127.0.0.1")"; + auto *correct_query_parsed = dynamic_cast<ReplicationQuery *>(ast_generator.ParseQuery(correct_query)); + check_replication_query(&ast_generator, correct_query_parsed, "replica1", TypedValue("127.0.0.1"), + ReplicationQuery::SyncMode::SYNC); + + std::string full_query = R"(REGISTER REPLICA replica2 SYNC TO "1.1.1.1:10000")"; + auto *full_query_parsed = dynamic_cast<ReplicationQuery *>(ast_generator.ParseQuery(full_query)); + ASSERT_TRUE(full_query_parsed); + check_replication_query(&ast_generator, full_query_parsed, "replica2", TypedValue("1.1.1.1:10000"), + ReplicationQuery::SyncMode::SYNC); +} + +TEST_P(CypherMainVisitorTest, TestDeleteReplica) { + auto &ast_generator = *GetParam(); + + std::string missing_name_query = "DROP REPLICA"; + ASSERT_THROW(ast_generator.ParseQuery(missing_name_query), SyntaxException); + + std::string correct_query = "DROP REPLICA replica1"; + auto *correct_query_parsed = dynamic_cast<ReplicationQuery *>(ast_generator.ParseQuery(correct_query)); + ASSERT_TRUE(correct_query_parsed); + EXPECT_EQ(correct_query_parsed->replica_name_, "replica1"); +} + +TEST_P(CypherMainVisitorTest, TestExplainRegularQuery) { + auto &ast_generator = *GetParam(); + EXPECT_TRUE(dynamic_cast<ExplainQuery *>(ast_generator.ParseQuery("EXPLAIN RETURN n"))); +} + +TEST_P(CypherMainVisitorTest, TestExplainExplainQuery) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("EXPLAIN EXPLAIN RETURN n"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, TestExplainAuthQuery) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("EXPLAIN SHOW ROLES"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, TestProfileRegularQuery) { + { + auto &ast_generator = *GetParam(); + EXPECT_TRUE(dynamic_cast<ProfileQuery *>(ast_generator.ParseQuery("PROFILE RETURN n"))); + } +} + +TEST_P(CypherMainVisitorTest, TestProfileComplicatedQuery) { + { + auto &ast_generator = *GetParam(); + EXPECT_TRUE( + dynamic_cast<ProfileQuery *>(ast_generator.ParseQuery("profile optional match (n) where n.hello = 5 " + "return n union optional match (n) where n.there = 10 " + "return n"))); + } +} + +TEST_P(CypherMainVisitorTest, TestProfileProfileQuery) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("PROFILE PROFILE RETURN n"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, TestProfileAuthQuery) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("PROFILE SHOW ROLES"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, TestShowStorageInfo) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<InfoQuery *>(ast_generator.ParseQuery("SHOW STORAGE INFO")); + ASSERT_TRUE(query); + EXPECT_EQ(query->info_type_, InfoQuery::InfoType::STORAGE); +} + +TEST_P(CypherMainVisitorTest, TestShowIndexInfo) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<InfoQuery *>(ast_generator.ParseQuery("SHOW INDEX INFO")); + ASSERT_TRUE(query); + EXPECT_EQ(query->info_type_, InfoQuery::InfoType::INDEX); +} + +TEST_P(CypherMainVisitorTest, TestShowConstraintInfo) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<InfoQuery *>(ast_generator.ParseQuery("SHOW CONSTRAINT INFO")); + ASSERT_TRUE(query); + EXPECT_EQ(query->info_type_, InfoQuery::InfoType::CONSTRAINT); +} + +TEST_P(CypherMainVisitorTest, CreateConstraintSyntaxError) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (:label) ASSERT EXISTS"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT () ASSERT EXISTS"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON () ASSERT EXISTS(prop1)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON () ASSERT EXISTS (prop1, prop2)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT " + "EXISTS (n.prop1, missing.prop2)"), + SemanticException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT " + "EXISTS (m.prop1, m.prop2)"), + SemanticException); + + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (:label) ASSERT IS UNIQUE"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT () ASSERT IS UNIQUE"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON () ASSERT prop1 IS UNIQUE"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON () ASSERT prop1, prop2 IS UNIQUE"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT " + "n.prop1, missing.prop2 IS UNIQUE"), + SemanticException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT " + "m.prop1, m.prop2 IS UNIQUE"), + SemanticException); + + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (:label) ASSERT IS NODE KEY"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT () ASSERT IS NODE KEY"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON () ASSERT (prop1) IS NODE KEY"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON () ASSERT (prop1, prop2) IS NODE KEY"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT " + "(n.prop1, missing.prop2) IS NODE KEY"), + SemanticException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT " + "(m.prop1, m.prop2) IS NODE KEY"), + SemanticException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT " + "n.prop1, n.prop2 IS NODE KEY"), + SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT " + "exists(n.prop1, n.prop2) IS NODE KEY"), + SyntaxException); +} + +TEST_P(CypherMainVisitorTest, CreateConstraint) { + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<ConstraintQuery *>( + ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT EXISTS(n.prop1)")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::CREATE); + EXPECT_EQ(query->constraint_.type, Constraint::Type::EXISTS); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, UnorderedElementsAre(ast_generator.Prop("prop1"))); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<ConstraintQuery *>( + ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT EXISTS (n.prop1, n.prop2)")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::CREATE); + EXPECT_EQ(query->constraint_.type, Constraint::Type::EXISTS); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, + UnorderedElementsAre(ast_generator.Prop("prop1"), ast_generator.Prop("prop2"))); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<ConstraintQuery *>( + ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT n.prop1 IS UNIQUE")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::CREATE); + EXPECT_EQ(query->constraint_.type, Constraint::Type::UNIQUE); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, UnorderedElementsAre(ast_generator.Prop("prop1"))); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<ConstraintQuery *>( + ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT n.prop1, n.prop2 IS UNIQUE")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::CREATE); + EXPECT_EQ(query->constraint_.type, Constraint::Type::UNIQUE); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, + UnorderedElementsAre(ast_generator.Prop("prop1"), ast_generator.Prop("prop2"))); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<ConstraintQuery *>( + ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT (n.prop1) IS NODE KEY")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::CREATE); + EXPECT_EQ(query->constraint_.type, Constraint::Type::NODE_KEY); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, UnorderedElementsAre(ast_generator.Prop("prop1"))); + } + { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast<ConstraintQuery *>(ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT " + "(n.prop1, n.prop2) IS NODE KEY")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::CREATE); + EXPECT_EQ(query->constraint_.type, Constraint::Type::NODE_KEY); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, + UnorderedElementsAre(ast_generator.Prop("prop1"), ast_generator.Prop("prop2"))); + } +} + +TEST_P(CypherMainVisitorTest, DropConstraint) { + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<ConstraintQuery *>( + ast_generator.ParseQuery("DROP CONSTRAINT ON (n:label) ASSERT EXISTS(n.prop1)")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::DROP); + EXPECT_EQ(query->constraint_.type, Constraint::Type::EXISTS); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, UnorderedElementsAre(ast_generator.Prop("prop1"))); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<ConstraintQuery *>( + ast_generator.ParseQuery("DROP CONSTRAINT ON (n:label) ASSERT EXISTS(n.prop1, n.prop2)")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::DROP); + EXPECT_EQ(query->constraint_.type, Constraint::Type::EXISTS); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, + UnorderedElementsAre(ast_generator.Prop("prop1"), ast_generator.Prop("prop2"))); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<ConstraintQuery *>( + ast_generator.ParseQuery("DROP CONSTRAINT ON (n:label) ASSERT n.prop1 IS UNIQUE")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::DROP); + EXPECT_EQ(query->constraint_.type, Constraint::Type::UNIQUE); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, UnorderedElementsAre(ast_generator.Prop("prop1"))); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<ConstraintQuery *>( + ast_generator.ParseQuery("DROP CONSTRAINT ON (n:label) ASSERT n.prop1, n.prop2 IS UNIQUE")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::DROP); + EXPECT_EQ(query->constraint_.type, Constraint::Type::UNIQUE); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, + UnorderedElementsAre(ast_generator.Prop("prop1"), ast_generator.Prop("prop2"))); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<ConstraintQuery *>( + ast_generator.ParseQuery("DROP CONSTRAINT ON (n:label) ASSERT (n.prop1) IS NODE KEY")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::DROP); + EXPECT_EQ(query->constraint_.type, Constraint::Type::NODE_KEY); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, UnorderedElementsAre(ast_generator.Prop("prop1"))); + } + { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast<ConstraintQuery *>(ast_generator.ParseQuery("DROP CONSTRAINT ON (n:label) ASSERT " + "(n.prop1, n.prop2) IS NODE KEY")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::DROP); + EXPECT_EQ(query->constraint_.type, Constraint::Type::NODE_KEY); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, + UnorderedElementsAre(ast_generator.Prop("prop1"), ast_generator.Prop("prop2"))); + } +} + +TEST_P(CypherMainVisitorTest, RegexMatch) { + { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MATCH (n) WHERE n.name =~ \".*bla.*\" RETURN n.name")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *match_clause = dynamic_cast<Match *>(single_query->clauses_[0]); + ASSERT_TRUE(match_clause); + auto *regex_match = dynamic_cast<RegexMatch *>(match_clause->where_->expression_); + ASSERT_TRUE(regex_match); + ASSERT_TRUE(dynamic_cast<PropertyLookup *>(regex_match->string_expr_)); + ast_generator.CheckLiteral(regex_match->regex_, ".*bla.*"); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN \"text\" =~ \".*bla.*\"")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_TRUE(return_clause); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 1U); + auto *named_expression = return_clause->body_.named_expressions[0]; + auto *regex_match = dynamic_cast<RegexMatch *>(named_expression->expression_); + ASSERT_TRUE(regex_match); + ast_generator.CheckLiteral(regex_match->string_expr_, "text"); + ast_generator.CheckLiteral(regex_match->regex_, ".*bla.*"); + } +} + +// NOLINTNEXTLINE(hicpp-special-member-functions) +TEST_P(CypherMainVisitorTest, DumpDatabase) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<DumpQuery *>(ast_generator.ParseQuery("DUMP DATABASE")); + ASSERT_TRUE(query); +} + +namespace { +template <class TAst> +void CheckCallProcedureDefaultMemoryLimit(const TAst &ast, const CallProcedure &call_proc) { + // Should be 100 MB + auto *literal = dynamic_cast<PrimitiveLiteral *>(call_proc.memory_limit_); + ASSERT_TRUE(literal); + TypedValue value(literal->value_); + ASSERT_TRUE(TypedValue::BoolEqual{}(value, TypedValue(100))); + ASSERT_EQ(call_proc.memory_scale_, 1024 * 1024); +} +} // namespace + +TEST_P(CypherMainVisitorTest, CallProcedureWithDotsInName) { + AddProc(*mock_module_with_dots_in_name, "proc", {}, {"res"}, ProcedureType::WRITE); + auto &ast_generator = *GetParam(); + + auto *query = + dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mock_module.with.dots.in.name.proc() YIELD res")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mock_module.with.dots.in.name.proc"); + ASSERT_TRUE(call_proc->arguments_.empty()); + std::vector<std::string> identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + ASSERT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + std::vector<std::string> expected_names{"res"}; + ASSERT_EQ(identifier_names, expected_names); + ASSERT_EQ(identifier_names, call_proc->result_fields_); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithDashesInName) { + AddProc(*mock_module, "proc-with-dashes", {}, {"res"}, ProcedureType::READ); + auto &ast_generator = *GetParam(); + + auto *query = + dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL `mock_module.proc-with-dashes`() YIELD res")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mock_module.proc-with-dashes"); + ASSERT_TRUE(call_proc->arguments_.empty()); + std::vector<std::string> identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + ASSERT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + std::vector<std::string> expected_names{"res"}; + ASSERT_EQ(identifier_names, expected_names); + ASSERT_EQ(identifier_names, call_proc->result_fields_); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithYieldSomeFields) { + auto &ast_generator = *GetParam(); + auto check_proc = [this, &ast_generator](const ProcedureType type) { + const auto proc_name = std::string{"proc_"} + ToString(type); + SCOPED_TRACE(proc_name); + const auto fully_qualified_proc_name = std::string{"mock_module."} + proc_name; + AddProc(*mock_module, proc_name.c_str(), {}, {"fst", "field-with-dashes", "last_field"}, type); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery( + fmt::format("CALL {}() YIELD fst, `field-with-dashes`, last_field", fully_qualified_proc_name))); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->is_write_, type == ProcedureType::WRITE); + ASSERT_EQ(call_proc->procedure_name_, fully_qualified_proc_name); + ASSERT_TRUE(call_proc->arguments_.empty()); + ASSERT_EQ(call_proc->result_fields_.size(), 3U); + ASSERT_EQ(call_proc->result_identifiers_.size(), call_proc->result_fields_.size()); + std::vector<std::string> identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + ASSERT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + std::vector<std::string> expected_names{"fst", "field-with-dashes", "last_field"}; + ASSERT_EQ(identifier_names, expected_names); + ASSERT_EQ(identifier_names, call_proc->result_fields_); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); + }; + check_proc(ProcedureType::READ); + check_proc(ProcedureType::WRITE); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithYieldAliasedFields) { + AddProc(*mock_module, "proc", {}, {"fst", "snd", "thrd"}, ProcedureType::READ); + auto &ast_generator = *GetParam(); + + auto *query = + dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mock_module.proc() YIELD fst AS res1, snd AS " + "`result-with-dashes`, thrd AS last_result")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mock_module.proc"); + ASSERT_TRUE(call_proc->arguments_.empty()); + ASSERT_EQ(call_proc->result_fields_.size(), 3U); + ASSERT_EQ(call_proc->result_identifiers_.size(), call_proc->result_fields_.size()); + std::vector<std::string> identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + ASSERT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + std::vector<std::string> aliased_names{"res1", "result-with-dashes", "last_result"}; + ASSERT_EQ(identifier_names, aliased_names); + std::vector<std::string> field_names{"fst", "snd", "thrd"}; + ASSERT_EQ(call_proc->result_fields_, field_names); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithArguments) { + AddProc(*mock_module, "proc", {"arg1", "arg2", "arg3"}, {"res"}, ProcedureType::READ); + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mock_module.proc(0, 1, 2) YIELD res")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mock_module.proc"); + ASSERT_EQ(call_proc->arguments_.size(), 3U); + for (int64_t i = 0; i < 3; ++i) { + ast_generator.CheckLiteral(call_proc->arguments_[i], i); + } + std::vector<std::string> identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + ASSERT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + std::vector<std::string> expected_names{"res"}; + ASSERT_EQ(identifier_names, expected_names); + ASSERT_EQ(identifier_names, call_proc->result_fields_); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); +} + +TEST_P(CypherMainVisitorTest, CallProcedureYieldAsterisk) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mg.procedures() YIELD *")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mg.procedures"); + ASSERT_TRUE(call_proc->arguments_.empty()); + std::vector<std::string> identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + ASSERT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + ASSERT_THAT(identifier_names, UnorderedElementsAre("name", "signature", "is_write", "path", "is_editable")); + ASSERT_EQ(identifier_names, call_proc->result_fields_); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); +} + +TEST_P(CypherMainVisitorTest, CallProcedureYieldAsteriskReturnAsterisk) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mg.procedures() YIELD * RETURN *")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *ret = dynamic_cast<Return *>(single_query->clauses_[1]); + ASSERT_TRUE(ret); + ASSERT_TRUE(ret->body_.all_identifiers); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mg.procedures"); + ASSERT_TRUE(call_proc->arguments_.empty()); + std::vector<std::string> identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + ASSERT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + ASSERT_THAT(identifier_names, UnorderedElementsAre("name", "signature", "is_write", "path", "is_editable")); + ASSERT_EQ(identifier_names, call_proc->result_fields_); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithoutYield) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mg.load_all()")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mg.load_all"); + ASSERT_TRUE(call_proc->arguments_.empty()); + ASSERT_TRUE(call_proc->result_fields_.empty()); + ASSERT_TRUE(call_proc->result_identifiers_.empty()); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithMemoryLimitWithoutYield) { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mg.load_all() PROCEDURE MEMORY LIMIT 32 KB")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mg.load_all"); + ASSERT_TRUE(call_proc->arguments_.empty()); + ASSERT_TRUE(call_proc->result_fields_.empty()); + ASSERT_TRUE(call_proc->result_identifiers_.empty()); + ast_generator.CheckLiteral(call_proc->memory_limit_, 32); + ASSERT_EQ(call_proc->memory_scale_, 1024); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithMemoryUnlimitedWithoutYield) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mg.load_all() PROCEDURE MEMORY UNLIMITED")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mg.load_all"); + ASSERT_TRUE(call_proc->arguments_.empty()); + ASSERT_TRUE(call_proc->result_fields_.empty()); + ASSERT_TRUE(call_proc->result_identifiers_.empty()); + ASSERT_FALSE(call_proc->memory_limit_); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithMemoryLimit) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>( + ast_generator.ParseQuery("CALL mg.load_all() PROCEDURE MEMORY LIMIT 32 MB YIELD res")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mg.load_all"); + ASSERT_TRUE(call_proc->arguments_.empty()); + std::vector<std::string> identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + ASSERT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + std::vector<std::string> expected_names{"res"}; + ASSERT_EQ(identifier_names, expected_names); + ASSERT_EQ(identifier_names, call_proc->result_fields_); + ast_generator.CheckLiteral(call_proc->memory_limit_, 32); + ASSERT_EQ(call_proc->memory_scale_, 1024 * 1024); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithMemoryUnlimited) { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mg.load_all() PROCEDURE MEMORY UNLIMITED YIELD res")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mg.load_all"); + ASSERT_TRUE(call_proc->arguments_.empty()); + std::vector<std::string> identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + ASSERT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + std::vector<std::string> expected_names{"res"}; + ASSERT_EQ(identifier_names, expected_names); + ASSERT_EQ(identifier_names, call_proc->result_fields_); + ASSERT_FALSE(call_proc->memory_limit_); +} + +namespace { +template <typename TException = SyntaxException> +void TestInvalidQuery(const auto &query, Base &ast_generator) { + SCOPED_TRACE(query); + EXPECT_THROW(ast_generator.ParseQuery(query), TException) << query; +} + +template <typename TException = SyntaxException> +void TestInvalidQueryWithMessage(const auto &query, Base &ast_generator, const std::string_view message) { + bool exception_is_thrown = false; + try { + ast_generator.ParseQuery(query); + } catch (const TException &se) { + EXPECT_EQ(std::string_view{se.what()}, message); + exception_is_thrown = true; + } catch (...) { + FAIL() << "Unexpected exception"; + } + EXPECT_TRUE(exception_is_thrown); +} + +void CheckParsedCallProcedure(const CypherQuery &query, Base &ast_generator, + const std::string_view fully_qualified_proc_name, + const std::vector<std::string_view> &args, const ProcedureType type, + const size_t clauses_size, const size_t call_procedure_index) { + ASSERT_NE(query.single_query_, nullptr); + auto *single_query = query.single_query_; + EXPECT_EQ(single_query->clauses_.size(), clauses_size); + ASSERT_FALSE(single_query->clauses_.empty()); + ASSERT_LT(call_procedure_index, clauses_size); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[call_procedure_index]); + ASSERT_NE(call_proc, nullptr); + EXPECT_EQ(call_proc->procedure_name_, fully_qualified_proc_name); + EXPECT_TRUE(call_proc->arguments_.empty()); + EXPECT_EQ(call_proc->result_fields_.size(), 2U); + EXPECT_EQ(call_proc->result_identifiers_.size(), call_proc->result_fields_.size()); + std::vector<std::string> identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + EXPECT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + std::vector<std::string> args_as_str{}; + std::transform(args.begin(), args.end(), std::back_inserter(args_as_str), + [](const std::string_view arg) { return std::string{arg}; }); + EXPECT_EQ(identifier_names, args_as_str); + EXPECT_EQ(identifier_names, call_proc->result_fields_); + ASSERT_EQ(call_proc->is_write_, type == ProcedureType::WRITE); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); +}; +} // namespace + +TEST_P(CypherMainVisitorTest, CallProcedureMultipleQueryPartsAfter) { + auto &ast_generator = *GetParam(); + static constexpr std::string_view fst{"fst"}; + static constexpr std::string_view snd{"snd"}; + const std::vector args{fst, snd}; + + const auto read_proc = CreateProcByType(ProcedureType::READ, args); + const auto write_proc = CreateProcByType(ProcedureType::WRITE, args); + + const auto check_parsed_call_proc = [&ast_generator, &args](const CypherQuery &query, + const std::string_view fully_qualified_proc_name, + const ProcedureType type, const size_t clause_size) { + CheckParsedCallProcedure(query, ast_generator, fully_qualified_proc_name, args, type, clause_size, 0); + }; + { + SCOPED_TRACE("Read query part"); + { + SCOPED_TRACE("With WITH"); + static constexpr std::string_view kQueryWithWith{"CALL {}() YIELD {},{} WITH {},{} UNWIND {} as u RETURN u"}; + static constexpr size_t kQueryParts{4}; + { + SCOPED_TRACE("Write proc"); + const auto query_str = fmt::format(kQueryWithWith, write_proc, fst, snd, fst, snd, fst); + TestInvalidQueryWithMessage<SemanticException>( + query_str, ast_generator, + "WITH can't be put after calling a writeable procedure, only RETURN clause can be put after."); + } + { + SCOPED_TRACE("Read proc"); + const auto query_str = fmt::format(kQueryWithWith, read_proc, fst, snd, fst, snd, fst); + const auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(query_str)); + ASSERT_NE(query, nullptr); + check_parsed_call_proc(*query, read_proc, ProcedureType::READ, kQueryParts); + } + } + { + SCOPED_TRACE("Without WITH"); + static constexpr std::string_view kQueryWithoutWith{"CALL {}() YIELD {},{} UNWIND {} as u RETURN u"}; + static constexpr size_t kQueryParts{3}; + { + SCOPED_TRACE("Write proc"); + const auto query_str = fmt::format(kQueryWithoutWith, write_proc, fst, snd, fst); + TestInvalidQueryWithMessage<SemanticException>( + query_str, ast_generator, + "UNWIND can't be put after calling a writeable procedure, only RETURN clause can be put after."); + } + { + SCOPED_TRACE("Read proc"); + const auto query_str = fmt::format(kQueryWithoutWith, read_proc, fst, snd, fst); + const auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(query_str)); + ASSERT_NE(query, nullptr); + check_parsed_call_proc(*query, read_proc, ProcedureType::READ, kQueryParts); + } + } + } + { + SCOPED_TRACE("Write query part"); + { + SCOPED_TRACE("With WITH"); + static constexpr std::string_view kQueryWithWith{ + "CALL {}() YIELD {},{} WITH {},{} CREATE(n {{prop : {}}}) RETURN n"}; + static constexpr size_t kQueryParts{4}; + { + SCOPED_TRACE("Write proc"); + const auto query_str = fmt::format(kQueryWithWith, write_proc, fst, snd, fst, snd, fst); + TestInvalidQueryWithMessage<SemanticException>( + query_str, ast_generator, + "WITH can't be put after calling a writeable procedure, only RETURN clause can be put after."); + } + { + SCOPED_TRACE("Read proc"); + const auto query_str = fmt::format(kQueryWithWith, read_proc, fst, snd, fst, snd, fst); + const auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(query_str)); + ASSERT_NE(query, nullptr); + check_parsed_call_proc(*query, read_proc, ProcedureType::READ, kQueryParts); + } + } + { + SCOPED_TRACE("Without WITH"); + static constexpr std::string_view kQueryWithoutWith{"CALL {}() YIELD {},{} CREATE(n {{prop : {}}}) RETURN n"}; + static constexpr size_t kQueryParts{3}; + { + SCOPED_TRACE("Write proc"); + const auto query_str = fmt::format(kQueryWithoutWith, write_proc, fst, snd, fst); + TestInvalidQueryWithMessage<SemanticException>( + query_str, ast_generator, + "Update clause can't be put after calling a writeable procedure, only RETURN clause can be put after."); + } + { + SCOPED_TRACE("Read proc"); + const auto query_str = fmt::format(kQueryWithoutWith, read_proc, fst, snd, fst, snd, fst); + const auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(query_str)); + ASSERT_NE(query, nullptr); + check_parsed_call_proc(*query, read_proc, ProcedureType::READ, kQueryParts); + } + } + } +} + +TEST_P(CypherMainVisitorTest, CallProcedureMultipleQueryPartsBefore) { + auto &ast_generator = *GetParam(); + static constexpr std::string_view fst{"fst"}; + static constexpr std::string_view snd{"snd"}; + const std::vector args{fst, snd}; + + const auto read_proc = CreateProcByType(ProcedureType::READ, args); + const auto write_proc = CreateProcByType(ProcedureType::WRITE, args); + + const auto check_parsed_call_proc = [&ast_generator, &args](const CypherQuery &query, + const std::string_view fully_qualified_proc_name, + const ProcedureType type, const size_t clause_size) { + CheckParsedCallProcedure(query, ast_generator, fully_qualified_proc_name, args, type, clause_size, clause_size - 2); + }; + { + SCOPED_TRACE("Read query part"); + static constexpr std::string_view kQueryWithReadQueryPart{"MATCH (n) CALL {}() YIELD * RETURN *"}; + static constexpr size_t kQueryParts{3}; + { + SCOPED_TRACE("Write proc"); + const auto query_str = fmt::format(kQueryWithReadQueryPart, write_proc); + const auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(query_str)); + ASSERT_NE(query, nullptr); + check_parsed_call_proc(*query, write_proc, ProcedureType::WRITE, kQueryParts); + } + { + SCOPED_TRACE("Read proc"); + const auto query_str = fmt::format(kQueryWithReadQueryPart, read_proc); + const auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(query_str)); + ASSERT_NE(query, nullptr); + check_parsed_call_proc(*query, read_proc, ProcedureType::READ, kQueryParts); + } + } + { + SCOPED_TRACE("Write query part"); + static constexpr std::string_view kQueryWithWriteQueryPart{"CREATE (n) WITH n CALL {}() YIELD * RETURN *"}; + static constexpr size_t kQueryParts{4}; + { + SCOPED_TRACE("Write proc"); + const auto query_str = fmt::format(kQueryWithWriteQueryPart, write_proc, fst, snd, fst); + TestInvalidQueryWithMessage<SemanticException>( + query_str, ast_generator, "Write procedures cannot be used in queries that contains any update clauses!"); + } + { + SCOPED_TRACE("Read proc"); + const auto query_str = fmt::format(kQueryWithWriteQueryPart, read_proc, fst, snd, fst, snd, fst); + const auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(query_str)); + ASSERT_NE(query, nullptr); + check_parsed_call_proc(*query, read_proc, ProcedureType::READ, kQueryParts); + } + } +} + +TEST_P(CypherMainVisitorTest, CallProcedureMultipleProcedures) { + auto &ast_generator = *GetParam(); + static constexpr std::string_view fst{"fst"}; + static constexpr std::string_view snd{"snd"}; + const std::vector args{fst, snd}; + + const auto read_proc = CreateProcByType(ProcedureType::READ, args); + const auto write_proc = CreateProcByType(ProcedureType::WRITE, args); + + { + SCOPED_TRACE("Read then write"); + const auto query_str = fmt::format("CALL {}() YIELD * CALL {}() YIELD * RETURN *", read_proc, write_proc); + static constexpr size_t kQueryParts{3}; + const auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(query_str)); + ASSERT_NE(query, nullptr); + + CheckParsedCallProcedure(*query, ast_generator, read_proc, args, ProcedureType::READ, kQueryParts, 0); + CheckParsedCallProcedure(*query, ast_generator, write_proc, args, ProcedureType::WRITE, kQueryParts, 1); + } + { + SCOPED_TRACE("Write then read"); + const auto query_str = fmt::format("CALL {}() YIELD * CALL {}() YIELD * RETURN *", write_proc, read_proc); + TestInvalidQueryWithMessage<SemanticException>( + query_str, ast_generator, + "CALL can't be put after calling a writeable procedure, only RETURN clause can be put after."); + } + { + SCOPED_TRACE("Write twice"); + const auto query_str = fmt::format("CALL {}() YIELD * CALL {}() YIELD * RETURN *", write_proc, write_proc); + TestInvalidQueryWithMessage<SemanticException>( + query_str, ast_generator, + "CALL can't be put after calling a writeable procedure, only RETURN clause can be put after."); + } +} + +TEST_P(CypherMainVisitorTest, IncorrectCallProcedure) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc-with-dashes()"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc() yield field-with-dashes"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc() yield field.with.dots"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc() yield res AS result-with-dashes"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc() yield res AS result.with.dots"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("WITH 42 AS x CALL not_standalone(x)"), SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("CALL procedure() YIELD"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 42, CALL procedure() YIELD"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 42, CALL procedure() YIELD res"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 42 AS x CALL procedure() YIELD res"), SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc.with.dots() MEMORY YIELD res"), SyntaxException); + // mg.procedures returns something, so it needs to have a YIELD. + ASSERT_THROW(ast_generator.ParseQuery("CALL mg.procedures()"), SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("CALL mg.procedures() PROCEDURE MEMORY UNLIMITED"), SemanticException); + // TODO: Implement support for the following syntax. These are defined in + // Neo4j and accepted in openCypher CIP. + ASSERT_THROW(ast_generator.ParseQuery("CALL proc"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc RETURN 42"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc() YIELD res WHERE res > 42"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc() YIELD res WHERE res > 42 RETURN *"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, TestLockPathQuery) { + auto &ast_generator = *GetParam(); + + const auto test_lock_path_query = [&](const std::string_view command, const LockPathQuery::Action action) { + ASSERT_THROW(ast_generator.ParseQuery(command.data()), SyntaxException); + + { + const std::string query = fmt::format("{} ME", command); + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = fmt::format("{} DATA", command); + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = fmt::format("{} DATA STUFF", command); + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = fmt::format("{} DATA DIRECTORY", command); + auto *parsed_query = dynamic_cast<LockPathQuery *>(ast_generator.ParseQuery(query)); + ASSERT_TRUE(parsed_query); + EXPECT_EQ(parsed_query->action_, action); + } + }; + + test_lock_path_query("LOCK", LockPathQuery::Action::LOCK_PATH); + test_lock_path_query("UNLOCK", LockPathQuery::Action::UNLOCK_PATH); +} + +TEST_P(CypherMainVisitorTest, TestLoadCsvClause) { + auto &ast_generator = *GetParam(); + + { + const std::string query = R"(LOAD CSV FROM "file.csv")"; + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = R"(LOAD CSV FROM "file.csv" WITH)"; + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = R"(LOAD CSV FROM "file.csv" WITH HEADER)"; + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = R"(LOAD CSV FROM "file.csv" WITH HEADER DELIMITER ";")"; + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = R"(LOAD CSV FROM "file.csv" WITH HEADER DELIMITER ";" QUOTE "'")"; + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = R"(LOAD CSV FROM "file.csv" WITH HEADER DELIMITER ";" QUOTE "'" AS)"; + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = R"(LOAD CSV FROM file WITH HEADER IGNORE BAD DELIMITER ";" QUOTE "'" AS x)"; + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = R"(LOAD CSV FROM "file.csv" WITH HEADER IGNORE BAD DELIMITER 0 QUOTE "'" AS x)"; + ASSERT_THROW(ast_generator.ParseQuery(query), SemanticException); + } + + { + const std::string query = R"(LOAD CSV FROM "file.csv" WITH HEADER IGNORE BAD DELIMITER ";" QUOTE 0 AS x)"; + ASSERT_THROW(ast_generator.ParseQuery(query), SemanticException); + } + + { + // can't be a standalone clause + const std::string query = R"(LOAD CSV FROM "file.csv" WITH HEADER IGNORE BAD DELIMITER ";" QUOTE "'" AS x)"; + ASSERT_THROW(ast_generator.ParseQuery(query), SemanticException); + } + + { + const std::string query = + R"(LOAD CSV FROM "file.csv" WITH HEADER IGNORE BAD DELIMITER ";" QUOTE "'" AS x RETURN x)"; + auto *parsed_query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(query)); + ASSERT_TRUE(parsed_query); + auto *load_csv_clause = dynamic_cast<LoadCsv *>(parsed_query->single_query_->clauses_[0]); + ASSERT_TRUE(load_csv_clause); + ASSERT_TRUE(load_csv_clause->with_header_); + ASSERT_TRUE(load_csv_clause->ignore_bad_); + } +} + +TEST_P(CypherMainVisitorTest, MemoryLimit) { + auto &ast_generator = *GetParam(); + + ASSERT_THROW(ast_generator.ParseQuery("RETURN x QUE"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN x QUERY"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN x QUERY MEM"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN x QUERY MEMORY"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN x QUERY MEMORY LIM"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN x QUERY MEMORY LIMIT"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN x QUERY MEMORY LIMIT KB"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN x QUERY MEMORY LIMIT 12GB"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("QUERY MEMORY LIMIT 12KB RETURN x"), SyntaxException); + + { + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN x")); + ASSERT_TRUE(query); + ASSERT_FALSE(query->memory_limit_); + } + + { + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN x QUERY MEMORY LIMIT 12KB")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->memory_limit_); + ast_generator.CheckLiteral(query->memory_limit_, 12); + ASSERT_EQ(query->memory_scale_, 1024U); + } + + { + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("RETURN x QUERY MEMORY LIMIT 12MB")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->memory_limit_); + ast_generator.CheckLiteral(query->memory_limit_, 12); + ASSERT_EQ(query->memory_scale_, 1024U * 1024U); + } + + { + auto *query = dynamic_cast<CypherQuery *>( + ast_generator.ParseQuery("CALL mg.procedures() YIELD x RETURN x QUERY MEMORY LIMIT 12MB")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->memory_limit_); + ast_generator.CheckLiteral(query->memory_limit_, 12); + ASSERT_EQ(query->memory_scale_, 1024U * 1024U); + + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); + } + + { + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery( + "CALL mg.procedures() PROCEDURE MEMORY LIMIT 3KB YIELD x RETURN x QUERY MEMORY LIMIT 12MB")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->memory_limit_); + ast_generator.CheckLiteral(query->memory_limit_, 12); + ASSERT_EQ(query->memory_scale_, 1024U * 1024U); + + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + ASSERT_TRUE(call_proc->memory_limit_); + ast_generator.CheckLiteral(call_proc->memory_limit_, 3); + ASSERT_EQ(call_proc->memory_scale_, 1024U); + } + + { + auto *query = dynamic_cast<CypherQuery *>( + ast_generator.ParseQuery("CALL mg.procedures() PROCEDURE MEMORY LIMIT 3KB YIELD x RETURN x")); + ASSERT_TRUE(query); + ASSERT_FALSE(query->memory_limit_); + + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + ASSERT_TRUE(call_proc->memory_limit_); + ast_generator.CheckLiteral(call_proc->memory_limit_, 3); + ASSERT_EQ(call_proc->memory_scale_, 1024U); + } + + { + auto *query = + dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mg.load_all() PROCEDURE MEMORY LIMIT 3KB")); + ASSERT_TRUE(query); + ASSERT_FALSE(query->memory_limit_); + + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + ASSERT_TRUE(call_proc->memory_limit_); + ast_generator.CheckLiteral(call_proc->memory_limit_, 3); + ASSERT_EQ(call_proc->memory_scale_, 1024U); + } + + { + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mg.load_all() QUERY MEMORY LIMIT 3KB")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->memory_limit_); + ast_generator.CheckLiteral(query->memory_limit_, 3); + ASSERT_EQ(query->memory_scale_, 1024U); + + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); + } +} + +TEST_P(CypherMainVisitorTest, DropTrigger) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("DROP TR", ast_generator); + TestInvalidQuery("DROP TRIGGER", ast_generator); + + auto *parsed_query = dynamic_cast<TriggerQuery *>(ast_generator.ParseQuery("DROP TRIGGER trigger")); + EXPECT_EQ(parsed_query->action_, TriggerQuery::Action::DROP_TRIGGER); + EXPECT_EQ(parsed_query->trigger_name_, "trigger"); +} + +TEST_P(CypherMainVisitorTest, ShowTriggers) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("SHOW TR", ast_generator); + TestInvalidQuery("SHOW TRIGGER", ast_generator); + + auto *parsed_query = dynamic_cast<TriggerQuery *>(ast_generator.ParseQuery("SHOW TRIGGERS")); + EXPECT_EQ(parsed_query->action_, TriggerQuery::Action::SHOW_TRIGGERS); +} + +namespace { +void ValidateCreateQuery(Base &ast_generator, const auto &query, const auto &trigger_name, + const memgraph::query::v2::TriggerQuery::EventType event_type, const auto &phase, + const auto &statement) { + auto *parsed_query = dynamic_cast<TriggerQuery *>(ast_generator.ParseQuery(query)); + EXPECT_EQ(parsed_query->action_, TriggerQuery::Action::CREATE_TRIGGER); + EXPECT_EQ(parsed_query->trigger_name_, trigger_name); + EXPECT_EQ(parsed_query->event_type_, event_type); + EXPECT_EQ(parsed_query->before_commit_, phase == "BEFORE"); + EXPECT_EQ(parsed_query->statement_, statement); +} +} // namespace + +TEST_P(CypherMainVisitorTest, CreateTriggers) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("CREATE TRIGGER", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON ", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON ()", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON -->", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON CREATE", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON () CREATE", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON --> CREATE", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON DELETE", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON () DELETE", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON --> DELETE", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON UPDATE", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON () UPDATE", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON --> UPDATE", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON CREATE BEFORE", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON CREATE BEFORE COMMIT", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON CREATE AFTER", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON CREATE AFTER COMMIT", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON -> CREATE AFTER COMMIT EXECUTE a", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON ) CREATE AFTER COMMIT EXECUTE a", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON ( CREATE AFTER COMMIT EXECUTE a", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON CRETE AFTER COMMIT EXECUTE a", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON DELET AFTER COMMIT EXECUTE a", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON UPDTE AFTER COMMIT EXECUTE a", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON UPDATE COMMIT EXECUTE a", ast_generator); + + static constexpr std::string_view query_template = "CREATE TRIGGER trigger {} {} COMMIT EXECUTE {}"; + + static constexpr std::array events{ + std::pair{"", memgraph::query::v2::TriggerQuery::EventType::ANY}, + std::pair{"ON CREATE", memgraph::query::v2::TriggerQuery::EventType::CREATE}, + std::pair{"ON () CREATE", memgraph::query::v2::TriggerQuery::EventType::VERTEX_CREATE}, + std::pair{"ON --> CREATE", memgraph::query::v2::TriggerQuery::EventType::EDGE_CREATE}, + std::pair{"ON DELETE", memgraph::query::v2::TriggerQuery::EventType::DELETE}, + std::pair{"ON () DELETE", memgraph::query::v2::TriggerQuery::EventType::VERTEX_DELETE}, + std::pair{"ON --> DELETE", memgraph::query::v2::TriggerQuery::EventType::EDGE_DELETE}, + std::pair{"ON UPDATE", memgraph::query::v2::TriggerQuery::EventType::UPDATE}, + std::pair{"ON () UPDATE", memgraph::query::v2::TriggerQuery::EventType::VERTEX_UPDATE}, + std::pair{"ON --> UPDATE", memgraph::query::v2::TriggerQuery::EventType::EDGE_UPDATE}}; + + static constexpr std::array phases{"BEFORE", "AFTER"}; + + static constexpr std::array statements{ + "", "SOME SUPER\nSTATEMENT", "Statement with 12312321 3 ", " Statement with 12312321 3 " + + }; + + for (const auto &[event_string, event_type] : events) { + for (const auto &phase : phases) { + for (const auto &statement : statements) { + ValidateCreateQuery(ast_generator, fmt::format(query_template, event_string, phase, statement), "trigger", + event_type, phase, memgraph::utils::Trim(statement)); + } + } + } +} + +namespace { +void ValidateSetIsolationLevelQuery(Base &ast_generator, const auto &query, const auto scope, + const auto isolation_level) { + auto *parsed_query = dynamic_cast<IsolationLevelQuery *>(ast_generator.ParseQuery(query)); + EXPECT_EQ(parsed_query->isolation_level_scope_, scope); + EXPECT_EQ(parsed_query->isolation_level_, isolation_level); +} +} // namespace + +TEST_P(CypherMainVisitorTest, SetIsolationLevelQuery) { + auto &ast_generator = *GetParam(); + TestInvalidQuery("SET ISO", ast_generator); + TestInvalidQuery("SET TRANSACTION ISOLATION", ast_generator); + TestInvalidQuery("SET TRANSACTION ISOLATION LEVEL", ast_generator); + TestInvalidQuery("SET TRANSACTION ISOLATION LEVEL READ COMMITTED", ast_generator); + TestInvalidQuery("SET NEXT TRANSACTION ISOLATION LEVEL", ast_generator); + TestInvalidQuery("SET ISOLATION LEVEL READ COMMITTED", ast_generator); + TestInvalidQuery("SET GLOBAL ISOLATION LEVEL READ COMMITTED", ast_generator); + TestInvalidQuery("SET GLOBAL TRANSACTION ISOLATION LEVEL READ COMITTED", ast_generator); + TestInvalidQuery("SET GLOBAL TRANSACTION ISOLATION LEVEL READ_COMITTED", ast_generator); + TestInvalidQuery("SET SESSION TRANSACTION ISOLATION LEVEL READCOMITTED", ast_generator); + + static constexpr std::array scopes{ + std::pair{"GLOBAL", memgraph::query::v2::IsolationLevelQuery::IsolationLevelScope::GLOBAL}, + std::pair{"SESSION", memgraph::query::v2::IsolationLevelQuery::IsolationLevelScope::SESSION}, + std::pair{"NEXT", memgraph::query::v2::IsolationLevelQuery::IsolationLevelScope::NEXT}}; + static constexpr std::array isolation_levels{ + std::pair{"READ UNCOMMITTED", memgraph::query::v2::IsolationLevelQuery::IsolationLevel::READ_UNCOMMITTED}, + std::pair{"READ COMMITTED", memgraph::query::v2::IsolationLevelQuery::IsolationLevel::READ_COMMITTED}, + std::pair{"SNAPSHOT ISOLATION", memgraph::query::v2::IsolationLevelQuery::IsolationLevel::SNAPSHOT_ISOLATION}}; + + static constexpr const auto *query_template = "SET {} TRANSACTION ISOLATION LEVEL {}"; + + for (const auto &[scope_string, scope] : scopes) { + for (const auto &[isolation_level_string, isolation_level] : isolation_levels) { + ValidateSetIsolationLevelQuery(ast_generator, fmt::format(query_template, scope_string, isolation_level_string), + scope, isolation_level); + } + } +} + +TEST_P(CypherMainVisitorTest, CreateSnapshotQuery) { + auto &ast_generator = *GetParam(); + ASSERT_TRUE(dynamic_cast<CreateSnapshotQuery *>(ast_generator.ParseQuery("CREATE SNAPSHOT"))); +} + +void CheckOptionalExpression(Base &ast_generator, Expression *expression, const std::optional<TypedValue> &expected) { + EXPECT_EQ(expression != nullptr, expected.has_value()); + if (expected.has_value()) { + EXPECT_NO_FATAL_FAILURE(ast_generator.CheckLiteral(expression, *expected)); + } +}; + +void ValidateMostlyEmptyStreamQuery(Base &ast_generator, const std::string &query_string, + const StreamQuery::Action action, const std::string_view stream_name, + const std::optional<TypedValue> &batch_limit = std::nullopt, + const std::optional<TypedValue> &timeout = std::nullopt) { + auto *parsed_query = dynamic_cast<StreamQuery *>(ast_generator.ParseQuery(query_string)); + ASSERT_NE(parsed_query, nullptr); + EXPECT_EQ(parsed_query->action_, action); + EXPECT_EQ(parsed_query->stream_name_, stream_name); + auto topic_names = std::get_if<Expression *>(&parsed_query->topic_names_); + EXPECT_NE(topic_names, nullptr); + EXPECT_EQ(*topic_names, nullptr); + EXPECT_TRUE(topic_names); + EXPECT_FALSE(*topic_names); + EXPECT_TRUE(parsed_query->transform_name_.empty()); + EXPECT_TRUE(parsed_query->consumer_group_.empty()); + EXPECT_EQ(parsed_query->batch_interval_, nullptr); + EXPECT_EQ(parsed_query->batch_size_, nullptr); + EXPECT_EQ(parsed_query->service_url_, nullptr); + EXPECT_EQ(parsed_query->bootstrap_servers_, nullptr); + EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->batch_limit_, batch_limit)); + EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->timeout_, timeout)); + EXPECT_TRUE(parsed_query->configs_.empty()); + EXPECT_TRUE(parsed_query->credentials_.empty()); +} + +TEST_P(CypherMainVisitorTest, DropStream) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("DROP ST", ast_generator); + TestInvalidQuery("DROP STREAM", ast_generator); + TestInvalidQuery("DROP STREAMS", ast_generator); + + ValidateMostlyEmptyStreamQuery(ast_generator, "DrOP STREAm droppedStream", StreamQuery::Action::DROP_STREAM, + "droppedStream"); +} + +TEST_P(CypherMainVisitorTest, StartStream) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("START ST", ast_generator); + TestInvalidQuery("START STREAM", ast_generator); + TestInvalidQuery("START STREAMS", ast_generator); + + ValidateMostlyEmptyStreamQuery(ast_generator, "START STREAM startedStream", StreamQuery::Action::START_STREAM, + "startedStream"); +} + +TEST_P(CypherMainVisitorTest, StartAllStreams) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("START ALL", ast_generator); + TestInvalidQuery("START ALL STREAM", ast_generator); + TestInvalidQuery("START STREAMS ALL", ast_generator); + + ValidateMostlyEmptyStreamQuery(ast_generator, "StARt AlL StrEAMS", StreamQuery::Action::START_ALL_STREAMS, ""); +} + +TEST_P(CypherMainVisitorTest, ShowStreams) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("SHOW ALL", ast_generator); + TestInvalidQuery("SHOW STREAM", ast_generator); + TestInvalidQuery("SHOW STREAMS ALL", ast_generator); + + ValidateMostlyEmptyStreamQuery(ast_generator, "SHOW STREAMS", StreamQuery::Action::SHOW_STREAMS, ""); +} + +TEST_P(CypherMainVisitorTest, StopStream) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("STOP ST", ast_generator); + TestInvalidQuery("STOP STREAM", ast_generator); + TestInvalidQuery("STOP STREAMS", ast_generator); + TestInvalidQuery("STOP STREAM invalid stream name", ast_generator); + + ValidateMostlyEmptyStreamQuery(ast_generator, "STOP stREAM stoppedStream", StreamQuery::Action::STOP_STREAM, + "stoppedStream"); +} + +TEST_P(CypherMainVisitorTest, StopAllStreams) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("STOP ALL", ast_generator); + TestInvalidQuery("STOP ALL STREAM", ast_generator); + TestInvalidQuery("STOP STREAMS ALL", ast_generator); + + ValidateMostlyEmptyStreamQuery(ast_generator, "SToP ALL STReaMS", StreamQuery::Action::STOP_ALL_STREAMS, ""); +} + +void ValidateTopicNames(const auto &topic_names, const std::vector<std::string> &expected_topic_names, + Base &ast_generator) { + std::visit(memgraph::utils::Overloaded{ + [&](Expression *expression) { + ast_generator.CheckLiteral(expression, memgraph::utils::Join(expected_topic_names, ",")); + }, + [&](const std::vector<std::string> &topic_names) { EXPECT_EQ(topic_names, expected_topic_names); }}, + topic_names); +} + +void ValidateCreateKafkaStreamQuery(Base &ast_generator, const std::string &query_string, + const std::string_view stream_name, const std::vector<std::string> &topic_names, + const std::string_view transform_name, const std::string_view consumer_group, + const std::optional<TypedValue> &batch_interval, + const std::optional<TypedValue> &batch_size, + const std::string_view bootstrap_servers, + const std::unordered_map<std::string, std::string> &configs, + const std::unordered_map<std::string, std::string> &credentials) { + SCOPED_TRACE(query_string); + StreamQuery *parsed_query{nullptr}; + ASSERT_NO_THROW(parsed_query = dynamic_cast<StreamQuery *>(ast_generator.ParseQuery(query_string))) << query_string; + ASSERT_NE(parsed_query, nullptr); + EXPECT_EQ(parsed_query->stream_name_, stream_name); + ValidateTopicNames(parsed_query->topic_names_, topic_names, ast_generator); + EXPECT_EQ(parsed_query->transform_name_, transform_name); + EXPECT_EQ(parsed_query->consumer_group_, consumer_group); + EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->batch_interval_, batch_interval)); + EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->batch_size_, batch_size)); + EXPECT_EQ(parsed_query->batch_limit_, nullptr); + if (bootstrap_servers.empty()) { + EXPECT_EQ(parsed_query->bootstrap_servers_, nullptr); + } else { + EXPECT_NE(parsed_query->bootstrap_servers_, nullptr); + } + + const auto evaluate_config_map = [&ast_generator](const std::unordered_map<Expression *, Expression *> &config_map) { + std::unordered_map<std::string, std::string> evaluated_config_map; + const auto expr_to_str = [&ast_generator](Expression *expression) { + return std::string{ast_generator.GetLiteral(expression, ast_generator.context_.is_query_cached).ValueString()}; + }; + std::transform(config_map.begin(), config_map.end(), + std::inserter(evaluated_config_map, evaluated_config_map.end()), + [&expr_to_str](const auto expr_pair) { + return std::pair{expr_to_str(expr_pair.first), expr_to_str(expr_pair.second)}; + }); + return evaluated_config_map; + }; + + using testing::UnorderedElementsAreArray; + EXPECT_THAT(evaluate_config_map(parsed_query->configs_), UnorderedElementsAreArray(configs.begin(), configs.end())); + EXPECT_THAT(evaluate_config_map(parsed_query->credentials_), + UnorderedElementsAreArray(credentials.begin(), credentials.end())); +} + +TEST_P(CypherMainVisitorTest, CreateKafkaStream) { + auto &ast_generator = *GetParam(); + TestInvalidQuery("CREATE KAFKA STREAM", ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM invalid stream name TOPICS topic1 TRANSFORM transform", ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS invalid topic name TRANSFORM transform", ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM invalid transformation name", ast_generator); + // required configs are missing + TestInvalidQuery<SemanticException>("CREATE KAFKA STREAM stream TRANSFORM transform", ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS TRANSFORM transform", ast_generator); + // required configs are missing + TestInvalidQuery<SemanticException>("CREATE KAFKA STREAM stream TOPICS topic1", ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM", ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform CONSUMER_GROUP", ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform CONSUMER_GROUP invalid consumer group", + ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform BATCH_INTERVAL", ast_generator); + TestInvalidQuery<SemanticException>( + "CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform BATCH_INTERVAL 'invalid interval'", ast_generator); + TestInvalidQuery<SemanticException>("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform TOPICS topic2", + ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform BATCH_SIZE", ast_generator); + TestInvalidQuery<SemanticException>( + "CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform BATCH_SIZE 'invalid size'", ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1, TRANSFORM transform BATCH_SIZE 2 CONSUMER_GROUP Gru", + ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform BOOTSTRAP_SERVERS localhost:9092", + ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform BOOTSTRAP_SERVERS", ast_generator); + // the keys must be string literals + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform CONFIGS { symbolicname : 'string' }", + ast_generator); + TestInvalidQuery( + "CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform CREDENTIALS { symbolicname : 'string' }", + ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform CREDENTIALS 2", ast_generator); + + const std::vector<std::string> topic_names{"topic1_name.with_dot", "topic1_name.with_multiple.dots", + "topic-name.with-multiple.dots-and-dashes"}; + + static constexpr std::string_view kStreamName{"SomeSuperStream"}; + static constexpr std::string_view kTransformName{"moreAwesomeTransform"}; + + auto check_topic_names = [&](const std::vector<std::string> &topic_names) { + static constexpr std::string_view kConsumerGroup{"ConsumerGru"}; + static constexpr int kBatchInterval = 324; + const TypedValue batch_interval_value{kBatchInterval}; + static constexpr int kBatchSize = 1; + const TypedValue batch_size_value{kBatchSize}; + + const auto topic_names_as_str = memgraph::utils::Join(topic_names, ","); + + ValidateCreateKafkaStreamQuery( + ast_generator, + fmt::format("CREATE KAFKA STREAM {} TOPICS {} TRANSFORM {}", kStreamName, topic_names_as_str, kTransformName), + kStreamName, topic_names, kTransformName, "", std::nullopt, std::nullopt, {}, {}, {}); + + ValidateCreateKafkaStreamQuery(ast_generator, + fmt::format("CREATE KAFKA STREAM {} TOPICS {} TRANSFORM {} CONSUMER_GROUP {} ", + kStreamName, topic_names_as_str, kTransformName, kConsumerGroup), + kStreamName, topic_names, kTransformName, kConsumerGroup, std::nullopt, std::nullopt, + {}, {}, {}); + + ValidateCreateKafkaStreamQuery(ast_generator, + fmt::format("CREATE KAFKA STREAM {} TRANSFORM {} TOPICS {} BATCH_INTERVAL {}", + kStreamName, kTransformName, topic_names_as_str, kBatchInterval), + kStreamName, topic_names, kTransformName, "", batch_interval_value, std::nullopt, {}, + {}, {}); + + ValidateCreateKafkaStreamQuery(ast_generator, + fmt::format("CREATE KAFKA STREAM {} BATCH_SIZE {} TOPICS {} TRANSFORM {}", + kStreamName, kBatchSize, topic_names_as_str, kTransformName), + kStreamName, topic_names, kTransformName, "", std::nullopt, batch_size_value, {}, {}, + {}); + + ValidateCreateKafkaStreamQuery(ast_generator, + fmt::format("CREATE KAFKA STREAM {} TOPICS '{}' BATCH_SIZE {} TRANSFORM {}", + kStreamName, topic_names_as_str, kBatchSize, kTransformName), + kStreamName, topic_names, kTransformName, "", std::nullopt, batch_size_value, {}, {}, + {}); + + ValidateCreateKafkaStreamQuery( + ast_generator, + fmt::format("CREATE KAFKA STREAM {} TOPICS {} TRANSFORM {} CONSUMER_GROUP {} BATCH_INTERVAL {} BATCH_SIZE {}", + kStreamName, topic_names_as_str, kTransformName, kConsumerGroup, kBatchInterval, kBatchSize), + kStreamName, topic_names, kTransformName, kConsumerGroup, batch_interval_value, batch_size_value, {}, {}, {}); + using namespace std::string_literals; + const auto host1 = "localhost:9094"s; + ValidateCreateKafkaStreamQuery( + ast_generator, + fmt::format("CREATE KAFKA STREAM {} TOPICS {} CONSUMER_GROUP {} BATCH_SIZE {} BATCH_INTERVAL {} TRANSFORM {} " + "BOOTSTRAP_SERVERS '{}'", + kStreamName, topic_names_as_str, kConsumerGroup, kBatchSize, kBatchInterval, kTransformName, host1), + kStreamName, topic_names, kTransformName, kConsumerGroup, batch_interval_value, batch_size_value, host1, {}, + {}); + + ValidateCreateKafkaStreamQuery( + ast_generator, + fmt::format("CREATE KAFKA STREAM {} CONSUMER_GROUP {} TOPICS {} BATCH_INTERVAL {} TRANSFORM {} BATCH_SIZE {} " + "BOOTSTRAP_SERVERS '{}'", + kStreamName, kConsumerGroup, topic_names_as_str, kBatchInterval, kTransformName, kBatchSize, host1), + kStreamName, topic_names, kTransformName, kConsumerGroup, batch_interval_value, batch_size_value, host1, {}, + {}); + + const auto host2 = "localhost:9094,localhost:1994,168.1.1.256:345"s; + ValidateCreateKafkaStreamQuery( + ast_generator, + fmt::format("CREATE KAFKA STREAM {} TOPICS {} BOOTSTRAP_SERVERS '{}' CONSUMER_GROUP {} TRANSFORM {} " + "BATCH_INTERVAL {} BATCH_SIZE {}", + kStreamName, topic_names_as_str, host2, kConsumerGroup, kTransformName, kBatchInterval, kBatchSize), + kStreamName, topic_names, kTransformName, kConsumerGroup, batch_interval_value, batch_size_value, host2, {}, + {}); + }; + + for (const auto &topic_name : topic_names) { + EXPECT_NO_FATAL_FAILURE(check_topic_names({topic_name})); + } + + EXPECT_NO_FATAL_FAILURE(check_topic_names(topic_names)); + + auto check_consumer_group = [&](const std::string_view consumer_group) { + const std::string kTopicName{"topic1"}; + ValidateCreateKafkaStreamQuery(ast_generator, + fmt::format("CREATE KAFKA STREAM {} TOPICS {} TRANSFORM {} CONSUMER_GROUP {}", + kStreamName, kTopicName, kTransformName, consumer_group), + kStreamName, {kTopicName}, kTransformName, consumer_group, std::nullopt, + std::nullopt, {}, {}, {}); + }; + + using namespace std::literals; + static constexpr std::array consumer_groups{"consumergru"sv, "consumer-group-with-dash"sv, + "consumer_group.with.dot"sv, "consumer-group.With-Dot-and.dash"sv}; + + for (const auto consumer_group : consumer_groups) { + EXPECT_NO_FATAL_FAILURE(check_consumer_group(consumer_group)); + } + + auto check_config_map = [&](const std::unordered_map<std::string, std::string> &config_map) { + const std::string kTopicName{"topic1"}; + + const auto map_as_str = std::invoke([&config_map] { + std::stringstream buffer; + buffer << '{'; + if (!config_map.empty()) { + auto it = config_map.begin(); + buffer << fmt::format("'{}': '{}'", it->first, it->second); + for (; it != config_map.end(); ++it) { + buffer << fmt::format(", '{}': '{}'", it->first, it->second); + } + } + buffer << '}'; + return std::move(buffer).str(); + }); + + ValidateCreateKafkaStreamQuery(ast_generator, + fmt::format("CREATE KAFKA STREAM {} TOPICS {} TRANSFORM {} CONFIGS {}", kStreamName, + kTopicName, kTransformName, map_as_str), + kStreamName, {kTopicName}, kTransformName, "", std::nullopt, std::nullopt, {}, + config_map, {}); + + ValidateCreateKafkaStreamQuery(ast_generator, + fmt::format("CREATE KAFKA STREAM {} TOPICS {} TRANSFORM {} CREDENTIALS {}", + kStreamName, kTopicName, kTransformName, map_as_str), + kStreamName, {kTopicName}, kTransformName, "", std::nullopt, std::nullopt, {}, {}, + config_map); + }; + + const std::array config_maps = {std::unordered_map<std::string, std::string>{}, + std::unordered_map<std::string, std::string>{{"key", "value"}}, + std::unordered_map<std::string, std::string>{{"key.with.dot", "value.with.doth"}, + {"key with space", "value with space"}}}; + for (const auto &map_to_test : config_maps) { + EXPECT_NO_FATAL_FAILURE(check_config_map(map_to_test)); + } +} + +void ValidateCreatePulsarStreamQuery(Base &ast_generator, const std::string &query_string, + const std::string_view stream_name, const std::vector<std::string> &topic_names, + const std::string_view transform_name, + const std::optional<TypedValue> &batch_interval, + const std::optional<TypedValue> &batch_size, const std::string_view service_url) { + SCOPED_TRACE(query_string); + + StreamQuery *parsed_query{nullptr}; + ASSERT_NO_THROW(parsed_query = dynamic_cast<StreamQuery *>(ast_generator.ParseQuery(query_string))) << query_string; + ASSERT_NE(parsed_query, nullptr); + EXPECT_EQ(parsed_query->stream_name_, stream_name); + ValidateTopicNames(parsed_query->topic_names_, topic_names, ast_generator); + EXPECT_EQ(parsed_query->transform_name_, transform_name); + EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->batch_interval_, batch_interval)); + EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->batch_size_, batch_size)); + EXPECT_EQ(parsed_query->batch_limit_, nullptr); + if (service_url.empty()) { + EXPECT_EQ(parsed_query->service_url_, nullptr); + return; + } + EXPECT_NE(parsed_query->service_url_, nullptr); +} + +TEST_P(CypherMainVisitorTest, CreatePulsarStream) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("CREATE PULSAR STREAM", ast_generator); + TestInvalidQuery<SemanticException>("CREATE PULSAR STREAM stream", ast_generator); + TestInvalidQuery("CREATE PULSAR STREAM stream TOPICS", ast_generator); + TestInvalidQuery<SemanticException>("CREATE PULSAR STREAM stream TOPICS topic_name", ast_generator); + TestInvalidQuery("CREATE PULSAR STREAM stream TOPICS topic_name TRANSFORM", ast_generator); + TestInvalidQuery("CREATE PULSAR STREAM stream TOPICS topic_name TRANSFORM transform.name SERVICE_URL", ast_generator); + TestInvalidQuery<SemanticException>( + "CREATE PULSAR STREAM stream TOPICS topic_name TRANSFORM transform.name SERVICE_URL 1", ast_generator); + TestInvalidQuery( + "CREATE PULSAR STREAM stream TOPICS topic_name TRANSFORM transform.name BOOTSTRAP_SERVERS 'bootstrap'", + ast_generator); + TestInvalidQuery<SemanticException>( + "CREATE PULSAR STREAM stream TOPICS topic_name TRANSFORM transform.name SERVICE_URL 'test' TOPICS topic_name", + ast_generator); + TestInvalidQuery<SemanticException>( + "CREATE PULSAR STREAM stream TRANSFORM transform.name TOPICS topic_name TRANSFORM transform.name SERVICE_URL " + "'test'", + ast_generator); + TestInvalidQuery<SemanticException>( + "CREATE PULSAR STREAM stream BATCH_INTERVAL 1 TOPICS topic_name TRANSFORM transform.name SERVICE_URL 'test' " + "BATCH_INTERVAL 1000", + ast_generator); + TestInvalidQuery<SemanticException>( + "CREATE PULSAR STREAM stream BATCH_INTERVAL 'a' TOPICS topic_name TRANSFORM transform.name SERVICE_URL 'test'", + ast_generator); + TestInvalidQuery<SemanticException>( + "CREATE PULSAR STREAM stream BATCH_SIZE 'a' TOPICS topic_name TRANSFORM transform.name SERVICE_URL 'test'", + ast_generator); + + const std::vector<std::string> topic_names{"topic1", "topic2"}; + const std::string topic_names_str = memgraph::utils::Join(topic_names, ","); + static constexpr std::string_view kStreamName{"PulsarStream"}; + static constexpr std::string_view kTransformName{"boringTransformation"}; + static constexpr std::string_view kServiceUrl{"localhost"}; + static constexpr int kBatchSize{1000}; + static constexpr int kBatchInterval{231321}; + + { + SCOPED_TRACE("single topic"); + ValidateCreatePulsarStreamQuery( + ast_generator, + fmt::format("CREATE PULSAR STREAM {} TOPICS {} TRANSFORM {}", kStreamName, topic_names[0], kTransformName), + kStreamName, {topic_names[0]}, kTransformName, std::nullopt, std::nullopt, ""); + } + { + SCOPED_TRACE("multiple topics"); + ValidateCreatePulsarStreamQuery( + ast_generator, + fmt::format("CREATE PULSAR STREAM {} TRANSFORM {} TOPICS {}", kStreamName, kTransformName, topic_names_str), + kStreamName, topic_names, kTransformName, std::nullopt, std::nullopt, ""); + } + { + SCOPED_TRACE("topic name in string"); + ValidateCreatePulsarStreamQuery( + ast_generator, + fmt::format("CREATE PULSAR STREAM {} TRANSFORM {} TOPICS '{}'", kStreamName, kTransformName, topic_names_str), + kStreamName, topic_names, kTransformName, std::nullopt, std::nullopt, ""); + } + { + SCOPED_TRACE("service url"); + ValidateCreatePulsarStreamQuery(ast_generator, + fmt::format("CREATE PULSAR STREAM {} SERVICE_URL '{}' TRANSFORM {} TOPICS {}", + kStreamName, kServiceUrl, kTransformName, topic_names_str), + kStreamName, topic_names, kTransformName, std::nullopt, std::nullopt, kServiceUrl); + ValidateCreatePulsarStreamQuery(ast_generator, + fmt::format("CREATE PULSAR STREAM {} TRANSFORM {} SERVICE_URL '{}' TOPICS {}", + kStreamName, kTransformName, kServiceUrl, topic_names_str), + kStreamName, topic_names, kTransformName, std::nullopt, std::nullopt, kServiceUrl); + } + { + SCOPED_TRACE("batch size"); + ValidateCreatePulsarStreamQuery( + ast_generator, + fmt::format("CREATE PULSAR STREAM {} SERVICE_URL '{}' BATCH_SIZE {} TRANSFORM {} TOPICS {}", kStreamName, + kServiceUrl, kBatchSize, kTransformName, topic_names_str), + kStreamName, topic_names, kTransformName, std::nullopt, TypedValue(kBatchSize), kServiceUrl); + ValidateCreatePulsarStreamQuery( + ast_generator, + fmt::format("CREATE PULSAR STREAM {} TRANSFORM {} SERVICE_URL '{}' TOPICS {} BATCH_SIZE {}", kStreamName, + kTransformName, kServiceUrl, topic_names_str, kBatchSize), + kStreamName, topic_names, kTransformName, std::nullopt, TypedValue(kBatchSize), kServiceUrl); + } + { + SCOPED_TRACE("batch interval"); + ValidateCreatePulsarStreamQuery( + ast_generator, + fmt::format("CREATE PULSAR STREAM {} BATCH_INTERVAL {} SERVICE_URL '{}' BATCH_SIZE {} TRANSFORM {} TOPICS {}", + kStreamName, kBatchInterval, kServiceUrl, kBatchSize, kTransformName, topic_names_str), + kStreamName, topic_names, kTransformName, TypedValue(kBatchInterval), TypedValue(kBatchSize), kServiceUrl); + ValidateCreatePulsarStreamQuery( + ast_generator, + fmt::format("CREATE PULSAR STREAM {} TRANSFORM {} SERVICE_URL '{}' BATCH_INTERVAL {} TOPICS {} BATCH_SIZE {}", + kStreamName, kTransformName, kServiceUrl, kBatchInterval, topic_names_str, kBatchSize), + kStreamName, topic_names, kTransformName, TypedValue(kBatchInterval), TypedValue(kBatchSize), kServiceUrl); + } +} + +TEST_P(CypherMainVisitorTest, CheckStream) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("CHECK STREAM", ast_generator); + TestInvalidQuery("CHECK STREAMS", ast_generator); + TestInvalidQuery("CHECK STREAMS something", ast_generator); + TestInvalidQuery("CHECK STREAM something,something", ast_generator); + TestInvalidQuery("CHECK STREAM something BATCH LIMIT 1", ast_generator); + TestInvalidQuery("CHECK STREAM something BATCH_LIMIT", ast_generator); + TestInvalidQuery("CHECK STREAM something TIMEOUT", ast_generator); + TestInvalidQuery("CHECK STREAM something BATCH_LIMIT 1 TIMEOUT", ast_generator); + TestInvalidQuery<SemanticException>("CHECK STREAM something BATCH_LIMIT 'it should be an integer'", ast_generator); + TestInvalidQuery<SemanticException>("CHECK STREAM something BATCH_LIMIT 2.5", ast_generator); + TestInvalidQuery<SemanticException>("CHECK STREAM something TIMEOUT 'it should be an integer'", ast_generator); + + ValidateMostlyEmptyStreamQuery(ast_generator, "CHECK STREAM checkedStream", StreamQuery::Action::CHECK_STREAM, + "checkedStream"); + ValidateMostlyEmptyStreamQuery(ast_generator, "CHECK STREAM checkedStream bAtCH_LIMIT 42", + StreamQuery::Action::CHECK_STREAM, "checkedStream", TypedValue(42)); + ValidateMostlyEmptyStreamQuery(ast_generator, "CHECK STREAM checkedStream TimEOuT 666", + StreamQuery::Action::CHECK_STREAM, "checkedStream", std::nullopt, TypedValue(666)); + ValidateMostlyEmptyStreamQuery(ast_generator, "CHECK STREAM checkedStream BATCH_LIMIT 30 TIMEOUT 444", + StreamQuery::Action::CHECK_STREAM, "checkedStream", TypedValue(30), TypedValue(444)); +} + +TEST_P(CypherMainVisitorTest, SettingQuery) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("SHOW DB SETTINGS", ast_generator); + TestInvalidQuery("SHOW SETTINGS", ast_generator); + TestInvalidQuery("SHOW DATABASE SETTING", ast_generator); + TestInvalidQuery("SHOW DB SETTING 'setting'", ast_generator); + TestInvalidQuery("SHOW SETTING 'setting'", ast_generator); + TestInvalidQuery<SemanticException>("SHOW DATABASE SETTING 1", ast_generator); + TestInvalidQuery("SET SETTING 'setting' TO 'value'", ast_generator); + TestInvalidQuery("SET DB SETTING 'setting' TO 'value'", ast_generator); + TestInvalidQuery<SemanticException>("SET DATABASE SETTING 1 TO 'value'", ast_generator); + TestInvalidQuery<SemanticException>("SET DATABASE SETTING 'setting' TO 2", ast_generator); + + const auto validate_setting_query = [&](const auto &query, const auto action, + const std::optional<TypedValue> &expected_setting_name, + const std::optional<TypedValue> &expected_setting_value) { + auto *parsed_query = dynamic_cast<SettingQuery *>(ast_generator.ParseQuery(query)); + EXPECT_EQ(parsed_query->action_, action) << query; + EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->setting_name_, expected_setting_name)); + EXPECT_NO_FATAL_FAILURE( + CheckOptionalExpression(ast_generator, parsed_query->setting_value_, expected_setting_value)); + }; + + validate_setting_query("SHOW DATABASE SETTINGS", SettingQuery::Action::SHOW_ALL_SETTINGS, std::nullopt, std::nullopt); + validate_setting_query("SHOW DATABASE SETTING 'setting'", SettingQuery::Action::SHOW_SETTING, TypedValue{"setting"}, + std::nullopt); + validate_setting_query("SET DATABASE SETTING 'setting' TO 'value'", SettingQuery::Action::SET_SETTING, + TypedValue{"setting"}, TypedValue{"value"}); +} + +TEST_P(CypherMainVisitorTest, VersionQuery) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("SHOW VERION", ast_generator); + TestInvalidQuery("SHOW VER", ast_generator); + TestInvalidQuery("SHOW VERSIONS", ast_generator); + ASSERT_NO_THROW(ast_generator.ParseQuery("SHOW VERSION")); +} + +TEST_P(CypherMainVisitorTest, ForeachThrow) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("FOREACH(i IN [1, 2] | UNWIND [1,2,3] AS j CREATE (n))"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("FOREACH(i IN [1, 2] CREATE (:Foo {prop : i}))"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("FOREACH(i IN [1, 2] | MATCH (n)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("FOREACH(i IN x | MATCH (n)"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, Foreach) { + auto &ast_generator = *GetParam(); + // CREATE + { + auto *query = dynamic_cast<CypherQuery *>( + ast_generator.ParseQuery("FOREACH (age IN [1, 2, 3] | CREATE (m:Age {amount: age}))")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *foreach = dynamic_cast<Foreach *>(single_query->clauses_[0]); + ASSERT_TRUE(foreach); + ASSERT_TRUE(foreach->named_expression_); + EXPECT_EQ(foreach->named_expression_->name_, "age"); + auto *expr = foreach->named_expression_->expression_; + ASSERT_TRUE(expr); + ASSERT_TRUE(dynamic_cast<ListLiteral *>(expr)); + const auto &clauses = foreach->clauses_; + ASSERT_TRUE(clauses.size() == 1); + ASSERT_TRUE(dynamic_cast<Create *>(clauses.front())); + } + // SET + { + auto *query = + dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("FOREACH (i IN nodes(path) | SET i.checkpoint = true)")); + auto *foreach = dynamic_cast<Foreach *>(query->single_query_->clauses_[0]); + const auto &clauses = foreach->clauses_; + ASSERT_TRUE(clauses.size() == 1); + ASSERT_TRUE(dynamic_cast<SetProperty *>(clauses.front())); + } + // REMOVE + { + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("FOREACH (i IN nodes(path) | REMOVE i.prop)")); + auto *foreach = dynamic_cast<Foreach *>(query->single_query_->clauses_[0]); + const auto &clauses = foreach->clauses_; + ASSERT_TRUE(clauses.size() == 1); + ASSERT_TRUE(dynamic_cast<RemoveProperty *>(clauses.front())); + } + // MERGE + { + // merge works as create here + auto *query = + dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("FOREACH (i IN [1, 2, 3] | MERGE (n {no : i}))")); + auto *foreach = dynamic_cast<Foreach *>(query->single_query_->clauses_[0]); + const auto &clauses = foreach->clauses_; + ASSERT_TRUE(clauses.size() == 1); + ASSERT_TRUE(dynamic_cast<Merge *>(clauses.front())); + } + // CYPHER DELETE + { + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("FOREACH (i IN nodes(path) | DETACH DELETE i)")); + auto *foreach = dynamic_cast<Foreach *>(query->single_query_->clauses_[0]); + const auto &clauses = foreach->clauses_; + ASSERT_TRUE(clauses.size() == 1); + ASSERT_TRUE(dynamic_cast<Delete *>(clauses.front())); + } + // nested FOREACH + { + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery( + "FOREACH (i IN nodes(path) | FOREACH (age IN i.list | CREATE (m:Age {amount: age})))")); + + auto *foreach = dynamic_cast<Foreach *>(query->single_query_->clauses_[0]); + const auto &clauses = foreach->clauses_; + ASSERT_TRUE(clauses.size() == 1); + ASSERT_TRUE(dynamic_cast<Foreach *>(clauses.front())); + } + // Multiple update clauses + { + auto *query = dynamic_cast<CypherQuery *>( + ast_generator.ParseQuery("FOREACH (i IN nodes(path) | SET i.checkpoint = true REMOVE i.prop)")); + auto *foreach = dynamic_cast<Foreach *>(query->single_query_->clauses_[0]); + const auto &clauses = foreach->clauses_; + ASSERT_TRUE(clauses.size() == 2); + ASSERT_TRUE(dynamic_cast<SetProperty *>(clauses.front())); + ASSERT_TRUE(dynamic_cast<RemoveProperty *>(*++clauses.begin())); + } +} + +TEST_P(CypherMainVisitorTest, TestShowSchemas) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<SchemaQuery *>(ast_generator.ParseQuery("SHOW SCHEMAS")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_, SchemaQuery::Action::SHOW_SCHEMAS); +} + +TEST_P(CypherMainVisitorTest, TestShowSchema) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("SHOW SCHEMA ON label"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("SHOW SCHEMA :label"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("SHOW SCHEMA label"), SyntaxException); + + auto *query = dynamic_cast<SchemaQuery *>(ast_generator.ParseQuery("SHOW SCHEMA ON :label")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_, SchemaQuery::Action::SHOW_SCHEMA); + EXPECT_EQ(query->label_, ast_generator.Label("label")); +} + +void AssertSchemaPropertyMap(auto &schema_property_map, + std::vector<std::pair<std::string, memgraph::common::SchemaType>> properties_type, + auto &ast_generator) { + EXPECT_EQ(schema_property_map.size(), properties_type.size()); + for (size_t i{0}; i < schema_property_map.size(); ++i) { + // Assert PropertyId + EXPECT_EQ(schema_property_map[i].first, ast_generator.Prop(properties_type[i].first)); + // Assert Property Type + EXPECT_EQ(schema_property_map[i].second, properties_type[i].second); + } +} + +TEST_P(CypherMainVisitorTest, TestCreateSchema) { + { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label()"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(123 INTEGER)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(name TYPE)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(name, age)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(name, DURATION)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON label(name INTEGER)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(name INTEGER, name INTEGER)"), SemanticException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(name INTEGER, name STRING)"), SemanticException); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<SchemaQuery *>(ast_generator.ParseQuery("CREATE SCHEMA ON :label1(name STRING)")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_, SchemaQuery::Action::CREATE_SCHEMA); + EXPECT_EQ(query->label_, ast_generator.Label("label1")); + AssertSchemaPropertyMap(query->schema_type_map_, {{"name", memgraph::common::SchemaType::STRING}}, ast_generator); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<SchemaQuery *>(ast_generator.ParseQuery("CREATE SCHEMA ON :label2(name string)")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_, SchemaQuery::Action::CREATE_SCHEMA); + EXPECT_EQ(query->label_, ast_generator.Label("label2")); + AssertSchemaPropertyMap(query->schema_type_map_, {{"name", memgraph::common::SchemaType::STRING}}, ast_generator); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<SchemaQuery *>( + ast_generator.ParseQuery("CREATE SCHEMA ON :label3(first_name STRING, last_name STRING)")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_, SchemaQuery::Action::CREATE_SCHEMA); + EXPECT_EQ(query->label_, ast_generator.Label("label3")); + AssertSchemaPropertyMap( + query->schema_type_map_, + {{"first_name", memgraph::common::SchemaType::STRING}, {"last_name", memgraph::common::SchemaType::STRING}}, + ast_generator); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<SchemaQuery *>( + ast_generator.ParseQuery("CREATE SCHEMA ON :label4(name STRING, age INTEGER, dur DURATION, birthday1 " + "LOCALDATETIME, birthday2 DATE, some_time LOCALTIME, speaks_truth BOOL)")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_, SchemaQuery::Action::CREATE_SCHEMA); + EXPECT_EQ(query->label_, ast_generator.Label("label4")); + AssertSchemaPropertyMap(query->schema_type_map_, + { + {"name", memgraph::common::SchemaType::STRING}, + {"age", memgraph::common::SchemaType::INT}, + {"dur", memgraph::common::SchemaType::DURATION}, + {"birthday1", memgraph::common::SchemaType::LOCALDATETIME}, + {"birthday2", memgraph::common::SchemaType::DATE}, + {"some_time", memgraph::common::SchemaType::LOCALTIME}, + {"speaks_truth", memgraph::common::SchemaType::BOOL}, + }, + ast_generator); + } +} + +TEST_P(CypherMainVisitorTest, TestDropSchema) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("DROP SCHEMA"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("DROP SCHEMA ON label"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("DROP SCHEMA :label"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("DROP SCHEMA ON :label()"), SyntaxException); + + auto *query = dynamic_cast<SchemaQuery *>(ast_generator.ParseQuery("DROP SCHEMA ON :label")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_, SchemaQuery::Action::DROP_SCHEMA); + EXPECT_EQ(query->label_, ast_generator.Label("label")); +} diff --git a/tests/unit/query_v2_interpreter.cpp b/tests/unit/query_v2_interpreter.cpp new file mode 100644 index 000000000..b73cbeb5a --- /dev/null +++ b/tests/unit/query_v2_interpreter.cpp @@ -0,0 +1,1645 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <algorithm> +#include <cstddef> +#include <cstdlib> +#include <filesystem> +#include <unordered_set> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "communication/bolt/v1/value.hpp" +#include "glue/v2/communication.hpp" +#include "query/v2/auth_checker.hpp" +#include "query/v2/config.hpp" +#include "query/v2/exceptions.hpp" +#include "query/v2/interpreter.hpp" +#include "query/v2/stream.hpp" +#include "query/v2/typed_value.hpp" +#include "query_v2_query_common.hpp" +#include "result_stream_faker.hpp" +#include "storage/v3/isolation_level.hpp" +#include "storage/v3/property_value.hpp" +#include "utils/csv_parsing.hpp" +#include "utils/logging.hpp" + +namespace { + +auto ToEdgeList(const memgraph::communication::bolt::Value &v) { + std::vector<memgraph::communication::bolt::Edge> list; + for (auto x : v.ValueList()) { + list.push_back(x.ValueEdge()); + } + return list; +} + +auto StringToUnorderedSet(const std::string &element) { + const auto element_split = memgraph::utils::Split(element, ", "); + return std::unordered_set<std::string>(element_split.begin(), element_split.end()); +}; + +struct InterpreterFaker { + InterpreterFaker(memgraph::storage::v3::Storage *db, const memgraph::query::v2::InterpreterConfig config, + const std::filesystem::path &data_directory) + : interpreter_context(db, config, data_directory), interpreter(&interpreter_context) { + interpreter_context.auth_checker = &auth_checker; + } + + auto Prepare(const std::string &query, + const std::map<std::string, memgraph::storage::v3::PropertyValue> ¶ms = {}) { + ResultStreamFaker stream(interpreter_context.db); + + const auto [header, _, qid] = interpreter.Prepare(query, params, nullptr); + stream.Header(header); + return std::make_pair(std::move(stream), qid); + } + + void Pull(ResultStreamFaker *stream, std::optional<int> n = {}, std::optional<int> qid = {}) { + const auto summary = interpreter.Pull(stream, n, qid); + stream->Summary(summary); + } + + /** + * Execute the given query and commit the transaction. + * + * Return the query stream. + */ + auto Interpret(const std::string &query, + const std::map<std::string, memgraph::storage::v3::PropertyValue> ¶ms = {}) { + auto prepare_result = Prepare(query, params); + + auto &stream = prepare_result.first; + auto summary = interpreter.Pull(&stream, {}, prepare_result.second); + stream.Summary(summary); + + return std::move(stream); + } + + memgraph::query::v2::AllowEverythingAuthChecker auth_checker; + memgraph::query::v2::InterpreterContext interpreter_context; + memgraph::query::v2::Interpreter interpreter; +}; + +} // namespace + +// TODO: This is not a unit test, but tests/integration dir is chaotic at the +// moment. After tests refactoring is done, move/rename this. + +class InterpreterTest : public ::testing::Test { + protected: + memgraph::storage::v3::Storage db_; + std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_query_v2_interpreter"}; + + InterpreterFaker default_interpreter{&db_, {}, data_directory}; + + auto Prepare(const std::string &query, + const std::map<std::string, memgraph::storage::v3::PropertyValue> ¶ms = {}) { + return default_interpreter.Prepare(query, params); + } + + void Pull(ResultStreamFaker *stream, std::optional<int> n = {}, std::optional<int> qid = {}) { + default_interpreter.Pull(stream, n, qid); + } + + auto Interpret(const std::string &query, + const std::map<std::string, memgraph::storage::v3::PropertyValue> ¶ms = {}) { + return default_interpreter.Interpret(query, params); + } +}; + +TEST_F(InterpreterTest, MultiplePulls) { + { + auto [stream, qid] = Prepare("UNWIND [1,2,3,4,5] as n RETURN n"); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "n"); + Pull(&stream, 1); + ASSERT_EQ(stream.GetSummary().count("has_more"), 1); + ASSERT_TRUE(stream.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueInt(), 1); + Pull(&stream, 2); + ASSERT_EQ(stream.GetSummary().count("has_more"), 1); + ASSERT_TRUE(stream.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream.GetResults().size(), 3U); + ASSERT_EQ(stream.GetResults()[1][0].ValueInt(), 2); + ASSERT_EQ(stream.GetResults()[2][0].ValueInt(), 3); + Pull(&stream); + ASSERT_EQ(stream.GetSummary().count("has_more"), 1); + ASSERT_FALSE(stream.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream.GetResults().size(), 5U); + ASSERT_EQ(stream.GetResults()[3][0].ValueInt(), 4); + ASSERT_EQ(stream.GetResults()[4][0].ValueInt(), 5); + } +} + +// Run query with different ast twice to see if query executes correctly when +// ast is read from cache. +TEST_F(InterpreterTest, AstCache) { + { + auto stream = Interpret("RETURN 2 + 3"); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "2 + 3"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueInt(), 5); + } + { + // Cached ast, different literals. + auto stream = Interpret("RETURN 5 + 4"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueInt(), 9); + } + { + // Different ast (because of different types). + auto stream = Interpret("RETURN 5.5 + 4"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueDouble(), 9.5); + } + { + // Cached ast, same literals. + auto stream = Interpret("RETURN 2 + 3"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueInt(), 5); + } + { + // Cached ast, different literals. + auto stream = Interpret("RETURN 10.5 + 1"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueDouble(), 11.5); + } + { + // Cached ast, same literals, different whitespaces. + auto stream = Interpret("RETURN 10.5 + 1"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueDouble(), 11.5); + } + { + // Cached ast, same literals, different named header. + auto stream = Interpret("RETURN 10.5+1"); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "10.5+1"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueDouble(), 11.5); + } +} + +// Run query with same ast multiple times with different parameters. +TEST_F(InterpreterTest, Parameters) { + { + auto stream = Interpret("RETURN $2 + $`a b`", {{"2", memgraph::storage::v3::PropertyValue(10)}, + {"a b", memgraph::storage::v3::PropertyValue(15)}}); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "$2 + $`a b`"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueInt(), 25); + } + { + // Not needed parameter. + auto stream = Interpret("RETURN $2 + $`a b`", {{"2", memgraph::storage::v3::PropertyValue(10)}, + {"a b", memgraph::storage::v3::PropertyValue(15)}, + {"c", memgraph::storage::v3::PropertyValue(10)}}); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "$2 + $`a b`"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueInt(), 25); + } + { + // Cached ast, different parameters. + auto stream = Interpret("RETURN $2 + $`a b`", {{"2", memgraph::storage::v3::PropertyValue("da")}, + {"a b", memgraph::storage::v3::PropertyValue("ne")}}); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueString(), "dane"); + } + { + // Non-primitive literal. + auto stream = Interpret( + "RETURN $2", {{"2", memgraph::storage::v3::PropertyValue(std::vector<memgraph::storage::v3::PropertyValue>{ + memgraph::storage::v3::PropertyValue(5), memgraph::storage::v3::PropertyValue(2), + memgraph::storage::v3::PropertyValue(3)})}}); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + auto result = + memgraph::query::v2::test_common::ToIntList(memgraph::glue::v2::ToTypedValue(stream.GetResults()[0][0])); + ASSERT_THAT(result, testing::ElementsAre(5, 2, 3)); + } + { + // Cached ast, unprovided parameter. + ASSERT_THROW(Interpret("RETURN $2 + $`a b`", {{"2", memgraph::storage::v3::PropertyValue("da")}, + {"ab", memgraph::storage::v3::PropertyValue("ne")}}), + memgraph::query::v2::UnprovidedParameterError); + } +} + +// Run CREATE/MATCH/MERGE queries with property map +TEST_F(InterpreterTest, ParametersAsPropertyMap) { + { + EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :label(name STRING, age INTEGER)")); + std::map<std::string, memgraph::storage::v3::PropertyValue> property_map{}; + property_map["name"] = memgraph::storage::v3::PropertyValue("name1"); + property_map["age"] = memgraph::storage::v3::PropertyValue(25); + auto stream = + Interpret("CREATE (n:label $prop) RETURN n", { + {"prop", memgraph::storage::v3::PropertyValue(property_map)}, + }); + ASSERT_EQ(stream.GetHeader().size(), 1U); + ASSERT_EQ(stream.GetHeader()[0], "n"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + auto result = stream.GetResults()[0][0].ValueVertex(); + EXPECT_EQ(result.properties["name"].ValueString(), "name1"); + EXPECT_EQ(result.properties["age"].ValueInt(), 25); + } + { + EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :Person(name STRING, age INTEGER)")); + std::map<std::string, memgraph::storage::v3::PropertyValue> property_map{}; + property_map["name"] = memgraph::storage::v3::PropertyValue("name1"); + property_map["age"] = memgraph::storage::v3::PropertyValue(25); + EXPECT_NO_THROW(Interpret("CREATE (:Person {name: 'test', age: 30})")); + auto stream = Interpret("MATCH (m:Person) CREATE (n:Person $prop) RETURN n", + { + {"prop", memgraph::storage::v3::PropertyValue(property_map)}, + }); + ASSERT_EQ(stream.GetHeader().size(), 1U); + ASSERT_EQ(stream.GetHeader()[0], "n"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + auto result = stream.GetResults()[0][0].ValueVertex(); + EXPECT_EQ(result.properties["name"].ValueString(), "name1"); + EXPECT_EQ(result.properties["age"].ValueInt(), 25); + } + { + EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :L1(name STRING)")); + std::map<std::string, memgraph::storage::v3::PropertyValue> property_map{}; + property_map["name"] = memgraph::storage::v3::PropertyValue("name1"); + property_map["weight"] = memgraph::storage::v3::PropertyValue(121); + auto stream = Interpret("CREATE (:L1 {name: 'name1'})-[r:TO $prop]->(:L1 {name: 'name2'}) RETURN r", + { + {"prop", memgraph::storage::v3::PropertyValue(property_map)}, + }); + ASSERT_EQ(stream.GetHeader().size(), 1U); + ASSERT_EQ(stream.GetHeader()[0], "r"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + auto result = stream.GetResults()[0][0].ValueEdge(); + EXPECT_EQ(result.properties["name"].ValueString(), "name1"); + EXPECT_EQ(result.properties["weight"].ValueInt(), 121); + } + { + std::map<std::string, memgraph::storage::v3::PropertyValue> property_map{}; + property_map["name"] = memgraph::storage::v3::PropertyValue("name1"); + property_map["age"] = memgraph::storage::v3::PropertyValue(15); + ASSERT_THROW(Interpret("MATCH (n $prop) RETURN n", + { + {"prop", memgraph::storage::v3::PropertyValue(property_map)}, + }), + memgraph::query::v2::SemanticException); + } + { + EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :L2(name STRING, age INTEGER)")); + std::map<std::string, memgraph::storage::v3::PropertyValue> property_map{}; + property_map["name"] = memgraph::storage::v3::PropertyValue("name1"); + property_map["age"] = memgraph::storage::v3::PropertyValue(15); + ASSERT_THROW(Interpret("MERGE (n:L2 $prop) RETURN n", + { + {"prop", memgraph::storage::v3::PropertyValue(property_map)}, + }), + memgraph::query::v2::SemanticException); + } +} + +// Test bfs end to end. +TEST_F(InterpreterTest, Bfs) { + srand(0); + const auto kNumLevels = 10; + const auto kNumNodesPerLevel = 100; + const auto kNumEdgesPerNode = 100; + const auto kNumUnreachableNodes = 1000; + const auto kNumUnreachableEdges = 100000; + const auto kReachable = "reachable"; + const auto kId = "id"; + + std::vector<std::vector<memgraph::query::v2::VertexAccessor>> levels(kNumLevels); + int id = 0; + + // Set up. + { + auto storage_dba = db_.Access(); + memgraph::query::v2::DbAccessor dba(&storage_dba); + auto add_node = [&](int level, bool reachable) { + auto node = dba.InsertVertex(); + MG_ASSERT(node.SetProperty(dba.NameToProperty(kId), memgraph::storage::v3::PropertyValue(id++)).HasValue()); + MG_ASSERT( + node.SetProperty(dba.NameToProperty(kReachable), memgraph::storage::v3::PropertyValue(reachable)).HasValue()); + levels[level].push_back(node); + return node; + }; + + auto add_edge = [&](auto &v1, auto &v2, bool reachable) { + auto edge = dba.InsertEdge(&v1, &v2, dba.NameToEdgeType("edge")); + MG_ASSERT(edge->SetProperty(dba.NameToProperty(kReachable), memgraph::storage::v3::PropertyValue(reachable)) + .HasValue()); + }; + + // Add source node. + add_node(0, true); + + // Add reachable nodes. + for (int i = 1; i < kNumLevels; ++i) { + for (int j = 0; j < kNumNodesPerLevel; ++j) { + auto node = add_node(i, true); + for (int k = 0; k < kNumEdgesPerNode; ++k) { + auto &node2 = levels[i - 1][rand() % levels[i - 1].size()]; + add_edge(node2, node, true); + } + } + } + + // Add unreachable nodes. + for (int i = 0; i < kNumUnreachableNodes; ++i) { + auto node = add_node(rand() % kNumLevels, // Not really important. + false); + for (int j = 0; j < kNumEdgesPerNode; ++j) { + auto &level = levels[rand() % kNumLevels]; + auto &node2 = level[rand() % level.size()]; + add_edge(node2, node, true); + add_edge(node, node2, true); + } + } + + // Add unreachable edges. + for (int i = 0; i < kNumUnreachableEdges; ++i) { + auto &level1 = levels[rand() % kNumLevels]; + auto &node1 = level1[rand() % level1.size()]; + auto &level2 = levels[rand() % kNumLevels]; + auto &node2 = level2[rand() % level2.size()]; + add_edge(node1, node2, false); + } + + ASSERT_FALSE(dba.Commit().HasError()); + } + + auto stream = Interpret( + "MATCH (n {id: 0})-[r *bfs..5 (e, n | n.reachable and " + "e.reachable)]->(m) RETURN n, r, m"); + + ASSERT_EQ(stream.GetHeader().size(), 3U); + EXPECT_EQ(stream.GetHeader()[0], "n"); + EXPECT_EQ(stream.GetHeader()[1], "r"); + EXPECT_EQ(stream.GetHeader()[2], "m"); + ASSERT_EQ(stream.GetResults().size(), 5 * kNumNodesPerLevel); + + auto dba = db_.Access(); + int expected_level = 1; + int remaining_nodes_in_level = kNumNodesPerLevel; + std::unordered_set<int64_t> matched_ids; + + for (const auto &result : stream.GetResults()) { + const auto &begin = result[0].ValueVertex(); + const auto &edges = ToEdgeList(result[1]); + const auto &end = result[2].ValueVertex(); + + // Check that path is of expected length. Returned paths should be from + // shorter to longer ones. + EXPECT_EQ(edges.size(), expected_level); + // Check that starting node is correct. + EXPECT_EQ(edges.front().from, begin.id); + EXPECT_EQ(begin.properties.at(kId).ValueInt(), 0); + for (int i = 1; i < static_cast<int>(edges.size()); ++i) { + // Check that edges form a connected path. + EXPECT_EQ(edges[i - 1].to.AsInt(), edges[i].from.AsInt()); + } + auto matched_id = end.properties.at(kId).ValueInt(); + EXPECT_EQ(edges.back().to, end.id); + // Check that we didn't match that node already. + EXPECT_TRUE(matched_ids.insert(matched_id).second); + // Check that shortest path was found. + EXPECT_TRUE(matched_id > kNumNodesPerLevel * (expected_level - 1) && + matched_id <= kNumNodesPerLevel * expected_level); + if (!--remaining_nodes_in_level) { + remaining_nodes_in_level = kNumNodesPerLevel; + ++expected_level; + } + } +} + +// Test shortest path end to end. +TEST_F(InterpreterTest, ShortestPath) { + EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :A(x INTEGER)")); + EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :B(x INTEGER)")); + EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :C(x INTEGER)")); + const auto test_shortest_path = [this](const bool use_duration) { + const auto get_weight = [use_duration](const auto value) { + return fmt::format(fmt::runtime(use_duration ? "DURATION('PT{}S')" : "{}"), value); + }; + + Interpret( + fmt::format("CREATE (n:A {{x: 1}}), (m:B {{x: 2}}), (l:C {{x: 1}}), (n)-[:r1 {{w: {} " + "}}]->(m)-[:r2 {{w: {}}}]->(l), (n)-[:r3 {{w: {}}}]->(l)", + get_weight(1), get_weight(2), get_weight(4))); + + auto stream = Interpret("MATCH (n)-[e *wshortest 5 (e, n | e.w) ]->(m) return e"); + + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "e"); + ASSERT_EQ(stream.GetResults().size(), 3U); + + auto dba = db_.Access(); + std::vector<std::vector<std::string>> expected_results{{"r1"}, {"r2"}, {"r1", "r2"}}; + + for (const auto &result : stream.GetResults()) { + const auto &edges = ToEdgeList(result[0]); + + std::vector<std::string> datum; + datum.reserve(edges.size()); + + for (const auto &edge : edges) { + datum.push_back(edge.type); + } + + bool any_match = false; + for (const auto &expected : expected_results) { + if (expected == datum) { + any_match = true; + break; + } + } + + EXPECT_TRUE(any_match); + } + + Interpret("MATCH (n) DETACH DELETE n"); + }; + + static constexpr bool kUseNumeric{false}; + static constexpr bool kUseDuration{true}; + { + SCOPED_TRACE("Test with numeric values"); + test_shortest_path(kUseNumeric); + } + { + SCOPED_TRACE("Test with Duration values"); + test_shortest_path(kUseDuration); + } +} + +TEST_F(InterpreterTest, CreateLabelIndexInMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("CREATE INDEX ON :X"), memgraph::query::v2::IndexInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, CreateLabelPropertyIndexInMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("CREATE INDEX ON :X(y)"), memgraph::query::v2::IndexInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, CreateExistenceConstraintInMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("CREATE CONSTRAINT ON (n:A) ASSERT EXISTS (n.a)"), + memgraph::query::v2::ConstraintInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, CreateUniqueConstraintInMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("CREATE CONSTRAINT ON (n:A) ASSERT n.a, n.b IS UNIQUE"), + memgraph::query::v2::ConstraintInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, ShowIndexInfoInMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("SHOW INDEX INFO"), memgraph::query::v2::InfoInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, ShowConstraintInfoInMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("SHOW CONSTRAINT INFO"), memgraph::query::v2::InfoInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, ShowStorageInfoInMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("SHOW STORAGE INFO"), memgraph::query::v2::InfoInMulticommandTxException); + Interpret("ROLLBACK"); +} + +// // NOLINTNEXTLINE(hicpp-special-member-functions) +TEST_F(InterpreterTest, ExistenceConstraintTest) { + ASSERT_NO_THROW(Interpret("CREATE SCHEMA ON :A(a INTEGER);")); + + Interpret("CREATE CONSTRAINT ON (n:A) ASSERT EXISTS (n.b);"); + Interpret("CREATE (:A{a: 3, b:1})"); + Interpret("CREATE (:A{a: 3, b:2})"); + ASSERT_THROW(Interpret("CREATE (:A {a: 12})"), memgraph::query::v2::QueryException); + Interpret("MATCH (n:A{a:3, b: 2}) SET n.b=5"); + Interpret("CREATE (:A{a:2, b: 3})"); + Interpret("MATCH (n:A{a:3, b: 1}) DETACH DELETE n"); + Interpret("CREATE (n:A{a:2, b: 3})"); + ASSERT_THROW(Interpret("CREATE CONSTRAINT ON (n:A) ASSERT EXISTS (n.c);"), + memgraph::query::v2::QueryRuntimeException); +} + +TEST_F(InterpreterTest, UniqueConstraintTest) { + ASSERT_NO_THROW(Interpret("CREATE SCHEMA ON :A(a INTEGER);")); + + // Empty property list should result with syntax exception. + ASSERT_THROW(Interpret("CREATE CONSTRAINT ON (n:A) ASSERT IS UNIQUE;"), memgraph::query::v2::SyntaxException); + ASSERT_THROW(Interpret("DROP CONSTRAINT ON (n:A) ASSERT IS UNIQUE;"), memgraph::query::v2::SyntaxException); + + // Too large list of properties should also result with syntax exception. + { + std::stringstream stream; + stream << " ON (n:A) ASSERT "; + for (size_t i = 0; i < 33; ++i) { + if (i > 0) stream << ", "; + stream << "n.prop" << i; + } + stream << " IS UNIQUE;"; + std::string create_query = "CREATE CONSTRAINT" + stream.str(); + std::string drop_query = "DROP CONSTRAINT" + stream.str(); + ASSERT_THROW(Interpret(create_query), memgraph::query::v2::SyntaxException); + ASSERT_THROW(Interpret(drop_query), memgraph::query::v2::SyntaxException); + } + + // Providing property list with duplicates results with syntax exception. + ASSERT_THROW(Interpret("CREATE CONSTRAINT ON (n:A) ASSERT n.a, n.b, n.a IS UNIQUE;"), + memgraph::query::v2::SyntaxException); + ASSERT_THROW(Interpret("DROP CONSTRAINT ON (n:A) ASSERT n.a, n.b, n.a IS UNIQUE;"), + memgraph::query::v2::SyntaxException); + + // Commit of vertex should fail if a constraint is violated. + Interpret("CREATE CONSTRAINT ON (n:A) ASSERT n.a, n.b IS UNIQUE;"); + Interpret("CREATE (:A{a:1, b:2})"); + Interpret("CREATE (:A{a:1, b:3})"); + ASSERT_THROW(Interpret("CREATE (:A{a:1, b:2})"), memgraph::query::v2::QueryException); + + // Attempt to create a constraint should fail if it's violated. + Interpret("CREATE (:A{a:1, c:2})"); + Interpret("CREATE (:A{a:1, c:2})"); + ASSERT_THROW(Interpret("CREATE CONSTRAINT ON (n:A) ASSERT n.a, n.c IS UNIQUE;"), + memgraph::query::v2::QueryRuntimeException); + + Interpret("MATCH (n:A{a:2, b:2}) SET n.a=1"); + Interpret("CREATE (:A{a:2})"); + Interpret("MATCH (n:A{a:2}) DETACH DELETE n"); + Interpret("CREATE (n:A{a:2})"); + + // Show constraint info. + { + auto stream = Interpret("SHOW CONSTRAINT INFO"); + ASSERT_EQ(stream.GetHeader().size(), 3U); + const auto &header = stream.GetHeader(); + ASSERT_EQ(header[0], "constraint type"); + ASSERT_EQ(header[1], "label"); + ASSERT_EQ(header[2], "properties"); + ASSERT_EQ(stream.GetResults().size(), 1U); + const auto &result = stream.GetResults().front(); + ASSERT_EQ(result.size(), 3U); + ASSERT_EQ(result[0].ValueString(), "unique"); + ASSERT_EQ(result[1].ValueString(), "A"); + const auto &properties = result[2].ValueList(); + ASSERT_EQ(properties.size(), 2U); + ASSERT_EQ(properties[0].ValueString(), "a"); + ASSERT_EQ(properties[1].ValueString(), "b"); + } + + // Drop constraint. + Interpret("DROP CONSTRAINT ON (n:A) ASSERT n.a, n.b IS UNIQUE;"); + // Removing the same constraint twice should not throw any exception. + Interpret("DROP CONSTRAINT ON (n:A) ASSERT n.a, n.b IS UNIQUE;"); +} + +TEST_F(InterpreterTest, ExplainQuery) { + const auto &interpreter_context = default_interpreter.interpreter_context; + + EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); + auto stream = Interpret("EXPLAIN MATCH (n) RETURN *;"); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader().front(), "QUERY PLAN"); + std::vector<std::string> expected_rows{" * Produce {n}", " * ScanAll (n)", " * Once"}; + ASSERT_EQ(stream.GetResults().size(), expected_rows.size()); + auto expected_it = expected_rows.begin(); + for (const auto &row : stream.GetResults()) { + ASSERT_EQ(row.size(), 1U); + EXPECT_EQ(row.front().ValueString(), *expected_it); + ++expected_it; + } + // We should have a plan cache for MATCH ... + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + // We should have AST cache for EXPLAIN ... and for inner MATCH ... + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); + Interpret("MATCH (n) RETURN *;"); + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); +} + +TEST_F(InterpreterTest, ExplainQueryMultiplePulls) { + const auto &interpreter_context = default_interpreter.interpreter_context; + + EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); + auto [stream, qid] = Prepare("EXPLAIN MATCH (n) RETURN *;"); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader().front(), "QUERY PLAN"); + std::vector<std::string> expected_rows{" * Produce {n}", " * ScanAll (n)", " * Once"}; + Pull(&stream, 1); + ASSERT_EQ(stream.GetResults().size(), 1); + auto expected_it = expected_rows.begin(); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + EXPECT_EQ(stream.GetResults()[0].front().ValueString(), *expected_it); + ++expected_it; + + Pull(&stream, 1); + ASSERT_EQ(stream.GetResults().size(), 2); + ASSERT_EQ(stream.GetResults()[1].size(), 1U); + EXPECT_EQ(stream.GetResults()[1].front().ValueString(), *expected_it); + ++expected_it; + + Pull(&stream); + ASSERT_EQ(stream.GetResults().size(), 3); + ASSERT_EQ(stream.GetResults()[2].size(), 1U); + EXPECT_EQ(stream.GetResults()[2].front().ValueString(), *expected_it); + // We should have a plan cache for MATCH ... + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + // We should have AST cache for EXPLAIN ... and for inner MATCH ... + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); + Interpret("MATCH (n) RETURN *;"); + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); +} + +TEST_F(InterpreterTest, ExplainQueryInMulticommandTransaction) { + const auto &interpreter_context = default_interpreter.interpreter_context; + + EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); + Interpret("BEGIN"); + auto stream = Interpret("EXPLAIN MATCH (n) RETURN *;"); + Interpret("COMMIT"); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader().front(), "QUERY PLAN"); + std::vector<std::string> expected_rows{" * Produce {n}", " * ScanAll (n)", " * Once"}; + ASSERT_EQ(stream.GetResults().size(), expected_rows.size()); + auto expected_it = expected_rows.begin(); + for (const auto &row : stream.GetResults()) { + ASSERT_EQ(row.size(), 1U); + EXPECT_EQ(row.front().ValueString(), *expected_it); + ++expected_it; + } + // We should have a plan cache for MATCH ... + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + // We should have AST cache for EXPLAIN ... and for inner MATCH ... + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); + Interpret("MATCH (n) RETURN *;"); + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); +} + +TEST_F(InterpreterTest, ExplainQueryWithParams) { + const auto &interpreter_context = default_interpreter.interpreter_context; + + EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); + auto stream = + Interpret("EXPLAIN MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::v3::PropertyValue(42)}}); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader().front(), "QUERY PLAN"); + std::vector<std::string> expected_rows{" * Produce {n}", " * Filter", " * ScanAll (n)", " * Once"}; + ASSERT_EQ(stream.GetResults().size(), expected_rows.size()); + auto expected_it = expected_rows.begin(); + for (const auto &row : stream.GetResults()) { + ASSERT_EQ(row.size(), 1U); + EXPECT_EQ(row.front().ValueString(), *expected_it); + ++expected_it; + } + // We should have a plan cache for MATCH ... + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + // We should have AST cache for EXPLAIN ... and for inner MATCH ... + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); + Interpret("MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::v3::PropertyValue("something else")}}); + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); +} + +TEST_F(InterpreterTest, ProfileQuery) { + const auto &interpreter_context = default_interpreter.interpreter_context; + + EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); + auto stream = Interpret("PROFILE MATCH (n) RETURN *;"); + std::vector<std::string> expected_header{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"}; + EXPECT_EQ(stream.GetHeader(), expected_header); + std::vector<std::string> expected_rows{"* Produce", "* ScanAll", "* Once"}; + ASSERT_EQ(stream.GetResults().size(), expected_rows.size()); + auto expected_it = expected_rows.begin(); + for (const auto &row : stream.GetResults()) { + ASSERT_EQ(row.size(), 4U); + EXPECT_EQ(row.front().ValueString(), *expected_it); + ++expected_it; + } + // We should have a plan cache for MATCH ... + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + // We should have AST cache for PROFILE ... and for inner MATCH ... + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); + Interpret("MATCH (n) RETURN *;"); + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); +} + +TEST_F(InterpreterTest, ProfileQueryMultiplePulls) { + const auto &interpreter_context = default_interpreter.interpreter_context; + + EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); + auto [stream, qid] = Prepare("PROFILE MATCH (n) RETURN *;"); + std::vector<std::string> expected_header{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"}; + EXPECT_EQ(stream.GetHeader(), expected_header); + + std::vector<std::string> expected_rows{"* Produce", "* ScanAll", "* Once"}; + auto expected_it = expected_rows.begin(); + + Pull(&stream, 1); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 4U); + ASSERT_EQ(stream.GetResults()[0][0].ValueString(), *expected_it); + ++expected_it; + + Pull(&stream, 1); + ASSERT_EQ(stream.GetResults().size(), 2U); + ASSERT_EQ(stream.GetResults()[1].size(), 4U); + ASSERT_EQ(stream.GetResults()[1][0].ValueString(), *expected_it); + ++expected_it; + + Pull(&stream); + ASSERT_EQ(stream.GetResults().size(), 3U); + ASSERT_EQ(stream.GetResults()[2].size(), 4U); + ASSERT_EQ(stream.GetResults()[2][0].ValueString(), *expected_it); + + // We should have a plan cache for MATCH ... + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + // We should have AST cache for PROFILE ... and for inner MATCH ... + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); + Interpret("MATCH (n) RETURN *;"); + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); +} + +TEST_F(InterpreterTest, ProfileQueryInMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("PROFILE MATCH (n) RETURN *;"), memgraph::query::v2::ProfileInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, ProfileQueryWithParams) { + const auto &interpreter_context = default_interpreter.interpreter_context; + + EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); + auto stream = + Interpret("PROFILE MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::v3::PropertyValue(42)}}); + std::vector<std::string> expected_header{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"}; + EXPECT_EQ(stream.GetHeader(), expected_header); + std::vector<std::string> expected_rows{"* Produce", "* Filter", "* ScanAll", "* Once"}; + ASSERT_EQ(stream.GetResults().size(), expected_rows.size()); + auto expected_it = expected_rows.begin(); + for (const auto &row : stream.GetResults()) { + ASSERT_EQ(row.size(), 4U); + EXPECT_EQ(row.front().ValueString(), *expected_it); + ++expected_it; + } + // We should have a plan cache for MATCH ... + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + // We should have AST cache for PROFILE ... and for inner MATCH ... + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); + Interpret("MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::v3::PropertyValue("something else")}}); + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); +} + +TEST_F(InterpreterTest, ProfileQueryWithLiterals) { + const auto &interpreter_context = default_interpreter.interpreter_context; + ASSERT_NO_THROW(Interpret("CREATE SCHEMA ON :Node(id INTEGER)")); + + EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 1U); + auto stream = Interpret("PROFILE UNWIND range(1, 1000) AS x CREATE (:Node {id: x});", {}); + std::vector<std::string> expected_header{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"}; + EXPECT_EQ(stream.GetHeader(), expected_header); + std::vector<std::string> expected_rows{"* CreateNode", "* Unwind", "* Once"}; + ASSERT_EQ(stream.GetResults().size(), expected_rows.size()); + auto expected_it = expected_rows.begin(); + for (const auto &row : stream.GetResults()) { + ASSERT_EQ(row.size(), 4U); + EXPECT_EQ(row.front().ValueString(), *expected_it); + ++expected_it; + } + // We should have a plan cache for UNWIND ... + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + // We should have AST cache for PROFILE ... and for inner UNWIND ... + EXPECT_EQ(interpreter_context.ast_cache.size(), 3U); + Interpret("UNWIND range(42, 4242) AS x CREATE (:Node {id: x});", {}); + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 3U); +} + +TEST_F(InterpreterTest, Transactions) { + auto &interpreter = default_interpreter.interpreter; + { + ASSERT_THROW(interpreter.CommitTransaction(), memgraph::query::v2::ExplicitTransactionUsageException); + ASSERT_THROW(interpreter.RollbackTransaction(), memgraph::query::v2::ExplicitTransactionUsageException); + interpreter.BeginTransaction(); + ASSERT_THROW(interpreter.BeginTransaction(), memgraph::query::v2::ExplicitTransactionUsageException); + auto [stream, qid] = Prepare("RETURN 2"); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "2"); + Pull(&stream, 1); + ASSERT_EQ(stream.GetSummary().count("has_more"), 1); + ASSERT_FALSE(stream.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueInt(), 2); + interpreter.CommitTransaction(); + } + { + interpreter.BeginTransaction(); + auto [stream, qid] = Prepare("RETURN 2"); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "2"); + Pull(&stream, 1); + ASSERT_EQ(stream.GetSummary().count("has_more"), 1); + ASSERT_FALSE(stream.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueInt(), 2); + interpreter.RollbackTransaction(); + } +} + +TEST_F(InterpreterTest, Qid) { + auto &interpreter = default_interpreter.interpreter; + { + interpreter.BeginTransaction(); + auto [stream, qid] = Prepare("RETURN 2"); + ASSERT_TRUE(qid); + ASSERT_THROW(Pull(&stream, {}, *qid + 1), memgraph::query::v2::InvalidArgumentsException); + interpreter.RollbackTransaction(); + } + { + interpreter.BeginTransaction(); + auto [stream1, qid1] = Prepare("UNWIND(range(1,3)) as n RETURN n"); + ASSERT_TRUE(qid1); + ASSERT_EQ(stream1.GetHeader().size(), 1U); + EXPECT_EQ(stream1.GetHeader()[0], "n"); + + auto [stream2, qid2] = Prepare("UNWIND(range(4,6)) as n RETURN n"); + ASSERT_TRUE(qid2); + ASSERT_EQ(stream2.GetHeader().size(), 1U); + EXPECT_EQ(stream2.GetHeader()[0], "n"); + + Pull(&stream1, 1, qid1); + ASSERT_EQ(stream1.GetSummary().count("has_more"), 1); + ASSERT_TRUE(stream1.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream1.GetResults().size(), 1U); + ASSERT_EQ(stream1.GetResults()[0].size(), 1U); + ASSERT_EQ(stream1.GetResults()[0][0].ValueInt(), 1); + + auto [stream3, qid3] = Prepare("UNWIND(range(7,9)) as n RETURN n"); + ASSERT_TRUE(qid3); + ASSERT_EQ(stream3.GetHeader().size(), 1U); + EXPECT_EQ(stream3.GetHeader()[0], "n"); + + Pull(&stream2, {}, qid2); + ASSERT_EQ(stream2.GetSummary().count("has_more"), 1); + ASSERT_FALSE(stream2.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream2.GetResults().size(), 3U); + ASSERT_EQ(stream2.GetResults()[0].size(), 1U); + ASSERT_EQ(stream2.GetResults()[0][0].ValueInt(), 4); + ASSERT_EQ(stream2.GetResults()[1][0].ValueInt(), 5); + ASSERT_EQ(stream2.GetResults()[2][0].ValueInt(), 6); + + Pull(&stream3, 1, qid3); + ASSERT_EQ(stream3.GetSummary().count("has_more"), 1); + ASSERT_TRUE(stream3.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream3.GetResults().size(), 1U); + ASSERT_EQ(stream3.GetResults()[0].size(), 1U); + ASSERT_EQ(stream3.GetResults()[0][0].ValueInt(), 7); + + Pull(&stream1, {}, qid1); + ASSERT_EQ(stream1.GetSummary().count("has_more"), 1); + ASSERT_FALSE(stream1.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream1.GetResults().size(), 3U); + ASSERT_EQ(stream1.GetResults()[1].size(), 1U); + ASSERT_EQ(stream1.GetResults()[1][0].ValueInt(), 2); + ASSERT_EQ(stream1.GetResults()[2][0].ValueInt(), 3); + + Pull(&stream3); + ASSERT_EQ(stream3.GetSummary().count("has_more"), 1); + ASSERT_FALSE(stream3.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream3.GetResults().size(), 3U); + ASSERT_EQ(stream3.GetResults()[1].size(), 1U); + ASSERT_EQ(stream3.GetResults()[1][0].ValueInt(), 8); + ASSERT_EQ(stream3.GetResults()[2][0].ValueInt(), 9); + + interpreter.CommitTransaction(); + } +} + +namespace { +// copied from utils_csv_parsing.cpp - tmp dir management and csv file writer +class TmpDirManager final { + public: + explicit TmpDirManager(const std::string_view directory) + : tmp_dir_{std::filesystem::temp_directory_path() / directory} { + CreateDir(); + } + ~TmpDirManager() { Clear(); } + + const std::filesystem::path &Path() const { return tmp_dir_; } + + private: + std::filesystem::path tmp_dir_; + + void CreateDir() { + if (!std::filesystem::exists(tmp_dir_)) { + std::filesystem::create_directory(tmp_dir_); + } + } + + void Clear() { + if (!std::filesystem::exists(tmp_dir_)) return; + std::filesystem::remove_all(tmp_dir_); + } +}; + +class FileWriter { + public: + explicit FileWriter(const std::filesystem::path path) { stream_.open(path); } + + FileWriter(const FileWriter &) = delete; + FileWriter &operator=(const FileWriter &) = delete; + + FileWriter(FileWriter &&) = delete; + FileWriter &operator=(FileWriter &&) = delete; + + void Close() { stream_.close(); } + + size_t WriteLine(const std::string_view line) { + if (!stream_.is_open()) { + return 0; + } + + stream_ << line << std::endl; + + // including the newline character + return line.size() + 1; + } + + private: + std::ofstream stream_; +}; + +std::string CreateRow(const std::vector<std::string> &columns, const std::string_view delim) { + return memgraph::utils::Join(columns, delim); +} +} // namespace + +TEST_F(InterpreterTest, LoadCsvClause) { + auto dir_manager = TmpDirManager("csv_directory"); + const auto csv_path = dir_manager.Path() / "file.csv"; + auto writer = FileWriter(csv_path); + + const std::string delimiter{"|"}; + + const std::vector<std::string> header{"A", "B", "C"}; + writer.WriteLine(CreateRow(header, delimiter)); + + const std::vector<std::string> good_columns_1{"a", "b", "c"}; + writer.WriteLine(CreateRow(good_columns_1, delimiter)); + + const std::vector<std::string> bad_columns{"\"\"1", "2", "3"}; + writer.WriteLine(CreateRow(bad_columns, delimiter)); + + const std::vector<std::string> good_columns_2{"d", "e", "f"}; + writer.WriteLine(CreateRow(good_columns_2, delimiter)); + + writer.Close(); + + { + const std::string query = fmt::format(R"(LOAD CSV FROM "{}" WITH HEADER IGNORE BAD DELIMITER "{}" AS x RETURN + x.A)", + csv_path.string(), delimiter); + auto [stream, qid] = Prepare(query); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "x.A"); + + Pull(&stream, 1); + ASSERT_EQ(stream.GetSummary().count("has_more"), 1); + ASSERT_TRUE(stream.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueString(), "a"); + + Pull(&stream, 1); + ASSERT_EQ(stream.GetSummary().count("has_more"), 1); + ASSERT_FALSE(stream.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream.GetResults().size(), 2U); + ASSERT_EQ(stream.GetResults()[1][0].ValueString(), "d"); + } + + { + const std::string query = fmt::format(R"(LOAD CSV FROM "{}" WITH HEADER IGNORE BAD DELIMITER "{}" AS x RETURN + x.C)", + csv_path.string(), delimiter); + auto [stream, qid] = Prepare(query); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "x.C"); + + Pull(&stream); + ASSERT_EQ(stream.GetSummary().count("has_more"), 1); + ASSERT_FALSE(stream.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream.GetResults().size(), 2U); + ASSERT_EQ(stream.GetResults()[0][0].ValueString(), "c"); + ASSERT_EQ(stream.GetResults()[1][0].ValueString(), "f"); + } +} + +TEST_F(InterpreterTest, CacheableQueries) { + const auto &interpreter_context = default_interpreter.interpreter_context; + // This should be cached + { + SCOPED_TRACE("Cacheable query"); + Interpret("RETURN 1"); + EXPECT_EQ(interpreter_context.ast_cache.size(), 1U); + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + } + + { + SCOPED_TRACE("Uncacheable query"); + // Queries which are calling procedure should not be cached because the + // result signature could be changed + Interpret("CALL mg.load_all()"); + EXPECT_EQ(interpreter_context.ast_cache.size(), 1U); + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + } +} + +TEST_F(InterpreterTest, AllowLoadCsvConfig) { + const auto check_load_csv_queries = [&](const bool allow_load_csv) { + TmpDirManager directory_manager{"allow_load_csv"}; + const auto csv_path = directory_manager.Path() / "file.csv"; + auto writer = FileWriter(csv_path); + const std::vector<std::string> data{"A", "B", "C"}; + writer.WriteLine(CreateRow(data, ",")); + writer.Close(); + + const std::array<std::string, 2> queries = { + fmt::format("LOAD CSV FROM \"{}\" WITH HEADER AS row RETURN row", csv_path.string()), + "CREATE TRIGGER trigger ON CREATE BEFORE COMMIT EXECUTE LOAD CSV FROM 'file.csv' WITH HEADER AS row RETURN " + "row"}; + + InterpreterFaker interpreter_faker{&db_, {.query = {.allow_load_csv = allow_load_csv}}, directory_manager.Path()}; + for (const auto &query : queries) { + if (allow_load_csv) { + SCOPED_TRACE(fmt::format("'{}' should not throw because LOAD CSV is allowed", query)); + ASSERT_NO_THROW(interpreter_faker.Interpret(query)); + } else { + SCOPED_TRACE(fmt::format("'{}' should throw becuase LOAD CSV is not allowed", query)); + ASSERT_THROW(interpreter_faker.Interpret(query), memgraph::utils::BasicException); + } + SCOPED_TRACE(fmt::format("Normal query should not throw (allow_load_csv: {})", allow_load_csv)); + ASSERT_NO_THROW(interpreter_faker.Interpret("RETURN 1")); + } + }; + + check_load_csv_queries(true); + check_load_csv_queries(false); +} + +void AssertAllValuesAreZero(const std::map<std::string, memgraph::communication::bolt::Value> &map, + const std::vector<std::string> &exceptions) { + for (const auto &[key, value] : map) { + if (const auto it = std::find(exceptions.begin(), exceptions.end(), key); it != exceptions.end()) continue; + ASSERT_EQ(value.ValueInt(), 0) << "Value " << key << " actual: " << value.ValueInt() << ", expected 0"; + } +} + +TEST_F(InterpreterTest, ExecutionStatsIsValid) { + { + auto [stream, qid] = Prepare("MATCH (n) DELETE n;"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("stats"), 0); + } + { + EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :L1(name STRING)")); + std::array stats_keys{"nodes-created", "nodes-deleted", "relationships-created", "relationships-deleted", + "properties-set", "labels-added", "labels-removed"}; + auto [stream, qid] = Prepare("CREATE (:L1 {name: 'name1'});"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("stats"), 1); + ASSERT_TRUE(stream.GetSummary().at("stats").IsMap()); + auto stats = stream.GetSummary().at("stats").ValueMap(); + ASSERT_TRUE( + std::all_of(stats_keys.begin(), stats_keys.end(), [&stats](const auto &key) { return stats.contains(key); })); + AssertAllValuesAreZero(stats, {"nodes-created"}); + } +} + +TEST_F(InterpreterTest, ExecutionStatsValues) { + EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :L1(name STRING)")); + { + auto [stream, qid] = + Prepare("CREATE (:L1{name: 'name1'}),(:L1{name: 'name2'}),(:L1{name: 'name3'}),(:L1{name: 'name4'});"); + + Pull(&stream); + auto stats = stream.GetSummary().at("stats").ValueMap(); + ASSERT_EQ(stats["nodes-created"].ValueInt(), 4); + AssertAllValuesAreZero(stats, {"nodes-created", "labels-added"}); + } + { + auto [stream, qid] = Prepare("MATCH (n) DELETE n;"); + Pull(&stream); + + auto stats = stream.GetSummary().at("stats").ValueMap(); + ASSERT_EQ(stats["nodes-deleted"].ValueInt(), 4); + AssertAllValuesAreZero(stats, {"nodes-deleted"}); + } + { + auto [stream, qid] = + Prepare("CREATE (n:L1 {name: 'name5'})-[:TO]->(m:L1{name: 'name6'}), (n)-[:TO]->(m), (n)-[:TO]->(m);"); + + Pull(&stream); + + auto stats = stream.GetSummary().at("stats").ValueMap(); + ASSERT_EQ(stats["nodes-created"].ValueInt(), 2); + ASSERT_EQ(stats["relationships-created"].ValueInt(), 3); + AssertAllValuesAreZero(stats, {"nodes-created", "relationships-created"}); + } + { + auto [stream, qid] = Prepare("MATCH (n) DETACH DELETE n;"); + Pull(&stream); + + auto stats = stream.GetSummary().at("stats").ValueMap(); + ASSERT_EQ(stats["nodes-deleted"].ValueInt(), 2); + ASSERT_EQ(stats["relationships-deleted"].ValueInt(), 3); + AssertAllValuesAreZero(stats, {"nodes-deleted", "relationships-deleted"}); + } + { + auto [stream, qid] = Prepare("CREATE (n:L1 {name: 'name7'}) SET n:L2:L3:L4"); + Pull(&stream); + + auto stats = stream.GetSummary().at("stats").ValueMap(); + ASSERT_EQ(stats["nodes-created"].ValueInt(), 1); + ASSERT_EQ(stats["labels-added"].ValueInt(), 3); + AssertAllValuesAreZero(stats, {"nodes-created", "labels-added"}); + } + { + auto [stream, qid] = Prepare("MATCH (n:L1) SET n.name2='test';"); + Pull(&stream); + + auto stats = stream.GetSummary().at("stats").ValueMap(); + ASSERT_EQ(stats["properties-set"].ValueInt(), 1); + AssertAllValuesAreZero(stats, {"properties-set"}); + } +} + +TEST_F(InterpreterTest, NotificationsValidStructure) { + { + auto [stream, qid] = Prepare("MATCH (n) DELETE n;"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 0); + } + { + auto [stream, qid] = Prepare("CREATE INDEX ON :Person(id);"); + Pull(&stream); + + // Assert notifications list + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + ASSERT_TRUE(stream.GetSummary().at("notifications").IsList()); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + // Assert one notification structure + ASSERT_EQ(notifications.size(), 1); + ASSERT_TRUE(notifications[0].IsMap()); + auto notification = notifications[0].ValueMap(); + ASSERT_TRUE(notification.contains("severity")); + ASSERT_TRUE(notification.contains("code")); + ASSERT_TRUE(notification.contains("title")); + ASSERT_TRUE(notification.contains("description")); + ASSERT_TRUE(notification["severity"].IsString()); + ASSERT_TRUE(notification["code"].IsString()); + ASSERT_TRUE(notification["title"].IsString()); + ASSERT_TRUE(notification["description"].IsString()); + } +} + +TEST_F(InterpreterTest, IndexInfoNotifications) { + { + auto [stream, qid] = Prepare("CREATE INDEX ON :Person;"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "CreateIndex"); + ASSERT_EQ(notification["title"].ValueString(), "Created index on label Person on properties ."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("CREATE INDEX ON :Person(id);"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "CreateIndex"); + ASSERT_EQ(notification["title"].ValueString(), "Created index on label Person on properties id."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("CREATE INDEX ON :Person(id);"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "IndexAlreadyExists"); + ASSERT_EQ(notification["title"].ValueString(), "Index on label Person on properties id already exists."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("DROP INDEX ON :Person(id);"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "DropIndex"); + ASSERT_EQ(notification["title"].ValueString(), "Dropped index on label Person on properties id."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("DROP INDEX ON :Person(id);"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "IndexDoesNotExist"); + ASSERT_EQ(notification["title"].ValueString(), "Index on label Person on properties id doesn't exist."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } +} + +TEST_F(InterpreterTest, ConstraintUniqueInfoNotifications) { + { + auto [stream, qid] = Prepare("CREATE CONSTRAINT ON (n:Person) ASSERT n.email, n.id IS UNIQUE;"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "CreateConstraint"); + ASSERT_EQ(notification["title"].ValueString(), + "Created UNIQUE constraint on label Person on properties email, id."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("CREATE CONSTRAINT ON (n:Person) ASSERT n.email, n.id IS UNIQUE;"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "ConstraintAlreadyExists"); + ASSERT_EQ(notification["title"].ValueString(), + "Constraint UNIQUE on label Person on properties email, id already exists."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("DROP CONSTRAINT ON (n:Person) ASSERT n.email, n.id IS UNIQUE;"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "DropConstraint"); + ASSERT_EQ(notification["title"].ValueString(), + "Dropped UNIQUE constraint on label Person on properties email, id."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("DROP CONSTRAINT ON (n:Person) ASSERT n.email, n.id IS UNIQUE;"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "ConstraintDoesNotExist"); + ASSERT_EQ(notification["title"].ValueString(), + "Constraint UNIQUE on label Person on properties email, id doesn't exist."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } +} + +TEST_F(InterpreterTest, ConstraintExistsInfoNotifications) { + { + auto [stream, qid] = Prepare("CREATE CONSTRAINT ON (n:L1) ASSERT EXISTS (n.name);"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "CreateConstraint"); + ASSERT_EQ(notification["title"].ValueString(), "Created EXISTS constraint on label L1 on properties name."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("CREATE CONSTRAINT ON (n:L1) ASSERT EXISTS (n.name);"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "ConstraintAlreadyExists"); + ASSERT_EQ(notification["title"].ValueString(), "Constraint EXISTS on label L1 on properties name already exists."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("DROP CONSTRAINT ON (n:L1) ASSERT EXISTS (n.name);"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "DropConstraint"); + ASSERT_EQ(notification["title"].ValueString(), "Dropped EXISTS constraint on label L1 on properties name."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("DROP CONSTRAINT ON (n:L1) ASSERT EXISTS (n.name);"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "ConstraintDoesNotExist"); + ASSERT_EQ(notification["title"].ValueString(), "Constraint EXISTS on label L1 on properties name doesn't exist."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } +} + +TEST_F(InterpreterTest, TriggerInfoNotifications) { + { + auto [stream, qid] = Prepare( + "CREATE TRIGGER bestTriggerEver ON CREATE AFTER COMMIT EXECUTE " + "CREATE ();"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "CreateTrigger"); + ASSERT_EQ(notification["title"].ValueString(), "Created trigger bestTriggerEver."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("DROP TRIGGER bestTriggerEver;"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "DropTrigger"); + ASSERT_EQ(notification["title"].ValueString(), "Dropped trigger bestTriggerEver."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } +} + +TEST_F(InterpreterTest, LoadCsvClauseNotification) { + auto dir_manager = TmpDirManager("csv_directory"); + const auto csv_path = dir_manager.Path() / "file.csv"; + auto writer = FileWriter(csv_path); + + const std::string delimiter{"|"}; + + const std::vector<std::string> header{"A", "B", "C"}; + writer.WriteLine(CreateRow(header, delimiter)); + + const std::vector<std::string> good_columns_1{"a", "b", "c"}; + writer.WriteLine(CreateRow(good_columns_1, delimiter)); + + writer.Close(); + + const std::string query = fmt::format(R"(LOAD CSV FROM "{}" WITH HEADER IGNORE BAD DELIMITER "{}" AS x RETURN x;)", + csv_path.string(), delimiter); + auto [stream, qid] = Prepare(query); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "LoadCSVTip"); + ASSERT_EQ(notification["title"].ValueString(), + "It's important to note that the parser parses the values as strings. It's up to the user to " + "convert the parsed row values to the appropriate type. This can be done using the built-in " + "conversion functions such as ToInteger, ToFloat, ToBoolean etc."); + ASSERT_EQ(notification["description"].ValueString(), ""); +} + +TEST_F(InterpreterTest, CreateSchemaMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING, age INTEGER)"), + memgraph::query::v2::ConstraintInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, ShowSchemasMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("SHOW SCHEMAS"), memgraph::query::v2::ConstraintInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, ShowSchemaMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("SHOW SCHEMA ON :label"), memgraph::query::v2::ConstraintInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, DropSchemaMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("DROP SCHEMA ON :label"), memgraph::query::v2::ConstraintInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, SchemaTestCreateAndShow) { + // Empty schema type map should result with syntax exception. + ASSERT_THROW(Interpret("CREATE SCHEMA ON :label();"), memgraph::query::v2::SyntaxException); + + // Duplicate properties are should also cause an exception + ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING, name STRING);"), memgraph::query::v2::SemanticException); + ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING, name INTEGER);"), + memgraph::query::v2::SemanticException); + + { + // Cannot create same schema twice + Interpret("CREATE SCHEMA ON :label(name STRING, age INTEGER)"); + ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING);"), memgraph::query::v2::QueryException); + } + // Show schema + { + auto stream = Interpret("SHOW SCHEMA ON :label"); + ASSERT_EQ(stream.GetHeader().size(), 2U); + const auto &header = stream.GetHeader(); + ASSERT_EQ(header[0], "property_name"); + ASSERT_EQ(header[1], "property_type"); + ASSERT_EQ(stream.GetResults().size(), 2U); + std::unordered_map<std::string, std::string> result_table{{"age", "Integer"}, {"name", "String"}}; + + const auto &result = stream.GetResults().front(); + ASSERT_EQ(result.size(), 2U); + const auto key1 = result[0].ValueString(); + ASSERT_TRUE(result_table.contains(key1)); + ASSERT_EQ(result[1].ValueString(), result_table[key1]); + + const auto &result2 = stream.GetResults().front(); + ASSERT_EQ(result2.size(), 2U); + const auto key2 = result2[0].ValueString(); + ASSERT_TRUE(result_table.contains(key2)); + ASSERT_EQ(result[1].ValueString(), result_table[key2]); + } + // Create Another Schema + Interpret("CREATE SCHEMA ON :label2(place STRING, dur DURATION)"); + + // Show schemas + { + auto stream = Interpret("SHOW SCHEMAS"); + ASSERT_EQ(stream.GetHeader().size(), 2U); + const auto &header = stream.GetHeader(); + ASSERT_EQ(header[0], "label"); + ASSERT_EQ(header[1], "primary_key"); + ASSERT_EQ(stream.GetResults().size(), 2U); + std::unordered_map<std::string, std::unordered_set<std::string>> result_table{ + {"label", {"name::String", "age::Integer"}}, {"label2", {"place::String", "dur::Duration"}}}; + + const auto &result = stream.GetResults().front(); + ASSERT_EQ(result.size(), 2U); + const auto key1 = result[0].ValueString(); + ASSERT_TRUE(result_table.contains(key1)); + const auto primary_key_split = StringToUnorderedSet(result[1].ValueString()); + ASSERT_EQ(primary_key_split.size(), 2); + ASSERT_TRUE(primary_key_split == result_table[key1]) << "actual value is: " << result[1].ValueString(); + + const auto &result2 = stream.GetResults().front(); + ASSERT_EQ(result2.size(), 2U); + const auto key2 = result2[0].ValueString(); + ASSERT_TRUE(result_table.contains(key2)); + const auto primary_key_split2 = StringToUnorderedSet(result2[1].ValueString()); + ASSERT_EQ(primary_key_split2.size(), 2); + ASSERT_TRUE(primary_key_split2 == result_table[key2]) << "Real value is: " << result[1].ValueString(); + } +} + +TEST_F(InterpreterTest, SchemaTestCreateDropAndShow) { + Interpret("CREATE SCHEMA ON :label(name STRING, age INTEGER)"); + // Wrong syntax for dropping schema. + ASSERT_THROW(Interpret("DROP SCHEMA ON :label();"), memgraph::query::v2::SyntaxException); + // Cannot drop non existant schema. + ASSERT_THROW(Interpret("DROP SCHEMA ON :label1;"), memgraph::query::v2::QueryException); + + // Create Schema and Drop + auto get_number_of_schemas = [this]() { + auto stream = Interpret("SHOW SCHEMAS"); + return stream.GetResults().size(); + }; + + ASSERT_EQ(get_number_of_schemas(), 1); + Interpret("CREATE SCHEMA ON :label1(name STRING, age INTEGER)"); + ASSERT_EQ(get_number_of_schemas(), 2); + Interpret("CREATE SCHEMA ON :label2(name STRING, alive BOOL)"); + ASSERT_EQ(get_number_of_schemas(), 3); + Interpret("DROP SCHEMA ON :label1"); + ASSERT_EQ(get_number_of_schemas(), 2); + Interpret("CREATE SCHEMA ON :label3(name STRING, birthday LOCALDATETIME)"); + ASSERT_EQ(get_number_of_schemas(), 3); + Interpret("DROP SCHEMA ON :label2"); + ASSERT_EQ(get_number_of_schemas(), 2); + Interpret("CREATE SCHEMA ON :label4(name STRING, age DURATION)"); + ASSERT_EQ(get_number_of_schemas(), 3); + Interpret("DROP SCHEMA ON :label3"); + ASSERT_EQ(get_number_of_schemas(), 2); + Interpret("DROP SCHEMA ON :label"); + ASSERT_EQ(get_number_of_schemas(), 1); + + // Show schemas + auto stream = Interpret("SHOW SCHEMAS"); + ASSERT_EQ(stream.GetHeader().size(), 2U); + const auto &header = stream.GetHeader(); + ASSERT_EQ(header[0], "label"); + ASSERT_EQ(header[1], "primary_key"); + ASSERT_EQ(stream.GetResults().size(), 1U); + std::unordered_map<std::string, std::unordered_set<std::string>> result_table{ + {"label4", {"name::String", "age::Duration"}}}; + + const auto &result = stream.GetResults().front(); + ASSERT_EQ(result.size(), 2U); + const auto key1 = result[0].ValueString(); + ASSERT_TRUE(result_table.contains(key1)); + const auto primary_key_split = StringToUnorderedSet(result[1].ValueString()); + ASSERT_EQ(primary_key_split.size(), 2); + ASSERT_TRUE(primary_key_split == result_table[key1]); +} diff --git a/tests/unit/query_v2_query_common.hpp b/tests/unit/query_v2_query_common.hpp new file mode 100644 index 000000000..5471ff42a --- /dev/null +++ b/tests/unit/query_v2_query_common.hpp @@ -0,0 +1,594 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file +/// This file provides macros for easier construction of openCypher query AST. +/// The usage of macros is very similar to how one would write openCypher. For +/// example: +/// +/// AstStorage storage; // Macros rely on storage being in scope. +/// // PROPERTY_LOOKUP and PROPERTY_PAIR macros +/// // rely on a DbAccessor *reference* named dba. +/// database::GraphDb db; +/// auto dba_ptr = db.Access(); +/// auto &dba = *dba_ptr; +/// +/// QUERY(MATCH(PATTERN(NODE("n"), EDGE("e"), NODE("m"))), +/// WHERE(LESS(PROPERTY_LOOKUP("e", edge_prop), LITERAL(3))), +/// RETURN(SUM(PROPERTY_LOOKUP("m", prop)), AS("sum"), +/// ORDER_BY(IDENT("sum")), +/// SKIP(ADD(LITERAL(1), LITERAL(2))))); +/// +/// Each of the macros is accompanied by a function. The functions use overload +/// resolution and template magic to provide a type safe way of constructing +/// queries. Although the functions can be used by themselves, it is more +/// convenient to use the macros. + +#pragma once + +#include <map> +#include <sstream> +#include <string> +#include <unordered_map> +#include <utility> +#include <vector> + +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/frontend/ast/pretty_print.hpp" +#include "storage/v3/id_types.hpp" +#include "utils/string.hpp" + +namespace memgraph::query::v2 { + +namespace test_common { + +auto ToIntList(const TypedValue &t) { + std::vector<int64_t> list; + for (auto x : t.ValueList()) { + list.push_back(x.ValueInt()); + } + return list; +}; + +auto ToIntMap(const TypedValue &t) { + std::map<std::string, int64_t> map; + for (const auto &kv : t.ValueMap()) map.emplace(kv.first, kv.second.ValueInt()); + return map; +}; + +std::string ToString(Expression *expr) { + std::ostringstream ss; + PrintExpression(expr, &ss); + return ss.str(); +} + +std::string ToString(NamedExpression *expr) { + std::ostringstream ss; + PrintExpression(expr, &ss); + return ss.str(); +} + +// Custom types for ORDER BY, SKIP, LIMIT, ON MATCH and ON CREATE expressions, +// so that they can be used to resolve function calls. +struct OrderBy { + std::vector<SortItem> expressions; +}; +struct Skip { + Expression *expression = nullptr; +}; +struct Limit { + Expression *expression = nullptr; +}; +struct OnMatch { + std::vector<Clause *> set; +}; +struct OnCreate { + std::vector<Clause *> set; +}; + +// Helper functions for filling the OrderBy with expressions. +auto FillOrderBy(OrderBy &order_by, Expression *expression, Ordering ordering = Ordering::ASC) { + order_by.expressions.push_back({ordering, expression}); +} +template <class... T> +auto FillOrderBy(OrderBy &order_by, Expression *expression, Ordering ordering, T... rest) { + FillOrderBy(order_by, expression, ordering); + FillOrderBy(order_by, rest...); +} +template <class... T> +auto FillOrderBy(OrderBy &order_by, Expression *expression, T... rest) { + FillOrderBy(order_by, expression); + FillOrderBy(order_by, rest...); +} + +/// Create OrderBy expressions. +/// +/// The supported combination of arguments is: (Expression, [Ordering])+ +/// Since the Ordering is optional, by default it is ascending. +template <class... T> +auto GetOrderBy(T... exprs) { + OrderBy order_by; + FillOrderBy(order_by, exprs...); + return order_by; +} + +/// Create PropertyLookup with given name and property. +/// +/// Name is used to create the Identifier which is used for property lookup. +template <class TDbAccessor> +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &dba, const std::string &name, + memgraph::storage::v3::PropertyId property) { + return storage.Create<PropertyLookup>(storage.Create<Identifier>(name), + storage.GetPropertyIx(dba.PropertyToName(property))); +} + +template <class TDbAccessor> +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &dba, Expression *expr, + memgraph::storage::v3::PropertyId property) { + return storage.Create<PropertyLookup>(expr, storage.GetPropertyIx(dba.PropertyToName(property))); +} + +template <class TDbAccessor> +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &dba, Expression *expr, const std::string &property) { + return storage.Create<PropertyLookup>(expr, storage.GetPropertyIx(property)); +} + +template <class TDbAccessor> +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &, const std::string &name, + const std::pair<std::string, memgraph::storage::v3::PropertyId> &prop_pair) { + return storage.Create<PropertyLookup>(storage.Create<Identifier>(name), storage.GetPropertyIx(prop_pair.first)); +} + +template <class TDbAccessor> +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &, Expression *expr, + const std::pair<std::string, memgraph::storage::v3::PropertyId> &prop_pair) { + return storage.Create<PropertyLookup>(expr, storage.GetPropertyIx(prop_pair.first)); +} + +/// Create an EdgeAtom with given name, direction and edge_type. +/// +/// Name is used to create the Identifier which is assigned to the edge. +auto GetEdge(AstStorage &storage, const std::string &name, EdgeAtom::Direction dir = EdgeAtom::Direction::BOTH, + const std::vector<std::string> &edge_types = {}) { + std::vector<EdgeTypeIx> types; + types.reserve(edge_types.size()); + for (const auto &type : edge_types) { + types.push_back(storage.GetEdgeTypeIx(type)); + } + return storage.Create<EdgeAtom>(storage.Create<Identifier>(name), EdgeAtom::Type::SINGLE, dir, types); +} + +/// Create a variable length expansion EdgeAtom with given name, direction and +/// edge_type. +/// +/// Name is used to create the Identifier which is assigned to the edge. +auto GetEdgeVariable(AstStorage &storage, const std::string &name, EdgeAtom::Type type = EdgeAtom::Type::DEPTH_FIRST, + EdgeAtom::Direction dir = EdgeAtom::Direction::BOTH, + const std::vector<std::string> &edge_types = {}, Identifier *flambda_inner_edge = nullptr, + Identifier *flambda_inner_node = nullptr, Identifier *wlambda_inner_edge = nullptr, + Identifier *wlambda_inner_node = nullptr, Expression *wlambda_expression = nullptr, + Identifier *total_weight = nullptr) { + std::vector<EdgeTypeIx> types; + types.reserve(edge_types.size()); + for (const auto &type : edge_types) { + types.push_back(storage.GetEdgeTypeIx(type)); + } + auto r_val = storage.Create<EdgeAtom>(storage.Create<Identifier>(name), type, dir, types); + + r_val->filter_lambda_.inner_edge = + flambda_inner_edge ? flambda_inner_edge : storage.Create<Identifier>(memgraph::utils::RandomString(20)); + r_val->filter_lambda_.inner_node = + flambda_inner_node ? flambda_inner_node : storage.Create<Identifier>(memgraph::utils::RandomString(20)); + + if (type == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) { + r_val->weight_lambda_.inner_edge = + wlambda_inner_edge ? wlambda_inner_edge : storage.Create<Identifier>(memgraph::utils::RandomString(20)); + r_val->weight_lambda_.inner_node = + wlambda_inner_node ? wlambda_inner_node : storage.Create<Identifier>(memgraph::utils::RandomString(20)); + r_val->weight_lambda_.expression = + wlambda_expression ? wlambda_expression : storage.Create<memgraph::query::v2::PrimitiveLiteral>(1); + + r_val->total_weight_ = total_weight; + } + + return r_val; +} + +/// Create a NodeAtom with given name and label. +/// +/// Name is used to create the Identifier which is assigned to the node. +auto GetNode(AstStorage &storage, const std::string &name, std::optional<std::string> label = std::nullopt) { + auto node = storage.Create<NodeAtom>(storage.Create<Identifier>(name)); + if (label) node->labels_.emplace_back(storage.GetLabelIx(*label)); + return node; +} + +/// Create a Pattern with given atoms. +auto GetPattern(AstStorage &storage, std::vector<PatternAtom *> atoms) { + auto pattern = storage.Create<Pattern>(); + pattern->identifier_ = storage.Create<Identifier>(memgraph::utils::RandomString(20), false); + pattern->atoms_.insert(pattern->atoms_.begin(), atoms.begin(), atoms.end()); + return pattern; +} + +/// Create a Pattern with given name and atoms. +auto GetPattern(AstStorage &storage, const std::string &name, std::vector<PatternAtom *> atoms) { + auto pattern = storage.Create<Pattern>(); + pattern->identifier_ = storage.Create<Identifier>(name, true); + pattern->atoms_.insert(pattern->atoms_.begin(), atoms.begin(), atoms.end()); + return pattern; +} + +/// This function fills an AST node which with given patterns. +/// +/// The function is most commonly used to create Match and Create clauses. +template <class TWithPatterns> +auto GetWithPatterns(TWithPatterns *with_patterns, std::vector<Pattern *> patterns) { + with_patterns->patterns_.insert(with_patterns->patterns_.begin(), patterns.begin(), patterns.end()); + return with_patterns; +} + +/// Create a query with given clauses. + +auto GetSingleQuery(SingleQuery *single_query, Clause *clause) { + single_query->clauses_.emplace_back(clause); + return single_query; +} +auto GetSingleQuery(SingleQuery *single_query, Match *match, Where *where) { + match->where_ = where; + single_query->clauses_.emplace_back(match); + return single_query; +} +auto GetSingleQuery(SingleQuery *single_query, With *with, Where *where) { + with->where_ = where; + single_query->clauses_.emplace_back(with); + return single_query; +} +template <class... T> +auto GetSingleQuery(SingleQuery *single_query, Match *match, Where *where, T *...clauses) { + match->where_ = where; + single_query->clauses_.emplace_back(match); + return GetSingleQuery(single_query, clauses...); +} +template <class... T> +auto GetSingleQuery(SingleQuery *single_query, With *with, Where *where, T *...clauses) { + with->where_ = where; + single_query->clauses_.emplace_back(with); + return GetSingleQuery(single_query, clauses...); +} + +template <class... T> +auto GetSingleQuery(SingleQuery *single_query, Clause *clause, T *...clauses) { + single_query->clauses_.emplace_back(clause); + return GetSingleQuery(single_query, clauses...); +} + +auto GetCypherUnion(CypherUnion *cypher_union, SingleQuery *single_query) { + cypher_union->single_query_ = single_query; + return cypher_union; +} + +auto GetQuery(AstStorage &storage, SingleQuery *single_query) { + auto *query = storage.Create<CypherQuery>(); + query->single_query_ = single_query; + return query; +} + +template <class... T> +auto GetQuery(AstStorage &storage, SingleQuery *single_query, T *...cypher_unions) { + auto *query = storage.Create<CypherQuery>(); + query->single_query_ = single_query; + query->cypher_unions_ = std::vector<CypherUnion *>{cypher_unions...}; + return query; +} + +// Helper functions for constructing RETURN and WITH clauses. +void FillReturnBody(AstStorage &, ReturnBody &body, NamedExpression *named_expr) { + body.named_expressions.emplace_back(named_expr); +} +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name) { + if (name == "*") { + body.all_identifiers = true; + } else { + auto *ident = storage.Create<memgraph::query::v2::Identifier>(name); + auto *named_expr = storage.Create<memgraph::query::v2::NamedExpression>(name, ident); + body.named_expressions.emplace_back(named_expr); + } +} +void FillReturnBody(AstStorage &, ReturnBody &body, Limit limit) { body.limit = limit.expression; } +void FillReturnBody(AstStorage &, ReturnBody &body, Skip skip, Limit limit = Limit{}) { + body.skip = skip.expression; + body.limit = limit.expression; +} +void FillReturnBody(AstStorage &, ReturnBody &body, OrderBy order_by, Limit limit = Limit{}) { + body.order_by = order_by.expressions; + body.limit = limit.expression; +} +void FillReturnBody(AstStorage &, ReturnBody &body, OrderBy order_by, Skip skip, Limit limit = Limit{}) { + body.order_by = order_by.expressions; + body.skip = skip.expression; + body.limit = limit.expression; +} +void FillReturnBody(AstStorage &, ReturnBody &body, Expression *expr, NamedExpression *named_expr) { + // This overload supports `RETURN(expr, AS(name))` construct, since + // NamedExpression does not inherit Expression. + named_expr->expression_ = expr; + body.named_expressions.emplace_back(named_expr); +} +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name, NamedExpression *named_expr) { + named_expr->expression_ = storage.Create<memgraph::query::v2::Identifier>(name); + body.named_expressions.emplace_back(named_expr); +} +template <class... T> +void FillReturnBody(AstStorage &storage, ReturnBody &body, Expression *expr, NamedExpression *named_expr, T... rest) { + named_expr->expression_ = expr; + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} +template <class... T> +void FillReturnBody(AstStorage &storage, ReturnBody &body, NamedExpression *named_expr, T... rest) { + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} +template <class... T> +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name, NamedExpression *named_expr, + T... rest) { + named_expr->expression_ = storage.Create<memgraph::query::v2::Identifier>(name); + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} +template <class... T> +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name, T... rest) { + auto *ident = storage.Create<memgraph::query::v2::Identifier>(name); + auto *named_expr = storage.Create<memgraph::query::v2::NamedExpression>(name, ident); + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} + +/// Create the return clause with given expressions. +/// +/// The supported expression combination of arguments is: +/// +/// (String | NamedExpression | (Expression NamedExpression))+ +/// [OrderBy] [Skip] [Limit] +/// +/// When the pair (Expression NamedExpression) is given, the Expression will be +/// moved inside the NamedExpression. This is done, so that the constructs like +/// RETURN(expr, AS("name"), ...) are supported. Taking a String is a shorthand +/// for RETURN(IDENT(string), AS(string), ....). +/// +/// @sa GetWith +template <class... T> +auto GetReturn(AstStorage &storage, bool distinct, T... exprs) { + auto ret = storage.Create<Return>(); + ret->body_.distinct = distinct; + FillReturnBody(storage, ret->body_, exprs...); + return ret; +} + +/// Create the with clause with given expressions. +/// +/// The supported expression combination is the same as for @c GetReturn. +/// +/// @sa GetReturn +template <class... T> +auto GetWith(AstStorage &storage, bool distinct, T... exprs) { + auto with = storage.Create<With>(); + with->body_.distinct = distinct; + FillReturnBody(storage, with->body_, exprs...); + return with; +} + +/// Create the UNWIND clause with given named expression. +auto GetUnwind(AstStorage &storage, NamedExpression *named_expr) { + return storage.Create<memgraph::query::v2::Unwind>(named_expr); +} +auto GetUnwind(AstStorage &storage, Expression *expr, NamedExpression *as) { + as->expression_ = expr; + return GetUnwind(storage, as); +} + +/// Create the delete clause with given named expressions. +auto GetDelete(AstStorage &storage, std::vector<Expression *> exprs, bool detach = false) { + auto del = storage.Create<Delete>(); + del->expressions_.insert(del->expressions_.begin(), exprs.begin(), exprs.end()); + del->detach_ = detach; + return del; +} + +/// Create a set property clause for given property lookup and the right hand +/// side expression. +auto GetSet(AstStorage &storage, PropertyLookup *prop_lookup, Expression *expr) { + return storage.Create<SetProperty>(prop_lookup, expr); +} + +/// Create a set properties clause for given identifier name and the right hand +/// side expression. +auto GetSet(AstStorage &storage, const std::string &name, Expression *expr, bool update = false) { + return storage.Create<SetProperties>(storage.Create<Identifier>(name), expr, update); +} + +/// Create a set labels clause for given identifier name and labels. +auto GetSet(AstStorage &storage, const std::string &name, std::vector<std::string> label_names) { + std::vector<LabelIx> labels; + labels.reserve(label_names.size()); + for (const auto &label : label_names) { + labels.push_back(storage.GetLabelIx(label)); + } + return storage.Create<SetLabels>(storage.Create<Identifier>(name), labels); +} + +/// Create a remove property clause for given property lookup +auto GetRemove(AstStorage &storage, PropertyLookup *prop_lookup) { return storage.Create<RemoveProperty>(prop_lookup); } + +/// Create a remove labels clause for given identifier name and labels. +auto GetRemove(AstStorage &storage, const std::string &name, std::vector<std::string> label_names) { + std::vector<LabelIx> labels; + labels.reserve(label_names.size()); + for (const auto &label : label_names) { + labels.push_back(storage.GetLabelIx(label)); + } + return storage.Create<RemoveLabels>(storage.Create<Identifier>(name), labels); +} + +/// Create a Merge clause for given Pattern with optional OnMatch and OnCreate +/// parts. +auto GetMerge(AstStorage &storage, Pattern *pattern, OnCreate on_create = OnCreate{}) { + auto *merge = storage.Create<memgraph::query::v2::Merge>(); + merge->pattern_ = pattern; + merge->on_create_ = on_create.set; + return merge; +} +auto GetMerge(AstStorage &storage, Pattern *pattern, OnMatch on_match, OnCreate on_create = OnCreate{}) { + auto *merge = storage.Create<memgraph::query::v2::Merge>(); + merge->pattern_ = pattern; + merge->on_match_ = on_match.set; + merge->on_create_ = on_create.set; + return merge; +} + +auto GetCallProcedure(AstStorage &storage, std::string procedure_name, + std::vector<memgraph::query::v2::Expression *> arguments = {}) { + auto *call_procedure = storage.Create<memgraph::query::v2::CallProcedure>(); + call_procedure->procedure_name_ = std::move(procedure_name); + call_procedure->arguments_ = std::move(arguments); + return call_procedure; +} + +/// Create the FOREACH clause with given named expression. +auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vector<query::v2::Clause *> &clauses) { + return storage.Create<query::v2::Foreach>(named_expr, clauses); +} + +} // namespace test_common + +} // namespace memgraph::query::v2 + +/// All the following macros implicitly pass `storage` variable to functions. +/// You need to have `AstStorage storage;` somewhere in scope to use them. +/// Refer to function documentation to see what the macro does. +/// +/// Example usage: +/// +/// // Create MATCH (n) -[r]- (m) RETURN m AS new_name +/// AstStorage storage; +/// auto query = QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), +/// RETURN(NEXPR("new_name"), IDENT("m"))); +#define NODE(...) memgraph::query::v2::test_common::GetNode(storage, __VA_ARGS__) +#define EDGE(...) memgraph::query::v2::test_common::GetEdge(storage, __VA_ARGS__) +#define EDGE_VARIABLE(...) memgraph::query::v2::test_common::GetEdgeVariable(storage, __VA_ARGS__) +#define PATTERN(...) memgraph::query::v2::test_common::GetPattern(storage, {__VA_ARGS__}) +#define NAMED_PATTERN(name, ...) memgraph::query::v2::test_common::GetPattern(storage, name, {__VA_ARGS__}) +#define OPTIONAL_MATCH(...) \ + memgraph::query::v2::test_common::GetWithPatterns(storage.Create<memgraph::query::v2::Match>(true), {__VA_ARGS__}) +#define MATCH(...) \ + memgraph::query::v2::test_common::GetWithPatterns(storage.Create<memgraph::query::v2::Match>(), {__VA_ARGS__}) +#define WHERE(expr) storage.Create<memgraph::query::v2::Where>((expr)) +#define CREATE(...) \ + memgraph::query::v2::test_common::GetWithPatterns(storage.Create<memgraph::query::v2::Create>(), {__VA_ARGS__}) +#define IDENT(...) storage.Create<memgraph::query::v2::Identifier>(__VA_ARGS__) +#define LITERAL(val) storage.Create<memgraph::query::v2::PrimitiveLiteral>((val)) +#define LIST(...) \ + storage.Create<memgraph::query::v2::ListLiteral>(std::vector<memgraph::query::v2::Expression *>{__VA_ARGS__}) +#define MAP(...) \ + storage.Create<memgraph::query::v2::MapLiteral>( \ + std::unordered_map<memgraph::query::v2::PropertyIx, memgraph::query::v2::Expression *>{__VA_ARGS__}) +#define PROPERTY_PAIR(property_name) std::make_pair(property_name, dba.NameToProperty(property_name)) +#define PROPERTY_LOOKUP(...) memgraph::query::v2::test_common::GetPropertyLookup(storage, dba, __VA_ARGS__) +#define PARAMETER_LOOKUP(token_position) storage.Create<memgraph::query::v2::ParameterLookup>((token_position)) +#define NEXPR(name, expr) storage.Create<memgraph::query::v2::NamedExpression>((name), (expr)) +// AS is alternative to NEXPR which does not initialize NamedExpression with +// Expression. It should be used with RETURN or WITH. For example: +// RETURN(IDENT("n"), AS("n")) vs. RETURN(NEXPR("n", IDENT("n"))). +#define AS(name) storage.Create<memgraph::query::v2::NamedExpression>((name)) +#define RETURN(...) memgraph::query::v2::test_common::GetReturn(storage, false, __VA_ARGS__) +#define WITH(...) memgraph::query::v2::test_common::GetWith(storage, false, __VA_ARGS__) +#define RETURN_DISTINCT(...) memgraph::query::v2::test_common::GetReturn(storage, true, __VA_ARGS__) +#define WITH_DISTINCT(...) memgraph::query::v2::test_common::GetWith(storage, true, __VA_ARGS__) +#define UNWIND(...) memgraph::query::v2::test_common::GetUnwind(storage, __VA_ARGS__) +#define ORDER_BY(...) memgraph::query::v2::test_common::GetOrderBy(__VA_ARGS__) +#define SKIP(expr) \ + memgraph::query::v2::test_common::Skip { (expr) } +#define LIMIT(expr) \ + memgraph::query::v2::test_common::Limit { (expr) } +#define DELETE(...) memgraph::query::v2::test_common::GetDelete(storage, {__VA_ARGS__}) +#define DETACH_DELETE(...) memgraph::query::v2::test_common::GetDelete(storage, {__VA_ARGS__}, true) +#define SET(...) memgraph::query::v2::test_common::GetSet(storage, __VA_ARGS__) +#define REMOVE(...) memgraph::query::v2::test_common::GetRemove(storage, __VA_ARGS__) +#define MERGE(...) memgraph::query::v2::test_common::GetMerge(storage, __VA_ARGS__) +#define ON_MATCH(...) \ + memgraph::query::v2::test_common::OnMatch { \ + std::vector<memgraph::query::v2::Clause *> { __VA_ARGS__ } \ + } +#define ON_CREATE(...) \ + memgraph::query::v2::test_common::OnCreate { \ + std::vector<memgraph::query::v2::Clause *> { __VA_ARGS__ } \ + } +#define CREATE_INDEX_ON(label, property) \ + storage.Create<memgraph::query::v2::IndexQuery>(memgraph::query::v2::IndexQuery::Action::CREATE, (label), \ + std::vector<memgraph::query::v2::PropertyIx>{(property)}) +#define QUERY(...) memgraph::query::v2::test_common::GetQuery(storage, __VA_ARGS__) +#define SINGLE_QUERY(...) memgraph::query::v2::test_common::GetSingleQuery(storage.Create<SingleQuery>(), __VA_ARGS__) +#define UNION(...) memgraph::query::v2::test_common::GetCypherUnion(storage.Create<CypherUnion>(true), __VA_ARGS__) +#define UNION_ALL(...) memgraph::query::v2::test_common::GetCypherUnion(storage.Create<CypherUnion>(false), __VA_ARGS__) +#define FOREACH(...) memgraph::query::v2::test_common::GetForeach(storage, __VA_ARGS__) +// Various operators +#define NOT(expr) storage.Create<memgraph::query::v2::NotOperator>((expr)) +#define UPLUS(expr) storage.Create<memgraph::query::v2::UnaryPlusOperator>((expr)) +#define UMINUS(expr) storage.Create<memgraph::query::v2::UnaryMinusOperator>((expr)) +#define IS_NULL(expr) storage.Create<memgraph::query::v2::IsNullOperator>((expr)) +#define ADD(expr1, expr2) storage.Create<memgraph::query::v2::AdditionOperator>((expr1), (expr2)) +#define LESS(expr1, expr2) storage.Create<memgraph::query::v2::LessOperator>((expr1), (expr2)) +#define LESS_EQ(expr1, expr2) storage.Create<memgraph::query::v2::LessEqualOperator>((expr1), (expr2)) +#define GREATER(expr1, expr2) storage.Create<memgraph::query::v2::GreaterOperator>((expr1), (expr2)) +#define GREATER_EQ(expr1, expr2) storage.Create<memgraph::query::v2::GreaterEqualOperator>((expr1), (expr2)) +#define SUM(expr) \ + storage.Create<memgraph::query::v2::Aggregation>((expr), nullptr, memgraph::query::v2::Aggregation::Op::SUM) +#define COUNT(expr) \ + storage.Create<memgraph::query::v2::Aggregation>((expr), nullptr, memgraph::query::v2::Aggregation::Op::COUNT) +#define AVG(expr) \ + storage.Create<memgraph::query::v2::Aggregation>((expr), nullptr, memgraph::query::v2::Aggregation::Op::AVG) +#define COLLECT_LIST(expr) \ + storage.Create<memgraph::query::v2::Aggregation>((expr), nullptr, memgraph::query::v2::Aggregation::Op::COLLECT_LIST) +#define EQ(expr1, expr2) storage.Create<memgraph::query::v2::EqualOperator>((expr1), (expr2)) +#define NEQ(expr1, expr2) storage.Create<memgraph::query::v2::NotEqualOperator>((expr1), (expr2)) +#define AND(expr1, expr2) storage.Create<memgraph::query::v2::AndOperator>((expr1), (expr2)) +#define OR(expr1, expr2) storage.Create<memgraph::query::v2::OrOperator>((expr1), (expr2)) +#define IN_LIST(expr1, expr2) storage.Create<memgraph::query::v2::InListOperator>((expr1), (expr2)) +#define IF(cond, then, else) storage.Create<memgraph::query::v2::IfOperator>((cond), (then), (else)) +// Function call +#define FN(function_name, ...) \ + storage.Create<memgraph::query::v2::Function>(memgraph::utils::ToUpperCase(function_name), \ + std::vector<memgraph::query::v2::Expression *>{__VA_ARGS__}) +// List slicing +#define SLICE(list, lower_bound, upper_bound) \ + storage.Create<memgraph::query::v2::ListSlicingOperator>(list, lower_bound, upper_bound) +// all(variable IN list WHERE predicate) +#define ALL(variable, list, where) \ + storage.Create<memgraph::query::v2::All>(storage.Create<memgraph::query::v2::Identifier>(variable), list, where) +#define SINGLE(variable, list, where) \ + storage.Create<memgraph::query::v2::Single>(storage.Create<memgraph::query::v2::Identifier>(variable), list, where) +#define ANY(variable, list, where) \ + storage.Create<memgraph::query::v2::Any>(storage.Create<memgraph::query::v2::Identifier>(variable), list, where) +#define NONE(variable, list, where) \ + storage.Create<memgraph::query::v2::None>(storage.Create<memgraph::query::v2::Identifier>(variable), list, where) +#define REDUCE(accumulator, initializer, variable, list, expr) \ + storage.Create<memgraph::query::v2::Reduce>(storage.Create<memgraph::query::v2::Identifier>(accumulator), \ + initializer, storage.Create<memgraph::query::v2::Identifier>(variable), \ + list, expr) +#define COALESCE(...) \ + storage.Create<memgraph::query::v2::Coalesce>(std::vector<memgraph::query::v2::Expression *>{__VA_ARGS__}) +#define EXTRACT(variable, list, expr) \ + storage.Create<memgraph::query::v2::Extract>(storage.Create<memgraph::query::v2::Identifier>(variable), list, expr) +#define AUTH_QUERY(action, user, role, user_or_role, password, privileges) \ + storage.Create<memgraph::query::v2::AuthQuery>((action), (user), (role), (user_or_role), password, (privileges)) +#define DROP_USER(usernames) storage.Create<memgraph::query::v2::DropUser>((usernames)) +#define CALL_PROCEDURE(...) memgraph::query::v2::test_common::GetCallProcedure(storage, __VA_ARGS__) diff --git a/tests/unit/query_v2_query_plan_accumulate_aggregate.cpp b/tests/unit/query_v2_query_plan_accumulate_aggregate.cpp new file mode 100644 index 000000000..784b45a8d --- /dev/null +++ b/tests/unit/query_v2_query_plan_accumulate_aggregate.cpp @@ -0,0 +1,634 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <algorithm> +#include <iterator> +#include <memory> +#include <vector> + +#include "common/types.hpp" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "query/v2/context.hpp" +#include "query/v2/exceptions.hpp" +#include "query/v2/plan/operator.hpp" +#include "query_v2_query_plan_common.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/schemas.hpp" + +using namespace memgraph::query::v2; +using namespace memgraph::query::v2::plan; +using test_common::ToIntList; +using test_common::ToIntMap; +using testing::UnorderedElementsAre; + +namespace memgraph::query::v2::tests { + +class QueryPlanAccumulateAggregateTest : public ::testing::Test { + protected: + void SetUp() override { + ASSERT_TRUE(db.CreateSchema(label, {storage::v3::SchemaProperty{property, common::SchemaType::INT}})); + } + + storage::v3::Storage db; + const storage::v3::LabelId label{db.NameToLabel("label")}; + const storage::v3::PropertyId property{db.NameToProperty("property")}; +}; + +TEST_F(QueryPlanAccumulateAggregateTest, Accumulate) { + // simulate the following two query execution on an empty db + // CREATE ({x:0})-[:T]->({x:0}) + // MATCH (n)--(m) SET n.x = n.x + 1, m.x = m.x + 1 RETURN n.x, m.x + // without accumulation we expected results to be [[1, 1], [2, 2]] + // with accumulation we expect them to be [[2, 2], [2, 2]] + + auto check = [&](bool accumulate) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto prop = dba.NameToProperty("x"); + + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(v1.SetProperty(prop, storage::v3::PropertyValue(0)).HasValue()); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + ASSERT_TRUE(v2.SetProperty(prop, storage::v3::PropertyValue(0)).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, dba.NameToEdgeType("T")).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::BOTH, {}, "m", false, + storage::v3::View::OLD); + + auto one = LITERAL(1); + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + auto set_n_p = std::make_shared<plan::SetProperty>(r_m.op_, prop, n_p, ADD(n_p, one)); + auto m_p = PROPERTY_LOOKUP(IDENT("m")->MapTo(r_m.node_sym_), prop); + auto set_m_p = std::make_shared<plan::SetProperty>(set_n_p, prop, m_p, ADD(m_p, one)); + + std::shared_ptr<LogicalOperator> last_op = set_m_p; + if (accumulate) { + last_op = std::make_shared<Accumulate>(last_op, std::vector<Symbol>{n.sym_, r_m.node_sym_}); + } + + auto n_p_ne = NEXPR("n.p", n_p)->MapTo(symbol_table.CreateSymbol("n_p_ne", true)); + auto m_p_ne = NEXPR("m.p", m_p)->MapTo(symbol_table.CreateSymbol("m_p_ne", true)); + auto produce = MakeProduce(last_op, n_p_ne, m_p_ne); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + std::vector<int> results_data; + for (const auto &row : results) + for (const auto &column : row) results_data.emplace_back(column.ValueInt()); + if (accumulate) + EXPECT_THAT(results_data, ::testing::ElementsAre(2, 2, 2, 2)); + else + EXPECT_THAT(results_data, ::testing::ElementsAre(1, 1, 2, 2)); + }; + + check(false); + check(true); +} + +TEST_F(QueryPlanAccumulateAggregateTest, AccumulateAdvance) { + // we simulate 'CREATE (n) WITH n AS n MATCH (m) RETURN m' + // to get correct results we need to advance the command + auto check = [&](bool advance) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + NodeCreationInfo node; + node.symbol = symbol_table.CreateSymbol("n", true); + node.labels = {label}; + std::get<std::vector<std::pair<storage::v3::PropertyId, Expression *>>>(node.properties) + .emplace_back(property, LITERAL(1)); + auto create = std::make_shared<CreateNode>(nullptr, node); + auto accumulate = std::make_shared<Accumulate>(create, std::vector<Symbol>{node.symbol}, advance); + auto match = MakeScanAll(storage, symbol_table, "m", accumulate); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(advance ? 1 : 0, PullAll(*match.op_, &context)); + }; + check(false); + check(true); +} + +std::shared_ptr<Produce> MakeAggregationProduce(std::shared_ptr<LogicalOperator> input, SymbolTable &symbol_table, + AstStorage &storage, const std::vector<Expression *> aggr_inputs, + const std::vector<Aggregation::Op> aggr_ops, + const std::vector<Expression *> group_by_exprs, + const std::vector<Symbol> remember) { + // prepare all the aggregations + std::vector<Aggregate::Element> aggregates; + std::vector<NamedExpression *> named_expressions; + + auto aggr_inputs_it = aggr_inputs.begin(); + for (auto aggr_op : aggr_ops) { + // TODO change this from using IDENT to using AGGREGATION + // once AGGREGATION is handled properly in ExpressionEvaluation + auto aggr_sym = symbol_table.CreateSymbol("aggregation", true); + auto named_expr = + NEXPR("", IDENT("aggregation")->MapTo(aggr_sym))->MapTo(symbol_table.CreateSymbol("named_expression", true)); + named_expressions.push_back(named_expr); + // the key expression is only used in COLLECT_MAP + Expression *key_expr_ptr = aggr_op == Aggregation::Op::COLLECT_MAP ? LITERAL("key") : nullptr; + aggregates.emplace_back(Aggregate::Element{*aggr_inputs_it++, key_expr_ptr, aggr_op, aggr_sym}); + } + + // Produce will also evaluate group_by expressions and return them after the + // aggregations. + for (auto group_by_expr : group_by_exprs) { + auto named_expr = NEXPR("", group_by_expr)->MapTo(symbol_table.CreateSymbol("named_expression", true)); + named_expressions.push_back(named_expr); + } + auto aggregation = std::make_shared<Aggregate>(input, aggregates, group_by_exprs, remember); + return std::make_shared<Produce>(aggregation, named_expressions); +} + +// /** Test fixture for all the aggregation ops in one return. */ +class QueryPlanAggregateOps : public ::testing::Test { + protected: + void SetUp() override { + ASSERT_TRUE(db.CreateSchema(label, {storage::v3::SchemaProperty{property, common::SchemaType::INT}})); + } + storage::v3::Storage db; + storage::v3::Storage::Accessor storage_dba{db.Access()}; + DbAccessor dba{&storage_dba}; + storage::v3::LabelId label = db.NameToLabel("label"); + storage::v3::PropertyId property = db.NameToProperty("property"); + storage::v3::PropertyId prop = db.NameToProperty("prop"); + + AstStorage storage; + SymbolTable symbol_table; + + void AddData() { + // setup is several nodes most of which have an int property set + // we will take the sum, avg, min, max and count + // we won't group by anything + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}) + ->SetProperty(prop, storage::v3::PropertyValue(5)) + .HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}) + ->SetProperty(prop, storage::v3::PropertyValue(7)) + .HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}) + ->SetProperty(prop, storage::v3::PropertyValue(12)) + .HasValue()); + // a missing property (null) gets ignored by all aggregations except + // COUNT(*) + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(4)}}).HasValue()); + dba.AdvanceCommand(); + } + + auto AggregationResults(bool with_group_by, std::vector<Aggregation::Op> ops = { + Aggregation::Op::COUNT, Aggregation::Op::COUNT, Aggregation::Op::MIN, + Aggregation::Op::MAX, Aggregation::Op::SUM, Aggregation::Op::AVG, + Aggregation::Op::COLLECT_LIST, Aggregation::Op::COLLECT_MAP}) { + // match all nodes and perform aggregations + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + + std::vector<Expression *> aggregation_expressions(ops.size(), n_p); + std::vector<Expression *> group_bys; + if (with_group_by) group_bys.push_back(n_p); + aggregation_expressions[0] = nullptr; + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, aggregation_expressions, ops, group_bys, {}); + auto context = MakeContext(storage, symbol_table, &dba); + return CollectProduce(*produce, &context); + } +}; + +TEST_F(QueryPlanAggregateOps, WithData) { + AddData(); + auto results = AggregationResults(false); + + ASSERT_EQ(results.size(), 1); + ASSERT_EQ(results[0].size(), 8); + // count(*) + ASSERT_EQ(results[0][0].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][0].ValueInt(), 4); + // count + ASSERT_EQ(results[0][1].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][1].ValueInt(), 3); + // min + ASSERT_EQ(results[0][2].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][2].ValueInt(), 5); + // max + ASSERT_EQ(results[0][3].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][3].ValueInt(), 12); + // sum + ASSERT_EQ(results[0][4].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][4].ValueInt(), 24); + // avg + ASSERT_EQ(results[0][5].type(), TypedValue::Type::Double); + EXPECT_FLOAT_EQ(results[0][5].ValueDouble(), 24 / 3.0); + // collect list + ASSERT_EQ(results[0][6].type(), TypedValue::Type::List); + EXPECT_THAT(ToIntList(results[0][6]), UnorderedElementsAre(5, 7, 12)); + // collect map + ASSERT_EQ(results[0][7].type(), TypedValue::Type::Map); + auto map = ToIntMap(results[0][7]); + ASSERT_EQ(map.size(), 1); + EXPECT_EQ(map.begin()->first, "key"); + EXPECT_FALSE(std::set<int>({5, 7, 12}).insert(map.begin()->second).second); +} + +TEST_F(QueryPlanAggregateOps, WithoutDataWithGroupBy) { + { + auto results = AggregationResults(true, {Aggregation::Op::COUNT}); + EXPECT_EQ(results.size(), 0); + } + { + auto results = AggregationResults(true, {Aggregation::Op::SUM}); + EXPECT_EQ(results.size(), 0); + } + { + auto results = AggregationResults(true, {Aggregation::Op::AVG}); + EXPECT_EQ(results.size(), 0); + } + { + auto results = AggregationResults(true, {Aggregation::Op::MIN}); + EXPECT_EQ(results.size(), 0); + } + { + auto results = AggregationResults(true, {Aggregation::Op::MAX}); + EXPECT_EQ(results.size(), 0); + } + { + auto results = AggregationResults(true, {Aggregation::Op::COLLECT_LIST}); + EXPECT_EQ(results.size(), 0); + } + { + auto results = AggregationResults(true, {Aggregation::Op::COLLECT_MAP}); + EXPECT_EQ(results.size(), 0); + } +} + +TEST_F(QueryPlanAggregateOps, WithoutDataWithoutGroupBy) { + auto results = AggregationResults(false); + ASSERT_EQ(results.size(), 1); + ASSERT_EQ(results[0].size(), 8); + // count(*) + ASSERT_EQ(results[0][0].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][0].ValueInt(), 0); + // count + ASSERT_EQ(results[0][1].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][1].ValueInt(), 0); + // min + EXPECT_TRUE(results[0][2].IsNull()); + // max + EXPECT_TRUE(results[0][3].IsNull()); + // sum + EXPECT_TRUE(results[0][4].IsNull()); + // avg + EXPECT_TRUE(results[0][5].IsNull()); + // collect list + ASSERT_EQ(results[0][6].type(), TypedValue::Type::List); + EXPECT_EQ(ToIntList(results[0][6]).size(), 0); + // collect map + ASSERT_EQ(results[0][7].type(), TypedValue::Type::Map); + EXPECT_EQ(ToIntMap(results[0][7]).size(), 0); +} + +TEST_F(QueryPlanAccumulateAggregateTest, AggregateGroupByValues) { + // Tests that distinct groups are aggregated properly for values of all types. + // Also test the "remember" part of the Aggregation API as final results are + // obtained via a property lookup of a remembered node. + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // a vector of storage::v3::PropertyValue to be set as property values on vertices + // most of them should result in a distinct group (commented where not) + std::vector<storage::v3::PropertyValue> group_by_vals; + group_by_vals.emplace_back(4); + group_by_vals.emplace_back(7); + group_by_vals.emplace_back(7.3); + group_by_vals.emplace_back(7.2); + group_by_vals.emplace_back("Johhny"); + group_by_vals.emplace_back("Jane"); + group_by_vals.emplace_back("1"); + group_by_vals.emplace_back(true); + group_by_vals.emplace_back(false); + group_by_vals.emplace_back(std::vector<storage::v3::PropertyValue>{storage::v3::PropertyValue(1)}); + group_by_vals.emplace_back( + std::vector<storage::v3::PropertyValue>{storage::v3::PropertyValue(1), storage::v3::PropertyValue(2)}); + group_by_vals.emplace_back( + std::vector<storage::v3::PropertyValue>{storage::v3::PropertyValue(2), storage::v3::PropertyValue(1)}); + group_by_vals.emplace_back(storage::v3::PropertyValue()); + // should NOT result in another group because 7.0 == 7 + group_by_vals.emplace_back(7.0); + // should NOT result in another group + group_by_vals.emplace_back( + std::vector<storage::v3::PropertyValue>{storage::v3::PropertyValue(1), storage::v3::PropertyValue(2.0)}); + + // generate a lot of vertices and set props on them + auto prop = dba.NameToProperty("prop"); + for (int i = 0; i < 1000; ++i) + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}) + ->SetProperty(prop, group_by_vals[i % group_by_vals.size()]) + .HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + // match all nodes and perform aggregations + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p}, {Aggregation::Op::COUNT}, {n_p}, {n.sym_}); + + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + ASSERT_EQ(results.size(), group_by_vals.size() - 2); + std::unordered_set<TypedValue, TypedValue::Hash, TypedValue::BoolEqual> result_group_bys; + for (const auto &row : results) { + ASSERT_EQ(2, row.size()); + result_group_bys.insert(row[1]); + } + ASSERT_EQ(result_group_bys.size(), group_by_vals.size() - 2); + std::vector<TypedValue> group_by_tvals; + group_by_tvals.reserve(group_by_vals.size()); + for (const auto &v : group_by_vals) group_by_tvals.emplace_back(v); + EXPECT_TRUE(std::is_permutation(group_by_tvals.begin(), group_by_tvals.end() - 2, result_group_bys.begin(), + TypedValue::BoolEqual{})); +} + +TEST_F(QueryPlanAccumulateAggregateTest, AggregateMultipleGroupBy) { + // in this test we have 3 different properties that have different values + // for different records and assert that we get the correct combination + // of values in our groups + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto prop1 = dba.NameToProperty("prop1"); + auto prop2 = dba.NameToProperty("prop2"); + auto prop3 = dba.NameToProperty("prop3"); + for (int i = 0; i < 2 * 3 * 5; ++i) { + auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i)}}); + ASSERT_TRUE(v.SetProperty(prop1, storage::v3::PropertyValue(static_cast<bool>(i % 2))).HasValue()); + ASSERT_TRUE(v.SetProperty(prop2, storage::v3::PropertyValue(i % 3)).HasValue()); + ASSERT_TRUE(v.SetProperty(prop3, storage::v3::PropertyValue("value" + std::to_string(i % 5))).HasValue()); + } + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + // match all nodes and perform aggregations + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p1 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop1); + auto n_p2 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop2); + auto n_p3 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop3); + + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p1}, {Aggregation::Op::COUNT}, + {n_p1, n_p2, n_p3}, {n.sym_}); + + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), 2 * 3 * 5); +} + +TEST(QueryPlan, AggregateNoInput) { + storage::v3::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + + auto two = LITERAL(2); + auto produce = MakeAggregationProduce(nullptr, symbol_table, storage, {two}, {Aggregation::Op::COUNT}, {}, {}); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(1, results.size()); + EXPECT_EQ(1, results[0].size()); + EXPECT_EQ(TypedValue::Type::Int, results[0][0].type()); + EXPECT_EQ(1, results[0][0].ValueInt()); +} + +TEST_F(QueryPlanAccumulateAggregateTest, AggregateCountEdgeCases) { + // tests for detected bugs in the COUNT aggregation behavior + // ensure that COUNT returns correctly for + // - 0 vertices in database + // - 1 vertex in database, property not set + // - 1 vertex in database, property set + // - 2 vertices in database, property set on one + // - 2 vertices in database, property set on both + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto prop = dba.NameToProperty("prop"); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + + // returns -1 when there are no results + // otherwise returns MATCH (n) RETURN count(n.prop) + auto count = [&]() { + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p}, {Aggregation::Op::COUNT}, {}, {}); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + if (results.size() == 0) return -1L; + EXPECT_EQ(1, results.size()); + EXPECT_EQ(1, results[0].size()); + EXPECT_EQ(TypedValue::Type::Int, results[0][0].type()); + return results[0][0].ValueInt(); + }; + + // no vertices yet in database + EXPECT_EQ(0, count()); + + // one vertex, no property set + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(0, count()); + + // one vertex, property set + for (auto va : dba.Vertices(storage::v3::View::OLD)) + ASSERT_TRUE(va.SetProperty(prop, storage::v3::PropertyValue(42)).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(1, count()); + + // two vertices, one with property set + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(1, count()); + + // two vertices, both with property set + for (auto va : dba.Vertices(storage::v3::View::OLD)) + ASSERT_TRUE(va.SetProperty(prop, storage::v3::PropertyValue(42)).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(2, count()); +} + +TEST_F(QueryPlanAccumulateAggregateTest, AggregateFirstValueTypes) { + // testing exceptions that get emitted by the first-value + // type check + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto prop_string = dba.NameToProperty("string"); + ASSERT_TRUE(v1.SetProperty(prop_string, storage::v3::PropertyValue("johhny")).HasValue()); + auto prop_int = dba.NameToProperty("int"); + ASSERT_TRUE(v1.SetProperty(prop_int, storage::v3::PropertyValue(12)).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_prop_string = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop_string); + auto n_prop_int = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop_int); + auto n_id = n_prop_string->expression_; + + auto aggregate = [&](Expression *expression, Aggregation::Op aggr_op) { + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {expression}, {aggr_op}, {}, {}); + auto context = MakeContext(storage, symbol_table, &dba); + CollectProduce(*produce, &context); + }; + + // everything except for COUNT and COLLECT fails on a Vertex + aggregate(n_id, Aggregation::Op::COUNT); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::MIN), QueryRuntimeException); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::MAX), QueryRuntimeException); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::AVG), QueryRuntimeException); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::SUM), QueryRuntimeException); + + // on strings AVG and SUM fail + aggregate(n_prop_string, Aggregation::Op::COUNT); + aggregate(n_prop_string, Aggregation::Op::MIN); + aggregate(n_prop_string, Aggregation::Op::MAX); + EXPECT_THROW(aggregate(n_prop_string, Aggregation::Op::AVG), QueryRuntimeException); + EXPECT_THROW(aggregate(n_prop_string, Aggregation::Op::SUM), QueryRuntimeException); + + // on ints nothing fails + aggregate(n_prop_int, Aggregation::Op::COUNT); + aggregate(n_prop_int, Aggregation::Op::MIN); + aggregate(n_prop_int, Aggregation::Op::MAX); + aggregate(n_prop_int, Aggregation::Op::AVG); + aggregate(n_prop_int, Aggregation::Op::SUM); + aggregate(n_prop_int, Aggregation::Op::COLLECT_LIST); + aggregate(n_prop_int, Aggregation::Op::COLLECT_MAP); +} + +TEST_F(QueryPlanAccumulateAggregateTest, AggregateTypes) { + // testing exceptions that can get emitted by an aggregation + // does not check all combinations that can result in an exception + // (that logic is defined and tested by TypedValue) + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto p1 = dba.NameToProperty("p1"); // has only string props + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}) + ->SetProperty(p1, storage::v3::PropertyValue("string")) + .HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}) + ->SetProperty(p1, storage::v3::PropertyValue("str2")) + .HasValue()); + auto p2 = dba.NameToProperty("p2"); // combines int and bool + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}) + ->SetProperty(p2, storage::v3::PropertyValue(42)) + .HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}) + ->SetProperty(p2, storage::v3::PropertyValue(true)) + .HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p1 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), p1); + auto n_p2 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), p2); + + auto aggregate = [&](Expression *expression, Aggregation::Op aggr_op) { + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {expression}, {aggr_op}, {}, {}); + auto context = MakeContext(storage, symbol_table, &dba); + CollectProduce(*produce, &context); + }; + + // everything except for COUNT and COLLECT fails on a Vertex + auto n_id = n_p1->expression_; + aggregate(n_id, Aggregation::Op::COUNT); + aggregate(n_id, Aggregation::Op::COLLECT_LIST); + aggregate(n_id, Aggregation::Op::COLLECT_MAP); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::MIN), QueryRuntimeException); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::MAX), QueryRuntimeException); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::AVG), QueryRuntimeException); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::SUM), QueryRuntimeException); + + // on strings AVG and SUM fail + aggregate(n_p1, Aggregation::Op::COUNT); + aggregate(n_p1, Aggregation::Op::COLLECT_LIST); + aggregate(n_p1, Aggregation::Op::COLLECT_MAP); + aggregate(n_p1, Aggregation::Op::MIN); + aggregate(n_p1, Aggregation::Op::MAX); + EXPECT_THROW(aggregate(n_p1, Aggregation::Op::AVG), QueryRuntimeException); + EXPECT_THROW(aggregate(n_p1, Aggregation::Op::SUM), QueryRuntimeException); + + // combination of int and bool, everything except COUNT and COLLECT fails + aggregate(n_p2, Aggregation::Op::COUNT); + aggregate(n_p2, Aggregation::Op::COLLECT_LIST); + aggregate(n_p2, Aggregation::Op::COLLECT_MAP); + EXPECT_THROW(aggregate(n_p2, Aggregation::Op::MIN), QueryRuntimeException); + EXPECT_THROW(aggregate(n_p2, Aggregation::Op::MAX), QueryRuntimeException); + EXPECT_THROW(aggregate(n_p2, Aggregation::Op::AVG), QueryRuntimeException); + EXPECT_THROW(aggregate(n_p2, Aggregation::Op::SUM), QueryRuntimeException); +} + +TEST(QueryPlan, Unwind) { + storage::v3::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + + // UNWIND [ [1, true, "x"], [], ["bla"] ] AS x UNWIND x as y RETURN x, y + auto input_expr = storage.Create<PrimitiveLiteral>(std::vector<storage::v3::PropertyValue>{ + storage::v3::PropertyValue(std::vector<storage::v3::PropertyValue>{ + storage::v3::PropertyValue(1), storage::v3::PropertyValue(true), storage::v3::PropertyValue("x")}), + storage::v3::PropertyValue(std::vector<storage::v3::PropertyValue>{}), + storage::v3::PropertyValue(std::vector<storage::v3::PropertyValue>{storage::v3::PropertyValue("bla")})}); + + auto x = symbol_table.CreateSymbol("x", true); + auto unwind_0 = std::make_shared<plan::Unwind>(nullptr, input_expr, x); + auto x_expr = IDENT("x")->MapTo(x); + auto y = symbol_table.CreateSymbol("y", true); + auto unwind_1 = std::make_shared<plan::Unwind>(unwind_0, x_expr, y); + + auto x_ne = NEXPR("x", x_expr)->MapTo(symbol_table.CreateSymbol("x_ne", true)); + auto y_ne = NEXPR("y", IDENT("y")->MapTo(y))->MapTo(symbol_table.CreateSymbol("y_ne", true)); + auto produce = MakeProduce(unwind_1, x_ne, y_ne); + + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + ASSERT_EQ(4, results.size()); + const std::vector<int> expected_x_card{3, 3, 3, 1}; + auto expected_x_card_it = expected_x_card.begin(); + const std::vector<TypedValue> expected_y{TypedValue(1), TypedValue(true), TypedValue("x"), TypedValue("bla")}; + auto expected_y_it = expected_y.begin(); + for (const auto &row : results) { + ASSERT_EQ(2, row.size()); + ASSERT_EQ(row[0].type(), TypedValue::Type::List); + EXPECT_EQ(row[0].ValueList().size(), *expected_x_card_it); + EXPECT_EQ(row[1].type(), expected_y_it->type()); + expected_x_card_it++; + expected_y_it++; + } +} +} // namespace memgraph::query::v2::tests diff --git a/tests/unit/query_v2_query_plan_bag_semantics.cpp b/tests/unit/query_v2_query_plan_bag_semantics.cpp new file mode 100644 index 000000000..a07ced087 --- /dev/null +++ b/tests/unit/query_v2_query_plan_bag_semantics.cpp @@ -0,0 +1,311 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <algorithm> +#include <iterator> +#include <memory> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "query/v2/context.hpp" +#include "query/v2/exceptions.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/plan/operator.hpp" + +#include "query_v2_query_plan_common.hpp" +#include "storage/v3/property_value.hpp" + +using namespace memgraph::query::v2; +using namespace memgraph::query::v2::plan; + +namespace memgraph::query::v2::tests { + +class QueryPlanBagSemanticsTest : public testing::Test { + protected: + void SetUp() override { + ASSERT_TRUE(db.CreateSchema(label, {storage::v3::SchemaProperty{property, common::SchemaType::INT}})); + } + + storage::v3::Storage db; + const storage::v3::LabelId label{db.NameToLabel("label")}; + const storage::v3::PropertyId property{db.NameToProperty("property")}; +}; + +TEST_F(QueryPlanBagSemanticsTest, Skip) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n1"); + auto skip = std::make_shared<plan::Skip>(n.op_, LITERAL(2)); + + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(0, PullAll(*skip, &context)); + + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(0, PullAll(*skip, &context)); + + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(0, PullAll(*skip, &context)); + + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(1, PullAll(*skip, &context)); + + for (int i = 0; i < 10; ++i) { + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i + 3)}}).HasValue()); + } + dba.AdvanceCommand(); + EXPECT_EQ(11, PullAll(*skip, &context)); +} + +TEST_F(QueryPlanBagSemanticsTest, Limit) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n1"); + auto skip = std::make_shared<plan::Limit>(n.op_, LITERAL(2)); + + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(0, PullAll(*skip, &context)); + + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(1, PullAll(*skip, &context)); + + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(2, PullAll(*skip, &context)); + + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(2, PullAll(*skip, &context)); + + for (int i = 0; i < 10; ++i) { + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i + 3)}}).HasValue()); + } + dba.AdvanceCommand(); + EXPECT_EQ(2, PullAll(*skip, &context)); +} + +TEST_F(QueryPlanBagSemanticsTest, CreateLimit) { + // CREATE (n), (m) + // MATCH (n) CREATE (m) LIMIT 1 + // in the end we need to have 3 vertices in the db + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n1"); + NodeCreationInfo m; + m.symbol = symbol_table.CreateSymbol("m", true); + m.labels = {label}; + std::get<std::vector<std::pair<storage::v3::PropertyId, Expression *>>>(m.properties) + .emplace_back(property, LITERAL(3)); + auto c = std::make_shared<CreateNode>(n.op_, m); + auto skip = std::make_shared<plan::Limit>(c, LITERAL(1)); + + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*skip, &context)); + dba.AdvanceCommand(); + EXPECT_EQ(3, CountIterable(dba.Vertices(storage::v3::View::OLD))); +} + +TEST_F(QueryPlanBagSemanticsTest, OrderBy) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + auto prop = dba.NameToProperty("prop"); + + // contains a series of tests + // each test defines the ordering a vector of values in the desired order + auto Null = storage::v3::PropertyValue(); + std::vector<std::pair<Ordering, std::vector<storage::v3::PropertyValue>>> orderable{ + {Ordering::ASC, + {storage::v3::PropertyValue(0), storage::v3::PropertyValue(0), storage::v3::PropertyValue(0.5), + storage::v3::PropertyValue(1), storage::v3::PropertyValue(2), storage::v3::PropertyValue(12.6), + storage::v3::PropertyValue(42), Null, Null}}, + {Ordering::ASC, + {storage::v3::PropertyValue(false), storage::v3::PropertyValue(false), storage::v3::PropertyValue(true), + storage::v3::PropertyValue(true), Null, Null}}, + {Ordering::ASC, + {storage::v3::PropertyValue("A"), storage::v3::PropertyValue("B"), storage::v3::PropertyValue("a"), + storage::v3::PropertyValue("a"), storage::v3::PropertyValue("aa"), storage::v3::PropertyValue("ab"), + storage::v3::PropertyValue("aba"), Null, Null}}, + {Ordering::DESC, + {Null, Null, storage::v3::PropertyValue(33), storage::v3::PropertyValue(33), storage::v3::PropertyValue(32.5), + storage::v3::PropertyValue(32), storage::v3::PropertyValue(2.2), storage::v3::PropertyValue(2.1), + storage::v3::PropertyValue(0)}}, + {Ordering::DESC, {Null, storage::v3::PropertyValue(true), storage::v3::PropertyValue(false)}}, + {Ordering::DESC, {Null, storage::v3::PropertyValue("zorro"), storage::v3::PropertyValue("borro")}}}; + + for (const auto &order_value_pair : orderable) { + std::vector<TypedValue> values; + values.reserve(order_value_pair.second.size()); + for (const auto &v : order_value_pair.second) values.emplace_back(v); + // empty database + for (auto vertex : dba.Vertices(storage::v3::View::OLD)) ASSERT_TRUE(dba.DetachRemoveVertex(&vertex).HasValue()); + dba.AdvanceCommand(); + ASSERT_EQ(0, CountIterable(dba.Vertices(storage::v3::View::OLD))); + + // take some effort to shuffle the values + // because we are testing that something not ordered gets ordered + // and need to take care it does not happen by accident + auto shuffled = values; + auto order_equal = [&values, &shuffled]() { + return std::equal(values.begin(), values.end(), shuffled.begin(), TypedValue::BoolEqual{}); + }; + for (int i = 0; i < 50 && order_equal(); ++i) { + std::random_shuffle(shuffled.begin(), shuffled.end()); + } + ASSERT_FALSE(order_equal()); + + // create the vertices + for (const auto &value : shuffled) { + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}) + ->SetProperty(prop, storage::v3::PropertyValue(value)) + .HasValue()); + } + dba.AdvanceCommand(); + + // order by and collect results + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + auto order_by = std::make_shared<plan::OrderBy>(n.op_, std::vector<SortItem>{{order_value_pair.first, n_p}}, + std::vector<Symbol>{n.sym_}); + auto n_p_ne = NEXPR("n.p", n_p)->MapTo(symbol_table.CreateSymbol("n.p", true)); + auto produce = MakeProduce(order_by, n_p_ne); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + ASSERT_EQ(values.size(), results.size()); + for (int j = 0; j < results.size(); ++j) EXPECT_TRUE(TypedValue::BoolEqual{}(results[j][0], values[j])); + } +} + +TEST_F(QueryPlanBagSemanticsTest, OrderByMultiple) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + + auto p1 = dba.NameToProperty("p1"); + auto p2 = dba.NameToProperty("p2"); + + // create a bunch of vertices that in two properties + // have all the variations (with repetition) of N values. + // ensure that those vertices are not created in the + // "right" sequence, but randomized + const int N = 20; + std::vector<std::pair<int, int>> prop_values; + for (int i = 0; i < N * N; ++i) prop_values.emplace_back(i % N, i / N); + std::random_shuffle(prop_values.begin(), prop_values.end()); + for (const auto &pair : prop_values) { + auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(v.SetProperty(p1, storage::v3::PropertyValue(pair.first)).HasValue()); + ASSERT_TRUE(v.SetProperty(p2, storage::v3::PropertyValue(pair.second)).HasValue()); + } + dba.AdvanceCommand(); + + // order by and collect results + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p1 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), p1); + auto n_p2 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), p2); + // order the results so we get + // (p1: 0, p2: N-1) + // (p1: 0, p2: N-2) + // ... + // (p1: N-1, p2:0) + auto order_by = std::make_shared<plan::OrderBy>(n.op_, + std::vector<SortItem>{ + {Ordering::ASC, n_p1}, + {Ordering::DESC, n_p2}, + }, + std::vector<Symbol>{n.sym_}); + auto n_p1_ne = NEXPR("n.p1", n_p1)->MapTo(symbol_table.CreateSymbol("n.p1", true)); + auto n_p2_ne = NEXPR("n.p2", n_p2)->MapTo(symbol_table.CreateSymbol("n.p2", true)); + auto produce = MakeProduce(order_by, n_p1_ne, n_p2_ne); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + ASSERT_EQ(N * N, results.size()); + for (int j = 0; j < N * N; ++j) { + ASSERT_EQ(results[j][0].type(), TypedValue::Type::Int); + EXPECT_EQ(results[j][0].ValueInt(), j / N); + ASSERT_EQ(results[j][1].type(), TypedValue::Type::Int); + EXPECT_EQ(results[j][1].ValueInt(), N - 1 - j % N); + } +} + +TEST_F(QueryPlanBagSemanticsTest, OrderByExceptions) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + auto prop = dba.NameToProperty("prop"); + + // a vector of pairs of typed values that should result + // in an exception when trying to order on them + std::vector<std::pair<storage::v3::PropertyValue, storage::v3::PropertyValue>> exception_pairs{ + {storage::v3::PropertyValue(42), storage::v3::PropertyValue(true)}, + {storage::v3::PropertyValue(42), storage::v3::PropertyValue("bla")}, + {storage::v3::PropertyValue(42), + storage::v3::PropertyValue(std::vector<storage::v3::PropertyValue>{storage::v3::PropertyValue(42)})}, + {storage::v3::PropertyValue(true), storage::v3::PropertyValue("bla")}, + {storage::v3::PropertyValue(true), + storage::v3::PropertyValue(std::vector<storage::v3::PropertyValue>{storage::v3::PropertyValue(true)})}, + {storage::v3::PropertyValue("bla"), + storage::v3::PropertyValue(std::vector<storage::v3::PropertyValue>{storage::v3::PropertyValue("bla")})}, + // illegal comparisons of same-type values + {storage::v3::PropertyValue(std::vector<storage::v3::PropertyValue>{storage::v3::PropertyValue(42)}), + storage::v3::PropertyValue(std::vector<storage::v3::PropertyValue>{storage::v3::PropertyValue(42)})}}; + + for (const auto &pair : exception_pairs) { + // empty database + for (auto vertex : dba.Vertices(storage::v3::View::OLD)) ASSERT_TRUE(dba.DetachRemoveVertex(&vertex).HasValue()); + dba.AdvanceCommand(); + ASSERT_EQ(0, CountIterable(dba.Vertices(storage::v3::View::OLD))); + + // make two vertices, and set values + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}) + ->SetProperty(prop, pair.first) + .HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}) + ->SetProperty(prop, pair.second) + .HasValue()); + dba.AdvanceCommand(); + ASSERT_EQ(2, CountIterable(dba.Vertices(storage::v3::View::OLD))); + for (const auto &va : dba.Vertices(storage::v3::View::OLD)) + ASSERT_NE(va.GetProperty(storage::v3::View::OLD, prop).GetValue().type(), storage::v3::PropertyValue::Type::Null); + + // order by and expect an exception + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + auto order_by = + std::make_shared<plan::OrderBy>(n.op_, std::vector<SortItem>{{Ordering::ASC, n_p}}, std::vector<Symbol>{}); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*order_by, &context), QueryRuntimeException); + } +} +} // namespace memgraph::query::v2::tests diff --git a/tests/unit/query_v2_query_plan_common.hpp b/tests/unit/query_v2_query_plan_common.hpp new file mode 100644 index 000000000..7e535b1d5 --- /dev/null +++ b/tests/unit/query_v2_query_plan_common.hpp @@ -0,0 +1,225 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <iterator> +#include <memory> +#include <vector> + +#include "query/v2/common.hpp" +#include "query/v2/context.hpp" +#include "query/v2/db_accessor.hpp" +#include "query/v2/frontend/semantic/symbol_table.hpp" +#include "query/v2/interpret/frame.hpp" +#include "query/v2/plan/operator.hpp" +#include "storage/v3/storage.hpp" +#include "utils/logging.hpp" + +#include "query_v2_query_common.hpp" + +using namespace memgraph::query::v2; +using namespace memgraph::query::v2::plan; + +using Bound = ScanAllByLabelPropertyRange::Bound; + +ExecutionContext MakeContext(const AstStorage &storage, const SymbolTable &symbol_table, + memgraph::query::v2::DbAccessor *dba) { + ExecutionContext context{dba}; + context.symbol_table = symbol_table; + context.evaluation_context.properties = NamesToProperties(storage.properties_, dba); + context.evaluation_context.labels = NamesToLabels(storage.labels_, dba); + return context; +} + +/** Helper function that collects all the results from the given Produce. */ +std::vector<std::vector<TypedValue>> CollectProduce(const Produce &produce, ExecutionContext *context) { + Frame frame(context->symbol_table.max_position()); + + // top level node in the operator tree is a produce (return) + // so stream out results + + // collect the symbols from the return clause + std::vector<Symbol> symbols; + for (auto named_expression : produce.named_expressions_) + symbols.emplace_back(context->symbol_table.at(*named_expression)); + + // stream out results + auto cursor = produce.MakeCursor(memgraph::utils::NewDeleteResource()); + std::vector<std::vector<TypedValue>> results; + while (cursor->Pull(frame, *context)) { + std::vector<TypedValue> values; + for (auto &symbol : symbols) values.emplace_back(frame[symbol]); + results.emplace_back(values); + } + + return results; +} + +int PullAll(const LogicalOperator &logical_op, ExecutionContext *context) { + Frame frame(context->symbol_table.max_position()); + auto cursor = logical_op.MakeCursor(memgraph::utils::NewDeleteResource()); + int count = 0; + while (cursor->Pull(frame, *context)) count++; + return count; +} + +template <typename... TNamedExpressions> +auto MakeProduce(std::shared_ptr<LogicalOperator> input, TNamedExpressions... named_expressions) { + return std::make_shared<Produce>(input, std::vector<NamedExpression *>{named_expressions...}); +} + +struct ScanAllTuple { + NodeAtom *node_; + std::shared_ptr<LogicalOperator> op_; + Symbol sym_; +}; + +/** + * Creates and returns a tuple of stuff for a scan-all starting + * from the node with the given name. + * + * Returns ScanAllTuple(node_atom, scan_all_logical_op, symbol). + */ +ScanAllTuple MakeScanAll(AstStorage &storage, SymbolTable &symbol_table, const std::string &identifier, + std::shared_ptr<LogicalOperator> input = {nullptr}, + memgraph::storage::v3::View view = memgraph::storage::v3::View::OLD) { + auto node = NODE(identifier); + auto symbol = symbol_table.CreateSymbol(identifier, true); + node->identifier_->MapTo(symbol); + auto logical_op = std::make_shared<ScanAll>(input, symbol, view); + return ScanAllTuple{node, logical_op, symbol}; +} + +ScanAllTuple MakeScanAllNew(AstStorage &storage, SymbolTable &symbol_table, const std::string &identifier, + std::shared_ptr<LogicalOperator> input = {nullptr}, + memgraph::storage::v3::View view = memgraph::storage::v3::View::OLD) { + auto *node = NODE(identifier, "label"); + auto symbol = symbol_table.CreateSymbol(identifier, true); + node->identifier_->MapTo(symbol); + auto logical_op = std::make_shared<ScanAll>(input, symbol, view); + return ScanAllTuple{node, logical_op, symbol}; +} + +/** + * Creates and returns a tuple of stuff for a scan-all starting + * from the node with the given name and label. + * + * Returns ScanAllTuple(node_atom, scan_all_logical_op, symbol). + */ +ScanAllTuple MakeScanAllByLabel(AstStorage &storage, SymbolTable &symbol_table, const std::string &identifier, + memgraph::storage::v3::LabelId label, + std::shared_ptr<LogicalOperator> input = {nullptr}, + memgraph::storage::v3::View view = memgraph::storage::v3::View::OLD) { + auto node = NODE(identifier); + auto symbol = symbol_table.CreateSymbol(identifier, true); + node->identifier_->MapTo(symbol); + auto logical_op = std::make_shared<ScanAllByLabel>(input, symbol, label, view); + return ScanAllTuple{node, logical_op, symbol}; +} + +/** + * Creates and returns a tuple of stuff for a scan-all starting from the node + * with the given name and label whose property values are in range. + * + * Returns ScanAllTuple(node_atom, scan_all_logical_op, symbol). + */ +ScanAllTuple MakeScanAllByLabelPropertyRange(AstStorage &storage, SymbolTable &symbol_table, std::string identifier, + memgraph::storage::v3::LabelId label, + memgraph::storage::v3::PropertyId property, + const std::string &property_name, std::optional<Bound> lower_bound, + std::optional<Bound> upper_bound, + std::shared_ptr<LogicalOperator> input = {nullptr}, + memgraph::storage::v3::View view = memgraph::storage::v3::View::OLD) { + auto node = NODE(identifier); + auto symbol = symbol_table.CreateSymbol(identifier, true); + node->identifier_->MapTo(symbol); + auto logical_op = std::make_shared<ScanAllByLabelPropertyRange>(input, symbol, label, property, property_name, + lower_bound, upper_bound, view); + return ScanAllTuple{node, logical_op, symbol}; +} + +/** + * Creates and returns a tuple of stuff for a scan-all starting from the node + * with the given name and label whose property value is equal to given value. + * + * Returns ScanAllTuple(node_atom, scan_all_logical_op, symbol). + */ +ScanAllTuple MakeScanAllByLabelPropertyValue(AstStorage &storage, SymbolTable &symbol_table, std::string identifier, + memgraph::storage::v3::LabelId label, + memgraph::storage::v3::PropertyId property, + const std::string &property_name, Expression *value, + std::shared_ptr<LogicalOperator> input = {nullptr}, + memgraph::storage::v3::View view = memgraph::storage::v3::View::OLD) { + auto node = NODE(identifier); + auto symbol = symbol_table.CreateSymbol(identifier, true); + node->identifier_->MapTo(symbol); + auto logical_op = + std::make_shared<ScanAllByLabelPropertyValue>(input, symbol, label, property, property_name, value, view); + return ScanAllTuple{node, logical_op, symbol}; +} + +struct ExpandTuple { + EdgeAtom *edge_; + Symbol edge_sym_; + NodeAtom *node_; + Symbol node_sym_; + std::shared_ptr<LogicalOperator> op_; +}; + +ExpandTuple MakeExpand(AstStorage &storage, SymbolTable &symbol_table, std::shared_ptr<LogicalOperator> input, + Symbol input_symbol, const std::string &edge_identifier, EdgeAtom::Direction direction, + const std::vector<memgraph::storage::v3::EdgeTypeId> &edge_types, + const std::string &node_identifier, bool existing_node, memgraph::storage::v3::View view) { + auto edge = EDGE(edge_identifier, direction); + auto edge_sym = symbol_table.CreateSymbol(edge_identifier, true); + edge->identifier_->MapTo(edge_sym); + + auto node = NODE(node_identifier); + auto node_sym = symbol_table.CreateSymbol(node_identifier, true); + node->identifier_->MapTo(node_sym); + + auto op = + std::make_shared<Expand>(input, input_symbol, node_sym, edge_sym, direction, edge_types, existing_node, view); + + return ExpandTuple{edge, edge_sym, node, node_sym, op}; +} + +struct UnwindTuple { + Symbol sym_; + std::shared_ptr<LogicalOperator> op_; +}; + +UnwindTuple MakeUnwind(SymbolTable &symbol_table, const std::string &symbol_name, + std::shared_ptr<LogicalOperator> input, Expression *input_expression) { + auto sym = symbol_table.CreateSymbol(symbol_name, true); + auto op = std::make_shared<memgraph::query::v2::plan::Unwind>(input, input_expression, sym); + return UnwindTuple{sym, op}; +} + +template <typename TIterable> +auto CountIterable(TIterable &&iterable) { + uint64_t count = 0; + for (auto it = iterable.begin(); it != iterable.end(); ++it) { + ++count; + } + return count; +} + +inline uint64_t CountEdges(memgraph::query::v2::DbAccessor *dba, memgraph::storage::v3::View view) { + uint64_t count = 0; + for (auto vertex : dba->Vertices(view)) { + auto maybe_edges = vertex.OutEdges(view); + MG_ASSERT(maybe_edges.HasValue()); + count += CountIterable(*maybe_edges); + } + return count; +} diff --git a/tests/unit/query_v2_query_plan_create_set_remove_delete.cpp b/tests/unit/query_v2_query_plan_create_set_remove_delete.cpp new file mode 100644 index 000000000..723aec472 --- /dev/null +++ b/tests/unit/query_v2_query_plan_create_set_remove_delete.cpp @@ -0,0 +1,1097 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <iterator> +#include <memory> +#include <variant> +#include <vector> + +#include "common/types.hpp" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "query/v2/context.hpp" +#include "query/v2/db_accessor.hpp" +#include "query/v2/exceptions.hpp" +#include "query/v2/interpret/frame.hpp" +#include "query/v2/plan/operator.hpp" + +#include "query_v2_query_plan_common.hpp" +#include "storage/v3/id_types.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/schemas.hpp" +#include "storage/v3/storage.hpp" +#include "storage/v3/vertex.hpp" +#include "storage/v3/view.hpp" + +using namespace memgraph::query::v2; +using namespace memgraph::query::v2::plan; + +namespace memgraph::query::tests { + +class QueryPlanCRUDTest : public testing::Test { + protected: + void SetUp() override { + ASSERT_TRUE(db.CreateSchema(label, {storage::v3::SchemaProperty{property, common::SchemaType::INT}})); + } + + storage::v3::Storage db; + const storage::v3::LabelId label{db.NameToLabel("label")}; + const storage::v3::PropertyId property{db.NameToProperty("property")}; +}; + +TEST_F(QueryPlanCRUDTest, CreateNodeWithAttributes) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + AstStorage storage; + SymbolTable symbol_table; + + NodeCreationInfo node; + node.symbol = symbol_table.CreateSymbol("n", true); + node.labels.emplace_back(label); + std::get<std::vector<std::pair<storage::v3::PropertyId, Expression *>>>(node.properties) + .emplace_back(property, LITERAL(42)); + + auto create = std::make_shared<CreateNode>(nullptr, node); + auto context = MakeContext(storage, symbol_table, &dba); + PullAll(*create, &context); + dba.AdvanceCommand(); + + // count the number of vertices + int vertex_count = 0; + for (auto vertex : dba.Vertices(storage::v3::View::OLD)) { + vertex_count++; + auto maybe_labels = vertex.Labels(storage::v3::View::OLD); + ASSERT_TRUE(maybe_labels.HasValue()); + const auto &labels = *maybe_labels; + EXPECT_EQ(labels.size(), 0); + + auto maybe_properties = vertex.Properties(storage::v3::View::OLD); + ASSERT_TRUE(maybe_properties.HasValue()); + const auto &properties = *maybe_properties; + EXPECT_EQ(properties.size(), 1); + auto maybe_prop = vertex.GetProperty(storage::v3::View::OLD, property); + ASSERT_TRUE(maybe_prop.HasValue()); + auto prop_eq = TypedValue(*maybe_prop) == TypedValue(42); + ASSERT_EQ(prop_eq.type(), TypedValue::Type::Bool); + EXPECT_TRUE(prop_eq.ValueBool()); + } + EXPECT_EQ(vertex_count, 1); +} + +TEST(QueryPlan, CreateReturn) { + // test CREATE (n:Person {age: 42}) RETURN n, n.age + storage::v3::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + storage::v3::LabelId label = dba.NameToLabel("Person"); + auto property = PROPERTY_PAIR("property"); + db.CreateSchema(label, {storage::v3::SchemaProperty{property.second, common::SchemaType::INT}}); + + AstStorage storage; + SymbolTable symbol_table; + + NodeCreationInfo node; + node.symbol = symbol_table.CreateSymbol("n", true); + node.labels.emplace_back(label); + std::get<std::vector<std::pair<storage::v3::PropertyId, Expression *>>>(node.properties) + .emplace_back(property.second, LITERAL(42)); + + auto create = std::make_shared<CreateNode>(nullptr, node); + auto named_expr_n = + NEXPR("n", IDENT("n")->MapTo(node.symbol))->MapTo(symbol_table.CreateSymbol("named_expr_n", true)); + auto prop_lookup = PROPERTY_LOOKUP(IDENT("n")->MapTo(node.symbol), property); + auto named_expr_n_p = NEXPR("n", prop_lookup)->MapTo(symbol_table.CreateSymbol("named_expr_n_p", true)); + + auto produce = MakeProduce(create, named_expr_n, named_expr_n_p); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(1, results.size()); + EXPECT_EQ(2, results[0].size()); + EXPECT_EQ(TypedValue::Type::Vertex, results[0][0].type()); + auto maybe_labels = results[0][0].ValueVertex().Labels(storage::v3::View::NEW); + EXPECT_EQ(maybe_labels->size(), 0); + + EXPECT_EQ(TypedValue::Type::Int, results[0][1].type()); + EXPECT_EQ(42, results[0][1].ValueInt()); + + dba.AdvanceCommand(); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::v3::View::OLD))); +} + +TEST(QueryPlan, CreateExpand) { + storage::v3::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + storage::v3::LabelId label_node_1 = dba.NameToLabel("Node1"); + storage::v3::LabelId label_node_2 = dba.NameToLabel("Node2"); + auto property = PROPERTY_PAIR("property"); + storage::v3::EdgeTypeId edge_type = dba.NameToEdgeType("edge_type"); + db.CreateSchema(label_node_1, {storage::v3::SchemaProperty{property.second, common::SchemaType::INT}}); + db.CreateSchema(label_node_2, {storage::v3::SchemaProperty{property.second, common::SchemaType::INT}}); + + SymbolTable symbol_table; + AstStorage storage; + + auto test_create_path = [&](bool cycle, int expected_nodes_created, int expected_edges_created) { + int before_v = CountIterable(dba.Vertices(storage::v3::View::OLD)); + int before_e = CountEdges(&dba, storage::v3::View::OLD); + + // data for the first node + NodeCreationInfo n; + n.symbol = symbol_table.CreateSymbol("n", true); + n.labels.emplace_back(label_node_1); + std::get<std::vector<std::pair<storage::v3::PropertyId, Expression *>>>(n.properties) + .emplace_back(property.second, LITERAL(1)); + + // data for the second node + NodeCreationInfo m; + m.symbol = cycle ? n.symbol : symbol_table.CreateSymbol("m", true); + m.labels.emplace_back(label_node_2); + std::get<std::vector<std::pair<storage::v3::PropertyId, Expression *>>>(m.properties) + .emplace_back(property.second, LITERAL(2)); + + EdgeCreationInfo r; + r.symbol = symbol_table.CreateSymbol("r", true); + r.edge_type = edge_type; + std::get<0>(r.properties).emplace_back(property.second, LITERAL(3)); + + auto create_op = std::make_shared<CreateNode>(nullptr, n); + auto create_expand = std::make_shared<CreateExpand>(m, r, create_op, n.symbol, cycle); + auto context = MakeContext(storage, symbol_table, &dba); + PullAll(*create_expand, &context); + dba.AdvanceCommand(); + + EXPECT_EQ(CountIterable(dba.Vertices(storage::v3::View::OLD)) - before_v, expected_nodes_created); + EXPECT_EQ(CountEdges(&dba, storage::v3::View::OLD) - before_e, expected_edges_created); + }; + + test_create_path(false, 2, 1); + test_create_path(true, 1, 1); + + for (auto vertex : dba.Vertices(storage::v3::View::OLD)) { + auto maybe_labels = vertex.Labels(storage::v3::View::OLD); + MG_ASSERT(maybe_labels.HasValue()); + const auto &labels = *maybe_labels; + EXPECT_EQ(labels.size(), 0); + auto maybe_primary_label = vertex.PrimaryLabel(storage::v3::View::OLD); + ASSERT_TRUE(maybe_primary_label.HasValue()); + if (*maybe_primary_label == label_node_1) { + // node created by first op + EXPECT_EQ(vertex.GetProperty(storage::v3::View::OLD, property.second)->ValueInt(), 1); + } else if (*maybe_primary_label == label_node_2) { + // node create by expansion + EXPECT_EQ(vertex.GetProperty(storage::v3::View::OLD, property.second)->ValueInt(), 2); + } else { + // should not happen + FAIL(); + } + + for (auto vertex : dba.Vertices(storage::v3::View::OLD)) { + auto maybe_edges = vertex.OutEdges(storage::v3::View::OLD); + MG_ASSERT(maybe_edges.HasValue()); + for (auto edge : *maybe_edges) { + EXPECT_EQ(edge.EdgeType(), edge_type); + EXPECT_EQ(edge.GetProperty(storage::v3::View::OLD, property.second)->ValueInt(), 3); + } + } + } +} + +TEST_F(QueryPlanCRUDTest, MatchCreateNode) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}).HasValue()); + dba.AdvanceCommand(); + + SymbolTable symbol_table; + AstStorage storage; + + // first node + auto n_scan_all = MakeScanAll(storage, symbol_table, "n"); + // second node + NodeCreationInfo m; + m.symbol = symbol_table.CreateSymbol("m", true); + m.labels = {label}; + std::get<std::vector<std::pair<storage::v3::PropertyId, Expression *>>>(m.properties) + .emplace_back(property, LITERAL(1)); + + // creation op + auto create_node = std::make_shared<CreateNode>(n_scan_all.op_, m); + + EXPECT_EQ(CountIterable(dba.Vertices(storage::v3::View::OLD)), 3); + auto context = MakeContext(storage, symbol_table, &dba); + PullAll(*create_node, &context); + dba.AdvanceCommand(); + EXPECT_EQ(CountIterable(dba.Vertices(storage::v3::View::OLD)), 6); +} + +TEST_F(QueryPlanCRUDTest, MatchCreateExpand) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}).HasValue()); + dba.AdvanceCommand(); + + // storage::v3::LabelId label_node_1 = dba.NameToLabel("Node1"); + // storage::v3::LabelId label_node_2 = dba.NameToLabel("Node2"); + // storage::v3::PropertyId property = dba.NameToLabel("prop"); + storage::v3::EdgeTypeId edge_type = dba.NameToEdgeType("edge_type"); + + SymbolTable symbol_table; + AstStorage storage; + + auto test_create_path = [&](bool cycle, int expected_nodes_created, int expected_edges_created) { + int before_v = CountIterable(dba.Vertices(storage::v3::View::OLD)); + int before_e = CountEdges(&dba, storage::v3::View::OLD); + + // data for the first node + auto n_scan_all = MakeScanAll(storage, symbol_table, "n"); + + // data for the second node + NodeCreationInfo m; + m.symbol = cycle ? n_scan_all.sym_ : symbol_table.CreateSymbol("m", true); + m.labels = {label}; + std::get<std::vector<std::pair<storage::v3::PropertyId, Expression *>>>(m.properties) + .emplace_back(property, LITERAL(1)); + + EdgeCreationInfo r; + r.symbol = symbol_table.CreateSymbol("r", true); + r.direction = EdgeAtom::Direction::OUT; + r.edge_type = edge_type; + + auto create_expand = std::make_shared<CreateExpand>(m, r, n_scan_all.op_, n_scan_all.sym_, cycle); + auto context = MakeContext(storage, symbol_table, &dba); + PullAll(*create_expand, &context); + dba.AdvanceCommand(); + + EXPECT_EQ(CountIterable(dba.Vertices(storage::v3::View::OLD)) - before_v, expected_nodes_created); + EXPECT_EQ(CountEdges(&dba, storage::v3::View::OLD) - before_e, expected_edges_created); + }; + + test_create_path(false, 3, 3); + test_create_path(true, 0, 6); +} + +TEST_F(QueryPlanCRUDTest, Delete) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // make a fully-connected (one-direction, no cycles) with 4 nodes + std::vector<VertexAccessor> vertices; + for (int i = 0; i < 4; ++i) { + vertices.push_back(*dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i)}})); + } + auto type = dba.NameToEdgeType("type"); + for (int j = 0; j < 4; ++j) + for (int k = j + 1; k < 4; ++k) ASSERT_TRUE(dba.InsertEdge(&vertices[j], &vertices[k], type).HasValue()); + + dba.AdvanceCommand(); + EXPECT_EQ(4, CountIterable(dba.Vertices(storage::v3::View::OLD))); + EXPECT_EQ(6, CountEdges(&dba, storage::v3::View::OLD)); + + AstStorage storage; + SymbolTable symbol_table; + + // attempt to delete a vertex, and fail + { + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); + auto delete_op = std::make_shared<plan::Delete>(n.op_, std::vector<Expression *>{n_get}, false); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*delete_op, &context), QueryRuntimeException); + dba.AdvanceCommand(); + EXPECT_EQ(4, CountIterable(dba.Vertices(storage::v3::View::OLD))); + EXPECT_EQ(6, CountEdges(&dba, storage::v3::View::OLD)); + } + + // detach delete a single vertex + { + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); + auto delete_op = std::make_shared<plan::Delete>(n.op_, std::vector<Expression *>{n_get}, true); + Frame frame(symbol_table.max_position()); + auto context = MakeContext(storage, symbol_table, &dba); + delete_op->MakeCursor(utils::NewDeleteResource())->Pull(frame, context); + dba.AdvanceCommand(); + EXPECT_EQ(3, CountIterable(dba.Vertices(storage::v3::View::OLD))); + EXPECT_EQ(3, CountEdges(&dba, storage::v3::View::OLD)); + } + + // delete all remaining edges + { + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "m", false, + storage::v3::View::NEW); + auto r_get = storage.Create<Identifier>("r")->MapTo(r_m.edge_sym_); + auto delete_op = std::make_shared<plan::Delete>(r_m.op_, std::vector<Expression *>{r_get}, false); + auto context = MakeContext(storage, symbol_table, &dba); + PullAll(*delete_op, &context); + dba.AdvanceCommand(); + EXPECT_EQ(3, CountIterable(dba.Vertices(storage::v3::View::OLD))); + EXPECT_EQ(0, CountEdges(&dba, storage::v3::View::OLD)); + } + + // delete all remaining vertices + { + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); + auto delete_op = std::make_shared<plan::Delete>(n.op_, std::vector<Expression *>{n_get}, false); + auto context = MakeContext(storage, symbol_table, &dba); + PullAll(*delete_op, &context); + dba.AdvanceCommand(); + EXPECT_EQ(0, CountIterable(dba.Vertices(storage::v3::View::OLD))); + EXPECT_EQ(0, CountEdges(&dba, storage::v3::View::OLD)); + } +} + +TEST_F(QueryPlanCRUDTest, DeleteTwiceDeleteBlockingEdge) { + // test deleting the same vertex and edge multiple times + // + // also test vertex deletion succeeds if the prohibiting + // edge is deleted in the same logical op + // + // we test both with the following queries (note the + // undirected edge in MATCH): + // + // CREATE (:label{property: 1})-[:T]->(:label{property: 2}) + // MATCH (n)-[r]-(m) [DETACH] DELETE n, r, m + + auto test_delete = [this](bool detach) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, dba.NameToEdgeType("T")).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::v3::View::OLD))); + EXPECT_EQ(1, CountEdges(&dba, storage::v3::View::OLD)); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::BOTH, {}, "m", false, + storage::v3::View::OLD); + + // getter expressions for deletion + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); + auto r_get = storage.Create<Identifier>("r")->MapTo(r_m.edge_sym_); + auto m_get = storage.Create<Identifier>("m")->MapTo(r_m.node_sym_); + + auto delete_op = std::make_shared<plan::Delete>(r_m.op_, std::vector<Expression *>{n_get, r_get, m_get}, detach); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(2, PullAll(*delete_op, &context)); + dba.AdvanceCommand(); + EXPECT_EQ(0, CountIterable(dba.Vertices(storage::v3::View::OLD))); + EXPECT_EQ(0, CountEdges(&dba, storage::v3::View::OLD)); + }; + + test_delete(true); + test_delete(false); +} + +TEST_F(QueryPlanCRUDTest, DeleteReturn) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // make a fully-connected (one-direction, no cycles) with 4 nodes + for (int i = 0; i < 4; ++i) { + const auto property_value = storage::v3::PropertyValue(i); + auto va = *dba.InsertVertexAndValidate(label, {}, {{property, property_value}}); + EXPECT_EQ(*va.GetProperty(storage::v3::View::NEW, property), property_value); + } + + dba.AdvanceCommand(); + EXPECT_EQ(4, CountIterable(dba.Vertices(storage::v3::View::OLD))); + EXPECT_EQ(0, CountEdges(&dba, storage::v3::View::OLD)); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); + auto delete_op = std::make_shared<plan::Delete>(n.op_, std::vector<Expression *>{n_get}, true); + + auto prop_lookup = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), property); + auto n_p = storage.Create<NamedExpression>("n", prop_lookup)->MapTo(symbol_table.CreateSymbol("bla", true)); + auto produce = MakeProduce(delete_op, n_p); + + auto context = MakeContext(storage, symbol_table, &dba); + ASSERT_THROW(CollectProduce(*produce, &context), QueryRuntimeException); +} + +TEST(QueryPlan, DeleteNull) { + // test (simplified) WITH Null as x delete x + storage::v3::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + + auto once = std::make_shared<Once>(); + auto delete_op = std::make_shared<plan::Delete>(once, std::vector<Expression *>{LITERAL(TypedValue())}, false); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*delete_op, &context)); +} + +TEST_F(QueryPlanCRUDTest, DeleteAdvance) { + // test queries on empty DB: + // CREATE (n: label{property: 1}) + // MATCH (n) DELETE n WITH n ... + // this fails only if the deleted record `n` is actually used in subsequent + // clauses, which is compatible with Neo's behavior. + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); + auto delete_op = std::make_shared<plan::Delete>(n.op_, std::vector<Expression *>{n_get}, false); + auto advance = std::make_shared<Accumulate>(delete_op, std::vector<Symbol>{n.sym_}, true); + auto res_sym = symbol_table.CreateSymbol("res", true); + { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); + dba.AdvanceCommand(); + auto produce = MakeProduce(advance, NEXPR("res", LITERAL(42))->MapTo(res_sym)); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*produce, &context)); + } + { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}).HasValue()); + dba.AdvanceCommand(); + auto n_prop = PROPERTY_LOOKUP(n_get, dba.NameToProperty("prop")); + auto produce = MakeProduce(advance, NEXPR("res", n_prop)->MapTo(res_sym)); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*produce, &context), QueryRuntimeException); + } +} + +TEST_F(QueryPlanCRUDTest, SetProperty) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // graph with 4 vertices in connected pairs + // the origin vertex in each par and both edges + // have a property set + + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}); + auto v4 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(4)}}); + auto edge_type = dba.NameToEdgeType("edge_type"); + ASSERT_TRUE(dba.InsertEdge(&v1, &v3, edge_type).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v2, &v4, edge_type).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + // scan (n)-[r]->(m) + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "m", false, + storage::v3::View::OLD); + + // set prop1 to 42 on n and r + auto prop1 = dba.NameToProperty("prop1"); + auto literal = LITERAL(42); + + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop1); + auto set_n_p = std::make_shared<plan::SetProperty>(r_m.op_, prop1, n_p, literal); + + auto r_p = PROPERTY_LOOKUP(IDENT("r")->MapTo(r_m.edge_sym_), prop1); + auto set_r_p = std::make_shared<plan::SetProperty>(set_n_p, prop1, r_p, literal); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(2, PullAll(*set_r_p, &context)); + dba.AdvanceCommand(); + + EXPECT_EQ(CountEdges(&dba, storage::v3::View::OLD), 2); + for (auto vertex : dba.Vertices(storage::v3::View::OLD)) { + auto maybe_edges = vertex.OutEdges(storage::v3::View::OLD); + ASSERT_TRUE(maybe_edges.HasValue()); + for (auto edge : *maybe_edges) { + ASSERT_EQ(edge.GetProperty(storage::v3::View::OLD, prop1)->type(), storage::v3::PropertyValue::Type::Int); + EXPECT_EQ(edge.GetProperty(storage::v3::View::OLD, prop1)->ValueInt(), 42); + auto from = edge.From(); + auto to = edge.To(); + ASSERT_EQ(from.GetProperty(storage::v3::View::OLD, prop1)->type(), storage::v3::PropertyValue::Type::Int); + EXPECT_EQ(from.GetProperty(storage::v3::View::OLD, prop1)->ValueInt(), 42); + ASSERT_EQ(to.GetProperty(storage::v3::View::OLD, prop1)->type(), storage::v3::PropertyValue::Type::Null); + } + } +} + +TEST_F(QueryPlanCRUDTest, SetProperties) { + auto test_set_properties = [this](bool update) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // graph: ({a: 0})-[:R {b:1}]->({c:2}) + auto prop_a = dba.NameToProperty("a"); + auto prop_b = dba.NameToProperty("b"); + auto prop_c = dba.NameToProperty("c"); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + dba.AdvanceCommand(); + + auto e = dba.InsertEdge(&v1, &v2, dba.NameToEdgeType("R")); + ASSERT_TRUE(v1.SetPropertyAndValidate(prop_a, storage::v3::PropertyValue(0)).HasValue()); + ASSERT_TRUE(e->SetProperty(prop_b, storage::v3::PropertyValue(1)).HasValue()); + ASSERT_TRUE(v2.SetPropertyAndValidate(prop_c, storage::v3::PropertyValue(2)).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + // scan (n)-[r]->(m) + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "m", false, + storage::v3::View::OLD); + + auto op = update ? plan::SetProperties::Op::UPDATE : plan::SetProperties::Op::REPLACE; + + // set properties on r to n, and on r to m + auto r_ident = IDENT("r")->MapTo(r_m.edge_sym_); + auto m_ident = IDENT("m")->MapTo(r_m.node_sym_); + auto set_r_to_n = std::make_shared<plan::SetProperties>(r_m.op_, n.sym_, r_ident, op); + auto set_m_to_r = std::make_shared<plan::SetProperties>(set_r_to_n, r_m.edge_sym_, m_ident, op); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*set_m_to_r, &context)); + dba.AdvanceCommand(); + + EXPECT_EQ(CountEdges(&dba, storage::v3::View::OLD), 1); + for (auto vertex : dba.Vertices(storage::v3::View::OLD)) { + auto maybe_edges = vertex.OutEdges(storage::v3::View::OLD); + ASSERT_TRUE(maybe_edges.HasValue()); + for (auto edge : *maybe_edges) { + auto from = edge.From(); + EXPECT_EQ(from.Properties(storage::v3::View::OLD)->size(), update ? 3 : 1); + if (update) { + ASSERT_EQ(from.GetProperty(storage::v3::View::OLD, prop_a)->type(), storage::v3::PropertyValue::Type::Int); + EXPECT_EQ(from.GetProperty(storage::v3::View::OLD, prop_a)->ValueInt(), 0); + } + ASSERT_EQ(from.GetProperty(storage::v3::View::OLD, prop_b)->type(), storage::v3::PropertyValue::Type::Int); + EXPECT_EQ(from.GetProperty(storage::v3::View::OLD, prop_b)->ValueInt(), 1); + + EXPECT_EQ(edge.Properties(storage::v3::View::OLD)->size(), update ? 3 : 2); + if (update) { + ASSERT_EQ(edge.GetProperty(storage::v3::View::OLD, prop_b)->type(), storage::v3::PropertyValue::Type::Int); + EXPECT_EQ(edge.GetProperty(storage::v3::View::OLD, prop_b)->ValueInt(), 1); + } + ASSERT_EQ(edge.GetProperty(storage::v3::View::OLD, prop_c)->type(), storage::v3::PropertyValue::Type::Int); + EXPECT_EQ(edge.GetProperty(storage::v3::View::OLD, prop_c)->ValueInt(), 2); + + auto to = edge.To(); + EXPECT_EQ(to.Properties(storage::v3::View::OLD)->size(), 2); + ASSERT_EQ(to.GetProperty(storage::v3::View::OLD, prop_c)->type(), storage::v3::PropertyValue::Type::Int); + EXPECT_EQ(to.GetProperty(storage::v3::View::OLD, prop_c)->ValueInt(), 2); + } + } + }; + + test_set_properties(true); + test_set_properties(false); +} + +TEST_F(QueryPlanCRUDTest, SetSecondaryLabels) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + + auto label1 = dba.NameToLabel("label1"); + auto label2 = dba.NameToLabel("label2"); + auto label3 = dba.NameToLabel("label3"); + ASSERT_TRUE(v1.AddLabel(label1).HasValue()); + ASSERT_TRUE(v2.AddLabel(label1).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto label_set = std::make_shared<plan::SetLabels>(n.op_, n.sym_, std::vector<storage::v3::LabelId>{label2, label3}); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(2, PullAll(*label_set, &context)); + + for (auto vertex : dba.Vertices(storage::v3::View::OLD)) { + EXPECT_EQ(3, vertex.Labels(storage::v3::View::NEW)->size()); + EXPECT_TRUE(*vertex.HasLabel(storage::v3::View::NEW, label2)); + EXPECT_TRUE(*vertex.HasLabel(storage::v3::View::NEW, label3)); + } +} + +TEST_F(QueryPlanCRUDTest, RemoveProperty) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // graph with 4 vertices in connected pairs + // the origin vertex in each par and both edges + // have a property set + auto prop1 = dba.NameToProperty("prop1"); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}); + auto v4 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(4)}}); + auto edge_type = dba.NameToEdgeType("edge_type"); + { + auto e = dba.InsertEdge(&v1, &v3, edge_type); + ASSERT_TRUE(e.HasValue()); + ASSERT_TRUE(e->SetProperty(prop1, storage::v3::PropertyValue(42)).HasValue()); + } + ASSERT_TRUE(dba.InsertEdge(&v2, &v4, edge_type).HasValue()); + ASSERT_TRUE(v2.SetProperty(prop1, storage::v3::PropertyValue(42)).HasValue()); + ASSERT_TRUE(v3.SetProperty(prop1, storage::v3::PropertyValue(42)).HasValue()); + ASSERT_TRUE(v4.SetProperty(prop1, storage::v3::PropertyValue(42)).HasValue()); + auto prop2 = dba.NameToProperty("prop2"); + ASSERT_TRUE(v1.SetProperty(prop2, storage::v3::PropertyValue(0)).HasValue()); + ASSERT_TRUE(v2.SetProperty(prop2, storage::v3::PropertyValue(0)).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + // scan (n)-[r]->(m) + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "m", false, + storage::v3::View::OLD); + + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop1); + auto set_n_p = std::make_shared<plan::RemoveProperty>(r_m.op_, prop1, n_p); + + auto r_p = PROPERTY_LOOKUP(IDENT("r")->MapTo(r_m.edge_sym_), prop1); + auto set_r_p = std::make_shared<plan::RemoveProperty>(set_n_p, prop1, r_p); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(2, PullAll(*set_r_p, &context)); + dba.AdvanceCommand(); + + EXPECT_EQ(CountEdges(&dba, storage::v3::View::OLD), 2); + for (auto vertex : dba.Vertices(storage::v3::View::OLD)) { + auto maybe_edges = vertex.OutEdges(storage::v3::View::OLD); + ASSERT_TRUE(maybe_edges.HasValue()); + for (auto edge : *maybe_edges) { + EXPECT_EQ(edge.GetProperty(storage::v3::View::OLD, prop1)->type(), storage::v3::PropertyValue::Type::Null); + auto from = edge.From(); + auto to = edge.To(); + EXPECT_EQ(from.GetProperty(storage::v3::View::OLD, prop1)->type(), storage::v3::PropertyValue::Type::Null); + EXPECT_EQ(from.GetProperty(storage::v3::View::OLD, prop2)->type(), storage::v3::PropertyValue::Type::Int); + EXPECT_EQ(to.GetProperty(storage::v3::View::OLD, prop1)->type(), storage::v3::PropertyValue::Type::Int); + } + } +} + +TEST_F(QueryPlanCRUDTest, RemoveLabels) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto label1 = dba.NameToLabel("label1"); + auto label2 = dba.NameToLabel("label2"); + auto label3 = dba.NameToLabel("label3"); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(v1.AddLabel(label1).HasValue()); + ASSERT_TRUE(v1.AddLabel(label2).HasValue()); + ASSERT_TRUE(v1.AddLabel(label3).HasValue()); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + ASSERT_TRUE(v2.AddLabel(label1).HasValue()); + ASSERT_TRUE(v2.AddLabel(label3).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto label_remove = + std::make_shared<plan::RemoveLabels>(n.op_, n.sym_, std::vector<storage::v3::LabelId>{label1, label2}); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(2, PullAll(*label_remove, &context)); + + for (auto vertex : dba.Vertices(storage::v3::View::OLD)) { + EXPECT_EQ(1, vertex.Labels(storage::v3::View::NEW)->size()); + EXPECT_FALSE(*vertex.HasLabel(storage::v3::View::NEW, label1)); + EXPECT_FALSE(*vertex.HasLabel(storage::v3::View::NEW, label2)); + } +} + +TEST_F(QueryPlanCRUDTest, NodeFilterSet) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // Create a graph such that (v1 {prop: 42}) is connected to v2 and v3. + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto prop = PROPERTY_PAIR("prop"); + ASSERT_TRUE(v1.SetProperty(prop.second, storage::v3::PropertyValue(42)).HasValue()); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}); + auto edge_type = dba.NameToEdgeType("Edge"); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, edge_type).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v1, &v3, edge_type).HasValue()); + dba.AdvanceCommand(); + // Create operations which match (v1 {prop: 42}) -- (v) and increment the + // v1.prop. The expected result is two incremenentations, since v1 is matched + // twice for 2 edges it has. + AstStorage storage; + SymbolTable symbol_table; + // MATCH (n {prop: 42}) -[r]- (m) + auto scan_all = MakeScanAll(storage, symbol_table, "n"); + std::get<0>(scan_all.node_->properties_)[storage.GetPropertyIx(prop.first)] = LITERAL(42); + auto expand = MakeExpand(storage, symbol_table, scan_all.op_, scan_all.sym_, "r", EdgeAtom::Direction::BOTH, {}, "m", + false, storage::v3::View::OLD); + auto *filter_expr = + EQ(storage.Create<PropertyLookup>(scan_all.node_->identifier_, storage.GetPropertyIx(prop.first)), LITERAL(42)); + auto node_filter = std::make_shared<Filter>(expand.op_, filter_expr); + // SET n.prop = n.prop + 1 + auto set_prop = PROPERTY_LOOKUP(IDENT("n")->MapTo(scan_all.sym_), prop); + auto add = ADD(set_prop, LITERAL(1)); + auto set = std::make_shared<plan::SetProperty>(node_filter, prop.second, set_prop, add); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(2, PullAll(*set, &context)); + dba.AdvanceCommand(); + auto prop_eq = TypedValue(*v1.GetProperty(storage::v3::View::OLD, prop.second)) == TypedValue(42 + 2); + ASSERT_EQ(prop_eq.type(), TypedValue::Type::Bool); + EXPECT_TRUE(prop_eq.ValueBool()); +} + +TEST_F(QueryPlanCRUDTest, FilterRemove) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // Create a graph such that (v1 {prop: 42}) is connected to v2 and v3. + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto prop = PROPERTY_PAIR("prop"); + ASSERT_TRUE(v1.SetProperty(prop.second, storage::v3::PropertyValue(42)).HasValue()); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}); + auto edge_type = dba.NameToEdgeType("Edge"); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, edge_type).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v1, &v3, edge_type).HasValue()); + dba.AdvanceCommand(); + // Create operations which match (v1 {prop: 42}) -- (v) and remove v1.prop. + // The expected result is two matches, for each edge of v1. + AstStorage storage; + SymbolTable symbol_table; + // MATCH (n) -[r]- (m) WHERE n.prop < 43 + auto scan_all = MakeScanAll(storage, symbol_table, "n"); + std::get<0>(scan_all.node_->properties_)[storage.GetPropertyIx(prop.first)] = LITERAL(42); + auto expand = MakeExpand(storage, symbol_table, scan_all.op_, scan_all.sym_, "r", EdgeAtom::Direction::BOTH, {}, "m", + false, storage::v3::View::OLD); + auto filter_prop = PROPERTY_LOOKUP(IDENT("n")->MapTo(scan_all.sym_), prop); + auto filter = std::make_shared<Filter>(expand.op_, LESS(filter_prop, LITERAL(43))); + // REMOVE n.prop + auto rem_prop = PROPERTY_LOOKUP(IDENT("n")->MapTo(scan_all.sym_), prop); + auto rem = std::make_shared<plan::RemoveProperty>(filter, prop.second, rem_prop); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(2, PullAll(*rem, &context)); + dba.AdvanceCommand(); + EXPECT_EQ(v1.GetProperty(storage::v3::View::OLD, prop.second)->type(), storage::v3::PropertyValue::Type::Null); +} + +TEST_F(QueryPlanCRUDTest, SetRemove) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto label1 = dba.NameToLabel("label1"); + auto label2 = dba.NameToLabel("label2"); + dba.AdvanceCommand(); + // Create operations which match (v) and set and remove v :label. + // The expected result is single (v) as it was at the start. + AstStorage storage; + SymbolTable symbol_table; + // MATCH (n) SET n :label1 :label2 REMOVE n :label1 :label2 + auto scan_all = MakeScanAll(storage, symbol_table, "n"); + auto set = + std::make_shared<plan::SetLabels>(scan_all.op_, scan_all.sym_, std::vector<storage::v3::LabelId>{label1, label2}); + auto rem = + std::make_shared<plan::RemoveLabels>(set, scan_all.sym_, std::vector<storage::v3::LabelId>{label1, label2}); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*rem, &context)); + dba.AdvanceCommand(); + EXPECT_FALSE(*v.HasLabel(storage::v3::View::OLD, label1)); + EXPECT_FALSE(*v.HasLabel(storage::v3::View::OLD, label2)); +} + +TEST_F(QueryPlanCRUDTest, Merge) { + // test setup: + // - three nodes, two of them connected with T + // - merge input branch matches all nodes + // - merge_match branch looks for an expansion (any direction) + // and sets some property (for result validation) + // - merge_create branch just sets some other property + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, dba.NameToEdgeType("Type")).HasValue()); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto prop = PROPERTY_PAIR("prop"); + auto n = MakeScanAll(storage, symbol_table, "n"); + + // merge_match branch + auto r_m = MakeExpand(storage, symbol_table, std::make_shared<Once>(), n.sym_, "r", EdgeAtom::Direction::BOTH, {}, + "m", false, storage::v3::View::OLD); + auto m_p = PROPERTY_LOOKUP(IDENT("m")->MapTo(r_m.node_sym_), prop); + auto m_set = std::make_shared<plan::SetProperty>(r_m.op_, prop.second, m_p, LITERAL(1)); + + // merge_create branch + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + auto n_set = std::make_shared<plan::SetProperty>(std::make_shared<Once>(), prop.second, n_p, LITERAL(2)); + + auto merge = std::make_shared<plan::Merge>(n.op_, m_set, n_set); + auto context = MakeContext(storage, symbol_table, &dba); + ASSERT_EQ(3, PullAll(*merge, &context)); + dba.AdvanceCommand(); + + ASSERT_EQ(v1.GetProperty(storage::v3::View::OLD, prop.second)->type(), storage::v3::PropertyValue::Type::Int); + ASSERT_EQ(v1.GetProperty(storage::v3::View::OLD, prop.second)->ValueInt(), 1); + ASSERT_EQ(v2.GetProperty(storage::v3::View::OLD, prop.second)->type(), storage::v3::PropertyValue::Type::Int); + ASSERT_EQ(v2.GetProperty(storage::v3::View::OLD, prop.second)->ValueInt(), 1); + ASSERT_EQ(v3.GetProperty(storage::v3::View::OLD, prop.second)->type(), storage::v3::PropertyValue::Type::Int); + ASSERT_EQ(v3.GetProperty(storage::v3::View::OLD, prop.second)->ValueInt(), 2); +} + +TEST_F(QueryPlanCRUDTest, MergeNoInput) { + // merge with no input, creates a single node + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + + NodeCreationInfo node; + node.symbol = symbol_table.CreateSymbol("n", true); + node.labels = {label}; + std::get<std::vector<std::pair<storage::v3::PropertyId, Expression *>>>(node.properties) + .emplace_back(property, LITERAL(1)); + auto create = std::make_shared<CreateNode>(nullptr, node); + auto merge = std::make_shared<plan::Merge>(nullptr, create, create); + + EXPECT_EQ(0, CountIterable(dba.Vertices(storage::v3::View::OLD))); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*merge, &context)); + dba.AdvanceCommand(); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::v3::View::OLD))); +} + +TEST(QueryPlan, SetPropertyOnNull) { + // SET (Null).prop = 42 + storage::v3::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + auto prop = PROPERTY_PAIR("property"); + auto null = LITERAL(TypedValue()); + auto literal = LITERAL(42); + auto n_prop = PROPERTY_LOOKUP(null, prop); + auto once = std::make_shared<Once>(); + auto set_op = std::make_shared<plan::SetProperty>(once, prop.second, n_prop, literal); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*set_op, &context)); +} + +TEST(QueryPlan, SetPropertiesOnNull) { + // OPTIONAL MATCH (n) SET n = n + storage::v3::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_ident = IDENT("n")->MapTo(n.sym_); + auto optional = std::make_shared<plan::Optional>(nullptr, n.op_, std::vector<Symbol>{n.sym_}); + auto set_op = std::make_shared<plan::SetProperties>(optional, n.sym_, n_ident, plan::SetProperties::Op::REPLACE); + EXPECT_EQ(0, CountIterable(dba.Vertices(storage::v3::View::OLD))); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*set_op, &context)); +} + +TEST(QueryPlan, SetLabelsOnNull) { + // OPTIONAL MATCH (n) SET n :label + storage::v3::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto label = dba.NameToLabel("label"); + AstStorage storage; + SymbolTable symbol_table; + auto n = MakeScanAll(storage, symbol_table, "n"); + auto optional = std::make_shared<plan::Optional>(nullptr, n.op_, std::vector<Symbol>{n.sym_}); + auto set_op = std::make_shared<plan::SetLabels>(optional, n.sym_, std::vector<storage::v3::LabelId>{label}); + EXPECT_EQ(0, CountIterable(dba.Vertices(storage::v3::View::OLD))); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*set_op, &context)); +} + +TEST(QueryPlan, RemovePropertyOnNull) { + // REMOVE (Null).prop + storage::v3::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + auto prop = PROPERTY_PAIR("property"); + auto null = LITERAL(TypedValue()); + auto n_prop = PROPERTY_LOOKUP(null, prop); + auto once = std::make_shared<Once>(); + auto remove_op = std::make_shared<plan::RemoveProperty>(once, prop.second, n_prop); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*remove_op, &context)); +} + +TEST(QueryPlan, RemoveLabelsOnNull) { + // OPTIONAL MATCH (n) REMOVE n :label + storage::v3::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto label = dba.NameToLabel("label"); + AstStorage storage; + SymbolTable symbol_table; + auto n = MakeScanAll(storage, symbol_table, "n"); + auto optional = std::make_shared<plan::Optional>(nullptr, n.op_, std::vector<Symbol>{n.sym_}); + auto remove_op = std::make_shared<plan::RemoveLabels>(optional, n.sym_, std::vector<storage::v3::LabelId>{label}); + EXPECT_EQ(0, CountIterable(dba.Vertices(storage::v3::View::OLD))); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*remove_op, &context)); +} + +TEST_F(QueryPlanCRUDTest, DeleteSetProperty) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // Add a single vertex. + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::v3::View::OLD))); + AstStorage storage; + SymbolTable symbol_table; + // MATCH (n) DELETE n SET n.prop = 42 + auto n = MakeScanAllNew(storage, symbol_table, "n"); + + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); + auto delete_op = std::make_shared<plan::Delete>(n.op_, std::vector<Expression *>{n_get}, false); + auto prop = PROPERTY_PAIR("prop"); + auto n_prop = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + auto set_op = std::make_shared<plan::SetProperty>(delete_op, prop.second, n_prop, LITERAL(42)); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*set_op, &context), QueryRuntimeException); +} + +TEST_F(QueryPlanCRUDTest, DeleteSetPropertiesFromMap) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // Add a single vertex. + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::v3::View::OLD))); + AstStorage storage; + SymbolTable symbol_table; + // MATCH (n) DELETE n SET n = {prop: 42} + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); + auto delete_op = std::make_shared<plan::Delete>(n.op_, std::vector<Expression *>{n_get}, false); + auto prop = PROPERTY_PAIR("prop"); + std::unordered_map<PropertyIx, Expression *> prop_map; + prop_map.emplace(storage.GetPropertyIx(prop.first), LITERAL(42)); + auto *rhs = storage.Create<MapLiteral>(prop_map); + for (auto op_type : {plan::SetProperties::Op::REPLACE, plan::SetProperties::Op::UPDATE}) { + auto set_op = std::make_shared<plan::SetProperties>(delete_op, n.sym_, rhs, op_type); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*set_op, &context), QueryRuntimeException); + } +} + +TEST_F(QueryPlanCRUDTest, DeleteSetPropertiesFrom) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // Add a single vertex. + { + auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(v.SetProperty(dba.NameToProperty("prop"), storage::v3::PropertyValue(1)).HasValue()); + } + dba.AdvanceCommand(); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::v3::View::OLD))); + AstStorage storage; + SymbolTable symbol_table; + // MATCH (n) DELETE n SET n = n + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); + auto delete_op = std::make_shared<plan::Delete>(n.op_, std::vector<Expression *>{n_get}, false); + auto *rhs = IDENT("n")->MapTo(n.sym_); + for (auto op_type : {plan::SetProperties::Op::REPLACE, plan::SetProperties::Op::UPDATE}) { + auto set_op = std::make_shared<plan::SetProperties>(delete_op, n.sym_, rhs, op_type); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*set_op, &context), QueryRuntimeException); + } +} + +TEST_F(QueryPlanCRUDTest, DeleteRemoveLabels) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // Add a single vertex. + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::v3::View::OLD))); + AstStorage storage; + SymbolTable symbol_table; + // MATCH (n) DELETE n REMOVE n :label + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); + auto delete_op = std::make_shared<plan::Delete>(n.op_, std::vector<Expression *>{n_get}, false); + std::vector<storage::v3::LabelId> labels{dba.NameToLabel("label1")}; + auto rem_op = std::make_shared<plan::RemoveLabels>(delete_op, n.sym_, labels); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*rem_op, &context), QueryRuntimeException); +} + +TEST_F(QueryPlanCRUDTest, DeleteRemoveProperty) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // Add a single vertex. + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::v3::View::OLD))); + AstStorage storage; + SymbolTable symbol_table; + // MATCH (n) DELETE n REMOVE n.prop + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_get = storage.Create<Identifier>("n")->MapTo(n.sym_); + auto delete_op = std::make_shared<plan::Delete>(n.op_, std::vector<Expression *>{n_get}, false); + auto prop = PROPERTY_PAIR("prop"); + auto n_prop = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + auto rem_op = std::make_shared<plan::RemoveProperty>(delete_op, prop.second, n_prop); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*rem_op, &context), QueryRuntimeException); +} +} // namespace memgraph::query::tests diff --git a/tests/unit/query_v2_query_plan_edge_cases.cpp b/tests/unit/query_v2_query_plan_edge_cases.cpp new file mode 100644 index 000000000..cf268bc27 --- /dev/null +++ b/tests/unit/query_v2_query_plan_edge_cases.cpp @@ -0,0 +1,116 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +// tests in this suite deal with edge cases in logical operator behavior +// that's not easily testable with single-phase testing. instead, for +// easy testing and latter readability they are tested end-to-end. + +#include <filesystem> +#include <optional> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include "query/v2/interpreter.hpp" +#include "result_stream_faker.hpp" +#include "storage/v3/storage.hpp" + +DECLARE_bool(query_cost_planner); + +namespace memgraph::query::v2::tests { + +class QueryExecution : public testing::Test { + protected: + storage::v3::Storage db; + std::optional<storage::v3::Storage> db_; + std::optional<InterpreterContext> interpreter_context_; + std::optional<Interpreter> interpreter_; + + std::filesystem::path data_directory{std::filesystem::temp_directory_path() / + "MG_tests_unit_query_v2_query_plan_edge_cases"}; + + void SetUp() { + db_.emplace(); + interpreter_context_.emplace(&*db_, InterpreterConfig{}, data_directory); + interpreter_.emplace(&*interpreter_context_); + } + + void TearDown() { + interpreter_ = std::nullopt; + interpreter_context_ = std::nullopt; + db_ = std::nullopt; + } + + /** + * Execute the given query and commit the transaction. + * + * Return the query results. + */ + auto Execute(const std::string &query) { + ResultStreamFaker stream(&*db_); + + auto [header, _, qid] = interpreter_->Prepare(query, {}, nullptr); + stream.Header(header); + auto summary = interpreter_->PullAll(&stream); + stream.Summary(summary); + + return stream.GetResults(); + } +}; + +TEST_F(QueryExecution, MissingOptionalIntoExpand) { + Execute("CREATE SCHEMA ON :Person(id INTEGER)"); + Execute("CREATE SCHEMA ON :Dog(id INTEGER)"); + Execute("CREATE SCHEMA ON :Food(id INTEGER)"); + // validating bug where expanding from Null (due to a preceding optional + // match) exhausts the expansion cursor, even if it's input is still not + // exhausted + Execute( + "CREATE (a:Person {id: 1}), (b:Person " + "{id:2})-[:Has]->(:Dog {id: 1})-[:Likes]->(:Food {id: 1})"); + ASSERT_EQ(Execute("MATCH (n) RETURN n").size(), 4); + + auto Exec = [this](bool desc, const std::string &edge_pattern) { + // this test depends on left-to-right query planning + FLAGS_query_cost_planner = false; + return Execute(std::string("MATCH (p:Person) WITH p ORDER BY p.id ") + (desc ? "DESC " : "") + + "OPTIONAL MATCH (p)-->(d:Dog) WITH p, d " + "MATCH (d)" + + edge_pattern + + "(f:Food) " + "RETURN p, d, f") + .size(); + }; + + std::string expand = "-->"; + std::string variable = "-[*1]->"; + std::string bfs = "-[*bfs..1]->"; + + EXPECT_EQ(Exec(false, expand), 1); + EXPECT_EQ(Exec(true, expand), 1); + EXPECT_EQ(Exec(false, variable), 1); + EXPECT_EQ(Exec(true, bfs), 1); + EXPECT_EQ(Exec(true, bfs), 1); +} + +TEST_F(QueryExecution, EdgeUniquenessInOptional) { + Execute("CREATE SCHEMA ON :label(id INTEGER)"); + // Validating that an edge uniqueness check can't fail when the edge is Null + // due to optional match. Since edge-uniqueness only happens in one OPTIONAL + // MATCH, we only need to check that scenario. + Execute("CREATE (:label {id: 1}), (:label {id: 2})-[:Type]->(:label {id: 3})"); + ASSERT_EQ(Execute("MATCH (n) RETURN n").size(), 3); + EXPECT_EQ(Execute("MATCH (n) OPTIONAL MATCH (n)-[r1]->(), (n)-[r2]->() " + "RETURN n, r1, r2") + .size(), + 3); +} +} // namespace memgraph::query::v2::tests diff --git a/tests/unit/query_v2_query_plan_match_filter_return.cpp b/tests/unit/query_v2_query_plan_match_filter_return.cpp new file mode 100644 index 000000000..8be276772 --- /dev/null +++ b/tests/unit/query_v2_query_plan_match_filter_return.cpp @@ -0,0 +1,2066 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <iterator> +#include <memory> +#include <optional> +#include <unordered_map> +#include <variant> +#include <vector> + +#include <fmt/format.h> +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include <cppitertools/enumerate.hpp> +#include <cppitertools/product.hpp> +#include <cppitertools/range.hpp> +#include <cppitertools/repeat.hpp> + +#include "query/v2/context.hpp" +#include "query/v2/exceptions.hpp" +#include "query/v2/plan/operator.hpp" +#include "query_v2_query_common.hpp" +#include "storage/v3/property_value.hpp" + +#include "query_v2_query_plan_common.hpp" + +using namespace memgraph::query::v2; +using namespace memgraph::query::v2::plan; + +namespace std { +template <> +struct hash<std::pair<int, int>> { + size_t operator()(const std::pair<int, int> &p) const { return p.first + 31 * p.second; } +}; +} // namespace std + +namespace memgraph::query::tests { + +class MatchReturnFixture : public testing::Test { + protected: + void SetUp() override { + ASSERT_TRUE(db.CreateSchema(label, {storage::v3::SchemaProperty{property, common::SchemaType::INT}})); + } + + storage::v3::Storage db; + storage::v3::Storage::Accessor storage_dba{db.Access()}; + DbAccessor dba{&storage_dba}; + const storage::v3::LabelId label{db.NameToLabel("label")}; + const storage::v3::PropertyId property{db.NameToProperty("property")}; + AstStorage storage; + SymbolTable symbol_table; + + void AddVertices(int count) { + for (int i = 0; i < count; i++) { + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i)}}).HasValue()); + } + } + + std::vector<Path> PathResults(std::shared_ptr<Produce> &op) { + std::vector<Path> res; + auto context = MakeContext(storage, symbol_table, &dba); + for (const auto &row : CollectProduce(*op, &context)) res.emplace_back(row[0].ValuePath()); + return res; + } +}; + +TEST_F(MatchReturnFixture, MatchReturn) { + AddVertices(2); + dba.AdvanceCommand(); + + auto test_pull_count = [&](storage::v3::View view) { + auto scan_all = MakeScanAll(storage, symbol_table, "n", nullptr, view); + auto output = + NEXPR("n", IDENT("n")->MapTo(scan_all.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(scan_all.op_, output); + auto context = MakeContext(storage, symbol_table, &dba); + return PullAll(*produce, &context); + }; + + EXPECT_EQ(2, test_pull_count(storage::v3::View::NEW)); + EXPECT_EQ(2, test_pull_count(storage::v3::View::OLD)); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); + EXPECT_EQ(3, test_pull_count(storage::v3::View::NEW)); + EXPECT_EQ(2, test_pull_count(storage::v3::View::OLD)); + dba.AdvanceCommand(); + EXPECT_EQ(3, test_pull_count(storage::v3::View::OLD)); +} + +TEST_F(MatchReturnFixture, MatchReturnPath) { + AddVertices(2); + dba.AdvanceCommand(); + + auto scan_all = MakeScanAll(storage, symbol_table, "n", nullptr); + Symbol path_sym = symbol_table.CreateSymbol("path", true); + auto make_path = std::make_shared<ConstructNamedPath>(scan_all.op_, path_sym, std::vector<Symbol>{scan_all.sym_}); + auto output = + NEXPR("path", IDENT("path")->MapTo(path_sym))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(make_path, output); + auto results = PathResults(produce); + ASSERT_EQ(results.size(), 2); + std::vector<Path> expected_paths; + for (const auto &v : dba.Vertices(storage::v3::View::OLD)) expected_paths.emplace_back(v); + ASSERT_EQ(expected_paths.size(), 2); + EXPECT_TRUE(std::is_permutation(expected_paths.begin(), expected_paths.end(), results.begin())); +} + +class QueryPlanMatchFilterTest : public testing::Test { + protected: + QueryPlanMatchFilterTest() { + EXPECT_TRUE(db.CreateSchema(label, {storage::v3::SchemaProperty{property, common::SchemaType::INT}})); + } + + storage::v3::Storage db; + storage::v3::LabelId label = db.NameToLabel("label"); + storage::v3::PropertyId property = db.NameToProperty("property"); +}; + +TEST_F(QueryPlanMatchFilterTest, MatchReturnCartesian) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}) + ->AddLabel(dba.NameToLabel("l1")) + .HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}) + ->AddLabel(dba.NameToLabel("l2")) + .HasValue()); + + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto m = MakeScanAll(storage, symbol_table, "m", n.op_); + auto return_n = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto return_m = NEXPR("m", IDENT("m")->MapTo(m.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_2", true)); + auto produce = MakeProduce(m.op_, return_n, return_m); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), 4); + // ensure the result ordering is OK: + // "n" from the results is the same for the first two rows, while "m" isn't + EXPECT_EQ(results[0][0].ValueVertex(), results[1][0].ValueVertex()); + EXPECT_NE(results[0][1].ValueVertex(), results[1][1].ValueVertex()); +} + +TEST_F(QueryPlanMatchFilterTest, StandaloneReturn) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // add a few nodes to the database + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto output = NEXPR("n", LITERAL(42)); + auto produce = MakeProduce(std::shared_ptr<LogicalOperator>(nullptr), output); + output->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), 1); + EXPECT_EQ(results[0].size(), 1); + EXPECT_EQ(results[0][0].ValueInt(), 42); +} + +TEST_F(QueryPlanMatchFilterTest, NodeFilterLabelsAndProperties) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // add a few nodes to the database + storage::v3::LabelId label1 = dba.NameToLabel("Label1"); + auto property1 = PROPERTY_PAIR("Property1"); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}); + auto v4 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(4)}}); + auto v5 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(5)}}); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(6)}}).HasValue()); + + // test all combination of (label | no_label) * (no_prop | wrong_prop | + // right_prop) + // only v1-v3 will have the right labels + ASSERT_TRUE(v1.AddLabel(label1).HasValue()); + ASSERT_TRUE(v2.AddLabel(label1).HasValue()); + ASSERT_TRUE(v3.AddLabel(label1).HasValue()); + // v1 and v4 will have the right properties + ASSERT_TRUE(v1.SetProperty(property1.second, storage::v3::PropertyValue(42)).HasValue()); + ASSERT_TRUE(v2.SetProperty(property1.second, storage::v3::PropertyValue(1)).HasValue()); + ASSERT_TRUE(v4.SetProperty(property1.second, storage::v3::PropertyValue(42)).HasValue()); + ASSERT_TRUE(v5.SetProperty(property1.second, storage::v3::PropertyValue(1)).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + // make a scan all + auto n = MakeScanAll(storage, symbol_table, "n"); + n.node_->labels_.emplace_back(storage.GetLabelIx(dba.LabelToName(label1))); + std::get<0>(n.node_->properties_)[storage.GetPropertyIx(property1.first)] = LITERAL(42); + + // node filtering + auto *filter_expr = AND(storage.Create<LabelsTest>(n.node_->identifier_, n.node_->labels_), + EQ(PROPERTY_LOOKUP(n.node_->identifier_, property1), LITERAL(42))); + auto node_filter = std::make_shared<Filter>(n.op_, filter_expr); + + // make a named expression and a produce + auto output = NEXPR("x", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(node_filter, output); + + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*produce, &context)); + + // test that filtering works with old records + ASSERT_TRUE(v4.AddLabel(label1).HasValue()); + EXPECT_EQ(1, PullAll(*produce, &context)); + dba.AdvanceCommand(); + EXPECT_EQ(2, PullAll(*produce, &context)); +} + +TEST_F(QueryPlanMatchFilterTest, NodeFilterMultipleLabels) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // add a few nodes to the database + storage::v3::LabelId label1 = dba.NameToLabel("label1"); + storage::v3::LabelId label2 = dba.NameToLabel("label2"); + storage::v3::LabelId label3 = dba.NameToLabel("label3"); + // the test will look for nodes that have label1 and label2 + ASSERT_TRUE( + dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); // NOT accepted + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}) + ->AddLabel(label1) + .HasValue()); // NOT accepted + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}) + ->AddLabel(label2) + .HasValue()); // NOT accepted + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(4)}}) + ->AddLabel(label3) + .HasValue()); // NOT accepted + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(5)}}); // YES accepted + ASSERT_TRUE(v1.AddLabel(label1).HasValue()); + ASSERT_TRUE(v1.AddLabel(label2).HasValue()); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(6)}}); // NOT accepted + ASSERT_TRUE(v2.AddLabel(label1).HasValue()); + ASSERT_TRUE(v2.AddLabel(label3).HasValue()); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(7)}}); // YES accepted + ASSERT_TRUE(v3.AddLabel(label1).HasValue()); + ASSERT_TRUE(v3.AddLabel(label2).HasValue()); + ASSERT_TRUE(v3.AddLabel(label3).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + // make a scan all + auto n = MakeScanAll(storage, symbol_table, "n"); + n.node_->labels_.emplace_back(storage.GetLabelIx(dba.LabelToName(label1))); + n.node_->labels_.emplace_back(storage.GetLabelIx(dba.LabelToName(label2))); + + // node filtering + auto *filter_expr = storage.Create<LabelsTest>(n.node_->identifier_, n.node_->labels_); + auto node_filter = std::make_shared<Filter>(n.op_, filter_expr); + + // make a named expression and a produce + auto output = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(node_filter, output); + + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), 2); +} + +TEST_F(QueryPlanMatchFilterTest, Cartesian) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto add_vertex = [&dba, this](std::string label1) { + auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + MG_ASSERT(vertex.AddLabel(dba.NameToLabel(label1)).HasValue()); + return vertex; + }; + + std::vector<VertexAccessor> vertices{add_vertex("v1"), add_vertex("v2"), add_vertex("v3")}; + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto m = MakeScanAll(storage, symbol_table, "m"); + auto return_n = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto return_m = NEXPR("m", IDENT("m")->MapTo(m.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_2", true)); + + std::vector<Symbol> left_symbols{n.sym_}; + std::vector<Symbol> right_symbols{m.sym_}; + auto cartesian_op = std::make_shared<Cartesian>(n.op_, left_symbols, m.op_, right_symbols); + + auto produce = MakeProduce(cartesian_op, return_n, return_m); + + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), 9); + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 3; ++j) { + EXPECT_EQ(results[3 * i + j][0].ValueVertex(), vertices[j]); + EXPECT_EQ(results[3 * i + j][1].ValueVertex(), vertices[i]); + } + } +} + +TEST_F(QueryPlanMatchFilterTest, CartesianEmptySet) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAllNew(storage, symbol_table, "n"); + auto m = MakeScanAllNew(storage, symbol_table, "m"); + auto return_n = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto return_m = NEXPR("m", IDENT("m")->MapTo(m.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_2", true)); + + std::vector<Symbol> left_symbols{n.sym_}; + std::vector<Symbol> right_symbols{m.sym_}; + auto cartesian_op = std::make_shared<Cartesian>(n.op_, left_symbols, m.op_, right_symbols); + + auto produce = MakeProduce(cartesian_op, return_n, return_m); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), 0); +} + +TEST_F(QueryPlanMatchFilterTest, CartesianThreeWay) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto add_vertex = [&dba, this](std::string label1) { + auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + MG_ASSERT(vertex.AddLabel(dba.NameToLabel(label1)).HasValue()); + return vertex; + }; + + std::vector<VertexAccessor> vertices{add_vertex("v1"), add_vertex("v2"), add_vertex("v3")}; + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAllNew(storage, symbol_table, "n"); + auto m = MakeScanAllNew(storage, symbol_table, "m"); + auto l = MakeScanAllNew(storage, symbol_table, "l"); + auto *return_n = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto *return_m = NEXPR("m", IDENT("m")->MapTo(m.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_2", true)); + auto *return_l = NEXPR("l", IDENT("l")->MapTo(l.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_3", true)); + + std::vector<Symbol> n_symbols{n.sym_}; + std::vector<Symbol> m_symbols{m.sym_}; + std::vector<Symbol> n_m_symbols{n.sym_, m.sym_}; + std::vector<Symbol> l_symbols{l.sym_}; + auto cartesian_op_1 = std::make_shared<Cartesian>(n.op_, n_symbols, m.op_, m_symbols); + + auto cartesian_op_2 = std::make_shared<Cartesian>(cartesian_op_1, n_m_symbols, l.op_, l_symbols); + + auto produce = MakeProduce(cartesian_op_2, return_n, return_m, return_l); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), 27); + int id = 0; + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 3; ++k) { + EXPECT_EQ(results[id][0].ValueVertex(), vertices[k]); + EXPECT_EQ(results[id][1].ValueVertex(), vertices[j]); + EXPECT_EQ(results[id][2].ValueVertex(), vertices[i]); + ++id; + } + } + } +} + +class ExpandFixture : public QueryPlanMatchFilterTest { + protected: + storage::v3::Storage::Accessor storage_dba{db.Access()}; + DbAccessor dba{&storage_dba}; + AstStorage storage; + SymbolTable symbol_table; + + // make a V-graph (v3)<-[r2]-(v1)-[r1]->(v2) + VertexAccessor v1{*dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}})}; + VertexAccessor v2{*dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}})}; + VertexAccessor v3{*dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}})}; + storage::v3::EdgeTypeId edge_type{db.NameToEdgeType("Edge")}; + EdgeAccessor r1{*dba.InsertEdge(&v1, &v2, edge_type)}; + EdgeAccessor r2{*dba.InsertEdge(&v1, &v3, edge_type)}; + + void SetUp() override { + ASSERT_TRUE(v1.AddLabel(dba.NameToLabel("l1")).HasValue()); + ASSERT_TRUE(v2.AddLabel(dba.NameToLabel("l2")).HasValue()); + ASSERT_TRUE(v3.AddLabel(dba.NameToLabel("l3")).HasValue()); + dba.AdvanceCommand(); + } +}; + +TEST_F(ExpandFixture, Expand) { + auto test_expand = [&](EdgeAtom::Direction direction, storage::v3::View view) { + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", direction, {}, "m", false, view); + + // make a named expression and a produce + auto *output = + NEXPR("m", IDENT("m")->MapTo(r_m.node_sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(r_m.op_, output); + auto context = MakeContext(storage, symbol_table, &dba); + return PullAll(*produce, &context); + }; + // test that expand works well for both old and new graph state + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, edge_type).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v1, &v3, edge_type).HasValue()); + EXPECT_EQ(2, test_expand(EdgeAtom::Direction::OUT, storage::v3::View::OLD)); + EXPECT_EQ(2, test_expand(EdgeAtom::Direction::IN, storage::v3::View::OLD)); + EXPECT_EQ(4, test_expand(EdgeAtom::Direction::BOTH, storage::v3::View::OLD)); + EXPECT_EQ(4, test_expand(EdgeAtom::Direction::OUT, storage::v3::View::NEW)); + EXPECT_EQ(4, test_expand(EdgeAtom::Direction::IN, storage::v3::View::NEW)); + EXPECT_EQ(8, test_expand(EdgeAtom::Direction::BOTH, storage::v3::View::NEW)); + dba.AdvanceCommand(); + EXPECT_EQ(4, test_expand(EdgeAtom::Direction::OUT, storage::v3::View::OLD)); + EXPECT_EQ(4, test_expand(EdgeAtom::Direction::IN, storage::v3::View::OLD)); + EXPECT_EQ(8, test_expand(EdgeAtom::Direction::BOTH, storage::v3::View::OLD)); +} + +TEST_F(ExpandFixture, ExpandPath) { + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "m", false, + storage::v3::View::OLD); + Symbol path_sym = symbol_table.CreateSymbol("path", true); + auto path = std::make_shared<ConstructNamedPath>(r_m.op_, path_sym, + std::vector<Symbol>{n.sym_, r_m.edge_sym_, r_m.node_sym_}); + auto output = + NEXPR("path", IDENT("path")->MapTo(path_sym))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(path, output); + + std::vector<Path> expected_paths{Path(v1, r2, v3), Path(v1, r1, v2)}; + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + ASSERT_EQ(results.size(), 2); + std::vector<Path> results_paths; + for (const auto &result : results) results_paths.emplace_back(result[0].ValuePath()); + EXPECT_TRUE(std::is_permutation(expected_paths.begin(), expected_paths.end(), results_paths.begin())); +} + +// /** +// * A fixture that sets a graph up and provides some functions. +// * +// * The graph is a double chain: +// * (v:0)-(v:1)-(v:2) +// * X X +// * (v:0)-(v:1)-(v:2) +// * +// * Each vertex is labeled (the labels are available as a +// * member in this class). Edges have properties set that +// * indicate origin and destination vertex for debugging. +// */ +class QueryPlanExpandVariable : public QueryPlanMatchFilterTest { + protected: + // type returned by the GetEdgeListSizes function, used + // a lot below in test declaration + using map_int = std::unordered_map<int, int>; + + storage::v3::Storage::Accessor storage_dba{db.Access()}; + DbAccessor dba{&storage_dba}; + // labels for layers in the double chain + std::vector<storage::v3::LabelId> labels; + // for all the edges + storage::v3::EdgeTypeId edge_type = dba.NameToEdgeType("edge_type"); + + AstStorage storage; + SymbolTable symbol_table; + + // using std::nullopt + std::nullopt_t nullopt = std::nullopt; + + void SetUp() { + // create the graph + int chain_length = 3; + std::vector<VertexAccessor> layer; + for (int from_layer_ind = -1; from_layer_ind < chain_length - 1; from_layer_ind++) { + std::vector<VertexAccessor> new_layer{ + *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}), + *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}})}; + auto label = dba.NameToLabel(std::to_string(from_layer_ind + 1)); + labels.push_back(label); + for (size_t v_to_ind = 0; v_to_ind < new_layer.size(); v_to_ind++) { + auto &v_to = new_layer[v_to_ind]; + ASSERT_TRUE(v_to.AddLabel(label).HasValue()); + for (size_t v_from_ind = 0; v_from_ind < layer.size(); v_from_ind++) { + auto &v_from = layer[v_from_ind]; + auto edge = dba.InsertEdge(&v_from, &v_to, edge_type); + ASSERT_TRUE(edge->SetProperty(dba.NameToProperty("p"), + storage::v3::PropertyValue(fmt::format( + "V{}{}->V{}{}", from_layer_ind, v_from_ind, from_layer_ind + 1, v_to_ind))) + .HasValue()); + } + } + layer = new_layer; + } + dba.AdvanceCommand(); + ASSERT_EQ(CountIterable(dba.Vertices(storage::v3::View::OLD)), 2 * chain_length); + ASSERT_EQ(CountEdges(&dba, storage::v3::View::OLD), 4 * (chain_length - 1)); + } + + /** + * Expands the given LogicalOperator input with a match + * (ScanAll->Filter(label)->Expand). Can create both VariableExpand + * ops and plain Expand (depending on template param). + * When creating plain Expand the bound arguments (lower, upper) are ignored. + * + * @param is_reverse Set to true if ExpandVariable should produce the list of + * edges in reverse order. As if ExpandVariable starts from `node_to` and ends + * with `node_from`. + * @return the last created logical op. + */ + template <typename TExpansionOperator> + std::shared_ptr<LogicalOperator> AddMatch(std::shared_ptr<LogicalOperator> input_op, const std::string &node_from, + int layer, EdgeAtom::Direction direction, + const std::vector<storage::v3::EdgeTypeId> &edge_types, + std::optional<size_t> lower, std::optional<size_t> upper, Symbol edge_sym, + const std::string &node_to, storage::v3::View view, + bool is_reverse = false) { + auto n_from = MakeScanAll(storage, symbol_table, node_from, input_op); + auto filter_op = std::make_shared<Filter>( + n_from.op_, + storage.Create<LabelsTest>(n_from.node_->identifier_, + std::vector<LabelIx>{storage.GetLabelIx(dba.LabelToName(labels[layer]))})); + + auto n_to = NODE(node_to); + auto n_to_sym = symbol_table.CreateSymbol(node_to, true); + n_to->identifier_->MapTo(n_to_sym); + + if (std::is_same<TExpansionOperator, ExpandVariable>::value) { + // convert optional ints to optional expressions + auto convert = [this](std::optional<size_t> bound) { + return bound ? LITERAL(static_cast<int64_t>(bound.value())) : nullptr; + }; + MG_ASSERT(view == storage::v3::View::OLD, "ExpandVariable should only be planned with storage::v3::View::OLD"); + + return std::make_shared<ExpandVariable>(filter_op, n_from.sym_, n_to_sym, edge_sym, EdgeAtom::Type::DEPTH_FIRST, + direction, edge_types, is_reverse, convert(lower), convert(upper), false, + ExpansionLambda{symbol_table.CreateSymbol("inner_edge", false), + symbol_table.CreateSymbol("inner_node", false), nullptr}, + std::nullopt, std::nullopt); + } else + return std::make_shared<Expand>(filter_op, n_from.sym_, n_to_sym, edge_sym, direction, edge_types, false, view); + } + + /* Creates an edge (in the frame and symbol table). Returns the symbol. */ + auto Edge(const std::string &identifier, EdgeAtom::Direction direction) { + auto edge = EDGE(identifier, direction); + auto edge_sym = symbol_table.CreateSymbol(identifier, true); + edge->identifier_->MapTo(edge_sym); + return edge_sym; + } + + /** + * Pulls from the given input and returns the results under the given symbol. + */ + auto GetListResults(std::shared_ptr<LogicalOperator> input_op, Symbol symbol) { + Frame frame(symbol_table.max_position()); + auto cursor = input_op->MakeCursor(utils::NewDeleteResource()); + auto context = MakeContext(storage, symbol_table, &dba); + std::vector<utils::pmr::vector<TypedValue>> results; + while (cursor->Pull(frame, context)) results.emplace_back(frame[symbol].ValueList()); + return results; + } + + /** + * Pulls from the given input and returns the results under the given symbol. + */ + auto GetPathResults(std::shared_ptr<LogicalOperator> input_op, Symbol symbol) { + Frame frame(symbol_table.max_position()); + auto cursor = input_op->MakeCursor(utils::NewDeleteResource()); + auto context = MakeContext(storage, symbol_table, &dba); + std::vector<Path> results; + while (cursor->Pull(frame, context)) results.emplace_back(frame[symbol].ValuePath()); + return results; + } + + /** + * Pulls from the given input and analyses the edge-list (result of variable + * length expansion) found in the results under the given symbol. + * + * @return a map {edge_list_length -> number_of_results} + */ + auto GetEdgeListSizes(std::shared_ptr<LogicalOperator> input_op, Symbol symbol) { + map_int count_per_length; + for (const auto &edge_list : GetListResults(input_op, symbol)) { + auto length = edge_list.size(); + auto found = count_per_length.find(length); + if (found == count_per_length.end()) + count_per_length[length] = 1; + else + found->second++; + } + return count_per_length; + } +}; + +TEST_F(QueryPlanExpandVariable, OneVariableExpansion) { + auto test_expand = [&](int layer, EdgeAtom::Direction direction, std::optional<size_t> lower, + std::optional<size_t> upper, bool reverse) { + auto e = Edge("r", direction); + return GetEdgeListSizes(AddMatch<ExpandVariable>(nullptr, "n", layer, direction, {}, lower, upper, e, "m", + storage::v3::View::OLD, reverse), + e); + }; + + for (int reverse = 0; reverse < 2; ++reverse) { + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::IN, 0, 0, reverse), (map_int{{0, 2}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, 0, 0, reverse), (map_int{{0, 2}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::BOTH, 0, 0, reverse), (map_int{{0, 2}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::IN, 1, 1, reverse), (map_int{})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, 1, 1, reverse), (map_int{{1, 4}})); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::IN, 1, 1, reverse), (map_int{{1, 4}})); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::OUT, 1, 1, reverse), (map_int{{1, 4}})); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::BOTH, 1, 1, reverse), (map_int{{1, 8}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, 2, 2, reverse), (map_int{{2, 8}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, 2, 3, reverse), (map_int{{2, 8}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, 1, 2, reverse), (map_int{{1, 4}, {2, 8}})); + + // the following tests also check edge-uniqueness (cyphermorphisim) + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::BOTH, 1, 2, reverse), (map_int{{1, 4}, {2, 12}})); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::BOTH, 4, 4, reverse), (map_int{{4, 24}})); + + // default bound values (lower default is 1, upper default is inf) + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, nullopt, 0, reverse), (map_int{})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, nullopt, 1, reverse), (map_int{{1, 4}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, nullopt, 2, reverse), (map_int{{1, 4}, {2, 8}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::BOTH, 7, nullopt, reverse), (map_int{{7, 24}, {8, 24}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::BOTH, 8, nullopt, reverse), (map_int{{8, 24}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::BOTH, 9, nullopt, reverse), (map_int{})); + } +} + +TEST_F(QueryPlanExpandVariable, EdgeUniquenessSingleAndVariableExpansion) { + auto test_expand = [&](int layer, EdgeAtom::Direction direction, std::optional<size_t> lower, + std::optional<size_t> upper, bool single_expansion_before, bool add_uniqueness_check) { + std::shared_ptr<LogicalOperator> last_op{nullptr}; + std::vector<Symbol> symbols; + + if (single_expansion_before) { + symbols.push_back(Edge("r0", direction)); + last_op = AddMatch<Expand>(last_op, "n0", layer, direction, {}, lower, upper, symbols.back(), "m0", + storage::v3::View::OLD); + } + + auto var_length_sym = Edge("r1", direction); + symbols.push_back(var_length_sym); + last_op = AddMatch<ExpandVariable>(last_op, "n1", layer, direction, {}, lower, upper, var_length_sym, "m1", + storage::v3::View::OLD); + + if (!single_expansion_before) { + symbols.push_back(Edge("r2", direction)); + last_op = AddMatch<Expand>(last_op, "n2", layer, direction, {}, lower, upper, symbols.back(), "m2", + storage::v3::View::OLD); + } + + if (add_uniqueness_check) { + auto last_symbol = symbols.back(); + symbols.pop_back(); + last_op = std::make_shared<EdgeUniquenessFilter>(last_op, last_symbol, symbols); + } + + return GetEdgeListSizes(last_op, var_length_sym); + }; + + // no uniqueness between variable and single expansion + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, 2, 3, true, false), (map_int{{2, 4 * 8}})); + // with uniqueness test, different ordering of (variable, single) expansion + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, 2, 3, true, true), (map_int{{2, 3 * 8}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, 2, 3, false, true), (map_int{{2, 3 * 8}})); +} + +TEST_F(QueryPlanExpandVariable, EdgeUniquenessTwoVariableExpansions) { + auto test_expand = [&](int layer, EdgeAtom::Direction direction, std::optional<size_t> lower, + std::optional<size_t> upper, bool add_uniqueness_check) { + auto e1 = Edge("r1", direction); + auto first = + AddMatch<ExpandVariable>(nullptr, "n1", layer, direction, {}, lower, upper, e1, "m1", storage::v3::View::OLD); + auto e2 = Edge("r2", direction); + auto last_op = + AddMatch<ExpandVariable>(first, "n2", layer, direction, {}, lower, upper, e2, "m2", storage::v3::View::OLD); + if (add_uniqueness_check) { + last_op = std::make_shared<EdgeUniquenessFilter>(last_op, e2, std::vector<Symbol>{e1}); + } + + return GetEdgeListSizes(last_op, e2); + }; + + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, 2, 2, false), (map_int{{2, 8 * 8}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, 2, 2, true), (map_int{{2, 5 * 8}})); +} + +TEST_F(QueryPlanExpandVariable, NamedPath) { + auto e = Edge("r", EdgeAtom::Direction::OUT); + auto expand = + AddMatch<ExpandVariable>(nullptr, "n", 0, EdgeAtom::Direction::OUT, {}, 2, 2, e, "m", storage::v3::View::OLD); + auto find_symbol = [this](const std::string &name) { + for (const auto &sym : symbol_table.table()) + if (sym.second.name() == name) return sym.second; + throw std::runtime_error("Symbol not found"); + }; + + auto path_symbol = symbol_table.CreateSymbol("path", true, Symbol::Type::PATH); + auto create_path = std::make_shared<ConstructNamedPath>(expand, path_symbol, + std::vector<Symbol>{find_symbol("n"), e, find_symbol("m")}); + + std::vector<Path> expected_paths; + for (const auto &v : dba.Vertices(storage::v3::View::OLD)) { + if (!*v.HasLabel(storage::v3::View::OLD, labels[0])) continue; + auto maybe_edges1 = v.OutEdges(storage::v3::View::OLD); + for (const auto &e1 : *maybe_edges1) { + auto maybe_edges2 = e1.To().OutEdges(storage::v3::View::OLD); + for (const auto &e2 : *maybe_edges2) { + expected_paths.emplace_back(v, e1, e1.To(), e2, e2.To()); + } + } + } + ASSERT_EQ(expected_paths.size(), 8); + + auto results = GetPathResults(create_path, path_symbol); + ASSERT_EQ(results.size(), 8); + EXPECT_TRUE(std::is_permutation(results.begin(), results.end(), expected_paths.begin())); +} + +TEST_F(QueryPlanExpandVariable, ExpandToSameSymbol) { + auto test_expand = [&](int layer, EdgeAtom::Direction direction, std::optional<size_t> lower, + std::optional<size_t> upper, bool reverse) { + auto e = Edge("r", direction); + + auto node = NODE("n"); + auto symbol = symbol_table.CreateSymbol("n", true); + node->identifier_->MapTo(symbol); + auto logical_op = std::make_shared<ScanAll>(nullptr, symbol, storage::v3::View::OLD); + auto n_from = ScanAllTuple{node, logical_op, symbol}; + + auto filter_op = std::make_shared<Filter>( + n_from.op_, + storage.Create<LabelsTest>(n_from.node_->identifier_, + std::vector<LabelIx>{storage.GetLabelIx(dba.LabelToName(labels[layer]))})); + + // convert optional ints to optional expressions + auto convert = [this](std::optional<size_t> bound) { + return bound ? LITERAL(static_cast<int64_t>(bound.value())) : nullptr; + }; + + return GetEdgeListSizes(std::make_shared<ExpandVariable>( + filter_op, symbol, symbol, e, EdgeAtom::Type::DEPTH_FIRST, direction, + std::vector<storage::v3::EdgeTypeId>{}, reverse, convert(lower), convert(upper), + /* existing = */ true, + ExpansionLambda{symbol_table.CreateSymbol("inner_edge", false), + symbol_table.CreateSymbol("inner_node", false), nullptr}, + std::nullopt, std::nullopt), + e); + }; + + // The graph is a double chain: + // chain 0: (v:0)-(v:1)-(v:2) + // X X + // chain 1: (v:0)-(v:1)-(v:2) + + // Expand from chain 0 v:0 to itself. + // + // It has a total of 3 cycles: + // 1. C0 v:0 -> C0 v:1 -> C1 v:2 -> C1 v:1 -> C0 v:0 + // 2. C0 v:0 -> C0 v:1 -> C0 v:2 -> C1 v:1 -> C0 v:0 + // 3. C0 v:0 -> C0 v:1 -> C1 v:0 -> C1 v:1 -> C0 v:0 + // + // Each cycle can be in two directions, also, we have two starting nodes: one + // in chain 0 and the other in chain 1. + for (auto reverse : {false, true}) { + // Tests with both bounds set. + for (int lower_bound = 0; lower_bound < 10; ++lower_bound) { + for (int upper_bound = lower_bound; upper_bound < 10; ++upper_bound) { + map_int expected_directed; + map_int expected_undirected; + if (lower_bound == 0) { + expected_directed.emplace(0, 2); + expected_undirected.emplace(0, 2); + } + if (lower_bound <= 4 && upper_bound >= 4) { + expected_undirected.emplace(4, 12); + } + if (lower_bound <= 8 && upper_bound >= 8) { + expected_undirected.emplace(8, 24); + } + + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::IN, lower_bound, upper_bound, reverse), expected_directed); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, lower_bound, upper_bound, reverse), expected_directed); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::BOTH, lower_bound, upper_bound, reverse), expected_undirected); + } + } + + // Test only upper bound. + for (int upper_bound = 0; upper_bound < 10; ++upper_bound) { + map_int expected_directed; + map_int expected_undirected; + if (upper_bound >= 4) { + expected_undirected.emplace(4, 12); + } + if (upper_bound >= 8) { + expected_undirected.emplace(8, 24); + } + + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::IN, std::nullopt, upper_bound, reverse), expected_directed); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, std::nullopt, upper_bound, reverse), expected_directed); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::BOTH, std::nullopt, upper_bound, reverse), expected_undirected); + } + + // Test only lower bound. + for (int lower_bound = 0; lower_bound < 10; ++lower_bound) { + map_int expected_directed; + map_int expected_undirected; + if (lower_bound == 0) { + expected_directed.emplace(0, 2); + expected_undirected.emplace(0, 2); + } + if (lower_bound <= 4) { + expected_undirected.emplace(4, 12); + } + if (lower_bound <= 8) { + expected_undirected.emplace(8, 24); + } + + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::IN, lower_bound, std::nullopt, reverse), expected_directed); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, lower_bound, std::nullopt, reverse), expected_directed); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::BOTH, lower_bound, std::nullopt, reverse), expected_undirected); + } + + // Test no bounds. + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::IN, std::nullopt, std::nullopt, reverse), (map_int{})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, std::nullopt, std::nullopt, reverse), (map_int{})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::BOTH, std::nullopt, std::nullopt, reverse), + (map_int{{4, 12}, {8, 24}})); + } + + // Expand from chain 0 v:1 to itself. + // + // It has a total of 6 cycles: + // 1. C0 v:1 -> C1 v:0 -> C1 v:1 -> C1 v:2 -> C0 v:1 + // 2. C0 v:1 -> C1 v:0 -> C1 v:1 -> C0 v:2 -> C0 v:1 + // 3. C0 v:1 -> C0 v:0 -> C1 v:1 -> C1 v:2 -> C0 v:1 + // 4. C0 v:1 -> C0 v:0 -> C1 v:1 -> C0 v:2 -> C0 v:1 + // 5. C0 v:1 -> C1 v:0 -> C1 v:1 -> C0 v:0 -> C0 v:1 + // 6. C0 v:1 -> C1 v:2 -> C1 v:1 -> C0 v:2 -> C0 v:1 + // + // Each cycle can be in two directions, also, we have two starting nodes: one + // in chain 0 and the other in chain 1. + for (auto reverse : {false, true}) { + // Tests with both bounds set. + for (int lower_bound = 0; lower_bound < 10; ++lower_bound) { + for (int upper_bound = lower_bound; upper_bound < 10; ++upper_bound) { + map_int expected_directed; + map_int expected_undirected; + if (lower_bound == 0) { + expected_directed.emplace(0, 2); + expected_undirected.emplace(0, 2); + } + if (lower_bound <= 4 && upper_bound >= 4) { + expected_undirected.emplace(4, 24); + } + if (lower_bound <= 8 && upper_bound >= 8) { + expected_undirected.emplace(8, 48); + } + + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::IN, lower_bound, upper_bound, reverse), expected_directed); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::OUT, lower_bound, upper_bound, reverse), expected_directed); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::BOTH, lower_bound, upper_bound, reverse), expected_undirected); + } + } + + // Test only upper bound. + for (int upper_bound = 0; upper_bound < 10; ++upper_bound) { + map_int expected_directed; + map_int expected_undirected; + if (upper_bound >= 4) { + expected_undirected.emplace(4, 24); + } + if (upper_bound >= 8) { + expected_undirected.emplace(8, 48); + } + + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::IN, std::nullopt, upper_bound, reverse), expected_directed); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::OUT, std::nullopt, upper_bound, reverse), expected_directed); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::BOTH, std::nullopt, upper_bound, reverse), expected_undirected); + } + + // Test only lower bound. + for (int lower_bound = 0; lower_bound < 10; ++lower_bound) { + map_int expected_directed; + map_int expected_undirected; + if (lower_bound == 0) { + expected_directed.emplace(0, 2); + expected_undirected.emplace(0, 2); + } + if (lower_bound <= 4) { + expected_undirected.emplace(4, 24); + } + if (lower_bound <= 8) { + expected_undirected.emplace(8, 48); + } + + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::IN, lower_bound, std::nullopt, reverse), expected_directed); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::OUT, lower_bound, std::nullopt, reverse), expected_directed); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::BOTH, lower_bound, std::nullopt, reverse), expected_undirected); + } + + // Test no bounds. + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::IN, std::nullopt, std::nullopt, reverse), (map_int{})); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::OUT, std::nullopt, std::nullopt, reverse), (map_int{})); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::BOTH, std::nullopt, std::nullopt, reverse), + (map_int{{4, 24}, {8, 48}})); + } +} + +/** A test fixture for weighted shortest path expansion */ +class QueryPlanExpandWeightedShortestPath : public QueryPlanMatchFilterTest { + public: + struct ResultType { + std::vector<EdgeAccessor> path; + VertexAccessor vertex; + double total_weight; + }; + + protected: + storage::v3::Storage::Accessor storage_dba{db.Access()}; + DbAccessor dba{&storage_dba}; + std::pair<std::string, storage::v3::PropertyId> prop = PROPERTY_PAIR("property1"); + storage::v3::EdgeTypeId edge_type = dba.NameToEdgeType("edge_type"); + + // make 5 vertices because we'll need to compare against them exactly + // v[0] has `prop` with the value 0 + std::vector<VertexAccessor> v; + + // make some edges too, in a map (from, to) vertex indices + std::unordered_map<std::pair<int, int>, EdgeAccessor> e; + + AstStorage storage; + SymbolTable symbol_table; + + // inner edge and vertex symbols + Symbol filter_edge = symbol_table.CreateSymbol("f_edge", true); + Symbol filter_node = symbol_table.CreateSymbol("f_node", true); + + Symbol weight_edge = symbol_table.CreateSymbol("w_edge", true); + Symbol weight_node = symbol_table.CreateSymbol("w_node", true); + + Symbol total_weight = symbol_table.CreateSymbol("total_weight", true); + + void SetUp() { + for (int i = 0; i < 5; i++) { + v.push_back(*dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i)}})); + ASSERT_TRUE(v.back().SetProperty(prop.second, storage::v3::PropertyValue(i)).HasValue()); + } + + auto add_edge = [&](int from, int to, double weight) { + auto edge = dba.InsertEdge(&v[from], &v[to], edge_type); + ASSERT_TRUE(edge->SetProperty(prop.second, storage::v3::PropertyValue(weight)).HasValue()); + e.emplace(std::make_pair(from, to), *edge); + }; + + add_edge(0, 1, 5); + add_edge(1, 4, 5); + add_edge(0, 2, 3); + add_edge(2, 3, 3); + add_edge(3, 4, 3); + add_edge(4, 0, 12); + + dba.AdvanceCommand(); + } + + // defines and performs a weighted shortest expansion with the given + // params returns a vector of pairs. each pair is (vector-of-edges, + // vertex) + auto ExpandWShortest(EdgeAtom::Direction direction, std::optional<int> max_depth, Expression *where, + std::optional<int> node_id = 0, ScanAllTuple *existing_node_input = nullptr) { + // scan the nodes optionally filtering on property value + auto n = MakeScanAll(storage, symbol_table, "n", existing_node_input ? existing_node_input->op_ : nullptr); + auto last_op = n.op_; + if (node_id) { + last_op = std::make_shared<Filter>(last_op, EQ(PROPERTY_LOOKUP(n.node_->identifier_, prop), LITERAL(*node_id))); + } + + auto ident_e = IDENT("e"); + ident_e->MapTo(weight_edge); + + // expand wshortest + auto node_sym = existing_node_input ? existing_node_input->sym_ : symbol_table.CreateSymbol("node", true); + auto edge_list_sym = symbol_table.CreateSymbol("edgelist_", true); + auto filter_lambda = last_op = std::make_shared<ExpandVariable>( + last_op, n.sym_, node_sym, edge_list_sym, EdgeAtom::Type::WEIGHTED_SHORTEST_PATH, direction, + std::vector<storage::v3::EdgeTypeId>{}, false, nullptr, max_depth ? LITERAL(max_depth.value()) : nullptr, + existing_node_input != nullptr, ExpansionLambda{filter_edge, filter_node, where}, + ExpansionLambda{weight_edge, weight_node, PROPERTY_LOOKUP(ident_e, prop)}, total_weight); + + Frame frame(symbol_table.max_position()); + auto cursor = last_op->MakeCursor(utils::NewDeleteResource()); + std::vector<ResultType> results; + auto context = MakeContext(storage, symbol_table, &dba); + while (cursor->Pull(frame, context)) { + results.push_back( + ResultType{std::vector<EdgeAccessor>(), frame[node_sym].ValueVertex(), frame[total_weight].ValueDouble()}); + for (const TypedValue &edge : frame[edge_list_sym].ValueList()) + results.back().path.emplace_back(edge.ValueEdge()); + } + + return results; + } + + template <typename TAccessor> + auto GetProp(const TAccessor &accessor) { + return accessor.GetProperty(storage::v3::View::OLD, prop.second)->ValueInt(); + } + + template <typename TAccessor> + auto GetDoubleProp(const TAccessor &accessor) { + return accessor.GetProperty(storage::v3::View::OLD, prop.second)->ValueDouble(); + } + + Expression *PropNe(Symbol symbol, int value) { + auto ident = IDENT("inner_element"); + ident->MapTo(symbol); + return NEQ(PROPERTY_LOOKUP(ident, prop), LITERAL(value)); + } +}; + +// // Testing weighted shortest path on this graph: +// // +// // 5 5 +// // /-->--[1]-->--\ +// // / \ +// // / 12 \ 2 +// // [0]--------<--------[4]------->-------[5] +// // \ / (on some tests only) +// // \ / +// // \->[2]->-[3]->/ +// // 3 3 3 + +TEST_F(QueryPlanExpandWeightedShortestPath, Basic) { + auto results = ExpandWShortest(EdgeAtom::Direction::BOTH, 1000, LITERAL(true)); + + ASSERT_EQ(results.size(), 4); + + // check end nodes + EXPECT_EQ(GetProp(results[0].vertex), 2); + EXPECT_EQ(GetProp(results[1].vertex), 1); + EXPECT_EQ(GetProp(results[2].vertex), 3); + EXPECT_EQ(GetProp(results[3].vertex), 4); + + // check paths and total weights + EXPECT_EQ(results[0].path.size(), 1); + EXPECT_EQ(GetDoubleProp(results[0].path[0]), 3); + EXPECT_EQ(results[0].total_weight, 3); + + EXPECT_EQ(results[1].path.size(), 1); + EXPECT_EQ(GetDoubleProp(results[1].path[0]), 5); + EXPECT_EQ(results[1].total_weight, 5); + + EXPECT_EQ(results[2].path.size(), 2); + EXPECT_EQ(GetDoubleProp(results[2].path[0]), 3); + EXPECT_EQ(GetDoubleProp(results[2].path[1]), 3); + EXPECT_EQ(results[2].total_weight, 6); + + EXPECT_EQ(results[3].path.size(), 3); + EXPECT_EQ(GetDoubleProp(results[3].path[0]), 3); + EXPECT_EQ(GetDoubleProp(results[3].path[1]), 3); + EXPECT_EQ(GetDoubleProp(results[3].path[2]), 3); + EXPECT_EQ(results[3].total_weight, 9); +} + +TEST_F(QueryPlanExpandWeightedShortestPath, EdgeDirection) { + { + auto results = ExpandWShortest(EdgeAtom::Direction::OUT, 1000, LITERAL(true)); + ASSERT_EQ(results.size(), 4); + EXPECT_EQ(GetProp(results[0].vertex), 2); + EXPECT_EQ(results[0].total_weight, 3); + EXPECT_EQ(GetProp(results[1].vertex), 1); + EXPECT_EQ(results[1].total_weight, 5); + EXPECT_EQ(GetProp(results[2].vertex), 3); + EXPECT_EQ(results[2].total_weight, 6); + EXPECT_EQ(GetProp(results[3].vertex), 4); + EXPECT_EQ(results[3].total_weight, 9); + } + { + auto results = ExpandWShortest(EdgeAtom::Direction::IN, 1000, LITERAL(true)); + ASSERT_EQ(results.size(), 4); + EXPECT_EQ(GetProp(results[0].vertex), 4); + EXPECT_EQ(results[0].total_weight, 12); + EXPECT_EQ(GetProp(results[1].vertex), 3); + EXPECT_EQ(results[1].total_weight, 15); + EXPECT_EQ(GetProp(results[2].vertex), 1); + EXPECT_EQ(results[2].total_weight, 17); + EXPECT_EQ(GetProp(results[3].vertex), 2); + EXPECT_EQ(results[3].total_weight, 18); + } +} + +TEST_F(QueryPlanExpandWeightedShortestPath, Where) { + { + auto results = ExpandWShortest(EdgeAtom::Direction::BOTH, 1000, PropNe(filter_node, 2)); + ASSERT_EQ(results.size(), 3); + EXPECT_EQ(GetProp(results[0].vertex), 1); + EXPECT_EQ(results[0].total_weight, 5); + EXPECT_EQ(GetProp(results[1].vertex), 4); + EXPECT_EQ(results[1].total_weight, 10); + EXPECT_EQ(GetProp(results[2].vertex), 3); + EXPECT_EQ(results[2].total_weight, 13); + } + { + auto results = ExpandWShortest(EdgeAtom::Direction::BOTH, 1000, PropNe(filter_node, 1)); + ASSERT_EQ(results.size(), 3); + EXPECT_EQ(GetProp(results[0].vertex), 2); + EXPECT_EQ(results[0].total_weight, 3); + EXPECT_EQ(GetProp(results[1].vertex), 3); + EXPECT_EQ(results[1].total_weight, 6); + EXPECT_EQ(GetProp(results[2].vertex), 4); + EXPECT_EQ(results[2].total_weight, 9); + } +} + +TEST_F(QueryPlanExpandWeightedShortestPath, ExistingNode) { + auto ExpandPreceeding = [this](std::optional<int> preceeding_node_id) { + // scan the nodes optionally filtering on property value + auto n0 = MakeScanAll(storage, symbol_table, "n0"); + if (preceeding_node_id) { + auto filter = std::make_shared<Filter>( + n0.op_, EQ(PROPERTY_LOOKUP(n0.node_->identifier_, prop), LITERAL(*preceeding_node_id))); + // inject the filter op into the ScanAllTuple. that way the filter + // op can be passed into the ExpandWShortest function without too + // much refactor + n0.op_ = filter; + } + + return ExpandWShortest(EdgeAtom::Direction::OUT, 1000, LITERAL(true), std::nullopt, &n0); + }; + + EXPECT_EQ(ExpandPreceeding(std::nullopt).size(), 20); + { + auto results = ExpandPreceeding(3); + ASSERT_EQ(results.size(), 4); + for (int i = 0; i < 4; i++) EXPECT_EQ(GetProp(results[i].vertex), 3); + } +} + +TEST_F(QueryPlanExpandWeightedShortestPath, UpperBound) { + { + auto results = ExpandWShortest(EdgeAtom::Direction::BOTH, std::nullopt, LITERAL(true)); + ASSERT_EQ(results.size(), 4); + EXPECT_EQ(GetProp(results[0].vertex), 2); + EXPECT_EQ(results[0].total_weight, 3); + EXPECT_EQ(GetProp(results[1].vertex), 1); + EXPECT_EQ(results[1].total_weight, 5); + EXPECT_EQ(GetProp(results[2].vertex), 3); + EXPECT_EQ(results[2].total_weight, 6); + EXPECT_EQ(GetProp(results[3].vertex), 4); + EXPECT_EQ(results[3].total_weight, 9); + } + { + auto results = ExpandWShortest(EdgeAtom::Direction::BOTH, 2, LITERAL(true)); + ASSERT_EQ(results.size(), 4); + EXPECT_EQ(GetProp(results[0].vertex), 2); + EXPECT_EQ(results[0].total_weight, 3); + EXPECT_EQ(GetProp(results[1].vertex), 1); + EXPECT_EQ(results[1].total_weight, 5); + EXPECT_EQ(GetProp(results[2].vertex), 3); + EXPECT_EQ(results[2].total_weight, 6); + EXPECT_EQ(GetProp(results[3].vertex), 4); + EXPECT_EQ(results[3].total_weight, 10); + } + { + auto results = ExpandWShortest(EdgeAtom::Direction::BOTH, 1, LITERAL(true)); + ASSERT_EQ(results.size(), 3); + EXPECT_EQ(GetProp(results[0].vertex), 2); + EXPECT_EQ(results[0].total_weight, 3); + EXPECT_EQ(GetProp(results[1].vertex), 1); + EXPECT_EQ(results[1].total_weight, 5); + EXPECT_EQ(GetProp(results[2].vertex), 4); + EXPECT_EQ(results[2].total_weight, 12); + } + { + auto new_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(new_vertex.SetProperty(prop.second, storage::v3::PropertyValue(5)).HasValue()); + auto edge = dba.InsertEdge(&v[4], &new_vertex, edge_type); + ASSERT_TRUE(edge.HasValue()); + ASSERT_TRUE(edge->SetProperty(prop.second, storage::v3::PropertyValue(2)).HasValue()); + dba.AdvanceCommand(); + + auto results = ExpandWShortest(EdgeAtom::Direction::BOTH, 3, LITERAL(true)); + + ASSERT_EQ(results.size(), 5); + EXPECT_EQ(GetProp(results[0].vertex), 2); + EXPECT_EQ(results[0].total_weight, 3); + EXPECT_EQ(GetProp(results[1].vertex), 1); + EXPECT_EQ(results[1].total_weight, 5); + EXPECT_EQ(GetProp(results[2].vertex), 3); + EXPECT_EQ(results[2].total_weight, 6); + EXPECT_EQ(GetProp(results[3].vertex), 4); + EXPECT_EQ(results[3].total_weight, 9); + EXPECT_EQ(GetProp(results[4].vertex), 5); + EXPECT_EQ(results[4].total_weight, 12); + } +} + +TEST_F(QueryPlanExpandWeightedShortestPath, NonNumericWeight) { + auto new_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(new_vertex.SetProperty(prop.second, storage::v3::PropertyValue(5)).HasValue()); + auto edge = dba.InsertEdge(&v[4], &new_vertex, edge_type); + ASSERT_TRUE(edge.HasValue()); + ASSERT_TRUE(edge->SetProperty(prop.second, storage::v3::PropertyValue("not a number")).HasValue()); + dba.AdvanceCommand(); + EXPECT_THROW(ExpandWShortest(EdgeAtom::Direction::BOTH, 1000, LITERAL(true)), QueryRuntimeException); +} + +TEST_F(QueryPlanExpandWeightedShortestPath, NegativeWeight) { + auto new_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(new_vertex.SetProperty(prop.second, storage::v3::PropertyValue(5)).HasValue()); + auto edge = dba.InsertEdge(&v[4], &new_vertex, edge_type); + ASSERT_TRUE(edge.HasValue()); + ASSERT_TRUE(edge->SetProperty(prop.second, storage::v3::PropertyValue(-10)).HasValue()); // negative weight + dba.AdvanceCommand(); + EXPECT_THROW(ExpandWShortest(EdgeAtom::Direction::BOTH, 1000, LITERAL(true)), QueryRuntimeException); +} + +TEST_F(QueryPlanExpandWeightedShortestPath, NegativeUpperBound) { + EXPECT_THROW(ExpandWShortest(EdgeAtom::Direction::BOTH, -1, LITERAL(true)), QueryRuntimeException); +} + +TEST_F(QueryPlanMatchFilterTest, ExpandOptional) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + AstStorage storage; + SymbolTable symbol_table; + + // graph (v2 {p: 2})<-[:T]-(v1 {p: 1})-[:T]->(v3 {p: 2}) + auto prop = dba.NameToProperty("p"); + auto edge_type = dba.NameToEdgeType("T"); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(v1.SetProperty(prop, storage::v3::PropertyValue(1)).HasValue()); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + ASSERT_TRUE(v2.SetProperty(prop, storage::v3::PropertyValue(2)).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, edge_type).HasValue()); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}); + ASSERT_TRUE(v3.SetProperty(prop, storage::v3::PropertyValue(2)).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v1, &v3, edge_type).HasValue()); + dba.AdvanceCommand(); + + // MATCH (n) OPTIONAL MATCH (n)-[r]->(m) + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, nullptr, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "m", false, + storage::v3::View::OLD); + auto optional = std::make_shared<plan::Optional>(n.op_, r_m.op_, std::vector<Symbol>{r_m.edge_sym_, r_m.node_sym_}); + + // RETURN n, r, m + auto n_ne = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("n", true)); + auto r_ne = NEXPR("r", IDENT("r")->MapTo(r_m.edge_sym_))->MapTo(symbol_table.CreateSymbol("r", true)); + auto m_ne = NEXPR("m", IDENT("m")->MapTo(r_m.node_sym_))->MapTo(symbol_table.CreateSymbol("m", true)); + auto produce = MakeProduce(optional, n_ne, r_ne, m_ne); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + ASSERT_EQ(4, results.size()); + int v1_is_n_count = 0; + for (auto &row : results) { + ASSERT_EQ(row[0].type(), TypedValue::Type::Vertex); + auto &va = row[0].ValueVertex(); + auto va_p = *va.GetProperty(storage::v3::View::OLD, prop); + ASSERT_EQ(va_p.type(), storage::v3::PropertyValue::Type::Int); + if (va_p.ValueInt() == 1) { + v1_is_n_count++; + EXPECT_EQ(row[1].type(), TypedValue::Type::Edge); + EXPECT_EQ(row[2].type(), TypedValue::Type::Vertex); + } else { + EXPECT_EQ(row[1].type(), TypedValue::Type::Null); + EXPECT_EQ(row[2].type(), TypedValue::Type::Null); + } + } + EXPECT_EQ(2, v1_is_n_count); +} + +TEST(QueryPlan, OptionalMatchEmptyDB) { + storage::v3::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + AstStorage storage; + SymbolTable symbol_table; + + // OPTIONAL MATCH (n) + auto n = MakeScanAllNew(storage, symbol_table, "n"); + // RETURN n + auto n_ne = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("n", true)); + auto optional = std::make_shared<plan::Optional>(nullptr, n.op_, std::vector<Symbol>{n.sym_}); + auto produce = MakeProduce(optional, n_ne); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + ASSERT_EQ(1, results.size()); + EXPECT_EQ(results[0][0].type(), TypedValue::Type::Null); +} + +TEST(QueryPlan, OptionalMatchEmptyDBExpandFromNode) { + storage::v3::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + // OPTIONAL MATCH (n) + auto n = MakeScanAll(storage, symbol_table, "n"); + auto optional = std::make_shared<plan::Optional>(nullptr, n.op_, std::vector<Symbol>{n.sym_}); + // WITH n + auto n_ne = NEXPR("n", IDENT("n")->MapTo(n.sym_)); + auto with_n_sym = symbol_table.CreateSymbol("n", true); + n_ne->MapTo(with_n_sym); + auto with = MakeProduce(optional, n_ne); + // MATCH (n) -[r]-> (m) + auto r_m = MakeExpand(storage, symbol_table, with, with_n_sym, "r", EdgeAtom::Direction::OUT, {}, "m", false, + storage::v3::View::OLD); + // RETURN m + auto m_ne = NEXPR("m", IDENT("m")->MapTo(r_m.node_sym_))->MapTo(symbol_table.CreateSymbol("m", true)); + auto produce = MakeProduce(r_m.op_, m_ne); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(0, results.size()); +} + +TEST_F(QueryPlanMatchFilterTest, OptionalMatchThenExpandToMissingNode) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + // Make a graph with 2 connected, unlabeled nodes. + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + auto edge_type = dba.NameToEdgeType("edge_type"); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, edge_type).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::v3::View::OLD))); + EXPECT_EQ(1, CountEdges(&dba, storage::v3::View::OLD)); + AstStorage storage; + SymbolTable symbol_table; + // OPTIONAL MATCH (n :missing) + auto n = MakeScanAll(storage, symbol_table, "n"); + auto label_missing = "missing"; + n.node_->labels_.emplace_back(storage.GetLabelIx(label_missing)); + + auto *filter_expr = storage.Create<LabelsTest>(n.node_->identifier_, n.node_->labels_); + auto node_filter = std::make_shared<Filter>(n.op_, filter_expr); + auto optional = std::make_shared<plan::Optional>(nullptr, node_filter, std::vector<Symbol>{n.sym_}); + // WITH n + auto n_ne = NEXPR("n", IDENT("n")->MapTo(n.sym_)); + auto with_n_sym = symbol_table.CreateSymbol("n", true); + n_ne->MapTo(with_n_sym); + auto with = MakeProduce(optional, n_ne); + // MATCH (m) -[r]-> (n) + auto m = MakeScanAll(storage, symbol_table, "m", with); + auto edge_direction = EdgeAtom::Direction::OUT; + auto edge = EDGE("r", edge_direction); + auto edge_sym = symbol_table.CreateSymbol("r", true); + edge->identifier_->MapTo(edge_sym); + auto node = NODE("n"); + node->identifier_->MapTo(with_n_sym); + auto expand = std::make_shared<plan::Expand>(m.op_, m.sym_, with_n_sym, edge_sym, edge_direction, + std::vector<storage::v3::EdgeTypeId>{}, true, storage::v3::View::OLD); + // RETURN m + auto m_ne = NEXPR("m", IDENT("m")->MapTo(m.sym_))->MapTo(symbol_table.CreateSymbol("m", true)); + auto produce = MakeProduce(expand, m_ne); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(0, results.size()); +} + +TEST_F(QueryPlanMatchFilterTest, ExpandExistingNode) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // make a graph (v1)->(v2) that + // has a recursive edge (v1)->(v1) + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + auto edge_type = dba.NameToEdgeType("Edge"); + ASSERT_TRUE(dba.InsertEdge(&v1, &v1, edge_type).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, edge_type).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto test_existing = [&](bool with_existing, int expected_result_count) { + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_n = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "n", with_existing, + storage::v3::View::OLD); + if (with_existing) + r_n.op_ = std::make_shared<Expand>(n.op_, n.sym_, n.sym_, r_n.edge_sym_, r_n.edge_->direction_, + std::vector<storage::v3::EdgeTypeId>{}, with_existing, storage::v3::View::OLD); + + // make a named expression and a produce + auto output = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(r_n.op_, output); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), expected_result_count); + }; + + test_existing(true, 1); + test_existing(false, 2); +} + +TEST_F(QueryPlanMatchFilterTest, ExpandBothCycleEdgeCase) { + // we're testing that expanding on BOTH + // does only one expansion for a cycle + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(dba.InsertEdge(&v, &v, dba.NameToEdgeType("et")).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_ = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::BOTH, {}, "_", false, + storage::v3::View::OLD); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*r_.op_, &context)); +} + +TEST_F(QueryPlanMatchFilterTest, EdgeFilter) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // make an N-star expanding from (v1) + // where only one edge will qualify + // and there are all combinations of + // (edge_type yes|no) * (property yes|absent|no) + std::vector<storage::v3::EdgeTypeId> edge_types; + for (int j = 0; j < 2; ++j) edge_types.push_back(dba.NameToEdgeType("et" + std::to_string(j))); + std::vector<VertexAccessor> vertices; + for (int i = 0; i < 7; ++i) { + vertices.push_back(*dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i)}})); + } + auto prop = PROPERTY_PAIR("property1"); + std::vector<EdgeAccessor> edges; + for (int i = 0; i < 6; ++i) { + edges.push_back(*dba.InsertEdge(&vertices[0], &vertices[i + 1], edge_types[i % 2])); + switch (i % 3) { + case 0: + ASSERT_TRUE(edges.back().SetProperty(prop.second, storage::v3::PropertyValue(42)).HasValue()); + break; + case 1: + ASSERT_TRUE(edges.back().SetProperty(prop.second, storage::v3::PropertyValue(100)).HasValue()); + break; + default: + break; + } + } + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto test_filter = [&]() { + // define an operator tree for query + // MATCH (n)-[r :et0 {property: 42}]->(m) RETURN m + + auto n = MakeScanAll(storage, symbol_table, "n"); + const auto &edge_type = edge_types[0]; + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {edge_type}, "m", false, + storage::v3::View::OLD); + r_m.edge_->edge_types_.push_back(storage.GetEdgeTypeIx(dba.EdgeTypeToName(edge_type))); + std::get<0>(r_m.edge_->properties_)[storage.GetPropertyIx(prop.first)] = LITERAL(42); + auto *filter_expr = EQ(PROPERTY_LOOKUP(r_m.edge_->identifier_, prop), LITERAL(42)); + auto edge_filter = std::make_shared<Filter>(r_m.op_, filter_expr); + + // make a named expression and a produce + auto output = + NEXPR("m", IDENT("m")->MapTo(r_m.node_sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(edge_filter, output); + auto context = MakeContext(storage, symbol_table, &dba); + return PullAll(*produce, &context); + }; + + EXPECT_EQ(1, test_filter()); + // test that edge filtering always filters on old state + for (auto &edge : edges) ASSERT_TRUE(edge.SetProperty(prop.second, storage::v3::PropertyValue(42)).HasValue()); + EXPECT_EQ(1, test_filter()); + dba.AdvanceCommand(); + EXPECT_EQ(3, test_filter()); +} + +TEST_F(QueryPlanMatchFilterTest, EdgeFilterMultipleTypes) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + auto type_1 = dba.NameToEdgeType("type_1"); + auto type_2 = dba.NameToEdgeType("type_2"); + auto type_3 = dba.NameToEdgeType("type_3"); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, type_1).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, type_2).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, type_3).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + // make a scan all + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {type_1, type_2}, "m", + false, storage::v3::View::OLD); + + // make a named expression and a produce + auto output = + NEXPR("m", IDENT("m")->MapTo(r_m.node_sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(r_m.op_, output); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), 2); +} + +TEST_F(QueryPlanMatchFilterTest, Filter) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // add a 6 nodes with property 'prop', 2 have true as value + auto property1 = PROPERTY_PAIR("property1"); + for (int i = 0; i < 6; ++i) { + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i)}}) + ->SetProperty(property1.second, storage::v3::PropertyValue(i % 3 == 0)) + .HasValue()); + } + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}) + .HasValue()); // prop not set, gives NULL + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto e = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), property1); + auto f = std::make_shared<Filter>(n.op_, e); + + auto output = NEXPR("x", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(f, output); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(CollectProduce(*produce, &context).size(), 2); +} + +TEST_F(QueryPlanMatchFilterTest, EdgeUniquenessFilter) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // make a graph that has (v1)->(v2) and a recursive edge (v1)->(v1) + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + auto edge_type = dba.NameToEdgeType("edge_type"); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, edge_type).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v1, &v1, edge_type).HasValue()); + dba.AdvanceCommand(); + + auto check_expand_results = [&](bool edge_uniqueness) { + AstStorage storage; + SymbolTable symbol_table; + + auto n1 = MakeScanAll(storage, symbol_table, "n1"); + auto r1_n2 = MakeExpand(storage, symbol_table, n1.op_, n1.sym_, "r1", EdgeAtom::Direction::OUT, {}, "n2", false, + storage::v3::View::OLD); + std::shared_ptr<LogicalOperator> last_op = r1_n2.op_; + auto r2_n3 = MakeExpand(storage, symbol_table, last_op, r1_n2.node_sym_, "r2", EdgeAtom::Direction::OUT, {}, "n3", + false, storage::v3::View::OLD); + last_op = r2_n3.op_; + if (edge_uniqueness) + last_op = std::make_shared<EdgeUniquenessFilter>(last_op, r2_n3.edge_sym_, std::vector<Symbol>{r1_n2.edge_sym_}); + auto context = MakeContext(storage, symbol_table, &dba); + return PullAll(*last_op, &context); + }; + + EXPECT_EQ(2, check_expand_results(false)); + EXPECT_EQ(1, check_expand_results(true)); +} + +TEST(QueryPlan, Distinct) { + // test queries like + // UNWIND [1, 2, 3, 3] AS x RETURN DISTINCT x + + storage::v3::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + + auto check_distinct = [&](const std::vector<TypedValue> input, const std::vector<TypedValue> output, + bool assume_int_value) { + auto input_expr = LITERAL(TypedValue(input)); + + auto x = symbol_table.CreateSymbol("x", true); + auto unwind = std::make_shared<plan::Unwind>(nullptr, input_expr, x); + auto x_expr = IDENT("x"); + x_expr->MapTo(x); + + auto distinct = std::make_shared<plan::Distinct>(unwind, std::vector<Symbol>{x}); + + auto x_ne = NEXPR("x", x_expr); + x_ne->MapTo(symbol_table.CreateSymbol("x_ne", true)); + auto produce = MakeProduce(distinct, x_ne); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + ASSERT_EQ(output.size(), results.size()); + auto output_it = output.begin(); + for (const auto &row : results) { + ASSERT_EQ(1, row.size()); + ASSERT_EQ(row[0].type(), output_it->type()); + if (assume_int_value) EXPECT_EQ(output_it->ValueInt(), row[0].ValueInt()); + output_it++; + } + }; + + check_distinct({TypedValue(1), TypedValue(1), TypedValue(2), TypedValue(3), TypedValue(3), TypedValue(3)}, + {TypedValue(1), TypedValue(2), TypedValue(3)}, true); + check_distinct({TypedValue(3), TypedValue(2), TypedValue(3), TypedValue(5), TypedValue(3), TypedValue(5), + TypedValue(2), TypedValue(1), TypedValue(2)}, + {TypedValue(3), TypedValue(2), TypedValue(5), TypedValue(1)}, true); + check_distinct( + {TypedValue(3), TypedValue("two"), TypedValue(), TypedValue(3), TypedValue(true), TypedValue(false), + TypedValue("TWO"), TypedValue()}, + {TypedValue(3), TypedValue("two"), TypedValue(), TypedValue(true), TypedValue(false), TypedValue("TWO")}, false); +} + +TEST_F(QueryPlanMatchFilterTest, ScanAllByLabel) { + auto label1 = db.NameToLabel("label"); + db.CreateIndex(label1); + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + // Add a vertex with a label and one without. + auto labeled_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(labeled_vertex.AddLabel(label1).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); + + dba.AdvanceCommand(); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::v3::View::OLD))); + // MATCH (n :label) + AstStorage storage; + SymbolTable symbol_table; + auto scan_all_by_label = MakeScanAllByLabel(storage, symbol_table, "n", label1); + // RETURN n + auto output = NEXPR("n", IDENT("n")->MapTo(scan_all_by_label.sym_))->MapTo(symbol_table.CreateSymbol("n", true)); + auto produce = MakeProduce(scan_all_by_label.op_, output); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + ASSERT_EQ(results.size(), 1); + auto result_row = results[0]; + ASSERT_EQ(result_row.size(), 1); + EXPECT_EQ(result_row[0].ValueVertex(), labeled_vertex); +} + +TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelProperty) { + // Add 5 vertices with same label, but with different property values. + auto label1 = db.NameToLabel("label1"); + auto prop = db.NameToProperty("prop"); + // vertex property values that will be stored into the DB + std::vector<storage::v3::PropertyValue> values{ + storage::v3::PropertyValue(true), + storage::v3::PropertyValue(false), + storage::v3::PropertyValue("a"), + storage::v3::PropertyValue("b"), + storage::v3::PropertyValue("c"), + storage::v3::PropertyValue(0), + storage::v3::PropertyValue(1), + storage::v3::PropertyValue(2), + storage::v3::PropertyValue(0.5), + storage::v3::PropertyValue(1.5), + storage::v3::PropertyValue(2.5), + storage::v3::PropertyValue(std::vector<storage::v3::PropertyValue>{storage::v3::PropertyValue(0)}), + storage::v3::PropertyValue(std::vector<storage::v3::PropertyValue>{storage::v3::PropertyValue(1)}), + storage::v3::PropertyValue(std::vector<storage::v3::PropertyValue>{storage::v3::PropertyValue(2)})}; + { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + for (const auto &value : values) { + auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(vertex.AddLabel(label1).HasValue()); + ASSERT_TRUE(vertex.SetProperty(prop, value).HasValue()); + } + ASSERT_FALSE(dba.Commit().HasError()); + } + db.CreateIndex(label1, prop); + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + ASSERT_EQ(14, CountIterable(dba.Vertices(storage::v3::View::OLD))); + + auto run_scan_all = [&](const TypedValue &lower, Bound::Type lower_type, const TypedValue &upper, + Bound::Type upper_type) { + AstStorage storage; + SymbolTable symbol_table; + auto scan_all = + MakeScanAllByLabelPropertyRange(storage, symbol_table, "n", label1, prop, "prop", + Bound{LITERAL(lower), lower_type}, Bound{LITERAL(upper), upper_type}); + // RETURN n + auto output = NEXPR("n", IDENT("n")->MapTo(scan_all.sym_))->MapTo(symbol_table.CreateSymbol("n", true)); + auto produce = MakeProduce(scan_all.op_, output); + auto context = MakeContext(storage, symbol_table, &dba); + return CollectProduce(*produce, &context); + }; + + auto check = [&](TypedValue lower, Bound::Type lower_type, TypedValue upper, Bound::Type upper_type, + const std::vector<TypedValue> &expected) { + auto results = run_scan_all(lower, lower_type, upper, upper_type); + ASSERT_EQ(results.size(), expected.size()); + for (size_t i = 0; i < expected.size(); i++) { + TypedValue equal = + TypedValue(*results[i][0].ValueVertex().GetProperty(storage::v3::View::OLD, prop)) == expected[i]; + ASSERT_EQ(equal.type(), TypedValue::Type::Bool); + EXPECT_TRUE(equal.ValueBool()); + } + }; + + // normal ranges that return something + check(TypedValue("a"), Bound::Type::EXCLUSIVE, TypedValue("c"), Bound::Type::EXCLUSIVE, {TypedValue("b")}); + check(TypedValue(0), Bound::Type::EXCLUSIVE, TypedValue(2), Bound::Type::INCLUSIVE, + {TypedValue(0.5), TypedValue(1), TypedValue(1.5), TypedValue(2)}); + check(TypedValue(1.5), Bound::Type::EXCLUSIVE, TypedValue(2.5), Bound::Type::INCLUSIVE, + {TypedValue(2), TypedValue(2.5)}); + + auto are_comparable = [](storage::v3::PropertyValue::Type a, storage::v3::PropertyValue::Type b) { + auto is_numeric = [](const storage::v3::PropertyValue::Type t) { + return t == storage::v3::PropertyValue::Type::Int || t == storage::v3::PropertyValue::Type::Double; + }; + + return a == b || (is_numeric(a) && is_numeric(b)); + }; + + auto is_orderable = [](const storage::v3::PropertyValue &t) { + return t.IsNull() || t.IsInt() || t.IsDouble() || t.IsString(); + }; + + // when a range contains different types, nothing should get returned + for (const auto &value_a : values) { + for (const auto &value_b : values) { + if (are_comparable(static_cast<storage::v3::PropertyValue>(value_a).type(), + static_cast<storage::v3::PropertyValue>(value_b).type())) + continue; + if (is_orderable(value_a) && is_orderable(value_b)) { + check(TypedValue(value_a), Bound::Type::INCLUSIVE, TypedValue(value_b), Bound::Type::INCLUSIVE, {}); + } else { + EXPECT_THROW( + run_scan_all(TypedValue(value_a), Bound::Type::INCLUSIVE, TypedValue(value_b), Bound::Type::INCLUSIVE), + QueryRuntimeException); + } + } + } + // These should all raise an exception due to type mismatch when using + // `operator<`. + EXPECT_THROW(run_scan_all(TypedValue(false), Bound::Type::INCLUSIVE, TypedValue(true), Bound::Type::EXCLUSIVE), + QueryRuntimeException); + EXPECT_THROW(run_scan_all(TypedValue(false), Bound::Type::EXCLUSIVE, TypedValue(true), Bound::Type::INCLUSIVE), + QueryRuntimeException); + EXPECT_THROW(run_scan_all(TypedValue(std::vector<TypedValue>{TypedValue(0.5)}), Bound::Type::EXCLUSIVE, + TypedValue(std::vector<TypedValue>{TypedValue(1.5)}), Bound::Type::INCLUSIVE), + QueryRuntimeException); +} + +TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelPropertyEqualityNoError) { + // Add 2 vertices with same label, but with property values that cannot be + // compared. On the other hand, equality works fine. + auto label1 = db.NameToLabel("label1"); + auto prop = db.NameToProperty("prop"); + { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto number_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(number_vertex.AddLabel(label1).HasValue()); + ASSERT_TRUE(number_vertex.SetProperty(prop, storage::v3::PropertyValue(42)).HasValue()); + auto string_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(string_vertex.AddLabel(label1).HasValue()); + ASSERT_TRUE(string_vertex.SetProperty(prop, storage::v3::PropertyValue("string")).HasValue()); + ASSERT_FALSE(dba.Commit().HasError()); + } + db.CreateIndex(label1, prop); + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::v3::View::OLD))); + // MATCH (n :label {prop: 42}) + AstStorage storage; + SymbolTable symbol_table; + auto scan_all = MakeScanAllByLabelPropertyValue(storage, symbol_table, "n", label1, prop, "prop", LITERAL(42)); + // RETURN n + auto output = NEXPR("n", IDENT("n")->MapTo(scan_all.sym_))->MapTo(symbol_table.CreateSymbol("n", true)); + auto produce = MakeProduce(scan_all.op_, output); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + ASSERT_EQ(results.size(), 1); + const auto &row = results[0]; + ASSERT_EQ(row.size(), 1); + auto vertex = row[0].ValueVertex(); + TypedValue value(*vertex.GetProperty(storage::v3::View::OLD, prop)); + TypedValue::BoolEqual eq; + EXPECT_TRUE(eq(value, TypedValue(42))); +} + +TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelPropertyValueError) { + auto label1 = db.NameToLabel("label1"); + auto prop = db.NameToProperty("prop"); + { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + for (int i = 0; i < 2; ++i) { + auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(vertex.AddLabel(label1).HasValue()); + ASSERT_TRUE(vertex.SetProperty(prop, storage::v3::PropertyValue(i)).HasValue()); + } + ASSERT_FALSE(dba.Commit().HasError()); + } + db.CreateIndex(label1, prop); + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::v3::View::OLD))); + // MATCH (m), (n :label1 {prop: m}) + AstStorage storage; + SymbolTable symbol_table; + auto scan_all = MakeScanAll(storage, symbol_table, "m"); + auto *ident_m = IDENT("m"); + ident_m->MapTo(scan_all.sym_); + auto scan_index = + MakeScanAllByLabelPropertyValue(storage, symbol_table, "n", label1, prop, "prop", ident_m, scan_all.op_); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*scan_index.op_, &context), QueryRuntimeException); +} + +TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelPropertyRangeError) { + auto label1 = db.NameToLabel("label1"); + auto prop = db.NameToProperty("prop"); + { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + for (int i = 0; i < 2; ++i) { + auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(vertex.AddLabel(label1).HasValue()); + ASSERT_TRUE(vertex.SetProperty(prop, storage::v3::PropertyValue(i)).HasValue()); + } + ASSERT_FALSE(dba.Commit().HasError()); + } + db.CreateIndex(label1, prop); + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::v3::View::OLD))); + // MATCH (m), (n :label1 {prop: m}) + AstStorage storage; + SymbolTable symbol_table; + auto scan_all = MakeScanAll(storage, symbol_table, "m"); + auto *ident_m = IDENT("m"); + ident_m->MapTo(scan_all.sym_); + { + // Lower bound isn't property value + auto scan_index = + MakeScanAllByLabelPropertyRange(storage, symbol_table, "n", label1, prop, "prop", + Bound{ident_m, Bound::Type::INCLUSIVE}, std::nullopt, scan_all.op_); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*scan_index.op_, &context), QueryRuntimeException); + } + { + // Upper bound isn't property value + auto scan_index = MakeScanAllByLabelPropertyRange(storage, symbol_table, "n", label1, prop, "prop", std::nullopt, + Bound{ident_m, Bound::Type::INCLUSIVE}, scan_all.op_); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*scan_index.op_, &context), QueryRuntimeException); + } + { + // Both bounds aren't property value + auto scan_index = MakeScanAllByLabelPropertyRange(storage, symbol_table, "n", label1, prop, "prop", + Bound{ident_m, Bound::Type::INCLUSIVE}, + Bound{ident_m, Bound::Type::INCLUSIVE}, scan_all.op_); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*scan_index.op_, &context), QueryRuntimeException); + } +} + +TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelPropertyEqualNull) { + // Add 2 vertices with the same label, but one has a property value while + // the other does not. Checking if the value is equal to null, should + // yield no results. + auto label1 = db.NameToLabel("label1"); + auto prop = db.NameToProperty("prop"); + { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(vertex.AddLabel(label1).HasValue()); + auto vertex_with_prop = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(vertex_with_prop.AddLabel(label1).HasValue()); + ASSERT_TRUE(vertex_with_prop.SetProperty(prop, storage::v3::PropertyValue(42)).HasValue()); + ASSERT_FALSE(dba.Commit().HasError()); + } + db.CreateIndex(label1, prop); + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::v3::View::OLD))); + // MATCH (n :label1 {prop: 42}) + AstStorage storage; + SymbolTable symbol_table; + auto scan_all = + MakeScanAllByLabelPropertyValue(storage, symbol_table, "n", label1, prop, "prop", LITERAL(TypedValue())); + // RETURN n + auto output = NEXPR("n", IDENT("n")->MapTo(scan_all.sym_))->MapTo(symbol_table.CreateSymbol("n", true)); + auto produce = MakeProduce(scan_all.op_, output); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), 0); +} + +TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelPropertyRangeNull) { + // Add 2 vertices with the same label, but one has a property value while + // the other does not. Checking if the value is between nulls, should + // yield no results. + auto label1 = db.NameToLabel("label1"); + auto prop = db.NameToProperty("prop"); + { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(vertex.AddLabel(label).HasValue()); + auto vertex_with_prop = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + ASSERT_TRUE(vertex_with_prop.AddLabel(label).HasValue()); + ASSERT_TRUE(vertex_with_prop.SetProperty(prop, storage::v3::PropertyValue(42)).HasValue()); + ASSERT_FALSE(dba.Commit().HasError()); + } + db.CreateIndex(label1, prop); + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::v3::View::OLD))); + // MATCH (n :label1) WHERE null <= n.prop < null + AstStorage storage; + SymbolTable symbol_table; + auto scan_all = MakeScanAllByLabelPropertyRange(storage, symbol_table, "n", label1, prop, "prop", + Bound{LITERAL(TypedValue()), Bound::Type::INCLUSIVE}, + Bound{LITERAL(TypedValue()), Bound::Type::EXCLUSIVE}); + // RETURN n + auto output = NEXPR("n", IDENT("n")->MapTo(scan_all.sym_))->MapTo(symbol_table.CreateSymbol("n", true)); + auto produce = MakeProduce(scan_all.op_, output); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), 0); +} + +TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelPropertyNoValueInIndexContinuation) { + auto label1 = db.NameToLabel("label1"); + auto prop = db.NameToProperty("prop"); + { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(v.AddLabel(label1).HasValue()); + ASSERT_TRUE(v.SetProperty(prop, storage::v3::PropertyValue(2)).HasValue()); + ASSERT_FALSE(dba.Commit().HasError()); + } + db.CreateIndex(label1, prop); + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::v3::View::OLD))); + + AstStorage storage; + SymbolTable symbol_table; + + // UNWIND [1, 2, 3] as x + auto input_expr = LIST(LITERAL(1), LITERAL(2), LITERAL(3)); + auto x = symbol_table.CreateSymbol("x", true); + auto unwind = std::make_shared<plan::Unwind>(nullptr, input_expr, x); + auto x_expr = IDENT("x"); + x_expr->MapTo(x); + + // MATCH (n :label1 {prop: x}) + auto scan_all = MakeScanAllByLabelPropertyValue(storage, symbol_table, "n", label1, prop, "prop", x_expr, unwind); + + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(PullAll(*scan_all.op_, &context), 1); +} + +TEST_F(QueryPlanMatchFilterTest, ScanAllEqualsScanAllByLabelProperty) { + auto label1 = db.NameToLabel("label1"); + auto prop = db.NameToProperty("prop"); + + // Insert vertices + const int vertex_count = 300, vertex_prop_count = 50; + const int prop_value1 = 42, prop_value2 = 69; + + for (int i = 0; i < vertex_count; ++i) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(v.AddLabel(label1).HasValue()); + ASSERT_TRUE( + v.SetProperty(prop, storage::v3::PropertyValue(i < vertex_prop_count ? prop_value1 : prop_value2)).HasValue()); + ASSERT_FALSE(dba.Commit().HasError()); + } + + db.CreateIndex(label1, prop); + + // Make sure there are `vertex_count` vertices + { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + EXPECT_EQ(vertex_count, CountIterable(dba.Vertices(storage::v3::View::OLD))); + } + + // Make sure there are `vertex_prop_count` results when using index + auto count_with_index = [this, &label1, &prop](int prop_value, int prop_count) { + AstStorage storage; + SymbolTable symbol_table; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto scan_all_by_label_property_value = + MakeScanAllByLabelPropertyValue(storage, symbol_table, "n", label1, prop, "prop", LITERAL(prop_value)); + auto output = NEXPR("n", IDENT("n")->MapTo(scan_all_by_label_property_value.sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(scan_all_by_label_property_value.op_, output); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(PullAll(*produce, &context), prop_count); + }; + + // Make sure there are `vertex_count` results when using scan all + auto count_with_scan_all = [this, &prop](int prop_value, int prop_count) { + AstStorage storage; + SymbolTable symbol_table; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto scan_all = MakeScanAll(storage, symbol_table, "n"); + auto e = PROPERTY_LOOKUP(IDENT("n")->MapTo(scan_all.sym_), std::make_pair("prop", prop)); + auto filter = std::make_shared<Filter>(scan_all.op_, EQ(e, LITERAL(prop_value))); + auto output = + NEXPR("n", IDENT("n")->MapTo(scan_all.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(filter, output); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(PullAll(*produce, &context), prop_count); + }; + + count_with_index(prop_value1, vertex_prop_count); + count_with_scan_all(prop_value1, vertex_prop_count); + + count_with_index(prop_value2, vertex_count - vertex_prop_count); + count_with_scan_all(prop_value2, vertex_count - vertex_prop_count); +} +} // namespace memgraph::query::tests diff --git a/tests/unit/query_v2_query_plan_v2_create_set_remove_delete.cpp b/tests/unit/query_v2_query_plan_v2_create_set_remove_delete.cpp new file mode 100644 index 000000000..ed119722f --- /dev/null +++ b/tests/unit/query_v2_query_plan_v2_create_set_remove_delete.cpp @@ -0,0 +1,146 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <gtest/gtest.h> + +#include "query/v2/frontend/semantic/symbol_table.hpp" +#include "query/v2/plan/operator.hpp" +#include "query_v2_query_plan_common.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/storage.hpp" + +namespace memgraph::query::v2::tests { + +class QueryPlanCRUDTest : public testing::Test { + protected: + void SetUp() override { + ASSERT_TRUE(db.CreateSchema(label, {storage::v3::SchemaProperty{property, common::SchemaType::INT}})); + } + + storage::v3::Storage db; + const storage::v3::LabelId label{db.NameToLabel("label")}; + const storage::v3::PropertyId property{db.NameToProperty("property")}; +}; + +TEST_F(QueryPlanCRUDTest, CreateNodeWithAttributes) { + auto dba = db.Access(); + + AstStorage ast; + SymbolTable symbol_table; + + plan::NodeCreationInfo node; + node.symbol = symbol_table.CreateSymbol("n", true); + node.labels.emplace_back(label); + std::get<std::vector<std::pair<storage::v3::PropertyId, Expression *>>>(node.properties) + .emplace_back(property, ast.Create<PrimitiveLiteral>(42)); + + plan::CreateNode create_node(nullptr, node); + DbAccessor execution_dba(&dba); + auto context = MakeContext(ast, symbol_table, &execution_dba); + Frame frame(context.symbol_table.max_position()); + auto cursor = create_node.MakeCursor(utils::NewDeleteResource()); + int count = 0; + while (cursor->Pull(frame, context)) { + ++count; + const auto &node_value = frame[node.symbol]; + EXPECT_EQ(node_value.type(), TypedValue::Type::Vertex); + const auto &v = node_value.ValueVertex(); + EXPECT_TRUE(*v.HasLabel(storage::v3::View::NEW, label)); + EXPECT_EQ(v.GetProperty(storage::v3::View::NEW, property)->ValueInt(), 42); + EXPECT_EQ(CountIterable(*v.InEdges(storage::v3::View::NEW)), 0); + EXPECT_EQ(CountIterable(*v.OutEdges(storage::v3::View::NEW)), 0); + // Invokes LOG(FATAL) instead of erroring out. + // EXPECT_TRUE(v.HasLabel(label, storage::v3::View::OLD).IsError()); + } + EXPECT_EQ(count, 1); +} + +TEST_F(QueryPlanCRUDTest, ScanAllEmpty) { + AstStorage ast; + SymbolTable symbol_table; + auto dba = db.Access(); + DbAccessor execution_dba(&dba); + auto node_symbol = symbol_table.CreateSymbol("n", true); + { + plan::ScanAll scan_all(nullptr, node_symbol, storage::v3::View::OLD); + auto context = MakeContext(ast, symbol_table, &execution_dba); + Frame frame(context.symbol_table.max_position()); + auto cursor = scan_all.MakeCursor(utils::NewDeleteResource()); + int count = 0; + while (cursor->Pull(frame, context)) ++count; + EXPECT_EQ(count, 0); + } + { + plan::ScanAll scan_all(nullptr, node_symbol, storage::v3::View::NEW); + auto context = MakeContext(ast, symbol_table, &execution_dba); + Frame frame(context.symbol_table.max_position()); + auto cursor = scan_all.MakeCursor(utils::NewDeleteResource()); + int count = 0; + while (cursor->Pull(frame, context)) ++count; + EXPECT_EQ(count, 0); + } +} + +TEST_F(QueryPlanCRUDTest, ScanAll) { + { + auto dba = db.Access(); + for (int i = 0; i < 42; ++i) { + auto v = *dba.CreateVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i)}}); + ASSERT_TRUE(v.SetProperty(property, storage::v3::PropertyValue(i)).HasValue()); + } + EXPECT_FALSE(dba.Commit().HasError()); + } + AstStorage ast; + SymbolTable symbol_table; + auto dba = db.Access(); + DbAccessor execution_dba(&dba); + auto node_symbol = symbol_table.CreateSymbol("n", true); + plan::ScanAll scan_all(nullptr, node_symbol); + auto context = MakeContext(ast, symbol_table, &execution_dba); + Frame frame(context.symbol_table.max_position()); + auto cursor = scan_all.MakeCursor(utils::NewDeleteResource()); + int count = 0; + while (cursor->Pull(frame, context)) ++count; + EXPECT_EQ(count, 42); +} + +TEST_F(QueryPlanCRUDTest, ScanAllByLabel) { + auto label2 = db.NameToLabel("label2"); + ASSERT_TRUE(db.CreateIndex(label2)); + { + auto dba = db.Access(); + // Add some unlabeled vertices + for (int i = 0; i < 12; ++i) { + auto v = *dba.CreateVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i)}}); + ASSERT_TRUE(v.SetProperty(property, storage::v3::PropertyValue(i)).HasValue()); + } + // Add labeled vertices + for (int i = 0; i < 42; ++i) { + auto v = *dba.CreateVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i)}}); + ASSERT_TRUE(v.SetProperty(property, storage::v3::PropertyValue(i)).HasValue()); + ASSERT_TRUE(v.AddLabel(label2).HasValue()); + } + EXPECT_FALSE(dba.Commit().HasError()); + } + auto dba = db.Access(); + AstStorage ast; + SymbolTable symbol_table; + auto node_symbol = symbol_table.CreateSymbol("n", true); + DbAccessor execution_dba(&dba); + plan::ScanAllByLabel scan_all(nullptr, node_symbol, label2); + auto context = MakeContext(ast, symbol_table, &execution_dba); + Frame frame(context.symbol_table.max_position()); + auto cursor = scan_all.MakeCursor(utils::NewDeleteResource()); + int count = 0; + while (cursor->Pull(frame, context)) ++count; + EXPECT_EQ(count, 42); +} +} // namespace memgraph::query::v2::tests diff --git a/tests/unit/query_v2_query_required_privileges.cpp b/tests/unit/query_v2_query_required_privileges.cpp new file mode 100644 index 000000000..c59fc604e --- /dev/null +++ b/tests/unit/query_v2_query_required_privileges.cpp @@ -0,0 +1,222 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/frontend/ast/ast_visitor.hpp" +#include "query/v2/frontend/semantic/required_privileges.hpp" +#include "storage/v3/id_types.hpp" + +#include "query_v2_query_common.hpp" + +using namespace memgraph::query::v2; + +class FakeDbAccessor {}; + +const std::string EDGE_TYPE = "0"; +const std::string LABEL_0 = "label0"; +const std::string LABEL_1 = "label1"; +const std::string PROP_0 = "prop0"; + +using ::testing::UnorderedElementsAre; + +class TestPrivilegeExtractor : public ::testing::Test { + protected: + AstStorage storage; + FakeDbAccessor dba; +}; + +TEST_F(TestPrivilegeExtractor, CreateNode) { + auto *query = QUERY(SINGLE_QUERY(CREATE(PATTERN(NODE("n"))))); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::CREATE)); +} + +TEST_F(TestPrivilegeExtractor, MatchNodeDelete) { + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), DELETE(IDENT("n")))); + EXPECT_THAT(GetRequiredPrivileges(query), + UnorderedElementsAre(AuthQuery::Privilege::MATCH, AuthQuery::Privilege::DELETE)); +} + +TEST_F(TestPrivilegeExtractor, MatchNodeReturn) { + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN("n"))); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::MATCH)); +} + +TEST_F(TestPrivilegeExtractor, MatchCreateExpand) { + auto *query = + QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), + CREATE(PATTERN(NODE("n"), EDGE("r", EdgeAtom::Direction::OUT, {EDGE_TYPE}), NODE("m"))))); + EXPECT_THAT(GetRequiredPrivileges(query), + UnorderedElementsAre(AuthQuery::Privilege::MATCH, AuthQuery::Privilege::CREATE)); +} + +TEST_F(TestPrivilegeExtractor, MatchNodeSetLabels) { + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), SET("n", {LABEL_0, LABEL_1}))); + EXPECT_THAT(GetRequiredPrivileges(query), + UnorderedElementsAre(AuthQuery::Privilege::MATCH, AuthQuery::Privilege::SET)); +} + +TEST_F(TestPrivilegeExtractor, MatchNodeSetProperty) { + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), + SET(PROPERTY_LOOKUP(storage.Create<Identifier>("n"), PROP_0), LITERAL(42)))); + EXPECT_THAT(GetRequiredPrivileges(query), + UnorderedElementsAre(AuthQuery::Privilege::MATCH, AuthQuery::Privilege::SET)); +} + +TEST_F(TestPrivilegeExtractor, MatchNodeSetProperties) { + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), SET("n", LIST()))); + EXPECT_THAT(GetRequiredPrivileges(query), + UnorderedElementsAre(AuthQuery::Privilege::MATCH, AuthQuery::Privilege::SET)); +} + +TEST_F(TestPrivilegeExtractor, MatchNodeRemoveLabels) { + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), REMOVE("n", {LABEL_0, LABEL_1}))); + EXPECT_THAT(GetRequiredPrivileges(query), + UnorderedElementsAre(AuthQuery::Privilege::MATCH, AuthQuery::Privilege::REMOVE)); +} + +TEST_F(TestPrivilegeExtractor, MatchNodeRemoveProperty) { + auto *query = + QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), REMOVE(PROPERTY_LOOKUP(storage.Create<Identifier>("n"), PROP_0)))); + EXPECT_THAT(GetRequiredPrivileges(query), + UnorderedElementsAre(AuthQuery::Privilege::MATCH, AuthQuery::Privilege::REMOVE)); +} + +TEST_F(TestPrivilegeExtractor, CreateIndex) { + auto *query = CREATE_INDEX_ON(storage.GetLabelIx(LABEL_0), storage.GetPropertyIx(PROP_0)); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::INDEX)); +} + +TEST_F(TestPrivilegeExtractor, AuthQuery) { + auto *query = + AUTH_QUERY(AuthQuery::Action::CREATE_ROLE, "", "role", "", nullptr, std::vector<AuthQuery::Privilege>{}); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::AUTH)); +} + +TEST_F(TestPrivilegeExtractor, ShowIndexInfo) { + auto *query = storage.Create<InfoQuery>(); + query->info_type_ = InfoQuery::InfoType::INDEX; + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::INDEX)); +} + +TEST_F(TestPrivilegeExtractor, ShowStatsInfo) { + auto *query = storage.Create<InfoQuery>(); + query->info_type_ = InfoQuery::InfoType::STORAGE; + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::STATS)); +} + +TEST_F(TestPrivilegeExtractor, ShowConstraintInfo) { + auto *query = storage.Create<InfoQuery>(); + query->info_type_ = InfoQuery::InfoType::CONSTRAINT; + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::CONSTRAINT)); +} + +TEST_F(TestPrivilegeExtractor, CreateConstraint) { + auto *query = storage.Create<ConstraintQuery>(); + query->action_type_ = ConstraintQuery::ActionType::CREATE; + query->constraint_.label = storage.GetLabelIx("label"); + query->constraint_.properties.push_back(storage.GetPropertyIx("prop0")); + query->constraint_.properties.push_back(storage.GetPropertyIx("prop1")); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::CONSTRAINT)); +} + +TEST_F(TestPrivilegeExtractor, DropConstraint) { + auto *query = storage.Create<ConstraintQuery>(); + query->action_type_ = ConstraintQuery::ActionType::DROP; + query->constraint_.label = storage.GetLabelIx("label"); + query->constraint_.properties.push_back(storage.GetPropertyIx("prop0")); + query->constraint_.properties.push_back(storage.GetPropertyIx("prop1")); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::CONSTRAINT)); +} + +// NOLINTNEXTLINE(hicpp-special-member-functions) +TEST_F(TestPrivilegeExtractor, DumpDatabase) { + auto *query = storage.Create<DumpQuery>(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::DUMP)); +} + +TEST_F(TestPrivilegeExtractor, ReadFile) { + auto load_csv = storage.Create<LoadCsv>(); + load_csv->row_var_ = IDENT("row"); + auto *query = QUERY(SINGLE_QUERY(load_csv)); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::READ_FILE)); +} + +TEST_F(TestPrivilegeExtractor, LockPathQuery) { + auto *query = storage.Create<LockPathQuery>(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::DURABILITY)); +} + +TEST_F(TestPrivilegeExtractor, FreeMemoryQuery) { + auto *query = storage.Create<FreeMemoryQuery>(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::FREE_MEMORY)); +} + +TEST_F(TestPrivilegeExtractor, TriggerQuery) { + auto *query = storage.Create<TriggerQuery>(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::TRIGGER)); +} + +TEST_F(TestPrivilegeExtractor, SetIsolationLevelQuery) { + auto *query = storage.Create<IsolationLevelQuery>(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::CONFIG)); +} + +TEST_F(TestPrivilegeExtractor, CreateSnapshotQuery) { + auto *query = storage.Create<CreateSnapshotQuery>(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::DURABILITY)); +} + +TEST_F(TestPrivilegeExtractor, StreamQuery) { + auto *query = storage.Create<StreamQuery>(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::STREAM)); +} + +TEST_F(TestPrivilegeExtractor, SettingQuery) { + auto *query = storage.Create<SettingQuery>(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::CONFIG)); +} + +TEST_F(TestPrivilegeExtractor, ShowVersion) { + auto *query = storage.Create<VersionQuery>(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::STATS)); +} + +TEST_F(TestPrivilegeExtractor, SchemaQuery) { + auto *query = storage.Create<SchemaQuery>(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::SCHEMA)); +} + +TEST_F(TestPrivilegeExtractor, CallProcedureQuery) { + { + auto *query = QUERY(SINGLE_QUERY(CALL_PROCEDURE("mg.get_module_files"))); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::MODULE_READ)); + } + { + auto *query = QUERY(SINGLE_QUERY(CALL_PROCEDURE("mg.create_module_file", {LITERAL("some_name.py")}))); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::MODULE_WRITE)); + } + { + auto *query = QUERY( + SINGLE_QUERY(CALL_PROCEDURE("mg.update_module_file", {LITERAL("some_name.py"), LITERAL("some content")}))); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::MODULE_WRITE)); + } + { + auto *query = QUERY(SINGLE_QUERY(CALL_PROCEDURE("mg.get_module_file", {LITERAL("some_name.py")}))); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::MODULE_READ)); + } + { + auto *query = QUERY(SINGLE_QUERY(CALL_PROCEDURE("mg.delete_module_file", {LITERAL("some_name.py")}))); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::MODULE_WRITE)); + } +} diff --git a/tests/unit/result_stream_faker.hpp b/tests/unit/result_stream_faker.hpp new file mode 100644 index 000000000..60c5884e7 --- /dev/null +++ b/tests/unit/result_stream_faker.hpp @@ -0,0 +1,132 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <map> + +#include "glue/v2/communication.hpp" +#include "query/v2/typed_value.hpp" +#include "storage/v3/storage.hpp" +#include "utils/algorithm.hpp" + +/** + * A mocker for the data output record stream. + * This implementation checks that messages are + * sent to it in an acceptable order, and tracks + * the content of those messages. + */ +class ResultStreamFaker { + public: + explicit ResultStreamFaker(memgraph::storage::v3::Storage *store) : store_(store) {} + + ResultStreamFaker(const ResultStreamFaker &) = delete; + ResultStreamFaker &operator=(const ResultStreamFaker &) = delete; + ResultStreamFaker(ResultStreamFaker &&) = default; + ResultStreamFaker &operator=(ResultStreamFaker &&) = default; + + void Header(const std::vector<std::string> &fields) { header_ = fields; } + + void Result(const std::vector<memgraph::communication::bolt::Value> &values) { results_.push_back(values); } + + void Result(const std::vector<memgraph::query::v2::TypedValue> &values) { + std::vector<memgraph::communication::bolt::Value> bvalues; + bvalues.reserve(values.size()); + for (const auto &value : values) { + auto maybe_value = memgraph::glue::v2::ToBoltValue(value, *store_, memgraph::storage::v3::View::NEW); + MG_ASSERT(maybe_value.HasValue()); + bvalues.push_back(std::move(*maybe_value)); + } + results_.push_back(std::move(bvalues)); + } + + void Summary(const std::map<std::string, memgraph::communication::bolt::Value> &summary) { summary_ = summary; } + + void Summary(const std::map<std::string, memgraph::query::v2::TypedValue> &summary) { + std::map<std::string, memgraph::communication::bolt::Value> bsummary; + for (const auto &item : summary) { + auto maybe_value = memgraph::glue::v2::ToBoltValue(item.second, *store_, memgraph::storage::v3::View::NEW); + MG_ASSERT(maybe_value.HasValue()); + bsummary.insert({item.first, std::move(*maybe_value)}); + } + summary_ = std::move(bsummary); + } + + const auto &GetHeader() const { return header_; } + + const auto &GetResults() const { return results_; } + + const auto &GetSummary() const { return summary_; } + + friend std::ostream &operator<<(std::ostream &os, const ResultStreamFaker &results) { + auto decoded_value_to_string = [](const auto &value) { + std::stringstream ss; + ss << value; + return ss.str(); + }; + const std::vector<std::string> &header = results.GetHeader(); + std::vector<int> column_widths(header.size()); + std::transform(header.begin(), header.end(), column_widths.begin(), [](const auto &s) { return s.size(); }); + + // convert all the results into strings, and track max column width + auto &results_data = results.GetResults(); + std::vector<std::vector<std::string>> result_strings(results_data.size(), + std::vector<std::string>(column_widths.size())); + for (int row_ind = 0; row_ind < static_cast<int>(results_data.size()); ++row_ind) { + for (int col_ind = 0; col_ind < static_cast<int>(column_widths.size()); ++col_ind) { + std::string string_val = decoded_value_to_string(results_data[row_ind][col_ind]); + column_widths[col_ind] = std::max(column_widths[col_ind], (int)string_val.size()); + result_strings[row_ind][col_ind] = string_val; + } + } + + // output a results table + // first define some helper functions + auto emit_horizontal_line = [&]() { + os << "+"; + for (auto col_width : column_widths) os << std::string((unsigned long)col_width + 2, '-') << "+"; + os << std::endl; + }; + + auto emit_result_vec = [&](const std::vector<std::string> result_vec) { + os << "| "; + for (int col_ind = 0; col_ind < static_cast<int>(column_widths.size()); ++col_ind) { + const std::string &res = result_vec[col_ind]; + os << res << std::string(column_widths[col_ind] - res.size(), ' '); + os << " | "; + } + os << std::endl; + }; + + // final output of results + emit_horizontal_line(); + emit_result_vec(results.GetHeader()); + emit_horizontal_line(); + for (const auto &result_vec : result_strings) emit_result_vec(result_vec); + emit_horizontal_line(); + os << "Found " << results_data.size() << " matching results" << std::endl; + + // output the summary + os << "Query summary: {"; + memgraph::utils::PrintIterable(os, results.GetSummary(), ", ", + [&](auto &stream, const auto &kv) { stream << kv.first << ": " << kv.second; }); + os << "}" << std::endl; + + return os; + } + + private: + memgraph::storage::v3::Storage *store_; + // the data that the record stream can accept + std::vector<std::string> header_; + std::vector<std::vector<memgraph::communication::bolt::Value>> results_; + std::map<std::string, memgraph::communication::bolt::Value> summary_; +}; diff --git a/tests/unit/storage_v3_schema.cpp b/tests/unit/storage_v3_schema.cpp new file mode 100644 index 000000000..3d090fc2f --- /dev/null +++ b/tests/unit/storage_v3_schema.cpp @@ -0,0 +1,294 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <gmock/gmock-matchers.h> +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include <fmt/format.h> +#include <optional> +#include <string> +#include <vector> + +#include "common/types.hpp" +#include "storage/v3/id_types.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/schema_validator.hpp" +#include "storage/v3/schemas.hpp" +#include "storage/v3/storage.hpp" +#include "storage/v3/temporal.hpp" + +using testing::Pair; +using testing::UnorderedElementsAre; +using SchemaType = memgraph::common::SchemaType; + +namespace memgraph::storage::v3::tests { + +class SchemaTest : public testing::Test { + private: + NameIdMapper label_mapper_; + NameIdMapper property_mapper_; + + protected: + LabelId NameToLabel(const std::string &name) { return LabelId::FromUint(label_mapper_.NameToId(name)); } + + PropertyId NameToProperty(const std::string &name) { return PropertyId::FromUint(property_mapper_.NameToId(name)); } + + PropertyId prop1{NameToProperty("prop1")}; + PropertyId prop2{NameToProperty("prop2")}; + LabelId label1{NameToLabel("label1")}; + LabelId label2{NameToLabel("label2")}; + SchemaProperty schema_prop_string{prop1, SchemaType::STRING}; + SchemaProperty schema_prop_int{prop2, SchemaType::INT}; +}; + +TEST_F(SchemaTest, TestSchemaCreate) { + Schemas schemas; + EXPECT_EQ(schemas.ListSchemas().size(), 0); + + EXPECT_TRUE(schemas.CreateSchema(label1, {schema_prop_string})); + EXPECT_EQ(schemas.ListSchemas().size(), 1); + + { + EXPECT_TRUE(schemas.CreateSchema(label2, {schema_prop_string, schema_prop_int})); + const auto current_schemas = schemas.ListSchemas(); + EXPECT_EQ(current_schemas.size(), 2); + EXPECT_THAT(current_schemas, + UnorderedElementsAre(Pair(label1, std::vector<SchemaProperty>{schema_prop_string}), + Pair(label2, std::vector<SchemaProperty>{schema_prop_string, schema_prop_int}))); + } + { + // Assert after unsuccessful creation, number oif schemas remains the same + EXPECT_FALSE(schemas.CreateSchema(label2, {schema_prop_int})); + const auto current_schemas = schemas.ListSchemas(); + EXPECT_EQ(current_schemas.size(), 2); + EXPECT_THAT(current_schemas, + UnorderedElementsAre(Pair(label1, std::vector<SchemaProperty>{schema_prop_string}), + Pair(label2, std::vector<SchemaProperty>{schema_prop_string, schema_prop_int}))); + } +} + +TEST_F(SchemaTest, TestSchemaList) { + Schemas schemas; + + EXPECT_TRUE(schemas.CreateSchema(label1, {schema_prop_string})); + EXPECT_TRUE(schemas.CreateSchema(label2, {{NameToProperty("prop1"), SchemaType::STRING}, + {NameToProperty("prop2"), SchemaType::INT}, + {NameToProperty("prop3"), SchemaType::BOOL}, + {NameToProperty("prop4"), SchemaType::DATE}, + {NameToProperty("prop5"), SchemaType::LOCALDATETIME}, + {NameToProperty("prop6"), SchemaType::DURATION}, + {NameToProperty("prop7"), SchemaType::LOCALTIME}})); + { + const auto current_schemas = schemas.ListSchemas(); + EXPECT_EQ(current_schemas.size(), 2); + EXPECT_THAT(current_schemas, + UnorderedElementsAre( + Pair(label1, std::vector<SchemaProperty>{schema_prop_string}), + Pair(label2, std::vector<SchemaProperty>{{NameToProperty("prop1"), SchemaType::STRING}, + {NameToProperty("prop2"), SchemaType::INT}, + {NameToProperty("prop3"), SchemaType::BOOL}, + {NameToProperty("prop4"), SchemaType::DATE}, + {NameToProperty("prop5"), SchemaType::LOCALDATETIME}, + {NameToProperty("prop6"), SchemaType::DURATION}, + {NameToProperty("prop7"), SchemaType::LOCALTIME}}))); + } + { + const auto *const schema1 = schemas.GetSchema(label1); + ASSERT_NE(schema1, nullptr); + EXPECT_EQ(*schema1, (Schemas::Schema{label1, std::vector<SchemaProperty>{schema_prop_string}})); + } + { + const auto *const schema2 = schemas.GetSchema(label2); + ASSERT_NE(schema2, nullptr); + EXPECT_EQ(schema2->first, label2); + EXPECT_EQ(schema2->second.size(), 7); + } +} + +TEST_F(SchemaTest, TestSchemaDrop) { + Schemas schemas; + EXPECT_EQ(schemas.ListSchemas().size(), 0); + + EXPECT_TRUE(schemas.CreateSchema(label1, {schema_prop_string})); + EXPECT_EQ(schemas.ListSchemas().size(), 1); + + EXPECT_TRUE(schemas.DropSchema(label1)); + EXPECT_EQ(schemas.ListSchemas().size(), 0); + + EXPECT_TRUE(schemas.CreateSchema(label1, {schema_prop_string})); + EXPECT_TRUE(schemas.CreateSchema(label2, {schema_prop_string, schema_prop_int})); + EXPECT_EQ(schemas.ListSchemas().size(), 2); + + { + EXPECT_TRUE(schemas.DropSchema(label1)); + const auto current_schemas = schemas.ListSchemas(); + EXPECT_EQ(current_schemas.size(), 1); + EXPECT_THAT(current_schemas, + UnorderedElementsAre(Pair(label2, std::vector<SchemaProperty>{schema_prop_string, schema_prop_int}))); + } + + { + // Cannot drop nonexisting schema + EXPECT_FALSE(schemas.DropSchema(label1)); + const auto current_schemas = schemas.ListSchemas(); + EXPECT_EQ(current_schemas.size(), 1); + EXPECT_THAT(current_schemas, + UnorderedElementsAre(Pair(label2, std::vector<SchemaProperty>{schema_prop_string, schema_prop_int}))); + } + + EXPECT_TRUE(schemas.DropSchema(label2)); + EXPECT_EQ(schemas.ListSchemas().size(), 0); +} + +class SchemaValidatorTest : public testing::Test { + protected: + void SetUp() override { + ASSERT_TRUE(schemas.CreateSchema(label1, {schema_prop_string})); + ASSERT_TRUE(schemas.CreateSchema(label2, {schema_prop_string, schema_prop_int, schema_prop_duration})); + } + + LabelId NameToLabel(const std::string &name) { return LabelId::FromUint(label_mapper_.NameToId(name)); } + + PropertyId NameToProperty(const std::string &name) { return PropertyId::FromUint(property_mapper_.NameToId(name)); } + + private: + NameIdMapper label_mapper_; + NameIdMapper property_mapper_; + + protected: + Schemas schemas; + SchemaValidator schema_validator{schemas}; + PropertyId prop_string{NameToProperty("prop1")}; + PropertyId prop_int{NameToProperty("prop2")}; + PropertyId prop_duration{NameToProperty("prop3")}; + LabelId label1{NameToLabel("label1")}; + LabelId label2{NameToLabel("label2")}; + SchemaProperty schema_prop_string{prop_string, SchemaType::STRING}; + SchemaProperty schema_prop_int{prop_int, SchemaType::INT}; + SchemaProperty schema_prop_duration{prop_duration, SchemaType::DURATION}; +}; + +TEST_F(SchemaValidatorTest, TestSchemaValidateVertexCreate) { + // Validate against secondary label + { + const auto schema_violation = + schema_validator.ValidateVertexCreate(NameToLabel("test"), {}, {{prop_string, PropertyValue(1)}}); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, + SchemaViolation(SchemaViolation::ValidationStatus::NO_SCHEMA_DEFINED_FOR_LABEL, NameToLabel("test"))); + } + // Validate missing property + { + const auto schema_violation = schema_validator.ValidateVertexCreate(label1, {}, {{prop_int, PropertyValue(1)}}); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_HAS_NO_PRIMARY_PROPERTY, + label1, schema_prop_string)); + } + { + const auto schema_violation = schema_validator.ValidateVertexCreate(label2, {}, {}); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_HAS_NO_PRIMARY_PROPERTY, + label2, schema_prop_string)); + } + // Validate wrong secondary label + { + const auto schema_violation = + schema_validator.ValidateVertexCreate(label1, {label1}, {{prop_string, PropertyValue("test")}}); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, + SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_SECONDARY_LABEL_IS_PRIMARY, label1)); + } + { + const auto schema_violation = + schema_validator.ValidateVertexCreate(label1, {label2}, {{prop_string, PropertyValue("test")}}); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, + SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_SECONDARY_LABEL_IS_PRIMARY, label2)); + } + // Validate wrong property type + { + const auto schema_violation = schema_validator.ValidateVertexCreate(label1, {}, {{prop_string, PropertyValue(1)}}); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_PROPERTY_WRONG_TYPE, label1, + schema_prop_string, PropertyValue(1))); + } + { + const auto schema_violation = schema_validator.ValidateVertexCreate( + label2, {}, + {{prop_string, PropertyValue("test")}, {prop_int, PropertyValue(12)}, {prop_duration, PropertyValue(1)}}); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_PROPERTY_WRONG_TYPE, label2, + schema_prop_duration, PropertyValue(1))); + } + { + const auto wrong_prop = PropertyValue(TemporalData(TemporalType::Date, 1234)); + const auto schema_violation = schema_validator.ValidateVertexCreate( + label2, {}, {{prop_string, PropertyValue("test")}, {prop_int, PropertyValue(12)}, {prop_duration, wrong_prop}}); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_PROPERTY_WRONG_TYPE, label2, + schema_prop_duration, wrong_prop)); + } + // Passing validations + EXPECT_EQ(schema_validator.ValidateVertexCreate(label1, {}, {{prop_string, PropertyValue("test")}}), std::nullopt); + EXPECT_EQ(schema_validator.ValidateVertexCreate(label1, {NameToLabel("label3"), NameToLabel("label4")}, + {{prop_string, PropertyValue("test")}}), + std::nullopt); + EXPECT_EQ(schema_validator.ValidateVertexCreate( + label2, {}, + {{prop_string, PropertyValue("test")}, + {prop_int, PropertyValue(122)}, + {prop_duration, PropertyValue(TemporalData(TemporalType::Duration, 1234))}}), + std::nullopt); + EXPECT_EQ(schema_validator.ValidateVertexCreate( + label2, {NameToLabel("label5"), NameToLabel("label6")}, + {{prop_string, PropertyValue("test123")}, + {prop_int, PropertyValue(122221)}, + {prop_duration, PropertyValue(TemporalData(TemporalType::Duration, 12344321))}}), + std::nullopt); +} + +TEST_F(SchemaValidatorTest, TestSchemaValidatePropertyUpdate) { + // Validate updating of primary key + { + const auto schema_violation = schema_validator.ValidatePropertyUpdate(label1, prop_string); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_UPDATE_PRIMARY_KEY, label1, + schema_prop_string)); + } + { + const auto schema_violation = schema_validator.ValidatePropertyUpdate(label2, prop_duration); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_UPDATE_PRIMARY_KEY, label2, + schema_prop_duration)); + } + EXPECT_EQ(schema_validator.ValidatePropertyUpdate(label1, prop_int), std::nullopt); + EXPECT_EQ(schema_validator.ValidatePropertyUpdate(label1, prop_duration), std::nullopt); + EXPECT_EQ(schema_validator.ValidatePropertyUpdate(label2, NameToProperty("test")), std::nullopt); +} + +TEST_F(SchemaValidatorTest, TestSchemaValidatePropertyUpdateLabel) { + // Validate adding primary label + { + const auto schema_violation = schema_validator.ValidateLabelUpdate(label1); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, + SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_MODIFY_PRIMARY_LABEL, label1)); + } + { + const auto schema_violation = schema_validator.ValidateLabelUpdate(label2); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, + SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_MODIFY_PRIMARY_LABEL, label2)); + } + EXPECT_EQ(schema_validator.ValidateLabelUpdate(NameToLabel("test")), std::nullopt); +} +} // namespace memgraph::storage::v3::tests