From 2998f925952b17c6159f296941b71a4c94ff9fc6 Mon Sep 17 00:00:00 2001 From: Jure Bajic Date: Thu, 23 Jun 2022 14:04:44 +0200 Subject: [PATCH 1/5] Add initial schema implementation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add initial schema implementation * Add index to schema * List schemas and enable multiple properties * Implement SchemaTypes * Apply suggestions from code review Co-authored-by: Jeremy B <97525434+42jeremy@users.noreply.github.com> Co-authored-by: János Benjamin Antal * Address review comments * Remove Map and List * Apply suggestions from code review Co-authored-by: Kostas Kyrimis Co-authored-by: Jeremy B <97525434+42jeremy@users.noreply.github.com> Co-authored-by: János Benjamin Antal Co-authored-by: Kostas Kyrimis --- src/query/db_accessor.hpp | 2 + src/storage/v2/CMakeLists.txt | 1 + src/storage/v2/schemas.cpp | 74 ++++++++++++++++ src/storage/v2/schemas.hpp | 155 ++++++++++++++++++++++++++++++++++ src/storage/v2/storage.cpp | 11 ++- src/storage/v2/storage.hpp | 21 ++++- 6 files changed, 261 insertions(+), 3 deletions(-) create mode 100644 src/storage/v2/schemas.cpp create mode 100644 src/storage/v2/schemas.hpp diff --git a/src/query/db_accessor.hpp b/src/query/db_accessor.hpp index f325e1282..0369b462b 100644 --- a/src/query/db_accessor.hpp +++ b/src/query/db_accessor.hpp @@ -356,6 +356,8 @@ class DbAccessor final { storage::IndicesInfo ListAllIndices() const { return accessor_->ListAllIndices(); } storage::ConstraintsInfo ListAllConstraints() const { return accessor_->ListAllConstraints(); } + + storage::SchemasInfo ListAllSchemas() const { return accessor_->ListAllSchemas(); } }; } // namespace memgraph::query diff --git a/src/storage/v2/CMakeLists.txt b/src/storage/v2/CMakeLists.txt index f33a8553d..dab088c93 100644 --- a/src/storage/v2/CMakeLists.txt +++ b/src/storage/v2/CMakeLists.txt @@ -10,6 +10,7 @@ set(storage_v2_src_files indices.cpp property_store.cpp vertex_accessor.cpp + schemas.cpp storage.cpp) ##### Replication ##### diff --git a/src/storage/v2/schemas.cpp b/src/storage/v2/schemas.cpp new file mode 100644 index 000000000..2e4fbbe3b --- /dev/null +++ b/src/storage/v2/schemas.cpp @@ -0,0 +1,74 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include +#include +#include + +#include "storage/v2/property_value.hpp" +#include "storage/v2/schemas.hpp" + +namespace memgraph::storage { + +SchemaViolation::SchemaViolation(ValidationStatus status, LabelId label) : status{status}, label{label} {} +SchemaViolation::SchemaViolation(ValidationStatus status, LabelId label, SchemaProperty violated_type) + : status{status}, label{label}, violated_type{violated_type} {} + +SchemaViolation::SchemaViolation(ValidationStatus status, LabelId label, SchemaProperty violated_type, + PropertyValue violated_property_value) + : status{status}, label{label}, violated_type{violated_type}, violated_property_value{violated_property_value} {} + +bool Schemas::CreateSchema(const LabelId primary_label, const std::vector &schemas_types) { + return schemas_.insert({primary_label, schemas_types}).second; +} + +bool Schemas::DeleteSchema(const LabelId primary_label) { + return schemas_.erase(primary_label); +} + +std::optional Schemas::ValidateVertex(const LabelId primary_label, const Vertex &vertex) { + // TODO Check for multiple defined primary labels + const auto schema = schemas_.find(primary_label); + if (schema == schemas_.end()) { + return SchemaViolation(SchemaViolation::ValidationStatus::NO_SCHEMA_DEFINED_FOR_LABEL, primary_label); + } + if (!utils::Contains(vertex.labels, primary_label)) { + return SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_HAS_NO_PRIMARY_LABEL, primary_label); + } + + for (const auto &schema_type : schema->second) { + if (!vertex.properties.HasProperty(schema_type.property_id)) { + return SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_HAS_NO_PROPERTY, primary_label, schema_type); + } + // Property type check + // TODO Can this be replaced with just property id check? + if (auto vertex_property = vertex.properties.GetProperty(schema_type.property_id); + PropertyValueTypeToSchemaProperty(vertex_property) != schema_type.type) { + return SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_PROPERTY_WRONG_TYPE, primary_label, schema_type, + vertex_property); + } + } + // TODO after the introduction of vertex hashing introduce check for vertex + // primary key uniqueness + + return std::nullopt; +} + +Schemas::SchemasList Schemas::ListSchemas() const { + Schemas::SchemasList ret; + ret.reserve(schemas_.size()); + for (const auto &[label_props, schema_property] : schemas_) { + ret.emplace_back(label_props, schema_property); + } + return ret; +} + +} // namespace memgraph::storage diff --git a/src/storage/v2/schemas.hpp b/src/storage/v2/schemas.hpp new file mode 100644 index 000000000..113707069 --- /dev/null +++ b/src/storage/v2/schemas.hpp @@ -0,0 +1,155 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include +#include +#include +#include + +#include "storage/v2/id_types.hpp" +#include "storage/v2/indices.hpp" +#include "storage/v2/property_value.hpp" +#include "storage/v2/temporal.hpp" +#include "storage/v2/transaction.hpp" +#include "storage/v2/vertex.hpp" +#include "utils/result.hpp" + +namespace memgraph::storage { + +class SchemaViolationException : public utils::BasicException { + using utils::BasicException::BasicException; +}; + +struct SchemaProperty { + enum class Type : uint8_t { Bool, Int, Double, String, Date, LocalTime, LocalDateTime, Duration }; + + Type type; + PropertyId property_id; +}; + +struct SchemaViolation { + enum class ValidationStatus : uint8_t { + VERTEX_HAS_NO_PRIMARY_LABEL, + VERTEX_HAS_NO_PROPERTY, + NO_SCHEMA_DEFINED_FOR_LABEL, + VERTEX_PROPERTY_WRONG_TYPE + }; + + SchemaViolation(ValidationStatus status, LabelId label); + + SchemaViolation(ValidationStatus status, LabelId label, SchemaProperty violated_type); + + SchemaViolation(ValidationStatus status, LabelId label, SchemaProperty violated_type, + PropertyValue violated_property_value); + + ValidationStatus status; + LabelId label; + std::optional violated_type; + std::optional violated_property_value; +}; + +/// Structure that represents a collection of schemas +/// Schema can be mapped under only one label => primary label +class Schemas { + public: + using SchemasMap = std::unordered_map>; + using SchemasList = std::vector>>; + + Schemas() = default; + Schemas(const Schemas &) = delete; + Schemas(Schemas &&) = delete; + Schemas &operator=(const Schemas &) = delete; + Schemas &operator=(Schemas &&) = delete; + ~Schemas() = default; + + [[nodiscard]] bool CreateSchema(LabelId label, const std::vector &schemas_types); + + [[nodiscard]] bool DeleteSchema(LabelId label); + + [[nodiscard]] std::optional ValidateVertex(LabelId primary_label, const Vertex &vertex); + + [[nodiscard]] SchemasList ListSchemas() const; + + private: + SchemasMap schemas_; +}; + +inline std::optional PropertyValueTypeToSchemaProperty(const PropertyValue &property_value) { + switch (property_value.type()) { + case PropertyValue::Type::Bool: { + return SchemaProperty::Type::Bool; + } + case PropertyValue::Type::Int: { + return SchemaProperty::Type::Int; + } + case PropertyValue::Type::Double: { + return SchemaProperty::Type::Double; + } + case PropertyValue::Type::String: { + return SchemaProperty::Type::String; + } + case PropertyValue::Type::TemporalData: { + switch (property_value.ValueTemporalData().type) { + case TemporalType::Date: { + return SchemaProperty::Type::Date; + } + case TemporalType::LocalDateTime: { + return SchemaProperty::Type::LocalDateTime; + } + case TemporalType::LocalTime: { + return SchemaProperty::Type::LocalTime; + } + case TemporalType::Duration: { + return SchemaProperty::Type::Duration; + } + } + } + case PropertyValue::Type::Null: + case PropertyValue::Type::Map: + case PropertyValue::Type::List: { + return std::nullopt; + } + } +} + +inline std::string SchemaPropertyToString(const SchemaProperty::Type type) { + switch (type) { + case SchemaProperty::Type::Bool: { + return "Bool"; + } + case SchemaProperty::Type::Int: { + return "Integer"; + } + case SchemaProperty::Type::Double: { + return "Double"; + } + case SchemaProperty::Type::String: { + return "String"; + } + case SchemaProperty::Type::Date: { + return "Date"; + } + case SchemaProperty::Type::LocalTime: { + return "LocalTime"; + } + case SchemaProperty::Type::LocalDateTime: { + return "LocalDateTime"; + } + case SchemaProperty::Type::Duration: { + return "Duration"; + } + } +} + +} // namespace memgraph::storage diff --git a/src/storage/v2/storage.cpp b/src/storage/v2/storage.cpp index e0fbdf0da..5b1f3084b 100644 --- a/src/storage/v2/storage.cpp +++ b/src/storage/v2/storage.cpp @@ -28,6 +28,7 @@ #include "storage/v2/indices.hpp" #include "storage/v2/mvcc.hpp" #include "storage/v2/replication/config.hpp" +#include "storage/v2/schemas.hpp" #include "storage/v2/transaction.hpp" #include "storage/v2/vertex_accessor.hpp" #include "utils/file.hpp" @@ -463,12 +464,13 @@ VertexAccessor Storage::Accessor::CreateVertex() { OOMExceptionEnabler oom_exception; auto gid = storage_->vertex_id_.fetch_add(1, std::memory_order_acq_rel); auto acc = storage_->vertices_.access(); - auto delta = CreateDeleteObjectDelta(&transaction_); + auto *delta = CreateDeleteObjectDelta(&transaction_); auto [it, inserted] = acc.insert(Vertex{storage::Gid::FromUint(gid), delta}); MG_ASSERT(inserted, "The vertex must be inserted here!"); MG_ASSERT(it != acc.end(), "Invalid Vertex accessor!"); + delta->prev.Set(&*it); - return VertexAccessor(&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_); + return {&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_}; } VertexAccessor Storage::Accessor::CreateVertex(storage::Gid gid) { @@ -1234,6 +1236,11 @@ ConstraintsInfo Storage::ListAllConstraints() const { return {ListExistenceConstraints(constraints_), constraints_.unique_constraints.ListConstraints()}; } +SchemasInfo Storage::ListAllSchemas() const { + std::shared_lock storage_guard_(main_lock_); + return {schemas_.ListSchemas()}; +} + StorageInfo Storage::GetInfo() const { auto vertex_count = vertices_.size(); auto edge_count = edge_count_.load(std::memory_order_acquire); diff --git a/src/storage/v2/storage.hpp b/src/storage/v2/storage.hpp index 90a5daf16..4839ff3ca 100644 --- a/src/storage/v2/storage.hpp +++ b/src/storage/v2/storage.hpp @@ -16,6 +16,7 @@ #include #include #include +#include #include "io/network/endpoint.hpp" #include "storage/v2/commit_log.hpp" @@ -25,14 +26,18 @@ #include "storage/v2/durability/wal.hpp" #include "storage/v2/edge.hpp" #include "storage/v2/edge_accessor.hpp" +#include "storage/v2/id_types.hpp" #include "storage/v2/indices.hpp" #include "storage/v2/isolation_level.hpp" #include "storage/v2/mvcc.hpp" #include "storage/v2/name_id_mapper.hpp" +#include "storage/v2/property_value.hpp" #include "storage/v2/result.hpp" +#include "storage/v2/schemas.hpp" #include "storage/v2/transaction.hpp" #include "storage/v2/vertex.hpp" #include "storage/v2/vertex_accessor.hpp" +#include "utils/exceptions.hpp" #include "utils/file_locker.hpp" #include "utils/on_scope_exit.hpp" #include "utils/rw_lock.hpp" @@ -173,6 +178,11 @@ struct ConstraintsInfo { std::vector>> unique; }; +/// Structure used to return information about existing schemas in the storage +struct SchemasInfo { + Schemas::SchemasList schemas; +}; + /// Structure used to return information about the storage. struct StorageInfo { uint64_t vertex_count; @@ -306,6 +316,8 @@ class Storage final { storage_->constraints_.unique_constraints.ListConstraints()}; } + SchemasInfo ListAllSchemas() const { return {storage_->schemas_.ListSchemas()}; } + void AdvanceCommand(); /// Commit returns `ConstraintViolation` if the changes made by this @@ -364,7 +376,7 @@ class Storage final { IndicesInfo ListAllIndices() const; /// Creates an existence constraint. Returns true if the constraint was - /// successfuly added, false if it already exists and a `ConstraintViolation` + /// successfully added, false if it already exists and a `ConstraintViolation` /// if there is an existing vertex violating the constraint. /// /// @throw std::bad_alloc @@ -402,6 +414,12 @@ class Storage final { ConstraintsInfo ListAllConstraints() const; + bool CreateSchema(LabelId primary_label, std::vector &schemas_types); + + bool DeleteSchema(LabelId primary_label); + + SchemasInfo ListAllSchemas() const; + StorageInfo GetInfo() const; bool LockPath(); @@ -497,6 +515,7 @@ class Storage final { Constraints constraints_; Indices indices_; + Schemas schemas_; // Transaction engine utils::SpinLock engine_lock_; From 3f4f66b57fa80245125ab8a07722d07797bdda42 Mon Sep 17 00:00:00 2001 From: Jure Bajic Date: Mon, 11 Jul 2022 09:20:15 +0200 Subject: [PATCH 2/5] Create schema DDL expressions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add initial schema implementation * Add index to schema * List schemas and enable multiple properties * Implement SchemaTypes * Apply suggestions from code review Co-authored-by: Jeremy B <97525434+42jeremy@users.noreply.github.com> Co-authored-by: János Benjamin Antal * Address review comments * Remove Map and List * Add schema operations in storage * Add create and show schema queries * Add privileges for schema * Add missing keywords into lexer * Add drop schema query * Add schema visitors * Update metadata * Add PrepareSchemaQuery function * Implement show schemas * Add show schema query * Fix schema visitor * Add common schema type * Fix grammar * Temporary create ddl logic * Fix naming for schemaproperty type to schema type * Rename schemaproperty to schemapropertytype * Enable Create schema ddl * Override visitPropertyType * Add initial schema implementation * Add initial schema implementation * Add index to schema * List schemas and enable multiple properties * Implement SchemaTypes * Apply suggestions from code review Co-authored-by: Jeremy B <97525434+42jeremy@users.noreply.github.com> Co-authored-by: János Benjamin Antal * Address review comments * Remove Map and List * Apply suggestions from code review Co-authored-by: Kostas Kyrimis Co-authored-by: Jeremy B <97525434+42jeremy@users.noreply.github.com> Co-authored-by: János Benjamin Antal Co-authored-by: Kostas Kyrimis * Add verification on creation and deletion * Rename DeleteSchema to DropSchema * Remove list and map from lexer * Fix grammar with schemaTypeMap * Add privilege and cypher visitor tests * Catch repeating type name in schema definition * Fix conflicting keywords * Add notifications * Drop float support * Finish interpreter tests * Fix tests * Fix clang tidy errors * Fix GetSchema * Replace for with transfrom * Add cloning og schema_property_map * Address review comments * Rename SchemaPropertyType to SchemaType * Remove inline * Assert of schema properties Co-authored-by: Jeremy B <97525434+42jeremy@users.noreply.github.com> Co-authored-by: János Benjamin Antal Co-authored-by: Kostas Kyrimis --- src/auth/models.cpp | 4 +- src/auth/models.hpp | 3 +- src/common/types.hpp | 19 +++ src/glue/auth.cpp | 2 + src/query/frontend/ast/ast.lcp | 47 +++++- src/query/frontend/ast/ast_visitor.hpp | 9 +- .../frontend/ast/cypher_main_visitor.cpp | 90 +++++++++++ .../frontend/ast/cypher_main_visitor.hpp | 35 ++++ .../opencypher/grammar/MemgraphCypher.g4 | 28 +++- .../opencypher/grammar/MemgraphCypherLexer.g4 | 2 + .../frontend/semantic/required_privileges.cpp | 2 + .../frontend/stripped_lexer_constants.hpp | 5 +- src/query/interpreter.cpp | 126 +++++++++++++++ src/query/metadata.cpp | 10 +- src/query/metadata.hpp | 4 + src/storage/v2/schemas.cpp | 92 +++++++++-- src/storage/v2/schemas.hpp | 87 ++-------- src/storage/v2/storage.cpp | 11 ++ src/storage/v2/storage.hpp | 10 +- tests/unit/cypher_main_visitor.cpp | 114 ++++++++++++++ tests/unit/interpreter.cpp | 149 ++++++++++++++++++ tests/unit/query_required_privileges.cpp | 5 + 22 files changed, 754 insertions(+), 100 deletions(-) create mode 100644 src/common/types.hpp diff --git a/src/auth/models.cpp b/src/auth/models.cpp index 54bf24da6..32b2ca291 100644 --- a/src/auth/models.cpp +++ b/src/auth/models.cpp @@ -37,7 +37,7 @@ const std::vector kPermissionsAll = { Permission::CONSTRAINT, Permission::DUMP, Permission::AUTH, Permission::REPLICATION, Permission::DURABILITY, Permission::READ_FILE, Permission::FREE_MEMORY, Permission::TRIGGER, Permission::CONFIG, Permission::STREAM, Permission::MODULE_READ, Permission::MODULE_WRITE, - Permission::WEBSOCKET}; + Permission::WEBSOCKET, Permission::SCHEMA}; } // namespace std::string PermissionToString(Permission permission) { @@ -84,6 +84,8 @@ std::string PermissionToString(Permission permission) { return "MODULE_WRITE"; case Permission::WEBSOCKET: return "WEBSOCKET"; + case Permission::SCHEMA: + return "SCHEMA"; } } diff --git a/src/auth/models.hpp b/src/auth/models.hpp index 0f01c0a39..00c26464b 100644 --- a/src/auth/models.hpp +++ b/src/auth/models.hpp @@ -38,7 +38,8 @@ enum class Permission : uint64_t { STREAM = 1U << 17U, MODULE_READ = 1U << 18U, MODULE_WRITE = 1U << 19U, - WEBSOCKET = 1U << 20U + WEBSOCKET = 1U << 20U, + SCHEMA = 1U << 21U }; // clang-format on diff --git a/src/common/types.hpp b/src/common/types.hpp new file mode 100644 index 000000000..09a0aecf5 --- /dev/null +++ b/src/common/types.hpp @@ -0,0 +1,19 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include + +namespace memgraph::common { +enum class SchemaType : uint8_t { BOOL, INT, STRING, DATE, LOCALTIME, LOCALDATETIME, DURATION }; + +} // namespace memgraph::common diff --git a/src/glue/auth.cpp b/src/glue/auth.cpp index 7f05d8045..5d9ffbb84 100644 --- a/src/glue/auth.cpp +++ b/src/glue/auth.cpp @@ -57,6 +57,8 @@ auth::Permission PrivilegeToPermission(query::AuthQuery::Privilege privilege) { return auth::Permission::MODULE_WRITE; case query::AuthQuery::Privilege::WEBSOCKET: return auth::Permission::WEBSOCKET; + case query::AuthQuery::Privilege::SCHEMA: + return auth::Permission::SCHEMA; } } } // namespace memgraph::glue diff --git a/src/query/frontend/ast/ast.lcp b/src/query/frontend/ast/ast.lcp index 33d397754..de24ae37a 100644 --- a/src/query/frontend/ast/ast.lcp +++ b/src/query/frontend/ast/ast.lcp @@ -17,6 +17,7 @@ #include #include +#include "common/types.hpp" #include "query/frontend/ast/ast_visitor.hpp" #include "query/frontend/semantic/symbol.hpp" #include "query/interpret/awesome_memgraph_functions.hpp" @@ -133,6 +134,15 @@ cpp<# } cpp<#)) + +(defun clone-schema-property-vector (source dest) + #>cpp + ${dest}.reserve(${source}.size()); + for (const auto &[property_ix, property_type]: ${source}) { + ${dest}.emplace_back(storage->GetPropertyIx(property_ix.name), property_type); + } + cpp<#) + ;; The following index structs serve as a decoupling point of AST from ;; concrete database types. All the names are collected in AstStorage, and can ;; be indexed through these instances. This means that we can create a vector @@ -2253,7 +2263,7 @@ cpp<# (lcp:define-enum privilege (create delete match merge set remove index stats auth constraint dump replication durability read_file free_memory trigger config stream module_read module_write - websocket) + websocket schema) (:serialize)) #>cpp AuthQuery() = default; @@ -2295,7 +2305,7 @@ const std::vector kPrivilegesAll = { AuthQuery::Privilege::FREE_MEMORY, AuthQuery::Privilege::TRIGGER, AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM, AuthQuery::Privilege::MODULE_READ, AuthQuery::Privilege::MODULE_WRITE, - AuthQuery::Privilege::WEBSOCKET}; + AuthQuery::Privilege::WEBSOCKET, AuthQuery::Privilege::SCHEMA}; cpp<# (lcp:define-class info-query (query) @@ -2668,5 +2678,38 @@ cpp<# (:serialize (:slk)) (:clone)) +(lcp:define-class schema-query (query) + ((action "Action" :scope :public) + (label "LabelIx" :scope :public + :slk-load (lambda (member) + #>cpp + slk::Load(&self->${member}, reader, storage); + cpp<#) + :clone (lambda (source dest) + #>cpp + ${dest} = storage->GetLabelIx(${source}.name); + cpp<#)) + (schema_type_map "std::vector>" + :slk-save #'slk-save-property-map + :slk-load #'slk-load-property-map + :clone #'clone-schema-property-vector + :scope :public)) + + (:public + (lcp:define-enum action + (create-schema drop-schema show-schema show-schemas) + (:serialize)) + #>cpp + SchemaQuery() = default; + + DEFVISITABLE(QueryVisitor); + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + (lcp:pop-namespace) ;; namespace query (lcp:pop-namespace) ;; namespace memgraph diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index 0e4a6012c..307b96907 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -94,6 +94,7 @@ class StreamQuery; class SettingQuery; class VersionQuery; class Foreach; +class SchemaQuery; using TreeCompositeVisitor = utils::CompositeVisitor< SingleQuery, CypherUnion, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, AdditionOperator, @@ -125,9 +126,9 @@ class ExpressionVisitor None, ParameterLookup, Identifier, PrimitiveLiteral, RegexMatch> {}; template -class QueryVisitor - : public utils::Visitor {}; +class QueryVisitor : public utils::Visitor {}; } // namespace memgraph::query diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index f4a269dfd..2e73b2fbb 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -27,6 +28,7 @@ #include +#include "common/types.hpp" #include "query/exceptions.hpp" #include "query/frontend/ast/ast.hpp" #include "query/frontend/ast/ast_visitor.hpp" @@ -1355,6 +1357,7 @@ antlrcpp::Any CypherMainVisitor::visitPrivilege(MemgraphCypher::PrivilegeContext if (ctx->MODULE_READ()) return AuthQuery::Privilege::MODULE_READ; if (ctx->MODULE_WRITE()) return AuthQuery::Privilege::MODULE_WRITE; if (ctx->WEBSOCKET()) return AuthQuery::Privilege::WEBSOCKET; + if (ctx->SCHEMA()) return AuthQuery::Privilege::SCHEMA; LOG_FATAL("Should not get here - unknown privilege!"); } @@ -2353,6 +2356,93 @@ antlrcpp::Any CypherMainVisitor::visitForeach(MemgraphCypher::ForeachContext *ct return for_each; } +antlrcpp::Any CypherMainVisitor::visitSchemaQuery(MemgraphCypher::SchemaQueryContext *ctx) { + MG_ASSERT(ctx->children.size() == 1, "SchemaQuery should have exactly one child!"); + auto *schema_query = ctx->children[0]->accept(this).as(); + query_ = schema_query; + return schema_query; +} + +antlrcpp::Any CypherMainVisitor::visitShowSchema(MemgraphCypher::ShowSchemaContext *ctx) { + auto *schema_query = storage_->Create(); + schema_query->action_ = SchemaQuery::Action::SHOW_SCHEMA; + schema_query->label_ = AddLabel(ctx->labelName()->accept(this)); + query_ = schema_query; + return schema_query; +} + +antlrcpp::Any CypherMainVisitor::visitShowSchemas(MemgraphCypher::ShowSchemasContext * /*ctx*/) { + auto *schema_query = storage_->Create(); + schema_query->action_ = SchemaQuery::Action::SHOW_SCHEMAS; + query_ = schema_query; + return schema_query; +} + +antlrcpp::Any CypherMainVisitor::visitPropertyType(MemgraphCypher::PropertyTypeContext *ctx) { + MG_ASSERT(ctx->symbolicName()); + const auto property_type = utils::ToLowerCase(ctx->symbolicName()->accept(this).as()); + if (property_type == "bool") { + return common::SchemaType::BOOL; + } + if (property_type == "string") { + return common::SchemaType::STRING; + } + if (property_type == "integer") { + return common::SchemaType::INT; + } + if (property_type == "date") { + return common::SchemaType::DATE; + } + if (property_type == "duration") { + return common::SchemaType::DURATION; + } + if (property_type == "localdatetime") { + return common::SchemaType::LOCALDATETIME; + } + if (property_type == "localtime") { + return common::SchemaType::LOCALTIME; + } + throw SyntaxException("Property type must be one of the supported types!"); +} + +/** + * @return Schema* + */ +antlrcpp::Any CypherMainVisitor::visitSchemaPropertyMap(MemgraphCypher::SchemaPropertyMapContext *ctx) { + std::vector> schema_property_map; + for (auto *property_key_pair : ctx->propertyKeyTypePair()) { + PropertyIx key = property_key_pair->propertyKeyName()->accept(this); + common::SchemaType type = property_key_pair->propertyType()->accept(this); + if (std::ranges::find_if(schema_property_map, [&key](const auto &elem) { return elem.first == key; }) != + schema_property_map.end()) { + throw SemanticException("Same property name can't appear twice in a schema map."); + } + schema_property_map.emplace_back(key, type); + } + return schema_property_map; +} + +antlrcpp::Any CypherMainVisitor::visitCreateSchema(MemgraphCypher::CreateSchemaContext *ctx) { + auto *schema_query = storage_->Create(); + schema_query->action_ = SchemaQuery::Action::CREATE_SCHEMA; + schema_query->label_ = AddLabel(ctx->labelName()->accept(this)); + schema_query->schema_type_map_ = + ctx->schemaPropertyMap()->accept(this).as>>(); + query_ = schema_query; + return schema_query; +} + +/** + * @return Schema* + */ +antlrcpp::Any CypherMainVisitor::visitDropSchema(MemgraphCypher::DropSchemaContext *ctx) { + auto *schema_query = storage_->Create(); + schema_query->action_ = SchemaQuery::Action::DROP_SCHEMA; + schema_query->label_ = AddLabel(ctx->labelName()->accept(this)); + query_ = schema_query; + return schema_query; +} + LabelIx CypherMainVisitor::AddLabel(const std::string &name) { return storage_->GetLabelIx(name); } PropertyIx CypherMainVisitor::AddProperty(const std::string &name) { return storage_->GetPropertyIx(name); } diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index 2a6b8ff5e..a4216e702 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -849,6 +849,41 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { */ antlrcpp::Any visitForeach(MemgraphCypher::ForeachContext *ctx) override; + /** + * @return Schema* + */ + antlrcpp::Any visitPropertyType(MemgraphCypher::PropertyTypeContext *ctx) override; + + /** + * @return Schema* + */ + antlrcpp::Any visitSchemaPropertyMap(MemgraphCypher::SchemaPropertyMapContext *ctx) override; + + /** + * @return Schema* + */ + antlrcpp::Any visitSchemaQuery(MemgraphCypher::SchemaQueryContext *ctx) override; + + /** + * @return Schema* + */ + antlrcpp::Any visitShowSchema(MemgraphCypher::ShowSchemaContext *ctx) override; + + /** + * @return Schema* + */ + antlrcpp::Any visitShowSchemas(MemgraphCypher::ShowSchemasContext *ctx) override; + + /** + * @return Schema* + */ + antlrcpp::Any visitCreateSchema(MemgraphCypher::CreateSchemaContext *ctx) override; + + /** + * @return Schema* + */ + antlrcpp::Any visitDropSchema(MemgraphCypher::DropSchemaContext *ctx) override; + public: Query *query() { return query_; } const static std::string kAnonPrefix; diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 index b412a474a..df51de704 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 @@ -46,10 +46,10 @@ memgraphCypherKeyword : cypherKeyword | DROP | DUMP | EXECUTE - | FOR - | FOREACH | FREE | FROM + | FOR + | FOREACH | GLOBAL | GRANT | HEADER @@ -76,6 +76,8 @@ memgraphCypherKeyword : cypherKeyword | ROLE | ROLES | QUOTE + | SCHEMA + | SCHEMAS | SESSION | SETTING | SETTINGS @@ -122,6 +124,7 @@ query : cypherQuery | streamQuery | settingQuery | versionQuery + | schemaQuery ; authQuery : createRole @@ -192,6 +195,12 @@ settingQuery : setSetting | showSettings ; +schemaQuery : showSchema + | showSchemas + | createSchema + | dropSchema + ; + loadCsv : LOAD CSV FROM csvFile ( WITH | NO ) HEADER ( IGNORE BAD ) ? ( DELIMITER delimiter ) ? @@ -254,6 +263,7 @@ privilege : CREATE | MODULE_READ | MODULE_WRITE | WEBSOCKET + | SCHEMA ; privilegeList : privilege ( ',' privilege )* ; @@ -374,3 +384,17 @@ showSetting : SHOW DATABASE SETTING settingName ; showSettings : SHOW DATABASE SETTINGS ; versionQuery : SHOW VERSION ; + +showSchema : SHOW SCHEMA ON ':' labelName ; + +showSchemas : SHOW SCHEMAS ; + +propertyType : symbolicName ; + +propertyKeyTypePair : propertyKeyName propertyType ; + +schemaPropertyMap : '(' propertyKeyTypePair ( ',' propertyKeyTypePair )* ')' ; + +createSchema : CREATE SCHEMA ON ':' labelName schemaPropertyMap ; + +dropSchema : DROP SCHEMA ON ':' labelName ; diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 index 55e5d53a2..869141033 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 @@ -89,6 +89,8 @@ REVOKE : R E V O K E ; ROLE : R O L E ; ROLES : R O L E S ; QUOTE : Q U O T E ; +SCHEMA : S C H E M A ; +SCHEMAS : S C H E M A S ; SERVICE_URL : S E R V I C E UNDERSCORE U R L ; SESSION : S E S S I O N ; SETTING : S E T T I N G ; diff --git a/src/query/frontend/semantic/required_privileges.cpp b/src/query/frontend/semantic/required_privileges.cpp index e8dbd21e5..160004ac2 100644 --- a/src/query/frontend/semantic/required_privileges.cpp +++ b/src/query/frontend/semantic/required_privileges.cpp @@ -80,6 +80,8 @@ class PrivilegeExtractor : public QueryVisitor, public HierarchicalTreeVis void Visit(VersionQuery & /*version_query*/) override { AddPrivilege(AuthQuery::Privilege::STATS); } + void Visit(SchemaQuery & /*schema_query*/) override { AddPrivilege(AuthQuery::Privilege::SCHEMA); } + bool PreVisit(Create & /*unused*/) override { AddPrivilege(AuthQuery::Privilege::CREATE); return false; diff --git a/src/query/frontend/stripped_lexer_constants.hpp b/src/query/frontend/stripped_lexer_constants.hpp index 42b7b4aeb..be516aee6 100644 --- a/src/query/frontend/stripped_lexer_constants.hpp +++ b/src/query/frontend/stripped_lexer_constants.hpp @@ -204,8 +204,9 @@ const trie::Trie kKeywords = {"union", "pulsar", "service_url", "version", - "websocket" - "foreach"}; + "websocket", + "foreach", + "schema"}; // Unicode codepoints that are allowed at the start of the unescaped name. const std::bitset kUnescapedNameAllowedStarts( diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index a8bb42dc9..fbb1cbc55 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -44,6 +44,7 @@ #include "query/trigger.hpp" #include "query/typed_value.hpp" #include "storage/v2/property_value.hpp" +#include "storage/v2/schemas.hpp" #include "utils/algorithm.hpp" #include "utils/csv_parsing.hpp" #include "utils/event_counter.hpp" @@ -891,6 +892,102 @@ Callback HandleSettingQuery(SettingQuery *setting_query, const Parameters ¶m } } +Callback HandleSchemaQuery(SchemaQuery *schema_query, InterpreterContext *interpreter_context, + std::vector *notifications) { + Callback callback; + switch (schema_query->action_) { + case SchemaQuery::Action::SHOW_SCHEMAS: { + callback.header = {"label", "primary_key"}; + callback.fn = [interpreter_context]() { + auto *db = interpreter_context->db; + auto schemas_info = db->ListAllSchemas(); + std::vector> results; + results.reserve(schemas_info.schemas.size()); + + for (const auto &[label_id, schema_types] : schemas_info.schemas) { + std::vector schema_info_row; + schema_info_row.reserve(3); + + schema_info_row.emplace_back(db->LabelToName(label_id)); + std::vector primary_key_properties; + primary_key_properties.reserve(schema_types.size()); + std::transform(schema_types.begin(), schema_types.end(), std::back_inserter(primary_key_properties), + [&db](const auto &schema_type) { + return db->PropertyToName(schema_type.property_id) + + "::" + storage::SchemaTypeToString(schema_type.type); + }); + + schema_info_row.emplace_back(utils::Join(primary_key_properties, ", ")); + results.push_back(std::move(schema_info_row)); + } + return results; + }; + return callback; + } + case SchemaQuery::Action::SHOW_SCHEMA: { + callback.header = {"property_name", "property_type"}; + callback.fn = [interpreter_context, primary_label = schema_query->label_]() { + auto *db = interpreter_context->db; + const auto label = db->NameToLabel(primary_label.name); + const auto schema = db->GetSchema(label); + std::vector> results; + if (schema) { + for (const auto &schema_property : schema->second) { + std::vector schema_info_row; + schema_info_row.reserve(2); + schema_info_row.emplace_back(db->PropertyToName(schema_property.property_id)); + schema_info_row.emplace_back(storage::SchemaTypeToString(schema_property.type)); + results.push_back(std::move(schema_info_row)); + } + return results; + } + throw QueryException(fmt::format("Schema on label :{} not found!", primary_label.name)); + }; + return callback; + } + case SchemaQuery::Action::CREATE_SCHEMA: { + auto schema_type_map = schema_query->schema_type_map_; + if (schema_query->schema_type_map_.empty()) { + throw SyntaxException("One or more types have to be defined in schema definition."); + } + callback.fn = [interpreter_context, primary_label = schema_query->label_, + schema_type_map = std::move(schema_type_map)]() { + auto *db = interpreter_context->db; + const auto label = db->NameToLabel(primary_label.name); + std::vector schemas_types; + schemas_types.reserve(schema_type_map.size()); + for (const auto &schema_type : schema_type_map) { + auto property_id = db->NameToProperty(schema_type.first.name); + schemas_types.push_back({property_id, schema_type.second}); + } + if (!db->CreateSchema(label, schemas_types)) { + throw QueryException(fmt::format("Schema on label :{} already exists!", primary_label.name)); + } + return std::vector>{}; + }; + notifications->emplace_back(SeverityLevel::INFO, NotificationCode::CREATE_SCHEMA, + fmt::format("Create schema on label :{}", schema_query->label_.name)); + return callback; + } + case SchemaQuery::Action::DROP_SCHEMA: { + callback.fn = [interpreter_context, primary_label = schema_query->label_]() { + auto *db = interpreter_context->db; + const auto label = db->NameToLabel(primary_label.name); + + if (!db->DropSchema(label)) { + throw QueryException(fmt::format("Schema on label :{} does not exist!", primary_label.name)); + } + + return std::vector>{}; + }; + notifications->emplace_back(SeverityLevel::INFO, NotificationCode::DROP_SCHEMA, + fmt::format("Dropped schema on label :{}", schema_query->label_.name)); + return callback; + } + } + return callback; +} + // Struct for lazy pulling from a vector struct PullPlanVector { explicit PullPlanVector(std::vector> values) : values_(std::move(values)) {} @@ -2086,6 +2183,32 @@ PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_ RWType::NONE}; } +PreparedQuery PrepareSchemaQuery(ParsedQuery parsed_query, bool in_explicit_transaction, + InterpreterContext *interpreter_context, std::vector *notifications) { + if (in_explicit_transaction) { + throw ConstraintInMulticommandTxException(); + } + auto *schema_query = utils::Downcast(parsed_query.query); + MG_ASSERT(schema_query); + auto callback = HandleSchemaQuery(schema_query, interpreter_context, notifications); + + return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), + [handler = std::move(callback.fn), action = QueryHandlerResult::NOTHING, + pull_plan = std::shared_ptr(nullptr)]( + AnyStream *stream, std::optional n) mutable -> std::optional { + if (!pull_plan) { + auto results = handler(); + pull_plan = std::make_shared(std::move(results)); + } + + if (pull_plan->Pull(stream, n)) { + return action; + } + return std::nullopt; + }, + RWType::NONE}; +} + void Interpreter::BeginTransaction() { const auto prepared_query = PrepareTransactionQuery("BEGIN"); prepared_query.query_handler(nullptr, {}); @@ -2219,6 +2342,9 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, prepared_query = PrepareSettingQuery(std::move(parsed_query), in_explicit_transaction_, &*execution_db_accessor_); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareVersionQuery(std::move(parsed_query), in_explicit_transaction_); + } else if (utils::Downcast(parsed_query.query)) { + prepared_query = PrepareSchemaQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_, + &query_execution->notifications); } else { LOG_FATAL("Should not get here -- unknown query type!"); } diff --git a/src/query/metadata.cpp b/src/query/metadata.cpp index f4e8512fd..2e25ce8a4 100644 --- a/src/query/metadata.cpp +++ b/src/query/metadata.cpp @@ -38,6 +38,8 @@ constexpr std::string_view GetCodeString(const NotificationCode code) { return "CreateIndex"sv; case NotificationCode::CREATE_STREAM: return "CreateStream"sv; + case NotificationCode::CREATE_SCHEMA: + return "CreateSchema"sv; case NotificationCode::CHECK_STREAM: return "CheckStream"sv; case NotificationCode::CREATE_TRIGGER: @@ -48,6 +50,8 @@ constexpr std::string_view GetCodeString(const NotificationCode code) { return "DropReplica"sv; case NotificationCode::DROP_INDEX: return "DropIndex"sv; + case NotificationCode::DROP_SCHEMA: + return "DropSchema"sv; case NotificationCode::DROP_STREAM: return "DropStream"sv; case NotificationCode::DROP_TRIGGER: @@ -68,6 +72,10 @@ constexpr std::string_view GetCodeString(const NotificationCode code) { return "ReplicaPortWarning"sv; case NotificationCode::SET_REPLICA: return "SetReplica"sv; + case NotificationCode::SHOW_SCHEMA: + return "ShowSchema"sv; + case NotificationCode::SHOW_SCHEMAS: + return "ShowSchemas"sv; case NotificationCode::START_STREAM: return "StartStream"sv; case NotificationCode::START_ALL_STREAMS: @@ -114,4 +122,4 @@ std::string ExecutionStatsKeyToString(const ExecutionStats::Key key) { } } -} // namespace memgraph::query \ No newline at end of file +} // namespace memgraph::query diff --git a/src/query/metadata.hpp b/src/query/metadata.hpp index 67f784fa8..e557ca72e 100644 --- a/src/query/metadata.hpp +++ b/src/query/metadata.hpp @@ -26,12 +26,14 @@ enum class SeverityLevel : uint8_t { INFO, WARNING }; enum class NotificationCode : uint8_t { CREATE_CONSTRAINT, CREATE_INDEX, + CREATE_SCHEMA, CHECK_STREAM, CREATE_STREAM, CREATE_TRIGGER, DROP_CONSTRAINT, DROP_INDEX, DROP_REPLICA, + DROP_SCHEMA, DROP_STREAM, DROP_TRIGGER, EXISTANT_INDEX, @@ -42,6 +44,8 @@ enum class NotificationCode : uint8_t { REPLICA_PORT_WARNING, REGISTER_REPLICA, SET_REPLICA, + SHOW_SCHEMA, + SHOW_SCHEMAS, START_STREAM, START_ALL_STREAMS, STOP_STREAM, diff --git a/src/storage/v2/schemas.cpp b/src/storage/v2/schemas.cpp index 2e4fbbe3b..1bec8455a 100644 --- a/src/storage/v2/schemas.cpp +++ b/src/storage/v2/schemas.cpp @@ -26,14 +26,31 @@ SchemaViolation::SchemaViolation(ValidationStatus status, LabelId label, SchemaP PropertyValue violated_property_value) : status{status}, label{label}, violated_type{violated_type}, violated_property_value{violated_property_value} {} -bool Schemas::CreateSchema(const LabelId primary_label, const std::vector &schemas_types) { - return schemas_.insert({primary_label, schemas_types}).second; +Schemas::SchemasList Schemas::ListSchemas() const { + Schemas::SchemasList ret; + ret.reserve(schemas_.size()); + std::transform(schemas_.begin(), schemas_.end(), std::back_inserter(ret), + [](const auto &schema_property_type) { return schema_property_type; }); + return ret; } -bool Schemas::DeleteSchema(const LabelId primary_label) { - return schemas_.erase(primary_label); +std::optional Schemas::GetSchema(const LabelId primary_label) const { + if (auto schema_map = schemas_.find(primary_label); schema_map != schemas_.end()) { + return Schema{schema_map->first, schema_map->second}; + } + return std::nullopt; } +bool Schemas::CreateSchema(const LabelId primary_label, const std::vector &schemas_types) { + if (schemas_.contains(primary_label)) { + return false; + } + schemas_.emplace(primary_label, schemas_types); + return true; +} + +bool Schemas::DropSchema(const LabelId primary_label) { return schemas_.erase(primary_label); } + std::optional Schemas::ValidateVertex(const LabelId primary_label, const Vertex &vertex) { // TODO Check for multiple defined primary labels const auto schema = schemas_.find(primary_label); @@ -51,7 +68,7 @@ std::optional Schemas::ValidateVertex(const LabelId primary_lab // Property type check // TODO Can this be replaced with just property id check? if (auto vertex_property = vertex.properties.GetProperty(schema_type.property_id); - PropertyValueTypeToSchemaProperty(vertex_property) != schema_type.type) { + PropertyTypeToSchemaType(vertex_property) != schema_type.type) { return SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_PROPERTY_WRONG_TYPE, primary_label, schema_type, vertex_property); } @@ -62,13 +79,66 @@ std::optional Schemas::ValidateVertex(const LabelId primary_lab return std::nullopt; } -Schemas::SchemasList Schemas::ListSchemas() const { - Schemas::SchemasList ret; - ret.reserve(schemas_.size()); - for (const auto &[label_props, schema_property] : schemas_) { - ret.emplace_back(label_props, schema_property); +std::optional PropertyTypeToSchemaType(const PropertyValue &property_value) { + switch (property_value.type()) { + case PropertyValue::Type::Bool: { + return common::SchemaType::BOOL; + } + case PropertyValue::Type::Int: { + return common::SchemaType::INT; + } + case PropertyValue::Type::String: { + return common::SchemaType::STRING; + } + case PropertyValue::Type::TemporalData: { + switch (property_value.ValueTemporalData().type) { + case TemporalType::Date: { + return common::SchemaType::DATE; + } + case TemporalType::LocalDateTime: { + return common::SchemaType::LOCALDATETIME; + } + case TemporalType::LocalTime: { + return common::SchemaType::LOCALTIME; + } + case TemporalType::Duration: { + return common::SchemaType::DURATION; + } + } + } + case PropertyValue::Type::Double: + case PropertyValue::Type::Null: + case PropertyValue::Type::Map: + case PropertyValue::Type::List: { + return std::nullopt; + } + } +} + +std::string SchemaTypeToString(const common::SchemaType type) { + switch (type) { + case common::SchemaType::BOOL: { + return "Bool"; + } + case common::SchemaType::INT: { + return "Integer"; + } + case common::SchemaType::STRING: { + return "String"; + } + case common::SchemaType::DATE: { + return "Date"; + } + case common::SchemaType::LOCALTIME: { + return "LocalTime"; + } + case common::SchemaType::LOCALDATETIME: { + return "LocalDateTime"; + } + case common::SchemaType::DURATION: { + return "Duration"; + } } - return ret; } } // namespace memgraph::storage diff --git a/src/storage/v2/schemas.hpp b/src/storage/v2/schemas.hpp index 113707069..898288b6b 100644 --- a/src/storage/v2/schemas.hpp +++ b/src/storage/v2/schemas.hpp @@ -17,6 +17,7 @@ #include #include +#include "common/types.hpp" #include "storage/v2/id_types.hpp" #include "storage/v2/indices.hpp" #include "storage/v2/property_value.hpp" @@ -32,10 +33,8 @@ class SchemaViolationException : public utils::BasicException { }; struct SchemaProperty { - enum class Type : uint8_t { Bool, Int, Double, String, Date, LocalTime, LocalDateTime, Duration }; - - Type type; PropertyId property_id; + common::SchemaType type; }; struct SchemaViolation { @@ -63,8 +62,9 @@ struct SchemaViolation { /// Schema can be mapped under only one label => primary label class Schemas { public: + using Schema = std::pair>; using SchemasMap = std::unordered_map>; - using SchemasList = std::vector>>; + using SchemasList = std::vector; Schemas() = default; Schemas(const Schemas &) = delete; @@ -73,83 +73,26 @@ class Schemas { Schemas &operator=(Schemas &&) = delete; ~Schemas() = default; + [[nodiscard]] SchemasList ListSchemas() const; + + [[nodiscard]] std::optional GetSchema(LabelId primary_label) const; + + // Returns true if it was successfully created or false if the schema + // already exists [[nodiscard]] bool CreateSchema(LabelId label, const std::vector &schemas_types); - [[nodiscard]] bool DeleteSchema(LabelId label); + // Returns true if it was successfully dropped or false if the schema + // does not exist + [[nodiscard]] bool DropSchema(LabelId label); [[nodiscard]] std::optional ValidateVertex(LabelId primary_label, const Vertex &vertex); - [[nodiscard]] SchemasList ListSchemas() const; - private: SchemasMap schemas_; }; -inline std::optional PropertyValueTypeToSchemaProperty(const PropertyValue &property_value) { - switch (property_value.type()) { - case PropertyValue::Type::Bool: { - return SchemaProperty::Type::Bool; - } - case PropertyValue::Type::Int: { - return SchemaProperty::Type::Int; - } - case PropertyValue::Type::Double: { - return SchemaProperty::Type::Double; - } - case PropertyValue::Type::String: { - return SchemaProperty::Type::String; - } - case PropertyValue::Type::TemporalData: { - switch (property_value.ValueTemporalData().type) { - case TemporalType::Date: { - return SchemaProperty::Type::Date; - } - case TemporalType::LocalDateTime: { - return SchemaProperty::Type::LocalDateTime; - } - case TemporalType::LocalTime: { - return SchemaProperty::Type::LocalTime; - } - case TemporalType::Duration: { - return SchemaProperty::Type::Duration; - } - } - } - case PropertyValue::Type::Null: - case PropertyValue::Type::Map: - case PropertyValue::Type::List: { - return std::nullopt; - } - } -} +std::optional PropertyTypeToSchemaType(const PropertyValue &property_value); -inline std::string SchemaPropertyToString(const SchemaProperty::Type type) { - switch (type) { - case SchemaProperty::Type::Bool: { - return "Bool"; - } - case SchemaProperty::Type::Int: { - return "Integer"; - } - case SchemaProperty::Type::Double: { - return "Double"; - } - case SchemaProperty::Type::String: { - return "String"; - } - case SchemaProperty::Type::Date: { - return "Date"; - } - case SchemaProperty::Type::LocalTime: { - return "LocalTime"; - } - case SchemaProperty::Type::LocalDateTime: { - return "LocalDateTime"; - } - case SchemaProperty::Type::Duration: { - return "Duration"; - } - } -} +std::string SchemaTypeToString(common::SchemaType type); } // namespace memgraph::storage diff --git a/src/storage/v2/storage.cpp b/src/storage/v2/storage.cpp index 5b1f3084b..36f306a62 100644 --- a/src/storage/v2/storage.cpp +++ b/src/storage/v2/storage.cpp @@ -1241,6 +1241,17 @@ SchemasInfo Storage::ListAllSchemas() const { return {schemas_.ListSchemas()}; } +std::optional Storage::GetSchema(const LabelId primary_label) const { + std::shared_lock storage_guard_(main_lock_); + return schemas_.GetSchema(primary_label); +} + +bool Storage::CreateSchema(const LabelId primary_label, const std::vector &schemas_types) { + return schemas_.CreateSchema(primary_label, schemas_types); +} + +bool Storage::DropSchema(const LabelId primary_label) { return schemas_.DropSchema(primary_label); } + StorageInfo Storage::GetInfo() const { auto vertex_count = vertices_.size(); auto edge_count = edge_count_.load(std::memory_order_acquire); diff --git a/src/storage/v2/storage.hpp b/src/storage/v2/storage.hpp index 4839ff3ca..7f8ac6ddc 100644 --- a/src/storage/v2/storage.hpp +++ b/src/storage/v2/storage.hpp @@ -414,12 +414,14 @@ class Storage final { ConstraintsInfo ListAllConstraints() const; - bool CreateSchema(LabelId primary_label, std::vector &schemas_types); - - bool DeleteSchema(LabelId primary_label); - SchemasInfo ListAllSchemas() const; + std::optional GetSchema(LabelId primary_label) const; + + bool CreateSchema(LabelId primary_label, const std::vector &schemas_types); + + bool DropSchema(LabelId primary_label); + StorageInfo GetInfo() const; bool LockPath(); diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index 39663e2ea..f0d5a6cef 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -32,6 +32,7 @@ #include #include +#include "common/types.hpp" #include "query/exceptions.hpp" #include "query/frontend/ast/ast.hpp" #include "query/frontend/ast/cypher_main_visitor.hpp" @@ -2213,6 +2214,8 @@ TEST_P(CypherMainVisitorTest, GrantPrivilege) { {AuthQuery::Privilege::MODULE_READ}); check_auth_query(&ast_generator, "GRANT MODULE_WRITE TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, {AuthQuery::Privilege::MODULE_WRITE}); + check_auth_query(&ast_generator, "GRANT SCHEMA TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::SCHEMA}); } TEST_P(CypherMainVisitorTest, DenyPrivilege) { @@ -2253,6 +2256,8 @@ TEST_P(CypherMainVisitorTest, DenyPrivilege) { {AuthQuery::Privilege::MODULE_READ}); check_auth_query(&ast_generator, "DENY MODULE_WRITE TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, {AuthQuery::Privilege::MODULE_WRITE}); + check_auth_query(&ast_generator, "DENY SCHEMA TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::SCHEMA}); } TEST_P(CypherMainVisitorTest, RevokePrivilege) { @@ -2295,6 +2300,8 @@ TEST_P(CypherMainVisitorTest, RevokePrivilege) { {}, {AuthQuery::Privilege::MODULE_READ}); check_auth_query(&ast_generator, "REVOKE MODULE_WRITE FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, {AuthQuery::Privilege::MODULE_WRITE}); + check_auth_query(&ast_generator, "REVOKE SCHEMA FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::SCHEMA}); } TEST_P(CypherMainVisitorTest, ShowPrivileges) { @@ -4211,3 +4218,110 @@ TEST_P(CypherMainVisitorTest, Foreach) { ASSERT_TRUE(dynamic_cast(*++clauses.begin())); } } + +TEST_P(CypherMainVisitorTest, TestShowSchemas) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("SHOW SCHEMAS")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_, SchemaQuery::Action::SHOW_SCHEMAS); +} + +TEST_P(CypherMainVisitorTest, TestShowSchema) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("SHOW SCHEMA ON label"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("SHOW SCHEMA :label"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("SHOW SCHEMA label"), SyntaxException); + + auto *query = dynamic_cast(ast_generator.ParseQuery("SHOW SCHEMA ON :label")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_, SchemaQuery::Action::SHOW_SCHEMA); + EXPECT_EQ(query->label_, ast_generator.Label("label")); +} + +void AssertSchemaPropertyMap(auto &schema_property_map, + std::vector> properties_type, + auto &ast_generator) { + EXPECT_EQ(schema_property_map.size(), properties_type.size()); + for (size_t i{0}; i < schema_property_map.size(); ++i) { + // Assert PropertyId + EXPECT_EQ(schema_property_map[i].first, ast_generator.Prop(properties_type[i].first)); + // Assert Property Type + EXPECT_EQ(schema_property_map[i].second, properties_type[i].second); + } +} + +TEST_P(CypherMainVisitorTest, TestCreateSchema) { + { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label()"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(123 INTEGER)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(name TYPE)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(name, age)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(name, DURATION)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON label(name INTEGER)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(name INTEGER, name INTEGER)"), SemanticException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(name INTEGER, name STRING)"), SemanticException); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("CREATE SCHEMA ON :label1(name STRING)")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_, SchemaQuery::Action::CREATE_SCHEMA); + EXPECT_EQ(query->label_, ast_generator.Label("label1")); + AssertSchemaPropertyMap(query->schema_type_map_, {{"name", memgraph::common::SchemaType::STRING}}, ast_generator); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("CREATE SCHEMA ON :label2(name string)")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_, SchemaQuery::Action::CREATE_SCHEMA); + EXPECT_EQ(query->label_, ast_generator.Label("label2")); + AssertSchemaPropertyMap(query->schema_type_map_, {{"name", memgraph::common::SchemaType::STRING}}, ast_generator); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast( + ast_generator.ParseQuery("CREATE SCHEMA ON :label3(first_name STRING, last_name STRING)")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_, SchemaQuery::Action::CREATE_SCHEMA); + EXPECT_EQ(query->label_, ast_generator.Label("label3")); + AssertSchemaPropertyMap( + query->schema_type_map_, + {{"first_name", memgraph::common::SchemaType::STRING}, {"last_name", memgraph::common::SchemaType::STRING}}, + ast_generator); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast( + ast_generator.ParseQuery("CREATE SCHEMA ON :label4(name STRING, age INTEGER, dur DURATION, birthday1 " + "LOCALDATETIME, birthday2 DATE, some_time LOCALTIME, speaks_truth BOOL)")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_, SchemaQuery::Action::CREATE_SCHEMA); + EXPECT_EQ(query->label_, ast_generator.Label("label4")); + AssertSchemaPropertyMap(query->schema_type_map_, + { + {"name", memgraph::common::SchemaType::STRING}, + {"age", memgraph::common::SchemaType::INT}, + {"dur", memgraph::common::SchemaType::DURATION}, + {"birthday1", memgraph::common::SchemaType::LOCALDATETIME}, + {"birthday2", memgraph::common::SchemaType::DATE}, + {"some_time", memgraph::common::SchemaType::LOCALTIME}, + {"speaks_truth", memgraph::common::SchemaType::BOOL}, + }, + ast_generator); + } +} + +TEST_P(CypherMainVisitorTest, TestDropSchema) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("DROP SCHEMA"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("DROP SCHEMA ON label"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("DROP SCHEMA :label"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("DROP SCHEMA ON :label()"), SyntaxException); + + auto *query = dynamic_cast(ast_generator.ParseQuery("DROP SCHEMA ON :label")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_, SchemaQuery::Action::DROP_SCHEMA); + EXPECT_EQ(query->label_, ast_generator.Label("label")); +} diff --git a/tests/unit/interpreter.cpp b/tests/unit/interpreter.cpp index f5a3e03b3..466079578 100644 --- a/tests/unit/interpreter.cpp +++ b/tests/unit/interpreter.cpp @@ -10,8 +10,10 @@ // licenses/APL.txt. #include +#include #include #include +#include #include "communication/bolt/v1/value.hpp" #include "communication/result_stream_faker.hpp" @@ -38,6 +40,11 @@ auto ToEdgeList(const memgraph::communication::bolt::Value &v) { list.push_back(x.ValueEdge()); } return list; +} + +auto StringToUnorderedSet(const std::string &element) { + const auto element_split = memgraph::utils::Split(element, ", "); + return std::unordered_set(element_split.begin(), element_split.end()); }; struct InterpreterFaker { @@ -1465,3 +1472,145 @@ TEST_F(InterpreterTest, LoadCsvClauseNotification) { "conversion functions such as ToInteger, ToFloat, ToBoolean etc."); ASSERT_EQ(notification["description"].ValueString(), ""); } + +TEST_F(InterpreterTest, CreateSchemaMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING, age INTEGER)"), + memgraph::query::ConstraintInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, ShowSchemasMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("SHOW SCHEMAS"), memgraph::query::ConstraintInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, ShowSchemaMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("SHOW SCHEMA ON :label"), memgraph::query::ConstraintInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, DropSchemaMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("DROP SCHEMA ON :label"), memgraph::query::ConstraintInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, SchemaTestCreateAndShow) { + // Empty schema type map should result with syntax exception. + ASSERT_THROW(Interpret("CREATE SCHEMA ON :label();"), memgraph::query::SyntaxException); + + // Duplicate properties are should also cause an exception + ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING, name STRING);"), memgraph::query::SemanticException); + ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING, name INTEGER);"), memgraph::query::SemanticException); + + { + // Cannot create same schema twice + Interpret("CREATE SCHEMA ON :label(name STRING, age INTEGER)"); + ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING);"), memgraph::query::QueryException); + } + // Show schema + { + auto stream = Interpret("SHOW SCHEMA ON :label"); + ASSERT_EQ(stream.GetHeader().size(), 2U); + const auto &header = stream.GetHeader(); + ASSERT_EQ(header[0], "property_name"); + ASSERT_EQ(header[1], "property_type"); + ASSERT_EQ(stream.GetResults().size(), 2U); + std::unordered_map result_table{{"age", "Integer"}, {"name", "String"}}; + + const auto &result = stream.GetResults().front(); + ASSERT_EQ(result.size(), 2U); + const auto key1 = result[0].ValueString(); + ASSERT_TRUE(result_table.contains(key1)); + ASSERT_EQ(result[1].ValueString(), result_table[key1]); + + const auto &result2 = stream.GetResults().front(); + ASSERT_EQ(result2.size(), 2U); + const auto key2 = result2[0].ValueString(); + ASSERT_TRUE(result_table.contains(key2)); + ASSERT_EQ(result[1].ValueString(), result_table[key2]); + } + // Create Another Schema + Interpret("CREATE SCHEMA ON :label2(place STRING, dur DURATION)"); + + // Show schemas + { + auto stream = Interpret("SHOW SCHEMAS"); + ASSERT_EQ(stream.GetHeader().size(), 2U); + const auto &header = stream.GetHeader(); + ASSERT_EQ(header[0], "label"); + ASSERT_EQ(header[1], "primary_key"); + ASSERT_EQ(stream.GetResults().size(), 2U); + std::unordered_map> result_table{ + {"label", {"name::String", "age::Integer"}}, {"label2", {"place::String", "dur::Duration"}}}; + + const auto &result = stream.GetResults().front(); + ASSERT_EQ(result.size(), 2U); + const auto key1 = result[0].ValueString(); + ASSERT_TRUE(result_table.contains(key1)); + const auto primary_key_split = StringToUnorderedSet(result[1].ValueString()); + ASSERT_EQ(primary_key_split.size(), 2); + ASSERT_TRUE(primary_key_split == result_table[key1]) << "actual value is: " << result[1].ValueString(); + + const auto &result2 = stream.GetResults().front(); + ASSERT_EQ(result2.size(), 2U); + const auto key2 = result2[0].ValueString(); + ASSERT_TRUE(result_table.contains(key2)); + const auto primary_key_split2 = StringToUnorderedSet(result2[1].ValueString()); + ASSERT_EQ(primary_key_split2.size(), 2); + ASSERT_TRUE(primary_key_split2 == result_table[key2]) << "Real value is: " << result[1].ValueString(); + } +} + +TEST_F(InterpreterTest, SchemaTestCreateDropAndShow) { + Interpret("CREATE SCHEMA ON :label(name STRING, age INTEGER)"); + // Wrong syntax for dropping schema. + ASSERT_THROW(Interpret("DROP SCHEMA ON :label();"), memgraph::query::SyntaxException); + // Cannot drop non existant schema. + ASSERT_THROW(Interpret("DROP SCHEMA ON :label1;"), memgraph::query::QueryException); + + // Create Schema and Drop + auto get_number_of_schemas = [this]() { + auto stream = Interpret("SHOW SCHEMAS"); + return stream.GetResults().size(); + }; + + ASSERT_EQ(get_number_of_schemas(), 1); + Interpret("CREATE SCHEMA ON :label1(name STRING, age INTEGER)"); + ASSERT_EQ(get_number_of_schemas(), 2); + Interpret("CREATE SCHEMA ON :label2(name STRING, sex BOOL)"); + ASSERT_EQ(get_number_of_schemas(), 3); + Interpret("DROP SCHEMA ON :label1"); + ASSERT_EQ(get_number_of_schemas(), 2); + Interpret("CREATE SCHEMA ON :label3(name STRING, birthday LOCALDATETIME)"); + ASSERT_EQ(get_number_of_schemas(), 3); + Interpret("DROP SCHEMA ON :label2"); + ASSERT_EQ(get_number_of_schemas(), 2); + Interpret("CREATE SCHEMA ON :label4(name STRING, age DURATION)"); + ASSERT_EQ(get_number_of_schemas(), 3); + Interpret("DROP SCHEMA ON :label3"); + ASSERT_EQ(get_number_of_schemas(), 2); + Interpret("DROP SCHEMA ON :label"); + ASSERT_EQ(get_number_of_schemas(), 1); + + // Show schemas + auto stream = Interpret("SHOW SCHEMAS"); + ASSERT_EQ(stream.GetHeader().size(), 2U); + const auto &header = stream.GetHeader(); + ASSERT_EQ(header[0], "label"); + ASSERT_EQ(header[1], "primary_key"); + ASSERT_EQ(stream.GetResults().size(), 1U); + std::unordered_map> result_table{ + {"label4", {"name::String", "age::Duration"}}}; + + const auto &result = stream.GetResults().front(); + ASSERT_EQ(result.size(), 2U); + const auto key1 = result[0].ValueString(); + ASSERT_TRUE(result_table.contains(key1)); + const auto primary_key_split = StringToUnorderedSet(result[1].ValueString()); + ASSERT_EQ(primary_key_split.size(), 2); + ASSERT_TRUE(primary_key_split == result_table[key1]); +} diff --git a/tests/unit/query_required_privileges.cpp b/tests/unit/query_required_privileges.cpp index ad21b10c4..4aab492e1 100644 --- a/tests/unit/query_required_privileges.cpp +++ b/tests/unit/query_required_privileges.cpp @@ -192,6 +192,11 @@ TEST_F(TestPrivilegeExtractor, ShowVersion) { EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::STATS)); } +TEST_F(TestPrivilegeExtractor, SchemaQuery) { + auto *query = storage.Create(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::SCHEMA)); +} + TEST_F(TestPrivilegeExtractor, CallProcedureQuery) { { auto *query = QUERY(SINGLE_QUERY(CALL_PROCEDURE("mg.get_module_files"))); From 462daf3a2bf81dfab1e7093114be0faf82add594 Mon Sep 17 00:00:00 2001 From: Jure Bajic Date: Fri, 29 Jul 2022 13:38:17 +0200 Subject: [PATCH 3/5] Enforce schema on vertex creation - Separating schema definition from schema validation - Updating vertex_accessor and db_accessors with necessary methods - Adding a primary label to Vertex - Adding schema tests - Updating existing tests for storage v3, and deprecating old: - interpreter => interpreter_v2 - query_plan_accumulate_aggregate => storage_v3_query_plan_accumulate_aggregate - query_plan_create_set_remove_delete => storage_v3_query_plan_create_set_remove_delete - query_plan_bag_semantics => storage_v3_query_plan_bag_semantics - query_plan_edge_cases => storage_v3_query_plan_edge_cases - query_plan_v2_create_set_remove_delete => storage_v3_query_plan_v2_create_set_remove_delete - query_plan_match_filter_return => storage_v3_query_plan_match_filter_return --- src/query/common.hpp | 85 +- src/query/db_accessor.hpp | 41 +- src/query/exceptions.hpp | 8 + src/query/plan/operator.cpp | 173 +- src/storage/v2/CMakeLists.txt | 1 + src/storage/v2/constraints.cpp | 11 +- src/storage/v2/constraints.hpp | 4 +- src/storage/v2/durability/snapshot.cpp | 12 +- src/storage/v2/durability/snapshot.hpp | 6 +- src/storage/v2/edge_accessor.cpp | 5 +- src/storage/v2/edge_accessor.hpp | 6 +- src/storage/v2/indices.cpp | 23 +- src/storage/v2/indices.hpp | 33 +- .../v2/replication/replication_server.cpp | 10 +- src/storage/v2/schema_validator.cpp | 106 + src/storage/v2/schema_validator.hpp | 69 + src/storage/v2/schemas.cpp | 48 +- src/storage/v2/schemas.hpp | 34 +- src/storage/v2/storage.cpp | 108 +- src/storage/v2/storage.hpp | 27 +- src/storage/v2/vertex.hpp | 27 +- src/storage/v2/vertex_accessor.cpp | 138 +- src/storage/v2/vertex_accessor.hpp | 46 +- tests/unit/CMakeLists.txt | 99 +- tests/unit/interpreter.cpp | 149 -- tests/unit/interpreter_v2.cpp | 1636 +++++++++++++ tests/unit/query_common.hpp | 4 +- tests/unit/query_plan_bag_semantics.cpp | 5 - tests/unit/query_plan_common.hpp | 10 + ...ery_v2_query_plan_accumulate_aggregate.cpp | 631 +++++ .../query_v2_query_plan_bag_semantics.cpp | 309 +++ ...v2_query_plan_create_set_remove_delete.cpp | 1095 +++++++++ tests/unit/query_v2_query_plan_edge_cases.cpp | 115 + ...uery_v2_query_plan_match_filter_return.cpp | 2062 +++++++++++++++++ ...query_plan_v2_create_set_remove_delete.cpp | 146 ++ tests/unit/storage_v3_schema.cpp | 294 +++ 36 files changed, 7120 insertions(+), 456 deletions(-) create mode 100644 src/storage/v2/schema_validator.cpp create mode 100644 src/storage/v2/schema_validator.hpp create mode 100644 tests/unit/interpreter_v2.cpp create mode 100644 tests/unit/query_v2_query_plan_accumulate_aggregate.cpp create mode 100644 tests/unit/query_v2_query_plan_bag_semantics.cpp create mode 100644 tests/unit/query_v2_query_plan_create_set_remove_delete.cpp create mode 100644 tests/unit/query_v2_query_plan_edge_cases.cpp create mode 100644 tests/unit/query_v2_query_plan_match_filter_return.cpp create mode 100644 tests/unit/query_v2_query_plan_v2_create_set_remove_delete.cpp create mode 100644 tests/unit/storage_v3_schema.cpp diff --git a/src/query/common.hpp b/src/query/common.hpp index c51c34dee..f6526494c 100644 --- a/src/query/common.hpp +++ b/src/query/common.hpp @@ -16,6 +16,7 @@ #include #include #include +#include #include "query/db_accessor.hpp" #include "query/exceptions.hpp" @@ -24,8 +25,12 @@ #include "query/typed_value.hpp" #include "storage/v2/id_types.hpp" #include "storage/v2/property_value.hpp" +#include "storage/v2/result.hpp" +#include "storage/v2/schema_validator.hpp" #include "storage/v2/view.hpp" +#include "utils/exceptions.hpp" #include "utils/logging.hpp" +#include "utils/variant_helpers.hpp" namespace memgraph::query { @@ -81,27 +86,79 @@ concept AccessorWithSetProperty = requires(T accessor, const storage::PropertyId { accessor.SetProperty(key, new_value) } -> std::same_as>; }; +inline void HandleSchemaViolation(const storage::SchemaViolation &schema_violation, const DbAccessor &dba) { + switch (schema_violation.status) { + case storage::SchemaViolation::ValidationStatus::VERTEX_HAS_NO_PRIMARY_PROPERTY: { + throw SchemaViolationException( + fmt::format("Primary key {} not defined on label :{}", + storage::SchemaTypeToString(schema_violation.violated_schema_property->type), + dba.LabelToName(schema_violation.label))); + } + case storage::SchemaViolation::ValidationStatus::NO_SCHEMA_DEFINED_FOR_LABEL: { + throw SchemaViolationException( + fmt::format("Label :{} is not a primary label", dba.LabelToName(schema_violation.label))); + } + case storage::SchemaViolation::ValidationStatus::VERTEX_PROPERTY_WRONG_TYPE: { + throw SchemaViolationException( + fmt::format("Wrong type of property {} in schema :{}, should be of type {}", + *schema_violation.violated_property_value, dba.LabelToName(schema_violation.label), + storage::SchemaTypeToString(schema_violation.violated_schema_property->type))); + } + case storage::SchemaViolation::ValidationStatus::VERTEX_UPDATE_PRIMARY_KEY: { + throw SchemaViolationException(fmt::format("Updating of primary key {} on schema :{} not supported", + *schema_violation.violated_property_value, + dba.LabelToName(schema_violation.label))); + } + case storage::SchemaViolation::ValidationStatus::VERTEX_MODIFY_PRIMARY_LABEL: { + throw SchemaViolationException(fmt::format("Cannot add or remove label :{} since it is a primary label", + dba.LabelToName(schema_violation.label))); + } + case storage::SchemaViolation::ValidationStatus::VERTEX_SECONDARY_LABEL_IS_PRIMARY: { + throw SchemaViolationException( + fmt::format("Cannot create vertex with secondary label :{}", dba.LabelToName(schema_violation.label))); + } + } +} + +inline void HandleErrorOnPropertyUpdate(const storage::Error error) { + switch (error) { + case storage::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to set properties on a deleted object."); + case storage::Error::PROPERTIES_DISABLED: + throw QueryRuntimeException("Can't set property because properties on edges are disabled."); + case storage::Error::VERTEX_HAS_EDGES: + case storage::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when setting a property."); + } +} + /// Set a property `value` mapped with given `key` on a `record`. /// /// @throw QueryRuntimeException if value cannot be set as a property value template -storage::PropertyValue PropsSetChecked(T *record, const storage::PropertyId &key, const TypedValue &value) { +storage::PropertyValue PropsSetChecked(T *record, const DbAccessor &dba, const storage::PropertyId &key, + const TypedValue &value) { try { - auto maybe_old_value = record->SetProperty(key, storage::PropertyValue(value)); - if (maybe_old_value.HasError()) { - switch (maybe_old_value.GetError()) { - case storage::Error::SERIALIZATION_ERROR: - throw TransactionSerializationException(); - case storage::Error::DELETED_OBJECT: - throw QueryRuntimeException("Trying to set properties on a deleted object."); - case storage::Error::PROPERTIES_DISABLED: - throw QueryRuntimeException("Can't set property because properties on edges are disabled."); - case storage::Error::VERTEX_HAS_EDGES: - case storage::Error::NONEXISTENT_OBJECT: - throw QueryRuntimeException("Unexpected error when setting a property."); + if constexpr (std::is_same_v) { + const auto maybe_old_value = record->SetPropertyAndValidate(key, storage::PropertyValue(value)); + if (maybe_old_value.HasError()) { + std::visit(utils::Overloaded{[](const storage::Error error) { HandleErrorOnPropertyUpdate(error); }, + [&dba](const storage::SchemaViolation &schema_violation) { + HandleSchemaViolation(schema_violation, dba); + }}, + maybe_old_value.GetError()); } + return std::move(*maybe_old_value); + } else { + // No validation on edge properties + const auto maybe_old_value = record->SetProperty(key, storage::PropertyValue(value)); + if (maybe_old_value.HasError()) { + HandleErrorOnPropertyUpdate(maybe_old_value.GetError()); + } + return std::move(*maybe_old_value); } - return std::move(*maybe_old_value); } catch (const TypedValueException &) { throw QueryRuntimeException("'{}' cannot be used as a property value.", value.type()); } diff --git a/src/query/db_accessor.hpp b/src/query/db_accessor.hpp index 0369b462b..1fad73101 100644 --- a/src/query/db_accessor.hpp +++ b/src/query/db_accessor.hpp @@ -12,6 +12,7 @@ #pragma once #include +#include #include #include @@ -23,7 +24,7 @@ /////////////////////////////////////////////////////////// // Our communication layer and query engine don't mix -// very well on Centos because OpenSSL version avaialable +// very well on Centos because OpenSSL version available // on Centos 7 include libkrb5 which has brilliant macros // called TRUE and FALSE. For more detailed explanation go // to memgraph.cpp. @@ -34,6 +35,8 @@ // simply undefine those macros as we're sure that libkrb5 // won't and can't be used anywhere in the query engine. #include "storage/v2/storage.hpp" +#include "utils/logging.hpp" +#include "utils/result.hpp" #undef FALSE #undef TRUE @@ -51,7 +54,6 @@ class EdgeAccessor final { public: storage::EdgeAccessor impl_; - public: explicit EdgeAccessor(storage::EdgeAccessor impl) : impl_(std::move(impl)) {} bool IsVisible(storage::View view) const { return impl_.IsVisible(view); } @@ -97,17 +99,24 @@ class VertexAccessor final { static EdgeAccessor MakeEdgeAccessor(const storage::EdgeAccessor impl) { return EdgeAccessor(impl); } - public: explicit VertexAccessor(storage::VertexAccessor impl) : impl_(impl) {} bool IsVisible(storage::View view) const { return impl_.IsVisible(view); } auto Labels(storage::View view) const { return impl_.Labels(view); } + auto PrimaryLabel(storage::View view) const { return impl_.PrimaryLabel(view); } + storage::Result AddLabel(storage::LabelId label) { return impl_.AddLabel(label); } + storage::ResultSchema AddLabelAndValidate(storage::LabelId label) { return impl_.AddLabelAndValidate(label); } + storage::Result RemoveLabel(storage::LabelId label) { return impl_.RemoveLabel(label); } + storage::ResultSchema RemoveLabelAndValidate(storage::LabelId label) { + return impl_.RemoveLabelAndValidate(label); + } + storage::Result HasLabel(storage::View view, storage::LabelId label) const { return impl_.HasLabel(label, view); } @@ -122,8 +131,13 @@ class VertexAccessor final { return impl_.SetProperty(key, value); } - storage::Result RemoveProperty(storage::PropertyId key) { - return SetProperty(key, storage::PropertyValue()); + storage::ResultSchema SetPropertyAndValidate(storage::PropertyId key, + const storage::PropertyValue &value) { + return impl_.SetPropertyAndValidate(key, value); + } + + storage::ResultSchema RemovePropertyAndValidate(storage::PropertyId key) { + return SetPropertyAndValidate(key, storage::PropertyValue{}); } storage::Result> ClearProperties() { @@ -249,7 +263,18 @@ class DbAccessor final { return VerticesIterable(accessor_->Vertices(label, property, lower, upper, view)); } - VertexAccessor InsertVertex() { return VertexAccessor(accessor_->CreateVertex()); } + // TODO Remove when query modules have been fixed + [[deprecated]] VertexAccessor InsertVertex() { return VertexAccessor(accessor_->CreateVertex()); } + + storage::ResultSchema InsertVertexAndValidate( + const storage::LabelId primary_label, const std::vector &labels, + const std::vector> &properties) { + auto maybe_vertex_acc = accessor_->CreateVertexAndValidate(primary_label, labels, properties); + if (maybe_vertex_acc.HasError()) { + return {std::move(maybe_vertex_acc.GetError())}; + } + return VertexAccessor{maybe_vertex_acc.GetValue()}; + } storage::Result InsertEdge(VertexAccessor *from, VertexAccessor *to, const storage::EdgeTypeId &edge_type) { @@ -307,7 +332,7 @@ class DbAccessor final { return std::optional{}; } - return std::make_optional(*value); + return {std::make_optional(*value)}; } storage::PropertyId NameToProperty(const std::string_view name) { return accessor_->NameToProperty(name); } @@ -357,6 +382,8 @@ class DbAccessor final { storage::ConstraintsInfo ListAllConstraints() const { return accessor_->ListAllConstraints(); } + const storage::SchemaValidator &GetSchemaValidator() const { return accessor_->GetSchemaValidator(); } + storage::SchemasInfo ListAllSchemas() const { return accessor_->ListAllSchemas(); } }; diff --git a/src/query/exceptions.hpp b/src/query/exceptions.hpp index a18ce0c43..daba98e7a 100644 --- a/src/query/exceptions.hpp +++ b/src/query/exceptions.hpp @@ -224,4 +224,12 @@ class VersionInfoInMulticommandTxException : public QueryException { : QueryException("Version info query not allowed in multicommand transactions.") {} }; +/** + * An exception for an illegal operation that violates schema + */ +class SchemaViolationException : public QueryRuntimeException { + public: + using QueryRuntimeException::QueryRuntimeException; +}; + } // namespace memgraph::query diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 74463ead0..07442572c 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -37,7 +37,12 @@ #include "query/procedure/cypher_types.hpp" #include "query/procedure/mg_procedure_impl.hpp" #include "query/procedure/module.hpp" +#include "query/typed_value.hpp" +#include "storage/v2/id_types.hpp" #include "storage/v2/property_value.hpp" +#include "storage/v2/result.hpp" +#include "storage/v2/schema_validator.hpp" +#include "storage/v2/schemas.hpp" #include "utils/algorithm.hpp" #include "utils/csv_parsing.hpp" #include "utils/event_counter.hpp" @@ -52,6 +57,7 @@ #include "utils/readable_size.hpp" #include "utils/string.hpp" #include "utils/temporal.hpp" +#include "utils/variant_helpers.hpp" // macro for the default implementation of LogicalOperator::Accept // that accepts the visitor and visits it's input_ operator @@ -174,45 +180,56 @@ CreateNode::CreateNode(const std::shared_ptr &input, const Node // Creates a vertex on this GraphDb. Returns a reference to vertex placed on the // frame. -VertexAccessor &CreateLocalVertex(const NodeCreationInfo &node_info, Frame *frame, ExecutionContext &context) { +VertexAccessor &CreateLocalVertexAtomically(const NodeCreationInfo &node_info, Frame *frame, + ExecutionContext &context) { auto &dba = *context.db_accessor; - auto new_node = dba.InsertVertex(); - context.execution_stats[ExecutionStats::Key::CREATED_NODES] += 1; - for (auto label : node_info.labels) { - auto maybe_error = new_node.AddLabel(label); - if (maybe_error.HasError()) { - switch (maybe_error.GetError()) { - case storage::Error::SERIALIZATION_ERROR: - throw TransactionSerializationException(); - case storage::Error::DELETED_OBJECT: - throw QueryRuntimeException("Trying to set a label on a deleted node."); - case storage::Error::VERTEX_HAS_EDGES: - case storage::Error::PROPERTIES_DISABLED: - case storage::Error::NONEXISTENT_OBJECT: - throw QueryRuntimeException("Unexpected error when setting a label."); - } - } - context.execution_stats[ExecutionStats::Key::CREATED_LABELS] += 1; - } // Evaluator should use the latest accessors, as modified in this query, when // setting properties on new nodes. ExpressionEvaluator evaluator(frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::NEW); - // TODO: PropsSetChecked allocates a PropertyValue, make it use context.memory - // when we update PropertyValue with custom allocator. + + std::vector> properties; if (const auto *node_info_properties = std::get_if(&node_info.properties)) { + properties.reserve(node_info_properties->size()); for (const auto &[key, value_expression] : *node_info_properties) { - PropsSetChecked(&new_node, key, value_expression->Accept(evaluator)); + properties.emplace_back(key, storage::PropertyValue(value_expression->Accept(evaluator))); } } else { - auto property_map = evaluator.Visit(*std::get(node_info.properties)); - for (const auto &[key, value] : property_map.ValueMap()) { + auto property_map = evaluator.Visit(*std::get(node_info.properties)).ValueMap(); + properties.reserve(property_map.size()); + + for (const auto &[key, value] : property_map) { auto property_id = dba.NameToProperty(key); - PropsSetChecked(&new_node, property_id, value); + properties.emplace_back(property_id, value); } } + // TODO Remove later on since that will be enforced from grammar side + MG_ASSERT(!node_info.labels.empty(), "There must be at least one label!"); + const auto primary_label = node_info.labels[0]; + std::vector secondary_labels(node_info.labels.begin() + 1, node_info.labels.end()); + auto maybe_new_node = dba.InsertVertexAndValidate(primary_label, secondary_labels, properties); + if (maybe_new_node.HasError()) { + std::visit(utils::Overloaded{[&dba](const storage::SchemaViolation &schema_violation) { + HandleSchemaViolation(schema_violation, dba); + }, + [](const storage::Error error) { + switch (error) { + case storage::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to set a label on a deleted node."); + case storage::Error::VERTEX_HAS_EDGES: + case storage::Error::PROPERTIES_DISABLED: + case storage::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when setting a label."); + } + }}, + maybe_new_node.GetError()); + } - (*frame)[node_info.symbol] = new_node; + context.execution_stats[ExecutionStats::Key::CREATED_NODES] += 1; + + (*frame)[node_info.symbol] = *maybe_new_node; return (*frame)[node_info.symbol].ValueVertex(); } @@ -237,7 +254,7 @@ bool CreateNode::CreateNodeCursor::Pull(Frame &frame, ExecutionContext &context) SCOPED_PROFILE_OP("CreateNode"); if (input_cursor_->Pull(frame, context)) { - auto created_vertex = CreateLocalVertex(self_.node_info_, &frame, context); + auto created_vertex = CreateLocalVertexAtomically(self_.node_info_, &frame, context); if (context.trigger_context_collector) { context.trigger_context_collector->RegisterCreatedObject(created_vertex); } @@ -286,13 +303,13 @@ EdgeAccessor CreateEdge(const EdgeCreationInfo &edge_info, DbAccessor *dba, Vert auto &edge = *maybe_edge; if (const auto *properties = std::get_if(&edge_info.properties)) { for (const auto &[key, value_expression] : *properties) { - PropsSetChecked(&edge, key, value_expression->Accept(*evaluator)); + PropsSetChecked(&edge, *dba, key, value_expression->Accept(*evaluator)); } } else { auto property_map = evaluator->Visit(*std::get(edge_info.properties)); for (const auto &[key, value] : property_map.ValueMap()) { auto property_id = dba->NameToProperty(key); - PropsSetChecked(&edge, property_id, value); + PropsSetChecked(&edge, *dba, property_id, value); } } @@ -368,13 +385,12 @@ VertexAccessor &CreateExpand::CreateExpandCursor::OtherVertex(Frame &frame, Exec TypedValue &dest_node_value = frame[self_.node_info_.symbol]; ExpectType(self_.node_info_.symbol, dest_node_value, TypedValue::Type::Vertex); return dest_node_value.ValueVertex(); - } else { - auto &created_vertex = CreateLocalVertex(self_.node_info_, &frame, context); - if (context.trigger_context_collector) { - context.trigger_context_collector->RegisterCreatedObject(created_vertex); - } - return created_vertex; } + auto &created_vertex = CreateLocalVertexAtomically(self_.node_info_, &frame, context); + if (context.trigger_context_collector) { + context.trigger_context_collector->RegisterCreatedObject(created_vertex); + } + return created_vertex; } template @@ -2047,7 +2063,7 @@ bool SetProperty::SetPropertyCursor::Pull(Frame &frame, ExecutionContext &contex switch (lhs.type()) { case TypedValue::Type::Vertex: { - auto old_value = PropsSetChecked(&lhs.ValueVertex(), self_.property_, rhs); + auto old_value = PropsSetChecked(&lhs.ValueVertex(), *context.db_accessor, self_.property_, rhs); context.execution_stats[ExecutionStats::Key::UPDATED_PROPERTIES] += 1; if (context.trigger_context_collector) { // rhs cannot be moved because it was created with the allocator that is only valid during current pull @@ -2057,7 +2073,7 @@ bool SetProperty::SetPropertyCursor::Pull(Frame &frame, ExecutionContext &contex break; } case TypedValue::Type::Edge: { - auto old_value = PropsSetChecked(&lhs.ValueEdge(), self_.property_, rhs); + auto old_value = PropsSetChecked(&lhs.ValueEdge(), *context.db_accessor, self_.property_, rhs); context.execution_stats[ExecutionStats::Key::UPDATED_PROPERTIES] += 1; if (context.trigger_context_collector) { // rhs cannot be moved because it was created with the allocator that is only valid during current pull @@ -2211,7 +2227,7 @@ void SetPropertiesOnRecord(TRecordAccessor *record, const TypedValue &rhs, SetPr case TypedValue::Type::Map: { for (const auto &kv : rhs.ValueMap()) { auto key = context->db_accessor->NameToProperty(kv.first); - auto old_value = PropsSetChecked(record, key, kv.second); + auto old_value = PropsSetChecked(record, *context->db_accessor, key, kv.second); if (should_register_change) { register_set_property(std::move(old_value), key, kv.second); } @@ -2295,22 +2311,31 @@ bool SetLabels::SetLabelsCursor::Pull(Frame &frame, ExecutionContext &context) { // Skip setting labels on Null (can occur in optional match). if (vertex_value.IsNull()) return true; ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); + + auto &dba = *context.db_accessor; auto &vertex = vertex_value.ValueVertex(); - for (auto label : self_.labels_) { - auto maybe_value = vertex.AddLabel(label); + for (const auto label : self_.labels_) { + auto maybe_value = vertex.AddLabelAndValidate(label); if (maybe_value.HasError()) { - switch (maybe_value.GetError()) { - case storage::Error::SERIALIZATION_ERROR: - throw TransactionSerializationException(); - case storage::Error::DELETED_OBJECT: - throw QueryRuntimeException("Trying to set a label on a deleted node."); - case storage::Error::VERTEX_HAS_EDGES: - case storage::Error::PROPERTIES_DISABLED: - case storage::Error::NONEXISTENT_OBJECT: - throw QueryRuntimeException("Unexpected error when setting a label."); - } + std::visit(utils::Overloaded{[](const storage::Error error) { + switch (error) { + case storage::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to set a label on a deleted node."); + case storage::Error::VERTEX_HAS_EDGES: + case storage::Error::PROPERTIES_DISABLED: + case storage::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when setting a label."); + } + }, + [&dba](const storage::SchemaViolation schema_violation) { + HandleSchemaViolation(schema_violation, dba); + }}, + maybe_value.GetError()); } + context.execution_stats[ExecutionStats::Key::CREATED_LABELS]++; if (context.trigger_context_collector && *maybe_value) { context.trigger_context_collector->RegisterSetVertexLabel(vertex, label); } @@ -2353,26 +2378,11 @@ bool RemoveProperty::RemovePropertyCursor::Pull(Frame &frame, ExecutionContext & TypedValue lhs = self_.lhs_->expression_->Accept(evaluator); auto remove_prop = [property = self_.property_, &context](auto *record) { - auto maybe_old_value = record->RemoveProperty(property); - if (maybe_old_value.HasError()) { - switch (maybe_old_value.GetError()) { - case storage::Error::DELETED_OBJECT: - throw QueryRuntimeException("Trying to remove a property on a deleted graph element."); - case storage::Error::SERIALIZATION_ERROR: - throw TransactionSerializationException(); - case storage::Error::PROPERTIES_DISABLED: - throw QueryRuntimeException( - "Can't remove property because properties on edges are " - "disabled."); - case storage::Error::VERTEX_HAS_EDGES: - case storage::Error::NONEXISTENT_OBJECT: - throw QueryRuntimeException("Unexpected error when removing property."); - } - } + auto old_value = PropsSetChecked(record, *context.db_accessor, property, TypedValue{}); if (context.trigger_context_collector) { context.trigger_context_collector->RegisterRemovedObjectProperty(*record, property, - TypedValue(std::move(*maybe_old_value))); + TypedValue(std::move(old_value))); } }; @@ -2426,18 +2436,25 @@ bool RemoveLabels::RemoveLabelsCursor::Pull(Frame &frame, ExecutionContext &cont ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); auto &vertex = vertex_value.ValueVertex(); for (auto label : self_.labels_) { - auto maybe_value = vertex.RemoveLabel(label); + auto maybe_value = vertex.RemoveLabelAndValidate(label); if (maybe_value.HasError()) { - switch (maybe_value.GetError()) { - case storage::Error::SERIALIZATION_ERROR: - throw TransactionSerializationException(); - case storage::Error::DELETED_OBJECT: - throw QueryRuntimeException("Trying to remove labels from a deleted node."); - case storage::Error::VERTEX_HAS_EDGES: - case storage::Error::PROPERTIES_DISABLED: - case storage::Error::NONEXISTENT_OBJECT: - throw QueryRuntimeException("Unexpected error when removing labels from a node."); - } + std::visit( + utils::Overloaded{[](const storage::Error error) { + switch (error) { + case storage::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to remove labels from a deleted node."); + case storage::Error::VERTEX_HAS_EDGES: + case storage::Error::PROPERTIES_DISABLED: + case storage::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when removing labels from a node."); + } + }, + [&context](const storage::SchemaViolation &schema_violation) { + HandleSchemaViolation(schema_violation, *context.db_accessor); + }}, + maybe_value.GetError()); } context.execution_stats[ExecutionStats::Key::DELETED_LABELS] += 1; diff --git a/src/storage/v2/CMakeLists.txt b/src/storage/v2/CMakeLists.txt index bcdafc6fc..52ebfdd1f 100644 --- a/src/storage/v2/CMakeLists.txt +++ b/src/storage/v2/CMakeLists.txt @@ -11,6 +11,7 @@ set(storage_v2_src_files property_store.cpp vertex_accessor.cpp schemas.cpp + schema_validator.cpp storage.cpp) ##### Replication ##### diff --git a/src/storage/v2/constraints.cpp b/src/storage/v2/constraints.cpp index fab6ee4c4..5e5988099 100644 --- a/src/storage/v2/constraints.cpp +++ b/src/storage/v2/constraints.cpp @@ -16,6 +16,7 @@ #include #include "storage/v2/mvcc.hpp" +#include "storage/v2/vertex.hpp" #include "utils/logging.hpp" namespace memgraph::storage { @@ -59,7 +60,7 @@ bool LastCommittedVersionHasLabelProperty(const Vertex &vertex, LabelId label, c std::lock_guard guard(vertex.lock); delta = vertex.delta; deleted = vertex.deleted; - has_label = utils::Contains(vertex.labels, label); + has_label = VertexHasLabel(vertex, label); size_t i = 0; for (const auto &property : properties) { @@ -142,7 +143,7 @@ bool AnyVersionHasLabelProperty(const Vertex &vertex, LabelId label, const std:: Delta *delta; { std::lock_guard guard(vertex.lock); - has_label = utils::Contains(vertex.labels, label); + has_label = VertexHasLabel(vertex, label); deleted = vertex.deleted; delta = vertex.delta; @@ -267,7 +268,7 @@ bool UniqueConstraints::Entry::operator==(const std::vector &rhs) void UniqueConstraints::UpdateBeforeCommit(const Vertex *vertex, const Transaction &tx) { for (auto &[label_props, storage] : constraints_) { - if (!utils::Contains(vertex->labels, label_props.first)) { + if (!VertexHasLabel(*vertex, label_props.first)) { continue; } auto values = ExtractPropertyValues(*vertex, label_props.second); @@ -301,7 +302,7 @@ utils::BasicResult Uniqu auto acc = constraint->second.access(); for (const Vertex &vertex : vertices) { - if (vertex.deleted || !utils::Contains(vertex.labels, label)) { + if (vertex.deleted || !VertexHasLabel(vertex, label)) { continue; } auto values = ExtractPropertyValues(vertex, properties); @@ -352,7 +353,7 @@ std::optional UniqueConstraints::Validate(const Vertex &ver for (const auto &[label_props, storage] : constraints_) { const auto &label = label_props.first; const auto &properties = label_props.second; - if (!utils::Contains(vertex.labels, label)) { + if (!VertexHasLabel(vertex, label)) { continue; } diff --git a/src/storage/v2/constraints.hpp b/src/storage/v2/constraints.hpp index b209437f8..427b6ca4f 100644 --- a/src/storage/v2/constraints.hpp +++ b/src/storage/v2/constraints.hpp @@ -158,7 +158,7 @@ inline utils::BasicResult CreateExistenceConstraint( return false; } for (const auto &vertex : vertices) { - if (!vertex.deleted && utils::Contains(vertex.labels, label) && !vertex.properties.HasProperty(property)) { + if (!vertex.deleted && VertexHasLabel(vertex, label) && !vertex.properties.HasProperty(property)) { return ConstraintViolation{ConstraintViolation::Type::EXISTENCE, label, std::set{property}}; } } @@ -184,7 +184,7 @@ inline bool DropExistenceConstraint(Constraints *constraints, LabelId label, Pro [[nodiscard]] inline std::optional ValidateExistenceConstraints(const Vertex &vertex, const Constraints &constraints) { for (const auto &[label, property] : constraints.existence_constraints) { - if (!vertex.deleted && utils::Contains(vertex.labels, label) && !vertex.properties.HasProperty(property)) { + if (!vertex.deleted && VertexHasLabel(vertex, label) && !vertex.properties.HasProperty(property)) { return ConstraintViolation{ConstraintViolation::Type::EXISTENCE, label, std::set{property}}; } } diff --git a/src/storage/v2/durability/snapshot.cpp b/src/storage/v2/durability/snapshot.cpp index 16c7d017c..002da25fa 100644 --- a/src/storage/v2/durability/snapshot.cpp +++ b/src/storage/v2/durability/snapshot.cpp @@ -628,8 +628,9 @@ RecoveredSnapshot LoadSnapshot(const std::filesystem::path &path, utils::SkipLis void CreateSnapshot(Transaction *transaction, const std::filesystem::path &snapshot_directory, const std::filesystem::path &wal_directory, uint64_t snapshot_retention_count, utils::SkipList *vertices, utils::SkipList *edges, NameIdMapper *name_id_mapper, - Indices *indices, Constraints *constraints, Config::Items items, const std::string &uuid, - const std::string_view epoch_id, const std::deque> &epoch_history, + Indices *indices, Constraints *constraints, Config::Items items, + const SchemaValidator &schema_validator, const std::string &uuid, const std::string_view epoch_id, + const std::deque> &epoch_history, utils::FileRetainer *file_retainer) { // Ensure that the storage directory exists. utils::EnsureDirOrDie(snapshot_directory); @@ -713,8 +714,9 @@ void CreateSnapshot(Transaction *transaction, const std::filesystem::path &snaps // type and invalid from/to pointers because we don't know them here, // but that isn't an issue because we won't use that part of the API // here. - auto ea = - EdgeAccessor{edge_ref, EdgeTypeId::FromUint(0UL), nullptr, nullptr, transaction, indices, constraints, items}; + // TODO(jbajic) Fix snapshot with new schema rules + auto ea = EdgeAccessor{edge_ref, EdgeTypeId::FromUint(0UL), nullptr, nullptr, transaction, indices, constraints, + items, schema_validator}; // Get edge data. auto maybe_props = ea.Properties(View::OLD); @@ -742,7 +744,7 @@ void CreateSnapshot(Transaction *transaction, const std::filesystem::path &snaps auto acc = vertices->access(); for (auto &vertex : acc) { // The visibility check is implemented for vertices so we use it here. - auto va = VertexAccessor::Create(&vertex, transaction, indices, constraints, items, View::OLD); + auto va = VertexAccessor::Create(&vertex, transaction, indices, constraints, items, schema_validator, View::OLD); if (!va) continue; // Get vertex data. diff --git a/src/storage/v2/durability/snapshot.hpp b/src/storage/v2/durability/snapshot.hpp index b1cfad63c..643c1a34c 100644 --- a/src/storage/v2/durability/snapshot.hpp +++ b/src/storage/v2/durability/snapshot.hpp @@ -21,6 +21,7 @@ #include "storage/v2/edge.hpp" #include "storage/v2/indices.hpp" #include "storage/v2/name_id_mapper.hpp" +#include "storage/v2/schema_validator.hpp" #include "storage/v2/transaction.hpp" #include "storage/v2/vertex.hpp" #include "utils/file_locker.hpp" @@ -68,8 +69,9 @@ RecoveredSnapshot LoadSnapshot(const std::filesystem::path &path, utils::SkipLis void CreateSnapshot(Transaction *transaction, const std::filesystem::path &snapshot_directory, const std::filesystem::path &wal_directory, uint64_t snapshot_retention_count, utils::SkipList *vertices, utils::SkipList *edges, NameIdMapper *name_id_mapper, - Indices *indices, Constraints *constraints, Config::Items items, const std::string &uuid, - std::string_view epoch_id, const std::deque> &epoch_history, + Indices *indices, Constraints *constraints, Config::Items items, + const SchemaValidator &schema_validator, const std::string &uuid, std::string_view epoch_id, + const std::deque> &epoch_history, utils::FileRetainer *file_retainer); } // namespace memgraph::storage::durability diff --git a/src/storage/v2/edge_accessor.cpp b/src/storage/v2/edge_accessor.cpp index ef0444422..acb3ec288 100644 --- a/src/storage/v2/edge_accessor.cpp +++ b/src/storage/v2/edge_accessor.cpp @@ -15,6 +15,7 @@ #include "storage/v2/mvcc.hpp" #include "storage/v2/property_value.hpp" +#include "storage/v2/schema_validator.hpp" #include "storage/v2/vertex_accessor.hpp" #include "utils/memory_tracker.hpp" @@ -54,11 +55,11 @@ bool EdgeAccessor::IsVisible(const View view) const { } VertexAccessor EdgeAccessor::FromVertex() const { - return VertexAccessor{from_vertex_, transaction_, indices_, constraints_, config_}; + return VertexAccessor{from_vertex_, transaction_, indices_, constraints_, config_, *schema_validator_}; } VertexAccessor EdgeAccessor::ToVertex() const { - return VertexAccessor{to_vertex_, transaction_, indices_, constraints_, config_}; + return VertexAccessor{to_vertex_, transaction_, indices_, constraints_, config_, *schema_validator_}; } Result EdgeAccessor::SetProperty(PropertyId property, const PropertyValue &value) { diff --git a/src/storage/v2/edge_accessor.hpp b/src/storage/v2/edge_accessor.hpp index b0a1e1151..11abad3a1 100644 --- a/src/storage/v2/edge_accessor.hpp +++ b/src/storage/v2/edge_accessor.hpp @@ -18,6 +18,7 @@ #include "storage/v2/config.hpp" #include "storage/v2/result.hpp" +#include "storage/v2/schema_validator.hpp" #include "storage/v2/transaction.hpp" #include "storage/v2/view.hpp" @@ -34,7 +35,8 @@ class EdgeAccessor final { public: EdgeAccessor(EdgeRef edge, EdgeTypeId edge_type, Vertex *from_vertex, Vertex *to_vertex, Transaction *transaction, - Indices *indices, Constraints *constraints, Config::Items config, bool for_deleted = false) + Indices *indices, Constraints *constraints, Config::Items config, + const SchemaValidator &schema_validator, bool for_deleted = false) : edge_(edge), edge_type_(edge_type), from_vertex_(from_vertex), @@ -43,6 +45,7 @@ class EdgeAccessor final { indices_(indices), constraints_(constraints), config_(config), + schema_validator_{&schema_validator}, for_deleted_(for_deleted) {} /// @return true if the object is visible from the current transaction @@ -92,6 +95,7 @@ class EdgeAccessor final { Indices *indices_; Constraints *constraints_; Config::Items config_; + const SchemaValidator *schema_validator_; // if the accessor was created for a deleted edge. // Accessor behaves differently for some methods based on this diff --git a/src/storage/v2/indices.cpp b/src/storage/v2/indices.cpp index fb83ff166..2a1c1ff02 100644 --- a/src/storage/v2/indices.cpp +++ b/src/storage/v2/indices.cpp @@ -14,6 +14,7 @@ #include "storage/v2/mvcc.hpp" #include "storage/v2/property_value.hpp" +#include "storage/v2/schema_validator.hpp" #include "utils/bound.hpp" #include "utils/logging.hpp" #include "utils/memory_tracker.hpp" @@ -327,7 +328,7 @@ void LabelIndex::RemoveObsoleteEntries(uint64_t oldest_active_start_timestamp) { LabelIndex::Iterable::Iterator::Iterator(Iterable *self, utils::SkipList::Iterator index_iterator) : self_(self), index_iterator_(index_iterator), - current_vertex_accessor_(nullptr, nullptr, nullptr, nullptr, self_->config_), + current_vertex_accessor_(nullptr, nullptr, nullptr, nullptr, self_->config_, *self_->schema_validator_), current_vertex_(nullptr) { AdvanceUntilValid(); } @@ -345,8 +346,8 @@ void LabelIndex::Iterable::Iterator::AdvanceUntilValid() { } if (CurrentVersionHasLabel(*index_iterator_->vertex, self_->label_, self_->transaction_, self_->view_)) { current_vertex_ = index_iterator_->vertex; - current_vertex_accessor_ = - VertexAccessor{current_vertex_, self_->transaction_, self_->indices_, self_->constraints_, self_->config_}; + current_vertex_accessor_ = VertexAccessor{current_vertex_, self_->transaction_, self_->indices_, + self_->constraints_, self_->config_, *self_->schema_validator_}; break; } } @@ -354,14 +355,15 @@ void LabelIndex::Iterable::Iterator::AdvanceUntilValid() { LabelIndex::Iterable::Iterable(utils::SkipList::Accessor index_accessor, LabelId label, View view, Transaction *transaction, Indices *indices, Constraints *constraints, - Config::Items config) + Config::Items config, const SchemaValidator &schema_validator) : index_accessor_(std::move(index_accessor)), label_(label), view_(view), transaction_(transaction), indices_(indices), constraints_(constraints), - config_(config) {} + config_(config), + schema_validator_(&schema_validator) {} void LabelIndex::RunGC() { for (auto &index_entry : index_) { @@ -478,7 +480,7 @@ void LabelPropertyIndex::RemoveObsoleteEntries(uint64_t oldest_active_start_time LabelPropertyIndex::Iterable::Iterator::Iterator(Iterable *self, utils::SkipList::Iterator index_iterator) : self_(self), index_iterator_(index_iterator), - current_vertex_accessor_(nullptr, nullptr, nullptr, nullptr, self_->config_), + current_vertex_accessor_(nullptr, nullptr, nullptr, nullptr, self_->config_, *self_->schema_validator_), current_vertex_(nullptr) { AdvanceUntilValid(); } @@ -517,8 +519,8 @@ void LabelPropertyIndex::Iterable::Iterator::AdvanceUntilValid() { if (CurrentVersionHasLabelProperty(*index_iterator_->vertex, self_->label_, self_->property_, index_iterator_->value, self_->transaction_, self_->view_)) { current_vertex_ = index_iterator_->vertex; - current_vertex_accessor_ = - VertexAccessor(current_vertex_, self_->transaction_, self_->indices_, self_->constraints_, self_->config_); + current_vertex_accessor_ = VertexAccessor(current_vertex_, self_->transaction_, self_->indices_, + self_->constraints_, self_->config_, *self_->schema_validator_); break; } } @@ -541,7 +543,7 @@ LabelPropertyIndex::Iterable::Iterable(utils::SkipList::Accessor index_ac const std::optional> &lower_bound, const std::optional> &upper_bound, View view, Transaction *transaction, Indices *indices, Constraints *constraints, - Config::Items config) + Config::Items config, const SchemaValidator &schema_validator) : index_accessor_(std::move(index_accessor)), label_(label), property_(property), @@ -551,7 +553,8 @@ LabelPropertyIndex::Iterable::Iterable(utils::SkipList::Accessor index_ac transaction_(transaction), indices_(indices), constraints_(constraints), - config_(config) { + config_(config), + schema_validator_(&schema_validator) { // We have to fix the bounds that the user provided to us. If the user // provided only one bound we should make sure that only values of that type // are returned by the iterator. We ensure this by supplying either an diff --git a/src/storage/v2/indices.hpp b/src/storage/v2/indices.hpp index eed22e8b5..64e1501f3 100644 --- a/src/storage/v2/indices.hpp +++ b/src/storage/v2/indices.hpp @@ -17,6 +17,7 @@ #include "storage/v2/config.hpp" #include "storage/v2/property_value.hpp" +#include "storage/v2/schema_validator.hpp" #include "storage/v2/transaction.hpp" #include "storage/v2/vertex_accessor.hpp" #include "utils/bound.hpp" @@ -51,8 +52,8 @@ class LabelIndex { }; public: - LabelIndex(Indices *indices, Constraints *constraints, Config::Items config) - : indices_(indices), constraints_(constraints), config_(config) {} + LabelIndex(Indices *indices, Constraints *constraints, Config::Items config, const SchemaValidator &schema_validator) + : indices_(indices), constraints_(constraints), config_(config), schema_validator_{&schema_validator} {} /// @throw std::bad_alloc void UpdateOnAddLabel(LabelId label, Vertex *vertex, const Transaction &tx); @@ -72,7 +73,7 @@ class LabelIndex { class Iterable { public: Iterable(utils::SkipList::Accessor index_accessor, LabelId label, View view, Transaction *transaction, - Indices *indices, Constraints *constraints, Config::Items config); + Indices *indices, Constraints *constraints, Config::Items config, const SchemaValidator &schema_validator); class Iterator { public: @@ -105,13 +106,14 @@ class LabelIndex { Indices *indices_; Constraints *constraints_; Config::Items config_; + const SchemaValidator *schema_validator_; }; /// Returns an self with vertices visible from the given transaction. Iterable Vertices(LabelId label, View view, Transaction *transaction) { auto it = index_.find(label); MG_ASSERT(it != index_.end(), "Index for label {} doesn't exist", label.AsUint()); - return Iterable(it->second.access(), label, view, transaction, indices_, constraints_, config_); + return {it->second.access(), label, view, transaction, indices_, constraints_, config_, *schema_validator_}; } int64_t ApproximateVertexCount(LabelId label) { @@ -129,6 +131,7 @@ class LabelIndex { Indices *indices_; Constraints *constraints_; Config::Items config_; + const SchemaValidator *schema_validator_; }; class LabelPropertyIndex { @@ -146,8 +149,9 @@ class LabelPropertyIndex { }; public: - LabelPropertyIndex(Indices *indices, Constraints *constraints, Config::Items config) - : indices_(indices), constraints_(constraints), config_(config) {} + LabelPropertyIndex(Indices *indices, Constraints *constraints, Config::Items config, + const SchemaValidator &schema_validator) + : indices_(indices), constraints_(constraints), config_(config), schema_validator_{&schema_validator} {} /// @throw std::bad_alloc void UpdateOnAddLabel(LabelId label, Vertex *vertex, const Transaction &tx); @@ -171,7 +175,7 @@ class LabelPropertyIndex { Iterable(utils::SkipList::Accessor index_accessor, LabelId label, PropertyId property, const std::optional> &lower_bound, const std::optional> &upper_bound, View view, Transaction *transaction, - Indices *indices, Constraints *constraints, Config::Items config); + Indices *indices, Constraints *constraints, Config::Items config, const SchemaValidator &schema_validator); class Iterator { public: @@ -208,16 +212,17 @@ class LabelPropertyIndex { Indices *indices_; Constraints *constraints_; Config::Items config_; + const SchemaValidator *schema_validator_; }; Iterable Vertices(LabelId label, PropertyId property, const std::optional> &lower_bound, - const std::optional> &upper_bound, View view, - Transaction *transaction) { + const std::optional> &upper_bound, View view, Transaction *transaction, + const SchemaValidator &schema_validator_) { auto it = index_.find({label, property}); MG_ASSERT(it != index_.end(), "Index for label {} and property {} doesn't exist", label.AsUint(), property.AsUint()); - return Iterable(it->second.access(), label, property, lower_bound, upper_bound, view, transaction, indices_, - constraints_, config_); + return {it->second.access(), label, property, lower_bound, upper_bound, view, + transaction, indices_, constraints_, config_, schema_validator_}; } int64_t ApproximateVertexCount(LabelId label, PropertyId property) const { @@ -246,11 +251,13 @@ class LabelPropertyIndex { Indices *indices_; Constraints *constraints_; Config::Items config_; + const SchemaValidator *schema_validator_; }; struct Indices { - Indices(Constraints *constraints, Config::Items config) - : label_index(this, constraints, config), label_property_index(this, constraints, config) {} + Indices(Constraints *constraints, Config::Items config, const SchemaValidator &schema_validator) + : label_index(this, constraints, config, schema_validator), + label_property_index(this, constraints, config, schema_validator) {} // Disable copy and move because members hold pointer to `this`. Indices(const Indices &) = delete; diff --git a/src/storage/v2/replication/replication_server.cpp b/src/storage/v2/replication/replication_server.cpp index fed501d6e..f8f533cac 100644 --- a/src/storage/v2/replication/replication_server.cpp +++ b/src/storage/v2/replication/replication_server.cpp @@ -166,9 +166,10 @@ void Storage::ReplicationServer::SnapshotHandler(slk::Reader *req_reader, slk::B storage_->edges_.clear(); storage_->constraints_ = Constraints(); - storage_->indices_.label_index = LabelIndex(&storage_->indices_, &storage_->constraints_, storage_->config_.items); - storage_->indices_.label_property_index = - LabelPropertyIndex(&storage_->indices_, &storage_->constraints_, storage_->config_.items); + storage_->indices_.label_index = + LabelIndex(&storage_->indices_, &storage_->constraints_, storage_->config_.items, storage_->schema_validator_); + storage_->indices_.label_property_index = LabelPropertyIndex(&storage_->indices_, &storage_->constraints_, + storage_->config_.items, storage_->schema_validator_); try { spdlog::debug("Loading snapshot"); auto recovered_snapshot = durability::LoadSnapshot(*maybe_snapshot_path, &storage_->vertices_, &storage_->edges_, @@ -473,7 +474,8 @@ uint64_t Storage::ReplicationServer::ReadAndApplyDelta(durability::BaseDecoder * &transaction->transaction_, &storage_->indices_, &storage_->constraints_, - storage_->config_.items}; + storage_->config_.items, + storage_->schema_validator_}; auto ret = ea.SetProperty(transaction->NameToProperty(delta.vertex_edge_set_property.property), delta.vertex_edge_set_property.value); diff --git a/src/storage/v2/schema_validator.cpp b/src/storage/v2/schema_validator.cpp new file mode 100644 index 000000000..4c3689a9f --- /dev/null +++ b/src/storage/v2/schema_validator.cpp @@ -0,0 +1,106 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "storage/v2/schema_validator.hpp" + +#include +#include +#include + +#include "storage/v2/schemas.hpp" + +namespace memgraph::storage { + +bool operator==(const SchemaViolation &lhs, const SchemaViolation &rhs) { + return lhs.status == rhs.status && lhs.label == rhs.label && + lhs.violated_schema_property == rhs.violated_schema_property && + lhs.violated_property_value == rhs.violated_property_value; +} + +SchemaViolation::SchemaViolation(ValidationStatus status, LabelId label) : status{status}, label{label} {} + +SchemaViolation::SchemaViolation(ValidationStatus status, LabelId label, SchemaProperty violated_schema_property) + : status{status}, label{label}, violated_schema_property{violated_schema_property} {} + +SchemaViolation::SchemaViolation(ValidationStatus status, LabelId label, SchemaProperty violated_schema_property, + PropertyValue violated_property_value) + : status{status}, + label{label}, + violated_schema_property{violated_schema_property}, + violated_property_value{violated_property_value} {} + +SchemaValidator::SchemaValidator(Schemas &schemas) : schemas_{schemas} {} + +[[nodiscard]] std::optional SchemaValidator::ValidateVertexCreate( + LabelId primary_label, const std::vector &labels, + const std::vector> &properties) const { + // Schema on primary label + const auto *schema = schemas_.GetSchema(primary_label); + if (schema == nullptr) { + return SchemaViolation(SchemaViolation::ValidationStatus::NO_SCHEMA_DEFINED_FOR_LABEL, primary_label); + } + + // Is there another primary label among secondary labels + for (const auto &secondary_label : labels) { + if (schemas_.GetSchema(secondary_label)) { + return SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_SECONDARY_LABEL_IS_PRIMARY, secondary_label); + } + } + + // Check only properties defined by schema + for (const auto &schema_type : schema->second) { + // Check schema property existence + auto property_pair = std::ranges::find_if( + properties, [schema_property_id = schema_type.property_id](const auto &property_type_value) { + return property_type_value.first == schema_property_id; + }); + if (property_pair == properties.end()) { + return SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_HAS_NO_PRIMARY_PROPERTY, primary_label, + schema_type); + } + + // Check schema property type + if (auto property_schema_type = PropertyTypeToSchemaType(property_pair->second); + property_schema_type && *property_schema_type != schema_type.type) { + return SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_PROPERTY_WRONG_TYPE, primary_label, schema_type, + property_pair->second); + } + } + + return std::nullopt; +} + +[[nodiscard]] std::optional SchemaValidator::ValidatePropertyUpdate( + const LabelId primary_label, const PropertyId property_id) const { + // Verify existence of schema on primary label + const auto *schema = schemas_.GetSchema(primary_label); + MG_ASSERT(schema, "Cannot validate against non existing schema!"); + + // Verify that updating property is not part of schema + if (const auto schema_property = std::ranges::find_if( + schema->second, + [property_id](const auto &schema_property) { return property_id == schema_property.property_id; }); + schema_property != schema->second.end()) { + return SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_UPDATE_PRIMARY_KEY, primary_label, + *schema_property); + } + return std::nullopt; +} + +[[nodiscard]] std::optional SchemaValidator::ValidateLabelUpdate(const LabelId label) const { + const auto *schema = schemas_.GetSchema(label); + if (schema) { + return SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_MODIFY_PRIMARY_LABEL, label); + } + return std::nullopt; +} + +} // namespace memgraph::storage diff --git a/src/storage/v2/schema_validator.hpp b/src/storage/v2/schema_validator.hpp new file mode 100644 index 000000000..6ad260138 --- /dev/null +++ b/src/storage/v2/schema_validator.hpp @@ -0,0 +1,69 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include + +#include "storage/v2/id_types.hpp" +#include "storage/v2/property_value.hpp" +#include "storage/v2/result.hpp" +#include "storage/v2/schemas.hpp" + +namespace memgraph::storage { + +struct SchemaViolation { + enum class ValidationStatus : uint8_t { + VERTEX_HAS_NO_PRIMARY_PROPERTY, + NO_SCHEMA_DEFINED_FOR_LABEL, + VERTEX_PROPERTY_WRONG_TYPE, + VERTEX_UPDATE_PRIMARY_KEY, + VERTEX_MODIFY_PRIMARY_LABEL, + VERTEX_SECONDARY_LABEL_IS_PRIMARY, + }; + + SchemaViolation(ValidationStatus status, LabelId label); + + SchemaViolation(ValidationStatus status, LabelId label, SchemaProperty violated_schema_property); + + SchemaViolation(ValidationStatus status, LabelId label, SchemaProperty violated_schema_property, + PropertyValue violated_property_value); + + friend bool operator==(const SchemaViolation &lhs, const SchemaViolation &rhs); + + ValidationStatus status; + LabelId label; + std::optional violated_schema_property; + std::optional violated_property_value; +}; + +class SchemaValidator { + public: + explicit SchemaValidator(Schemas &schemas); + + [[nodiscard]] std::optional ValidateVertexCreate( + LabelId primary_label, const std::vector &labels, + const std::vector> &properties) const; + + [[nodiscard]] std::optional ValidatePropertyUpdate(LabelId primary_label, + PropertyId property_id) const; + + [[nodiscard]] std::optional ValidateLabelUpdate(LabelId label) const; + + private: + storage::Schemas &schemas_; +}; + +template +using ResultSchema = utils::BasicResult, TValue>; + +} // namespace memgraph::storage diff --git a/src/storage/v2/schemas.cpp b/src/storage/v2/schemas.cpp index 1bec8455a..167d2946f 100644 --- a/src/storage/v2/schemas.cpp +++ b/src/storage/v2/schemas.cpp @@ -9,22 +9,18 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. +#include "storage/v2/schemas.hpp" + #include -#include #include #include "storage/v2/property_value.hpp" -#include "storage/v2/schemas.hpp" namespace memgraph::storage { -SchemaViolation::SchemaViolation(ValidationStatus status, LabelId label) : status{status}, label{label} {} -SchemaViolation::SchemaViolation(ValidationStatus status, LabelId label, SchemaProperty violated_type) - : status{status}, label{label}, violated_type{violated_type} {} - -SchemaViolation::SchemaViolation(ValidationStatus status, LabelId label, SchemaProperty violated_type, - PropertyValue violated_property_value) - : status{status}, label{label}, violated_type{violated_type}, violated_property_value{violated_property_value} {} +bool operator==(const SchemaProperty &lhs, const SchemaProperty &rhs) { + return lhs.property_id == rhs.property_id && lhs.type == rhs.type; +} Schemas::SchemasList Schemas::ListSchemas() const { Schemas::SchemasList ret; @@ -34,11 +30,11 @@ Schemas::SchemasList Schemas::ListSchemas() const { return ret; } -std::optional Schemas::GetSchema(const LabelId primary_label) const { +const Schemas::Schema *Schemas::GetSchema(const LabelId primary_label) const { if (auto schema_map = schemas_.find(primary_label); schema_map != schemas_.end()) { - return Schema{schema_map->first, schema_map->second}; + return &*schema_map; } - return std::nullopt; + return nullptr; } bool Schemas::CreateSchema(const LabelId primary_label, const std::vector &schemas_types) { @@ -51,34 +47,6 @@ bool Schemas::CreateSchema(const LabelId primary_label, const std::vector Schemas::ValidateVertex(const LabelId primary_label, const Vertex &vertex) { - // TODO Check for multiple defined primary labels - const auto schema = schemas_.find(primary_label); - if (schema == schemas_.end()) { - return SchemaViolation(SchemaViolation::ValidationStatus::NO_SCHEMA_DEFINED_FOR_LABEL, primary_label); - } - if (!utils::Contains(vertex.labels, primary_label)) { - return SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_HAS_NO_PRIMARY_LABEL, primary_label); - } - - for (const auto &schema_type : schema->second) { - if (!vertex.properties.HasProperty(schema_type.property_id)) { - return SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_HAS_NO_PROPERTY, primary_label, schema_type); - } - // Property type check - // TODO Can this be replaced with just property id check? - if (auto vertex_property = vertex.properties.GetProperty(schema_type.property_id); - PropertyTypeToSchemaType(vertex_property) != schema_type.type) { - return SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_PROPERTY_WRONG_TYPE, primary_label, schema_type, - vertex_property); - } - } - // TODO after the introduction of vertex hashing introduce check for vertex - // primary key uniqueness - - return std::nullopt; -} - std::optional PropertyTypeToSchemaType(const PropertyValue &property_value) { switch (property_value.type()) { case PropertyValue::Type::Bool: { diff --git a/src/storage/v2/schemas.hpp b/src/storage/v2/schemas.hpp index 898288b6b..c248c5b12 100644 --- a/src/storage/v2/schemas.hpp +++ b/src/storage/v2/schemas.hpp @@ -19,51 +19,25 @@ #include "common/types.hpp" #include "storage/v2/id_types.hpp" -#include "storage/v2/indices.hpp" #include "storage/v2/property_value.hpp" #include "storage/v2/temporal.hpp" -#include "storage/v2/transaction.hpp" -#include "storage/v2/vertex.hpp" #include "utils/result.hpp" namespace memgraph::storage { -class SchemaViolationException : public utils::BasicException { - using utils::BasicException::BasicException; -}; - struct SchemaProperty { PropertyId property_id; common::SchemaType type; -}; -struct SchemaViolation { - enum class ValidationStatus : uint8_t { - VERTEX_HAS_NO_PRIMARY_LABEL, - VERTEX_HAS_NO_PROPERTY, - NO_SCHEMA_DEFINED_FOR_LABEL, - VERTEX_PROPERTY_WRONG_TYPE - }; - - SchemaViolation(ValidationStatus status, LabelId label); - - SchemaViolation(ValidationStatus status, LabelId label, SchemaProperty violated_type); - - SchemaViolation(ValidationStatus status, LabelId label, SchemaProperty violated_type, - PropertyValue violated_property_value); - - ValidationStatus status; - LabelId label; - std::optional violated_type; - std::optional violated_property_value; + friend bool operator==(const SchemaProperty &lhs, const SchemaProperty &rhs); }; /// Structure that represents a collection of schemas /// Schema can be mapped under only one label => primary label class Schemas { public: - using Schema = std::pair>; using SchemasMap = std::unordered_map>; + using Schema = SchemasMap::value_type; using SchemasList = std::vector; Schemas() = default; @@ -75,7 +49,7 @@ class Schemas { [[nodiscard]] SchemasList ListSchemas() const; - [[nodiscard]] std::optional GetSchema(LabelId primary_label) const; + [[nodiscard]] const Schema *GetSchema(LabelId primary_label) const; // Returns true if it was successfully created or false if the schema // already exists @@ -85,8 +59,6 @@ class Schemas { // does not exist [[nodiscard]] bool DropSchema(LabelId label); - [[nodiscard]] std::optional ValidateVertex(LabelId primary_label, const Vertex &vertex); - private: SchemasMap schemas_; }; diff --git a/src/storage/v2/storage.cpp b/src/storage/v2/storage.cpp index bd95e4f9d..ff03c6c77 100644 --- a/src/storage/v2/storage.cpp +++ b/src/storage/v2/storage.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -26,18 +27,22 @@ #include "storage/v2/durability/snapshot.hpp" #include "storage/v2/durability/wal.hpp" #include "storage/v2/edge_accessor.hpp" +#include "storage/v2/id_types.hpp" #include "storage/v2/indices.hpp" #include "storage/v2/mvcc.hpp" #include "storage/v2/replication/config.hpp" #include "storage/v2/replication/enums.hpp" #include "storage/v2/replication/replication_persistence_helper.hpp" +#include "storage/v2/schema_validator.hpp" #include "storage/v2/schemas.hpp" #include "storage/v2/transaction.hpp" #include "storage/v2/vertex_accessor.hpp" +#include "utils/exceptions.hpp" #include "utils/file.hpp" #include "utils/logging.hpp" #include "utils/memory_tracker.hpp" #include "utils/message.hpp" +#include "utils/result.hpp" #include "utils/rw_lock.hpp" #include "utils/spin_lock.hpp" #include "utils/stat.hpp" @@ -71,9 +76,9 @@ std::string RegisterReplicaErrorToString(Storage::RegisterReplicaError error) { auto AdvanceToVisibleVertex(utils::SkipList::Iterator it, utils::SkipList::Iterator end, std::optional *vertex, Transaction *tx, View view, Indices *indices, - Constraints *constraints, Config::Items config) { + Constraints *constraints, Config::Items config, const SchemaValidator &schema_validator) { while (it != end) { - *vertex = VertexAccessor::Create(&*it, tx, indices, constraints, config, view); + *vertex = VertexAccessor::Create(&*it, tx, indices, constraints, config, schema_validator, view); if (!*vertex) { ++it; continue; @@ -86,14 +91,14 @@ auto AdvanceToVisibleVertex(utils::SkipList::Iterator it, utils::SkipLis AllVerticesIterable::Iterator::Iterator(AllVerticesIterable *self, utils::SkipList::Iterator it) : self_(self), it_(AdvanceToVisibleVertex(it, self->vertices_accessor_.end(), &self->vertex_, self->transaction_, self->view_, - self->indices_, self_->constraints_, self->config_)) {} + self->indices_, self_->constraints_, self->config_, *self->schema_validator_)) {} VertexAccessor AllVerticesIterable::Iterator::operator*() const { return *self_->vertex_; } AllVerticesIterable::Iterator &AllVerticesIterable::Iterator::operator++() { ++it_; it_ = AdvanceToVisibleVertex(it_, self_->vertices_accessor_.end(), &self_->vertex_, self_->transaction_, self_->view_, - self_->indices_, self_->constraints_, self_->config_); + self_->indices_, self_->constraints_, self_->config_, *self_->schema_validator_); return *this; } @@ -314,7 +319,8 @@ bool VerticesIterable::Iterator::operator==(const Iterator &other) const { } Storage::Storage(Config config) - : indices_(&constraints_, config.items), + : schema_validator_(schemas_), + indices_(&constraints_, config.items, schema_validator_), isolation_level_(config.transaction.isolation_level), config_(config), snapshot_directory_(config_.durability.storage_directory / durability::kSnapshotDirectory), @@ -486,7 +492,8 @@ Storage::Accessor::~Accessor() { FinalizeTransaction(); } -VertexAccessor Storage::Accessor::CreateVertex() { +// TODO Remove when import csv is fixed +[[deprecated]] VertexAccessor Storage::Accessor::CreateVertex() { OOMExceptionEnabler oom_exception; auto gid = storage_->vertex_id_.fetch_add(1, std::memory_order_acq_rel); auto acc = storage_->vertices_.access(); @@ -496,33 +503,69 @@ VertexAccessor Storage::Accessor::CreateVertex() { MG_ASSERT(it != acc.end(), "Invalid Vertex accessor!"); delta->prev.Set(&*it); - return {&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_}; + return {&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_, storage_->schema_validator_}; } +// TODO Remove when replication is fixed VertexAccessor Storage::Accessor::CreateVertex(storage::Gid gid) { OOMExceptionEnabler oom_exception; // NOTE: When we update the next `vertex_id_` here we perform a RMW // (read-modify-write) operation that ISN'T atomic! But, that isn't an issue // because this function is only called from the replication delta applier - // that runs single-threadedly and while this instance is set-up to apply + // that runs single-threaded and while this instance is set-up to apply // threads (it is the replica), it is guaranteed that no other writes are // possible. storage_->vertex_id_.store(std::max(storage_->vertex_id_.load(std::memory_order_acquire), gid.AsUint() + 1), std::memory_order_release); auto acc = storage_->vertices_.access(); - auto delta = CreateDeleteObjectDelta(&transaction_); - auto [it, inserted] = acc.insert(Vertex{gid, delta}); + auto *delta = CreateDeleteObjectDelta(&transaction_); + auto [it, inserted] = acc.insert(Vertex{gid}); MG_ASSERT(inserted, "The vertex must be inserted here!"); MG_ASSERT(it != acc.end(), "Invalid Vertex accessor!"); delta->prev.Set(&*it); - return VertexAccessor(&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_); + return {&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_, storage_->schema_validator_}; +} + +ResultSchema Storage::Accessor::CreateVertexAndValidate( + storage::LabelId primary_label, const std::vector &labels, + const std::vector> &properties) { + auto maybe_schema_violation = GetSchemaValidator().ValidateVertexCreate(primary_label, labels, properties); + if (maybe_schema_violation) { + return {std::move(*maybe_schema_violation)}; + } + OOMExceptionEnabler oom_exception; + auto gid = storage_->vertex_id_.fetch_add(1, std::memory_order_acq_rel); + auto acc = storage_->vertices_.access(); + auto *delta = CreateDeleteObjectDelta(&transaction_); + auto [it, inserted] = acc.insert(Vertex{storage::Gid::FromUint(gid), delta, primary_label}); + MG_ASSERT(inserted, "The vertex must be inserted here!"); + MG_ASSERT(it != acc.end(), "Invalid Vertex accessor!"); + delta->prev.Set(&*it); + + auto va = VertexAccessor{ + &*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_, storage_->schema_validator_}; + for (const auto label : labels) { + const auto maybe_error = va.AddLabel(label); + if (maybe_error.HasError()) { + return {maybe_error.GetError()}; + } + } + // Set properties + for (auto [property_id, property_value] : properties) { + const auto maybe_error = va.SetProperty(property_id, property_value); + if (maybe_error.HasError()) { + return {maybe_error.GetError()}; + } + } + return va; } std::optional Storage::Accessor::FindVertex(Gid gid, View view) { auto acc = storage_->vertices_.access(); auto it = acc.find(gid); if (it == acc.end()) return std::nullopt; - return VertexAccessor::Create(&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_, view); + return VertexAccessor::Create(&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_, + storage_->schema_validator_, view); } Result> Storage::Accessor::DeleteVertex(VertexAccessor *vertex) { @@ -545,7 +588,7 @@ Result> Storage::Accessor::DeleteVertex(VertexAcce vertex_ptr->deleted = true; return std::make_optional(vertex_ptr, &transaction_, &storage_->indices_, &storage_->constraints_, - config_, true); + config_, storage_->schema_validator_, true); } Result>>> Storage::Accessor::DetachDeleteVertex( @@ -575,7 +618,7 @@ Result>>> Stor for (const auto &item : in_edges) { auto [edge_type, from_vertex, edge] = item; EdgeAccessor e(edge, edge_type, from_vertex, vertex_ptr, &transaction_, &storage_->indices_, - &storage_->constraints_, config_); + &storage_->constraints_, config_, storage_->schema_validator_); auto ret = DeleteEdge(&e); if (ret.HasError()) { MG_ASSERT(ret.GetError() == Error::SERIALIZATION_ERROR, "Invalid database state!"); @@ -589,7 +632,7 @@ Result>>> Stor for (const auto &item : out_edges) { auto [edge_type, to_vertex, edge] = item; EdgeAccessor e(edge, edge_type, vertex_ptr, to_vertex, &transaction_, &storage_->indices_, &storage_->constraints_, - config_); + config_, storage_->schema_validator_); auto ret = DeleteEdge(&e); if (ret.HasError()) { MG_ASSERT(ret.GetError() == Error::SERIALIZATION_ERROR, "Invalid database state!"); @@ -615,7 +658,8 @@ Result>>> Stor vertex_ptr->deleted = true; return std::make_optional( - VertexAccessor{vertex_ptr, &transaction_, &storage_->indices_, &storage_->constraints_, config_, true}, + VertexAccessor{vertex_ptr, &transaction_, &storage_->indices_, &storage_->constraints_, config_, + storage_->schema_validator_, true}, std::move(deleted_edges)); } @@ -675,7 +719,7 @@ Result Storage::Accessor::CreateEdge(VertexAccessor *from, VertexA storage_->edge_count_.fetch_add(1, std::memory_order_acq_rel); return EdgeAccessor(edge, edge_type, from_vertex, to_vertex, &transaction_, &storage_->indices_, - &storage_->constraints_, config_); + &storage_->constraints_, config_, storage_->schema_validator_); } Result Storage::Accessor::CreateEdge(VertexAccessor *from, VertexAccessor *to, EdgeTypeId edge_type, @@ -743,7 +787,7 @@ Result Storage::Accessor::CreateEdge(VertexAccessor *from, VertexA storage_->edge_count_.fetch_add(1, std::memory_order_acq_rel); return EdgeAccessor(edge, edge_type, from_vertex, to_vertex, &transaction_, &storage_->indices_, - &storage_->constraints_, config_); + &storage_->constraints_, config_, storage_->schema_validator_); } Result> Storage::Accessor::DeleteEdge(EdgeAccessor *edge) { @@ -827,7 +871,8 @@ Result> Storage::Accessor::DeleteEdge(EdgeAccessor * storage_->edge_count_.fetch_add(-1, std::memory_order_acq_rel); return std::make_optional(edge_ref, edge_type, from_vertex, to_vertex, &transaction_, - &storage_->indices_, &storage_->constraints_, config_, true); + &storage_->indices_, &storage_->constraints_, config_, + storage_->schema_validator_, true); } const std::string &Storage::Accessor::LabelToName(LabelId label) const { return storage_->LabelToName(label); } @@ -871,11 +916,11 @@ utils::BasicResult Storage::Accessor::Commit( auto validation_result = ValidateExistenceConstraints(*prev.vertex, storage_->constraints_); if (validation_result) { Abort(); - return *validation_result; + return {*validation_result}; } } - // Result of validating the vertex against unqiue constraints. It has to be + // Result of validating the vertex against unique constraints. It has to be // declared outside of the critical section scope because its value is // tested for Abort call which has to be done out of the scope. std::optional unique_constraint_violation; @@ -956,7 +1001,7 @@ utils::BasicResult Storage::Accessor::Commit( if (unique_constraint_violation) { Abort(); - return *unique_constraint_violation; + return {*unique_constraint_violation}; } } is_transaction_active_ = false; @@ -1257,6 +1302,8 @@ UniqueConstraints::DeletionStatus Storage::DropUniqueConstraint( return UniqueConstraints::DeletionStatus::SUCCESS; } +const SchemaValidator &Storage::Accessor::GetSchemaValidator() const { return storage_->schema_validator_; } + ConstraintsInfo Storage::ListAllConstraints() const { std::shared_lock storage_guard_(main_lock_); return {ListExistenceConstraints(constraints_), constraints_.unique_constraints.ListConstraints()}; @@ -1267,7 +1314,7 @@ SchemasInfo Storage::ListAllSchemas() const { return {schemas_.ListSchemas()}; } -std::optional Storage::GetSchema(const LabelId primary_label) const { +const Schemas::Schema *Storage::GetSchema(const LabelId primary_label) const { std::shared_lock storage_guard_(main_lock_); return schemas_.GetSchema(primary_label); } @@ -1294,21 +1341,22 @@ VerticesIterable Storage::Accessor::Vertices(LabelId label, View view) { } VerticesIterable Storage::Accessor::Vertices(LabelId label, PropertyId property, View view) { - return VerticesIterable(storage_->indices_.label_property_index.Vertices(label, property, std::nullopt, std::nullopt, - view, &transaction_)); + return VerticesIterable(storage_->indices_.label_property_index.Vertices( + label, property, std::nullopt, std::nullopt, view, &transaction_, storage_->schema_validator_)); } VerticesIterable Storage::Accessor::Vertices(LabelId label, PropertyId property, const PropertyValue &value, View view) { return VerticesIterable(storage_->indices_.label_property_index.Vertices( - label, property, utils::MakeBoundInclusive(value), utils::MakeBoundInclusive(value), view, &transaction_)); + label, property, utils::MakeBoundInclusive(value), utils::MakeBoundInclusive(value), view, &transaction_, + storage_->schema_validator_)); } VerticesIterable Storage::Accessor::Vertices(LabelId label, PropertyId property, const std::optional> &lower_bound, const std::optional> &upper_bound, View view) { - return VerticesIterable( - storage_->indices_.label_property_index.Vertices(label, property, lower_bound, upper_bound, view, &transaction_)); + return VerticesIterable(storage_->indices_.label_property_index.Vertices( + label, property, lower_bound, upper_bound, view, &transaction_, storage_->schema_validator_)); } Transaction Storage::CreateTransaction(IsolationLevel isolation_level) { @@ -1836,8 +1884,8 @@ utils::BasicResult Storage::CreateSnapshot() { // Create snapshot. durability::CreateSnapshot(&transaction, snapshot_directory_, wal_directory_, config_.durability.snapshot_retention_count, &vertices_, &edges_, &name_id_mapper_, - &indices_, &constraints_, config_.items, uuid_, epoch_id_, epoch_history_, - &file_retainer_); + &indices_, &constraints_, config_.items, schema_validator_, uuid_, epoch_id_, + epoch_history_, &file_retainer_); // Finalize snapshot transaction. commit_log_->MarkFinished(transaction.start_timestamp); diff --git a/src/storage/v2/storage.hpp b/src/storage/v2/storage.hpp index e0f93dc6a..4addce561 100644 --- a/src/storage/v2/storage.hpp +++ b/src/storage/v2/storage.hpp @@ -34,6 +34,7 @@ #include "storage/v2/name_id_mapper.hpp" #include "storage/v2/property_value.hpp" #include "storage/v2/result.hpp" +#include "storage/v2/schema_validator.hpp" #include "storage/v2/schemas.hpp" #include "storage/v2/transaction.hpp" #include "storage/v2/vertex.hpp" @@ -72,6 +73,7 @@ class AllVerticesIterable final { Indices *indices_; Constraints *constraints_; Config::Items config_; + const SchemaValidator *schema_validator_; std::optional vertex_; public: @@ -92,13 +94,15 @@ class AllVerticesIterable final { }; AllVerticesIterable(utils::SkipList::Accessor vertices_accessor, Transaction *transaction, View view, - Indices *indices, Constraints *constraints, Config::Items config) + Indices *indices, Constraints *constraints, Config::Items config, + SchemaValidator *schema_validator) : vertices_accessor_(std::move(vertices_accessor)), transaction_(transaction), view_(view), indices_(indices), constraints_(constraints), - config_(config) {} + config_(config), + schema_validator_(schema_validator) {} Iterator begin() { return Iterator(this, vertices_accessor_.begin()); } Iterator end() { return Iterator(this, vertices_accessor_.end()); } @@ -220,15 +224,21 @@ class Storage final { ~Accessor(); - /// @throw std::bad_alloc VertexAccessor CreateVertex(); + VertexAccessor CreateVertex(storage::Gid gid); + + /// @throw std::bad_alloc + ResultSchema CreateVertexAndValidate( + storage::LabelId primary_label, const std::vector &labels, + const std::vector> &properties); + std::optional FindVertex(Gid gid, View view); VerticesIterable Vertices(View view) { return VerticesIterable(AllVerticesIterable(storage_->vertices_.access(), &transaction_, view, - &storage_->indices_, &storage_->constraints_, - storage_->config_.items)); + &storage_->indices_, &storage_->constraints_, storage_->config_.items, + &storage_->schema_validator_)); } VerticesIterable Vertices(LabelId label, View view); @@ -317,6 +327,8 @@ class Storage final { storage_->constraints_.unique_constraints.ListConstraints()}; } + const SchemaValidator &GetSchemaValidator() const; + SchemasInfo ListAllSchemas() const { return {storage_->schemas_.ListSchemas()}; } void AdvanceCommand(); @@ -334,7 +346,7 @@ class Storage final { private: /// @throw std::bad_alloc - VertexAccessor CreateVertex(storage::Gid gid); + VertexAccessor CreateVertex(storage::Gid gid, storage::LabelId primary_label); /// @throw std::bad_alloc Result CreateEdge(VertexAccessor *from, VertexAccessor *to, EdgeTypeId edge_type, storage::Gid gid); @@ -417,7 +429,7 @@ class Storage final { SchemasInfo ListAllSchemas() const; - std::optional GetSchema(LabelId primary_label) const; + const Schemas::Schema *GetSchema(LabelId primary_label) const; bool CreateSchema(LabelId primary_label, const std::vector &schemas_types); @@ -524,6 +536,7 @@ class Storage final { NameIdMapper name_id_mapper_; + SchemaValidator schema_validator_; Constraints constraints_; Indices indices_; Schemas schemas_; diff --git a/src/storage/v2/vertex.hpp b/src/storage/v2/vertex.hpp index 83f517c46..c2a63144f 100644 --- a/src/storage/v2/vertex.hpp +++ b/src/storage/v2/vertex.hpp @@ -19,18 +19,39 @@ #include "storage/v2/edge_ref.hpp" #include "storage/v2/id_types.hpp" #include "storage/v2/property_store.hpp" +#include "utils/algorithm.hpp" #include "utils/spin_lock.hpp" namespace memgraph::storage { struct Vertex { - Vertex(Gid gid, Delta *delta) : gid(gid), deleted(false), delta(delta) { + Vertex(Gid gid, Delta *delta, LabelId primary_label) + : gid(gid), primary_label{primary_label}, deleted(false), delta(delta) { + MG_ASSERT(delta == nullptr || delta->action == Delta::Action::DELETE_OBJECT, + "Vertex must be created with an initial DELETE_OBJECT delta!"); + } + + // TODO remove this when import replication is solved + Vertex(Gid gid, LabelId primary_label) : gid(gid), primary_label{primary_label}, deleted(false) { + MG_ASSERT(delta == nullptr || delta->action == Delta::Action::DELETE_OBJECT, + "Vertex must be created with an initial DELETE_OBJECT delta!"); + } + + // TODO remove this when import csv is solved + [[deprecated]] Vertex(Gid gid, Delta *delta) : gid(gid), deleted(false), delta(delta) { + MG_ASSERT(delta == nullptr || delta->action == Delta::Action::DELETE_OBJECT, + "Vertex must be created with an initial DELETE_OBJECT delta!"); + } + + // TODO remove this when import replication is solved + [[deprecated]] explicit Vertex(Gid gid) : gid(gid), deleted(false) { MG_ASSERT(delta == nullptr || delta->action == Delta::Action::DELETE_OBJECT, "Vertex must be created with an initial DELETE_OBJECT delta!"); } Gid gid; + LabelId primary_label; std::vector labels; PropertyStore properties; @@ -52,4 +73,8 @@ inline bool operator<(const Vertex &first, const Vertex &second) { return first. inline bool operator==(const Vertex &first, const Gid &second) { return first.gid == second; } inline bool operator<(const Vertex &first, const Gid &second) { return first.gid < second; } +inline bool VertexHasLabel(const Vertex &vertex, const LabelId label) { + return vertex.primary_label == label || utils::Contains(vertex.labels, label); +} + } // namespace memgraph::storage diff --git a/src/storage/v2/vertex_accessor.cpp b/src/storage/v2/vertex_accessor.cpp index 05ba1ebcc..dacdedb1b 100644 --- a/src/storage/v2/vertex_accessor.cpp +++ b/src/storage/v2/vertex_accessor.cpp @@ -18,6 +18,8 @@ #include "storage/v2/indices.hpp" #include "storage/v2/mvcc.hpp" #include "storage/v2/property_value.hpp" +#include "storage/v2/schema_validator.hpp" +#include "storage/v2/vertex.hpp" #include "utils/logging.hpp" #include "utils/memory_tracker.hpp" @@ -61,12 +63,13 @@ std::pair IsVisible(Vertex *vertex, Transaction *transaction, View v } // namespace detail std::optional VertexAccessor::Create(Vertex *vertex, Transaction *transaction, Indices *indices, - Constraints *constraints, Config::Items config, View view) { + Constraints *constraints, Config::Items config, + const SchemaValidator &schema_validator, View view) { if (const auto [exists, deleted] = detail::IsVisible(vertex, transaction, view); !exists || deleted) { return std::nullopt; } - return VertexAccessor{vertex, transaction, indices, constraints, config}; + return VertexAccessor{vertex, transaction, indices, constraints, config, schema_validator}; } bool VertexAccessor::IsVisible(View view) const { @@ -93,6 +96,28 @@ Result VertexAccessor::AddLabel(LabelId label) { return true; } +storage::ResultSchema VertexAccessor::AddLabelAndValidate(LabelId label) { + if (const auto maybe_violation_error = vertex_validator_.ValidateAddLabel(label); maybe_violation_error) { + return {*maybe_violation_error}; + } + utils::MemoryTracker::OutOfMemoryExceptionEnabler oom_exception; + std::lock_guard guard(vertex_->lock); + + if (!PrepareForWrite(transaction_, vertex_)) return {Error::SERIALIZATION_ERROR}; + + if (vertex_->deleted) return {Error::DELETED_OBJECT}; + + if (std::find(vertex_->labels.begin(), vertex_->labels.end(), label) != vertex_->labels.end()) return false; + + CreateAndLinkDelta(transaction_, vertex_, Delta::RemoveLabelTag(), label); + + vertex_->labels.push_back(label); + + UpdateOnAddLabel(indices_, label, vertex_, *transaction_); + + return true; +} + Result VertexAccessor::RemoveLabel(LabelId label) { std::lock_guard guard(vertex_->lock); @@ -110,6 +135,26 @@ Result VertexAccessor::RemoveLabel(LabelId label) { return true; } +ResultSchema VertexAccessor::RemoveLabelAndValidate(LabelId label) { + if (const auto maybe_violation_error = vertex_validator_.ValidateRemoveLabel(label); maybe_violation_error) { + return {*maybe_violation_error}; + } + std::lock_guard guard(vertex_->lock); + + if (!PrepareForWrite(transaction_, vertex_)) return {Error::SERIALIZATION_ERROR}; + + if (vertex_->deleted) return {Error::DELETED_OBJECT}; + + auto it = std::find(vertex_->labels.begin(), vertex_->labels.end(), label); + if (it == vertex_->labels.end()) return false; + + CreateAndLinkDelta(transaction_, vertex_, Delta::AddLabelTag(), label); + + std::swap(*it, *vertex_->labels.rbegin()); + vertex_->labels.pop_back(); + return true; +} + Result VertexAccessor::HasLabel(LabelId label, View view) const { bool exists = true; bool deleted = false; @@ -118,7 +163,7 @@ Result VertexAccessor::HasLabel(LabelId label, View view) const { { std::lock_guard guard(vertex_->lock); deleted = vertex_->deleted; - has_label = std::find(vertex_->labels.begin(), vertex_->labels.end(), label) != vertex_->labels.end(); + has_label = VertexHasLabel(*vertex_, label); delta = vertex_->delta; } ApplyDeltasForRead(transaction_, delta, view, [&exists, &deleted, &has_label, label](const Delta &delta) { @@ -158,6 +203,40 @@ Result VertexAccessor::HasLabel(LabelId label, View view) const { return has_label; } +Result VertexAccessor::PrimaryLabel(const View view) const { + bool exists = true; + bool deleted = false; + Delta *delta = nullptr; + { + std::lock_guard guard(vertex_->lock); + deleted = vertex_->deleted; + delta = vertex_->delta; + } + ApplyDeltasForRead(transaction_, delta, view, [&exists, &deleted](const Delta &delta) { + switch (delta.action) { + case Delta::Action::DELETE_OBJECT: { + exists = false; + break; + } + case Delta::Action::RECREATE_OBJECT: { + deleted = false; + break; + } + case Delta::Action::ADD_LABEL: + case Delta::Action::REMOVE_LABEL: + case Delta::Action::SET_PROPERTY: + case Delta::Action::ADD_IN_EDGE: + case Delta::Action::ADD_OUT_EDGE: + case Delta::Action::REMOVE_IN_EDGE: + case Delta::Action::REMOVE_OUT_EDGE: + break; + } + }); + if (!exists) return Error::NONEXISTENT_OBJECT; + if (!for_deleted_ && deleted) return Error::DELETED_OBJECT; + return vertex_->primary_label; +} + Result> VertexAccessor::Labels(View view) const { bool exists = true; bool deleted = false; @@ -230,6 +309,36 @@ Result VertexAccessor::SetProperty(PropertyId property, const Pro return std::move(current_value); } +ResultSchema VertexAccessor::SetPropertyAndValidate(PropertyId property, const PropertyValue &value) { + if (auto maybe_violation_error = vertex_validator_.ValidatePropertyUpdate(property); maybe_violation_error) { + return {*maybe_violation_error}; + } + utils::MemoryTracker::OutOfMemoryExceptionEnabler oom_exception; + std::lock_guard guard(vertex_->lock); + + if (!PrepareForWrite(transaction_, vertex_)) { + return {Error::SERIALIZATION_ERROR}; + } + + if (vertex_->deleted) { + return {Error::DELETED_OBJECT}; + } + + auto current_value = vertex_->properties.GetProperty(property); + // We could skip setting the value if the previous one is the same to the new + // one. This would save some memory as a delta would not be created as well as + // avoid copying the value. The reason we are not doing that is because the + // current code always follows the logical pattern of "create a delta" and + // "modify in-place". Additionally, the created delta will make other + // transactions get a SERIALIZATION_ERROR. + CreateAndLinkDelta(transaction_, vertex_, Delta::SetPropertyTag(), property, current_value); + vertex_->properties.SetProperty(property, value); + + UpdateOnSetProperty(indices_, property, value, vertex_, *transaction_); + + return std::move(current_value); +} + Result> VertexAccessor::ClearProperties() { std::lock_guard guard(vertex_->lock); @@ -414,7 +523,8 @@ Result> VertexAccessor::InEdges(View view, const std:: ret.reserve(in_edges.size()); for (const auto &item : in_edges) { const auto &[edge_type, from_vertex, edge] = item; - ret.emplace_back(edge, edge_type, from_vertex, vertex_, transaction_, indices_, constraints_, config_); + ret.emplace_back(edge, edge_type, from_vertex, vertex_, transaction_, indices_, constraints_, config_, + *vertex_validator_.schema_validator); } return std::move(ret); } @@ -494,7 +604,8 @@ Result> VertexAccessor::OutEdges(View view, const std: ret.reserve(out_edges.size()); for (const auto &item : out_edges) { const auto &[edge_type, to_vertex, edge] = item; - ret.emplace_back(edge, edge_type, vertex_, to_vertex, transaction_, indices_, constraints_, config_); + ret.emplace_back(edge, edge_type, vertex_, to_vertex, transaction_, indices_, constraints_, config_, + *vertex_validator_.schema_validator); } return std::move(ret); } @@ -575,4 +686,21 @@ Result VertexAccessor::OutDegree(View view) const { return degree; } +VertexAccessor::VertexValidator::VertexValidator(const SchemaValidator &schema_validator, const Vertex *vertex) + : schema_validator{&schema_validator}, vertex_{vertex} {} + +[[nodiscard]] std::optional VertexAccessor::VertexValidator::ValidatePropertyUpdate( + PropertyId property_id) const { + MG_ASSERT(vertex_ != nullptr, "Cannot validate vertex which is nullptr"); + return schema_validator->ValidatePropertyUpdate(vertex_->primary_label, property_id); +}; + +[[nodiscard]] std::optional VertexAccessor::VertexValidator::ValidateAddLabel(LabelId label) const { + return schema_validator->ValidateLabelUpdate(label); +} + +[[nodiscard]] std::optional VertexAccessor::VertexValidator::ValidateRemoveLabel(LabelId label) const { + return schema_validator->ValidateLabelUpdate(label); +} + } // namespace memgraph::storage diff --git a/src/storage/v2/vertex_accessor.hpp b/src/storage/v2/vertex_accessor.hpp index 840eec910..eed4cb7e5 100644 --- a/src/storage/v2/vertex_accessor.hpp +++ b/src/storage/v2/vertex_accessor.hpp @@ -13,6 +13,8 @@ #include +#include "storage/v2/id_types.hpp" +#include "storage/v2/schema_validator.hpp" #include "storage/v2/vertex.hpp" #include "storage/v2/config.hpp" @@ -29,20 +31,39 @@ struct Constraints; class VertexAccessor final { private: + struct VertexValidator { + // TODO(jbajic) Beware since vertex is pointer it will be accessed even as nullptr + explicit VertexValidator(const SchemaValidator &schema_validator, const Vertex *vertex); + + [[nodiscard]] std::optional ValidatePropertyUpdate(PropertyId property_id) const; + + [[nodiscard]] std::optional ValidateAddLabel(LabelId label) const; + + [[nodiscard]] std::optional ValidateRemoveLabel(LabelId label) const; + + const SchemaValidator *schema_validator; + + private: + const Vertex *vertex_; + }; friend class Storage; public: + // Be careful when using VertexAccessor since it can be instantiated with + // nullptr values VertexAccessor(Vertex *vertex, Transaction *transaction, Indices *indices, Constraints *constraints, - Config::Items config, bool for_deleted = false) + Config::Items config, const SchemaValidator &schema_validator, bool for_deleted = false) : vertex_(vertex), transaction_(transaction), indices_(indices), constraints_(constraints), config_(config), + vertex_validator_{schema_validator, vertex}, for_deleted_(for_deleted) {} static std::optional Create(Vertex *vertex, Transaction *transaction, Indices *indices, - Constraints *constraints, Config::Items config, View view); + Constraints *constraints, Config::Items config, + const SchemaValidator &schema_validator, View view); /// @return true if the object is visible from the current transaction bool IsVisible(View view) const; @@ -52,11 +73,23 @@ class VertexAccessor final { /// @throw std::bad_alloc Result AddLabel(LabelId label); + /// Add a label and return `true` if insertion took place. + /// `false` is returned if the label already existed, or SchemaViolation + /// if adding the label has violated one of the schema constraints. + /// @throw std::bad_alloc + storage::ResultSchema AddLabelAndValidate(LabelId label); + /// Remove a label and return `true` if deletion took place. /// `false` is returned if the vertex did not have a label already. /// @throw std::bad_alloc Result RemoveLabel(LabelId label); + /// Remove a label and return `true` if deletion took place. + /// `false` is returned if the vertex did not have a label already. or SchemaViolation + /// if adding the label has violated one of the schema constraints. + /// @throw std::bad_alloc + ResultSchema RemoveLabelAndValidate(LabelId label); + Result HasLabel(LabelId label, View view) const; /// @throw std::bad_alloc @@ -64,10 +97,16 @@ class VertexAccessor final { /// std::vector::max_size(). Result> Labels(View view) const; + Result PrimaryLabel(View view) const; + /// Set a property value and return the old value. /// @throw std::bad_alloc Result SetProperty(PropertyId property, const PropertyValue &value); + /// Set a property value and return the old value or error. + /// @throw std::bad_alloc + ResultSchema SetPropertyAndValidate(PropertyId property, const PropertyValue &value); + /// Remove all properties and return the values of the removed properties. /// @throw std::bad_alloc Result> ClearProperties(); @@ -96,6 +135,8 @@ class VertexAccessor final { Gid Gid() const noexcept { return vertex_->gid; } + const SchemaValidator *GetSchemaValidator() const; + bool operator==(const VertexAccessor &other) const noexcept { return vertex_ == other.vertex_ && transaction_ == other.transaction_; } @@ -107,6 +148,7 @@ class VertexAccessor final { Indices *indices_; Constraints *constraints_; Config::Items config_; + VertexValidator vertex_validator_; // if the accessor was created for a deleted vertex. // Accessor behaves differently for some methods based on this diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 319627715..78e1e366d 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -75,46 +75,42 @@ target_link_libraries(${test_prefix}bfs_single_node mg-query) add_unit_test(cypher_main_visitor.cpp) target_link_libraries(${test_prefix}cypher_main_visitor mg-query) -add_unit_test(interpreter.cpp ${CMAKE_SOURCE_DIR}/src/glue/communication.cpp) -target_link_libraries(${test_prefix}interpreter mg-communication mg-query) - +# add_unit_test(interpreter.cpp ${CMAKE_SOURCE_DIR}/src/glue/communication.cpp) +# target_link_libraries(${test_prefix}interpreter mg-communication mg-query) add_unit_test(plan_pretty_print.cpp) target_link_libraries(${test_prefix}plan_pretty_print mg-query) add_unit_test(query_cost_estimator.cpp) target_link_libraries(${test_prefix}query_cost_estimator mg-query) -add_unit_test(query_dump.cpp ${CMAKE_SOURCE_DIR}/src/glue/communication.cpp) -target_link_libraries(${test_prefix}query_dump mg-communication mg-query) - +# TODO Fix later on +# add_unit_test(query_dump.cpp ${CMAKE_SOURCE_DIR}/src/glue/communication.cpp) +# target_link_libraries(${test_prefix}query_dump mg-communication mg-query) add_unit_test(query_expression_evaluator.cpp) target_link_libraries(${test_prefix}query_expression_evaluator mg-query) add_unit_test(query_plan.cpp) target_link_libraries(${test_prefix}query_plan mg-query) -add_unit_test(query_plan_accumulate_aggregate.cpp) -target_link_libraries(${test_prefix}query_plan_accumulate_aggregate mg-query) +# add_unit_test(query_plan_accumulate_aggregate.cpp) +# target_link_libraries(${test_prefix}query_plan_accumulate_aggregate mg-query) -add_unit_test(query_plan_bag_semantics.cpp) -target_link_libraries(${test_prefix}query_plan_bag_semantics mg-query) +# add_unit_test(query_plan_bag_semantics.cpp) +# target_link_libraries(${test_prefix}query_plan_bag_semantics mg-query) -add_unit_test(query_plan_create_set_remove_delete.cpp) -target_link_libraries(${test_prefix}query_plan_create_set_remove_delete mg-query) - -add_unit_test(query_plan_edge_cases.cpp ${CMAKE_SOURCE_DIR}/src/glue/communication.cpp) -target_link_libraries(${test_prefix}query_plan_edge_cases mg-communication mg-query) - -add_unit_test(query_plan_match_filter_return.cpp) -target_link_libraries(${test_prefix}query_plan_match_filter_return mg-query) +# add_unit_test(query_plan_create_set_remove_delete.cpp) +# target_link_libraries(${test_prefix}query_plan_create_set_remove_delete mg-query) +# add_unit_test(query_plan_edge_cases.cpp ${CMAKE_SOURCE_DIR}/src/glue/communication.cpp) +# target_link_libraries(${test_prefix}query_plan_edge_cases mg-communication mg-query) +# add_unit_test(query_plan_match_filter_return.cpp) +# target_link_libraries(${test_prefix}query_plan_match_filter_return mg-query) add_unit_test(query_plan_read_write_typecheck.cpp ${CMAKE_SOURCE_DIR}/src/query/plan/read_write_type_checker.cpp) target_link_libraries(${test_prefix}query_plan_read_write_typecheck mg-query) -add_unit_test(query_plan_v2_create_set_remove_delete.cpp) -target_link_libraries(${test_prefix}query_plan_v2_create_set_remove_delete mg-query) - +# add_unit_test(query_plan_v2_create_set_remove_delete.cpp) +# target_link_libraries(${test_prefix}query_plan_v2_create_set_remove_delete mg-query) add_unit_test(query_pretty_print.cpp) target_link_libraries(${test_prefix}query_pretty_print mg-query) @@ -282,42 +278,67 @@ target_link_libraries(${test_prefix}commit_log_v2 gflags mg-utils mg-storage-v2) add_unit_test(property_value_v2.cpp) target_link_libraries(${test_prefix}property_value_v2 mg-storage-v2 mg-utils) -add_unit_test(storage_v2.cpp) -target_link_libraries(${test_prefix}storage_v2 mg-storage-v2 storage_test_utils) - +# add_unit_test(storage_v2.cpp) +# target_link_libraries(${test_prefix}storage_v2 mg-storage-v2 storage_test_utils) add_unit_test(storage_v2_constraints.cpp) target_link_libraries(${test_prefix}storage_v2_constraints mg-storage-v2) add_unit_test(storage_v2_decoder_encoder.cpp) target_link_libraries(${test_prefix}storage_v2_decoder_encoder mg-storage-v2) -add_unit_test(storage_v2_durability.cpp) -target_link_libraries(${test_prefix}storage_v2_durability mg-storage-v2) - -add_unit_test(storage_v2_edge.cpp) -target_link_libraries(${test_prefix}storage_v2_edge mg-storage-v2) +# add_unit_test(storage_v2_durability.cpp) +# target_link_libraries(${test_prefix}storage_v2_durability mg-storage-v2) +# add_unit_test(storage_v2_edge.cpp) +# target_link_libraries(${test_prefix}storage_v2_edge mg-storage-v2) add_unit_test(storage_v2_gc.cpp) target_link_libraries(${test_prefix}storage_v2_gc mg-storage-v2) -add_unit_test(storage_v2_indices.cpp) -target_link_libraries(${test_prefix}storage_v2_indices mg-storage-v2 mg-utils) - +# add_unit_test(storage_v2_indices.cpp) +# target_link_libraries(${test_prefix}storage_v2_indices mg-storage-v2 mg-utils) add_unit_test(storage_v2_name_id_mapper.cpp) target_link_libraries(${test_prefix}storage_v2_name_id_mapper mg-storage-v2) add_unit_test(storage_v2_property_store.cpp) target_link_libraries(${test_prefix}storage_v2_property_store mg-storage-v2 fmt) -add_unit_test(storage_v2_wal_file.cpp) -target_link_libraries(${test_prefix}storage_v2_wal_file mg-storage-v2 fmt) - -add_unit_test(storage_v2_replication.cpp) -target_link_libraries(${test_prefix}storage_v2_replication mg-storage-v2 fmt) +# add_unit_test(storage_v2_wal_file.cpp) +# target_link_libraries(${test_prefix}storage_v2_wal_file mg-storage-v2 fmt) +# add_unit_test(storage_v2_replication.cpp) +# target_link_libraries(${test_prefix}storage_v2_replication mg-storage-v2 fmt) add_unit_test(storage_v2_isolation_level.cpp) target_link_libraries(${test_prefix}storage_v2_isolation_level mg-storage-v2) +# Test mg-storage-v3 + +add_unit_test(storage_v3.cpp) +target_link_libraries(${test_prefix}storage_v3 mg-storage-v3) + +add_unit_test(storage_v3_schema.cpp) +target_link_libraries(${test_prefix}storage_v3_schema mg-storage-v2) + +add_unit_test(interpreter_v2.cpp ${CMAKE_SOURCE_DIR}/src/glue/communication.cpp) +target_link_libraries(${test_prefix}interpreter_v2 mg-storage-v2 mg-query mg-communication) + +add_unit_test(query_v2_query_plan_accumulate_aggregate.cpp) +target_link_libraries(${test_prefix}query_v2_query_plan_accumulate_aggregate mg-query) + +add_unit_test(query_v2_query_plan_create_set_remove_delete.cpp) +target_link_libraries(${test_prefix}query_v2_query_plan_create_set_remove_delete mg-query) + +add_unit_test(query_v2_query_plan_bag_semantics.cpp) +target_link_libraries(${test_prefix}query_v2_query_plan_bag_semantics mg-query) + +add_unit_test(query_v2_query_plan_edge_cases.cpp ${CMAKE_SOURCE_DIR}/src/glue/communication.cpp) +target_link_libraries(${test_prefix}query_v2_query_plan_edge_cases mg-communication mg-query) + +add_unit_test(query_v2_query_plan_v2_create_set_remove_delete.cpp) +target_link_libraries(${test_prefix}query_v2_query_plan_v2_create_set_remove_delete mg-query) + +add_unit_test(query_v2_query_plan_match_filter_return.cpp) +target_link_libraries(${test_prefix}query_v2_query_plan_match_filter_return mg-query) + add_unit_test(replication_persistence_helper.cpp) target_link_libraries(${test_prefix}replication_persistence_helper mg-storage-v2) @@ -361,7 +382,3 @@ find_package(Boost REQUIRED) add_unit_test(websocket.cpp) target_link_libraries(${test_prefix}websocket mg-communication Boost::headers) - -# Test storage-v3 -add_unit_test(storage_v3.cpp) -target_link_libraries(${test_prefix}storage_v3 mg-storage-v3) diff --git a/tests/unit/interpreter.cpp b/tests/unit/interpreter.cpp index 466079578..f5a3e03b3 100644 --- a/tests/unit/interpreter.cpp +++ b/tests/unit/interpreter.cpp @@ -10,10 +10,8 @@ // licenses/APL.txt. #include -#include #include #include -#include #include "communication/bolt/v1/value.hpp" #include "communication/result_stream_faker.hpp" @@ -40,11 +38,6 @@ auto ToEdgeList(const memgraph::communication::bolt::Value &v) { list.push_back(x.ValueEdge()); } return list; -} - -auto StringToUnorderedSet(const std::string &element) { - const auto element_split = memgraph::utils::Split(element, ", "); - return std::unordered_set(element_split.begin(), element_split.end()); }; struct InterpreterFaker { @@ -1472,145 +1465,3 @@ TEST_F(InterpreterTest, LoadCsvClauseNotification) { "conversion functions such as ToInteger, ToFloat, ToBoolean etc."); ASSERT_EQ(notification["description"].ValueString(), ""); } - -TEST_F(InterpreterTest, CreateSchemaMulticommandTransaction) { - Interpret("BEGIN"); - ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING, age INTEGER)"), - memgraph::query::ConstraintInMulticommandTxException); - Interpret("ROLLBACK"); -} - -TEST_F(InterpreterTest, ShowSchemasMulticommandTransaction) { - Interpret("BEGIN"); - ASSERT_THROW(Interpret("SHOW SCHEMAS"), memgraph::query::ConstraintInMulticommandTxException); - Interpret("ROLLBACK"); -} - -TEST_F(InterpreterTest, ShowSchemaMulticommandTransaction) { - Interpret("BEGIN"); - ASSERT_THROW(Interpret("SHOW SCHEMA ON :label"), memgraph::query::ConstraintInMulticommandTxException); - Interpret("ROLLBACK"); -} - -TEST_F(InterpreterTest, DropSchemaMulticommandTransaction) { - Interpret("BEGIN"); - ASSERT_THROW(Interpret("DROP SCHEMA ON :label"), memgraph::query::ConstraintInMulticommandTxException); - Interpret("ROLLBACK"); -} - -TEST_F(InterpreterTest, SchemaTestCreateAndShow) { - // Empty schema type map should result with syntax exception. - ASSERT_THROW(Interpret("CREATE SCHEMA ON :label();"), memgraph::query::SyntaxException); - - // Duplicate properties are should also cause an exception - ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING, name STRING);"), memgraph::query::SemanticException); - ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING, name INTEGER);"), memgraph::query::SemanticException); - - { - // Cannot create same schema twice - Interpret("CREATE SCHEMA ON :label(name STRING, age INTEGER)"); - ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING);"), memgraph::query::QueryException); - } - // Show schema - { - auto stream = Interpret("SHOW SCHEMA ON :label"); - ASSERT_EQ(stream.GetHeader().size(), 2U); - const auto &header = stream.GetHeader(); - ASSERT_EQ(header[0], "property_name"); - ASSERT_EQ(header[1], "property_type"); - ASSERT_EQ(stream.GetResults().size(), 2U); - std::unordered_map result_table{{"age", "Integer"}, {"name", "String"}}; - - const auto &result = stream.GetResults().front(); - ASSERT_EQ(result.size(), 2U); - const auto key1 = result[0].ValueString(); - ASSERT_TRUE(result_table.contains(key1)); - ASSERT_EQ(result[1].ValueString(), result_table[key1]); - - const auto &result2 = stream.GetResults().front(); - ASSERT_EQ(result2.size(), 2U); - const auto key2 = result2[0].ValueString(); - ASSERT_TRUE(result_table.contains(key2)); - ASSERT_EQ(result[1].ValueString(), result_table[key2]); - } - // Create Another Schema - Interpret("CREATE SCHEMA ON :label2(place STRING, dur DURATION)"); - - // Show schemas - { - auto stream = Interpret("SHOW SCHEMAS"); - ASSERT_EQ(stream.GetHeader().size(), 2U); - const auto &header = stream.GetHeader(); - ASSERT_EQ(header[0], "label"); - ASSERT_EQ(header[1], "primary_key"); - ASSERT_EQ(stream.GetResults().size(), 2U); - std::unordered_map> result_table{ - {"label", {"name::String", "age::Integer"}}, {"label2", {"place::String", "dur::Duration"}}}; - - const auto &result = stream.GetResults().front(); - ASSERT_EQ(result.size(), 2U); - const auto key1 = result[0].ValueString(); - ASSERT_TRUE(result_table.contains(key1)); - const auto primary_key_split = StringToUnorderedSet(result[1].ValueString()); - ASSERT_EQ(primary_key_split.size(), 2); - ASSERT_TRUE(primary_key_split == result_table[key1]) << "actual value is: " << result[1].ValueString(); - - const auto &result2 = stream.GetResults().front(); - ASSERT_EQ(result2.size(), 2U); - const auto key2 = result2[0].ValueString(); - ASSERT_TRUE(result_table.contains(key2)); - const auto primary_key_split2 = StringToUnorderedSet(result2[1].ValueString()); - ASSERT_EQ(primary_key_split2.size(), 2); - ASSERT_TRUE(primary_key_split2 == result_table[key2]) << "Real value is: " << result[1].ValueString(); - } -} - -TEST_F(InterpreterTest, SchemaTestCreateDropAndShow) { - Interpret("CREATE SCHEMA ON :label(name STRING, age INTEGER)"); - // Wrong syntax for dropping schema. - ASSERT_THROW(Interpret("DROP SCHEMA ON :label();"), memgraph::query::SyntaxException); - // Cannot drop non existant schema. - ASSERT_THROW(Interpret("DROP SCHEMA ON :label1;"), memgraph::query::QueryException); - - // Create Schema and Drop - auto get_number_of_schemas = [this]() { - auto stream = Interpret("SHOW SCHEMAS"); - return stream.GetResults().size(); - }; - - ASSERT_EQ(get_number_of_schemas(), 1); - Interpret("CREATE SCHEMA ON :label1(name STRING, age INTEGER)"); - ASSERT_EQ(get_number_of_schemas(), 2); - Interpret("CREATE SCHEMA ON :label2(name STRING, sex BOOL)"); - ASSERT_EQ(get_number_of_schemas(), 3); - Interpret("DROP SCHEMA ON :label1"); - ASSERT_EQ(get_number_of_schemas(), 2); - Interpret("CREATE SCHEMA ON :label3(name STRING, birthday LOCALDATETIME)"); - ASSERT_EQ(get_number_of_schemas(), 3); - Interpret("DROP SCHEMA ON :label2"); - ASSERT_EQ(get_number_of_schemas(), 2); - Interpret("CREATE SCHEMA ON :label4(name STRING, age DURATION)"); - ASSERT_EQ(get_number_of_schemas(), 3); - Interpret("DROP SCHEMA ON :label3"); - ASSERT_EQ(get_number_of_schemas(), 2); - Interpret("DROP SCHEMA ON :label"); - ASSERT_EQ(get_number_of_schemas(), 1); - - // Show schemas - auto stream = Interpret("SHOW SCHEMAS"); - ASSERT_EQ(stream.GetHeader().size(), 2U); - const auto &header = stream.GetHeader(); - ASSERT_EQ(header[0], "label"); - ASSERT_EQ(header[1], "primary_key"); - ASSERT_EQ(stream.GetResults().size(), 1U); - std::unordered_map> result_table{ - {"label4", {"name::String", "age::Duration"}}}; - - const auto &result = stream.GetResults().front(); - ASSERT_EQ(result.size(), 2U); - const auto key1 = result[0].ValueString(); - ASSERT_TRUE(result_table.contains(key1)); - const auto primary_key_split = StringToUnorderedSet(result[1].ValueString()); - ASSERT_EQ(primary_key_split.size(), 2); - ASSERT_TRUE(primary_key_split == result_table[key1]); -} diff --git a/tests/unit/interpreter_v2.cpp b/tests/unit/interpreter_v2.cpp new file mode 100644 index 000000000..1c4da6865 --- /dev/null +++ b/tests/unit/interpreter_v2.cpp @@ -0,0 +1,1636 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include +#include +#include +#include +#include + +#include "communication/bolt/v1/value.hpp" +#include "communication/result_stream_faker.hpp" +#include "glue/communication.hpp" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "query/auth_checker.hpp" +#include "query/config.hpp" +#include "query/exceptions.hpp" +#include "query/interpreter.hpp" +#include "query/stream.hpp" +#include "query/typed_value.hpp" +#include "query_common.hpp" +#include "storage/v2/isolation_level.hpp" +#include "storage/v2/property_value.hpp" +#include "utils/csv_parsing.hpp" +#include "utils/logging.hpp" + +namespace { + +auto ToEdgeList(const memgraph::communication::bolt::Value &v) { + std::vector list; + for (auto x : v.ValueList()) { + list.push_back(x.ValueEdge()); + } + return list; +} + +auto StringToUnorderedSet(const std::string &element) { + const auto element_split = memgraph::utils::Split(element, ", "); + return std::unordered_set(element_split.begin(), element_split.end()); +}; + +struct InterpreterFaker { + InterpreterFaker(memgraph::storage::Storage *db, const memgraph::query::InterpreterConfig config, + const std::filesystem::path &data_directory) + : interpreter_context(db, config, data_directory), interpreter(&interpreter_context) { + interpreter_context.auth_checker = &auth_checker; + } + + auto Prepare(const std::string &query, const std::map ¶ms = {}) { + ResultStreamFaker stream(interpreter_context.db); + + const auto [header, _, qid] = interpreter.Prepare(query, params, nullptr); + stream.Header(header); + return std::make_pair(std::move(stream), qid); + } + + void Pull(ResultStreamFaker *stream, std::optional n = {}, std::optional qid = {}) { + const auto summary = interpreter.Pull(stream, n, qid); + stream->Summary(summary); + } + + /** + * Execute the given query and commit the transaction. + * + * Return the query stream. + */ + auto Interpret(const std::string &query, const std::map ¶ms = {}) { + auto prepare_result = Prepare(query, params); + + auto &stream = prepare_result.first; + auto summary = interpreter.Pull(&stream, {}, prepare_result.second); + stream.Summary(summary); + + return std::move(stream); + } + + memgraph::query::AllowEverythingAuthChecker auth_checker; + memgraph::query::InterpreterContext interpreter_context; + memgraph::query::Interpreter interpreter; +}; + +} // namespace + +// TODO: This is not a unit test, but tests/integration dir is chaotic at the +// moment. After tests refactoring is done, move/rename this. + +class InterpreterTest : public ::testing::Test { + protected: + memgraph::storage::Storage db_; + std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_interpreter"}; + + InterpreterFaker default_interpreter{&db_, {}, data_directory}; + + auto Prepare(const std::string &query, const std::map ¶ms = {}) { + return default_interpreter.Prepare(query, params); + } + + void Pull(ResultStreamFaker *stream, std::optional n = {}, std::optional qid = {}) { + default_interpreter.Pull(stream, n, qid); + } + + auto Interpret(const std::string &query, const std::map ¶ms = {}) { + return default_interpreter.Interpret(query, params); + } +}; + +TEST_F(InterpreterTest, MultiplePulls) { + { + auto [stream, qid] = Prepare("UNWIND [1,2,3,4,5] as n RETURN n"); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "n"); + Pull(&stream, 1); + ASSERT_EQ(stream.GetSummary().count("has_more"), 1); + ASSERT_TRUE(stream.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueInt(), 1); + Pull(&stream, 2); + ASSERT_EQ(stream.GetSummary().count("has_more"), 1); + ASSERT_TRUE(stream.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream.GetResults().size(), 3U); + ASSERT_EQ(stream.GetResults()[1][0].ValueInt(), 2); + ASSERT_EQ(stream.GetResults()[2][0].ValueInt(), 3); + Pull(&stream); + ASSERT_EQ(stream.GetSummary().count("has_more"), 1); + ASSERT_FALSE(stream.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream.GetResults().size(), 5U); + ASSERT_EQ(stream.GetResults()[3][0].ValueInt(), 4); + ASSERT_EQ(stream.GetResults()[4][0].ValueInt(), 5); + } +} + +// Run query with different ast twice to see if query executes correctly when +// ast is read from cache. +TEST_F(InterpreterTest, AstCache) { + { + auto stream = Interpret("RETURN 2 + 3"); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "2 + 3"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueInt(), 5); + } + { + // Cached ast, different literals. + auto stream = Interpret("RETURN 5 + 4"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueInt(), 9); + } + { + // Different ast (because of different types). + auto stream = Interpret("RETURN 5.5 + 4"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueDouble(), 9.5); + } + { + // Cached ast, same literals. + auto stream = Interpret("RETURN 2 + 3"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueInt(), 5); + } + { + // Cached ast, different literals. + auto stream = Interpret("RETURN 10.5 + 1"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueDouble(), 11.5); + } + { + // Cached ast, same literals, different whitespaces. + auto stream = Interpret("RETURN 10.5 + 1"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueDouble(), 11.5); + } + { + // Cached ast, same literals, different named header. + auto stream = Interpret("RETURN 10.5+1"); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "10.5+1"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueDouble(), 11.5); + } +} + +// Run query with same ast multiple times with different parameters. +TEST_F(InterpreterTest, Parameters) { + { + auto stream = Interpret("RETURN $2 + $`a b`", {{"2", memgraph::storage::PropertyValue(10)}, + {"a b", memgraph::storage::PropertyValue(15)}}); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "$2 + $`a b`"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueInt(), 25); + } + { + // Not needed parameter. + auto stream = Interpret("RETURN $2 + $`a b`", {{"2", memgraph::storage::PropertyValue(10)}, + {"a b", memgraph::storage::PropertyValue(15)}, + {"c", memgraph::storage::PropertyValue(10)}}); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "$2 + $`a b`"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueInt(), 25); + } + { + // Cached ast, different parameters. + auto stream = Interpret("RETURN $2 + $`a b`", {{"2", memgraph::storage::PropertyValue("da")}, + {"a b", memgraph::storage::PropertyValue("ne")}}); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueString(), "dane"); + } + { + // Non-primitive literal. + auto stream = + Interpret("RETURN $2", {{"2", memgraph::storage::PropertyValue(std::vector{ + memgraph::storage::PropertyValue(5), memgraph::storage::PropertyValue(2), + memgraph::storage::PropertyValue(3)})}}); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + auto result = memgraph::query::test_common::ToIntList(memgraph::glue::ToTypedValue(stream.GetResults()[0][0])); + ASSERT_THAT(result, testing::ElementsAre(5, 2, 3)); + } + { + // Cached ast, unprovided parameter. + ASSERT_THROW(Interpret("RETURN $2 + $`a b`", {{"2", memgraph::storage::PropertyValue("da")}, + {"ab", memgraph::storage::PropertyValue("ne")}}), + memgraph::query::UnprovidedParameterError); + } +} + +// Run CREATE/MATCH/MERGE queries with property map +TEST_F(InterpreterTest, ParametersAsPropertyMap) { + { + EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :label(name STRING, age INTEGER)")); + std::map property_map{}; + property_map["name"] = memgraph::storage::PropertyValue("name1"); + property_map["age"] = memgraph::storage::PropertyValue(25); + auto stream = + Interpret("CREATE (n:label $prop) RETURN n", { + {"prop", memgraph::storage::PropertyValue(property_map)}, + }); + ASSERT_EQ(stream.GetHeader().size(), 1U); + ASSERT_EQ(stream.GetHeader()[0], "n"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + auto result = stream.GetResults()[0][0].ValueVertex(); + EXPECT_EQ(result.properties["name"].ValueString(), "name1"); + EXPECT_EQ(result.properties["age"].ValueInt(), 25); + } + { + EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :Person(name STRING, age INTEGER)")); + std::map property_map{}; + property_map["name"] = memgraph::storage::PropertyValue("name1"); + property_map["age"] = memgraph::storage::PropertyValue(25); + EXPECT_NO_THROW(Interpret("CREATE (:Person {name: 'test', age: 30})")); + auto stream = Interpret("MATCH (m:Person) CREATE (n:Person $prop) RETURN n", + { + {"prop", memgraph::storage::PropertyValue(property_map)}, + }); + ASSERT_EQ(stream.GetHeader().size(), 1U); + ASSERT_EQ(stream.GetHeader()[0], "n"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + auto result = stream.GetResults()[0][0].ValueVertex(); + EXPECT_EQ(result.properties["name"].ValueString(), "name1"); + EXPECT_EQ(result.properties["age"].ValueInt(), 25); + } + { + EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :L1(name STRING)")); + std::map property_map{}; + property_map["name"] = memgraph::storage::PropertyValue("name1"); + property_map["weight"] = memgraph::storage::PropertyValue(121); + auto stream = Interpret("CREATE (:L1 {name: 'name1'})-[r:TO $prop]->(:L1 {name: 'name2'}) RETURN r", + { + {"prop", memgraph::storage::PropertyValue(property_map)}, + }); + ASSERT_EQ(stream.GetHeader().size(), 1U); + ASSERT_EQ(stream.GetHeader()[0], "r"); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + auto result = stream.GetResults()[0][0].ValueEdge(); + EXPECT_EQ(result.properties["name"].ValueString(), "name1"); + EXPECT_EQ(result.properties["weight"].ValueInt(), 121); + } + { + std::map property_map{}; + property_map["name"] = memgraph::storage::PropertyValue("name1"); + property_map["age"] = memgraph::storage::PropertyValue(15); + ASSERT_THROW(Interpret("MATCH (n $prop) RETURN n", + { + {"prop", memgraph::storage::PropertyValue(property_map)}, + }), + memgraph::query::SemanticException); + } + { + EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :L2(name STRING, age INTEGER)")); + std::map property_map{}; + property_map["name"] = memgraph::storage::PropertyValue("name1"); + property_map["age"] = memgraph::storage::PropertyValue(15); + ASSERT_THROW(Interpret("MERGE (n:L2 $prop) RETURN n", + { + {"prop", memgraph::storage::PropertyValue(property_map)}, + }), + memgraph::query::SemanticException); + } +} + +// Test bfs end to end. +TEST_F(InterpreterTest, Bfs) { + srand(0); + const auto kNumLevels = 10; + const auto kNumNodesPerLevel = 100; + const auto kNumEdgesPerNode = 100; + const auto kNumUnreachableNodes = 1000; + const auto kNumUnreachableEdges = 100000; + const auto kReachable = "reachable"; + const auto kId = "id"; + + std::vector> levels(kNumLevels); + int id = 0; + + // Set up. + { + auto storage_dba = db_.Access(); + memgraph::query::DbAccessor dba(&storage_dba); + auto add_node = [&](int level, bool reachable) { + auto node = dba.InsertVertex(); + MG_ASSERT(node.SetProperty(dba.NameToProperty(kId), memgraph::storage::PropertyValue(id++)).HasValue()); + MG_ASSERT( + node.SetProperty(dba.NameToProperty(kReachable), memgraph::storage::PropertyValue(reachable)).HasValue()); + levels[level].push_back(node); + return node; + }; + + auto add_edge = [&](auto &v1, auto &v2, bool reachable) { + auto edge = dba.InsertEdge(&v1, &v2, dba.NameToEdgeType("edge")); + MG_ASSERT( + edge->SetProperty(dba.NameToProperty(kReachable), memgraph::storage::PropertyValue(reachable)).HasValue()); + }; + + // Add source node. + add_node(0, true); + + // Add reachable nodes. + for (int i = 1; i < kNumLevels; ++i) { + for (int j = 0; j < kNumNodesPerLevel; ++j) { + auto node = add_node(i, true); + for (int k = 0; k < kNumEdgesPerNode; ++k) { + auto &node2 = levels[i - 1][rand() % levels[i - 1].size()]; + add_edge(node2, node, true); + } + } + } + + // Add unreachable nodes. + for (int i = 0; i < kNumUnreachableNodes; ++i) { + auto node = add_node(rand() % kNumLevels, // Not really important. + false); + for (int j = 0; j < kNumEdgesPerNode; ++j) { + auto &level = levels[rand() % kNumLevels]; + auto &node2 = level[rand() % level.size()]; + add_edge(node2, node, true); + add_edge(node, node2, true); + } + } + + // Add unreachable edges. + for (int i = 0; i < kNumUnreachableEdges; ++i) { + auto &level1 = levels[rand() % kNumLevels]; + auto &node1 = level1[rand() % level1.size()]; + auto &level2 = levels[rand() % kNumLevels]; + auto &node2 = level2[rand() % level2.size()]; + add_edge(node1, node2, false); + } + + ASSERT_FALSE(dba.Commit().HasError()); + } + + auto stream = Interpret( + "MATCH (n {id: 0})-[r *bfs..5 (e, n | n.reachable and " + "e.reachable)]->(m) RETURN n, r, m"); + + ASSERT_EQ(stream.GetHeader().size(), 3U); + EXPECT_EQ(stream.GetHeader()[0], "n"); + EXPECT_EQ(stream.GetHeader()[1], "r"); + EXPECT_EQ(stream.GetHeader()[2], "m"); + ASSERT_EQ(stream.GetResults().size(), 5 * kNumNodesPerLevel); + + auto dba = db_.Access(); + int expected_level = 1; + int remaining_nodes_in_level = kNumNodesPerLevel; + std::unordered_set matched_ids; + + for (const auto &result : stream.GetResults()) { + const auto &begin = result[0].ValueVertex(); + const auto &edges = ToEdgeList(result[1]); + const auto &end = result[2].ValueVertex(); + + // Check that path is of expected length. Returned paths should be from + // shorter to longer ones. + EXPECT_EQ(edges.size(), expected_level); + // Check that starting node is correct. + EXPECT_EQ(edges.front().from, begin.id); + EXPECT_EQ(begin.properties.at(kId).ValueInt(), 0); + for (int i = 1; i < static_cast(edges.size()); ++i) { + // Check that edges form a connected path. + EXPECT_EQ(edges[i - 1].to.AsInt(), edges[i].from.AsInt()); + } + auto matched_id = end.properties.at(kId).ValueInt(); + EXPECT_EQ(edges.back().to, end.id); + // Check that we didn't match that node already. + EXPECT_TRUE(matched_ids.insert(matched_id).second); + // Check that shortest path was found. + EXPECT_TRUE(matched_id > kNumNodesPerLevel * (expected_level - 1) && + matched_id <= kNumNodesPerLevel * expected_level); + if (!--remaining_nodes_in_level) { + remaining_nodes_in_level = kNumNodesPerLevel; + ++expected_level; + } + } +} + +// Test shortest path end to end. +TEST_F(InterpreterTest, ShortestPath) { + EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :A(x INTEGER)")); + EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :B(x INTEGER)")); + EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :C(x INTEGER)")); + const auto test_shortest_path = [this](const bool use_duration) { + const auto get_weight = [use_duration](const auto value) { + return fmt::format(fmt::runtime(use_duration ? "DURATION('PT{}S')" : "{}"), value); + }; + + Interpret( + fmt::format("CREATE (n:A {{x: 1}}), (m:B {{x: 2}}), (l:C {{x: 1}}), (n)-[:r1 {{w: {} " + "}}]->(m)-[:r2 {{w: {}}}]->(l), (n)-[:r3 {{w: {}}}]->(l)", + get_weight(1), get_weight(2), get_weight(4))); + + auto stream = Interpret("MATCH (n)-[e *wshortest 5 (e, n | e.w) ]->(m) return e"); + + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "e"); + ASSERT_EQ(stream.GetResults().size(), 3U); + + auto dba = db_.Access(); + std::vector> expected_results{{"r1"}, {"r2"}, {"r1", "r2"}}; + + for (const auto &result : stream.GetResults()) { + const auto &edges = ToEdgeList(result[0]); + + std::vector datum; + datum.reserve(edges.size()); + + for (const auto &edge : edges) { + datum.push_back(edge.type); + } + + bool any_match = false; + for (const auto &expected : expected_results) { + if (expected == datum) { + any_match = true; + break; + } + } + + EXPECT_TRUE(any_match); + } + + Interpret("MATCH (n) DETACH DELETE n"); + }; + + static constexpr bool kUseNumeric{false}; + static constexpr bool kUseDuration{true}; + { + SCOPED_TRACE("Test with numeric values"); + test_shortest_path(kUseNumeric); + } + { + SCOPED_TRACE("Test with Duration values"); + test_shortest_path(kUseDuration); + } +} + +TEST_F(InterpreterTest, CreateLabelIndexInMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("CREATE INDEX ON :X"), memgraph::query::IndexInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, CreateLabelPropertyIndexInMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("CREATE INDEX ON :X(y)"), memgraph::query::IndexInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, CreateExistenceConstraintInMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("CREATE CONSTRAINT ON (n:A) ASSERT EXISTS (n.a)"), + memgraph::query::ConstraintInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, CreateUniqueConstraintInMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("CREATE CONSTRAINT ON (n:A) ASSERT n.a, n.b IS UNIQUE"), + memgraph::query::ConstraintInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, ShowIndexInfoInMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("SHOW INDEX INFO"), memgraph::query::InfoInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, ShowConstraintInfoInMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("SHOW CONSTRAINT INFO"), memgraph::query::InfoInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, ShowStorageInfoInMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("SHOW STORAGE INFO"), memgraph::query::InfoInMulticommandTxException); + Interpret("ROLLBACK"); +} + +// // NOLINTNEXTLINE(hicpp-special-member-functions) +TEST_F(InterpreterTest, ExistenceConstraintTest) { + ASSERT_NO_THROW(Interpret("CREATE SCHEMA ON :A(a INTEGER);")); + + Interpret("CREATE CONSTRAINT ON (n:A) ASSERT EXISTS (n.b);"); + Interpret("CREATE (:A{a: 3, b:1})"); + Interpret("CREATE (:A{a: 3, b:2})"); + ASSERT_THROW(Interpret("CREATE (:A {a: 12})"), memgraph::query::QueryException); + Interpret("MATCH (n:A{a:3, b: 2}) SET n.b=5"); + Interpret("CREATE (:A{a:2, b: 3})"); + Interpret("MATCH (n:A{a:3, b: 1}) DETACH DELETE n"); + Interpret("CREATE (n:A{a:2, b: 3})"); + ASSERT_THROW(Interpret("CREATE CONSTRAINT ON (n:A) ASSERT EXISTS (n.c);"), memgraph::query::QueryRuntimeException); +} + +TEST_F(InterpreterTest, UniqueConstraintTest) { + ASSERT_NO_THROW(Interpret("CREATE SCHEMA ON :A(a INTEGER);")); + + // Empty property list should result with syntax exception. + ASSERT_THROW(Interpret("CREATE CONSTRAINT ON (n:A) ASSERT IS UNIQUE;"), memgraph::query::SyntaxException); + ASSERT_THROW(Interpret("DROP CONSTRAINT ON (n:A) ASSERT IS UNIQUE;"), memgraph::query::SyntaxException); + + // Too large list of properties should also result with syntax exception. + { + std::stringstream stream; + stream << " ON (n:A) ASSERT "; + for (size_t i = 0; i < 33; ++i) { + if (i > 0) stream << ", "; + stream << "n.prop" << i; + } + stream << " IS UNIQUE;"; + std::string create_query = "CREATE CONSTRAINT" + stream.str(); + std::string drop_query = "DROP CONSTRAINT" + stream.str(); + ASSERT_THROW(Interpret(create_query), memgraph::query::SyntaxException); + ASSERT_THROW(Interpret(drop_query), memgraph::query::SyntaxException); + } + + // Providing property list with duplicates results with syntax exception. + ASSERT_THROW(Interpret("CREATE CONSTRAINT ON (n:A) ASSERT n.a, n.b, n.a IS UNIQUE;"), + memgraph::query::SyntaxException); + ASSERT_THROW(Interpret("DROP CONSTRAINT ON (n:A) ASSERT n.a, n.b, n.a IS UNIQUE;"), memgraph::query::SyntaxException); + + // Commit of vertex should fail if a constraint is violated. + Interpret("CREATE CONSTRAINT ON (n:A) ASSERT n.a, n.b IS UNIQUE;"); + Interpret("CREATE (:A{a:1, b:2})"); + Interpret("CREATE (:A{a:1, b:3})"); + ASSERT_THROW(Interpret("CREATE (:A{a:1, b:2})"), memgraph::query::QueryException); + + // Attempt to create a constraint should fail if it's violated. + Interpret("CREATE (:A{a:1, c:2})"); + Interpret("CREATE (:A{a:1, c:2})"); + ASSERT_THROW(Interpret("CREATE CONSTRAINT ON (n:A) ASSERT n.a, n.c IS UNIQUE;"), + memgraph::query::QueryRuntimeException); + + Interpret("MATCH (n:A{a:2, b:2}) SET n.a=1"); + Interpret("CREATE (:A{a:2})"); + Interpret("MATCH (n:A{a:2}) DETACH DELETE n"); + Interpret("CREATE (n:A{a:2})"); + + // Show constraint info. + { + auto stream = Interpret("SHOW CONSTRAINT INFO"); + ASSERT_EQ(stream.GetHeader().size(), 3U); + const auto &header = stream.GetHeader(); + ASSERT_EQ(header[0], "constraint type"); + ASSERT_EQ(header[1], "label"); + ASSERT_EQ(header[2], "properties"); + ASSERT_EQ(stream.GetResults().size(), 1U); + const auto &result = stream.GetResults().front(); + ASSERT_EQ(result.size(), 3U); + ASSERT_EQ(result[0].ValueString(), "unique"); + ASSERT_EQ(result[1].ValueString(), "A"); + const auto &properties = result[2].ValueList(); + ASSERT_EQ(properties.size(), 2U); + ASSERT_EQ(properties[0].ValueString(), "a"); + ASSERT_EQ(properties[1].ValueString(), "b"); + } + + // Drop constraint. + Interpret("DROP CONSTRAINT ON (n:A) ASSERT n.a, n.b IS UNIQUE;"); + // Removing the same constraint twice should not throw any exception. + Interpret("DROP CONSTRAINT ON (n:A) ASSERT n.a, n.b IS UNIQUE;"); +} + +TEST_F(InterpreterTest, ExplainQuery) { + const auto &interpreter_context = default_interpreter.interpreter_context; + + EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); + auto stream = Interpret("EXPLAIN MATCH (n) RETURN *;"); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader().front(), "QUERY PLAN"); + std::vector expected_rows{" * Produce {n}", " * ScanAll (n)", " * Once"}; + ASSERT_EQ(stream.GetResults().size(), expected_rows.size()); + auto expected_it = expected_rows.begin(); + for (const auto &row : stream.GetResults()) { + ASSERT_EQ(row.size(), 1U); + EXPECT_EQ(row.front().ValueString(), *expected_it); + ++expected_it; + } + // We should have a plan cache for MATCH ... + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + // We should have AST cache for EXPLAIN ... and for inner MATCH ... + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); + Interpret("MATCH (n) RETURN *;"); + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); +} + +TEST_F(InterpreterTest, ExplainQueryMultiplePulls) { + const auto &interpreter_context = default_interpreter.interpreter_context; + + EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); + auto [stream, qid] = Prepare("EXPLAIN MATCH (n) RETURN *;"); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader().front(), "QUERY PLAN"); + std::vector expected_rows{" * Produce {n}", " * ScanAll (n)", " * Once"}; + Pull(&stream, 1); + ASSERT_EQ(stream.GetResults().size(), 1); + auto expected_it = expected_rows.begin(); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + EXPECT_EQ(stream.GetResults()[0].front().ValueString(), *expected_it); + ++expected_it; + + Pull(&stream, 1); + ASSERT_EQ(stream.GetResults().size(), 2); + ASSERT_EQ(stream.GetResults()[1].size(), 1U); + EXPECT_EQ(stream.GetResults()[1].front().ValueString(), *expected_it); + ++expected_it; + + Pull(&stream); + ASSERT_EQ(stream.GetResults().size(), 3); + ASSERT_EQ(stream.GetResults()[2].size(), 1U); + EXPECT_EQ(stream.GetResults()[2].front().ValueString(), *expected_it); + // We should have a plan cache for MATCH ... + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + // We should have AST cache for EXPLAIN ... and for inner MATCH ... + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); + Interpret("MATCH (n) RETURN *;"); + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); +} + +TEST_F(InterpreterTest, ExplainQueryInMulticommandTransaction) { + const auto &interpreter_context = default_interpreter.interpreter_context; + + EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); + Interpret("BEGIN"); + auto stream = Interpret("EXPLAIN MATCH (n) RETURN *;"); + Interpret("COMMIT"); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader().front(), "QUERY PLAN"); + std::vector expected_rows{" * Produce {n}", " * ScanAll (n)", " * Once"}; + ASSERT_EQ(stream.GetResults().size(), expected_rows.size()); + auto expected_it = expected_rows.begin(); + for (const auto &row : stream.GetResults()) { + ASSERT_EQ(row.size(), 1U); + EXPECT_EQ(row.front().ValueString(), *expected_it); + ++expected_it; + } + // We should have a plan cache for MATCH ... + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + // We should have AST cache for EXPLAIN ... and for inner MATCH ... + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); + Interpret("MATCH (n) RETURN *;"); + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); +} + +TEST_F(InterpreterTest, ExplainQueryWithParams) { + const auto &interpreter_context = default_interpreter.interpreter_context; + + EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); + auto stream = + Interpret("EXPLAIN MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::PropertyValue(42)}}); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader().front(), "QUERY PLAN"); + std::vector expected_rows{" * Produce {n}", " * Filter", " * ScanAll (n)", " * Once"}; + ASSERT_EQ(stream.GetResults().size(), expected_rows.size()); + auto expected_it = expected_rows.begin(); + for (const auto &row : stream.GetResults()) { + ASSERT_EQ(row.size(), 1U); + EXPECT_EQ(row.front().ValueString(), *expected_it); + ++expected_it; + } + // We should have a plan cache for MATCH ... + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + // We should have AST cache for EXPLAIN ... and for inner MATCH ... + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); + Interpret("MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::PropertyValue("something else")}}); + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); +} + +TEST_F(InterpreterTest, ProfileQuery) { + const auto &interpreter_context = default_interpreter.interpreter_context; + + EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); + auto stream = Interpret("PROFILE MATCH (n) RETURN *;"); + std::vector expected_header{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"}; + EXPECT_EQ(stream.GetHeader(), expected_header); + std::vector expected_rows{"* Produce", "* ScanAll", "* Once"}; + ASSERT_EQ(stream.GetResults().size(), expected_rows.size()); + auto expected_it = expected_rows.begin(); + for (const auto &row : stream.GetResults()) { + ASSERT_EQ(row.size(), 4U); + EXPECT_EQ(row.front().ValueString(), *expected_it); + ++expected_it; + } + // We should have a plan cache for MATCH ... + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + // We should have AST cache for PROFILE ... and for inner MATCH ... + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); + Interpret("MATCH (n) RETURN *;"); + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); +} + +TEST_F(InterpreterTest, ProfileQueryMultiplePulls) { + const auto &interpreter_context = default_interpreter.interpreter_context; + + EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); + auto [stream, qid] = Prepare("PROFILE MATCH (n) RETURN *;"); + std::vector expected_header{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"}; + EXPECT_EQ(stream.GetHeader(), expected_header); + + std::vector expected_rows{"* Produce", "* ScanAll", "* Once"}; + auto expected_it = expected_rows.begin(); + + Pull(&stream, 1); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0].size(), 4U); + ASSERT_EQ(stream.GetResults()[0][0].ValueString(), *expected_it); + ++expected_it; + + Pull(&stream, 1); + ASSERT_EQ(stream.GetResults().size(), 2U); + ASSERT_EQ(stream.GetResults()[1].size(), 4U); + ASSERT_EQ(stream.GetResults()[1][0].ValueString(), *expected_it); + ++expected_it; + + Pull(&stream); + ASSERT_EQ(stream.GetResults().size(), 3U); + ASSERT_EQ(stream.GetResults()[2].size(), 4U); + ASSERT_EQ(stream.GetResults()[2][0].ValueString(), *expected_it); + + // We should have a plan cache for MATCH ... + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + // We should have AST cache for PROFILE ... and for inner MATCH ... + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); + Interpret("MATCH (n) RETURN *;"); + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); +} + +TEST_F(InterpreterTest, ProfileQueryInMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("PROFILE MATCH (n) RETURN *;"), memgraph::query::ProfileInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, ProfileQueryWithParams) { + const auto &interpreter_context = default_interpreter.interpreter_context; + + EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); + auto stream = + Interpret("PROFILE MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::PropertyValue(42)}}); + std::vector expected_header{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"}; + EXPECT_EQ(stream.GetHeader(), expected_header); + std::vector expected_rows{"* Produce", "* Filter", "* ScanAll", "* Once"}; + ASSERT_EQ(stream.GetResults().size(), expected_rows.size()); + auto expected_it = expected_rows.begin(); + for (const auto &row : stream.GetResults()) { + ASSERT_EQ(row.size(), 4U); + EXPECT_EQ(row.front().ValueString(), *expected_it); + ++expected_it; + } + // We should have a plan cache for MATCH ... + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + // We should have AST cache for PROFILE ... and for inner MATCH ... + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); + Interpret("MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::PropertyValue("something else")}}); + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); +} + +TEST_F(InterpreterTest, ProfileQueryWithLiterals) { + const auto &interpreter_context = default_interpreter.interpreter_context; + ASSERT_NO_THROW(Interpret("CREATE SCHEMA ON :Node(id INTEGER)")); + + EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 1U); + auto stream = Interpret("PROFILE UNWIND range(1, 1000) AS x CREATE (:Node {id: x});", {}); + std::vector expected_header{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"}; + EXPECT_EQ(stream.GetHeader(), expected_header); + std::vector expected_rows{"* CreateNode", "* Unwind", "* Once"}; + ASSERT_EQ(stream.GetResults().size(), expected_rows.size()); + auto expected_it = expected_rows.begin(); + for (const auto &row : stream.GetResults()) { + ASSERT_EQ(row.size(), 4U); + EXPECT_EQ(row.front().ValueString(), *expected_it); + ++expected_it; + } + // We should have a plan cache for UNWIND ... + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + // We should have AST cache for PROFILE ... and for inner UNWIND ... + EXPECT_EQ(interpreter_context.ast_cache.size(), 3U); + Interpret("UNWIND range(42, 4242) AS x CREATE (:Node {id: x});", {}); + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(interpreter_context.ast_cache.size(), 3U); +} + +TEST_F(InterpreterTest, Transactions) { + auto &interpreter = default_interpreter.interpreter; + { + ASSERT_THROW(interpreter.CommitTransaction(), memgraph::query::ExplicitTransactionUsageException); + ASSERT_THROW(interpreter.RollbackTransaction(), memgraph::query::ExplicitTransactionUsageException); + interpreter.BeginTransaction(); + ASSERT_THROW(interpreter.BeginTransaction(), memgraph::query::ExplicitTransactionUsageException); + auto [stream, qid] = Prepare("RETURN 2"); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "2"); + Pull(&stream, 1); + ASSERT_EQ(stream.GetSummary().count("has_more"), 1); + ASSERT_FALSE(stream.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueInt(), 2); + interpreter.CommitTransaction(); + } + { + interpreter.BeginTransaction(); + auto [stream, qid] = Prepare("RETURN 2"); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "2"); + Pull(&stream, 1); + ASSERT_EQ(stream.GetSummary().count("has_more"), 1); + ASSERT_FALSE(stream.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream.GetResults()[0].size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueInt(), 2); + interpreter.RollbackTransaction(); + } +} + +TEST_F(InterpreterTest, Qid) { + auto &interpreter = default_interpreter.interpreter; + { + interpreter.BeginTransaction(); + auto [stream, qid] = Prepare("RETURN 2"); + ASSERT_TRUE(qid); + ASSERT_THROW(Pull(&stream, {}, *qid + 1), memgraph::query::InvalidArgumentsException); + interpreter.RollbackTransaction(); + } + { + interpreter.BeginTransaction(); + auto [stream1, qid1] = Prepare("UNWIND(range(1,3)) as n RETURN n"); + ASSERT_TRUE(qid1); + ASSERT_EQ(stream1.GetHeader().size(), 1U); + EXPECT_EQ(stream1.GetHeader()[0], "n"); + + auto [stream2, qid2] = Prepare("UNWIND(range(4,6)) as n RETURN n"); + ASSERT_TRUE(qid2); + ASSERT_EQ(stream2.GetHeader().size(), 1U); + EXPECT_EQ(stream2.GetHeader()[0], "n"); + + Pull(&stream1, 1, qid1); + ASSERT_EQ(stream1.GetSummary().count("has_more"), 1); + ASSERT_TRUE(stream1.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream1.GetResults().size(), 1U); + ASSERT_EQ(stream1.GetResults()[0].size(), 1U); + ASSERT_EQ(stream1.GetResults()[0][0].ValueInt(), 1); + + auto [stream3, qid3] = Prepare("UNWIND(range(7,9)) as n RETURN n"); + ASSERT_TRUE(qid3); + ASSERT_EQ(stream3.GetHeader().size(), 1U); + EXPECT_EQ(stream3.GetHeader()[0], "n"); + + Pull(&stream2, {}, qid2); + ASSERT_EQ(stream2.GetSummary().count("has_more"), 1); + ASSERT_FALSE(stream2.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream2.GetResults().size(), 3U); + ASSERT_EQ(stream2.GetResults()[0].size(), 1U); + ASSERT_EQ(stream2.GetResults()[0][0].ValueInt(), 4); + ASSERT_EQ(stream2.GetResults()[1][0].ValueInt(), 5); + ASSERT_EQ(stream2.GetResults()[2][0].ValueInt(), 6); + + Pull(&stream3, 1, qid3); + ASSERT_EQ(stream3.GetSummary().count("has_more"), 1); + ASSERT_TRUE(stream3.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream3.GetResults().size(), 1U); + ASSERT_EQ(stream3.GetResults()[0].size(), 1U); + ASSERT_EQ(stream3.GetResults()[0][0].ValueInt(), 7); + + Pull(&stream1, {}, qid1); + ASSERT_EQ(stream1.GetSummary().count("has_more"), 1); + ASSERT_FALSE(stream1.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream1.GetResults().size(), 3U); + ASSERT_EQ(stream1.GetResults()[1].size(), 1U); + ASSERT_EQ(stream1.GetResults()[1][0].ValueInt(), 2); + ASSERT_EQ(stream1.GetResults()[2][0].ValueInt(), 3); + + Pull(&stream3); + ASSERT_EQ(stream3.GetSummary().count("has_more"), 1); + ASSERT_FALSE(stream3.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream3.GetResults().size(), 3U); + ASSERT_EQ(stream3.GetResults()[1].size(), 1U); + ASSERT_EQ(stream3.GetResults()[1][0].ValueInt(), 8); + ASSERT_EQ(stream3.GetResults()[2][0].ValueInt(), 9); + + interpreter.CommitTransaction(); + } +} + +namespace { +// copied from utils_csv_parsing.cpp - tmp dir management and csv file writer +class TmpDirManager final { + public: + explicit TmpDirManager(const std::string_view directory) + : tmp_dir_{std::filesystem::temp_directory_path() / directory} { + CreateDir(); + } + ~TmpDirManager() { Clear(); } + + const std::filesystem::path &Path() const { return tmp_dir_; } + + private: + std::filesystem::path tmp_dir_; + + void CreateDir() { + if (!std::filesystem::exists(tmp_dir_)) { + std::filesystem::create_directory(tmp_dir_); + } + } + + void Clear() { + if (!std::filesystem::exists(tmp_dir_)) return; + std::filesystem::remove_all(tmp_dir_); + } +}; + +class FileWriter { + public: + explicit FileWriter(const std::filesystem::path path) { stream_.open(path); } + + FileWriter(const FileWriter &) = delete; + FileWriter &operator=(const FileWriter &) = delete; + + FileWriter(FileWriter &&) = delete; + FileWriter &operator=(FileWriter &&) = delete; + + void Close() { stream_.close(); } + + size_t WriteLine(const std::string_view line) { + if (!stream_.is_open()) { + return 0; + } + + stream_ << line << std::endl; + + // including the newline character + return line.size() + 1; + } + + private: + std::ofstream stream_; +}; + +std::string CreateRow(const std::vector &columns, const std::string_view delim) { + return memgraph::utils::Join(columns, delim); +} +} // namespace + +TEST_F(InterpreterTest, LoadCsvClause) { + auto dir_manager = TmpDirManager("csv_directory"); + const auto csv_path = dir_manager.Path() / "file.csv"; + auto writer = FileWriter(csv_path); + + const std::string delimiter{"|"}; + + const std::vector header{"A", "B", "C"}; + writer.WriteLine(CreateRow(header, delimiter)); + + const std::vector good_columns_1{"a", "b", "c"}; + writer.WriteLine(CreateRow(good_columns_1, delimiter)); + + const std::vector bad_columns{"\"\"1", "2", "3"}; + writer.WriteLine(CreateRow(bad_columns, delimiter)); + + const std::vector good_columns_2{"d", "e", "f"}; + writer.WriteLine(CreateRow(good_columns_2, delimiter)); + + writer.Close(); + + { + const std::string query = fmt::format(R"(LOAD CSV FROM "{}" WITH HEADER IGNORE BAD DELIMITER "{}" AS x RETURN + x.A)", + csv_path.string(), delimiter); + auto [stream, qid] = Prepare(query); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "x.A"); + + Pull(&stream, 1); + ASSERT_EQ(stream.GetSummary().count("has_more"), 1); + ASSERT_TRUE(stream.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream.GetResults().size(), 1U); + ASSERT_EQ(stream.GetResults()[0][0].ValueString(), "a"); + + Pull(&stream, 1); + ASSERT_EQ(stream.GetSummary().count("has_more"), 1); + ASSERT_FALSE(stream.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream.GetResults().size(), 2U); + ASSERT_EQ(stream.GetResults()[1][0].ValueString(), "d"); + } + + { + const std::string query = fmt::format(R"(LOAD CSV FROM "{}" WITH HEADER IGNORE BAD DELIMITER "{}" AS x RETURN + x.C)", + csv_path.string(), delimiter); + auto [stream, qid] = Prepare(query); + ASSERT_EQ(stream.GetHeader().size(), 1U); + EXPECT_EQ(stream.GetHeader()[0], "x.C"); + + Pull(&stream); + ASSERT_EQ(stream.GetSummary().count("has_more"), 1); + ASSERT_FALSE(stream.GetSummary().at("has_more").ValueBool()); + ASSERT_EQ(stream.GetResults().size(), 2U); + ASSERT_EQ(stream.GetResults()[0][0].ValueString(), "c"); + ASSERT_EQ(stream.GetResults()[1][0].ValueString(), "f"); + } +} + +TEST_F(InterpreterTest, CacheableQueries) { + const auto &interpreter_context = default_interpreter.interpreter_context; + // This should be cached + { + SCOPED_TRACE("Cacheable query"); + Interpret("RETURN 1"); + EXPECT_EQ(interpreter_context.ast_cache.size(), 1U); + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + } + + { + SCOPED_TRACE("Uncacheable query"); + // Queries which are calling procedure should not be cached because the + // result signature could be changed + Interpret("CALL mg.load_all()"); + EXPECT_EQ(interpreter_context.ast_cache.size(), 1U); + EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); + } +} + +TEST_F(InterpreterTest, AllowLoadCsvConfig) { + const auto check_load_csv_queries = [&](const bool allow_load_csv) { + TmpDirManager directory_manager{"allow_load_csv"}; + const auto csv_path = directory_manager.Path() / "file.csv"; + auto writer = FileWriter(csv_path); + const std::vector data{"A", "B", "C"}; + writer.WriteLine(CreateRow(data, ",")); + writer.Close(); + + const std::array queries = { + fmt::format("LOAD CSV FROM \"{}\" WITH HEADER AS row RETURN row", csv_path.string()), + "CREATE TRIGGER trigger ON CREATE BEFORE COMMIT EXECUTE LOAD CSV FROM 'file.csv' WITH HEADER AS row RETURN " + "row"}; + + InterpreterFaker interpreter_faker{&db_, {.query = {.allow_load_csv = allow_load_csv}}, directory_manager.Path()}; + for (const auto &query : queries) { + if (allow_load_csv) { + SCOPED_TRACE(fmt::format("'{}' should not throw because LOAD CSV is allowed", query)); + ASSERT_NO_THROW(interpreter_faker.Interpret(query)); + } else { + SCOPED_TRACE(fmt::format("'{}' should throw becuase LOAD CSV is not allowed", query)); + ASSERT_THROW(interpreter_faker.Interpret(query), memgraph::utils::BasicException); + } + SCOPED_TRACE(fmt::format("Normal query should not throw (allow_load_csv: {})", allow_load_csv)); + ASSERT_NO_THROW(interpreter_faker.Interpret("RETURN 1")); + } + }; + + check_load_csv_queries(true); + check_load_csv_queries(false); +} + +void AssertAllValuesAreZero(const std::map &map, + const std::vector &exceptions) { + for (const auto &[key, value] : map) { + if (const auto it = std::find(exceptions.begin(), exceptions.end(), key); it != exceptions.end()) continue; + ASSERT_EQ(value.ValueInt(), 0) << "Value " << key << " actual: " << value.ValueInt() << ", expected 0"; + } +} + +TEST_F(InterpreterTest, ExecutionStatsIsValid) { + { + auto [stream, qid] = Prepare("MATCH (n) DELETE n;"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("stats"), 0); + } + { + EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :L1(name STRING)")); + std::array stats_keys{"nodes-created", "nodes-deleted", "relationships-created", "relationships-deleted", + "properties-set", "labels-added", "labels-removed"}; + auto [stream, qid] = Prepare("CREATE (:L1 {name: 'name1'});"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("stats"), 1); + ASSERT_TRUE(stream.GetSummary().at("stats").IsMap()); + auto stats = stream.GetSummary().at("stats").ValueMap(); + ASSERT_TRUE( + std::all_of(stats_keys.begin(), stats_keys.end(), [&stats](const auto &key) { return stats.contains(key); })); + AssertAllValuesAreZero(stats, {"nodes-created"}); + } +} + +TEST_F(InterpreterTest, ExecutionStatsValues) { + EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :L1(name STRING)")); + { + auto [stream, qid] = + Prepare("CREATE (:L1{name: 'name1'}),(:L1{name: 'name2'}),(:L1{name: 'name3'}),(:L1{name: 'name4'});"); + + Pull(&stream); + auto stats = stream.GetSummary().at("stats").ValueMap(); + ASSERT_EQ(stats["nodes-created"].ValueInt(), 4); + AssertAllValuesAreZero(stats, {"nodes-created", "labels-added"}); + } + { + auto [stream, qid] = Prepare("MATCH (n) DELETE n;"); + Pull(&stream); + + auto stats = stream.GetSummary().at("stats").ValueMap(); + ASSERT_EQ(stats["nodes-deleted"].ValueInt(), 4); + AssertAllValuesAreZero(stats, {"nodes-deleted"}); + } + { + auto [stream, qid] = + Prepare("CREATE (n:L1 {name: 'name5'})-[:TO]->(m:L1{name: 'name6'}), (n)-[:TO]->(m), (n)-[:TO]->(m);"); + + Pull(&stream); + + auto stats = stream.GetSummary().at("stats").ValueMap(); + ASSERT_EQ(stats["nodes-created"].ValueInt(), 2); + ASSERT_EQ(stats["relationships-created"].ValueInt(), 3); + AssertAllValuesAreZero(stats, {"nodes-created", "relationships-created"}); + } + { + auto [stream, qid] = Prepare("MATCH (n) DETACH DELETE n;"); + Pull(&stream); + + auto stats = stream.GetSummary().at("stats").ValueMap(); + ASSERT_EQ(stats["nodes-deleted"].ValueInt(), 2); + ASSERT_EQ(stats["relationships-deleted"].ValueInt(), 3); + AssertAllValuesAreZero(stats, {"nodes-deleted", "relationships-deleted"}); + } + { + auto [stream, qid] = Prepare("CREATE (n:L1 {name: 'name7'}) SET n:L2:L3:L4"); + Pull(&stream); + + auto stats = stream.GetSummary().at("stats").ValueMap(); + ASSERT_EQ(stats["nodes-created"].ValueInt(), 1); + ASSERT_EQ(stats["labels-added"].ValueInt(), 3); + AssertAllValuesAreZero(stats, {"nodes-created", "labels-added"}); + } + { + auto [stream, qid] = Prepare("MATCH (n:L1) SET n.name2='test';"); + Pull(&stream); + + auto stats = stream.GetSummary().at("stats").ValueMap(); + ASSERT_EQ(stats["properties-set"].ValueInt(), 1); + AssertAllValuesAreZero(stats, {"properties-set"}); + } +} + +TEST_F(InterpreterTest, NotificationsValidStructure) { + { + auto [stream, qid] = Prepare("MATCH (n) DELETE n;"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 0); + } + { + auto [stream, qid] = Prepare("CREATE INDEX ON :Person(id);"); + Pull(&stream); + + // Assert notifications list + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + ASSERT_TRUE(stream.GetSummary().at("notifications").IsList()); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + // Assert one notification structure + ASSERT_EQ(notifications.size(), 1); + ASSERT_TRUE(notifications[0].IsMap()); + auto notification = notifications[0].ValueMap(); + ASSERT_TRUE(notification.contains("severity")); + ASSERT_TRUE(notification.contains("code")); + ASSERT_TRUE(notification.contains("title")); + ASSERT_TRUE(notification.contains("description")); + ASSERT_TRUE(notification["severity"].IsString()); + ASSERT_TRUE(notification["code"].IsString()); + ASSERT_TRUE(notification["title"].IsString()); + ASSERT_TRUE(notification["description"].IsString()); + } +} + +TEST_F(InterpreterTest, IndexInfoNotifications) { + { + auto [stream, qid] = Prepare("CREATE INDEX ON :Person;"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "CreateIndex"); + ASSERT_EQ(notification["title"].ValueString(), "Created index on label Person on properties ."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("CREATE INDEX ON :Person(id);"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "CreateIndex"); + ASSERT_EQ(notification["title"].ValueString(), "Created index on label Person on properties id."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("CREATE INDEX ON :Person(id);"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "IndexAlreadyExists"); + ASSERT_EQ(notification["title"].ValueString(), "Index on label Person on properties id already exists."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("DROP INDEX ON :Person(id);"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "DropIndex"); + ASSERT_EQ(notification["title"].ValueString(), "Dropped index on label Person on properties id."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("DROP INDEX ON :Person(id);"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "IndexDoesNotExist"); + ASSERT_EQ(notification["title"].ValueString(), "Index on label Person on properties id doesn't exist."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } +} + +TEST_F(InterpreterTest, ConstraintUniqueInfoNotifications) { + { + auto [stream, qid] = Prepare("CREATE CONSTRAINT ON (n:Person) ASSERT n.email, n.id IS UNIQUE;"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "CreateConstraint"); + ASSERT_EQ(notification["title"].ValueString(), + "Created UNIQUE constraint on label Person on properties email, id."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("CREATE CONSTRAINT ON (n:Person) ASSERT n.email, n.id IS UNIQUE;"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "ConstraintAlreadyExists"); + ASSERT_EQ(notification["title"].ValueString(), + "Constraint UNIQUE on label Person on properties email, id already exists."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("DROP CONSTRAINT ON (n:Person) ASSERT n.email, n.id IS UNIQUE;"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "DropConstraint"); + ASSERT_EQ(notification["title"].ValueString(), + "Dropped UNIQUE constraint on label Person on properties email, id."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("DROP CONSTRAINT ON (n:Person) ASSERT n.email, n.id IS UNIQUE;"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "ConstraintDoesNotExist"); + ASSERT_EQ(notification["title"].ValueString(), + "Constraint UNIQUE on label Person on properties email, id doesn't exist."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } +} + +TEST_F(InterpreterTest, ConstraintExistsInfoNotifications) { + { + auto [stream, qid] = Prepare("CREATE CONSTRAINT ON (n:L1) ASSERT EXISTS (n.name);"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "CreateConstraint"); + ASSERT_EQ(notification["title"].ValueString(), "Created EXISTS constraint on label L1 on properties name."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("CREATE CONSTRAINT ON (n:L1) ASSERT EXISTS (n.name);"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "ConstraintAlreadyExists"); + ASSERT_EQ(notification["title"].ValueString(), "Constraint EXISTS on label L1 on properties name already exists."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("DROP CONSTRAINT ON (n:L1) ASSERT EXISTS (n.name);"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "DropConstraint"); + ASSERT_EQ(notification["title"].ValueString(), "Dropped EXISTS constraint on label L1 on properties name."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("DROP CONSTRAINT ON (n:L1) ASSERT EXISTS (n.name);"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "ConstraintDoesNotExist"); + ASSERT_EQ(notification["title"].ValueString(), "Constraint EXISTS on label L1 on properties name doesn't exist."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } +} + +TEST_F(InterpreterTest, TriggerInfoNotifications) { + { + auto [stream, qid] = Prepare( + "CREATE TRIGGER bestTriggerEver ON CREATE AFTER COMMIT EXECUTE " + "CREATE ();"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "CreateTrigger"); + ASSERT_EQ(notification["title"].ValueString(), "Created trigger bestTriggerEver."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } + { + auto [stream, qid] = Prepare("DROP TRIGGER bestTriggerEver;"); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "DropTrigger"); + ASSERT_EQ(notification["title"].ValueString(), "Dropped trigger bestTriggerEver."); + ASSERT_EQ(notification["description"].ValueString(), ""); + } +} + +TEST_F(InterpreterTest, LoadCsvClauseNotification) { + auto dir_manager = TmpDirManager("csv_directory"); + const auto csv_path = dir_manager.Path() / "file.csv"; + auto writer = FileWriter(csv_path); + + const std::string delimiter{"|"}; + + const std::vector header{"A", "B", "C"}; + writer.WriteLine(CreateRow(header, delimiter)); + + const std::vector good_columns_1{"a", "b", "c"}; + writer.WriteLine(CreateRow(good_columns_1, delimiter)); + + writer.Close(); + + const std::string query = fmt::format(R"(LOAD CSV FROM "{}" WITH HEADER IGNORE BAD DELIMITER "{}" AS x RETURN x;)", + csv_path.string(), delimiter); + auto [stream, qid] = Prepare(query); + Pull(&stream); + + ASSERT_EQ(stream.GetSummary().count("notifications"), 1); + auto notifications = stream.GetSummary().at("notifications").ValueList(); + + auto notification = notifications[0].ValueMap(); + ASSERT_EQ(notification["severity"].ValueString(), "INFO"); + ASSERT_EQ(notification["code"].ValueString(), "LoadCSVTip"); + ASSERT_EQ(notification["title"].ValueString(), + "It's important to note that the parser parses the values as strings. It's up to the user to " + "convert the parsed row values to the appropriate type. This can be done using the built-in " + "conversion functions such as ToInteger, ToFloat, ToBoolean etc."); + ASSERT_EQ(notification["description"].ValueString(), ""); +} + +TEST_F(InterpreterTest, CreateSchemaMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING, age INTEGER)"), + memgraph::query::ConstraintInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, ShowSchemasMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("SHOW SCHEMAS"), memgraph::query::ConstraintInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, ShowSchemaMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("SHOW SCHEMA ON :label"), memgraph::query::ConstraintInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, DropSchemaMulticommandTransaction) { + Interpret("BEGIN"); + ASSERT_THROW(Interpret("DROP SCHEMA ON :label"), memgraph::query::ConstraintInMulticommandTxException); + Interpret("ROLLBACK"); +} + +TEST_F(InterpreterTest, SchemaTestCreateAndShow) { + // Empty schema type map should result with syntax exception. + ASSERT_THROW(Interpret("CREATE SCHEMA ON :label();"), memgraph::query::SyntaxException); + + // Duplicate properties are should also cause an exception + ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING, name STRING);"), memgraph::query::SemanticException); + ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING, name INTEGER);"), memgraph::query::SemanticException); + + { + // Cannot create same schema twice + Interpret("CREATE SCHEMA ON :label(name STRING, age INTEGER)"); + ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING);"), memgraph::query::QueryException); + } + // Show schema + { + auto stream = Interpret("SHOW SCHEMA ON :label"); + ASSERT_EQ(stream.GetHeader().size(), 2U); + const auto &header = stream.GetHeader(); + ASSERT_EQ(header[0], "property_name"); + ASSERT_EQ(header[1], "property_type"); + ASSERT_EQ(stream.GetResults().size(), 2U); + std::unordered_map result_table{{"age", "Integer"}, {"name", "String"}}; + + const auto &result = stream.GetResults().front(); + ASSERT_EQ(result.size(), 2U); + const auto key1 = result[0].ValueString(); + ASSERT_TRUE(result_table.contains(key1)); + ASSERT_EQ(result[1].ValueString(), result_table[key1]); + + const auto &result2 = stream.GetResults().front(); + ASSERT_EQ(result2.size(), 2U); + const auto key2 = result2[0].ValueString(); + ASSERT_TRUE(result_table.contains(key2)); + ASSERT_EQ(result[1].ValueString(), result_table[key2]); + } + // Create Another Schema + Interpret("CREATE SCHEMA ON :label2(place STRING, dur DURATION)"); + + // Show schemas + { + auto stream = Interpret("SHOW SCHEMAS"); + ASSERT_EQ(stream.GetHeader().size(), 2U); + const auto &header = stream.GetHeader(); + ASSERT_EQ(header[0], "label"); + ASSERT_EQ(header[1], "primary_key"); + ASSERT_EQ(stream.GetResults().size(), 2U); + std::unordered_map> result_table{ + {"label", {"name::String", "age::Integer"}}, {"label2", {"place::String", "dur::Duration"}}}; + + const auto &result = stream.GetResults().front(); + ASSERT_EQ(result.size(), 2U); + const auto key1 = result[0].ValueString(); + ASSERT_TRUE(result_table.contains(key1)); + const auto primary_key_split = StringToUnorderedSet(result[1].ValueString()); + ASSERT_EQ(primary_key_split.size(), 2); + ASSERT_TRUE(primary_key_split == result_table[key1]) << "actual value is: " << result[1].ValueString(); + + const auto &result2 = stream.GetResults().front(); + ASSERT_EQ(result2.size(), 2U); + const auto key2 = result2[0].ValueString(); + ASSERT_TRUE(result_table.contains(key2)); + const auto primary_key_split2 = StringToUnorderedSet(result2[1].ValueString()); + ASSERT_EQ(primary_key_split2.size(), 2); + ASSERT_TRUE(primary_key_split2 == result_table[key2]) << "Real value is: " << result[1].ValueString(); + } +} + +TEST_F(InterpreterTest, SchemaTestCreateDropAndShow) { + Interpret("CREATE SCHEMA ON :label(name STRING, age INTEGER)"); + // Wrong syntax for dropping schema. + ASSERT_THROW(Interpret("DROP SCHEMA ON :label();"), memgraph::query::SyntaxException); + // Cannot drop non existant schema. + ASSERT_THROW(Interpret("DROP SCHEMA ON :label1;"), memgraph::query::QueryException); + + // Create Schema and Drop + auto get_number_of_schemas = [this]() { + auto stream = Interpret("SHOW SCHEMAS"); + return stream.GetResults().size(); + }; + + ASSERT_EQ(get_number_of_schemas(), 1); + Interpret("CREATE SCHEMA ON :label1(name STRING, age INTEGER)"); + ASSERT_EQ(get_number_of_schemas(), 2); + Interpret("CREATE SCHEMA ON :label2(name STRING, alive BOOL)"); + ASSERT_EQ(get_number_of_schemas(), 3); + Interpret("DROP SCHEMA ON :label1"); + ASSERT_EQ(get_number_of_schemas(), 2); + Interpret("CREATE SCHEMA ON :label3(name STRING, birthday LOCALDATETIME)"); + ASSERT_EQ(get_number_of_schemas(), 3); + Interpret("DROP SCHEMA ON :label2"); + ASSERT_EQ(get_number_of_schemas(), 2); + Interpret("CREATE SCHEMA ON :label4(name STRING, age DURATION)"); + ASSERT_EQ(get_number_of_schemas(), 3); + Interpret("DROP SCHEMA ON :label3"); + ASSERT_EQ(get_number_of_schemas(), 2); + Interpret("DROP SCHEMA ON :label"); + ASSERT_EQ(get_number_of_schemas(), 1); + + // Show schemas + auto stream = Interpret("SHOW SCHEMAS"); + ASSERT_EQ(stream.GetHeader().size(), 2U); + const auto &header = stream.GetHeader(); + ASSERT_EQ(header[0], "label"); + ASSERT_EQ(header[1], "primary_key"); + ASSERT_EQ(stream.GetResults().size(), 1U); + std::unordered_map> result_table{ + {"label4", {"name::String", "age::Duration"}}}; + + const auto &result = stream.GetResults().front(); + ASSERT_EQ(result.size(), 2U); + const auto key1 = result[0].ValueString(); + ASSERT_TRUE(result_table.contains(key1)); + const auto primary_key_split = StringToUnorderedSet(result[1].ValueString()); + ASSERT_EQ(primary_key_split.size(), 2); + ASSERT_TRUE(primary_key_split == result_table[key1]); +} diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index b0dd41a72..903cacf12 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -531,9 +531,9 @@ auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vec memgraph::query::test_common::OnCreate { \ std::vector { __VA_ARGS__ } \ } -#define CREATE_INDEX_ON(label, property) \ +#define CREATE_INDEX_ON(label, property) \ storage.Create(memgraph::query::IndexQuery::Action::CREATE, (label), \ - std::vector{(property)}) + std::vector{(property)}) #define QUERY(...) memgraph::query::test_common::GetQuery(storage, __VA_ARGS__) #define SINGLE_QUERY(...) memgraph::query::test_common::GetSingleQuery(storage.Create(), __VA_ARGS__) #define UNION(...) memgraph::query::test_common::GetCypherUnion(storage.Create(true), __VA_ARGS__) diff --git a/tests/unit/query_plan_bag_semantics.cpp b/tests/unit/query_plan_bag_semantics.cpp index f0b0916f4..d4cbaecf5 100644 --- a/tests/unit/query_plan_bag_semantics.cpp +++ b/tests/unit/query_plan_bag_semantics.cpp @@ -9,11 +9,6 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. -// -// Copyright 2017 Memgraph -// Created by Florijan Stamenkovic on 14.03.17. -// - #include #include #include diff --git a/tests/unit/query_plan_common.hpp b/tests/unit/query_plan_common.hpp index 037875754..48667607d 100644 --- a/tests/unit/query_plan_common.hpp +++ b/tests/unit/query_plan_common.hpp @@ -99,6 +99,16 @@ ScanAllTuple MakeScanAll(AstStorage &storage, SymbolTable &symbol_table, const s return ScanAllTuple{node, logical_op, symbol}; } +ScanAllTuple MakeScanAllNew(AstStorage &storage, SymbolTable &symbol_table, const std::string &identifier, + std::shared_ptr input = {nullptr}, + memgraph::storage::View view = memgraph::storage::View::OLD) { + auto *node = NODE(identifier, "label"); + auto symbol = symbol_table.CreateSymbol(identifier, true); + node->identifier_->MapTo(symbol); + auto logical_op = std::make_shared(input, symbol, view); + return ScanAllTuple{node, logical_op, symbol}; +} + /** * Creates and returns a tuple of stuff for a scan-all starting * from the node with the given name and label. diff --git a/tests/unit/query_v2_query_plan_accumulate_aggregate.cpp b/tests/unit/query_v2_query_plan_accumulate_aggregate.cpp new file mode 100644 index 000000000..224a3cc69 --- /dev/null +++ b/tests/unit/query_v2_query_plan_accumulate_aggregate.cpp @@ -0,0 +1,631 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include +#include +#include +#include + +#include "common/types.hpp" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "query/context.hpp" +#include "query/exceptions.hpp" +#include "query/plan/operator.hpp" +#include "query_plan_common.hpp" +#include "storage/v2/property_value.hpp" + +using namespace memgraph::query; +using namespace memgraph::query::plan; +using memgraph::query::test_common::ToIntList; +using memgraph::query::test_common::ToIntMap; +using testing::UnorderedElementsAre; + +namespace memgraph::query::v2::tests { + +class QueryPlanAccumulateAggregateTest : public ::testing::Test { + protected: + void SetUp() override { + ASSERT_TRUE(db.CreateSchema(label, {storage::SchemaProperty{property, common::SchemaType::INT}})); + } + + storage::Storage db; + const storage::LabelId label{db.NameToLabel("label")}; + const storage::PropertyId property{db.NameToProperty("property")}; +}; + +TEST_F(QueryPlanAccumulateAggregateTest, Accumulate) { + // simulate the following two query execution on an empty db + // CREATE ({x:0})-[:T]->({x:0}) + // MATCH (n)--(m) SET n.x = n.x + 1, m.x = m.x + 1 RETURN n.x, m.x + // without accumulation we expected results to be [[1, 1], [2, 2]] + // with accumulation we expect them to be [[2, 2], [2, 2]] + + auto check = [&](bool accumulate) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto prop = dba.NameToProperty("x"); + + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + ASSERT_TRUE(v1.SetProperty(prop, storage::PropertyValue(0)).HasValue()); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + ASSERT_TRUE(v2.SetProperty(prop, storage::PropertyValue(0)).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, dba.NameToEdgeType("T")).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::BOTH, {}, "m", false, + storage::View::OLD); + + auto one = LITERAL(1); + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + auto set_n_p = std::make_shared(r_m.op_, prop, n_p, ADD(n_p, one)); + auto m_p = PROPERTY_LOOKUP(IDENT("m")->MapTo(r_m.node_sym_), prop); + auto set_m_p = std::make_shared(set_n_p, prop, m_p, ADD(m_p, one)); + + std::shared_ptr last_op = set_m_p; + if (accumulate) { + last_op = std::make_shared(last_op, std::vector{n.sym_, r_m.node_sym_}); + } + + auto n_p_ne = NEXPR("n.p", n_p)->MapTo(symbol_table.CreateSymbol("n_p_ne", true)); + auto m_p_ne = NEXPR("m.p", m_p)->MapTo(symbol_table.CreateSymbol("m_p_ne", true)); + auto produce = MakeProduce(last_op, n_p_ne, m_p_ne); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + std::vector results_data; + for (const auto &row : results) + for (const auto &column : row) results_data.emplace_back(column.ValueInt()); + if (accumulate) + EXPECT_THAT(results_data, ::testing::ElementsAre(2, 2, 2, 2)); + else + EXPECT_THAT(results_data, ::testing::ElementsAre(1, 1, 2, 2)); + }; + + check(false); + check(true); +} + +TEST_F(QueryPlanAccumulateAggregateTest, AccumulateAdvance) { + // we simulate 'CREATE (n) WITH n AS n MATCH (m) RETURN m' + // to get correct results we need to advance the command + auto check = [&](bool advance) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + NodeCreationInfo node; + node.symbol = symbol_table.CreateSymbol("n", true); + node.labels = {label}; + std::get>>(node.properties) + .emplace_back(property, LITERAL(1)); + auto create = std::make_shared(nullptr, node); + auto accumulate = std::make_shared(create, std::vector{node.symbol}, advance); + auto match = MakeScanAll(storage, symbol_table, "m", accumulate); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(advance ? 1 : 0, PullAll(*match.op_, &context)); + }; + check(false); + check(true); +} + +std::shared_ptr MakeAggregationProduce(std::shared_ptr input, SymbolTable &symbol_table, + AstStorage &storage, const std::vector aggr_inputs, + const std::vector aggr_ops, + const std::vector group_by_exprs, + const std::vector remember) { + // prepare all the aggregations + std::vector aggregates; + std::vector named_expressions; + + auto aggr_inputs_it = aggr_inputs.begin(); + for (auto aggr_op : aggr_ops) { + // TODO change this from using IDENT to using AGGREGATION + // once AGGREGATION is handled properly in ExpressionEvaluation + auto aggr_sym = symbol_table.CreateSymbol("aggregation", true); + auto named_expr = + NEXPR("", IDENT("aggregation")->MapTo(aggr_sym))->MapTo(symbol_table.CreateSymbol("named_expression", true)); + named_expressions.push_back(named_expr); + // the key expression is only used in COLLECT_MAP + Expression *key_expr_ptr = aggr_op == Aggregation::Op::COLLECT_MAP ? LITERAL("key") : nullptr; + aggregates.emplace_back(Aggregate::Element{*aggr_inputs_it++, key_expr_ptr, aggr_op, aggr_sym}); + } + + // Produce will also evaluate group_by expressions and return them after the + // aggregations. + for (auto group_by_expr : group_by_exprs) { + auto named_expr = NEXPR("", group_by_expr)->MapTo(symbol_table.CreateSymbol("named_expression", true)); + named_expressions.push_back(named_expr); + } + auto aggregation = std::make_shared(input, aggregates, group_by_exprs, remember); + return std::make_shared(aggregation, named_expressions); +} + +// /** Test fixture for all the aggregation ops in one return. */ +class QueryPlanAggregateOps : public ::testing::Test { + protected: + void SetUp() override { + ASSERT_TRUE(db.CreateSchema(label, {storage::SchemaProperty{property, common::SchemaType::INT}})); + } + storage::Storage db; + storage::Storage::Accessor storage_dba{db.Access()}; + DbAccessor dba{&storage_dba}; + storage::LabelId label = db.NameToLabel("label"); + storage::PropertyId property = db.NameToProperty("property"); + storage::PropertyId prop = db.NameToProperty("prop"); + + AstStorage storage; + SymbolTable symbol_table; + + void AddData() { + // setup is several nodes most of which have an int property set + // we will take the sum, avg, min, max and count + // we won't group by anything + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}) + ->SetProperty(prop, storage::PropertyValue(5)) + .HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}) + ->SetProperty(prop, storage::PropertyValue(7)) + .HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}) + ->SetProperty(prop, storage::PropertyValue(12)) + .HasValue()); + // a missing property (null) gets ignored by all aggregations except + // COUNT(*) + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(4)}}).HasValue()); + dba.AdvanceCommand(); + } + + auto AggregationResults(bool with_group_by, std::vector ops = { + Aggregation::Op::COUNT, Aggregation::Op::COUNT, Aggregation::Op::MIN, + Aggregation::Op::MAX, Aggregation::Op::SUM, Aggregation::Op::AVG, + Aggregation::Op::COLLECT_LIST, Aggregation::Op::COLLECT_MAP}) { + // match all nodes and perform aggregations + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + + std::vector aggregation_expressions(ops.size(), n_p); + std::vector group_bys; + if (with_group_by) group_bys.push_back(n_p); + aggregation_expressions[0] = nullptr; + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, aggregation_expressions, ops, group_bys, {}); + auto context = MakeContext(storage, symbol_table, &dba); + return CollectProduce(*produce, &context); + } +}; + +TEST_F(QueryPlanAggregateOps, WithData) { + AddData(); + auto results = AggregationResults(false); + + ASSERT_EQ(results.size(), 1); + ASSERT_EQ(results[0].size(), 8); + // count(*) + ASSERT_EQ(results[0][0].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][0].ValueInt(), 4); + // count + ASSERT_EQ(results[0][1].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][1].ValueInt(), 3); + // min + ASSERT_EQ(results[0][2].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][2].ValueInt(), 5); + // max + ASSERT_EQ(results[0][3].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][3].ValueInt(), 12); + // sum + ASSERT_EQ(results[0][4].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][4].ValueInt(), 24); + // avg + ASSERT_EQ(results[0][5].type(), TypedValue::Type::Double); + EXPECT_FLOAT_EQ(results[0][5].ValueDouble(), 24 / 3.0); + // collect list + ASSERT_EQ(results[0][6].type(), TypedValue::Type::List); + EXPECT_THAT(ToIntList(results[0][6]), UnorderedElementsAre(5, 7, 12)); + // collect map + ASSERT_EQ(results[0][7].type(), TypedValue::Type::Map); + auto map = ToIntMap(results[0][7]); + ASSERT_EQ(map.size(), 1); + EXPECT_EQ(map.begin()->first, "key"); + EXPECT_FALSE(std::set({5, 7, 12}).insert(map.begin()->second).second); +} + +TEST_F(QueryPlanAggregateOps, WithoutDataWithGroupBy) { + { + auto results = AggregationResults(true, {Aggregation::Op::COUNT}); + EXPECT_EQ(results.size(), 0); + } + { + auto results = AggregationResults(true, {Aggregation::Op::SUM}); + EXPECT_EQ(results.size(), 0); + } + { + auto results = AggregationResults(true, {Aggregation::Op::AVG}); + EXPECT_EQ(results.size(), 0); + } + { + auto results = AggregationResults(true, {Aggregation::Op::MIN}); + EXPECT_EQ(results.size(), 0); + } + { + auto results = AggregationResults(true, {Aggregation::Op::MAX}); + EXPECT_EQ(results.size(), 0); + } + { + auto results = AggregationResults(true, {Aggregation::Op::COLLECT_LIST}); + EXPECT_EQ(results.size(), 0); + } + { + auto results = AggregationResults(true, {Aggregation::Op::COLLECT_MAP}); + EXPECT_EQ(results.size(), 0); + } +} + +TEST_F(QueryPlanAggregateOps, WithoutDataWithoutGroupBy) { + auto results = AggregationResults(false); + ASSERT_EQ(results.size(), 1); + ASSERT_EQ(results[0].size(), 8); + // count(*) + ASSERT_EQ(results[0][0].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][0].ValueInt(), 0); + // count + ASSERT_EQ(results[0][1].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][1].ValueInt(), 0); + // min + EXPECT_TRUE(results[0][2].IsNull()); + // max + EXPECT_TRUE(results[0][3].IsNull()); + // sum + EXPECT_TRUE(results[0][4].IsNull()); + // avg + EXPECT_TRUE(results[0][5].IsNull()); + // collect list + ASSERT_EQ(results[0][6].type(), TypedValue::Type::List); + EXPECT_EQ(ToIntList(results[0][6]).size(), 0); + // collect map + ASSERT_EQ(results[0][7].type(), TypedValue::Type::Map); + EXPECT_EQ(ToIntMap(results[0][7]).size(), 0); +} + +TEST_F(QueryPlanAccumulateAggregateTest, AggregateGroupByValues) { + // Tests that distinct groups are aggregated properly for values of all types. + // Also test the "remember" part of the Aggregation API as final results are + // obtained via a property lookup of a remembered node. + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // a vector of storage::PropertyValue to be set as property values on vertices + // most of them should result in a distinct group (commented where not) + std::vector group_by_vals; + group_by_vals.emplace_back(4); + group_by_vals.emplace_back(7); + group_by_vals.emplace_back(7.3); + group_by_vals.emplace_back(7.2); + group_by_vals.emplace_back("Johhny"); + group_by_vals.emplace_back("Jane"); + group_by_vals.emplace_back("1"); + group_by_vals.emplace_back(true); + group_by_vals.emplace_back(false); + group_by_vals.emplace_back(std::vector{storage::PropertyValue(1)}); + group_by_vals.emplace_back(std::vector{storage::PropertyValue(1), storage::PropertyValue(2)}); + group_by_vals.emplace_back(std::vector{storage::PropertyValue(2), storage::PropertyValue(1)}); + group_by_vals.emplace_back(storage::PropertyValue()); + // should NOT result in another group because 7.0 == 7 + group_by_vals.emplace_back(7.0); + // should NOT result in another group + group_by_vals.emplace_back( + std::vector{storage::PropertyValue(1), storage::PropertyValue(2.0)}); + + // generate a lot of vertices and set props on them + auto prop = dba.NameToProperty("prop"); + for (int i = 0; i < 1000; ++i) + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}) + ->SetProperty(prop, group_by_vals[i % group_by_vals.size()]) + .HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + // match all nodes and perform aggregations + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p}, {Aggregation::Op::COUNT}, {n_p}, {n.sym_}); + + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + ASSERT_EQ(results.size(), group_by_vals.size() - 2); + std::unordered_set result_group_bys; + for (const auto &row : results) { + ASSERT_EQ(2, row.size()); + result_group_bys.insert(row[1]); + } + ASSERT_EQ(result_group_bys.size(), group_by_vals.size() - 2); + std::vector group_by_tvals; + group_by_tvals.reserve(group_by_vals.size()); + for (const auto &v : group_by_vals) group_by_tvals.emplace_back(v); + EXPECT_TRUE(std::is_permutation(group_by_tvals.begin(), group_by_tvals.end() - 2, result_group_bys.begin(), + TypedValue::BoolEqual{})); +} + +TEST_F(QueryPlanAccumulateAggregateTest, AggregateMultipleGroupBy) { + // in this test we have 3 different properties that have different values + // for different records and assert that we get the correct combination + // of values in our groups + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto prop1 = dba.NameToProperty("prop1"); + auto prop2 = dba.NameToProperty("prop2"); + auto prop3 = dba.NameToProperty("prop3"); + for (int i = 0; i < 2 * 3 * 5; ++i) { + auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(i)}}); + ASSERT_TRUE(v.SetProperty(prop1, storage::PropertyValue(static_cast(i % 2))).HasValue()); + ASSERT_TRUE(v.SetProperty(prop2, storage::PropertyValue(i % 3)).HasValue()); + ASSERT_TRUE(v.SetProperty(prop3, storage::PropertyValue("value" + std::to_string(i % 5))).HasValue()); + } + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + // match all nodes and perform aggregations + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p1 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop1); + auto n_p2 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop2); + auto n_p3 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop3); + + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p1}, {Aggregation::Op::COUNT}, + {n_p1, n_p2, n_p3}, {n.sym_}); + + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), 2 * 3 * 5); +} + +TEST(QueryPlan, AggregateNoInput) { + storage::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + + auto two = LITERAL(2); + auto produce = MakeAggregationProduce(nullptr, symbol_table, storage, {two}, {Aggregation::Op::COUNT}, {}, {}); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(1, results.size()); + EXPECT_EQ(1, results[0].size()); + EXPECT_EQ(TypedValue::Type::Int, results[0][0].type()); + EXPECT_EQ(1, results[0][0].ValueInt()); +} + +TEST_F(QueryPlanAccumulateAggregateTest, AggregateCountEdgeCases) { + // tests for detected bugs in the COUNT aggregation behavior + // ensure that COUNT returns correctly for + // - 0 vertices in database + // - 1 vertex in database, property not set + // - 1 vertex in database, property set + // - 2 vertices in database, property set on one + // - 2 vertices in database, property set on both + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto prop = dba.NameToProperty("prop"); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + + // returns -1 when there are no results + // otherwise returns MATCH (n) RETURN count(n.prop) + auto count = [&]() { + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p}, {Aggregation::Op::COUNT}, {}, {}); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + if (results.size() == 0) return -1L; + EXPECT_EQ(1, results.size()); + EXPECT_EQ(1, results[0].size()); + EXPECT_EQ(TypedValue::Type::Int, results[0][0].type()); + return results[0][0].ValueInt(); + }; + + // no vertices yet in database + EXPECT_EQ(0, count()); + + // one vertex, no property set + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(0, count()); + + // one vertex, property set + for (auto va : dba.Vertices(storage::View::OLD)) + ASSERT_TRUE(va.SetProperty(prop, storage::PropertyValue(42)).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(1, count()); + + // two vertices, one with property set + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(1, count()); + + // two vertices, both with property set + for (auto va : dba.Vertices(storage::View::OLD)) + ASSERT_TRUE(va.SetProperty(prop, storage::PropertyValue(42)).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(2, count()); +} + +TEST_F(QueryPlanAccumulateAggregateTest, AggregateFirstValueTypes) { + // testing exceptions that get emitted by the first-value + // type check + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto prop_string = dba.NameToProperty("string"); + ASSERT_TRUE(v1.SetProperty(prop_string, storage::PropertyValue("johhny")).HasValue()); + auto prop_int = dba.NameToProperty("int"); + ASSERT_TRUE(v1.SetProperty(prop_int, storage::PropertyValue(12)).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_prop_string = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop_string); + auto n_prop_int = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop_int); + auto n_id = n_prop_string->expression_; + + auto aggregate = [&](Expression *expression, Aggregation::Op aggr_op) { + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {expression}, {aggr_op}, {}, {}); + auto context = MakeContext(storage, symbol_table, &dba); + CollectProduce(*produce, &context); + }; + + // everything except for COUNT and COLLECT fails on a Vertex + aggregate(n_id, Aggregation::Op::COUNT); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::MIN), QueryRuntimeException); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::MAX), QueryRuntimeException); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::AVG), QueryRuntimeException); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::SUM), QueryRuntimeException); + + // on strings AVG and SUM fail + aggregate(n_prop_string, Aggregation::Op::COUNT); + aggregate(n_prop_string, Aggregation::Op::MIN); + aggregate(n_prop_string, Aggregation::Op::MAX); + EXPECT_THROW(aggregate(n_prop_string, Aggregation::Op::AVG), QueryRuntimeException); + EXPECT_THROW(aggregate(n_prop_string, Aggregation::Op::SUM), QueryRuntimeException); + + // on ints nothing fails + aggregate(n_prop_int, Aggregation::Op::COUNT); + aggregate(n_prop_int, Aggregation::Op::MIN); + aggregate(n_prop_int, Aggregation::Op::MAX); + aggregate(n_prop_int, Aggregation::Op::AVG); + aggregate(n_prop_int, Aggregation::Op::SUM); + aggregate(n_prop_int, Aggregation::Op::COLLECT_LIST); + aggregate(n_prop_int, Aggregation::Op::COLLECT_MAP); +} + +TEST_F(QueryPlanAccumulateAggregateTest, AggregateTypes) { + // testing exceptions that can get emitted by an aggregation + // does not check all combinations that can result in an exception + // (that logic is defined and tested by TypedValue) + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto p1 = dba.NameToProperty("p1"); // has only string props + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}) + ->SetProperty(p1, storage::PropertyValue("string")) + .HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}) + ->SetProperty(p1, storage::PropertyValue("str2")) + .HasValue()); + auto p2 = dba.NameToProperty("p2"); // combines int and bool + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}) + ->SetProperty(p2, storage::PropertyValue(42)) + .HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}) + ->SetProperty(p2, storage::PropertyValue(true)) + .HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p1 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), p1); + auto n_p2 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), p2); + + auto aggregate = [&](Expression *expression, Aggregation::Op aggr_op) { + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {expression}, {aggr_op}, {}, {}); + auto context = MakeContext(storage, symbol_table, &dba); + CollectProduce(*produce, &context); + }; + + // everything except for COUNT and COLLECT fails on a Vertex + auto n_id = n_p1->expression_; + aggregate(n_id, Aggregation::Op::COUNT); + aggregate(n_id, Aggregation::Op::COLLECT_LIST); + aggregate(n_id, Aggregation::Op::COLLECT_MAP); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::MIN), QueryRuntimeException); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::MAX), QueryRuntimeException); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::AVG), QueryRuntimeException); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::SUM), QueryRuntimeException); + + // on strings AVG and SUM fail + aggregate(n_p1, Aggregation::Op::COUNT); + aggregate(n_p1, Aggregation::Op::COLLECT_LIST); + aggregate(n_p1, Aggregation::Op::COLLECT_MAP); + aggregate(n_p1, Aggregation::Op::MIN); + aggregate(n_p1, Aggregation::Op::MAX); + EXPECT_THROW(aggregate(n_p1, Aggregation::Op::AVG), QueryRuntimeException); + EXPECT_THROW(aggregate(n_p1, Aggregation::Op::SUM), QueryRuntimeException); + + // combination of int and bool, everything except COUNT and COLLECT fails + aggregate(n_p2, Aggregation::Op::COUNT); + aggregate(n_p2, Aggregation::Op::COLLECT_LIST); + aggregate(n_p2, Aggregation::Op::COLLECT_MAP); + EXPECT_THROW(aggregate(n_p2, Aggregation::Op::MIN), QueryRuntimeException); + EXPECT_THROW(aggregate(n_p2, Aggregation::Op::MAX), QueryRuntimeException); + EXPECT_THROW(aggregate(n_p2, Aggregation::Op::AVG), QueryRuntimeException); + EXPECT_THROW(aggregate(n_p2, Aggregation::Op::SUM), QueryRuntimeException); +} + +TEST(QueryPlan, Unwind) { + storage::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + + // UNWIND [ [1, true, "x"], [], ["bla"] ] AS x UNWIND x as y RETURN x, y + auto input_expr = storage.Create(std::vector{ + storage::PropertyValue(std::vector{ + storage::PropertyValue(1), storage::PropertyValue(true), storage::PropertyValue("x")}), + storage::PropertyValue(std::vector{}), + storage::PropertyValue(std::vector{storage::PropertyValue("bla")})}); + + auto x = symbol_table.CreateSymbol("x", true); + auto unwind_0 = std::make_shared(nullptr, input_expr, x); + auto x_expr = IDENT("x")->MapTo(x); + auto y = symbol_table.CreateSymbol("y", true); + auto unwind_1 = std::make_shared(unwind_0, x_expr, y); + + auto x_ne = NEXPR("x", x_expr)->MapTo(symbol_table.CreateSymbol("x_ne", true)); + auto y_ne = NEXPR("y", IDENT("y")->MapTo(y))->MapTo(symbol_table.CreateSymbol("y_ne", true)); + auto produce = MakeProduce(unwind_1, x_ne, y_ne); + + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + ASSERT_EQ(4, results.size()); + const std::vector expected_x_card{3, 3, 3, 1}; + auto expected_x_card_it = expected_x_card.begin(); + const std::vector expected_y{TypedValue(1), TypedValue(true), TypedValue("x"), TypedValue("bla")}; + auto expected_y_it = expected_y.begin(); + for (const auto &row : results) { + ASSERT_EQ(2, row.size()); + ASSERT_EQ(row[0].type(), TypedValue::Type::List); + EXPECT_EQ(row[0].ValueList().size(), *expected_x_card_it); + EXPECT_EQ(row[1].type(), expected_y_it->type()); + expected_x_card_it++; + expected_y_it++; + } +} +} // namespace memgraph::query::v2::tests diff --git a/tests/unit/query_v2_query_plan_bag_semantics.cpp b/tests/unit/query_v2_query_plan_bag_semantics.cpp new file mode 100644 index 000000000..496f1dc9b --- /dev/null +++ b/tests/unit/query_v2_query_plan_bag_semantics.cpp @@ -0,0 +1,309 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "query/context.hpp" +#include "query/exceptions.hpp" +#include "query/frontend/ast/ast.hpp" +#include "query/plan/operator.hpp" + +#include "query_plan_common.hpp" +#include "storage/v2/property_value.hpp" + +using namespace memgraph::query; +using namespace memgraph::query::plan; + +namespace memgraph::query::tests { + +class QueryPlanBagSemanticsTest : public testing::Test { + protected: + void SetUp() override { + ASSERT_TRUE(db.CreateSchema(label, {storage::SchemaProperty{property, common::SchemaType::INT}})); + } + + storage::Storage db; + const storage::LabelId label{db.NameToLabel("label")}; + const storage::PropertyId property{db.NameToProperty("property")}; +}; + +TEST_F(QueryPlanBagSemanticsTest, Skip) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n1"); + auto skip = std::make_shared(n.op_, LITERAL(2)); + + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(0, PullAll(*skip, &context)); + + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(0, PullAll(*skip, &context)); + + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(0, PullAll(*skip, &context)); + + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(1, PullAll(*skip, &context)); + + for (int i = 0; i < 10; ++i) { + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(i + 3)}}).HasValue()); + } + dba.AdvanceCommand(); + EXPECT_EQ(11, PullAll(*skip, &context)); +} + +TEST_F(QueryPlanBagSemanticsTest, Limit) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n1"); + auto skip = std::make_shared(n.op_, LITERAL(2)); + + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(0, PullAll(*skip, &context)); + + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(1, PullAll(*skip, &context)); + + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(2, PullAll(*skip, &context)); + + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(2, PullAll(*skip, &context)); + + for (int i = 0; i < 10; ++i) { + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(i + 3)}}).HasValue()); + } + dba.AdvanceCommand(); + EXPECT_EQ(2, PullAll(*skip, &context)); +} + +TEST_F(QueryPlanBagSemanticsTest, CreateLimit) { + // CREATE (n), (m) + // MATCH (n) CREATE (m) LIMIT 1 + // in the end we need to have 3 vertices in the db + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n1"); + NodeCreationInfo m; + m.symbol = symbol_table.CreateSymbol("m", true); + m.labels = {label}; + std::get>>(m.properties).emplace_back(property, LITERAL(3)); + auto c = std::make_shared(n.op_, m); + auto skip = std::make_shared(c, LITERAL(1)); + + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*skip, &context)); + dba.AdvanceCommand(); + EXPECT_EQ(3, CountIterable(dba.Vertices(storage::View::OLD))); +} + +TEST_F(QueryPlanBagSemanticsTest, OrderBy) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + auto prop = dba.NameToProperty("prop"); + + // contains a series of tests + // each test defines the ordering a vector of values in the desired order + auto Null = storage::PropertyValue(); + std::vector>> orderable{ + {Ordering::ASC, + {storage::PropertyValue(0), storage::PropertyValue(0), storage::PropertyValue(0.5), storage::PropertyValue(1), + storage::PropertyValue(2), storage::PropertyValue(12.6), storage::PropertyValue(42), Null, Null}}, + {Ordering::ASC, + {storage::PropertyValue(false), storage::PropertyValue(false), storage::PropertyValue(true), + storage::PropertyValue(true), Null, Null}}, + {Ordering::ASC, + {storage::PropertyValue("A"), storage::PropertyValue("B"), storage::PropertyValue("a"), + storage::PropertyValue("a"), storage::PropertyValue("aa"), storage::PropertyValue("ab"), + storage::PropertyValue("aba"), Null, Null}}, + {Ordering::DESC, + {Null, Null, storage::PropertyValue(33), storage::PropertyValue(33), storage::PropertyValue(32.5), + storage::PropertyValue(32), storage::PropertyValue(2.2), storage::PropertyValue(2.1), + storage::PropertyValue(0)}}, + {Ordering::DESC, {Null, storage::PropertyValue(true), storage::PropertyValue(false)}}, + {Ordering::DESC, {Null, storage::PropertyValue("zorro"), storage::PropertyValue("borro")}}}; + + for (const auto &order_value_pair : orderable) { + std::vector values; + values.reserve(order_value_pair.second.size()); + for (const auto &v : order_value_pair.second) values.emplace_back(v); + // empty database + for (auto vertex : dba.Vertices(storage::View::OLD)) ASSERT_TRUE(dba.DetachRemoveVertex(&vertex).HasValue()); + dba.AdvanceCommand(); + ASSERT_EQ(0, CountIterable(dba.Vertices(storage::View::OLD))); + + // take some effort to shuffle the values + // because we are testing that something not ordered gets ordered + // and need to take care it does not happen by accident + auto shuffled = values; + auto order_equal = [&values, &shuffled]() { + return std::equal(values.begin(), values.end(), shuffled.begin(), TypedValue::BoolEqual{}); + }; + for (int i = 0; i < 50 && order_equal(); ++i) { + std::random_shuffle(shuffled.begin(), shuffled.end()); + } + ASSERT_FALSE(order_equal()); + + // create the vertices + for (const auto &value : shuffled) { + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}) + ->SetProperty(prop, storage::PropertyValue(value)) + .HasValue()); + } + dba.AdvanceCommand(); + + // order by and collect results + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + auto order_by = std::make_shared(n.op_, std::vector{{order_value_pair.first, n_p}}, + std::vector{n.sym_}); + auto n_p_ne = NEXPR("n.p", n_p)->MapTo(symbol_table.CreateSymbol("n.p", true)); + auto produce = MakeProduce(order_by, n_p_ne); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + ASSERT_EQ(values.size(), results.size()); + for (int j = 0; j < results.size(); ++j) EXPECT_TRUE(TypedValue::BoolEqual{}(results[j][0], values[j])); + } +} + +TEST_F(QueryPlanBagSemanticsTest, OrderByMultiple) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + + auto p1 = dba.NameToProperty("p1"); + auto p2 = dba.NameToProperty("p2"); + + // create a bunch of vertices that in two properties + // have all the variations (with repetition) of N values. + // ensure that those vertices are not created in the + // "right" sequence, but randomized + const int N = 20; + std::vector> prop_values; + for (int i = 0; i < N * N; ++i) prop_values.emplace_back(i % N, i / N); + std::random_shuffle(prop_values.begin(), prop_values.end()); + for (const auto &pair : prop_values) { + auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + ASSERT_TRUE(v.SetProperty(p1, storage::PropertyValue(pair.first)).HasValue()); + ASSERT_TRUE(v.SetProperty(p2, storage::PropertyValue(pair.second)).HasValue()); + } + dba.AdvanceCommand(); + + // order by and collect results + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p1 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), p1); + auto n_p2 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), p2); + // order the results so we get + // (p1: 0, p2: N-1) + // (p1: 0, p2: N-2) + // ... + // (p1: N-1, p2:0) + auto order_by = std::make_shared(n.op_, + std::vector{ + {Ordering::ASC, n_p1}, + {Ordering::DESC, n_p2}, + }, + std::vector{n.sym_}); + auto n_p1_ne = NEXPR("n.p1", n_p1)->MapTo(symbol_table.CreateSymbol("n.p1", true)); + auto n_p2_ne = NEXPR("n.p2", n_p2)->MapTo(symbol_table.CreateSymbol("n.p2", true)); + auto produce = MakeProduce(order_by, n_p1_ne, n_p2_ne); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + ASSERT_EQ(N * N, results.size()); + for (int j = 0; j < N * N; ++j) { + ASSERT_EQ(results[j][0].type(), TypedValue::Type::Int); + EXPECT_EQ(results[j][0].ValueInt(), j / N); + ASSERT_EQ(results[j][1].type(), TypedValue::Type::Int); + EXPECT_EQ(results[j][1].ValueInt(), N - 1 - j % N); + } +} + +TEST_F(QueryPlanBagSemanticsTest, OrderByExceptions) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + auto prop = dba.NameToProperty("prop"); + + // a vector of pairs of typed values that should result + // in an exception when trying to order on them + std::vector> exception_pairs{ + {storage::PropertyValue(42), storage::PropertyValue(true)}, + {storage::PropertyValue(42), storage::PropertyValue("bla")}, + {storage::PropertyValue(42), + storage::PropertyValue(std::vector{storage::PropertyValue(42)})}, + {storage::PropertyValue(true), storage::PropertyValue("bla")}, + {storage::PropertyValue(true), + storage::PropertyValue(std::vector{storage::PropertyValue(true)})}, + {storage::PropertyValue("bla"), + storage::PropertyValue(std::vector{storage::PropertyValue("bla")})}, + // illegal comparisons of same-type values + {storage::PropertyValue(std::vector{storage::PropertyValue(42)}), + storage::PropertyValue(std::vector{storage::PropertyValue(42)})}}; + + for (const auto &pair : exception_pairs) { + // empty database + for (auto vertex : dba.Vertices(storage::View::OLD)) ASSERT_TRUE(dba.DetachRemoveVertex(&vertex).HasValue()); + dba.AdvanceCommand(); + ASSERT_EQ(0, CountIterable(dba.Vertices(storage::View::OLD))); + + // make two vertices, and set values + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}) + ->SetProperty(prop, pair.first) + .HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}) + ->SetProperty(prop, pair.second) + .HasValue()); + dba.AdvanceCommand(); + ASSERT_EQ(2, CountIterable(dba.Vertices(storage::View::OLD))); + for (const auto &va : dba.Vertices(storage::View::OLD)) + ASSERT_NE(va.GetProperty(storage::View::OLD, prop).GetValue().type(), storage::PropertyValue::Type::Null); + + // order by and expect an exception + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + auto order_by = + std::make_shared(n.op_, std::vector{{Ordering::ASC, n_p}}, std::vector{}); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*order_by, &context), QueryRuntimeException); + } +} +} // namespace memgraph::query::tests diff --git a/tests/unit/query_v2_query_plan_create_set_remove_delete.cpp b/tests/unit/query_v2_query_plan_create_set_remove_delete.cpp new file mode 100644 index 000000000..81904c17e --- /dev/null +++ b/tests/unit/query_v2_query_plan_create_set_remove_delete.cpp @@ -0,0 +1,1095 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include +#include +#include +#include + +#include "common/types.hpp" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "query/context.hpp" +#include "query/db_accessor.hpp" +#include "query/exceptions.hpp" +#include "query/interpret/frame.hpp" +#include "query/plan/operator.hpp" + +#include "query_plan_common.hpp" +#include "storage/v2/id_types.hpp" +#include "storage/v2/property_value.hpp" +#include "storage/v2/schemas.hpp" +#include "storage/v2/storage.hpp" +#include "storage/v2/vertex.hpp" +#include "storage/v2/view.hpp" + +using namespace memgraph::query; +using namespace memgraph::query::plan; + +namespace memgraph::query::tests { + +class QueryPlanCRUDTest : public testing::Test { + protected: + void SetUp() override { + ASSERT_TRUE(db.CreateSchema(label, {storage::SchemaProperty{property, common::SchemaType::INT}})); + } + + storage::Storage db; + const storage::LabelId label{db.NameToLabel("label")}; + const storage::PropertyId property{db.NameToProperty("property")}; +}; + +TEST_F(QueryPlanCRUDTest, CreateNodeWithAttributes) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + AstStorage storage; + SymbolTable symbol_table; + + NodeCreationInfo node; + node.symbol = symbol_table.CreateSymbol("n", true); + node.labels.emplace_back(label); + std::get>>(node.properties) + .emplace_back(property, LITERAL(42)); + + auto create = std::make_shared(nullptr, node); + auto context = MakeContext(storage, symbol_table, &dba); + PullAll(*create, &context); + dba.AdvanceCommand(); + + // count the number of vertices + int vertex_count = 0; + for (auto vertex : dba.Vertices(storage::View::OLD)) { + vertex_count++; + auto maybe_labels = vertex.Labels(storage::View::OLD); + ASSERT_TRUE(maybe_labels.HasValue()); + const auto &labels = *maybe_labels; + EXPECT_EQ(labels.size(), 0); + + auto maybe_properties = vertex.Properties(storage::View::OLD); + ASSERT_TRUE(maybe_properties.HasValue()); + const auto &properties = *maybe_properties; + EXPECT_EQ(properties.size(), 1); + auto maybe_prop = vertex.GetProperty(storage::View::OLD, property); + ASSERT_TRUE(maybe_prop.HasValue()); + auto prop_eq = TypedValue(*maybe_prop) == TypedValue(42); + ASSERT_EQ(prop_eq.type(), TypedValue::Type::Bool); + EXPECT_TRUE(prop_eq.ValueBool()); + } + EXPECT_EQ(vertex_count, 1); +} + +TEST(QueryPlan, CreateReturn) { + // test CREATE (n:Person {age: 42}) RETURN n, n.age + storage::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + storage::LabelId label = dba.NameToLabel("Person"); + auto property = PROPERTY_PAIR("property"); + db.CreateSchema(label, {storage::SchemaProperty{property.second, common::SchemaType::INT}}); + + AstStorage storage; + SymbolTable symbol_table; + + NodeCreationInfo node; + node.symbol = symbol_table.CreateSymbol("n", true); + node.labels.emplace_back(label); + std::get>>(node.properties) + .emplace_back(property.second, LITERAL(42)); + + auto create = std::make_shared(nullptr, node); + auto named_expr_n = + NEXPR("n", IDENT("n")->MapTo(node.symbol))->MapTo(symbol_table.CreateSymbol("named_expr_n", true)); + auto prop_lookup = PROPERTY_LOOKUP(IDENT("n")->MapTo(node.symbol), property); + auto named_expr_n_p = NEXPR("n", prop_lookup)->MapTo(symbol_table.CreateSymbol("named_expr_n_p", true)); + + auto produce = MakeProduce(create, named_expr_n, named_expr_n_p); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(1, results.size()); + EXPECT_EQ(2, results[0].size()); + EXPECT_EQ(TypedValue::Type::Vertex, results[0][0].type()); + auto maybe_labels = results[0][0].ValueVertex().Labels(storage::View::NEW); + EXPECT_EQ(maybe_labels->size(), 0); + + EXPECT_EQ(TypedValue::Type::Int, results[0][1].type()); + EXPECT_EQ(42, results[0][1].ValueInt()); + + dba.AdvanceCommand(); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::View::OLD))); +} + +TEST(QueryPlan, CreateExpand) { + storage::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + storage::LabelId label_node_1 = dba.NameToLabel("Node1"); + storage::LabelId label_node_2 = dba.NameToLabel("Node2"); + auto property = PROPERTY_PAIR("property"); + storage::EdgeTypeId edge_type = dba.NameToEdgeType("edge_type"); + db.CreateSchema(label_node_1, {storage::SchemaProperty{property.second, common::SchemaType::INT}}); + db.CreateSchema(label_node_2, {storage::SchemaProperty{property.second, common::SchemaType::INT}}); + + SymbolTable symbol_table; + AstStorage storage; + + auto test_create_path = [&](bool cycle, int expected_nodes_created, int expected_edges_created) { + int before_v = CountIterable(dba.Vertices(storage::View::OLD)); + int before_e = CountEdges(&dba, storage::View::OLD); + + // data for the first node + NodeCreationInfo n; + n.symbol = symbol_table.CreateSymbol("n", true); + n.labels.emplace_back(label_node_1); + std::get>>(n.properties) + .emplace_back(property.second, LITERAL(1)); + + // data for the second node + NodeCreationInfo m; + m.symbol = cycle ? n.symbol : symbol_table.CreateSymbol("m", true); + m.labels.emplace_back(label_node_2); + std::get>>(m.properties) + .emplace_back(property.second, LITERAL(2)); + + EdgeCreationInfo r; + r.symbol = symbol_table.CreateSymbol("r", true); + r.edge_type = edge_type; + std::get<0>(r.properties).emplace_back(property.second, LITERAL(3)); + + auto create_op = std::make_shared(nullptr, n); + auto create_expand = std::make_shared(m, r, create_op, n.symbol, cycle); + auto context = MakeContext(storage, symbol_table, &dba); + PullAll(*create_expand, &context); + dba.AdvanceCommand(); + + EXPECT_EQ(CountIterable(dba.Vertices(storage::View::OLD)) - before_v, expected_nodes_created); + EXPECT_EQ(CountEdges(&dba, storage::View::OLD) - before_e, expected_edges_created); + }; + + test_create_path(false, 2, 1); + test_create_path(true, 1, 1); + + for (auto vertex : dba.Vertices(storage::View::OLD)) { + auto maybe_labels = vertex.Labels(storage::View::OLD); + MG_ASSERT(maybe_labels.HasValue()); + const auto &labels = *maybe_labels; + EXPECT_EQ(labels.size(), 0); + auto maybe_primary_label = vertex.PrimaryLabel(storage::View::OLD); + ASSERT_TRUE(maybe_primary_label.HasValue()); + if (*maybe_primary_label == label_node_1) { + // node created by first op + EXPECT_EQ(vertex.GetProperty(storage::View::OLD, property.second)->ValueInt(), 1); + } else if (*maybe_primary_label == label_node_2) { + // node create by expansion + EXPECT_EQ(vertex.GetProperty(storage::View::OLD, property.second)->ValueInt(), 2); + } else { + // should not happen + FAIL(); + } + + for (auto vertex : dba.Vertices(storage::View::OLD)) { + auto maybe_edges = vertex.OutEdges(storage::View::OLD); + MG_ASSERT(maybe_edges.HasValue()); + for (auto edge : *maybe_edges) { + EXPECT_EQ(edge.EdgeType(), edge_type); + EXPECT_EQ(edge.GetProperty(storage::View::OLD, property.second)->ValueInt(), 3); + } + } + } +} + +TEST_F(QueryPlanCRUDTest, MatchCreateNode) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}).HasValue()); + dba.AdvanceCommand(); + + SymbolTable symbol_table; + AstStorage storage; + + // first node + auto n_scan_all = MakeScanAll(storage, symbol_table, "n"); + // second node + NodeCreationInfo m; + m.symbol = symbol_table.CreateSymbol("m", true); + m.labels = {label}; + std::get>>(m.properties).emplace_back(property, LITERAL(1)); + + // creation op + auto create_node = std::make_shared(n_scan_all.op_, m); + + EXPECT_EQ(CountIterable(dba.Vertices(storage::View::OLD)), 3); + auto context = MakeContext(storage, symbol_table, &dba); + PullAll(*create_node, &context); + dba.AdvanceCommand(); + EXPECT_EQ(CountIterable(dba.Vertices(storage::View::OLD)), 6); +} + +TEST_F(QueryPlanCRUDTest, MatchCreateExpand) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}).HasValue()); + dba.AdvanceCommand(); + + // storage::LabelId label_node_1 = dba.NameToLabel("Node1"); + // storage::LabelId label_node_2 = dba.NameToLabel("Node2"); + // storage::PropertyId property = dba.NameToLabel("prop"); + storage::EdgeTypeId edge_type = dba.NameToEdgeType("edge_type"); + + SymbolTable symbol_table; + AstStorage storage; + + auto test_create_path = [&](bool cycle, int expected_nodes_created, int expected_edges_created) { + int before_v = CountIterable(dba.Vertices(storage::View::OLD)); + int before_e = CountEdges(&dba, storage::View::OLD); + + // data for the first node + auto n_scan_all = MakeScanAll(storage, symbol_table, "n"); + + // data for the second node + NodeCreationInfo m; + m.symbol = cycle ? n_scan_all.sym_ : symbol_table.CreateSymbol("m", true); + m.labels = {label}; + std::get>>(m.properties) + .emplace_back(property, LITERAL(1)); + + EdgeCreationInfo r; + r.symbol = symbol_table.CreateSymbol("r", true); + r.direction = EdgeAtom::Direction::OUT; + r.edge_type = edge_type; + + auto create_expand = std::make_shared(m, r, n_scan_all.op_, n_scan_all.sym_, cycle); + auto context = MakeContext(storage, symbol_table, &dba); + PullAll(*create_expand, &context); + dba.AdvanceCommand(); + + EXPECT_EQ(CountIterable(dba.Vertices(storage::View::OLD)) - before_v, expected_nodes_created); + EXPECT_EQ(CountEdges(&dba, storage::View::OLD) - before_e, expected_edges_created); + }; + + test_create_path(false, 3, 3); + test_create_path(true, 0, 6); +} + +TEST_F(QueryPlanCRUDTest, Delete) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // make a fully-connected (one-direction, no cycles) with 4 nodes + std::vector vertices; + for (int i = 0; i < 4; ++i) { + vertices.push_back(*dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(i)}})); + } + auto type = dba.NameToEdgeType("type"); + for (int j = 0; j < 4; ++j) + for (int k = j + 1; k < 4; ++k) ASSERT_TRUE(dba.InsertEdge(&vertices[j], &vertices[k], type).HasValue()); + + dba.AdvanceCommand(); + EXPECT_EQ(4, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(6, CountEdges(&dba, storage::View::OLD)); + + AstStorage storage; + SymbolTable symbol_table; + + // attempt to delete a vertex, and fail + { + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_get = storage.Create("n")->MapTo(n.sym_); + auto delete_op = std::make_shared(n.op_, std::vector{n_get}, false); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*delete_op, &context), QueryRuntimeException); + dba.AdvanceCommand(); + EXPECT_EQ(4, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(6, CountEdges(&dba, storage::View::OLD)); + } + + // detach delete a single vertex + { + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_get = storage.Create("n")->MapTo(n.sym_); + auto delete_op = std::make_shared(n.op_, std::vector{n_get}, true); + Frame frame(symbol_table.max_position()); + auto context = MakeContext(storage, symbol_table, &dba); + delete_op->MakeCursor(utils::NewDeleteResource())->Pull(frame, context); + dba.AdvanceCommand(); + EXPECT_EQ(3, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(3, CountEdges(&dba, storage::View::OLD)); + } + + // delete all remaining edges + { + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "m", false, + storage::View::NEW); + auto r_get = storage.Create("r")->MapTo(r_m.edge_sym_); + auto delete_op = std::make_shared(r_m.op_, std::vector{r_get}, false); + auto context = MakeContext(storage, symbol_table, &dba); + PullAll(*delete_op, &context); + dba.AdvanceCommand(); + EXPECT_EQ(3, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(0, CountEdges(&dba, storage::View::OLD)); + } + + // delete all remaining vertices + { + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_get = storage.Create("n")->MapTo(n.sym_); + auto delete_op = std::make_shared(n.op_, std::vector{n_get}, false); + auto context = MakeContext(storage, symbol_table, &dba); + PullAll(*delete_op, &context); + dba.AdvanceCommand(); + EXPECT_EQ(0, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(0, CountEdges(&dba, storage::View::OLD)); + } +} + +TEST_F(QueryPlanCRUDTest, DeleteTwiceDeleteBlockingEdge) { + // test deleting the same vertex and edge multiple times + // + // also test vertex deletion succeeds if the prohibiting + // edge is deleted in the same logical op + // + // we test both with the following queries (note the + // undirected edge in MATCH): + // + // CREATE (:label{property: 1})-[:T]->(:label{property: 2}) + // MATCH (n)-[r]-(m) [DETACH] DELETE n, r, m + + auto test_delete = [this](bool detach) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, dba.NameToEdgeType("T")).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(1, CountEdges(&dba, storage::View::OLD)); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::BOTH, {}, "m", false, + storage::View::OLD); + + // getter expressions for deletion + auto n_get = storage.Create("n")->MapTo(n.sym_); + auto r_get = storage.Create("r")->MapTo(r_m.edge_sym_); + auto m_get = storage.Create("m")->MapTo(r_m.node_sym_); + + auto delete_op = std::make_shared(r_m.op_, std::vector{n_get, r_get, m_get}, detach); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(2, PullAll(*delete_op, &context)); + dba.AdvanceCommand(); + EXPECT_EQ(0, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(0, CountEdges(&dba, storage::View::OLD)); + }; + + test_delete(true); + test_delete(false); +} + +TEST_F(QueryPlanCRUDTest, DeleteReturn) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // make a fully-connected (one-direction, no cycles) with 4 nodes + for (int i = 0; i < 4; ++i) { + const auto property_value = storage::PropertyValue(i); + auto va = *dba.InsertVertexAndValidate(label, {}, {{property, property_value}}); + EXPECT_EQ(*va.GetProperty(storage::View::NEW, property), property_value); + } + + dba.AdvanceCommand(); + EXPECT_EQ(4, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(0, CountEdges(&dba, storage::View::OLD)); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + + auto n_get = storage.Create("n")->MapTo(n.sym_); + auto delete_op = std::make_shared(n.op_, std::vector{n_get}, true); + + auto prop_lookup = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), property); + auto n_p = storage.Create("n", prop_lookup)->MapTo(symbol_table.CreateSymbol("bla", true)); + auto produce = MakeProduce(delete_op, n_p); + + auto context = MakeContext(storage, symbol_table, &dba); + ASSERT_THROW(CollectProduce(*produce, &context), QueryRuntimeException); +} + +TEST(QueryPlan, DeleteNull) { + // test (simplified) WITH Null as x delete x + storage::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + + auto once = std::make_shared(); + auto delete_op = std::make_shared(once, std::vector{LITERAL(TypedValue())}, false); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*delete_op, &context)); +} + +TEST_F(QueryPlanCRUDTest, DeleteAdvance) { + // test queries on empty DB: + // CREATE (n: label{property: 1}) + // MATCH (n) DELETE n WITH n ... + // this fails only if the deleted record `n` is actually used in subsequent + // clauses, which is compatible with Neo's behavior. + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_get = storage.Create("n")->MapTo(n.sym_); + auto delete_op = std::make_shared(n.op_, std::vector{n_get}, false); + auto advance = std::make_shared(delete_op, std::vector{n.sym_}, true); + auto res_sym = symbol_table.CreateSymbol("res", true); + { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + dba.AdvanceCommand(); + auto produce = MakeProduce(advance, NEXPR("res", LITERAL(42))->MapTo(res_sym)); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*produce, &context)); + } + { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}).HasValue()); + dba.AdvanceCommand(); + auto n_prop = PROPERTY_LOOKUP(n_get, dba.NameToProperty("prop")); + auto produce = MakeProduce(advance, NEXPR("res", n_prop)->MapTo(res_sym)); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*produce, &context), QueryRuntimeException); + } +} + +TEST_F(QueryPlanCRUDTest, SetProperty) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // graph with 4 vertices in connected pairs + // the origin vertex in each par and both edges + // have a property set + + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}); + auto v4 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(4)}}); + auto edge_type = dba.NameToEdgeType("edge_type"); + ASSERT_TRUE(dba.InsertEdge(&v1, &v3, edge_type).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v2, &v4, edge_type).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + // scan (n)-[r]->(m) + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "m", false, + storage::View::OLD); + + // set prop1 to 42 on n and r + auto prop1 = dba.NameToProperty("prop1"); + auto literal = LITERAL(42); + + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop1); + auto set_n_p = std::make_shared(r_m.op_, prop1, n_p, literal); + + auto r_p = PROPERTY_LOOKUP(IDENT("r")->MapTo(r_m.edge_sym_), prop1); + auto set_r_p = std::make_shared(set_n_p, prop1, r_p, literal); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(2, PullAll(*set_r_p, &context)); + dba.AdvanceCommand(); + + EXPECT_EQ(CountEdges(&dba, storage::View::OLD), 2); + for (auto vertex : dba.Vertices(storage::View::OLD)) { + auto maybe_edges = vertex.OutEdges(storage::View::OLD); + ASSERT_TRUE(maybe_edges.HasValue()); + for (auto edge : *maybe_edges) { + ASSERT_EQ(edge.GetProperty(storage::View::OLD, prop1)->type(), storage::PropertyValue::Type::Int); + EXPECT_EQ(edge.GetProperty(storage::View::OLD, prop1)->ValueInt(), 42); + auto from = edge.From(); + auto to = edge.To(); + ASSERT_EQ(from.GetProperty(storage::View::OLD, prop1)->type(), storage::PropertyValue::Type::Int); + EXPECT_EQ(from.GetProperty(storage::View::OLD, prop1)->ValueInt(), 42); + ASSERT_EQ(to.GetProperty(storage::View::OLD, prop1)->type(), storage::PropertyValue::Type::Null); + } + } +} + +TEST_F(QueryPlanCRUDTest, SetProperties) { + auto test_set_properties = [this](bool update) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // graph: ({a: 0})-[:R {b:1}]->({c:2}) + auto prop_a = dba.NameToProperty("a"); + auto prop_b = dba.NameToProperty("b"); + auto prop_c = dba.NameToProperty("c"); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + dba.AdvanceCommand(); + + auto e = dba.InsertEdge(&v1, &v2, dba.NameToEdgeType("R")); + ASSERT_TRUE(v1.SetPropertyAndValidate(prop_a, storage::PropertyValue(0)).HasValue()); + ASSERT_TRUE(e->SetProperty(prop_b, storage::PropertyValue(1)).HasValue()); + ASSERT_TRUE(v2.SetPropertyAndValidate(prop_c, storage::PropertyValue(2)).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + // scan (n)-[r]->(m) + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "m", false, + storage::View::OLD); + + auto op = update ? plan::SetProperties::Op::UPDATE : plan::SetProperties::Op::REPLACE; + + // set properties on r to n, and on r to m + auto r_ident = IDENT("r")->MapTo(r_m.edge_sym_); + auto m_ident = IDENT("m")->MapTo(r_m.node_sym_); + auto set_r_to_n = std::make_shared(r_m.op_, n.sym_, r_ident, op); + auto set_m_to_r = std::make_shared(set_r_to_n, r_m.edge_sym_, m_ident, op); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*set_m_to_r, &context)); + dba.AdvanceCommand(); + + EXPECT_EQ(CountEdges(&dba, storage::View::OLD), 1); + for (auto vertex : dba.Vertices(storage::View::OLD)) { + auto maybe_edges = vertex.OutEdges(storage::View::OLD); + ASSERT_TRUE(maybe_edges.HasValue()); + for (auto edge : *maybe_edges) { + auto from = edge.From(); + EXPECT_EQ(from.Properties(storage::View::OLD)->size(), update ? 3 : 1); + if (update) { + ASSERT_EQ(from.GetProperty(storage::View::OLD, prop_a)->type(), storage::PropertyValue::Type::Int); + EXPECT_EQ(from.GetProperty(storage::View::OLD, prop_a)->ValueInt(), 0); + } + ASSERT_EQ(from.GetProperty(storage::View::OLD, prop_b)->type(), storage::PropertyValue::Type::Int); + EXPECT_EQ(from.GetProperty(storage::View::OLD, prop_b)->ValueInt(), 1); + + EXPECT_EQ(edge.Properties(storage::View::OLD)->size(), update ? 3 : 2); + if (update) { + ASSERT_EQ(edge.GetProperty(storage::View::OLD, prop_b)->type(), storage::PropertyValue::Type::Int); + EXPECT_EQ(edge.GetProperty(storage::View::OLD, prop_b)->ValueInt(), 1); + } + ASSERT_EQ(edge.GetProperty(storage::View::OLD, prop_c)->type(), storage::PropertyValue::Type::Int); + EXPECT_EQ(edge.GetProperty(storage::View::OLD, prop_c)->ValueInt(), 2); + + auto to = edge.To(); + EXPECT_EQ(to.Properties(storage::View::OLD)->size(), 2); + ASSERT_EQ(to.GetProperty(storage::View::OLD, prop_c)->type(), storage::PropertyValue::Type::Int); + EXPECT_EQ(to.GetProperty(storage::View::OLD, prop_c)->ValueInt(), 2); + } + } + }; + + test_set_properties(true); + test_set_properties(false); +} + +TEST_F(QueryPlanCRUDTest, SetSecondaryLabels) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + + auto label1 = dba.NameToLabel("label1"); + auto label2 = dba.NameToLabel("label2"); + auto label3 = dba.NameToLabel("label3"); + ASSERT_TRUE(v1.AddLabel(label1).HasValue()); + ASSERT_TRUE(v2.AddLabel(label1).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto label_set = std::make_shared(n.op_, n.sym_, std::vector{label2, label3}); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(2, PullAll(*label_set, &context)); + + for (auto vertex : dba.Vertices(storage::View::OLD)) { + EXPECT_EQ(3, vertex.Labels(storage::View::NEW)->size()); + EXPECT_TRUE(*vertex.HasLabel(storage::View::NEW, label2)); + EXPECT_TRUE(*vertex.HasLabel(storage::View::NEW, label3)); + } +} + +TEST_F(QueryPlanCRUDTest, RemoveProperty) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // graph with 4 vertices in connected pairs + // the origin vertex in each par and both edges + // have a property set + auto prop1 = dba.NameToProperty("prop1"); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}); + auto v4 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(4)}}); + auto edge_type = dba.NameToEdgeType("edge_type"); + { + auto e = dba.InsertEdge(&v1, &v3, edge_type); + ASSERT_TRUE(e.HasValue()); + ASSERT_TRUE(e->SetProperty(prop1, storage::PropertyValue(42)).HasValue()); + } + ASSERT_TRUE(dba.InsertEdge(&v2, &v4, edge_type).HasValue()); + ASSERT_TRUE(v2.SetProperty(prop1, storage::PropertyValue(42)).HasValue()); + ASSERT_TRUE(v3.SetProperty(prop1, storage::PropertyValue(42)).HasValue()); + ASSERT_TRUE(v4.SetProperty(prop1, storage::PropertyValue(42)).HasValue()); + auto prop2 = dba.NameToProperty("prop2"); + ASSERT_TRUE(v1.SetProperty(prop2, storage::PropertyValue(0)).HasValue()); + ASSERT_TRUE(v2.SetProperty(prop2, storage::PropertyValue(0)).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + // scan (n)-[r]->(m) + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "m", false, + storage::View::OLD); + + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop1); + auto set_n_p = std::make_shared(r_m.op_, prop1, n_p); + + auto r_p = PROPERTY_LOOKUP(IDENT("r")->MapTo(r_m.edge_sym_), prop1); + auto set_r_p = std::make_shared(set_n_p, prop1, r_p); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(2, PullAll(*set_r_p, &context)); + dba.AdvanceCommand(); + + EXPECT_EQ(CountEdges(&dba, storage::View::OLD), 2); + for (auto vertex : dba.Vertices(storage::View::OLD)) { + auto maybe_edges = vertex.OutEdges(storage::View::OLD); + ASSERT_TRUE(maybe_edges.HasValue()); + for (auto edge : *maybe_edges) { + EXPECT_EQ(edge.GetProperty(storage::View::OLD, prop1)->type(), storage::PropertyValue::Type::Null); + auto from = edge.From(); + auto to = edge.To(); + EXPECT_EQ(from.GetProperty(storage::View::OLD, prop1)->type(), storage::PropertyValue::Type::Null); + EXPECT_EQ(from.GetProperty(storage::View::OLD, prop2)->type(), storage::PropertyValue::Type::Int); + EXPECT_EQ(to.GetProperty(storage::View::OLD, prop1)->type(), storage::PropertyValue::Type::Int); + } + } +} + +TEST_F(QueryPlanCRUDTest, RemoveLabels) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto label1 = dba.NameToLabel("label1"); + auto label2 = dba.NameToLabel("label2"); + auto label3 = dba.NameToLabel("label3"); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + ASSERT_TRUE(v1.AddLabel(label1).HasValue()); + ASSERT_TRUE(v1.AddLabel(label2).HasValue()); + ASSERT_TRUE(v1.AddLabel(label3).HasValue()); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + ASSERT_TRUE(v2.AddLabel(label1).HasValue()); + ASSERT_TRUE(v2.AddLabel(label3).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto label_remove = + std::make_shared(n.op_, n.sym_, std::vector{label1, label2}); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(2, PullAll(*label_remove, &context)); + + for (auto vertex : dba.Vertices(storage::View::OLD)) { + EXPECT_EQ(1, vertex.Labels(storage::View::NEW)->size()); + EXPECT_FALSE(*vertex.HasLabel(storage::View::NEW, label1)); + EXPECT_FALSE(*vertex.HasLabel(storage::View::NEW, label2)); + } +} + +TEST_F(QueryPlanCRUDTest, NodeFilterSet) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // Create a graph such that (v1 {prop: 42}) is connected to v2 and v3. + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto prop = PROPERTY_PAIR("prop"); + ASSERT_TRUE(v1.SetProperty(prop.second, storage::PropertyValue(42)).HasValue()); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}); + auto edge_type = dba.NameToEdgeType("Edge"); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, edge_type).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v1, &v3, edge_type).HasValue()); + dba.AdvanceCommand(); + // Create operations which match (v1 {prop: 42}) -- (v) and increment the + // v1.prop. The expected result is two incremenentations, since v1 is matched + // twice for 2 edges it has. + AstStorage storage; + SymbolTable symbol_table; + // MATCH (n {prop: 42}) -[r]- (m) + auto scan_all = MakeScanAll(storage, symbol_table, "n"); + std::get<0>(scan_all.node_->properties_)[storage.GetPropertyIx(prop.first)] = LITERAL(42); + auto expand = MakeExpand(storage, symbol_table, scan_all.op_, scan_all.sym_, "r", EdgeAtom::Direction::BOTH, {}, "m", + false, storage::View::OLD); + auto *filter_expr = + EQ(storage.Create(scan_all.node_->identifier_, storage.GetPropertyIx(prop.first)), LITERAL(42)); + auto node_filter = std::make_shared(expand.op_, filter_expr); + // SET n.prop = n.prop + 1 + auto set_prop = PROPERTY_LOOKUP(IDENT("n")->MapTo(scan_all.sym_), prop); + auto add = ADD(set_prop, LITERAL(1)); + auto set = std::make_shared(node_filter, prop.second, set_prop, add); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(2, PullAll(*set, &context)); + dba.AdvanceCommand(); + auto prop_eq = TypedValue(*v1.GetProperty(storage::View::OLD, prop.second)) == TypedValue(42 + 2); + ASSERT_EQ(prop_eq.type(), TypedValue::Type::Bool); + EXPECT_TRUE(prop_eq.ValueBool()); +} + +TEST_F(QueryPlanCRUDTest, FilterRemove) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // Create a graph such that (v1 {prop: 42}) is connected to v2 and v3. + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto prop = PROPERTY_PAIR("prop"); + ASSERT_TRUE(v1.SetProperty(prop.second, storage::PropertyValue(42)).HasValue()); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}); + auto edge_type = dba.NameToEdgeType("Edge"); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, edge_type).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v1, &v3, edge_type).HasValue()); + dba.AdvanceCommand(); + // Create operations which match (v1 {prop: 42}) -- (v) and remove v1.prop. + // The expected result is two matches, for each edge of v1. + AstStorage storage; + SymbolTable symbol_table; + // MATCH (n) -[r]- (m) WHERE n.prop < 43 + auto scan_all = MakeScanAll(storage, symbol_table, "n"); + std::get<0>(scan_all.node_->properties_)[storage.GetPropertyIx(prop.first)] = LITERAL(42); + auto expand = MakeExpand(storage, symbol_table, scan_all.op_, scan_all.sym_, "r", EdgeAtom::Direction::BOTH, {}, "m", + false, storage::View::OLD); + auto filter_prop = PROPERTY_LOOKUP(IDENT("n")->MapTo(scan_all.sym_), prop); + auto filter = std::make_shared(expand.op_, LESS(filter_prop, LITERAL(43))); + // REMOVE n.prop + auto rem_prop = PROPERTY_LOOKUP(IDENT("n")->MapTo(scan_all.sym_), prop); + auto rem = std::make_shared(filter, prop.second, rem_prop); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(2, PullAll(*rem, &context)); + dba.AdvanceCommand(); + EXPECT_EQ(v1.GetProperty(storage::View::OLD, prop.second)->type(), storage::PropertyValue::Type::Null); +} + +TEST_F(QueryPlanCRUDTest, SetRemove) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto label1 = dba.NameToLabel("label1"); + auto label2 = dba.NameToLabel("label2"); + dba.AdvanceCommand(); + // Create operations which match (v) and set and remove v :label. + // The expected result is single (v) as it was at the start. + AstStorage storage; + SymbolTable symbol_table; + // MATCH (n) SET n :label1 :label2 REMOVE n :label1 :label2 + auto scan_all = MakeScanAll(storage, symbol_table, "n"); + auto set = + std::make_shared(scan_all.op_, scan_all.sym_, std::vector{label1, label2}); + auto rem = std::make_shared(set, scan_all.sym_, std::vector{label1, label2}); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*rem, &context)); + dba.AdvanceCommand(); + EXPECT_FALSE(*v.HasLabel(storage::View::OLD, label1)); + EXPECT_FALSE(*v.HasLabel(storage::View::OLD, label2)); +} + +TEST_F(QueryPlanCRUDTest, Merge) { + // test setup: + // - three nodes, two of them connected with T + // - merge input branch matches all nodes + // - merge_match branch looks for an expansion (any direction) + // and sets some property (for result validation) + // - merge_create branch just sets some other property + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, dba.NameToEdgeType("Type")).HasValue()); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto prop = PROPERTY_PAIR("prop"); + auto n = MakeScanAll(storage, symbol_table, "n"); + + // merge_match branch + auto r_m = MakeExpand(storage, symbol_table, std::make_shared(), n.sym_, "r", EdgeAtom::Direction::BOTH, {}, + "m", false, storage::View::OLD); + auto m_p = PROPERTY_LOOKUP(IDENT("m")->MapTo(r_m.node_sym_), prop); + auto m_set = std::make_shared(r_m.op_, prop.second, m_p, LITERAL(1)); + + // merge_create branch + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + auto n_set = std::make_shared(std::make_shared(), prop.second, n_p, LITERAL(2)); + + auto merge = std::make_shared(n.op_, m_set, n_set); + auto context = MakeContext(storage, symbol_table, &dba); + ASSERT_EQ(3, PullAll(*merge, &context)); + dba.AdvanceCommand(); + + ASSERT_EQ(v1.GetProperty(storage::View::OLD, prop.second)->type(), storage::PropertyValue::Type::Int); + ASSERT_EQ(v1.GetProperty(storage::View::OLD, prop.second)->ValueInt(), 1); + ASSERT_EQ(v2.GetProperty(storage::View::OLD, prop.second)->type(), storage::PropertyValue::Type::Int); + ASSERT_EQ(v2.GetProperty(storage::View::OLD, prop.second)->ValueInt(), 1); + ASSERT_EQ(v3.GetProperty(storage::View::OLD, prop.second)->type(), storage::PropertyValue::Type::Int); + ASSERT_EQ(v3.GetProperty(storage::View::OLD, prop.second)->ValueInt(), 2); +} + +TEST_F(QueryPlanCRUDTest, MergeNoInput) { + // merge with no input, creates a single node + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + + NodeCreationInfo node; + node.symbol = symbol_table.CreateSymbol("n", true); + node.labels = {label}; + std::get>>(node.properties) + .emplace_back(property, LITERAL(1)); + auto create = std::make_shared(nullptr, node); + auto merge = std::make_shared(nullptr, create, create); + + EXPECT_EQ(0, CountIterable(dba.Vertices(storage::View::OLD))); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*merge, &context)); + dba.AdvanceCommand(); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::View::OLD))); +} + +TEST(QueryPlan, SetPropertyOnNull) { + // SET (Null).prop = 42 + storage::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + auto prop = PROPERTY_PAIR("property"); + auto null = LITERAL(TypedValue()); + auto literal = LITERAL(42); + auto n_prop = PROPERTY_LOOKUP(null, prop); + auto once = std::make_shared(); + auto set_op = std::make_shared(once, prop.second, n_prop, literal); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*set_op, &context)); +} + +TEST(QueryPlan, SetPropertiesOnNull) { + // OPTIONAL MATCH (n) SET n = n + storage::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_ident = IDENT("n")->MapTo(n.sym_); + auto optional = std::make_shared(nullptr, n.op_, std::vector{n.sym_}); + auto set_op = std::make_shared(optional, n.sym_, n_ident, plan::SetProperties::Op::REPLACE); + EXPECT_EQ(0, CountIterable(dba.Vertices(storage::View::OLD))); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*set_op, &context)); +} + +TEST(QueryPlan, SetLabelsOnNull) { + // OPTIONAL MATCH (n) SET n :label + storage::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto label = dba.NameToLabel("label"); + AstStorage storage; + SymbolTable symbol_table; + auto n = MakeScanAll(storage, symbol_table, "n"); + auto optional = std::make_shared(nullptr, n.op_, std::vector{n.sym_}); + auto set_op = std::make_shared(optional, n.sym_, std::vector{label}); + EXPECT_EQ(0, CountIterable(dba.Vertices(storage::View::OLD))); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*set_op, &context)); +} + +TEST(QueryPlan, RemovePropertyOnNull) { + // REMOVE (Null).prop + storage::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + auto prop = PROPERTY_PAIR("property"); + auto null = LITERAL(TypedValue()); + auto n_prop = PROPERTY_LOOKUP(null, prop); + auto once = std::make_shared(); + auto remove_op = std::make_shared(once, prop.second, n_prop); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*remove_op, &context)); +} + +TEST(QueryPlan, RemoveLabelsOnNull) { + // OPTIONAL MATCH (n) REMOVE n :label + storage::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto label = dba.NameToLabel("label"); + AstStorage storage; + SymbolTable symbol_table; + auto n = MakeScanAll(storage, symbol_table, "n"); + auto optional = std::make_shared(nullptr, n.op_, std::vector{n.sym_}); + auto remove_op = std::make_shared(optional, n.sym_, std::vector{label}); + EXPECT_EQ(0, CountIterable(dba.Vertices(storage::View::OLD))); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*remove_op, &context)); +} + +TEST_F(QueryPlanCRUDTest, DeleteSetProperty) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // Add a single vertex. + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::View::OLD))); + AstStorage storage; + SymbolTable symbol_table; + // MATCH (n) DELETE n SET n.prop = 42 + auto n = MakeScanAllNew(storage, symbol_table, "n"); + + auto n_get = storage.Create("n")->MapTo(n.sym_); + auto delete_op = std::make_shared(n.op_, std::vector{n_get}, false); + auto prop = PROPERTY_PAIR("prop"); + auto n_prop = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + auto set_op = std::make_shared(delete_op, prop.second, n_prop, LITERAL(42)); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*set_op, &context), QueryRuntimeException); +} + +TEST_F(QueryPlanCRUDTest, DeleteSetPropertiesFromMap) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // Add a single vertex. + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::View::OLD))); + AstStorage storage; + SymbolTable symbol_table; + // MATCH (n) DELETE n SET n = {prop: 42} + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_get = storage.Create("n")->MapTo(n.sym_); + auto delete_op = std::make_shared(n.op_, std::vector{n_get}, false); + auto prop = PROPERTY_PAIR("prop"); + std::unordered_map prop_map; + prop_map.emplace(storage.GetPropertyIx(prop.first), LITERAL(42)); + auto *rhs = storage.Create(prop_map); + for (auto op_type : {plan::SetProperties::Op::REPLACE, plan::SetProperties::Op::UPDATE}) { + auto set_op = std::make_shared(delete_op, n.sym_, rhs, op_type); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*set_op, &context), QueryRuntimeException); + } +} + +TEST_F(QueryPlanCRUDTest, DeleteSetPropertiesFrom) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // Add a single vertex. + { + auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + ASSERT_TRUE(v.SetProperty(dba.NameToProperty("prop"), storage::PropertyValue(1)).HasValue()); + } + dba.AdvanceCommand(); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::View::OLD))); + AstStorage storage; + SymbolTable symbol_table; + // MATCH (n) DELETE n SET n = n + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_get = storage.Create("n")->MapTo(n.sym_); + auto delete_op = std::make_shared(n.op_, std::vector{n_get}, false); + auto *rhs = IDENT("n")->MapTo(n.sym_); + for (auto op_type : {plan::SetProperties::Op::REPLACE, plan::SetProperties::Op::UPDATE}) { + auto set_op = std::make_shared(delete_op, n.sym_, rhs, op_type); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*set_op, &context), QueryRuntimeException); + } +} + +TEST_F(QueryPlanCRUDTest, DeleteRemoveLabels) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // Add a single vertex. + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::View::OLD))); + AstStorage storage; + SymbolTable symbol_table; + // MATCH (n) DELETE n REMOVE n :label + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_get = storage.Create("n")->MapTo(n.sym_); + auto delete_op = std::make_shared(n.op_, std::vector{n_get}, false); + std::vector labels{dba.NameToLabel("label1")}; + auto rem_op = std::make_shared(delete_op, n.sym_, labels); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*rem_op, &context), QueryRuntimeException); +} + +TEST_F(QueryPlanCRUDTest, DeleteRemoveProperty) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // Add a single vertex. + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::View::OLD))); + AstStorage storage; + SymbolTable symbol_table; + // MATCH (n) DELETE n REMOVE n.prop + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_get = storage.Create("n")->MapTo(n.sym_); + auto delete_op = std::make_shared(n.op_, std::vector{n_get}, false); + auto prop = PROPERTY_PAIR("prop"); + auto n_prop = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + auto rem_op = std::make_shared(delete_op, prop.second, n_prop); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*rem_op, &context), QueryRuntimeException); +} +} // namespace memgraph::query::tests diff --git a/tests/unit/query_v2_query_plan_edge_cases.cpp b/tests/unit/query_v2_query_plan_edge_cases.cpp new file mode 100644 index 000000000..2abeab496 --- /dev/null +++ b/tests/unit/query_v2_query_plan_edge_cases.cpp @@ -0,0 +1,115 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +// tests in this suite deal with edge cases in logical operator behavior +// that's not easily testable with single-phase testing. instead, for +// easy testing and latter readability they are tested end-to-end. + +#include +#include + +#include +#include + +#include "communication/result_stream_faker.hpp" +#include "query/interpreter.hpp" +#include "storage/v2/storage.hpp" + +DECLARE_bool(query_cost_planner); + +namespace memgraph::query::tests { + +class QueryExecution : public testing::Test { + protected: + storage::Storage db; + std::optional db_; + std::optional interpreter_context_; + std::optional interpreter_; + + std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_query_plan_edge_cases"}; + + void SetUp() { + db_.emplace(); + interpreter_context_.emplace(&*db_, InterpreterConfig{}, data_directory); + interpreter_.emplace(&*interpreter_context_); + } + + void TearDown() { + interpreter_ = std::nullopt; + interpreter_context_ = std::nullopt; + db_ = std::nullopt; + } + + /** + * Execute the given query and commit the transaction. + * + * Return the query results. + */ + auto Execute(const std::string &query) { + ResultStreamFaker stream(&*db_); + + auto [header, _, qid] = interpreter_->Prepare(query, {}, nullptr); + stream.Header(header); + auto summary = interpreter_->PullAll(&stream); + stream.Summary(summary); + + return stream.GetResults(); + } +}; + +TEST_F(QueryExecution, MissingOptionalIntoExpand) { + Execute("CREATE SCHEMA ON :Person(id INTEGER)"); + Execute("CREATE SCHEMA ON :Dog(id INTEGER)"); + Execute("CREATE SCHEMA ON :Food(id INTEGER)"); + // validating bug where expanding from Null (due to a preceding optional + // match) exhausts the expansion cursor, even if it's input is still not + // exhausted + Execute( + "CREATE (a:Person {id: 1}), (b:Person " + "{id:2})-[:Has]->(:Dog {id: 1})-[:Likes]->(:Food {id: 1})"); + ASSERT_EQ(Execute("MATCH (n) RETURN n").size(), 4); + + auto Exec = [this](bool desc, const std::string &edge_pattern) { + // this test depends on left-to-right query planning + FLAGS_query_cost_planner = false; + return Execute(std::string("MATCH (p:Person) WITH p ORDER BY p.id ") + (desc ? "DESC " : "") + + "OPTIONAL MATCH (p)-->(d:Dog) WITH p, d " + "MATCH (d)" + + edge_pattern + + "(f:Food) " + "RETURN p, d, f") + .size(); + }; + + std::string expand = "-->"; + std::string variable = "-[*1]->"; + std::string bfs = "-[*bfs..1]->"; + + EXPECT_EQ(Exec(false, expand), 1); + EXPECT_EQ(Exec(true, expand), 1); + EXPECT_EQ(Exec(false, variable), 1); + EXPECT_EQ(Exec(true, bfs), 1); + EXPECT_EQ(Exec(true, bfs), 1); +} + +TEST_F(QueryExecution, EdgeUniquenessInOptional) { + Execute("CREATE SCHEMA ON :label(id INTEGER)"); + // Validating that an edge uniqueness check can't fail when the edge is Null + // due to optional match. Since edge-uniqueness only happens in one OPTIONAL + // MATCH, we only need to check that scenario. + Execute("CREATE (:label {id: 1}), (:label {id: 2})-[:Type]->(:label {id: 3})"); + ASSERT_EQ(Execute("MATCH (n) RETURN n").size(), 3); + EXPECT_EQ(Execute("MATCH (n) OPTIONAL MATCH (n)-[r1]->(), (n)-[r2]->() " + "RETURN n, r1, r2") + .size(), + 3); +} +} // namespace memgraph::query::tests diff --git a/tests/unit/query_v2_query_plan_match_filter_return.cpp b/tests/unit/query_v2_query_plan_match_filter_return.cpp new file mode 100644 index 000000000..079b3f5f4 --- /dev/null +++ b/tests/unit/query_v2_query_plan_match_filter_return.cpp @@ -0,0 +1,2062 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "query/context.hpp" +#include "query/exceptions.hpp" +#include "query/plan/operator.hpp" +#include "query_plan_common.hpp" +#include "storage/v2/property_value.hpp" + +using namespace memgraph::query; +using namespace memgraph::query::plan; + +namespace std { +template <> +struct hash> { + size_t operator()(const std::pair &p) const { return p.first + 31 * p.second; } +}; +} // namespace std + +namespace memgraph::query::tests { + +class MatchReturnFixture : public testing::Test { + protected: + void SetUp() override { + ASSERT_TRUE(db.CreateSchema(label, {storage::SchemaProperty{property, common::SchemaType::INT}})); + } + + storage::Storage db; + storage::Storage::Accessor storage_dba{db.Access()}; + DbAccessor dba{&storage_dba}; + const storage::LabelId label{db.NameToLabel("label")}; + const storage::PropertyId property{db.NameToProperty("property")}; + AstStorage storage; + SymbolTable symbol_table; + + void AddVertices(int count) { + for (int i = 0; i < count; i++) { + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(i)}}).HasValue()); + } + } + + std::vector PathResults(std::shared_ptr &op) { + std::vector res; + auto context = MakeContext(storage, symbol_table, &dba); + for (const auto &row : CollectProduce(*op, &context)) res.emplace_back(row[0].ValuePath()); + return res; + } +}; + +TEST_F(MatchReturnFixture, MatchReturn) { + AddVertices(2); + dba.AdvanceCommand(); + + auto test_pull_count = [&](storage::View view) { + auto scan_all = MakeScanAll(storage, symbol_table, "n", nullptr, view); + auto output = + NEXPR("n", IDENT("n")->MapTo(scan_all.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(scan_all.op_, output); + auto context = MakeContext(storage, symbol_table, &dba); + return PullAll(*produce, &context); + }; + + EXPECT_EQ(2, test_pull_count(storage::View::NEW)); + EXPECT_EQ(2, test_pull_count(storage::View::OLD)); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + EXPECT_EQ(3, test_pull_count(storage::View::NEW)); + EXPECT_EQ(2, test_pull_count(storage::View::OLD)); + dba.AdvanceCommand(); + EXPECT_EQ(3, test_pull_count(storage::View::OLD)); +} + +TEST_F(MatchReturnFixture, MatchReturnPath) { + AddVertices(2); + dba.AdvanceCommand(); + + auto scan_all = MakeScanAll(storage, symbol_table, "n", nullptr); + Symbol path_sym = symbol_table.CreateSymbol("path", true); + auto make_path = std::make_shared(scan_all.op_, path_sym, std::vector{scan_all.sym_}); + auto output = + NEXPR("path", IDENT("path")->MapTo(path_sym))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(make_path, output); + auto results = PathResults(produce); + ASSERT_EQ(results.size(), 2); + std::vector expected_paths; + for (const auto &v : dba.Vertices(storage::View::OLD)) expected_paths.emplace_back(v); + ASSERT_EQ(expected_paths.size(), 2); + EXPECT_TRUE(std::is_permutation(expected_paths.begin(), expected_paths.end(), results.begin())); +} + +class QueryPlanMatchFilterTest : public testing::Test { + protected: + QueryPlanMatchFilterTest() { + EXPECT_TRUE(db.CreateSchema(label, {storage::SchemaProperty{property, common::SchemaType::INT}})); + } + + storage::Storage db; + storage::LabelId label = db.NameToLabel("label"); + storage::PropertyId property = db.NameToProperty("property"); +}; + +TEST_F(QueryPlanMatchFilterTest, MatchReturnCartesian) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}) + ->AddLabel(dba.NameToLabel("l1")) + .HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}) + ->AddLabel(dba.NameToLabel("l2")) + .HasValue()); + + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto m = MakeScanAll(storage, symbol_table, "m", n.op_); + auto return_n = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto return_m = NEXPR("m", IDENT("m")->MapTo(m.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_2", true)); + auto produce = MakeProduce(m.op_, return_n, return_m); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), 4); + // ensure the result ordering is OK: + // "n" from the results is the same for the first two rows, while "m" isn't + EXPECT_EQ(results[0][0].ValueVertex(), results[1][0].ValueVertex()); + EXPECT_NE(results[0][1].ValueVertex(), results[1][1].ValueVertex()); +} + +TEST_F(QueryPlanMatchFilterTest, StandaloneReturn) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // add a few nodes to the database + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto output = NEXPR("n", LITERAL(42)); + auto produce = MakeProduce(std::shared_ptr(nullptr), output); + output->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), 1); + EXPECT_EQ(results[0].size(), 1); + EXPECT_EQ(results[0][0].ValueInt(), 42); +} + +TEST_F(QueryPlanMatchFilterTest, NodeFilterLabelsAndProperties) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // add a few nodes to the database + storage::LabelId label1 = dba.NameToLabel("Label1"); + auto property1 = PROPERTY_PAIR("Property1"); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}); + auto v4 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(4)}}); + auto v5 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(5)}}); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(6)}}).HasValue()); + + // test all combination of (label | no_label) * (no_prop | wrong_prop | + // right_prop) + // only v1-v3 will have the right labels + ASSERT_TRUE(v1.AddLabel(label1).HasValue()); + ASSERT_TRUE(v2.AddLabel(label1).HasValue()); + ASSERT_TRUE(v3.AddLabel(label1).HasValue()); + // v1 and v4 will have the right properties + ASSERT_TRUE(v1.SetProperty(property1.second, storage::PropertyValue(42)).HasValue()); + ASSERT_TRUE(v2.SetProperty(property1.second, storage::PropertyValue(1)).HasValue()); + ASSERT_TRUE(v4.SetProperty(property1.second, storage::PropertyValue(42)).HasValue()); + ASSERT_TRUE(v5.SetProperty(property1.second, storage::PropertyValue(1)).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + // make a scan all + auto n = MakeScanAll(storage, symbol_table, "n"); + n.node_->labels_.emplace_back(storage.GetLabelIx(dba.LabelToName(label1))); + std::get<0>(n.node_->properties_)[storage.GetPropertyIx(property1.first)] = LITERAL(42); + + // node filtering + auto *filter_expr = AND(storage.Create(n.node_->identifier_, n.node_->labels_), + EQ(PROPERTY_LOOKUP(n.node_->identifier_, property1), LITERAL(42))); + auto node_filter = std::make_shared(n.op_, filter_expr); + + // make a named expression and a produce + auto output = NEXPR("x", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(node_filter, output); + + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*produce, &context)); + + // test that filtering works with old records + ASSERT_TRUE(v4.AddLabel(label1).HasValue()); + EXPECT_EQ(1, PullAll(*produce, &context)); + dba.AdvanceCommand(); + EXPECT_EQ(2, PullAll(*produce, &context)); +} + +TEST_F(QueryPlanMatchFilterTest, NodeFilterMultipleLabels) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // add a few nodes to the database + storage::LabelId label1 = dba.NameToLabel("label1"); + storage::LabelId label2 = dba.NameToLabel("label2"); + storage::LabelId label3 = dba.NameToLabel("label3"); + // the test will look for nodes that have label1 and label2 + ASSERT_TRUE( + dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); // NOT accepted + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}) + ->AddLabel(label1) + .HasValue()); // NOT accepted + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}) + ->AddLabel(label2) + .HasValue()); // NOT accepted + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(4)}}) + ->AddLabel(label3) + .HasValue()); // NOT accepted + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(5)}}); // YES accepted + ASSERT_TRUE(v1.AddLabel(label1).HasValue()); + ASSERT_TRUE(v1.AddLabel(label2).HasValue()); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(6)}}); // NOT accepted + ASSERT_TRUE(v2.AddLabel(label1).HasValue()); + ASSERT_TRUE(v2.AddLabel(label3).HasValue()); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(7)}}); // YES accepted + ASSERT_TRUE(v3.AddLabel(label1).HasValue()); + ASSERT_TRUE(v3.AddLabel(label2).HasValue()); + ASSERT_TRUE(v3.AddLabel(label3).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + // make a scan all + auto n = MakeScanAll(storage, symbol_table, "n"); + n.node_->labels_.emplace_back(storage.GetLabelIx(dba.LabelToName(label1))); + n.node_->labels_.emplace_back(storage.GetLabelIx(dba.LabelToName(label2))); + + // node filtering + auto *filter_expr = storage.Create(n.node_->identifier_, n.node_->labels_); + auto node_filter = std::make_shared(n.op_, filter_expr); + + // make a named expression and a produce + auto output = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(node_filter, output); + + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), 2); +} + +TEST_F(QueryPlanMatchFilterTest, Cartesian) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto add_vertex = [&dba, this](std::string label1) { + auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + MG_ASSERT(vertex.AddLabel(dba.NameToLabel(label1)).HasValue()); + return vertex; + }; + + std::vector vertices{add_vertex("v1"), add_vertex("v2"), add_vertex("v3")}; + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto m = MakeScanAll(storage, symbol_table, "m"); + auto return_n = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto return_m = NEXPR("m", IDENT("m")->MapTo(m.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_2", true)); + + std::vector left_symbols{n.sym_}; + std::vector right_symbols{m.sym_}; + auto cartesian_op = std::make_shared(n.op_, left_symbols, m.op_, right_symbols); + + auto produce = MakeProduce(cartesian_op, return_n, return_m); + + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), 9); + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 3; ++j) { + EXPECT_EQ(results[3 * i + j][0].ValueVertex(), vertices[j]); + EXPECT_EQ(results[3 * i + j][1].ValueVertex(), vertices[i]); + } + } +} + +TEST_F(QueryPlanMatchFilterTest, CartesianEmptySet) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAllNew(storage, symbol_table, "n"); + auto m = MakeScanAllNew(storage, symbol_table, "m"); + auto return_n = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto return_m = NEXPR("m", IDENT("m")->MapTo(m.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_2", true)); + + std::vector left_symbols{n.sym_}; + std::vector right_symbols{m.sym_}; + auto cartesian_op = std::make_shared(n.op_, left_symbols, m.op_, right_symbols); + + auto produce = MakeProduce(cartesian_op, return_n, return_m); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), 0); +} + +TEST_F(QueryPlanMatchFilterTest, CartesianThreeWay) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto add_vertex = [&dba, this](std::string label1) { + auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + MG_ASSERT(vertex.AddLabel(dba.NameToLabel(label1)).HasValue()); + return vertex; + }; + + std::vector vertices{add_vertex("v1"), add_vertex("v2"), add_vertex("v3")}; + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAllNew(storage, symbol_table, "n"); + auto m = MakeScanAllNew(storage, symbol_table, "m"); + auto l = MakeScanAllNew(storage, symbol_table, "l"); + auto *return_n = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto *return_m = NEXPR("m", IDENT("m")->MapTo(m.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_2", true)); + auto *return_l = NEXPR("l", IDENT("l")->MapTo(l.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_3", true)); + + std::vector n_symbols{n.sym_}; + std::vector m_symbols{m.sym_}; + std::vector n_m_symbols{n.sym_, m.sym_}; + std::vector l_symbols{l.sym_}; + auto cartesian_op_1 = std::make_shared(n.op_, n_symbols, m.op_, m_symbols); + + auto cartesian_op_2 = std::make_shared(cartesian_op_1, n_m_symbols, l.op_, l_symbols); + + auto produce = MakeProduce(cartesian_op_2, return_n, return_m, return_l); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), 27); + int id = 0; + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 3; ++k) { + EXPECT_EQ(results[id][0].ValueVertex(), vertices[k]); + EXPECT_EQ(results[id][1].ValueVertex(), vertices[j]); + EXPECT_EQ(results[id][2].ValueVertex(), vertices[i]); + ++id; + } + } + } +} + +class ExpandFixture : public QueryPlanMatchFilterTest { + protected: + storage::Storage::Accessor storage_dba{db.Access()}; + DbAccessor dba{&storage_dba}; + AstStorage storage; + SymbolTable symbol_table; + + // make a V-graph (v3)<-[r2]-(v1)-[r1]->(v2) + VertexAccessor v1{*dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}})}; + VertexAccessor v2{*dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}})}; + VertexAccessor v3{*dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}})}; + storage::EdgeTypeId edge_type{db.NameToEdgeType("Edge")}; + EdgeAccessor r1{*dba.InsertEdge(&v1, &v2, edge_type)}; + EdgeAccessor r2{*dba.InsertEdge(&v1, &v3, edge_type)}; + + void SetUp() override { + ASSERT_TRUE(v1.AddLabel(dba.NameToLabel("l1")).HasValue()); + ASSERT_TRUE(v2.AddLabel(dba.NameToLabel("l2")).HasValue()); + ASSERT_TRUE(v3.AddLabel(dba.NameToLabel("l3")).HasValue()); + dba.AdvanceCommand(); + } +}; + +TEST_F(ExpandFixture, Expand) { + auto test_expand = [&](EdgeAtom::Direction direction, storage::View view) { + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", direction, {}, "m", false, view); + + // make a named expression and a produce + auto *output = + NEXPR("m", IDENT("m")->MapTo(r_m.node_sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(r_m.op_, output); + auto context = MakeContext(storage, symbol_table, &dba); + return PullAll(*produce, &context); + }; + // test that expand works well for both old and new graph state + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, edge_type).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v1, &v3, edge_type).HasValue()); + EXPECT_EQ(2, test_expand(EdgeAtom::Direction::OUT, storage::View::OLD)); + EXPECT_EQ(2, test_expand(EdgeAtom::Direction::IN, storage::View::OLD)); + EXPECT_EQ(4, test_expand(EdgeAtom::Direction::BOTH, storage::View::OLD)); + EXPECT_EQ(4, test_expand(EdgeAtom::Direction::OUT, storage::View::NEW)); + EXPECT_EQ(4, test_expand(EdgeAtom::Direction::IN, storage::View::NEW)); + EXPECT_EQ(8, test_expand(EdgeAtom::Direction::BOTH, storage::View::NEW)); + dba.AdvanceCommand(); + EXPECT_EQ(4, test_expand(EdgeAtom::Direction::OUT, storage::View::OLD)); + EXPECT_EQ(4, test_expand(EdgeAtom::Direction::IN, storage::View::OLD)); + EXPECT_EQ(8, test_expand(EdgeAtom::Direction::BOTH, storage::View::OLD)); +} + +TEST_F(ExpandFixture, ExpandPath) { + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "m", false, + storage::View::OLD); + Symbol path_sym = symbol_table.CreateSymbol("path", true); + auto path = std::make_shared(r_m.op_, path_sym, + std::vector{n.sym_, r_m.edge_sym_, r_m.node_sym_}); + auto output = + NEXPR("path", IDENT("path")->MapTo(path_sym))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(path, output); + + std::vector expected_paths{Path(v1, r2, v3), Path(v1, r1, v2)}; + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + ASSERT_EQ(results.size(), 2); + std::vector results_paths; + for (const auto &result : results) results_paths.emplace_back(result[0].ValuePath()); + EXPECT_TRUE(std::is_permutation(expected_paths.begin(), expected_paths.end(), results_paths.begin())); +} + +// /** +// * A fixture that sets a graph up and provides some functions. +// * +// * The graph is a double chain: +// * (v:0)-(v:1)-(v:2) +// * X X +// * (v:0)-(v:1)-(v:2) +// * +// * Each vertex is labeled (the labels are available as a +// * member in this class). Edges have properties set that +// * indicate origin and destination vertex for debugging. +// */ +class QueryPlanExpandVariable : public QueryPlanMatchFilterTest { + protected: + // type returned by the GetEdgeListSizes function, used + // a lot below in test declaration + using map_int = std::unordered_map; + + storage::Storage::Accessor storage_dba{db.Access()}; + DbAccessor dba{&storage_dba}; + // labels for layers in the double chain + std::vector labels; + // for all the edges + storage::EdgeTypeId edge_type = dba.NameToEdgeType("edge_type"); + + AstStorage storage; + SymbolTable symbol_table; + + // using std::nullopt + std::nullopt_t nullopt = std::nullopt; + + void SetUp() { + // create the graph + int chain_length = 3; + std::vector layer; + for (int from_layer_ind = -1; from_layer_ind < chain_length - 1; from_layer_ind++) { + std::vector new_layer{ + *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}), + *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}})}; + auto label = dba.NameToLabel(std::to_string(from_layer_ind + 1)); + labels.push_back(label); + for (size_t v_to_ind = 0; v_to_ind < new_layer.size(); v_to_ind++) { + auto &v_to = new_layer[v_to_ind]; + ASSERT_TRUE(v_to.AddLabel(label).HasValue()); + for (size_t v_from_ind = 0; v_from_ind < layer.size(); v_from_ind++) { + auto &v_from = layer[v_from_ind]; + auto edge = dba.InsertEdge(&v_from, &v_to, edge_type); + ASSERT_TRUE(edge->SetProperty(dba.NameToProperty("p"), + storage::PropertyValue(fmt::format("V{}{}->V{}{}", from_layer_ind, v_from_ind, + from_layer_ind + 1, v_to_ind))) + .HasValue()); + } + } + layer = new_layer; + } + dba.AdvanceCommand(); + ASSERT_EQ(CountIterable(dba.Vertices(storage::View::OLD)), 2 * chain_length); + ASSERT_EQ(CountEdges(&dba, storage::View::OLD), 4 * (chain_length - 1)); + } + + /** + * Expands the given LogicalOperator input with a match + * (ScanAll->Filter(label)->Expand). Can create both VariableExpand + * ops and plain Expand (depending on template param). + * When creating plain Expand the bound arguments (lower, upper) are ignored. + * + * @param is_reverse Set to true if ExpandVariable should produce the list of + * edges in reverse order. As if ExpandVariable starts from `node_to` and ends + * with `node_from`. + * @return the last created logical op. + */ + template + std::shared_ptr AddMatch(std::shared_ptr input_op, const std::string &node_from, + int layer, EdgeAtom::Direction direction, + const std::vector &edge_types, + std::optional lower, std::optional upper, Symbol edge_sym, + const std::string &node_to, storage::View view, bool is_reverse = false) { + auto n_from = MakeScanAll(storage, symbol_table, node_from, input_op); + auto filter_op = std::make_shared( + n_from.op_, + storage.Create(n_from.node_->identifier_, + std::vector{storage.GetLabelIx(dba.LabelToName(labels[layer]))})); + + auto n_to = NODE(node_to); + auto n_to_sym = symbol_table.CreateSymbol(node_to, true); + n_to->identifier_->MapTo(n_to_sym); + + if (std::is_same::value) { + // convert optional ints to optional expressions + auto convert = [this](std::optional bound) { + return bound ? LITERAL(static_cast(bound.value())) : nullptr; + }; + MG_ASSERT(view == storage::View::OLD, "ExpandVariable should only be planned with storage::View::OLD"); + + return std::make_shared(filter_op, n_from.sym_, n_to_sym, edge_sym, EdgeAtom::Type::DEPTH_FIRST, + direction, edge_types, is_reverse, convert(lower), convert(upper), false, + ExpansionLambda{symbol_table.CreateSymbol("inner_edge", false), + symbol_table.CreateSymbol("inner_node", false), nullptr}, + std::nullopt, std::nullopt); + } else + return std::make_shared(filter_op, n_from.sym_, n_to_sym, edge_sym, direction, edge_types, false, view); + } + + /* Creates an edge (in the frame and symbol table). Returns the symbol. */ + auto Edge(const std::string &identifier, EdgeAtom::Direction direction) { + auto edge = EDGE(identifier, direction); + auto edge_sym = symbol_table.CreateSymbol(identifier, true); + edge->identifier_->MapTo(edge_sym); + return edge_sym; + } + + /** + * Pulls from the given input and returns the results under the given symbol. + */ + auto GetListResults(std::shared_ptr input_op, Symbol symbol) { + Frame frame(symbol_table.max_position()); + auto cursor = input_op->MakeCursor(utils::NewDeleteResource()); + auto context = MakeContext(storage, symbol_table, &dba); + std::vector> results; + while (cursor->Pull(frame, context)) results.emplace_back(frame[symbol].ValueList()); + return results; + } + + /** + * Pulls from the given input and returns the results under the given symbol. + */ + auto GetPathResults(std::shared_ptr input_op, Symbol symbol) { + Frame frame(symbol_table.max_position()); + auto cursor = input_op->MakeCursor(utils::NewDeleteResource()); + auto context = MakeContext(storage, symbol_table, &dba); + std::vector results; + while (cursor->Pull(frame, context)) results.emplace_back(frame[symbol].ValuePath()); + return results; + } + + /** + * Pulls from the given input and analyses the edge-list (result of variable + * length expansion) found in the results under the given symbol. + * + * @return a map {edge_list_length -> number_of_results} + */ + auto GetEdgeListSizes(std::shared_ptr input_op, Symbol symbol) { + map_int count_per_length; + for (const auto &edge_list : GetListResults(input_op, symbol)) { + auto length = edge_list.size(); + auto found = count_per_length.find(length); + if (found == count_per_length.end()) + count_per_length[length] = 1; + else + found->second++; + } + return count_per_length; + } +}; + +TEST_F(QueryPlanExpandVariable, OneVariableExpansion) { + auto test_expand = [&](int layer, EdgeAtom::Direction direction, std::optional lower, + std::optional upper, bool reverse) { + auto e = Edge("r", direction); + return GetEdgeListSizes( + AddMatch(nullptr, "n", layer, direction, {}, lower, upper, e, "m", storage::View::OLD, reverse), + e); + }; + + for (int reverse = 0; reverse < 2; ++reverse) { + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::IN, 0, 0, reverse), (map_int{{0, 2}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, 0, 0, reverse), (map_int{{0, 2}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::BOTH, 0, 0, reverse), (map_int{{0, 2}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::IN, 1, 1, reverse), (map_int{})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, 1, 1, reverse), (map_int{{1, 4}})); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::IN, 1, 1, reverse), (map_int{{1, 4}})); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::OUT, 1, 1, reverse), (map_int{{1, 4}})); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::BOTH, 1, 1, reverse), (map_int{{1, 8}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, 2, 2, reverse), (map_int{{2, 8}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, 2, 3, reverse), (map_int{{2, 8}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, 1, 2, reverse), (map_int{{1, 4}, {2, 8}})); + + // the following tests also check edge-uniqueness (cyphermorphisim) + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::BOTH, 1, 2, reverse), (map_int{{1, 4}, {2, 12}})); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::BOTH, 4, 4, reverse), (map_int{{4, 24}})); + + // default bound values (lower default is 1, upper default is inf) + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, nullopt, 0, reverse), (map_int{})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, nullopt, 1, reverse), (map_int{{1, 4}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, nullopt, 2, reverse), (map_int{{1, 4}, {2, 8}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::BOTH, 7, nullopt, reverse), (map_int{{7, 24}, {8, 24}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::BOTH, 8, nullopt, reverse), (map_int{{8, 24}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::BOTH, 9, nullopt, reverse), (map_int{})); + } +} + +TEST_F(QueryPlanExpandVariable, EdgeUniquenessSingleAndVariableExpansion) { + auto test_expand = [&](int layer, EdgeAtom::Direction direction, std::optional lower, + std::optional upper, bool single_expansion_before, bool add_uniqueness_check) { + std::shared_ptr last_op{nullptr}; + std::vector symbols; + + if (single_expansion_before) { + symbols.push_back(Edge("r0", direction)); + last_op = + AddMatch(last_op, "n0", layer, direction, {}, lower, upper, symbols.back(), "m0", storage::View::OLD); + } + + auto var_length_sym = Edge("r1", direction); + symbols.push_back(var_length_sym); + last_op = AddMatch(last_op, "n1", layer, direction, {}, lower, upper, var_length_sym, "m1", + storage::View::OLD); + + if (!single_expansion_before) { + symbols.push_back(Edge("r2", direction)); + last_op = + AddMatch(last_op, "n2", layer, direction, {}, lower, upper, symbols.back(), "m2", storage::View::OLD); + } + + if (add_uniqueness_check) { + auto last_symbol = symbols.back(); + symbols.pop_back(); + last_op = std::make_shared(last_op, last_symbol, symbols); + } + + return GetEdgeListSizes(last_op, var_length_sym); + }; + + // no uniqueness between variable and single expansion + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, 2, 3, true, false), (map_int{{2, 4 * 8}})); + // with uniqueness test, different ordering of (variable, single) expansion + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, 2, 3, true, true), (map_int{{2, 3 * 8}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, 2, 3, false, true), (map_int{{2, 3 * 8}})); +} + +TEST_F(QueryPlanExpandVariable, EdgeUniquenessTwoVariableExpansions) { + auto test_expand = [&](int layer, EdgeAtom::Direction direction, std::optional lower, + std::optional upper, bool add_uniqueness_check) { + auto e1 = Edge("r1", direction); + auto first = + AddMatch(nullptr, "n1", layer, direction, {}, lower, upper, e1, "m1", storage::View::OLD); + auto e2 = Edge("r2", direction); + auto last_op = + AddMatch(first, "n2", layer, direction, {}, lower, upper, e2, "m2", storage::View::OLD); + if (add_uniqueness_check) { + last_op = std::make_shared(last_op, e2, std::vector{e1}); + } + + return GetEdgeListSizes(last_op, e2); + }; + + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, 2, 2, false), (map_int{{2, 8 * 8}})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, 2, 2, true), (map_int{{2, 5 * 8}})); +} + +TEST_F(QueryPlanExpandVariable, NamedPath) { + auto e = Edge("r", EdgeAtom::Direction::OUT); + auto expand = + AddMatch(nullptr, "n", 0, EdgeAtom::Direction::OUT, {}, 2, 2, e, "m", storage::View::OLD); + auto find_symbol = [this](const std::string &name) { + for (const auto &sym : symbol_table.table()) + if (sym.second.name() == name) return sym.second; + throw std::runtime_error("Symbol not found"); + }; + + auto path_symbol = symbol_table.CreateSymbol("path", true, Symbol::Type::PATH); + auto create_path = std::make_shared(expand, path_symbol, + std::vector{find_symbol("n"), e, find_symbol("m")}); + + std::vector expected_paths; + for (const auto &v : dba.Vertices(storage::View::OLD)) { + if (!*v.HasLabel(storage::View::OLD, labels[0])) continue; + auto maybe_edges1 = v.OutEdges(storage::View::OLD); + for (const auto &e1 : *maybe_edges1) { + auto maybe_edges2 = e1.To().OutEdges(storage::View::OLD); + for (const auto &e2 : *maybe_edges2) { + expected_paths.emplace_back(v, e1, e1.To(), e2, e2.To()); + } + } + } + ASSERT_EQ(expected_paths.size(), 8); + + auto results = GetPathResults(create_path, path_symbol); + ASSERT_EQ(results.size(), 8); + EXPECT_TRUE(std::is_permutation(results.begin(), results.end(), expected_paths.begin())); +} + +TEST_F(QueryPlanExpandVariable, ExpandToSameSymbol) { + auto test_expand = [&](int layer, EdgeAtom::Direction direction, std::optional lower, + std::optional upper, bool reverse) { + auto e = Edge("r", direction); + + auto node = NODE("n"); + auto symbol = symbol_table.CreateSymbol("n", true); + node->identifier_->MapTo(symbol); + auto logical_op = std::make_shared(nullptr, symbol, storage::View::OLD); + auto n_from = ScanAllTuple{node, logical_op, symbol}; + + auto filter_op = std::make_shared( + n_from.op_, + storage.Create(n_from.node_->identifier_, + std::vector{storage.GetLabelIx(dba.LabelToName(labels[layer]))})); + + // convert optional ints to optional expressions + auto convert = [this](std::optional bound) { + return bound ? LITERAL(static_cast(bound.value())) : nullptr; + }; + + return GetEdgeListSizes( + std::make_shared(filter_op, symbol, symbol, e, EdgeAtom::Type::DEPTH_FIRST, direction, + std::vector{}, reverse, convert(lower), convert(upper), + /* existing = */ true, + ExpansionLambda{symbol_table.CreateSymbol("inner_edge", false), + symbol_table.CreateSymbol("inner_node", false), nullptr}, + std::nullopt, std::nullopt), + e); + }; + + // The graph is a double chain: + // chain 0: (v:0)-(v:1)-(v:2) + // X X + // chain 1: (v:0)-(v:1)-(v:2) + + // Expand from chain 0 v:0 to itself. + // + // It has a total of 3 cycles: + // 1. C0 v:0 -> C0 v:1 -> C1 v:2 -> C1 v:1 -> C0 v:0 + // 2. C0 v:0 -> C0 v:1 -> C0 v:2 -> C1 v:1 -> C0 v:0 + // 3. C0 v:0 -> C0 v:1 -> C1 v:0 -> C1 v:1 -> C0 v:0 + // + // Each cycle can be in two directions, also, we have two starting nodes: one + // in chain 0 and the other in chain 1. + for (auto reverse : {false, true}) { + // Tests with both bounds set. + for (int lower_bound = 0; lower_bound < 10; ++lower_bound) { + for (int upper_bound = lower_bound; upper_bound < 10; ++upper_bound) { + map_int expected_directed; + map_int expected_undirected; + if (lower_bound == 0) { + expected_directed.emplace(0, 2); + expected_undirected.emplace(0, 2); + } + if (lower_bound <= 4 && upper_bound >= 4) { + expected_undirected.emplace(4, 12); + } + if (lower_bound <= 8 && upper_bound >= 8) { + expected_undirected.emplace(8, 24); + } + + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::IN, lower_bound, upper_bound, reverse), expected_directed); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, lower_bound, upper_bound, reverse), expected_directed); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::BOTH, lower_bound, upper_bound, reverse), expected_undirected); + } + } + + // Test only upper bound. + for (int upper_bound = 0; upper_bound < 10; ++upper_bound) { + map_int expected_directed; + map_int expected_undirected; + if (upper_bound >= 4) { + expected_undirected.emplace(4, 12); + } + if (upper_bound >= 8) { + expected_undirected.emplace(8, 24); + } + + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::IN, std::nullopt, upper_bound, reverse), expected_directed); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, std::nullopt, upper_bound, reverse), expected_directed); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::BOTH, std::nullopt, upper_bound, reverse), expected_undirected); + } + + // Test only lower bound. + for (int lower_bound = 0; lower_bound < 10; ++lower_bound) { + map_int expected_directed; + map_int expected_undirected; + if (lower_bound == 0) { + expected_directed.emplace(0, 2); + expected_undirected.emplace(0, 2); + } + if (lower_bound <= 4) { + expected_undirected.emplace(4, 12); + } + if (lower_bound <= 8) { + expected_undirected.emplace(8, 24); + } + + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::IN, lower_bound, std::nullopt, reverse), expected_directed); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, lower_bound, std::nullopt, reverse), expected_directed); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::BOTH, lower_bound, std::nullopt, reverse), expected_undirected); + } + + // Test no bounds. + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::IN, std::nullopt, std::nullopt, reverse), (map_int{})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::OUT, std::nullopt, std::nullopt, reverse), (map_int{})); + EXPECT_EQ(test_expand(0, EdgeAtom::Direction::BOTH, std::nullopt, std::nullopt, reverse), + (map_int{{4, 12}, {8, 24}})); + } + + // Expand from chain 0 v:1 to itself. + // + // It has a total of 6 cycles: + // 1. C0 v:1 -> C1 v:0 -> C1 v:1 -> C1 v:2 -> C0 v:1 + // 2. C0 v:1 -> C1 v:0 -> C1 v:1 -> C0 v:2 -> C0 v:1 + // 3. C0 v:1 -> C0 v:0 -> C1 v:1 -> C1 v:2 -> C0 v:1 + // 4. C0 v:1 -> C0 v:0 -> C1 v:1 -> C0 v:2 -> C0 v:1 + // 5. C0 v:1 -> C1 v:0 -> C1 v:1 -> C0 v:0 -> C0 v:1 + // 6. C0 v:1 -> C1 v:2 -> C1 v:1 -> C0 v:2 -> C0 v:1 + // + // Each cycle can be in two directions, also, we have two starting nodes: one + // in chain 0 and the other in chain 1. + for (auto reverse : {false, true}) { + // Tests with both bounds set. + for (int lower_bound = 0; lower_bound < 10; ++lower_bound) { + for (int upper_bound = lower_bound; upper_bound < 10; ++upper_bound) { + map_int expected_directed; + map_int expected_undirected; + if (lower_bound == 0) { + expected_directed.emplace(0, 2); + expected_undirected.emplace(0, 2); + } + if (lower_bound <= 4 && upper_bound >= 4) { + expected_undirected.emplace(4, 24); + } + if (lower_bound <= 8 && upper_bound >= 8) { + expected_undirected.emplace(8, 48); + } + + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::IN, lower_bound, upper_bound, reverse), expected_directed); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::OUT, lower_bound, upper_bound, reverse), expected_directed); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::BOTH, lower_bound, upper_bound, reverse), expected_undirected); + } + } + + // Test only upper bound. + for (int upper_bound = 0; upper_bound < 10; ++upper_bound) { + map_int expected_directed; + map_int expected_undirected; + if (upper_bound >= 4) { + expected_undirected.emplace(4, 24); + } + if (upper_bound >= 8) { + expected_undirected.emplace(8, 48); + } + + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::IN, std::nullopt, upper_bound, reverse), expected_directed); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::OUT, std::nullopt, upper_bound, reverse), expected_directed); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::BOTH, std::nullopt, upper_bound, reverse), expected_undirected); + } + + // Test only lower bound. + for (int lower_bound = 0; lower_bound < 10; ++lower_bound) { + map_int expected_directed; + map_int expected_undirected; + if (lower_bound == 0) { + expected_directed.emplace(0, 2); + expected_undirected.emplace(0, 2); + } + if (lower_bound <= 4) { + expected_undirected.emplace(4, 24); + } + if (lower_bound <= 8) { + expected_undirected.emplace(8, 48); + } + + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::IN, lower_bound, std::nullopt, reverse), expected_directed); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::OUT, lower_bound, std::nullopt, reverse), expected_directed); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::BOTH, lower_bound, std::nullopt, reverse), expected_undirected); + } + + // Test no bounds. + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::IN, std::nullopt, std::nullopt, reverse), (map_int{})); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::OUT, std::nullopt, std::nullopt, reverse), (map_int{})); + EXPECT_EQ(test_expand(1, EdgeAtom::Direction::BOTH, std::nullopt, std::nullopt, reverse), + (map_int{{4, 24}, {8, 48}})); + } +} + +/** A test fixture for weighted shortest path expansion */ +class QueryPlanExpandWeightedShortestPath : public QueryPlanMatchFilterTest { + public: + struct ResultType { + std::vector path; + VertexAccessor vertex; + double total_weight; + }; + + protected: + storage::Storage::Accessor storage_dba{db.Access()}; + DbAccessor dba{&storage_dba}; + std::pair prop = PROPERTY_PAIR("property1"); + storage::EdgeTypeId edge_type = dba.NameToEdgeType("edge_type"); + + // make 5 vertices because we'll need to compare against them exactly + // v[0] has `prop` with the value 0 + std::vector v; + + // make some edges too, in a map (from, to) vertex indices + std::unordered_map, EdgeAccessor> e; + + AstStorage storage; + SymbolTable symbol_table; + + // inner edge and vertex symbols + Symbol filter_edge = symbol_table.CreateSymbol("f_edge", true); + Symbol filter_node = symbol_table.CreateSymbol("f_node", true); + + Symbol weight_edge = symbol_table.CreateSymbol("w_edge", true); + Symbol weight_node = symbol_table.CreateSymbol("w_node", true); + + Symbol total_weight = symbol_table.CreateSymbol("total_weight", true); + + void SetUp() { + for (int i = 0; i < 5; i++) { + v.push_back(*dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(i)}})); + ASSERT_TRUE(v.back().SetProperty(prop.second, storage::PropertyValue(i)).HasValue()); + } + + auto add_edge = [&](int from, int to, double weight) { + auto edge = dba.InsertEdge(&v[from], &v[to], edge_type); + ASSERT_TRUE(edge->SetProperty(prop.second, storage::PropertyValue(weight)).HasValue()); + e.emplace(std::make_pair(from, to), *edge); + }; + + add_edge(0, 1, 5); + add_edge(1, 4, 5); + add_edge(0, 2, 3); + add_edge(2, 3, 3); + add_edge(3, 4, 3); + add_edge(4, 0, 12); + + dba.AdvanceCommand(); + } + + // defines and performs a weighted shortest expansion with the given + // params returns a vector of pairs. each pair is (vector-of-edges, + // vertex) + auto ExpandWShortest(EdgeAtom::Direction direction, std::optional max_depth, Expression *where, + std::optional node_id = 0, ScanAllTuple *existing_node_input = nullptr) { + // scan the nodes optionally filtering on property value + auto n = MakeScanAll(storage, symbol_table, "n", existing_node_input ? existing_node_input->op_ : nullptr); + auto last_op = n.op_; + if (node_id) { + last_op = std::make_shared(last_op, EQ(PROPERTY_LOOKUP(n.node_->identifier_, prop), LITERAL(*node_id))); + } + + auto ident_e = IDENT("e"); + ident_e->MapTo(weight_edge); + + // expand wshortest + auto node_sym = existing_node_input ? existing_node_input->sym_ : symbol_table.CreateSymbol("node", true); + auto edge_list_sym = symbol_table.CreateSymbol("edgelist_", true); + auto filter_lambda = last_op = std::make_shared( + last_op, n.sym_, node_sym, edge_list_sym, EdgeAtom::Type::WEIGHTED_SHORTEST_PATH, direction, + std::vector{}, false, nullptr, max_depth ? LITERAL(max_depth.value()) : nullptr, + existing_node_input != nullptr, ExpansionLambda{filter_edge, filter_node, where}, + ExpansionLambda{weight_edge, weight_node, PROPERTY_LOOKUP(ident_e, prop)}, total_weight); + + Frame frame(symbol_table.max_position()); + auto cursor = last_op->MakeCursor(utils::NewDeleteResource()); + std::vector results; + auto context = MakeContext(storage, symbol_table, &dba); + while (cursor->Pull(frame, context)) { + results.push_back( + ResultType{std::vector(), frame[node_sym].ValueVertex(), frame[total_weight].ValueDouble()}); + for (const TypedValue &edge : frame[edge_list_sym].ValueList()) + results.back().path.emplace_back(edge.ValueEdge()); + } + + return results; + } + + template + auto GetProp(const TAccessor &accessor) { + return accessor.GetProperty(storage::View::OLD, prop.second)->ValueInt(); + } + + template + auto GetDoubleProp(const TAccessor &accessor) { + return accessor.GetProperty(storage::View::OLD, prop.second)->ValueDouble(); + } + + Expression *PropNe(Symbol symbol, int value) { + auto ident = IDENT("inner_element"); + ident->MapTo(symbol); + return NEQ(PROPERTY_LOOKUP(ident, prop), LITERAL(value)); + } +}; + +// // Testing weighted shortest path on this graph: +// // +// // 5 5 +// // /-->--[1]-->--\ +// // / \ +// // / 12 \ 2 +// // [0]--------<--------[4]------->-------[5] +// // \ / (on some tests only) +// // \ / +// // \->[2]->-[3]->/ +// // 3 3 3 + +TEST_F(QueryPlanExpandWeightedShortestPath, Basic) { + auto results = ExpandWShortest(EdgeAtom::Direction::BOTH, 1000, LITERAL(true)); + + ASSERT_EQ(results.size(), 4); + + // check end nodes + EXPECT_EQ(GetProp(results[0].vertex), 2); + EXPECT_EQ(GetProp(results[1].vertex), 1); + EXPECT_EQ(GetProp(results[2].vertex), 3); + EXPECT_EQ(GetProp(results[3].vertex), 4); + + // check paths and total weights + EXPECT_EQ(results[0].path.size(), 1); + EXPECT_EQ(GetDoubleProp(results[0].path[0]), 3); + EXPECT_EQ(results[0].total_weight, 3); + + EXPECT_EQ(results[1].path.size(), 1); + EXPECT_EQ(GetDoubleProp(results[1].path[0]), 5); + EXPECT_EQ(results[1].total_weight, 5); + + EXPECT_EQ(results[2].path.size(), 2); + EXPECT_EQ(GetDoubleProp(results[2].path[0]), 3); + EXPECT_EQ(GetDoubleProp(results[2].path[1]), 3); + EXPECT_EQ(results[2].total_weight, 6); + + EXPECT_EQ(results[3].path.size(), 3); + EXPECT_EQ(GetDoubleProp(results[3].path[0]), 3); + EXPECT_EQ(GetDoubleProp(results[3].path[1]), 3); + EXPECT_EQ(GetDoubleProp(results[3].path[2]), 3); + EXPECT_EQ(results[3].total_weight, 9); +} + +TEST_F(QueryPlanExpandWeightedShortestPath, EdgeDirection) { + { + auto results = ExpandWShortest(EdgeAtom::Direction::OUT, 1000, LITERAL(true)); + ASSERT_EQ(results.size(), 4); + EXPECT_EQ(GetProp(results[0].vertex), 2); + EXPECT_EQ(results[0].total_weight, 3); + EXPECT_EQ(GetProp(results[1].vertex), 1); + EXPECT_EQ(results[1].total_weight, 5); + EXPECT_EQ(GetProp(results[2].vertex), 3); + EXPECT_EQ(results[2].total_weight, 6); + EXPECT_EQ(GetProp(results[3].vertex), 4); + EXPECT_EQ(results[3].total_weight, 9); + } + { + auto results = ExpandWShortest(EdgeAtom::Direction::IN, 1000, LITERAL(true)); + ASSERT_EQ(results.size(), 4); + EXPECT_EQ(GetProp(results[0].vertex), 4); + EXPECT_EQ(results[0].total_weight, 12); + EXPECT_EQ(GetProp(results[1].vertex), 3); + EXPECT_EQ(results[1].total_weight, 15); + EXPECT_EQ(GetProp(results[2].vertex), 1); + EXPECT_EQ(results[2].total_weight, 17); + EXPECT_EQ(GetProp(results[3].vertex), 2); + EXPECT_EQ(results[3].total_weight, 18); + } +} + +TEST_F(QueryPlanExpandWeightedShortestPath, Where) { + { + auto results = ExpandWShortest(EdgeAtom::Direction::BOTH, 1000, PropNe(filter_node, 2)); + ASSERT_EQ(results.size(), 3); + EXPECT_EQ(GetProp(results[0].vertex), 1); + EXPECT_EQ(results[0].total_weight, 5); + EXPECT_EQ(GetProp(results[1].vertex), 4); + EXPECT_EQ(results[1].total_weight, 10); + EXPECT_EQ(GetProp(results[2].vertex), 3); + EXPECT_EQ(results[2].total_weight, 13); + } + { + auto results = ExpandWShortest(EdgeAtom::Direction::BOTH, 1000, PropNe(filter_node, 1)); + ASSERT_EQ(results.size(), 3); + EXPECT_EQ(GetProp(results[0].vertex), 2); + EXPECT_EQ(results[0].total_weight, 3); + EXPECT_EQ(GetProp(results[1].vertex), 3); + EXPECT_EQ(results[1].total_weight, 6); + EXPECT_EQ(GetProp(results[2].vertex), 4); + EXPECT_EQ(results[2].total_weight, 9); + } +} + +TEST_F(QueryPlanExpandWeightedShortestPath, ExistingNode) { + auto ExpandPreceeding = [this](std::optional preceeding_node_id) { + // scan the nodes optionally filtering on property value + auto n0 = MakeScanAll(storage, symbol_table, "n0"); + if (preceeding_node_id) { + auto filter = std::make_shared( + n0.op_, EQ(PROPERTY_LOOKUP(n0.node_->identifier_, prop), LITERAL(*preceeding_node_id))); + // inject the filter op into the ScanAllTuple. that way the filter + // op can be passed into the ExpandWShortest function without too + // much refactor + n0.op_ = filter; + } + + return ExpandWShortest(EdgeAtom::Direction::OUT, 1000, LITERAL(true), std::nullopt, &n0); + }; + + EXPECT_EQ(ExpandPreceeding(std::nullopt).size(), 20); + { + auto results = ExpandPreceeding(3); + ASSERT_EQ(results.size(), 4); + for (int i = 0; i < 4; i++) EXPECT_EQ(GetProp(results[i].vertex), 3); + } +} + +TEST_F(QueryPlanExpandWeightedShortestPath, UpperBound) { + { + auto results = ExpandWShortest(EdgeAtom::Direction::BOTH, std::nullopt, LITERAL(true)); + ASSERT_EQ(results.size(), 4); + EXPECT_EQ(GetProp(results[0].vertex), 2); + EXPECT_EQ(results[0].total_weight, 3); + EXPECT_EQ(GetProp(results[1].vertex), 1); + EXPECT_EQ(results[1].total_weight, 5); + EXPECT_EQ(GetProp(results[2].vertex), 3); + EXPECT_EQ(results[2].total_weight, 6); + EXPECT_EQ(GetProp(results[3].vertex), 4); + EXPECT_EQ(results[3].total_weight, 9); + } + { + auto results = ExpandWShortest(EdgeAtom::Direction::BOTH, 2, LITERAL(true)); + ASSERT_EQ(results.size(), 4); + EXPECT_EQ(GetProp(results[0].vertex), 2); + EXPECT_EQ(results[0].total_weight, 3); + EXPECT_EQ(GetProp(results[1].vertex), 1); + EXPECT_EQ(results[1].total_weight, 5); + EXPECT_EQ(GetProp(results[2].vertex), 3); + EXPECT_EQ(results[2].total_weight, 6); + EXPECT_EQ(GetProp(results[3].vertex), 4); + EXPECT_EQ(results[3].total_weight, 10); + } + { + auto results = ExpandWShortest(EdgeAtom::Direction::BOTH, 1, LITERAL(true)); + ASSERT_EQ(results.size(), 3); + EXPECT_EQ(GetProp(results[0].vertex), 2); + EXPECT_EQ(results[0].total_weight, 3); + EXPECT_EQ(GetProp(results[1].vertex), 1); + EXPECT_EQ(results[1].total_weight, 5); + EXPECT_EQ(GetProp(results[2].vertex), 4); + EXPECT_EQ(results[2].total_weight, 12); + } + { + auto new_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + ASSERT_TRUE(new_vertex.SetProperty(prop.second, storage::PropertyValue(5)).HasValue()); + auto edge = dba.InsertEdge(&v[4], &new_vertex, edge_type); + ASSERT_TRUE(edge.HasValue()); + ASSERT_TRUE(edge->SetProperty(prop.second, storage::PropertyValue(2)).HasValue()); + dba.AdvanceCommand(); + + auto results = ExpandWShortest(EdgeAtom::Direction::BOTH, 3, LITERAL(true)); + + ASSERT_EQ(results.size(), 5); + EXPECT_EQ(GetProp(results[0].vertex), 2); + EXPECT_EQ(results[0].total_weight, 3); + EXPECT_EQ(GetProp(results[1].vertex), 1); + EXPECT_EQ(results[1].total_weight, 5); + EXPECT_EQ(GetProp(results[2].vertex), 3); + EXPECT_EQ(results[2].total_weight, 6); + EXPECT_EQ(GetProp(results[3].vertex), 4); + EXPECT_EQ(results[3].total_weight, 9); + EXPECT_EQ(GetProp(results[4].vertex), 5); + EXPECT_EQ(results[4].total_weight, 12); + } +} + +TEST_F(QueryPlanExpandWeightedShortestPath, NonNumericWeight) { + auto new_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + ASSERT_TRUE(new_vertex.SetProperty(prop.second, storage::PropertyValue(5)).HasValue()); + auto edge = dba.InsertEdge(&v[4], &new_vertex, edge_type); + ASSERT_TRUE(edge.HasValue()); + ASSERT_TRUE(edge->SetProperty(prop.second, storage::PropertyValue("not a number")).HasValue()); + dba.AdvanceCommand(); + EXPECT_THROW(ExpandWShortest(EdgeAtom::Direction::BOTH, 1000, LITERAL(true)), QueryRuntimeException); +} + +TEST_F(QueryPlanExpandWeightedShortestPath, NegativeWeight) { + auto new_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + ASSERT_TRUE(new_vertex.SetProperty(prop.second, storage::PropertyValue(5)).HasValue()); + auto edge = dba.InsertEdge(&v[4], &new_vertex, edge_type); + ASSERT_TRUE(edge.HasValue()); + ASSERT_TRUE(edge->SetProperty(prop.second, storage::PropertyValue(-10)).HasValue()); // negative weight + dba.AdvanceCommand(); + EXPECT_THROW(ExpandWShortest(EdgeAtom::Direction::BOTH, 1000, LITERAL(true)), QueryRuntimeException); +} + +TEST_F(QueryPlanExpandWeightedShortestPath, NegativeUpperBound) { + EXPECT_THROW(ExpandWShortest(EdgeAtom::Direction::BOTH, -1, LITERAL(true)), QueryRuntimeException); +} + +TEST_F(QueryPlanMatchFilterTest, ExpandOptional) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + AstStorage storage; + SymbolTable symbol_table; + + // graph (v2 {p: 2})<-[:T]-(v1 {p: 1})-[:T]->(v3 {p: 2}) + auto prop = dba.NameToProperty("p"); + auto edge_type = dba.NameToEdgeType("T"); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + ASSERT_TRUE(v1.SetProperty(prop, storage::PropertyValue(1)).HasValue()); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + ASSERT_TRUE(v2.SetProperty(prop, storage::PropertyValue(2)).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, edge_type).HasValue()); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}); + ASSERT_TRUE(v3.SetProperty(prop, storage::PropertyValue(2)).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v1, &v3, edge_type).HasValue()); + dba.AdvanceCommand(); + + // MATCH (n) OPTIONAL MATCH (n)-[r]->(m) + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, nullptr, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "m", false, + storage::View::OLD); + auto optional = std::make_shared(n.op_, r_m.op_, std::vector{r_m.edge_sym_, r_m.node_sym_}); + + // RETURN n, r, m + auto n_ne = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("n", true)); + auto r_ne = NEXPR("r", IDENT("r")->MapTo(r_m.edge_sym_))->MapTo(symbol_table.CreateSymbol("r", true)); + auto m_ne = NEXPR("m", IDENT("m")->MapTo(r_m.node_sym_))->MapTo(symbol_table.CreateSymbol("m", true)); + auto produce = MakeProduce(optional, n_ne, r_ne, m_ne); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + ASSERT_EQ(4, results.size()); + int v1_is_n_count = 0; + for (auto &row : results) { + ASSERT_EQ(row[0].type(), TypedValue::Type::Vertex); + auto &va = row[0].ValueVertex(); + auto va_p = *va.GetProperty(storage::View::OLD, prop); + ASSERT_EQ(va_p.type(), storage::PropertyValue::Type::Int); + if (va_p.ValueInt() == 1) { + v1_is_n_count++; + EXPECT_EQ(row[1].type(), TypedValue::Type::Edge); + EXPECT_EQ(row[2].type(), TypedValue::Type::Vertex); + } else { + EXPECT_EQ(row[1].type(), TypedValue::Type::Null); + EXPECT_EQ(row[2].type(), TypedValue::Type::Null); + } + } + EXPECT_EQ(2, v1_is_n_count); +} + +TEST(QueryPlan, OptionalMatchEmptyDB) { + storage::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + AstStorage storage; + SymbolTable symbol_table; + + // OPTIONAL MATCH (n) + auto n = MakeScanAllNew(storage, symbol_table, "n"); + // RETURN n + auto n_ne = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("n", true)); + auto optional = std::make_shared(nullptr, n.op_, std::vector{n.sym_}); + auto produce = MakeProduce(optional, n_ne); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + ASSERT_EQ(1, results.size()); + EXPECT_EQ(results[0][0].type(), TypedValue::Type::Null); +} + +TEST(QueryPlan, OptionalMatchEmptyDBExpandFromNode) { + storage::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + // OPTIONAL MATCH (n) + auto n = MakeScanAll(storage, symbol_table, "n"); + auto optional = std::make_shared(nullptr, n.op_, std::vector{n.sym_}); + // WITH n + auto n_ne = NEXPR("n", IDENT("n")->MapTo(n.sym_)); + auto with_n_sym = symbol_table.CreateSymbol("n", true); + n_ne->MapTo(with_n_sym); + auto with = MakeProduce(optional, n_ne); + // MATCH (n) -[r]-> (m) + auto r_m = MakeExpand(storage, symbol_table, with, with_n_sym, "r", EdgeAtom::Direction::OUT, {}, "m", false, + storage::View::OLD); + // RETURN m + auto m_ne = NEXPR("m", IDENT("m")->MapTo(r_m.node_sym_))->MapTo(symbol_table.CreateSymbol("m", true)); + auto produce = MakeProduce(r_m.op_, m_ne); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(0, results.size()); +} + +TEST_F(QueryPlanMatchFilterTest, OptionalMatchThenExpandToMissingNode) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + // Make a graph with 2 connected, unlabeled nodes. + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + auto edge_type = dba.NameToEdgeType("edge_type"); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, edge_type).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(1, CountEdges(&dba, storage::View::OLD)); + AstStorage storage; + SymbolTable symbol_table; + // OPTIONAL MATCH (n :missing) + auto n = MakeScanAll(storage, symbol_table, "n"); + auto label_missing = "missing"; + n.node_->labels_.emplace_back(storage.GetLabelIx(label_missing)); + + auto *filter_expr = storage.Create(n.node_->identifier_, n.node_->labels_); + auto node_filter = std::make_shared(n.op_, filter_expr); + auto optional = std::make_shared(nullptr, node_filter, std::vector{n.sym_}); + // WITH n + auto n_ne = NEXPR("n", IDENT("n")->MapTo(n.sym_)); + auto with_n_sym = symbol_table.CreateSymbol("n", true); + n_ne->MapTo(with_n_sym); + auto with = MakeProduce(optional, n_ne); + // MATCH (m) -[r]-> (n) + auto m = MakeScanAll(storage, symbol_table, "m", with); + auto edge_direction = EdgeAtom::Direction::OUT; + auto edge = EDGE("r", edge_direction); + auto edge_sym = symbol_table.CreateSymbol("r", true); + edge->identifier_->MapTo(edge_sym); + auto node = NODE("n"); + node->identifier_->MapTo(with_n_sym); + auto expand = std::make_shared(m.op_, m.sym_, with_n_sym, edge_sym, edge_direction, + std::vector{}, true, storage::View::OLD); + // RETURN m + auto m_ne = NEXPR("m", IDENT("m")->MapTo(m.sym_))->MapTo(symbol_table.CreateSymbol("m", true)); + auto produce = MakeProduce(expand, m_ne); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(0, results.size()); +} + +TEST_F(QueryPlanMatchFilterTest, ExpandExistingNode) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // make a graph (v1)->(v2) that + // has a recursive edge (v1)->(v1) + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + auto edge_type = dba.NameToEdgeType("Edge"); + ASSERT_TRUE(dba.InsertEdge(&v1, &v1, edge_type).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, edge_type).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto test_existing = [&](bool with_existing, int expected_result_count) { + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_n = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "n", with_existing, + storage::View::OLD); + if (with_existing) + r_n.op_ = std::make_shared(n.op_, n.sym_, n.sym_, r_n.edge_sym_, r_n.edge_->direction_, + std::vector{}, with_existing, storage::View::OLD); + + // make a named expression and a produce + auto output = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(r_n.op_, output); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), expected_result_count); + }; + + test_existing(true, 1); + test_existing(false, 2); +} + +TEST_F(QueryPlanMatchFilterTest, ExpandBothCycleEdgeCase) { + // we're testing that expanding on BOTH + // does only one expansion for a cycle + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + ASSERT_TRUE(dba.InsertEdge(&v, &v, dba.NameToEdgeType("et")).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_ = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::BOTH, {}, "_", false, + storage::View::OLD); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(1, PullAll(*r_.op_, &context)); +} + +TEST_F(QueryPlanMatchFilterTest, EdgeFilter) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // make an N-star expanding from (v1) + // where only one edge will qualify + // and there are all combinations of + // (edge_type yes|no) * (property yes|absent|no) + std::vector edge_types; + for (int j = 0; j < 2; ++j) edge_types.push_back(dba.NameToEdgeType("et" + std::to_string(j))); + std::vector vertices; + for (int i = 0; i < 7; ++i) { + vertices.push_back(*dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(i)}})); + } + auto prop = PROPERTY_PAIR("property1"); + std::vector edges; + for (int i = 0; i < 6; ++i) { + edges.push_back(*dba.InsertEdge(&vertices[0], &vertices[i + 1], edge_types[i % 2])); + switch (i % 3) { + case 0: + ASSERT_TRUE(edges.back().SetProperty(prop.second, storage::PropertyValue(42)).HasValue()); + break; + case 1: + ASSERT_TRUE(edges.back().SetProperty(prop.second, storage::PropertyValue(100)).HasValue()); + break; + default: + break; + } + } + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto test_filter = [&]() { + // define an operator tree for query + // MATCH (n)-[r :et0 {property: 42}]->(m) RETURN m + + auto n = MakeScanAll(storage, symbol_table, "n"); + const auto &edge_type = edge_types[0]; + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {edge_type}, "m", false, + storage::View::OLD); + r_m.edge_->edge_types_.push_back(storage.GetEdgeTypeIx(dba.EdgeTypeToName(edge_type))); + std::get<0>(r_m.edge_->properties_)[storage.GetPropertyIx(prop.first)] = LITERAL(42); + auto *filter_expr = EQ(PROPERTY_LOOKUP(r_m.edge_->identifier_, prop), LITERAL(42)); + auto edge_filter = std::make_shared(r_m.op_, filter_expr); + + // make a named expression and a produce + auto output = + NEXPR("m", IDENT("m")->MapTo(r_m.node_sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(edge_filter, output); + auto context = MakeContext(storage, symbol_table, &dba); + return PullAll(*produce, &context); + }; + + EXPECT_EQ(1, test_filter()); + // test that edge filtering always filters on old state + for (auto &edge : edges) ASSERT_TRUE(edge.SetProperty(prop.second, storage::PropertyValue(42)).HasValue()); + EXPECT_EQ(1, test_filter()); + dba.AdvanceCommand(); + EXPECT_EQ(3, test_filter()); +} + +TEST_F(QueryPlanMatchFilterTest, EdgeFilterMultipleTypes) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + auto type_1 = dba.NameToEdgeType("type_1"); + auto type_2 = dba.NameToEdgeType("type_2"); + auto type_3 = dba.NameToEdgeType("type_3"); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, type_1).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, type_2).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, type_3).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + // make a scan all + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {type_1, type_2}, "m", + false, storage::View::OLD); + + // make a named expression and a produce + auto output = + NEXPR("m", IDENT("m")->MapTo(r_m.node_sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(r_m.op_, output); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), 2); +} + +TEST_F(QueryPlanMatchFilterTest, Filter) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // add a 6 nodes with property 'prop', 2 have true as value + auto property1 = PROPERTY_PAIR("property1"); + for (int i = 0; i < 6; ++i) { + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(i)}}) + ->SetProperty(property1.second, storage::PropertyValue(i % 3 == 0)) + .HasValue()); + } + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}) + .HasValue()); // prop not set, gives NULL + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto e = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), property1); + auto f = std::make_shared(n.op_, e); + + auto output = NEXPR("x", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(f, output); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(CollectProduce(*produce, &context).size(), 2); +} + +TEST_F(QueryPlanMatchFilterTest, EdgeUniquenessFilter) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + + // make a graph that has (v1)->(v2) and a recursive edge (v1)->(v1) + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + auto edge_type = dba.NameToEdgeType("edge_type"); + ASSERT_TRUE(dba.InsertEdge(&v1, &v2, edge_type).HasValue()); + ASSERT_TRUE(dba.InsertEdge(&v1, &v1, edge_type).HasValue()); + dba.AdvanceCommand(); + + auto check_expand_results = [&](bool edge_uniqueness) { + AstStorage storage; + SymbolTable symbol_table; + + auto n1 = MakeScanAll(storage, symbol_table, "n1"); + auto r1_n2 = MakeExpand(storage, symbol_table, n1.op_, n1.sym_, "r1", EdgeAtom::Direction::OUT, {}, "n2", false, + storage::View::OLD); + std::shared_ptr last_op = r1_n2.op_; + auto r2_n3 = MakeExpand(storage, symbol_table, last_op, r1_n2.node_sym_, "r2", EdgeAtom::Direction::OUT, {}, "n3", + false, storage::View::OLD); + last_op = r2_n3.op_; + if (edge_uniqueness) + last_op = std::make_shared(last_op, r2_n3.edge_sym_, std::vector{r1_n2.edge_sym_}); + auto context = MakeContext(storage, symbol_table, &dba); + return PullAll(*last_op, &context); + }; + + EXPECT_EQ(2, check_expand_results(false)); + EXPECT_EQ(1, check_expand_results(true)); +} + +TEST(QueryPlan, Distinct) { + // test queries like + // UNWIND [1, 2, 3, 3] AS x RETURN DISTINCT x + + storage::Storage db; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + + auto check_distinct = [&](const std::vector input, const std::vector output, + bool assume_int_value) { + auto input_expr = LITERAL(TypedValue(input)); + + auto x = symbol_table.CreateSymbol("x", true); + auto unwind = std::make_shared(nullptr, input_expr, x); + auto x_expr = IDENT("x"); + x_expr->MapTo(x); + + auto distinct = std::make_shared(unwind, std::vector{x}); + + auto x_ne = NEXPR("x", x_expr); + x_ne->MapTo(symbol_table.CreateSymbol("x_ne", true)); + auto produce = MakeProduce(distinct, x_ne); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + ASSERT_EQ(output.size(), results.size()); + auto output_it = output.begin(); + for (const auto &row : results) { + ASSERT_EQ(1, row.size()); + ASSERT_EQ(row[0].type(), output_it->type()); + if (assume_int_value) EXPECT_EQ(output_it->ValueInt(), row[0].ValueInt()); + output_it++; + } + }; + + check_distinct({TypedValue(1), TypedValue(1), TypedValue(2), TypedValue(3), TypedValue(3), TypedValue(3)}, + {TypedValue(1), TypedValue(2), TypedValue(3)}, true); + check_distinct({TypedValue(3), TypedValue(2), TypedValue(3), TypedValue(5), TypedValue(3), TypedValue(5), + TypedValue(2), TypedValue(1), TypedValue(2)}, + {TypedValue(3), TypedValue(2), TypedValue(5), TypedValue(1)}, true); + check_distinct( + {TypedValue(3), TypedValue("two"), TypedValue(), TypedValue(3), TypedValue(true), TypedValue(false), + TypedValue("TWO"), TypedValue()}, + {TypedValue(3), TypedValue("two"), TypedValue(), TypedValue(true), TypedValue(false), TypedValue("TWO")}, false); +} + +TEST_F(QueryPlanMatchFilterTest, ScanAllByLabel) { + auto label1 = db.NameToLabel("label"); + db.CreateIndex(label1); + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + // Add a vertex with a label and one without. + auto labeled_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + ASSERT_TRUE(labeled_vertex.AddLabel(label1).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + + dba.AdvanceCommand(); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::View::OLD))); + // MATCH (n :label) + AstStorage storage; + SymbolTable symbol_table; + auto scan_all_by_label = MakeScanAllByLabel(storage, symbol_table, "n", label1); + // RETURN n + auto output = NEXPR("n", IDENT("n")->MapTo(scan_all_by_label.sym_))->MapTo(symbol_table.CreateSymbol("n", true)); + auto produce = MakeProduce(scan_all_by_label.op_, output); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + ASSERT_EQ(results.size(), 1); + auto result_row = results[0]; + ASSERT_EQ(result_row.size(), 1); + EXPECT_EQ(result_row[0].ValueVertex(), labeled_vertex); +} + +TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelProperty) { + // Add 5 vertices with same label, but with different property values. + auto label1 = db.NameToLabel("label1"); + auto prop = db.NameToProperty("prop"); + // vertex property values that will be stored into the DB + std::vector values{ + storage::PropertyValue(true), + storage::PropertyValue(false), + storage::PropertyValue("a"), + storage::PropertyValue("b"), + storage::PropertyValue("c"), + storage::PropertyValue(0), + storage::PropertyValue(1), + storage::PropertyValue(2), + storage::PropertyValue(0.5), + storage::PropertyValue(1.5), + storage::PropertyValue(2.5), + storage::PropertyValue(std::vector{storage::PropertyValue(0)}), + storage::PropertyValue(std::vector{storage::PropertyValue(1)}), + storage::PropertyValue(std::vector{storage::PropertyValue(2)})}; + { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + for (const auto &value : values) { + auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + ASSERT_TRUE(vertex.AddLabel(label1).HasValue()); + ASSERT_TRUE(vertex.SetProperty(prop, value).HasValue()); + } + ASSERT_FALSE(dba.Commit().HasError()); + } + db.CreateIndex(label1, prop); + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + ASSERT_EQ(14, CountIterable(dba.Vertices(storage::View::OLD))); + + auto run_scan_all = [&](const TypedValue &lower, Bound::Type lower_type, const TypedValue &upper, + Bound::Type upper_type) { + AstStorage storage; + SymbolTable symbol_table; + auto scan_all = + MakeScanAllByLabelPropertyRange(storage, symbol_table, "n", label1, prop, "prop", + Bound{LITERAL(lower), lower_type}, Bound{LITERAL(upper), upper_type}); + // RETURN n + auto output = NEXPR("n", IDENT("n")->MapTo(scan_all.sym_))->MapTo(symbol_table.CreateSymbol("n", true)); + auto produce = MakeProduce(scan_all.op_, output); + auto context = MakeContext(storage, symbol_table, &dba); + return CollectProduce(*produce, &context); + }; + + auto check = [&](TypedValue lower, Bound::Type lower_type, TypedValue upper, Bound::Type upper_type, + const std::vector &expected) { + auto results = run_scan_all(lower, lower_type, upper, upper_type); + ASSERT_EQ(results.size(), expected.size()); + for (size_t i = 0; i < expected.size(); i++) { + TypedValue equal = TypedValue(*results[i][0].ValueVertex().GetProperty(storage::View::OLD, prop)) == expected[i]; + ASSERT_EQ(equal.type(), TypedValue::Type::Bool); + EXPECT_TRUE(equal.ValueBool()); + } + }; + + // normal ranges that return something + check(TypedValue("a"), Bound::Type::EXCLUSIVE, TypedValue("c"), Bound::Type::EXCLUSIVE, {TypedValue("b")}); + check(TypedValue(0), Bound::Type::EXCLUSIVE, TypedValue(2), Bound::Type::INCLUSIVE, + {TypedValue(0.5), TypedValue(1), TypedValue(1.5), TypedValue(2)}); + check(TypedValue(1.5), Bound::Type::EXCLUSIVE, TypedValue(2.5), Bound::Type::INCLUSIVE, + {TypedValue(2), TypedValue(2.5)}); + + auto are_comparable = [](storage::PropertyValue::Type a, storage::PropertyValue::Type b) { + auto is_numeric = [](const storage::PropertyValue::Type t) { + return t == storage::PropertyValue::Type::Int || t == storage::PropertyValue::Type::Double; + }; + + return a == b || (is_numeric(a) && is_numeric(b)); + }; + + auto is_orderable = [](const storage::PropertyValue &t) { + return t.IsNull() || t.IsInt() || t.IsDouble() || t.IsString(); + }; + + // when a range contains different types, nothing should get returned + for (const auto &value_a : values) { + for (const auto &value_b : values) { + if (are_comparable(static_cast(value_a).type(), + static_cast(value_b).type())) + continue; + if (is_orderable(value_a) && is_orderable(value_b)) { + check(TypedValue(value_a), Bound::Type::INCLUSIVE, TypedValue(value_b), Bound::Type::INCLUSIVE, {}); + } else { + EXPECT_THROW( + run_scan_all(TypedValue(value_a), Bound::Type::INCLUSIVE, TypedValue(value_b), Bound::Type::INCLUSIVE), + QueryRuntimeException); + } + } + } + // These should all raise an exception due to type mismatch when using + // `operator<`. + EXPECT_THROW(run_scan_all(TypedValue(false), Bound::Type::INCLUSIVE, TypedValue(true), Bound::Type::EXCLUSIVE), + QueryRuntimeException); + EXPECT_THROW(run_scan_all(TypedValue(false), Bound::Type::EXCLUSIVE, TypedValue(true), Bound::Type::INCLUSIVE), + QueryRuntimeException); + EXPECT_THROW(run_scan_all(TypedValue(std::vector{TypedValue(0.5)}), Bound::Type::EXCLUSIVE, + TypedValue(std::vector{TypedValue(1.5)}), Bound::Type::INCLUSIVE), + QueryRuntimeException); +} + +TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelPropertyEqualityNoError) { + // Add 2 vertices with same label, but with property values that cannot be + // compared. On the other hand, equality works fine. + auto label1 = db.NameToLabel("label1"); + auto prop = db.NameToProperty("prop"); + { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto number_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + ASSERT_TRUE(number_vertex.AddLabel(label1).HasValue()); + ASSERT_TRUE(number_vertex.SetProperty(prop, storage::PropertyValue(42)).HasValue()); + auto string_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + ASSERT_TRUE(string_vertex.AddLabel(label1).HasValue()); + ASSERT_TRUE(string_vertex.SetProperty(prop, storage::PropertyValue("string")).HasValue()); + ASSERT_FALSE(dba.Commit().HasError()); + } + db.CreateIndex(label1, prop); + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::View::OLD))); + // MATCH (n :label {prop: 42}) + AstStorage storage; + SymbolTable symbol_table; + auto scan_all = MakeScanAllByLabelPropertyValue(storage, symbol_table, "n", label1, prop, "prop", LITERAL(42)); + // RETURN n + auto output = NEXPR("n", IDENT("n")->MapTo(scan_all.sym_))->MapTo(symbol_table.CreateSymbol("n", true)); + auto produce = MakeProduce(scan_all.op_, output); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + ASSERT_EQ(results.size(), 1); + const auto &row = results[0]; + ASSERT_EQ(row.size(), 1); + auto vertex = row[0].ValueVertex(); + TypedValue value(*vertex.GetProperty(storage::View::OLD, prop)); + TypedValue::BoolEqual eq; + EXPECT_TRUE(eq(value, TypedValue(42))); +} + +TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelPropertyValueError) { + auto label1 = db.NameToLabel("label1"); + auto prop = db.NameToProperty("prop"); + { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + for (int i = 0; i < 2; ++i) { + auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + ASSERT_TRUE(vertex.AddLabel(label1).HasValue()); + ASSERT_TRUE(vertex.SetProperty(prop, storage::PropertyValue(i)).HasValue()); + } + ASSERT_FALSE(dba.Commit().HasError()); + } + db.CreateIndex(label1, prop); + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::View::OLD))); + // MATCH (m), (n :label1 {prop: m}) + AstStorage storage; + SymbolTable symbol_table; + auto scan_all = MakeScanAll(storage, symbol_table, "m"); + auto *ident_m = IDENT("m"); + ident_m->MapTo(scan_all.sym_); + auto scan_index = + MakeScanAllByLabelPropertyValue(storage, symbol_table, "n", label1, prop, "prop", ident_m, scan_all.op_); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*scan_index.op_, &context), QueryRuntimeException); +} + +TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelPropertyRangeError) { + auto label1 = db.NameToLabel("label1"); + auto prop = db.NameToProperty("prop"); + { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + for (int i = 0; i < 2; ++i) { + auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + ASSERT_TRUE(vertex.AddLabel(label1).HasValue()); + ASSERT_TRUE(vertex.SetProperty(prop, storage::PropertyValue(i)).HasValue()); + } + ASSERT_FALSE(dba.Commit().HasError()); + } + db.CreateIndex(label1, prop); + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::View::OLD))); + // MATCH (m), (n :label1 {prop: m}) + AstStorage storage; + SymbolTable symbol_table; + auto scan_all = MakeScanAll(storage, symbol_table, "m"); + auto *ident_m = IDENT("m"); + ident_m->MapTo(scan_all.sym_); + { + // Lower bound isn't property value + auto scan_index = + MakeScanAllByLabelPropertyRange(storage, symbol_table, "n", label1, prop, "prop", + Bound{ident_m, Bound::Type::INCLUSIVE}, std::nullopt, scan_all.op_); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*scan_index.op_, &context), QueryRuntimeException); + } + { + // Upper bound isn't property value + auto scan_index = MakeScanAllByLabelPropertyRange(storage, symbol_table, "n", label1, prop, "prop", std::nullopt, + Bound{ident_m, Bound::Type::INCLUSIVE}, scan_all.op_); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*scan_index.op_, &context), QueryRuntimeException); + } + { + // Both bounds aren't property value + auto scan_index = MakeScanAllByLabelPropertyRange(storage, symbol_table, "n", label1, prop, "prop", + Bound{ident_m, Bound::Type::INCLUSIVE}, + Bound{ident_m, Bound::Type::INCLUSIVE}, scan_all.op_); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_THROW(PullAll(*scan_index.op_, &context), QueryRuntimeException); + } +} + +TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelPropertyEqualNull) { + // Add 2 vertices with the same label, but one has a property value while + // the other does not. Checking if the value is equal to null, should + // yield no results. + auto label1 = db.NameToLabel("label1"); + auto prop = db.NameToProperty("prop"); + { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + ASSERT_TRUE(vertex.AddLabel(label1).HasValue()); + auto vertex_with_prop = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + ASSERT_TRUE(vertex_with_prop.AddLabel(label1).HasValue()); + ASSERT_TRUE(vertex_with_prop.SetProperty(prop, storage::PropertyValue(42)).HasValue()); + ASSERT_FALSE(dba.Commit().HasError()); + } + db.CreateIndex(label1, prop); + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::View::OLD))); + // MATCH (n :label1 {prop: 42}) + AstStorage storage; + SymbolTable symbol_table; + auto scan_all = + MakeScanAllByLabelPropertyValue(storage, symbol_table, "n", label1, prop, "prop", LITERAL(TypedValue())); + // RETURN n + auto output = NEXPR("n", IDENT("n")->MapTo(scan_all.sym_))->MapTo(symbol_table.CreateSymbol("n", true)); + auto produce = MakeProduce(scan_all.op_, output); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), 0); +} + +TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelPropertyRangeNull) { + // Add 2 vertices with the same label, but one has a property value while + // the other does not. Checking if the value is between nulls, should + // yield no results. + auto label1 = db.NameToLabel("label1"); + auto prop = db.NameToProperty("prop"); + { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + ASSERT_TRUE(vertex.AddLabel(label).HasValue()); + auto vertex_with_prop = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + ASSERT_TRUE(vertex_with_prop.AddLabel(label).HasValue()); + ASSERT_TRUE(vertex_with_prop.SetProperty(prop, storage::PropertyValue(42)).HasValue()); + ASSERT_FALSE(dba.Commit().HasError()); + } + db.CreateIndex(label1, prop); + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::View::OLD))); + // MATCH (n :label1) WHERE null <= n.prop < null + AstStorage storage; + SymbolTable symbol_table; + auto scan_all = MakeScanAllByLabelPropertyRange(storage, symbol_table, "n", label1, prop, "prop", + Bound{LITERAL(TypedValue()), Bound::Type::INCLUSIVE}, + Bound{LITERAL(TypedValue()), Bound::Type::EXCLUSIVE}); + // RETURN n + auto output = NEXPR("n", IDENT("n")->MapTo(scan_all.sym_))->MapTo(symbol_table.CreateSymbol("n", true)); + auto produce = MakeProduce(scan_all.op_, output); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(results.size(), 0); +} + +TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelPropertyNoValueInIndexContinuation) { + auto label1 = db.NameToLabel("label1"); + auto prop = db.NameToProperty("prop"); + { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + ASSERT_TRUE(v.AddLabel(label1).HasValue()); + ASSERT_TRUE(v.SetProperty(prop, storage::PropertyValue(2)).HasValue()); + ASSERT_FALSE(dba.Commit().HasError()); + } + db.CreateIndex(label1, prop); + + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::View::OLD))); + + AstStorage storage; + SymbolTable symbol_table; + + // UNWIND [1, 2, 3] as x + auto input_expr = LIST(LITERAL(1), LITERAL(2), LITERAL(3)); + auto x = symbol_table.CreateSymbol("x", true); + auto unwind = std::make_shared(nullptr, input_expr, x); + auto x_expr = IDENT("x"); + x_expr->MapTo(x); + + // MATCH (n :label1 {prop: x}) + auto scan_all = MakeScanAllByLabelPropertyValue(storage, symbol_table, "n", label1, prop, "prop", x_expr, unwind); + + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(PullAll(*scan_all.op_, &context), 1); +} + +TEST_F(QueryPlanMatchFilterTest, ScanAllEqualsScanAllByLabelProperty) { + auto label1 = db.NameToLabel("label1"); + auto prop = db.NameToProperty("prop"); + + // Insert vertices + const int vertex_count = 300, vertex_prop_count = 50; + const int prop_value1 = 42, prop_value2 = 69; + + for (int i = 0; i < vertex_count; ++i) { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + ASSERT_TRUE(v.AddLabel(label1).HasValue()); + ASSERT_TRUE( + v.SetProperty(prop, storage::PropertyValue(i < vertex_prop_count ? prop_value1 : prop_value2)).HasValue()); + ASSERT_FALSE(dba.Commit().HasError()); + } + + db.CreateIndex(label1, prop); + + // Make sure there are `vertex_count` vertices + { + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + EXPECT_EQ(vertex_count, CountIterable(dba.Vertices(storage::View::OLD))); + } + + // Make sure there are `vertex_prop_count` results when using index + auto count_with_index = [this, &label1, &prop](int prop_value, int prop_count) { + AstStorage storage; + SymbolTable symbol_table; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto scan_all_by_label_property_value = + MakeScanAllByLabelPropertyValue(storage, symbol_table, "n", label1, prop, "prop", LITERAL(prop_value)); + auto output = NEXPR("n", IDENT("n")->MapTo(scan_all_by_label_property_value.sym_)) + ->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(scan_all_by_label_property_value.op_, output); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(PullAll(*produce, &context), prop_count); + }; + + // Make sure there are `vertex_count` results when using scan all + auto count_with_scan_all = [this, &prop](int prop_value, int prop_count) { + AstStorage storage; + SymbolTable symbol_table; + auto storage_dba = db.Access(); + DbAccessor dba(&storage_dba); + auto scan_all = MakeScanAll(storage, symbol_table, "n"); + auto e = PROPERTY_LOOKUP(IDENT("n")->MapTo(scan_all.sym_), std::make_pair("prop", prop)); + auto filter = std::make_shared(scan_all.op_, EQ(e, LITERAL(prop_value))); + auto output = + NEXPR("n", IDENT("n")->MapTo(scan_all.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); + auto produce = MakeProduce(filter, output); + auto context = MakeContext(storage, symbol_table, &dba); + EXPECT_EQ(PullAll(*produce, &context), prop_count); + }; + + count_with_index(prop_value1, vertex_prop_count); + count_with_scan_all(prop_value1, vertex_prop_count); + + count_with_index(prop_value2, vertex_count - vertex_prop_count); + count_with_scan_all(prop_value2, vertex_count - vertex_prop_count); +} +} // namespace memgraph::query::tests diff --git a/tests/unit/query_v2_query_plan_v2_create_set_remove_delete.cpp b/tests/unit/query_v2_query_plan_v2_create_set_remove_delete.cpp new file mode 100644 index 000000000..6f298e2ad --- /dev/null +++ b/tests/unit/query_v2_query_plan_v2_create_set_remove_delete.cpp @@ -0,0 +1,146 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include + +#include "query/frontend/semantic/symbol_table.hpp" +#include "query/plan/operator.hpp" +#include "query_plan_common.hpp" +#include "storage/v2/property_value.hpp" +#include "storage/v2/storage.hpp" + +namespace memgraph::query::tests { + +class QueryPlanCRUDTest : public testing::Test { + protected: + void SetUp() override { + ASSERT_TRUE(db.CreateSchema(label, {storage::SchemaProperty{property, common::SchemaType::INT}})); + } + + storage::Storage db; + const storage::LabelId label{db.NameToLabel("label")}; + const storage::PropertyId property{db.NameToProperty("property")}; +}; + +TEST_F(QueryPlanCRUDTest, CreateNodeWithAttributes) { + auto dba = db.Access(); + + AstStorage ast; + SymbolTable symbol_table; + + plan::NodeCreationInfo node; + node.symbol = symbol_table.CreateSymbol("n", true); + node.labels.emplace_back(label); + std::get>>(node.properties) + .emplace_back(property, ast.Create(42)); + + plan::CreateNode create_node(nullptr, node); + DbAccessor execution_dba(&dba); + auto context = MakeContext(ast, symbol_table, &execution_dba); + Frame frame(context.symbol_table.max_position()); + auto cursor = create_node.MakeCursor(utils::NewDeleteResource()); + int count = 0; + while (cursor->Pull(frame, context)) { + ++count; + const auto &node_value = frame[node.symbol]; + EXPECT_EQ(node_value.type(), TypedValue::Type::Vertex); + const auto &v = node_value.ValueVertex(); + EXPECT_TRUE(*v.HasLabel(storage::View::NEW, label)); + EXPECT_EQ(v.GetProperty(storage::View::NEW, property)->ValueInt(), 42); + EXPECT_EQ(CountIterable(*v.InEdges(storage::View::NEW)), 0); + EXPECT_EQ(CountIterable(*v.OutEdges(storage::View::NEW)), 0); + // Invokes LOG(FATAL) instead of erroring out. + // EXPECT_TRUE(v.HasLabel(label, storage::View::OLD).IsError()); + } + EXPECT_EQ(count, 1); +} + +TEST_F(QueryPlanCRUDTest, ScanAllEmpty) { + AstStorage ast; + SymbolTable symbol_table; + auto dba = db.Access(); + DbAccessor execution_dba(&dba); + auto node_symbol = symbol_table.CreateSymbol("n", true); + { + plan::ScanAll scan_all(nullptr, node_symbol, storage::View::OLD); + auto context = MakeContext(ast, symbol_table, &execution_dba); + Frame frame(context.symbol_table.max_position()); + auto cursor = scan_all.MakeCursor(utils::NewDeleteResource()); + int count = 0; + while (cursor->Pull(frame, context)) ++count; + EXPECT_EQ(count, 0); + } + { + plan::ScanAll scan_all(nullptr, node_symbol, storage::View::NEW); + auto context = MakeContext(ast, symbol_table, &execution_dba); + Frame frame(context.symbol_table.max_position()); + auto cursor = scan_all.MakeCursor(utils::NewDeleteResource()); + int count = 0; + while (cursor->Pull(frame, context)) ++count; + EXPECT_EQ(count, 0); + } +} + +TEST_F(QueryPlanCRUDTest, ScanAll) { + { + auto dba = db.Access(); + for (int i = 0; i < 42; ++i) { + auto v = *dba.CreateVertexAndValidate(label, {}, {{property, storage::PropertyValue(i)}}); + ASSERT_TRUE(v.SetProperty(property, storage::PropertyValue(i)).HasValue()); + } + EXPECT_FALSE(dba.Commit().HasError()); + } + AstStorage ast; + SymbolTable symbol_table; + auto dba = db.Access(); + DbAccessor execution_dba(&dba); + auto node_symbol = symbol_table.CreateSymbol("n", true); + plan::ScanAll scan_all(nullptr, node_symbol); + auto context = MakeContext(ast, symbol_table, &execution_dba); + Frame frame(context.symbol_table.max_position()); + auto cursor = scan_all.MakeCursor(utils::NewDeleteResource()); + int count = 0; + while (cursor->Pull(frame, context)) ++count; + EXPECT_EQ(count, 42); +} + +TEST_F(QueryPlanCRUDTest, ScanAllByLabel) { + auto label2 = db.NameToLabel("label2"); + ASSERT_TRUE(db.CreateIndex(label2)); + { + auto dba = db.Access(); + // Add some unlabeled vertices + for (int i = 0; i < 12; ++i) { + auto v = *dba.CreateVertexAndValidate(label, {}, {{property, storage::PropertyValue(i)}}); + ASSERT_TRUE(v.SetProperty(property, storage::PropertyValue(i)).HasValue()); + } + // Add labeled vertices + for (int i = 0; i < 42; ++i) { + auto v = *dba.CreateVertexAndValidate(label, {}, {{property, storage::PropertyValue(i)}}); + ASSERT_TRUE(v.SetProperty(property, storage::PropertyValue(i)).HasValue()); + ASSERT_TRUE(v.AddLabel(label2).HasValue()); + } + EXPECT_FALSE(dba.Commit().HasError()); + } + auto dba = db.Access(); + AstStorage ast; + SymbolTable symbol_table; + auto node_symbol = symbol_table.CreateSymbol("n", true); + DbAccessor execution_dba(&dba); + plan::ScanAllByLabel scan_all(nullptr, node_symbol, label2); + auto context = MakeContext(ast, symbol_table, &execution_dba); + Frame frame(context.symbol_table.max_position()); + auto cursor = scan_all.MakeCursor(utils::NewDeleteResource()); + int count = 0; + while (cursor->Pull(frame, context)) ++count; + EXPECT_EQ(count, 42); +} +} // namespace memgraph::query::tests diff --git a/tests/unit/storage_v3_schema.cpp b/tests/unit/storage_v3_schema.cpp new file mode 100644 index 000000000..bd9a82209 --- /dev/null +++ b/tests/unit/storage_v3_schema.cpp @@ -0,0 +1,294 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include +#include +#include + +#include +#include +#include +#include + +#include "common/types.hpp" +#include "storage/v2/id_types.hpp" +#include "storage/v2/property_value.hpp" +#include "storage/v2/schema_validator.hpp" +#include "storage/v2/schemas.hpp" +#include "storage/v2/storage.hpp" +#include "storage/v2/temporal.hpp" + +using testing::Pair; +using testing::UnorderedElementsAre; +using SchemaType = memgraph::common::SchemaType; + +namespace memgraph::storage::tests { + +class SchemaTest : public testing::Test { + private: + memgraph::storage::NameIdMapper label_mapper_; + memgraph::storage::NameIdMapper property_mapper_; + + protected: + LabelId NameToLabel(const std::string &name) { return LabelId::FromUint(label_mapper_.NameToId(name)); } + + PropertyId NameToProperty(const std::string &name) { return PropertyId::FromUint(property_mapper_.NameToId(name)); } + + PropertyId prop1{NameToProperty("prop1")}; + PropertyId prop2{NameToProperty("prop2")}; + LabelId label1{NameToLabel("label1")}; + LabelId label2{NameToLabel("label2")}; + SchemaProperty schema_prop_string{prop1, SchemaType::STRING}; + SchemaProperty schema_prop_int{prop2, SchemaType::INT}; +}; + +TEST_F(SchemaTest, TestSchemaCreate) { + Schemas schemas; + EXPECT_EQ(schemas.ListSchemas().size(), 0); + + EXPECT_TRUE(schemas.CreateSchema(label1, {schema_prop_string})); + EXPECT_EQ(schemas.ListSchemas().size(), 1); + + { + EXPECT_TRUE(schemas.CreateSchema(label2, {schema_prop_string, schema_prop_int})); + const auto current_schemas = schemas.ListSchemas(); + EXPECT_EQ(current_schemas.size(), 2); + EXPECT_THAT(current_schemas, + UnorderedElementsAre(Pair(label1, std::vector{schema_prop_string}), + Pair(label2, std::vector{schema_prop_string, schema_prop_int}))); + } + { + // Assert after unsuccessful creation, number oif schemas remains the same + EXPECT_FALSE(schemas.CreateSchema(label2, {schema_prop_int})); + const auto current_schemas = schemas.ListSchemas(); + EXPECT_EQ(current_schemas.size(), 2); + EXPECT_THAT(current_schemas, + UnorderedElementsAre(Pair(label1, std::vector{schema_prop_string}), + Pair(label2, std::vector{schema_prop_string, schema_prop_int}))); + } +} + +TEST_F(SchemaTest, TestSchemaList) { + Schemas schemas; + + EXPECT_TRUE(schemas.CreateSchema(label1, {schema_prop_string})); + EXPECT_TRUE(schemas.CreateSchema(label2, {{NameToProperty("prop1"), SchemaType::STRING}, + {NameToProperty("prop2"), SchemaType::INT}, + {NameToProperty("prop3"), SchemaType::BOOL}, + {NameToProperty("prop4"), SchemaType::DATE}, + {NameToProperty("prop5"), SchemaType::LOCALDATETIME}, + {NameToProperty("prop6"), SchemaType::DURATION}, + {NameToProperty("prop7"), SchemaType::LOCALTIME}})); + { + const auto current_schemas = schemas.ListSchemas(); + EXPECT_EQ(current_schemas.size(), 2); + EXPECT_THAT(current_schemas, + UnorderedElementsAre( + Pair(label1, std::vector{schema_prop_string}), + Pair(label2, std::vector{{NameToProperty("prop1"), SchemaType::STRING}, + {NameToProperty("prop2"), SchemaType::INT}, + {NameToProperty("prop3"), SchemaType::BOOL}, + {NameToProperty("prop4"), SchemaType::DATE}, + {NameToProperty("prop5"), SchemaType::LOCALDATETIME}, + {NameToProperty("prop6"), SchemaType::DURATION}, + {NameToProperty("prop7"), SchemaType::LOCALTIME}}))); + } + { + const auto *const schema1 = schemas.GetSchema(label1); + ASSERT_NE(schema1, nullptr); + EXPECT_EQ(*schema1, (Schemas::Schema{label1, std::vector{schema_prop_string}})); + } + { + const auto *const schema2 = schemas.GetSchema(label2); + ASSERT_NE(schema2, nullptr); + EXPECT_EQ(schema2->first, label2); + EXPECT_EQ(schema2->second.size(), 7); + } +} + +TEST_F(SchemaTest, TestSchemaDrop) { + Schemas schemas; + EXPECT_EQ(schemas.ListSchemas().size(), 0); + + EXPECT_TRUE(schemas.CreateSchema(label1, {schema_prop_string})); + EXPECT_EQ(schemas.ListSchemas().size(), 1); + + EXPECT_TRUE(schemas.DropSchema(label1)); + EXPECT_EQ(schemas.ListSchemas().size(), 0); + + EXPECT_TRUE(schemas.CreateSchema(label1, {schema_prop_string})); + EXPECT_TRUE(schemas.CreateSchema(label2, {schema_prop_string, schema_prop_int})); + EXPECT_EQ(schemas.ListSchemas().size(), 2); + + { + EXPECT_TRUE(schemas.DropSchema(label1)); + const auto current_schemas = schemas.ListSchemas(); + EXPECT_EQ(current_schemas.size(), 1); + EXPECT_THAT(current_schemas, + UnorderedElementsAre(Pair(label2, std::vector{schema_prop_string, schema_prop_int}))); + } + + { + // Cannot drop nonexisting schema + EXPECT_FALSE(schemas.DropSchema(label1)); + const auto current_schemas = schemas.ListSchemas(); + EXPECT_EQ(current_schemas.size(), 1); + EXPECT_THAT(current_schemas, + UnorderedElementsAre(Pair(label2, std::vector{schema_prop_string, schema_prop_int}))); + } + + EXPECT_TRUE(schemas.DropSchema(label2)); + EXPECT_EQ(schemas.ListSchemas().size(), 0); +} + +class SchemaValidatorTest : public testing::Test { + protected: + void SetUp() override { + ASSERT_TRUE(schemas.CreateSchema(label1, {schema_prop_string})); + ASSERT_TRUE(schemas.CreateSchema(label2, {schema_prop_string, schema_prop_int, schema_prop_duration})); + } + + LabelId NameToLabel(const std::string &name) { return LabelId::FromUint(label_mapper_.NameToId(name)); } + + PropertyId NameToProperty(const std::string &name) { return PropertyId::FromUint(property_mapper_.NameToId(name)); } + + private: + memgraph::storage::NameIdMapper label_mapper_; + memgraph::storage::NameIdMapper property_mapper_; + + protected: + Schemas schemas; + SchemaValidator schema_validator{schemas}; + PropertyId prop_string{NameToProperty("prop1")}; + PropertyId prop_int{NameToProperty("prop2")}; + PropertyId prop_duration{NameToProperty("prop3")}; + LabelId label1{NameToLabel("label1")}; + LabelId label2{NameToLabel("label2")}; + SchemaProperty schema_prop_string{prop_string, SchemaType::STRING}; + SchemaProperty schema_prop_int{prop_int, SchemaType::INT}; + SchemaProperty schema_prop_duration{prop_duration, SchemaType::DURATION}; +}; + +TEST_F(SchemaValidatorTest, TestSchemaValidateVertexCreate) { + // Validate against secondary label + { + const auto schema_violation = + schema_validator.ValidateVertexCreate(NameToLabel("test"), {}, {{prop_string, PropertyValue(1)}}); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, + SchemaViolation(SchemaViolation::ValidationStatus::NO_SCHEMA_DEFINED_FOR_LABEL, NameToLabel("test"))); + } + // Validate missing property + { + const auto schema_violation = schema_validator.ValidateVertexCreate(label1, {}, {{prop_int, PropertyValue(1)}}); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_HAS_NO_PRIMARY_PROPERTY, + label1, schema_prop_string)); + } + { + const auto schema_violation = schema_validator.ValidateVertexCreate(label2, {}, {}); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_HAS_NO_PRIMARY_PROPERTY, + label2, schema_prop_string)); + } + // Validate wrong secondary label + { + const auto schema_violation = + schema_validator.ValidateVertexCreate(label1, {label1}, {{prop_string, PropertyValue("test")}}); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, + SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_SECONDARY_LABEL_IS_PRIMARY, label1)); + } + { + const auto schema_violation = + schema_validator.ValidateVertexCreate(label1, {label2}, {{prop_string, PropertyValue("test")}}); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, + SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_SECONDARY_LABEL_IS_PRIMARY, label2)); + } + // Validate wrong property type + { + const auto schema_violation = schema_validator.ValidateVertexCreate(label1, {}, {{prop_string, PropertyValue(1)}}); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_PROPERTY_WRONG_TYPE, label1, + schema_prop_string, PropertyValue(1))); + } + { + const auto schema_violation = schema_validator.ValidateVertexCreate( + label2, {}, + {{prop_string, PropertyValue("test")}, {prop_int, PropertyValue(12)}, {prop_duration, PropertyValue(1)}}); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_PROPERTY_WRONG_TYPE, label2, + schema_prop_duration, PropertyValue(1))); + } + { + const auto wrong_prop = PropertyValue(TemporalData(TemporalType::Date, 1234)); + const auto schema_violation = schema_validator.ValidateVertexCreate( + label2, {}, {{prop_string, PropertyValue("test")}, {prop_int, PropertyValue(12)}, {prop_duration, wrong_prop}}); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_PROPERTY_WRONG_TYPE, label2, + schema_prop_duration, wrong_prop)); + } + // Passing validations + EXPECT_EQ(schema_validator.ValidateVertexCreate(label1, {}, {{prop_string, PropertyValue("test")}}), std::nullopt); + EXPECT_EQ(schema_validator.ValidateVertexCreate(label1, {NameToLabel("label3"), NameToLabel("label4")}, + {{prop_string, PropertyValue("test")}}), + std::nullopt); + EXPECT_EQ(schema_validator.ValidateVertexCreate( + label2, {}, + {{prop_string, PropertyValue("test")}, + {prop_int, PropertyValue(122)}, + {prop_duration, PropertyValue(TemporalData(TemporalType::Duration, 1234))}}), + std::nullopt); + EXPECT_EQ(schema_validator.ValidateVertexCreate( + label2, {NameToLabel("label5"), NameToLabel("label6")}, + {{prop_string, PropertyValue("test123")}, + {prop_int, PropertyValue(122221)}, + {prop_duration, PropertyValue(TemporalData(TemporalType::Duration, 12344321))}}), + std::nullopt); +} + +TEST_F(SchemaValidatorTest, TestSchemaValidatePropertyUpdate) { + // Validate updating of primary key + { + const auto schema_violation = schema_validator.ValidatePropertyUpdate(label1, prop_string); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_UPDATE_PRIMARY_KEY, label1, + schema_prop_string)); + } + { + const auto schema_violation = schema_validator.ValidatePropertyUpdate(label2, prop_duration); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_UPDATE_PRIMARY_KEY, label2, + schema_prop_duration)); + } + EXPECT_EQ(schema_validator.ValidatePropertyUpdate(label1, prop_int), std::nullopt); + EXPECT_EQ(schema_validator.ValidatePropertyUpdate(label1, prop_duration), std::nullopt); + EXPECT_EQ(schema_validator.ValidatePropertyUpdate(label2, NameToProperty("test")), std::nullopt); +} + +TEST_F(SchemaValidatorTest, TestSchemaValidatePropertyUpdateLabel) { + // Validate adding primary label + { + const auto schema_violation = schema_validator.ValidateLabelUpdate(label1); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, + SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_MODIFY_PRIMARY_LABEL, label1)); + } + { + const auto schema_violation = schema_validator.ValidateLabelUpdate(label2); + ASSERT_NE(schema_violation, std::nullopt); + EXPECT_EQ(*schema_violation, + SchemaViolation(SchemaViolation::ValidationStatus::VERTEX_MODIFY_PRIMARY_LABEL, label2)); + } + EXPECT_EQ(schema_validator.ValidateLabelUpdate(NameToLabel("test")), std::nullopt); +} +} // namespace memgraph::storage::tests From a12a1ea3582676c3f51d65379f3aadb3e973aa7f Mon Sep 17 00:00:00 2001 From: Jure Bajic Date: Thu, 4 Aug 2022 09:50:02 +0200 Subject: [PATCH 4/5] Move schema to storage v3 and query v2 * Move schema to storage v3 * Remove schema from v2 * Move schema to query v2 * Remove schema from query v1 * Make glue v2 * Move schema related tests to newer versions of query and storage * Fix typo in CMake * Fix interpreter test * Fix clang tidy errors * Change temp dir name --- src/glue/auth.cpp | 2 - src/glue/v2/auth.cpp | 64 + src/glue/v2/auth.hpp | 23 + src/glue/v2/communication.cpp | 275 ++ src/glue/v2/communication.hpp | 68 + src/query/common.hpp | 85 +- src/query/db_accessor.hpp | 41 +- src/query/exceptions.hpp | 8 - src/query/frontend/ast/ast.lcp | 47 +- src/query/frontend/ast/ast_visitor.hpp | 9 +- .../frontend/ast/cypher_main_visitor.cpp | 90 - .../frontend/ast/cypher_main_visitor.hpp | 35 - .../opencypher/grammar/MemgraphCypher.g4 | 28 +- .../opencypher/grammar/MemgraphCypherLexer.g4 | 2 - .../frontend/semantic/required_privileges.cpp | 2 - .../frontend/stripped_lexer_constants.hpp | 3 +- src/query/interpreter.cpp | 126 - src/query/metadata.cpp | 8 - src/query/metadata.hpp | 4 - src/query/plan/operator.cpp | 164 +- src/query/v2/common.hpp | 85 +- src/query/v2/db_accessor.hpp | 45 +- src/query/v2/exceptions.hpp | 8 + src/query/v2/frontend/ast/ast.lcp | 46 +- src/query/v2/frontend/ast/ast_visitor.hpp | 9 +- .../v2/frontend/ast/cypher_main_visitor.cpp | 101 +- .../v2/frontend/ast/cypher_main_visitor.hpp | 35 + .../opencypher/grammar/MemgraphCypher.g4 | 29 +- .../opencypher/grammar/MemgraphCypherLexer.g4 | 2 + .../frontend/semantic/required_privileges.cpp | 2 + .../v2/frontend/stripped_lexer_constants.hpp | 5 +- src/query/v2/interpreter.cpp | 125 + src/query/v2/metadata.cpp | 8 + src/query/v2/metadata.hpp | 4 + src/query/v2/plan/operator.cpp | 168 +- src/storage/v2/CMakeLists.txt | 2 - src/storage/v2/constraints.cpp | 11 +- src/storage/v2/constraints.hpp | 4 +- src/storage/v2/durability/snapshot.cpp | 12 +- src/storage/v2/durability/snapshot.hpp | 6 +- src/storage/v2/edge_accessor.cpp | 5 +- src/storage/v2/edge_accessor.hpp | 6 +- src/storage/v2/indices.cpp | 23 +- src/storage/v2/indices.hpp | 33 +- .../v2/replication/replication_server.cpp | 10 +- src/storage/v2/storage.cpp | 122 +- src/storage/v2/storage.hpp | 48 +- src/storage/v2/vertex.hpp | 27 +- src/storage/v2/vertex_accessor.cpp | 138 +- src/storage/v2/vertex_accessor.hpp | 46 +- src/storage/v3/CMakeLists.txt | 2 + src/storage/v3/constraints.cpp | 11 +- src/storage/v3/constraints.hpp | 4 +- src/storage/v3/durability/snapshot.cpp | 12 +- src/storage/v3/durability/snapshot.hpp | 6 +- src/storage/v3/edge_accessor.cpp | 5 +- src/storage/v3/edge_accessor.hpp | 6 +- src/storage/v3/indices.cpp | 24 +- src/storage/v3/indices.hpp | 34 +- .../v3/replication/replication_server.cpp | 11 +- src/storage/{v2 => v3}/schema_validator.cpp | 8 +- src/storage/{v2 => v3}/schema_validator.hpp | 14 +- src/storage/{v2 => v3}/schemas.cpp | 8 +- src/storage/{v2 => v3}/schemas.hpp | 10 +- src/storage/v3/storage.cpp | 131 +- src/storage/v3/storage.hpp | 72 +- src/storage/v3/vertex.hpp | 25 + src/storage/v3/vertex_accessor.cpp | 138 +- src/storage/v3/vertex_accessor.hpp | 46 +- tests/unit/CMakeLists.txt | 88 +- tests/unit/cypher_main_visitor.cpp | 114 - tests/unit/query_required_privileges.cpp | 5 - tests/unit/query_v2_cypher_main_visitor.cpp | 4326 +++++++++++++++++ ...preter_v2.cpp => query_v2_interpreter.cpp} | 211 +- tests/unit/query_v2_query_common.hpp | 594 +++ ...ery_v2_query_plan_accumulate_aggregate.cpp | 139 +- .../query_v2_query_plan_bag_semantics.cpp | 134 +- tests/unit/query_v2_query_plan_common.hpp | 225 + ...v2_query_plan_create_set_remove_delete.cpp | 426 +- tests/unit/query_v2_query_plan_edge_cases.cpp | 17 +- ...uery_v2_query_plan_match_filter_return.cpp | 436 +- ...query_plan_v2_create_set_remove_delete.cpp | 50 +- .../query_v2_query_required_privileges.cpp | 222 + tests/unit/result_stream_faker.hpp | 132 + tests/unit/storage_v3_schema.cpp | 24 +- 85 files changed, 7893 insertions(+), 2066 deletions(-) create mode 100644 src/glue/v2/auth.cpp create mode 100644 src/glue/v2/auth.hpp create mode 100644 src/glue/v2/communication.cpp create mode 100644 src/glue/v2/communication.hpp rename src/storage/{v2 => v3}/schema_validator.cpp (96%) rename src/storage/{v2 => v3}/schema_validator.hpp (90%) rename src/storage/{v2 => v3}/schemas.cpp (95%) rename src/storage/{v2 => v3}/schemas.hpp (91%) create mode 100644 tests/unit/query_v2_cypher_main_visitor.cpp rename tests/unit/{interpreter_v2.cpp => query_v2_interpreter.cpp} (90%) create mode 100644 tests/unit/query_v2_query_common.hpp create mode 100644 tests/unit/query_v2_query_plan_common.hpp create mode 100644 tests/unit/query_v2_query_required_privileges.cpp create mode 100644 tests/unit/result_stream_faker.hpp diff --git a/src/glue/auth.cpp b/src/glue/auth.cpp index 5d9ffbb84..7f05d8045 100644 --- a/src/glue/auth.cpp +++ b/src/glue/auth.cpp @@ -57,8 +57,6 @@ auth::Permission PrivilegeToPermission(query::AuthQuery::Privilege privilege) { return auth::Permission::MODULE_WRITE; case query::AuthQuery::Privilege::WEBSOCKET: return auth::Permission::WEBSOCKET; - case query::AuthQuery::Privilege::SCHEMA: - return auth::Permission::SCHEMA; } } } // namespace memgraph::glue diff --git a/src/glue/v2/auth.cpp b/src/glue/v2/auth.cpp new file mode 100644 index 000000000..c550f39a0 --- /dev/null +++ b/src/glue/v2/auth.cpp @@ -0,0 +1,64 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "glue/v2/auth.hpp" + +namespace memgraph::glue::v2 { + +auth::Permission PrivilegeToPermission(query::v2::AuthQuery::Privilege privilege) { + switch (privilege) { + case query::v2::AuthQuery::Privilege::MATCH: + return auth::Permission::MATCH; + case query::v2::AuthQuery::Privilege::CREATE: + return auth::Permission::CREATE; + case query::v2::AuthQuery::Privilege::MERGE: + return auth::Permission::MERGE; + case query::v2::AuthQuery::Privilege::DELETE: + return auth::Permission::DELETE; + case query::v2::AuthQuery::Privilege::SET: + return auth::Permission::SET; + case query::v2::AuthQuery::Privilege::REMOVE: + return auth::Permission::REMOVE; + case query::v2::AuthQuery::Privilege::INDEX: + return auth::Permission::INDEX; + case query::v2::AuthQuery::Privilege::STATS: + return auth::Permission::STATS; + case query::v2::AuthQuery::Privilege::CONSTRAINT: + return auth::Permission::CONSTRAINT; + case query::v2::AuthQuery::Privilege::DUMP: + return auth::Permission::DUMP; + case query::v2::AuthQuery::Privilege::REPLICATION: + return auth::Permission::REPLICATION; + case query::v2::AuthQuery::Privilege::DURABILITY: + return auth::Permission::DURABILITY; + case query::v2::AuthQuery::Privilege::READ_FILE: + return auth::Permission::READ_FILE; + case query::v2::AuthQuery::Privilege::FREE_MEMORY: + return auth::Permission::FREE_MEMORY; + case query::v2::AuthQuery::Privilege::TRIGGER: + return auth::Permission::TRIGGER; + case query::v2::AuthQuery::Privilege::CONFIG: + return auth::Permission::CONFIG; + case query::v2::AuthQuery::Privilege::AUTH: + return auth::Permission::AUTH; + case query::v2::AuthQuery::Privilege::STREAM: + return auth::Permission::STREAM; + case query::v2::AuthQuery::Privilege::MODULE_READ: + return auth::Permission::MODULE_READ; + case query::v2::AuthQuery::Privilege::MODULE_WRITE: + return auth::Permission::MODULE_WRITE; + case query::v2::AuthQuery::Privilege::WEBSOCKET: + return auth::Permission::WEBSOCKET; + case query::v2::AuthQuery::Privilege::SCHEMA: + return auth::Permission::SCHEMA; + } +} +} // namespace memgraph::glue::v2 diff --git a/src/glue/v2/auth.hpp b/src/glue/v2/auth.hpp new file mode 100644 index 000000000..3410f6d57 --- /dev/null +++ b/src/glue/v2/auth.hpp @@ -0,0 +1,23 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "auth/models.hpp" +#include "query/v2/frontend/ast/ast.hpp" + +namespace memgraph::glue::v2 { + +/** + * This function converts query::AuthQuery::Privilege to its corresponding + * auth::Permission. + */ +auth::Permission PrivilegeToPermission(query::v2::AuthQuery::Privilege privilege); + +} // namespace memgraph::glue::v2 diff --git a/src/glue/v2/communication.cpp b/src/glue/v2/communication.cpp new file mode 100644 index 000000000..6d99a9a92 --- /dev/null +++ b/src/glue/v2/communication.cpp @@ -0,0 +1,275 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "glue/v2/communication.hpp" + +#include +#include +#include + +#include "storage/v3/edge_accessor.hpp" +#include "storage/v3/storage.hpp" +#include "storage/v3/vertex_accessor.hpp" +#include "utils/temporal.hpp" + +using memgraph::communication::bolt::Value; + +namespace memgraph::glue::v2 { + +query::v2::TypedValue ToTypedValue(const Value &value) { + switch (value.type()) { + case Value::Type::Null: + return {}; + case Value::Type::Bool: + return query::v2::TypedValue(value.ValueBool()); + case Value::Type::Int: + return query::v2::TypedValue(value.ValueInt()); + case Value::Type::Double: + return query::v2::TypedValue(value.ValueDouble()); + case Value::Type::String: + return query::v2::TypedValue(value.ValueString()); + case Value::Type::List: { + std::vector list; + list.reserve(value.ValueList().size()); + for (const auto &v : value.ValueList()) list.push_back(ToTypedValue(v)); + return query::v2::TypedValue(std::move(list)); + } + case Value::Type::Map: { + std::map map; + for (const auto &kv : value.ValueMap()) map.emplace(kv.first, ToTypedValue(kv.second)); + return query::v2::TypedValue(std::move(map)); + } + case Value::Type::Vertex: + case Value::Type::Edge: + case Value::Type::UnboundedEdge: + case Value::Type::Path: + throw communication::bolt::ValueException("Unsupported conversion from Value to TypedValue"); + case Value::Type::Date: + return query::v2::TypedValue(value.ValueDate()); + case Value::Type::LocalTime: + return query::v2::TypedValue(value.ValueLocalTime()); + case Value::Type::LocalDateTime: + return query::v2::TypedValue(value.ValueLocalDateTime()); + case Value::Type::Duration: + return query::v2::TypedValue(value.ValueDuration()); + } +} + +storage::v3::Result ToBoltVertex(const query::v2::VertexAccessor &vertex, + const storage::v3::Storage &db, storage::v3::View view) { + return ToBoltVertex(vertex.impl_, db, view); +} + +storage::v3::Result ToBoltEdge(const query::v2::EdgeAccessor &edge, + const storage::v3::Storage &db, storage::v3::View view) { + return ToBoltEdge(edge.impl_, db, view); +} + +storage::v3::Result ToBoltValue(const query::v2::TypedValue &value, const storage::v3::Storage &db, + storage::v3::View view) { + switch (value.type()) { + case query::v2::TypedValue::Type::Null: + return Value(); + case query::v2::TypedValue::Type::Bool: + return Value(value.ValueBool()); + case query::v2::TypedValue::Type::Int: + return Value(value.ValueInt()); + case query::v2::TypedValue::Type::Double: + return Value(value.ValueDouble()); + case query::v2::TypedValue::Type::String: + return Value(std::string(value.ValueString())); + case query::v2::TypedValue::Type::List: { + std::vector values; + values.reserve(value.ValueList().size()); + for (const auto &v : value.ValueList()) { + auto maybe_value = ToBoltValue(v, db, view); + if (maybe_value.HasError()) return maybe_value.GetError(); + values.emplace_back(std::move(*maybe_value)); + } + return Value(std::move(values)); + } + case query::v2::TypedValue::Type::Map: { + std::map map; + for (const auto &kv : value.ValueMap()) { + auto maybe_value = ToBoltValue(kv.second, db, view); + if (maybe_value.HasError()) return maybe_value.GetError(); + map.emplace(kv.first, std::move(*maybe_value)); + } + return Value(std::move(map)); + } + case query::v2::TypedValue::Type::Vertex: { + auto maybe_vertex = ToBoltVertex(value.ValueVertex(), db, view); + if (maybe_vertex.HasError()) return maybe_vertex.GetError(); + return Value(std::move(*maybe_vertex)); + } + case query::v2::TypedValue::Type::Edge: { + auto maybe_edge = ToBoltEdge(value.ValueEdge(), db, view); + if (maybe_edge.HasError()) return maybe_edge.GetError(); + return Value(std::move(*maybe_edge)); + } + case query::v2::TypedValue::Type::Path: { + auto maybe_path = ToBoltPath(value.ValuePath(), db, view); + if (maybe_path.HasError()) return maybe_path.GetError(); + return Value(std::move(*maybe_path)); + } + case query::v2::TypedValue::Type::Date: + return Value(value.ValueDate()); + case query::v2::TypedValue::Type::LocalTime: + return Value(value.ValueLocalTime()); + case query::v2::TypedValue::Type::LocalDateTime: + return Value(value.ValueLocalDateTime()); + case query::v2::TypedValue::Type::Duration: + return Value(value.ValueDuration()); + } +} + +storage::v3::Result ToBoltVertex(const storage::v3::VertexAccessor &vertex, + const storage::v3::Storage &db, storage::v3::View view) { + auto id = communication::bolt::Id::FromUint(vertex.Gid().AsUint()); + auto maybe_labels = vertex.Labels(view); + if (maybe_labels.HasError()) return maybe_labels.GetError(); + std::vector labels; + labels.reserve(maybe_labels->size()); + for (const auto &label : *maybe_labels) { + labels.push_back(db.LabelToName(label)); + } + auto maybe_properties = vertex.Properties(view); + if (maybe_properties.HasError()) return maybe_properties.GetError(); + std::map properties; + for (const auto &prop : *maybe_properties) { + properties[db.PropertyToName(prop.first)] = ToBoltValue(prop.second); + } + return communication::bolt::Vertex{id, labels, properties}; +} + +storage::v3::Result ToBoltEdge(const storage::v3::EdgeAccessor &edge, + const storage::v3::Storage &db, storage::v3::View view) { + auto id = communication::bolt::Id::FromUint(edge.Gid().AsUint()); + auto from = communication::bolt::Id::FromUint(edge.FromVertex().Gid().AsUint()); + auto to = communication::bolt::Id::FromUint(edge.ToVertex().Gid().AsUint()); + const auto &type = db.EdgeTypeToName(edge.EdgeType()); + auto maybe_properties = edge.Properties(view); + if (maybe_properties.HasError()) return maybe_properties.GetError(); + std::map properties; + for (const auto &prop : *maybe_properties) { + properties[db.PropertyToName(prop.first)] = ToBoltValue(prop.second); + } + return communication::bolt::Edge{id, from, to, type, properties}; +} + +storage::v3::Result ToBoltPath(const query::v2::Path &path, const storage::v3::Storage &db, + storage::v3::View view) { + std::vector vertices; + vertices.reserve(path.vertices().size()); + for (const auto &v : path.vertices()) { + auto maybe_vertex = ToBoltVertex(v, db, view); + if (maybe_vertex.HasError()) return maybe_vertex.GetError(); + vertices.emplace_back(std::move(*maybe_vertex)); + } + std::vector edges; + edges.reserve(path.edges().size()); + for (const auto &e : path.edges()) { + auto maybe_edge = ToBoltEdge(e, db, view); + if (maybe_edge.HasError()) return maybe_edge.GetError(); + edges.emplace_back(std::move(*maybe_edge)); + } + return communication::bolt::Path(vertices, edges); +} + +storage::v3::PropertyValue ToPropertyValue(const Value &value) { + switch (value.type()) { + case Value::Type::Null: + return {}; + case Value::Type::Bool: + return storage::v3::PropertyValue(value.ValueBool()); + case Value::Type::Int: + return storage::v3::PropertyValue(value.ValueInt()); + case Value::Type::Double: + return storage::v3::PropertyValue(value.ValueDouble()); + case Value::Type::String: + return storage::v3::PropertyValue(value.ValueString()); + case Value::Type::List: { + std::vector vec; + vec.reserve(value.ValueList().size()); + for (const auto &value : value.ValueList()) vec.emplace_back(ToPropertyValue(value)); + return storage::v3::PropertyValue(std::move(vec)); + } + case Value::Type::Map: { + std::map map; + for (const auto &kv : value.ValueMap()) map.emplace(kv.first, ToPropertyValue(kv.second)); + return storage::v3::PropertyValue(std::move(map)); + } + case Value::Type::Vertex: + case Value::Type::Edge: + case Value::Type::UnboundedEdge: + case Value::Type::Path: + throw communication::bolt::ValueException("Unsupported conversion from Value to PropertyValue"); + case Value::Type::Date: + return storage::v3::PropertyValue( + storage::v3::TemporalData(storage::v3::TemporalType::Date, value.ValueDate().MicrosecondsSinceEpoch())); + case Value::Type::LocalTime: + return storage::v3::PropertyValue(storage::v3::TemporalData(storage::v3::TemporalType::LocalTime, + value.ValueLocalTime().MicrosecondsSinceEpoch())); + case Value::Type::LocalDateTime: + return storage::v3::PropertyValue(storage::v3::TemporalData(storage::v3::TemporalType::LocalDateTime, + value.ValueLocalDateTime().MicrosecondsSinceEpoch())); + case Value::Type::Duration: + return storage::v3::PropertyValue( + storage::v3::TemporalData(storage::v3::TemporalType::Duration, value.ValueDuration().microseconds)); + } +} + +Value ToBoltValue(const storage::v3::PropertyValue &value) { + switch (value.type()) { + case storage::v3::PropertyValue::Type::Null: + return {}; + case storage::v3::PropertyValue::Type::Bool: + return {value.ValueBool()}; + case storage::v3::PropertyValue::Type::Int: + return {value.ValueInt()}; + break; + case storage::v3::PropertyValue::Type::Double: + return {value.ValueDouble()}; + case storage::v3::PropertyValue::Type::String: + return {value.ValueString()}; + case storage::v3::PropertyValue::Type::List: { + const auto &values = value.ValueList(); + std::vector vec; + vec.reserve(values.size()); + for (const auto &v : values) { + vec.push_back(ToBoltValue(v)); + } + return {std::move(vec)}; + } + case storage::v3::PropertyValue::Type::Map: { + const auto &map = value.ValueMap(); + std::map dv_map; + for (const auto &kv : map) { + dv_map.emplace(kv.first, ToBoltValue(kv.second)); + } + return {std::move(dv_map)}; + } + case storage::v3::PropertyValue::Type::TemporalData: + const auto &type = value.ValueTemporalData(); + switch (type.type) { + case storage::v3::TemporalType::Date: + return {utils::Date(type.microseconds)}; + case storage::v3::TemporalType::LocalTime: + return {utils::LocalTime(type.microseconds)}; + case storage::v3::TemporalType::LocalDateTime: + return {utils::LocalDateTime(type.microseconds)}; + case storage::v3::TemporalType::Duration: + return {utils::Duration(type.microseconds)}; + } + } +} + +} // namespace memgraph::glue::v2 diff --git a/src/glue/v2/communication.hpp b/src/glue/v2/communication.hpp new file mode 100644 index 000000000..13bf96fca --- /dev/null +++ b/src/glue/v2/communication.hpp @@ -0,0 +1,68 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file Conversion functions between Value and other memgraph types. +#pragma once + +#include "communication/bolt/v1/value.hpp" +#include "query/v2/typed_value.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/result.hpp" +#include "storage/v3/view.hpp" + +namespace memgraph::storage::v3 { +class EdgeAccessor; +class Storage; +class VertexAccessor; +} // namespace memgraph::storage::v3 + +namespace memgraph::glue::v2 { + +/// @param storage::v3::VertexAccessor for converting to +/// communication::bolt::Vertex. +/// @param storage::v3::Storage for getting label and property names. +/// @param storage::v3::View for deciding which vertex attributes are visible. +/// +/// @throw std::bad_alloc +storage::v3::Result ToBoltVertex(const storage::v3::VertexAccessor &vertex, + const storage::v3::Storage &db, storage::v3::View view); + +/// @param storage::v3::EdgeAccessor for converting to communication::bolt::Edge. +/// @param storage::v3::Storage for getting edge type and property names. +/// @param storage::v3::View for deciding which edge attributes are visible. +/// +/// @throw std::bad_alloc +storage::v3::Result ToBoltEdge(const storage::v3::EdgeAccessor &edge, + const storage::v3::Storage &db, storage::v3::View view); + +/// @param query::v2::Path for converting to communication::bolt::Path. +/// @param storage::v3::Storage for ToBoltVertex and ToBoltEdge. +/// @param storage::v3::View for ToBoltVertex and ToBoltEdge. +/// +/// @throw std::bad_alloc +storage::v3::Result ToBoltPath(const query::v2::Path &path, const storage::v3::Storage &db, + storage::v3::View view); + +/// @param query::v2::TypedValue for converting to communication::bolt::Value. +/// @param storage::v3::Storage for ToBoltVertex and ToBoltEdge. +/// @param storage::v3::View for ToBoltVertex and ToBoltEdge. +/// +/// @throw std::bad_alloc +storage::v3::Result ToBoltValue(const query::v2::TypedValue &value, + const storage::v3::Storage &db, storage::v3::View view); + +query::v2::TypedValue ToTypedValue(const communication::bolt::Value &value); + +communication::bolt::Value ToBoltValue(const storage::v3::PropertyValue &value); + +storage::v3::PropertyValue ToPropertyValue(const communication::bolt::Value &value); + +} // namespace memgraph::glue::v2 diff --git a/src/query/common.hpp b/src/query/common.hpp index f6526494c..c51c34dee 100644 --- a/src/query/common.hpp +++ b/src/query/common.hpp @@ -16,7 +16,6 @@ #include #include #include -#include #include "query/db_accessor.hpp" #include "query/exceptions.hpp" @@ -25,12 +24,8 @@ #include "query/typed_value.hpp" #include "storage/v2/id_types.hpp" #include "storage/v2/property_value.hpp" -#include "storage/v2/result.hpp" -#include "storage/v2/schema_validator.hpp" #include "storage/v2/view.hpp" -#include "utils/exceptions.hpp" #include "utils/logging.hpp" -#include "utils/variant_helpers.hpp" namespace memgraph::query { @@ -86,79 +81,27 @@ concept AccessorWithSetProperty = requires(T accessor, const storage::PropertyId { accessor.SetProperty(key, new_value) } -> std::same_as>; }; -inline void HandleSchemaViolation(const storage::SchemaViolation &schema_violation, const DbAccessor &dba) { - switch (schema_violation.status) { - case storage::SchemaViolation::ValidationStatus::VERTEX_HAS_NO_PRIMARY_PROPERTY: { - throw SchemaViolationException( - fmt::format("Primary key {} not defined on label :{}", - storage::SchemaTypeToString(schema_violation.violated_schema_property->type), - dba.LabelToName(schema_violation.label))); - } - case storage::SchemaViolation::ValidationStatus::NO_SCHEMA_DEFINED_FOR_LABEL: { - throw SchemaViolationException( - fmt::format("Label :{} is not a primary label", dba.LabelToName(schema_violation.label))); - } - case storage::SchemaViolation::ValidationStatus::VERTEX_PROPERTY_WRONG_TYPE: { - throw SchemaViolationException( - fmt::format("Wrong type of property {} in schema :{}, should be of type {}", - *schema_violation.violated_property_value, dba.LabelToName(schema_violation.label), - storage::SchemaTypeToString(schema_violation.violated_schema_property->type))); - } - case storage::SchemaViolation::ValidationStatus::VERTEX_UPDATE_PRIMARY_KEY: { - throw SchemaViolationException(fmt::format("Updating of primary key {} on schema :{} not supported", - *schema_violation.violated_property_value, - dba.LabelToName(schema_violation.label))); - } - case storage::SchemaViolation::ValidationStatus::VERTEX_MODIFY_PRIMARY_LABEL: { - throw SchemaViolationException(fmt::format("Cannot add or remove label :{} since it is a primary label", - dba.LabelToName(schema_violation.label))); - } - case storage::SchemaViolation::ValidationStatus::VERTEX_SECONDARY_LABEL_IS_PRIMARY: { - throw SchemaViolationException( - fmt::format("Cannot create vertex with secondary label :{}", dba.LabelToName(schema_violation.label))); - } - } -} - -inline void HandleErrorOnPropertyUpdate(const storage::Error error) { - switch (error) { - case storage::Error::SERIALIZATION_ERROR: - throw TransactionSerializationException(); - case storage::Error::DELETED_OBJECT: - throw QueryRuntimeException("Trying to set properties on a deleted object."); - case storage::Error::PROPERTIES_DISABLED: - throw QueryRuntimeException("Can't set property because properties on edges are disabled."); - case storage::Error::VERTEX_HAS_EDGES: - case storage::Error::NONEXISTENT_OBJECT: - throw QueryRuntimeException("Unexpected error when setting a property."); - } -} - /// Set a property `value` mapped with given `key` on a `record`. /// /// @throw QueryRuntimeException if value cannot be set as a property value template -storage::PropertyValue PropsSetChecked(T *record, const DbAccessor &dba, const storage::PropertyId &key, - const TypedValue &value) { +storage::PropertyValue PropsSetChecked(T *record, const storage::PropertyId &key, const TypedValue &value) { try { - if constexpr (std::is_same_v) { - const auto maybe_old_value = record->SetPropertyAndValidate(key, storage::PropertyValue(value)); - if (maybe_old_value.HasError()) { - std::visit(utils::Overloaded{[](const storage::Error error) { HandleErrorOnPropertyUpdate(error); }, - [&dba](const storage::SchemaViolation &schema_violation) { - HandleSchemaViolation(schema_violation, dba); - }}, - maybe_old_value.GetError()); + auto maybe_old_value = record->SetProperty(key, storage::PropertyValue(value)); + if (maybe_old_value.HasError()) { + switch (maybe_old_value.GetError()) { + case storage::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to set properties on a deleted object."); + case storage::Error::PROPERTIES_DISABLED: + throw QueryRuntimeException("Can't set property because properties on edges are disabled."); + case storage::Error::VERTEX_HAS_EDGES: + case storage::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when setting a property."); } - return std::move(*maybe_old_value); - } else { - // No validation on edge properties - const auto maybe_old_value = record->SetProperty(key, storage::PropertyValue(value)); - if (maybe_old_value.HasError()) { - HandleErrorOnPropertyUpdate(maybe_old_value.GetError()); - } - return std::move(*maybe_old_value); } + return std::move(*maybe_old_value); } catch (const TypedValueException &) { throw QueryRuntimeException("'{}' cannot be used as a property value.", value.type()); } diff --git a/src/query/db_accessor.hpp b/src/query/db_accessor.hpp index 1fad73101..f8d400e54 100644 --- a/src/query/db_accessor.hpp +++ b/src/query/db_accessor.hpp @@ -12,7 +12,6 @@ #pragma once #include -#include #include #include @@ -24,7 +23,7 @@ /////////////////////////////////////////////////////////// // Our communication layer and query engine don't mix -// very well on Centos because OpenSSL version available +// very well on Centos because OpenSSL version avaialable // on Centos 7 include libkrb5 which has brilliant macros // called TRUE and FALSE. For more detailed explanation go // to memgraph.cpp. @@ -35,8 +34,6 @@ // simply undefine those macros as we're sure that libkrb5 // won't and can't be used anywhere in the query engine. #include "storage/v2/storage.hpp" -#include "utils/logging.hpp" -#include "utils/result.hpp" #undef FALSE #undef TRUE @@ -105,18 +102,10 @@ class VertexAccessor final { auto Labels(storage::View view) const { return impl_.Labels(view); } - auto PrimaryLabel(storage::View view) const { return impl_.PrimaryLabel(view); } - storage::Result AddLabel(storage::LabelId label) { return impl_.AddLabel(label); } - storage::ResultSchema AddLabelAndValidate(storage::LabelId label) { return impl_.AddLabelAndValidate(label); } - storage::Result RemoveLabel(storage::LabelId label) { return impl_.RemoveLabel(label); } - storage::ResultSchema RemoveLabelAndValidate(storage::LabelId label) { - return impl_.RemoveLabelAndValidate(label); - } - storage::Result HasLabel(storage::View view, storage::LabelId label) const { return impl_.HasLabel(label, view); } @@ -131,13 +120,8 @@ class VertexAccessor final { return impl_.SetProperty(key, value); } - storage::ResultSchema SetPropertyAndValidate(storage::PropertyId key, - const storage::PropertyValue &value) { - return impl_.SetPropertyAndValidate(key, value); - } - - storage::ResultSchema RemovePropertyAndValidate(storage::PropertyId key) { - return SetPropertyAndValidate(key, storage::PropertyValue{}); + storage::Result RemoveProperty(storage::PropertyId key) { + return SetProperty(key, storage::PropertyValue()); } storage::Result> ClearProperties() { @@ -263,18 +247,7 @@ class DbAccessor final { return VerticesIterable(accessor_->Vertices(label, property, lower, upper, view)); } - // TODO Remove when query modules have been fixed - [[deprecated]] VertexAccessor InsertVertex() { return VertexAccessor(accessor_->CreateVertex()); } - - storage::ResultSchema InsertVertexAndValidate( - const storage::LabelId primary_label, const std::vector &labels, - const std::vector> &properties) { - auto maybe_vertex_acc = accessor_->CreateVertexAndValidate(primary_label, labels, properties); - if (maybe_vertex_acc.HasError()) { - return {std::move(maybe_vertex_acc.GetError())}; - } - return VertexAccessor{maybe_vertex_acc.GetValue()}; - } + VertexAccessor InsertVertex() { return VertexAccessor(accessor_->CreateVertex()); } storage::Result InsertEdge(VertexAccessor *from, VertexAccessor *to, const storage::EdgeTypeId &edge_type) { @@ -332,7 +305,7 @@ class DbAccessor final { return std::optional{}; } - return {std::make_optional(*value)}; + return std::make_optional(*value); } storage::PropertyId NameToProperty(const std::string_view name) { return accessor_->NameToProperty(name); } @@ -381,10 +354,6 @@ class DbAccessor final { storage::IndicesInfo ListAllIndices() const { return accessor_->ListAllIndices(); } storage::ConstraintsInfo ListAllConstraints() const { return accessor_->ListAllConstraints(); } - - const storage::SchemaValidator &GetSchemaValidator() const { return accessor_->GetSchemaValidator(); } - - storage::SchemasInfo ListAllSchemas() const { return accessor_->ListAllSchemas(); } }; } // namespace memgraph::query diff --git a/src/query/exceptions.hpp b/src/query/exceptions.hpp index daba98e7a..a18ce0c43 100644 --- a/src/query/exceptions.hpp +++ b/src/query/exceptions.hpp @@ -224,12 +224,4 @@ class VersionInfoInMulticommandTxException : public QueryException { : QueryException("Version info query not allowed in multicommand transactions.") {} }; -/** - * An exception for an illegal operation that violates schema - */ -class SchemaViolationException : public QueryRuntimeException { - public: - using QueryRuntimeException::QueryRuntimeException; -}; - } // namespace memgraph::query diff --git a/src/query/frontend/ast/ast.lcp b/src/query/frontend/ast/ast.lcp index 24da369ca..6b5b5a16f 100644 --- a/src/query/frontend/ast/ast.lcp +++ b/src/query/frontend/ast/ast.lcp @@ -17,7 +17,6 @@ #include #include -#include "common/types.hpp" #include "query/frontend/ast/ast_visitor.hpp" #include "query/frontend/semantic/symbol.hpp" #include "query/interpret/awesome_memgraph_functions.hpp" @@ -134,15 +133,6 @@ cpp<# } cpp<#)) - -(defun clone-schema-property-vector (source dest) - #>cpp - ${dest}.reserve(${source}.size()); - for (const auto &[property_ix, property_type]: ${source}) { - ${dest}.emplace_back(storage->GetPropertyIx(property_ix.name), property_type); - } - cpp<#) - ;; The following index structs serve as a decoupling point of AST from ;; concrete database types. All the names are collected in AstStorage, and can ;; be indexed through these instances. This means that we can create a vector @@ -2263,7 +2253,7 @@ cpp<# (lcp:define-enum privilege (create delete match merge set remove index stats auth constraint dump replication durability read_file free_memory trigger config stream module_read module_write - websocket schema) + websocket) (:serialize)) #>cpp AuthQuery() = default; @@ -2305,7 +2295,7 @@ const std::vector kPrivilegesAll = { AuthQuery::Privilege::FREE_MEMORY, AuthQuery::Privilege::TRIGGER, AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM, AuthQuery::Privilege::MODULE_READ, AuthQuery::Privilege::MODULE_WRITE, - AuthQuery::Privilege::WEBSOCKET, AuthQuery::Privilege::SCHEMA}; + AuthQuery::Privilege::WEBSOCKET}; cpp<# (lcp:define-class info-query (query) @@ -2675,38 +2665,5 @@ cpp<# (:serialize (:slk)) (:clone)) -(lcp:define-class schema-query (query) - ((action "Action" :scope :public) - (label "LabelIx" :scope :public - :slk-load (lambda (member) - #>cpp - slk::Load(&self->${member}, reader, storage); - cpp<#) - :clone (lambda (source dest) - #>cpp - ${dest} = storage->GetLabelIx(${source}.name); - cpp<#)) - (schema_type_map "std::vector>" - :slk-save #'slk-save-property-map - :slk-load #'slk-load-property-map - :clone #'clone-schema-property-vector - :scope :public)) - - (:public - (lcp:define-enum action - (create-schema drop-schema show-schema show-schemas) - (:serialize)) - #>cpp - SchemaQuery() = default; - - DEFVISITABLE(QueryVisitor); - cpp<#) - (:private - #>cpp - friend class AstStorage; - cpp<#) - (:serialize (:slk)) - (:clone)) - (lcp:pop-namespace) ;; namespace query (lcp:pop-namespace) ;; namespace memgraph diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index 307b96907..0e4a6012c 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -94,7 +94,6 @@ class StreamQuery; class SettingQuery; class VersionQuery; class Foreach; -class SchemaQuery; using TreeCompositeVisitor = utils::CompositeVisitor< SingleQuery, CypherUnion, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, AdditionOperator, @@ -126,9 +125,9 @@ class ExpressionVisitor None, ParameterLookup, Identifier, PrimitiveLiteral, RegexMatch> {}; template -class QueryVisitor : public utils::Visitor {}; +class QueryVisitor + : public utils::Visitor {}; } // namespace memgraph::query diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index ac8af4c99..e97bdc05e 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -17,7 +17,6 @@ #include #include #include -#include #include #include #include @@ -28,7 +27,6 @@ #include -#include "common/types.hpp" #include "query/exceptions.hpp" #include "query/frontend/ast/ast.hpp" #include "query/frontend/ast/ast_visitor.hpp" @@ -1349,7 +1347,6 @@ antlrcpp::Any CypherMainVisitor::visitPrivilege(MemgraphCypher::PrivilegeContext if (ctx->MODULE_READ()) return AuthQuery::Privilege::MODULE_READ; if (ctx->MODULE_WRITE()) return AuthQuery::Privilege::MODULE_WRITE; if (ctx->WEBSOCKET()) return AuthQuery::Privilege::WEBSOCKET; - if (ctx->SCHEMA()) return AuthQuery::Privilege::SCHEMA; LOG_FATAL("Should not get here - unknown privilege!"); } @@ -2356,93 +2353,6 @@ antlrcpp::Any CypherMainVisitor::visitForeach(MemgraphCypher::ForeachContext *ct return for_each; } -antlrcpp::Any CypherMainVisitor::visitSchemaQuery(MemgraphCypher::SchemaQueryContext *ctx) { - MG_ASSERT(ctx->children.size() == 1, "SchemaQuery should have exactly one child!"); - auto *schema_query = ctx->children[0]->accept(this).as(); - query_ = schema_query; - return schema_query; -} - -antlrcpp::Any CypherMainVisitor::visitShowSchema(MemgraphCypher::ShowSchemaContext *ctx) { - auto *schema_query = storage_->Create(); - schema_query->action_ = SchemaQuery::Action::SHOW_SCHEMA; - schema_query->label_ = AddLabel(ctx->labelName()->accept(this)); - query_ = schema_query; - return schema_query; -} - -antlrcpp::Any CypherMainVisitor::visitShowSchemas(MemgraphCypher::ShowSchemasContext * /*ctx*/) { - auto *schema_query = storage_->Create(); - schema_query->action_ = SchemaQuery::Action::SHOW_SCHEMAS; - query_ = schema_query; - return schema_query; -} - -antlrcpp::Any CypherMainVisitor::visitPropertyType(MemgraphCypher::PropertyTypeContext *ctx) { - MG_ASSERT(ctx->symbolicName()); - const auto property_type = utils::ToLowerCase(ctx->symbolicName()->accept(this).as()); - if (property_type == "bool") { - return common::SchemaType::BOOL; - } - if (property_type == "string") { - return common::SchemaType::STRING; - } - if (property_type == "integer") { - return common::SchemaType::INT; - } - if (property_type == "date") { - return common::SchemaType::DATE; - } - if (property_type == "duration") { - return common::SchemaType::DURATION; - } - if (property_type == "localdatetime") { - return common::SchemaType::LOCALDATETIME; - } - if (property_type == "localtime") { - return common::SchemaType::LOCALTIME; - } - throw SyntaxException("Property type must be one of the supported types!"); -} - -/** - * @return Schema* - */ -antlrcpp::Any CypherMainVisitor::visitSchemaPropertyMap(MemgraphCypher::SchemaPropertyMapContext *ctx) { - std::vector> schema_property_map; - for (auto *property_key_pair : ctx->propertyKeyTypePair()) { - PropertyIx key = property_key_pair->propertyKeyName()->accept(this); - common::SchemaType type = property_key_pair->propertyType()->accept(this); - if (std::ranges::find_if(schema_property_map, [&key](const auto &elem) { return elem.first == key; }) != - schema_property_map.end()) { - throw SemanticException("Same property name can't appear twice in a schema map."); - } - schema_property_map.emplace_back(key, type); - } - return schema_property_map; -} - -antlrcpp::Any CypherMainVisitor::visitCreateSchema(MemgraphCypher::CreateSchemaContext *ctx) { - auto *schema_query = storage_->Create(); - schema_query->action_ = SchemaQuery::Action::CREATE_SCHEMA; - schema_query->label_ = AddLabel(ctx->labelName()->accept(this)); - schema_query->schema_type_map_ = - ctx->schemaPropertyMap()->accept(this).as>>(); - query_ = schema_query; - return schema_query; -} - -/** - * @return Schema* - */ -antlrcpp::Any CypherMainVisitor::visitDropSchema(MemgraphCypher::DropSchemaContext *ctx) { - auto *schema_query = storage_->Create(); - schema_query->action_ = SchemaQuery::Action::DROP_SCHEMA; - schema_query->label_ = AddLabel(ctx->labelName()->accept(this)); - query_ = schema_query; - return schema_query; -} - LabelIx CypherMainVisitor::AddLabel(const std::string &name) { return storage_->GetLabelIx(name); } PropertyIx CypherMainVisitor::AddProperty(const std::string &name) { return storage_->GetPropertyIx(name); } diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index 30914ba25..d3a12072e 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -849,41 +849,6 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { */ antlrcpp::Any visitForeach(MemgraphCypher::ForeachContext *ctx) override; - /** - * @return Schema* - */ - antlrcpp::Any visitPropertyType(MemgraphCypher::PropertyTypeContext *ctx) override; - - /** - * @return Schema* - */ - antlrcpp::Any visitSchemaPropertyMap(MemgraphCypher::SchemaPropertyMapContext *ctx) override; - - /** - * @return Schema* - */ - antlrcpp::Any visitSchemaQuery(MemgraphCypher::SchemaQueryContext *ctx) override; - - /** - * @return Schema* - */ - antlrcpp::Any visitShowSchema(MemgraphCypher::ShowSchemaContext *ctx) override; - - /** - * @return Schema* - */ - antlrcpp::Any visitShowSchemas(MemgraphCypher::ShowSchemasContext *ctx) override; - - /** - * @return Schema* - */ - antlrcpp::Any visitCreateSchema(MemgraphCypher::CreateSchemaContext *ctx) override; - - /** - * @return Schema* - */ - antlrcpp::Any visitDropSchema(MemgraphCypher::DropSchemaContext *ctx) override; - public: Query *query() { return query_; } const static std::string kAnonPrefix; diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 index 956320bbf..8790b8731 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 @@ -46,10 +46,10 @@ memgraphCypherKeyword : cypherKeyword | DROP | DUMP | EXECUTE - | FREE - | FROM | FOR | FOREACH + | FREE + | FROM | GLOBAL | GRANT | HEADER @@ -76,8 +76,6 @@ memgraphCypherKeyword : cypherKeyword | ROLE | ROLES | QUOTE - | SCHEMA - | SCHEMAS | SESSION | SETTING | SETTINGS @@ -124,7 +122,6 @@ query : cypherQuery | streamQuery | settingQuery | versionQuery - | schemaQuery ; authQuery : createRole @@ -195,12 +192,6 @@ settingQuery : setSetting | showSettings ; -schemaQuery : showSchema - | showSchemas - | createSchema - | dropSchema - ; - loadCsv : LOAD CSV FROM csvFile ( WITH | NO ) HEADER ( IGNORE BAD ) ? ( DELIMITER delimiter ) ? @@ -263,7 +254,6 @@ privilege : CREATE | MODULE_READ | MODULE_WRITE | WEBSOCKET - | SCHEMA ; privilegeList : privilege ( ',' privilege )* ; @@ -383,17 +373,3 @@ showSetting : SHOW DATABASE SETTING settingName ; showSettings : SHOW DATABASE SETTINGS ; versionQuery : SHOW VERSION ; - -showSchema : SHOW SCHEMA ON ':' labelName ; - -showSchemas : SHOW SCHEMAS ; - -propertyType : symbolicName ; - -propertyKeyTypePair : propertyKeyName propertyType ; - -schemaPropertyMap : '(' propertyKeyTypePair ( ',' propertyKeyTypePair )* ')' ; - -createSchema : CREATE SCHEMA ON ':' labelName schemaPropertyMap ; - -dropSchema : DROP SCHEMA ON ':' labelName ; diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 index 869141033..55e5d53a2 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 @@ -89,8 +89,6 @@ REVOKE : R E V O K E ; ROLE : R O L E ; ROLES : R O L E S ; QUOTE : Q U O T E ; -SCHEMA : S C H E M A ; -SCHEMAS : S C H E M A S ; SERVICE_URL : S E R V I C E UNDERSCORE U R L ; SESSION : S E S S I O N ; SETTING : S E T T I N G ; diff --git a/src/query/frontend/semantic/required_privileges.cpp b/src/query/frontend/semantic/required_privileges.cpp index 160004ac2..e8dbd21e5 100644 --- a/src/query/frontend/semantic/required_privileges.cpp +++ b/src/query/frontend/semantic/required_privileges.cpp @@ -80,8 +80,6 @@ class PrivilegeExtractor : public QueryVisitor, public HierarchicalTreeVis void Visit(VersionQuery & /*version_query*/) override { AddPrivilege(AuthQuery::Privilege::STATS); } - void Visit(SchemaQuery & /*schema_query*/) override { AddPrivilege(AuthQuery::Privilege::SCHEMA); } - bool PreVisit(Create & /*unused*/) override { AddPrivilege(AuthQuery::Privilege::CREATE); return false; diff --git a/src/query/frontend/stripped_lexer_constants.hpp b/src/query/frontend/stripped_lexer_constants.hpp index be516aee6..784692b53 100644 --- a/src/query/frontend/stripped_lexer_constants.hpp +++ b/src/query/frontend/stripped_lexer_constants.hpp @@ -205,8 +205,7 @@ const trie::Trie kKeywords = {"union", "service_url", "version", "websocket", - "foreach", - "schema"}; + "foreach"}; // Unicode codepoints that are allowed at the start of the unescaped name. const std::bitset kUnescapedNameAllowedStarts( diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 2989b2879..2925b1c14 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -45,7 +45,6 @@ #include "query/typed_value.hpp" #include "storage/v2/property_value.hpp" #include "storage/v2/replication/enums.hpp" -#include "storage/v2/schemas.hpp" #include "utils/algorithm.hpp" #include "utils/csv_parsing.hpp" #include "utils/event_counter.hpp" @@ -873,102 +872,6 @@ Callback HandleSettingQuery(SettingQuery *setting_query, const Parameters ¶m } } -Callback HandleSchemaQuery(SchemaQuery *schema_query, InterpreterContext *interpreter_context, - std::vector *notifications) { - Callback callback; - switch (schema_query->action_) { - case SchemaQuery::Action::SHOW_SCHEMAS: { - callback.header = {"label", "primary_key"}; - callback.fn = [interpreter_context]() { - auto *db = interpreter_context->db; - auto schemas_info = db->ListAllSchemas(); - std::vector> results; - results.reserve(schemas_info.schemas.size()); - - for (const auto &[label_id, schema_types] : schemas_info.schemas) { - std::vector schema_info_row; - schema_info_row.reserve(3); - - schema_info_row.emplace_back(db->LabelToName(label_id)); - std::vector primary_key_properties; - primary_key_properties.reserve(schema_types.size()); - std::transform(schema_types.begin(), schema_types.end(), std::back_inserter(primary_key_properties), - [&db](const auto &schema_type) { - return db->PropertyToName(schema_type.property_id) + - "::" + storage::SchemaTypeToString(schema_type.type); - }); - - schema_info_row.emplace_back(utils::Join(primary_key_properties, ", ")); - results.push_back(std::move(schema_info_row)); - } - return results; - }; - return callback; - } - case SchemaQuery::Action::SHOW_SCHEMA: { - callback.header = {"property_name", "property_type"}; - callback.fn = [interpreter_context, primary_label = schema_query->label_]() { - auto *db = interpreter_context->db; - const auto label = db->NameToLabel(primary_label.name); - const auto schema = db->GetSchema(label); - std::vector> results; - if (schema) { - for (const auto &schema_property : schema->second) { - std::vector schema_info_row; - schema_info_row.reserve(2); - schema_info_row.emplace_back(db->PropertyToName(schema_property.property_id)); - schema_info_row.emplace_back(storage::SchemaTypeToString(schema_property.type)); - results.push_back(std::move(schema_info_row)); - } - return results; - } - throw QueryException(fmt::format("Schema on label :{} not found!", primary_label.name)); - }; - return callback; - } - case SchemaQuery::Action::CREATE_SCHEMA: { - auto schema_type_map = schema_query->schema_type_map_; - if (schema_query->schema_type_map_.empty()) { - throw SyntaxException("One or more types have to be defined in schema definition."); - } - callback.fn = [interpreter_context, primary_label = schema_query->label_, - schema_type_map = std::move(schema_type_map)]() { - auto *db = interpreter_context->db; - const auto label = db->NameToLabel(primary_label.name); - std::vector schemas_types; - schemas_types.reserve(schema_type_map.size()); - for (const auto &schema_type : schema_type_map) { - auto property_id = db->NameToProperty(schema_type.first.name); - schemas_types.push_back({property_id, schema_type.second}); - } - if (!db->CreateSchema(label, schemas_types)) { - throw QueryException(fmt::format("Schema on label :{} already exists!", primary_label.name)); - } - return std::vector>{}; - }; - notifications->emplace_back(SeverityLevel::INFO, NotificationCode::CREATE_SCHEMA, - fmt::format("Create schema on label :{}", schema_query->label_.name)); - return callback; - } - case SchemaQuery::Action::DROP_SCHEMA: { - callback.fn = [interpreter_context, primary_label = schema_query->label_]() { - auto *db = interpreter_context->db; - const auto label = db->NameToLabel(primary_label.name); - - if (!db->DropSchema(label)) { - throw QueryException(fmt::format("Schema on label :{} does not exist!", primary_label.name)); - } - - return std::vector>{}; - }; - notifications->emplace_back(SeverityLevel::INFO, NotificationCode::DROP_SCHEMA, - fmt::format("Dropped schema on label :{}", schema_query->label_.name)); - return callback; - } - } - return callback; -} - // Struct for lazy pulling from a vector struct PullPlanVector { explicit PullPlanVector(std::vector> values) : values_(std::move(values)) {} @@ -2163,32 +2066,6 @@ PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_ RWType::NONE}; } -PreparedQuery PrepareSchemaQuery(ParsedQuery parsed_query, bool in_explicit_transaction, - InterpreterContext *interpreter_context, std::vector *notifications) { - if (in_explicit_transaction) { - throw ConstraintInMulticommandTxException(); - } - auto *schema_query = utils::Downcast(parsed_query.query); - MG_ASSERT(schema_query); - auto callback = HandleSchemaQuery(schema_query, interpreter_context, notifications); - - return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), - [handler = std::move(callback.fn), action = QueryHandlerResult::NOTHING, - pull_plan = std::shared_ptr(nullptr)]( - AnyStream *stream, std::optional n) mutable -> std::optional { - if (!pull_plan) { - auto results = handler(); - pull_plan = std::make_shared(std::move(results)); - } - - if (pull_plan->Pull(stream, n)) { - return action; - } - return std::nullopt; - }, - RWType::NONE}; -} - void Interpreter::BeginTransaction() { const auto prepared_query = PrepareTransactionQuery("BEGIN"); prepared_query.query_handler(nullptr, {}); @@ -2322,9 +2199,6 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, prepared_query = PrepareSettingQuery(std::move(parsed_query), in_explicit_transaction_, &*execution_db_accessor_); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareVersionQuery(std::move(parsed_query), in_explicit_transaction_); - } else if (utils::Downcast(parsed_query.query)) { - prepared_query = PrepareSchemaQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_, - &query_execution->notifications); } else { LOG_FATAL("Should not get here -- unknown query type!"); } diff --git a/src/query/metadata.cpp b/src/query/metadata.cpp index 2e25ce8a4..fa80c61f5 100644 --- a/src/query/metadata.cpp +++ b/src/query/metadata.cpp @@ -38,8 +38,6 @@ constexpr std::string_view GetCodeString(const NotificationCode code) { return "CreateIndex"sv; case NotificationCode::CREATE_STREAM: return "CreateStream"sv; - case NotificationCode::CREATE_SCHEMA: - return "CreateSchema"sv; case NotificationCode::CHECK_STREAM: return "CheckStream"sv; case NotificationCode::CREATE_TRIGGER: @@ -50,8 +48,6 @@ constexpr std::string_view GetCodeString(const NotificationCode code) { return "DropReplica"sv; case NotificationCode::DROP_INDEX: return "DropIndex"sv; - case NotificationCode::DROP_SCHEMA: - return "DropSchema"sv; case NotificationCode::DROP_STREAM: return "DropStream"sv; case NotificationCode::DROP_TRIGGER: @@ -72,10 +68,6 @@ constexpr std::string_view GetCodeString(const NotificationCode code) { return "ReplicaPortWarning"sv; case NotificationCode::SET_REPLICA: return "SetReplica"sv; - case NotificationCode::SHOW_SCHEMA: - return "ShowSchema"sv; - case NotificationCode::SHOW_SCHEMAS: - return "ShowSchemas"sv; case NotificationCode::START_STREAM: return "StartStream"sv; case NotificationCode::START_ALL_STREAMS: diff --git a/src/query/metadata.hpp b/src/query/metadata.hpp index e557ca72e..67f784fa8 100644 --- a/src/query/metadata.hpp +++ b/src/query/metadata.hpp @@ -26,14 +26,12 @@ enum class SeverityLevel : uint8_t { INFO, WARNING }; enum class NotificationCode : uint8_t { CREATE_CONSTRAINT, CREATE_INDEX, - CREATE_SCHEMA, CHECK_STREAM, CREATE_STREAM, CREATE_TRIGGER, DROP_CONSTRAINT, DROP_INDEX, DROP_REPLICA, - DROP_SCHEMA, DROP_STREAM, DROP_TRIGGER, EXISTANT_INDEX, @@ -44,8 +42,6 @@ enum class NotificationCode : uint8_t { REPLICA_PORT_WARNING, REGISTER_REPLICA, SET_REPLICA, - SHOW_SCHEMA, - SHOW_SCHEMAS, START_STREAM, START_ALL_STREAMS, STOP_STREAM, diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index fc55c13e8..72117f3ad 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -37,12 +37,7 @@ #include "query/procedure/cypher_types.hpp" #include "query/procedure/mg_procedure_impl.hpp" #include "query/procedure/module.hpp" -#include "query/typed_value.hpp" -#include "storage/v2/id_types.hpp" #include "storage/v2/property_value.hpp" -#include "storage/v2/result.hpp" -#include "storage/v2/schema_validator.hpp" -#include "storage/v2/schemas.hpp" #include "utils/algorithm.hpp" #include "utils/csv_parsing.hpp" #include "utils/event_counter.hpp" @@ -57,7 +52,6 @@ #include "utils/readable_size.hpp" #include "utils/string.hpp" #include "utils/temporal.hpp" -#include "utils/variant_helpers.hpp" // macro for the default implementation of LogicalOperator::Accept // that accepts the visitor and visits it's input_ operator @@ -180,56 +174,45 @@ CreateNode::CreateNode(const std::shared_ptr &input, const Node // Creates a vertex on this GraphDb. Returns a reference to vertex placed on the // frame. -VertexAccessor &CreateLocalVertexAtomically(const NodeCreationInfo &node_info, Frame *frame, - ExecutionContext &context) { +VertexAccessor &CreateLocalVertex(const NodeCreationInfo &node_info, Frame *frame, ExecutionContext &context) { auto &dba = *context.db_accessor; + auto new_node = dba.InsertVertex(); + context.execution_stats[ExecutionStats::Key::CREATED_NODES] += 1; + for (auto label : node_info.labels) { + auto maybe_error = new_node.AddLabel(label); + if (maybe_error.HasError()) { + switch (maybe_error.GetError()) { + case storage::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to set a label on a deleted node."); + case storage::Error::VERTEX_HAS_EDGES: + case storage::Error::PROPERTIES_DISABLED: + case storage::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when setting a label."); + } + } + context.execution_stats[ExecutionStats::Key::CREATED_LABELS] += 1; + } // Evaluator should use the latest accessors, as modified in this query, when // setting properties on new nodes. ExpressionEvaluator evaluator(frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::View::NEW); - - std::vector> properties; + // TODO: PropsSetChecked allocates a PropertyValue, make it use context.memory + // when we update PropertyValue with custom allocator. if (const auto *node_info_properties = std::get_if(&node_info.properties)) { - properties.reserve(node_info_properties->size()); for (const auto &[key, value_expression] : *node_info_properties) { - properties.emplace_back(key, storage::PropertyValue(value_expression->Accept(evaluator))); + PropsSetChecked(&new_node, key, value_expression->Accept(evaluator)); } } else { - auto property_map = evaluator.Visit(*std::get(node_info.properties)).ValueMap(); - properties.reserve(property_map.size()); - - for (const auto &[key, value] : property_map) { + auto property_map = evaluator.Visit(*std::get(node_info.properties)); + for (const auto &[key, value] : property_map.ValueMap()) { auto property_id = dba.NameToProperty(key); - properties.emplace_back(property_id, value); + PropsSetChecked(&new_node, property_id, value); } } - // TODO Remove later on since that will be enforced from grammar side - MG_ASSERT(!node_info.labels.empty(), "There must be at least one label!"); - const auto primary_label = node_info.labels[0]; - std::vector secondary_labels(node_info.labels.begin() + 1, node_info.labels.end()); - auto maybe_new_node = dba.InsertVertexAndValidate(primary_label, secondary_labels, properties); - if (maybe_new_node.HasError()) { - std::visit(utils::Overloaded{[&dba](const storage::SchemaViolation &schema_violation) { - HandleSchemaViolation(schema_violation, dba); - }, - [](const storage::Error error) { - switch (error) { - case storage::Error::SERIALIZATION_ERROR: - throw TransactionSerializationException(); - case storage::Error::DELETED_OBJECT: - throw QueryRuntimeException("Trying to set a label on a deleted node."); - case storage::Error::VERTEX_HAS_EDGES: - case storage::Error::PROPERTIES_DISABLED: - case storage::Error::NONEXISTENT_OBJECT: - throw QueryRuntimeException("Unexpected error when setting a label."); - } - }}, - maybe_new_node.GetError()); - } - context.execution_stats[ExecutionStats::Key::CREATED_NODES] += 1; - - (*frame)[node_info.symbol] = *maybe_new_node; + (*frame)[node_info.symbol] = new_node; return (*frame)[node_info.symbol].ValueVertex(); } @@ -254,7 +237,7 @@ bool CreateNode::CreateNodeCursor::Pull(Frame &frame, ExecutionContext &context) SCOPED_PROFILE_OP("CreateNode"); if (input_cursor_->Pull(frame, context)) { - auto created_vertex = CreateLocalVertexAtomically(self_.node_info_, &frame, context); + auto created_vertex = CreateLocalVertex(self_.node_info_, &frame, context); if (context.trigger_context_collector) { context.trigger_context_collector->RegisterCreatedObject(created_vertex); } @@ -303,13 +286,13 @@ EdgeAccessor CreateEdge(const EdgeCreationInfo &edge_info, DbAccessor *dba, Vert auto &edge = *maybe_edge; if (const auto *properties = std::get_if(&edge_info.properties)) { for (const auto &[key, value_expression] : *properties) { - PropsSetChecked(&edge, *dba, key, value_expression->Accept(*evaluator)); + PropsSetChecked(&edge, key, value_expression->Accept(*evaluator)); } } else { auto property_map = evaluator->Visit(*std::get(edge_info.properties)); for (const auto &[key, value] : property_map.ValueMap()) { auto property_id = dba->NameToProperty(key); - PropsSetChecked(&edge, *dba, property_id, value); + PropsSetChecked(&edge, property_id, value); } } @@ -386,7 +369,7 @@ VertexAccessor &CreateExpand::CreateExpandCursor::OtherVertex(Frame &frame, Exec ExpectType(self_.node_info_.symbol, dest_node_value, TypedValue::Type::Vertex); return dest_node_value.ValueVertex(); } - auto &created_vertex = CreateLocalVertexAtomically(self_.node_info_, &frame, context); + auto &created_vertex = CreateLocalVertex(self_.node_info_, &frame, context); if (context.trigger_context_collector) { context.trigger_context_collector->RegisterCreatedObject(created_vertex); } @@ -2063,7 +2046,7 @@ bool SetProperty::SetPropertyCursor::Pull(Frame &frame, ExecutionContext &contex switch (lhs.type()) { case TypedValue::Type::Vertex: { - auto old_value = PropsSetChecked(&lhs.ValueVertex(), *context.db_accessor, self_.property_, rhs); + auto old_value = PropsSetChecked(&lhs.ValueVertex(), self_.property_, rhs); context.execution_stats[ExecutionStats::Key::UPDATED_PROPERTIES] += 1; if (context.trigger_context_collector) { // rhs cannot be moved because it was created with the allocator that is only valid during current pull @@ -2073,7 +2056,7 @@ bool SetProperty::SetPropertyCursor::Pull(Frame &frame, ExecutionContext &contex break; } case TypedValue::Type::Edge: { - auto old_value = PropsSetChecked(&lhs.ValueEdge(), *context.db_accessor, self_.property_, rhs); + auto old_value = PropsSetChecked(&lhs.ValueEdge(), self_.property_, rhs); context.execution_stats[ExecutionStats::Key::UPDATED_PROPERTIES] += 1; if (context.trigger_context_collector) { // rhs cannot be moved because it was created with the allocator that is only valid during current pull @@ -2227,7 +2210,7 @@ void SetPropertiesOnRecord(TRecordAccessor *record, const TypedValue &rhs, SetPr case TypedValue::Type::Map: { for (const auto &kv : rhs.ValueMap()) { auto key = context->db_accessor->NameToProperty(kv.first); - auto old_value = PropsSetChecked(record, *context->db_accessor, key, kv.second); + auto old_value = PropsSetChecked(record, key, kv.second); if (should_register_change) { register_set_property(std::move(old_value), key, kv.second); } @@ -2311,31 +2294,22 @@ bool SetLabels::SetLabelsCursor::Pull(Frame &frame, ExecutionContext &context) { // Skip setting labels on Null (can occur in optional match). if (vertex_value.IsNull()) return true; ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); - - auto &dba = *context.db_accessor; auto &vertex = vertex_value.ValueVertex(); - for (const auto label : self_.labels_) { - auto maybe_value = vertex.AddLabelAndValidate(label); + for (auto label : self_.labels_) { + auto maybe_value = vertex.AddLabel(label); if (maybe_value.HasError()) { - std::visit(utils::Overloaded{[](const storage::Error error) { - switch (error) { - case storage::Error::SERIALIZATION_ERROR: - throw TransactionSerializationException(); - case storage::Error::DELETED_OBJECT: - throw QueryRuntimeException("Trying to set a label on a deleted node."); - case storage::Error::VERTEX_HAS_EDGES: - case storage::Error::PROPERTIES_DISABLED: - case storage::Error::NONEXISTENT_OBJECT: - throw QueryRuntimeException("Unexpected error when setting a label."); - } - }, - [&dba](const storage::SchemaViolation schema_violation) { - HandleSchemaViolation(schema_violation, dba); - }}, - maybe_value.GetError()); + switch (maybe_value.GetError()) { + case storage::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to set a label on a deleted node."); + case storage::Error::VERTEX_HAS_EDGES: + case storage::Error::PROPERTIES_DISABLED: + case storage::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when setting a label."); + } } - context.execution_stats[ExecutionStats::Key::CREATED_LABELS]++; if (context.trigger_context_collector && *maybe_value) { context.trigger_context_collector->RegisterSetVertexLabel(vertex, label); } @@ -2378,11 +2352,26 @@ bool RemoveProperty::RemovePropertyCursor::Pull(Frame &frame, ExecutionContext & TypedValue lhs = self_.lhs_->expression_->Accept(evaluator); auto remove_prop = [property = self_.property_, &context](auto *record) { - auto old_value = PropsSetChecked(record, *context.db_accessor, property, TypedValue{}); + auto maybe_old_value = record->RemoveProperty(property); + if (maybe_old_value.HasError()) { + switch (maybe_old_value.GetError()) { + case storage::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to remove a property on a deleted graph element."); + case storage::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::Error::PROPERTIES_DISABLED: + throw QueryRuntimeException( + "Can't remove property because properties on edges are " + "disabled."); + case storage::Error::VERTEX_HAS_EDGES: + case storage::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when removing property."); + } + } if (context.trigger_context_collector) { context.trigger_context_collector->RegisterRemovedObjectProperty(*record, property, - TypedValue(std::move(old_value))); + TypedValue(std::move(*maybe_old_value))); } }; @@ -2436,25 +2425,18 @@ bool RemoveLabels::RemoveLabelsCursor::Pull(Frame &frame, ExecutionContext &cont ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); auto &vertex = vertex_value.ValueVertex(); for (auto label : self_.labels_) { - auto maybe_value = vertex.RemoveLabelAndValidate(label); + auto maybe_value = vertex.RemoveLabel(label); if (maybe_value.HasError()) { - std::visit( - utils::Overloaded{[](const storage::Error error) { - switch (error) { - case storage::Error::SERIALIZATION_ERROR: - throw TransactionSerializationException(); - case storage::Error::DELETED_OBJECT: - throw QueryRuntimeException("Trying to remove labels from a deleted node."); - case storage::Error::VERTEX_HAS_EDGES: - case storage::Error::PROPERTIES_DISABLED: - case storage::Error::NONEXISTENT_OBJECT: - throw QueryRuntimeException("Unexpected error when removing labels from a node."); - } - }, - [&context](const storage::SchemaViolation &schema_violation) { - HandleSchemaViolation(schema_violation, *context.db_accessor); - }}, - maybe_value.GetError()); + switch (maybe_value.GetError()) { + case storage::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to remove labels from a deleted node."); + case storage::Error::VERTEX_HAS_EDGES: + case storage::Error::PROPERTIES_DISABLED: + case storage::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when removing labels from a node."); + } } context.execution_stats[ExecutionStats::Key::DELETED_LABELS] += 1; diff --git a/src/query/v2/common.hpp b/src/query/v2/common.hpp index e79ca996c..ee8f72ddf 100644 --- a/src/query/v2/common.hpp +++ b/src/query/v2/common.hpp @@ -16,6 +16,7 @@ #include #include #include +#include #include "query/v2/db_accessor.hpp" #include "query/v2/exceptions.hpp" @@ -24,8 +25,12 @@ #include "query/v2/typed_value.hpp" #include "storage/v3/id_types.hpp" #include "storage/v3/property_value.hpp" +#include "storage/v3/result.hpp" +#include "storage/v3/schema_validator.hpp" #include "storage/v3/view.hpp" +#include "utils/exceptions.hpp" #include "utils/logging.hpp" +#include "utils/variant_helpers.hpp" namespace memgraph::query::v2 { @@ -81,27 +86,79 @@ concept AccessorWithSetProperty = requires(T accessor, const storage::v3::Proper { accessor.SetProperty(key, new_value) } -> std::same_as>; }; +inline void HandleSchemaViolation(const storage::v3::SchemaViolation &schema_violation, const DbAccessor &dba) { + switch (schema_violation.status) { + case storage::v3::SchemaViolation::ValidationStatus::VERTEX_HAS_NO_PRIMARY_PROPERTY: { + throw SchemaViolationException( + fmt::format("Primary key {} not defined on label :{}", + storage::v3::SchemaTypeToString(schema_violation.violated_schema_property->type), + dba.LabelToName(schema_violation.label))); + } + case storage::v3::SchemaViolation::ValidationStatus::NO_SCHEMA_DEFINED_FOR_LABEL: { + throw SchemaViolationException( + fmt::format("Label :{} is not a primary label", dba.LabelToName(schema_violation.label))); + } + case storage::v3::SchemaViolation::ValidationStatus::VERTEX_PROPERTY_WRONG_TYPE: { + throw SchemaViolationException( + fmt::format("Wrong type of property {} in schema :{}, should be of type {}", + *schema_violation.violated_property_value, dba.LabelToName(schema_violation.label), + storage::v3::SchemaTypeToString(schema_violation.violated_schema_property->type))); + } + case storage::v3::SchemaViolation::ValidationStatus::VERTEX_UPDATE_PRIMARY_KEY: { + throw SchemaViolationException(fmt::format("Updating of primary key {} on schema :{} not supported", + *schema_violation.violated_property_value, + dba.LabelToName(schema_violation.label))); + } + case storage::v3::SchemaViolation::ValidationStatus::VERTEX_MODIFY_PRIMARY_LABEL: { + throw SchemaViolationException(fmt::format("Cannot add or remove label :{} since it is a primary label", + dba.LabelToName(schema_violation.label))); + } + case storage::v3::SchemaViolation::ValidationStatus::VERTEX_SECONDARY_LABEL_IS_PRIMARY: { + throw SchemaViolationException( + fmt::format("Cannot create vertex with secondary label :{}", dba.LabelToName(schema_violation.label))); + } + } +} + +inline void HandleErrorOnPropertyUpdate(const storage::v3::Error error) { + switch (error) { + case storage::v3::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to set properties on a deleted object."); + case storage::v3::Error::PROPERTIES_DISABLED: + throw QueryRuntimeException("Can't set property because properties on edges are disabled."); + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when setting a property."); + } +} + /// Set a property `value` mapped with given `key` on a `record`. /// /// @throw QueryRuntimeException if value cannot be set as a property value template -storage::v3::PropertyValue PropsSetChecked(T *record, const storage::v3::PropertyId &key, const TypedValue &value) { +storage::v3::PropertyValue PropsSetChecked(T *record, const DbAccessor &dba, const storage::v3::PropertyId &key, + const TypedValue &value) { try { - auto maybe_old_value = record->SetProperty(key, storage::v3::PropertyValue(value)); - if (maybe_old_value.HasError()) { - switch (maybe_old_value.GetError()) { - case storage::v3::Error::SERIALIZATION_ERROR: - throw TransactionSerializationException(); - case storage::v3::Error::DELETED_OBJECT: - throw QueryRuntimeException("Trying to set properties on a deleted object."); - case storage::v3::Error::PROPERTIES_DISABLED: - throw QueryRuntimeException("Can't set property because properties on edges are disabled."); - case storage::v3::Error::VERTEX_HAS_EDGES: - case storage::v3::Error::NONEXISTENT_OBJECT: - throw QueryRuntimeException("Unexpected error when setting a property."); + if constexpr (std::is_same_v) { + const auto maybe_old_value = record->SetPropertyAndValidate(key, storage::v3::PropertyValue(value)); + if (maybe_old_value.HasError()) { + std::visit(utils::Overloaded{[](const storage::v3::Error error) { HandleErrorOnPropertyUpdate(error); }, + [&dba](const storage::v3::SchemaViolation &schema_violation) { + HandleSchemaViolation(schema_violation, dba); + }}, + maybe_old_value.GetError()); } + return std::move(*maybe_old_value); + } else { + // No validation on edge properties + const auto maybe_old_value = record->SetProperty(key, storage::v3::PropertyValue(value)); + if (maybe_old_value.HasError()) { + HandleErrorOnPropertyUpdate(maybe_old_value.GetError()); + } + return std::move(*maybe_old_value); } - return std::move(*maybe_old_value); } catch (const TypedValueException &) { throw QueryRuntimeException("'{}' cannot be used as a property value.", value.type()); } diff --git a/src/query/v2/db_accessor.hpp b/src/query/v2/db_accessor.hpp index 90ea6d431..a09bcffcb 100644 --- a/src/query/v2/db_accessor.hpp +++ b/src/query/v2/db_accessor.hpp @@ -12,6 +12,7 @@ #pragma once #include +#include #include #include @@ -23,7 +24,7 @@ /////////////////////////////////////////////////////////// // Our communication layer and query engine don't mix -// very well on Centos because OpenSSL version avaialable +// very well on Centos because OpenSSL version available // on Centos 7 include libkrb5 which has brilliant macros // called TRUE and FALSE. For more detailed explanation go // to memgraph.cpp. @@ -34,6 +35,8 @@ // simply undefine those macros as we're sure that libkrb5 // won't and can't be used anywhere in the query engine. #include "storage/v3/storage.hpp" +#include "utils/logging.hpp" +#include "utils/result.hpp" #undef FALSE #undef TRUE @@ -51,7 +54,6 @@ class EdgeAccessor final { public: storage::v3::EdgeAccessor impl_; - public: explicit EdgeAccessor(storage::v3::EdgeAccessor impl) : impl_(std::move(impl)) {} bool IsVisible(storage::v3::View view) const { return impl_.IsVisible(view); } @@ -99,17 +101,26 @@ class VertexAccessor final { static EdgeAccessor MakeEdgeAccessor(const storage::v3::EdgeAccessor impl) { return EdgeAccessor(impl); } - public: explicit VertexAccessor(storage::v3::VertexAccessor impl) : impl_(impl) {} bool IsVisible(storage::v3::View view) const { return impl_.IsVisible(view); } auto Labels(storage::v3::View view) const { return impl_.Labels(view); } + auto PrimaryLabel(storage::v3::View view) const { return impl_.PrimaryLabel(view); } + storage::v3::Result AddLabel(storage::v3::LabelId label) { return impl_.AddLabel(label); } + storage::v3::ResultSchema AddLabelAndValidate(storage::v3::LabelId label) { + return impl_.AddLabelAndValidate(label); + } + storage::v3::Result RemoveLabel(storage::v3::LabelId label) { return impl_.RemoveLabel(label); } + storage::v3::ResultSchema RemoveLabelAndValidate(storage::v3::LabelId label) { + return impl_.RemoveLabelAndValidate(label); + } + storage::v3::Result HasLabel(storage::v3::View view, storage::v3::LabelId label) const { return impl_.HasLabel(label, view); } @@ -126,8 +137,13 @@ class VertexAccessor final { return impl_.SetProperty(key, value); } - storage::v3::Result RemoveProperty(storage::v3::PropertyId key) { - return SetProperty(key, storage::v3::PropertyValue()); + storage::v3::ResultSchema SetPropertyAndValidate( + storage::v3::PropertyId key, const storage::v3::PropertyValue &value) { + return impl_.SetPropertyAndValidate(key, value); + } + + storage::v3::ResultSchema RemovePropertyAndValidate(storage::v3::PropertyId key) { + return SetPropertyAndValidate(key, storage::v3::PropertyValue{}); } storage::v3::Result> ClearProperties() { @@ -254,7 +270,18 @@ class DbAccessor final { return VerticesIterable(accessor_->Vertices(label, property, lower, upper, view)); } - VertexAccessor InsertVertex() { return VertexAccessor(accessor_->CreateVertex()); } + // TODO Remove when query modules have been fixed + [[deprecated]] VertexAccessor InsertVertex() { return VertexAccessor(accessor_->CreateVertex()); } + + storage::v3::ResultSchema InsertVertexAndValidate( + const storage::v3::LabelId primary_label, const std::vector &labels, + const std::vector> &properties) { + auto maybe_vertex_acc = accessor_->CreateVertexAndValidate(primary_label, labels, properties); + if (maybe_vertex_acc.HasError()) { + return {std::move(maybe_vertex_acc.GetError())}; + } + return VertexAccessor{maybe_vertex_acc.GetValue()}; + } storage::v3::Result InsertEdge(VertexAccessor *from, VertexAccessor *to, const storage::v3::EdgeTypeId &edge_type) { @@ -312,7 +339,7 @@ class DbAccessor final { return std::optional{}; } - return std::make_optional(*value); + return {std::make_optional(*value)}; } storage::v3::PropertyId NameToProperty(const std::string_view name) { return accessor_->NameToProperty(name); } @@ -361,6 +388,10 @@ class DbAccessor final { storage::v3::IndicesInfo ListAllIndices() const { return accessor_->ListAllIndices(); } storage::v3::ConstraintsInfo ListAllConstraints() const { return accessor_->ListAllConstraints(); } + + const storage::v3::SchemaValidator &GetSchemaValidator() const { return accessor_->GetSchemaValidator(); } + + storage::v3::SchemasInfo ListAllSchemas() const { return accessor_->ListAllSchemas(); } }; } // namespace memgraph::query::v2 diff --git a/src/query/v2/exceptions.hpp b/src/query/v2/exceptions.hpp index e0802a6cc..959672eae 100644 --- a/src/query/v2/exceptions.hpp +++ b/src/query/v2/exceptions.hpp @@ -224,4 +224,12 @@ class VersionInfoInMulticommandTxException : public QueryException { : QueryException("Version info query not allowed in multicommand transactions.") {} }; +/** + * An exception for an illegal operation that violates schema + */ +class SchemaViolationException : public QueryRuntimeException { + public: + using QueryRuntimeException::QueryRuntimeException; +}; + } // namespace memgraph::query::v2 diff --git a/src/query/v2/frontend/ast/ast.lcp b/src/query/v2/frontend/ast/ast.lcp index b858ab71f..023be58f2 100644 --- a/src/query/v2/frontend/ast/ast.lcp +++ b/src/query/v2/frontend/ast/ast.lcp @@ -134,6 +134,15 @@ cpp<# } cpp<#)) + +(defun clone-schema-property-vector (source dest) + #>cpp + ${dest}.reserve(${source}.size()); + for (const auto &[property_ix, property_type]: ${source}) { + ${dest}.emplace_back(storage->GetPropertyIx(property_ix.name), property_type); + } + cpp<#) + ;; The following index structs serve as a decoupling point of AST from ;; concrete database types. All the names are collected in AstStorage, and can ;; be indexed through these instances. This means that we can create a vector @@ -2256,7 +2265,7 @@ cpp<# (lcp:define-enum privilege (create delete match merge set remove index stats auth constraint dump replication durability read_file free_memory trigger config stream module_read module_write - websocket) + websocket schema) (:serialize)) #>cpp AuthQuery() = default; @@ -2298,7 +2307,7 @@ const std::vector kPrivilegesAll = { AuthQuery::Privilege::FREE_MEMORY, AuthQuery::Privilege::TRIGGER, AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM, AuthQuery::Privilege::MODULE_READ, AuthQuery::Privilege::MODULE_WRITE, - AuthQuery::Privilege::WEBSOCKET}; + AuthQuery::Privilege::WEBSOCKET, AuthQuery::Privilege::SCHEMA}; cpp<# (lcp:define-class info-query (query) @@ -2671,6 +2680,39 @@ cpp<# (:serialize (:slk)) (:clone)) +(lcp:define-class schema-query (query) + ((action "Action" :scope :public) + (label "LabelIx" :scope :public + :slk-load (lambda (member) + #>cpp + slk::Load(&self->${member}, reader, storage); + cpp<#) + :clone (lambda (source dest) + #>cpp + ${dest} = storage->GetLabelIx(${source}.name); + cpp<#)) + (schema_type_map "std::vector>" + :slk-save #'slk-save-property-map + :slk-load #'slk-load-property-map + :clone #'clone-schema-property-vector + :scope :public)) + + (:public + (lcp:define-enum action + (create-schema drop-schema show-schema show-schemas) + (:serialize)) + #>cpp + SchemaQuery() = default; + + DEFVISITABLE(QueryVisitor); + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + (lcp:pop-namespace) ;; namespace v2 (lcp:pop-namespace) ;; namespace query (lcp:pop-namespace) ;; namespace memgraph diff --git a/src/query/v2/frontend/ast/ast_visitor.hpp b/src/query/v2/frontend/ast/ast_visitor.hpp index 3cd7f9074..77c25cffb 100644 --- a/src/query/v2/frontend/ast/ast_visitor.hpp +++ b/src/query/v2/frontend/ast/ast_visitor.hpp @@ -94,6 +94,7 @@ class StreamQuery; class SettingQuery; class VersionQuery; class Foreach; +class SchemaQuery; using TreeCompositeVisitor = utils::CompositeVisitor< SingleQuery, CypherUnion, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, AdditionOperator, @@ -125,9 +126,9 @@ class ExpressionVisitor None, ParameterLookup, Identifier, PrimitiveLiteral, RegexMatch> {}; template -class QueryVisitor - : public utils::Visitor {}; +class QueryVisitor : public utils::Visitor {}; } // namespace memgraph::query::v2 diff --git a/src/query/v2/frontend/ast/cypher_main_visitor.cpp b/src/query/v2/frontend/ast/cypher_main_visitor.cpp index 976e3abfc..5c8304b5b 100644 --- a/src/query/v2/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/v2/frontend/ast/cypher_main_visitor.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -27,6 +28,7 @@ #include +#include "common/types.hpp" #include "query/v2/exceptions.hpp" #include "query/v2/frontend/ast/ast.hpp" #include "query/v2/frontend/ast/ast_visitor.hpp" @@ -275,18 +277,7 @@ antlrcpp::Any CypherMainVisitor::visitRegisterReplica(MemgraphCypher::RegisterRe replication_query->replica_name_ = std::any_cast(ctx->replicaName()->symbolicName()->accept(this)); if (ctx->SYNC()) { replication_query->sync_mode_ = memgraph::query::v2::ReplicationQuery::SyncMode::SYNC; - if (ctx->WITH() && ctx->TIMEOUT()) { - if (ctx->timeout->numberLiteral()) { - // we accept both double and integer literals - replication_query->timeout_ = std::any_cast(ctx->timeout->accept(this)); - } else { - throw SemanticException("Timeout should be a integer or double literal!"); - } - } } else if (ctx->ASYNC()) { - if (ctx->WITH() && ctx->TIMEOUT()) { - throw SyntaxException("Timeout can be set only for the SYNC replication mode!"); - } replication_query->sync_mode_ = memgraph::query::v2::ReplicationQuery::SyncMode::ASYNC; } @@ -1358,6 +1349,7 @@ antlrcpp::Any CypherMainVisitor::visitPrivilege(MemgraphCypher::PrivilegeContext if (ctx->MODULE_READ()) return AuthQuery::Privilege::MODULE_READ; if (ctx->MODULE_WRITE()) return AuthQuery::Privilege::MODULE_WRITE; if (ctx->WEBSOCKET()) return AuthQuery::Privilege::WEBSOCKET; + if (ctx->SCHEMA()) return AuthQuery::Privilege::SCHEMA; LOG_FATAL("Should not get here - unknown privilege!"); } @@ -2364,6 +2356,93 @@ antlrcpp::Any CypherMainVisitor::visitForeach(MemgraphCypher::ForeachContext *ct return for_each; } +antlrcpp::Any CypherMainVisitor::visitSchemaQuery(MemgraphCypher::SchemaQueryContext *ctx) { + MG_ASSERT(ctx->children.size() == 1, "SchemaQuery should have exactly one child!"); + auto *schema_query = std::any_cast(ctx->children[0]->accept(this)); + query_ = schema_query; + return schema_query; +} + +antlrcpp::Any CypherMainVisitor::visitShowSchema(MemgraphCypher::ShowSchemaContext *ctx) { + auto *schema_query = storage_->Create(); + schema_query->action_ = SchemaQuery::Action::SHOW_SCHEMA; + schema_query->label_ = AddLabel(std::any_cast(ctx->labelName()->accept(this))); + query_ = schema_query; + return schema_query; +} + +antlrcpp::Any CypherMainVisitor::visitShowSchemas(MemgraphCypher::ShowSchemasContext * /*ctx*/) { + auto *schema_query = storage_->Create(); + schema_query->action_ = SchemaQuery::Action::SHOW_SCHEMAS; + query_ = schema_query; + return schema_query; +} + +antlrcpp::Any CypherMainVisitor::visitPropertyType(MemgraphCypher::PropertyTypeContext *ctx) { + MG_ASSERT(ctx->symbolicName()); + const auto property_type = utils::ToLowerCase(std::any_cast(ctx->symbolicName()->accept(this))); + if (property_type == "bool") { + return common::SchemaType::BOOL; + } + if (property_type == "string") { + return common::SchemaType::STRING; + } + if (property_type == "integer") { + return common::SchemaType::INT; + } + if (property_type == "date") { + return common::SchemaType::DATE; + } + if (property_type == "duration") { + return common::SchemaType::DURATION; + } + if (property_type == "localdatetime") { + return common::SchemaType::LOCALDATETIME; + } + if (property_type == "localtime") { + return common::SchemaType::LOCALTIME; + } + throw SyntaxException("Property type must be one of the supported types!"); +} + +/** + * @return Schema* + */ +antlrcpp::Any CypherMainVisitor::visitSchemaPropertyMap(MemgraphCypher::SchemaPropertyMapContext *ctx) { + std::vector> schema_property_map; + for (auto *property_key_pair : ctx->propertyKeyTypePair()) { + auto key = std::any_cast(property_key_pair->propertyKeyName()->accept(this)); + auto type = std::any_cast(property_key_pair->propertyType()->accept(this)); + if (std::ranges::find_if(schema_property_map, [&key](const auto &elem) { return elem.first == key; }) != + schema_property_map.end()) { + throw SemanticException("Same property name can't appear twice in a schema map."); + } + schema_property_map.emplace_back(key, type); + } + return schema_property_map; +} + +antlrcpp::Any CypherMainVisitor::visitCreateSchema(MemgraphCypher::CreateSchemaContext *ctx) { + auto *schema_query = storage_->Create(); + schema_query->action_ = SchemaQuery::Action::CREATE_SCHEMA; + schema_query->label_ = AddLabel(std::any_cast(ctx->labelName()->accept(this))); + schema_query->schema_type_map_ = + std::any_cast>>(ctx->schemaPropertyMap()->accept(this)); + query_ = schema_query; + return schema_query; +} + +/** + * @return Schema* + */ +antlrcpp::Any CypherMainVisitor::visitDropSchema(MemgraphCypher::DropSchemaContext *ctx) { + auto *schema_query = storage_->Create(); + schema_query->action_ = SchemaQuery::Action::DROP_SCHEMA; + schema_query->label_ = AddLabel(std::any_cast(ctx->labelName()->accept(this))); + query_ = schema_query; + return schema_query; +} + LabelIx CypherMainVisitor::AddLabel(const std::string &name) { return storage_->GetLabelIx(name); } PropertyIx CypherMainVisitor::AddProperty(const std::string &name) { return storage_->GetPropertyIx(name); } diff --git a/src/query/v2/frontend/ast/cypher_main_visitor.hpp b/src/query/v2/frontend/ast/cypher_main_visitor.hpp index 767c8ce65..0052cd279 100644 --- a/src/query/v2/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/v2/frontend/ast/cypher_main_visitor.hpp @@ -849,6 +849,41 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { */ antlrcpp::Any visitForeach(MemgraphCypher::ForeachContext *ctx) override; + /** + * @return Schema* + */ + antlrcpp::Any visitPropertyType(MemgraphCypher::PropertyTypeContext *ctx) override; + + /** + * @return Schema* + */ + antlrcpp::Any visitSchemaPropertyMap(MemgraphCypher::SchemaPropertyMapContext *ctx) override; + + /** + * @return Schema* + */ + antlrcpp::Any visitSchemaQuery(MemgraphCypher::SchemaQueryContext *ctx) override; + + /** + * @return Schema* + */ + antlrcpp::Any visitShowSchema(MemgraphCypher::ShowSchemaContext *ctx) override; + + /** + * @return Schema* + */ + antlrcpp::Any visitShowSchemas(MemgraphCypher::ShowSchemasContext *ctx) override; + + /** + * @return Schema* + */ + antlrcpp::Any visitCreateSchema(MemgraphCypher::CreateSchemaContext *ctx) override; + + /** + * @return Schema* + */ + antlrcpp::Any visitDropSchema(MemgraphCypher::DropSchemaContext *ctx) override; + public: Query *query() { return query_; } const static std::string kAnonPrefix; diff --git a/src/query/v2/frontend/opencypher/grammar/MemgraphCypher.g4 b/src/query/v2/frontend/opencypher/grammar/MemgraphCypher.g4 index b412a474a..956320bbf 100644 --- a/src/query/v2/frontend/opencypher/grammar/MemgraphCypher.g4 +++ b/src/query/v2/frontend/opencypher/grammar/MemgraphCypher.g4 @@ -46,10 +46,10 @@ memgraphCypherKeyword : cypherKeyword | DROP | DUMP | EXECUTE - | FOR - | FOREACH | FREE | FROM + | FOR + | FOREACH | GLOBAL | GRANT | HEADER @@ -76,6 +76,8 @@ memgraphCypherKeyword : cypherKeyword | ROLE | ROLES | QUOTE + | SCHEMA + | SCHEMAS | SESSION | SETTING | SETTINGS @@ -122,6 +124,7 @@ query : cypherQuery | streamQuery | settingQuery | versionQuery + | schemaQuery ; authQuery : createRole @@ -192,6 +195,12 @@ settingQuery : setSetting | showSettings ; +schemaQuery : showSchema + | showSchemas + | createSchema + | dropSchema + ; + loadCsv : LOAD CSV FROM csvFile ( WITH | NO ) HEADER ( IGNORE BAD ) ? ( DELIMITER delimiter ) ? @@ -254,6 +263,7 @@ privilege : CREATE | MODULE_READ | MODULE_WRITE | WEBSOCKET + | SCHEMA ; privilegeList : privilege ( ',' privilege )* ; @@ -276,7 +286,6 @@ replicaName : symbolicName ; socketAddress : literal ; registerReplica : REGISTER REPLICA replicaName ( SYNC | ASYNC ) - ( WITH TIMEOUT timeout=literal ) ? TO socketAddress ; dropReplica : DROP REPLICA replicaName ; @@ -374,3 +383,17 @@ showSetting : SHOW DATABASE SETTING settingName ; showSettings : SHOW DATABASE SETTINGS ; versionQuery : SHOW VERSION ; + +showSchema : SHOW SCHEMA ON ':' labelName ; + +showSchemas : SHOW SCHEMAS ; + +propertyType : symbolicName ; + +propertyKeyTypePair : propertyKeyName propertyType ; + +schemaPropertyMap : '(' propertyKeyTypePair ( ',' propertyKeyTypePair )* ')' ; + +createSchema : CREATE SCHEMA ON ':' labelName schemaPropertyMap ; + +dropSchema : DROP SCHEMA ON ':' labelName ; diff --git a/src/query/v2/frontend/opencypher/grammar/MemgraphCypherLexer.g4 b/src/query/v2/frontend/opencypher/grammar/MemgraphCypherLexer.g4 index 55e5d53a2..869141033 100644 --- a/src/query/v2/frontend/opencypher/grammar/MemgraphCypherLexer.g4 +++ b/src/query/v2/frontend/opencypher/grammar/MemgraphCypherLexer.g4 @@ -89,6 +89,8 @@ REVOKE : R E V O K E ; ROLE : R O L E ; ROLES : R O L E S ; QUOTE : Q U O T E ; +SCHEMA : S C H E M A ; +SCHEMAS : S C H E M A S ; SERVICE_URL : S E R V I C E UNDERSCORE U R L ; SESSION : S E S S I O N ; SETTING : S E T T I N G ; diff --git a/src/query/v2/frontend/semantic/required_privileges.cpp b/src/query/v2/frontend/semantic/required_privileges.cpp index 0790529cf..df160fac1 100644 --- a/src/query/v2/frontend/semantic/required_privileges.cpp +++ b/src/query/v2/frontend/semantic/required_privileges.cpp @@ -80,6 +80,8 @@ class PrivilegeExtractor : public QueryVisitor, public HierarchicalTreeVis void Visit(VersionQuery & /*version_query*/) override { AddPrivilege(AuthQuery::Privilege::STATS); } + void Visit(SchemaQuery & /*schema_query*/) override { AddPrivilege(AuthQuery::Privilege::SCHEMA); } + bool PreVisit(Create & /*unused*/) override { AddPrivilege(AuthQuery::Privilege::CREATE); return false; diff --git a/src/query/v2/frontend/stripped_lexer_constants.hpp b/src/query/v2/frontend/stripped_lexer_constants.hpp index 4e52fbdc4..df52066fc 100644 --- a/src/query/v2/frontend/stripped_lexer_constants.hpp +++ b/src/query/v2/frontend/stripped_lexer_constants.hpp @@ -204,8 +204,9 @@ const trie::Trie kKeywords = {"union", "pulsar", "service_url", "version", - "websocket" - "foreach"}; + "websocket", + "foreach", + "schema"}; // Unicode codepoints that are allowed at the start of the unescaped name. const std::bitset kUnescapedNameAllowedStarts( diff --git a/src/query/v2/interpreter.cpp b/src/query/v2/interpreter.cpp index 73f3e00be..c65b5e971 100644 --- a/src/query/v2/interpreter.cpp +++ b/src/query/v2/interpreter.cpp @@ -877,6 +877,102 @@ Callback HandleSettingQuery(SettingQuery *setting_query, const Parameters ¶m } } +Callback HandleSchemaQuery(SchemaQuery *schema_query, InterpreterContext *interpreter_context, + std::vector *notifications) { + Callback callback; + switch (schema_query->action_) { + case SchemaQuery::Action::SHOW_SCHEMAS: { + callback.header = {"label", "primary_key"}; + callback.fn = [interpreter_context]() { + auto *db = interpreter_context->db; + auto schemas_info = db->ListAllSchemas(); + std::vector> results; + results.reserve(schemas_info.schemas.size()); + + for (const auto &[label_id, schema_types] : schemas_info.schemas) { + std::vector schema_info_row; + schema_info_row.reserve(3); + + schema_info_row.emplace_back(db->LabelToName(label_id)); + std::vector primary_key_properties; + primary_key_properties.reserve(schema_types.size()); + std::transform(schema_types.begin(), schema_types.end(), std::back_inserter(primary_key_properties), + [&db](const auto &schema_type) { + return db->PropertyToName(schema_type.property_id) + + "::" + storage::v3::SchemaTypeToString(schema_type.type); + }); + + schema_info_row.emplace_back(utils::Join(primary_key_properties, ", ")); + results.push_back(std::move(schema_info_row)); + } + return results; + }; + return callback; + } + case SchemaQuery::Action::SHOW_SCHEMA: { + callback.header = {"property_name", "property_type"}; + callback.fn = [interpreter_context, primary_label = schema_query->label_]() { + auto *db = interpreter_context->db; + const auto label = db->NameToLabel(primary_label.name); + const auto *schema = db->GetSchema(label); + std::vector> results; + if (schema) { + for (const auto &schema_property : schema->second) { + std::vector schema_info_row; + schema_info_row.reserve(2); + schema_info_row.emplace_back(db->PropertyToName(schema_property.property_id)); + schema_info_row.emplace_back(storage::v3::SchemaTypeToString(schema_property.type)); + results.push_back(std::move(schema_info_row)); + } + return results; + } + throw QueryException(fmt::format("Schema on label :{} not found!", primary_label.name)); + }; + return callback; + } + case SchemaQuery::Action::CREATE_SCHEMA: { + auto schema_type_map = schema_query->schema_type_map_; + if (schema_query->schema_type_map_.empty()) { + throw SyntaxException("One or more types have to be defined in schema definition."); + } + callback.fn = [interpreter_context, primary_label = schema_query->label_, + schema_type_map = std::move(schema_type_map)]() { + auto *db = interpreter_context->db; + const auto label = db->NameToLabel(primary_label.name); + std::vector schemas_types; + schemas_types.reserve(schema_type_map.size()); + for (const auto &schema_type : schema_type_map) { + auto property_id = db->NameToProperty(schema_type.first.name); + schemas_types.push_back({property_id, schema_type.second}); + } + if (!db->CreateSchema(label, schemas_types)) { + throw QueryException(fmt::format("Schema on label :{} already exists!", primary_label.name)); + } + return std::vector>{}; + }; + notifications->emplace_back(SeverityLevel::INFO, NotificationCode::CREATE_SCHEMA, + fmt::format("Create schema on label :{}", schema_query->label_.name)); + return callback; + } + case SchemaQuery::Action::DROP_SCHEMA: { + callback.fn = [interpreter_context, primary_label = schema_query->label_]() { + auto *db = interpreter_context->db; + const auto label = db->NameToLabel(primary_label.name); + + if (!db->DropSchema(label)) { + throw QueryException(fmt::format("Schema on label :{} does not exist!", primary_label.name)); + } + + return std::vector>{}; + }; + notifications->emplace_back(SeverityLevel::INFO, NotificationCode::DROP_SCHEMA, + fmt::format("Dropped schema on label :{}", schema_query->label_.name)); + return callback; + } + } + return callback; +} + // Struct for lazy pulling from a vector struct PullPlanVector { explicit PullPlanVector(std::vector> values) : values_(std::move(values)) {} @@ -2072,6 +2168,32 @@ PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_ RWType::NONE}; } +PreparedQuery PrepareSchemaQuery(ParsedQuery parsed_query, bool in_explicit_transaction, + InterpreterContext *interpreter_context, std::vector *notifications) { + if (in_explicit_transaction) { + throw ConstraintInMulticommandTxException(); + } + auto *schema_query = utils::Downcast(parsed_query.query); + MG_ASSERT(schema_query); + auto callback = HandleSchemaQuery(schema_query, interpreter_context, notifications); + + return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), + [handler = std::move(callback.fn), action = QueryHandlerResult::NOTHING, + pull_plan = std::shared_ptr(nullptr)]( + AnyStream *stream, std::optional n) mutable -> std::optional { + if (!pull_plan) { + auto results = handler(); + pull_plan = std::make_shared(std::move(results)); + } + + if (pull_plan->Pull(stream, n)) { + return action; + } + return std::nullopt; + }, + RWType::NONE}; +} + void Interpreter::BeginTransaction() { const auto prepared_query = PrepareTransactionQuery("BEGIN"); prepared_query.query_handler(nullptr, {}); @@ -2205,6 +2327,9 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, prepared_query = PrepareSettingQuery(std::move(parsed_query), in_explicit_transaction_, &*execution_db_accessor_); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareVersionQuery(std::move(parsed_query), in_explicit_transaction_); + } else if (utils::Downcast(parsed_query.query)) { + prepared_query = PrepareSchemaQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_, + &query_execution->notifications); } else { LOG_FATAL("Should not get here -- unknown query type!"); } diff --git a/src/query/v2/metadata.cpp b/src/query/v2/metadata.cpp index fe7461e79..7735edacf 100644 --- a/src/query/v2/metadata.cpp +++ b/src/query/v2/metadata.cpp @@ -38,6 +38,8 @@ constexpr std::string_view GetCodeString(const NotificationCode code) { return "CreateIndex"sv; case NotificationCode::CREATE_STREAM: return "CreateStream"sv; + case NotificationCode::CREATE_SCHEMA: + return "CreateSchema"sv; case NotificationCode::CHECK_STREAM: return "CheckStream"sv; case NotificationCode::CREATE_TRIGGER: @@ -48,6 +50,8 @@ constexpr std::string_view GetCodeString(const NotificationCode code) { return "DropReplica"sv; case NotificationCode::DROP_INDEX: return "DropIndex"sv; + case NotificationCode::DROP_SCHEMA: + return "DropSchema"sv; case NotificationCode::DROP_STREAM: return "DropStream"sv; case NotificationCode::DROP_TRIGGER: @@ -68,6 +72,10 @@ constexpr std::string_view GetCodeString(const NotificationCode code) { return "ReplicaPortWarning"sv; case NotificationCode::SET_REPLICA: return "SetReplica"sv; + case NotificationCode::SHOW_SCHEMA: + return "ShowSchema"sv; + case NotificationCode::SHOW_SCHEMAS: + return "ShowSchemas"sv; case NotificationCode::START_STREAM: return "StartStream"sv; case NotificationCode::START_ALL_STREAMS: diff --git a/src/query/v2/metadata.hpp b/src/query/v2/metadata.hpp index ffc621d64..4a5e7b9c9 100644 --- a/src/query/v2/metadata.hpp +++ b/src/query/v2/metadata.hpp @@ -26,12 +26,14 @@ enum class SeverityLevel : uint8_t { INFO, WARNING }; enum class NotificationCode : uint8_t { CREATE_CONSTRAINT, CREATE_INDEX, + CREATE_SCHEMA, CHECK_STREAM, CREATE_STREAM, CREATE_TRIGGER, DROP_CONSTRAINT, DROP_INDEX, DROP_REPLICA, + DROP_SCHEMA, DROP_STREAM, DROP_TRIGGER, EXISTANT_INDEX, @@ -42,6 +44,8 @@ enum class NotificationCode : uint8_t { REPLICA_PORT_WARNING, REGISTER_REPLICA, SET_REPLICA, + SHOW_SCHEMA, + SHOW_SCHEMAS, START_STREAM, START_ALL_STREAMS, STOP_STREAM, diff --git a/src/query/v2/plan/operator.cpp b/src/query/v2/plan/operator.cpp index 4dd0bf693..432a8c7e8 100644 --- a/src/query/v2/plan/operator.cpp +++ b/src/query/v2/plan/operator.cpp @@ -52,6 +52,7 @@ #include "utils/readable_size.hpp" #include "utils/string.hpp" #include "utils/temporal.hpp" +#include "utils/variant_helpers.hpp" // macro for the default implementation of LogicalOperator::Accept // that accepts the visitor and visits it's input_ operator @@ -174,45 +175,56 @@ CreateNode::CreateNode(const std::shared_ptr &input, const Node // Creates a vertex on this GraphDb. Returns a reference to vertex placed on the // frame. -VertexAccessor &CreateLocalVertex(const NodeCreationInfo &node_info, Frame *frame, ExecutionContext &context) { +VertexAccessor &CreateLocalVertexAtomically(const NodeCreationInfo &node_info, Frame *frame, + ExecutionContext &context) { auto &dba = *context.db_accessor; - auto new_node = dba.InsertVertex(); - context.execution_stats[ExecutionStats::Key::CREATED_NODES] += 1; - for (auto label : node_info.labels) { - auto maybe_error = new_node.AddLabel(label); - if (maybe_error.HasError()) { - switch (maybe_error.GetError()) { - case storage::v3::Error::SERIALIZATION_ERROR: - throw TransactionSerializationException(); - case storage::v3::Error::DELETED_OBJECT: - throw QueryRuntimeException("Trying to set a label on a deleted node."); - case storage::v3::Error::VERTEX_HAS_EDGES: - case storage::v3::Error::PROPERTIES_DISABLED: - case storage::v3::Error::NONEXISTENT_OBJECT: - throw QueryRuntimeException("Unexpected error when setting a label."); - } - } - context.execution_stats[ExecutionStats::Key::CREATED_LABELS] += 1; - } // Evaluator should use the latest accessors, as modified in this query, when // setting properties on new nodes. ExpressionEvaluator evaluator(frame, context.symbol_table, context.evaluation_context, context.db_accessor, storage::v3::View::NEW); - // TODO: PropsSetChecked allocates a PropertyValue, make it use context.memory - // when we update PropertyValue with custom allocator. + + std::vector> properties; if (const auto *node_info_properties = std::get_if(&node_info.properties)) { + properties.reserve(node_info_properties->size()); for (const auto &[key, value_expression] : *node_info_properties) { - PropsSetChecked(&new_node, key, value_expression->Accept(evaluator)); + properties.emplace_back(key, storage::v3::PropertyValue(value_expression->Accept(evaluator))); } } else { - auto property_map = evaluator.Visit(*std::get(node_info.properties)); - for (const auto &[key, value] : property_map.ValueMap()) { + auto property_map = evaluator.Visit(*std::get(node_info.properties)).ValueMap(); + properties.reserve(property_map.size()); + + for (const auto &[key, value] : property_map) { auto property_id = dba.NameToProperty(key); - PropsSetChecked(&new_node, property_id, value); + properties.emplace_back(property_id, value); } } + // TODO Remove later on since that will be enforced from grammar side + MG_ASSERT(!node_info.labels.empty(), "There must be at least one label!"); + const auto primary_label = node_info.labels[0]; + std::vector secondary_labels(node_info.labels.begin() + 1, node_info.labels.end()); + auto maybe_new_node = dba.InsertVertexAndValidate(primary_label, secondary_labels, properties); + if (maybe_new_node.HasError()) { + std::visit(utils::Overloaded{[&dba](const storage::v3::SchemaViolation &schema_violation) { + HandleSchemaViolation(schema_violation, dba); + }, + [](const storage::v3::Error error) { + switch (error) { + case storage::v3::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to set a label on a deleted node."); + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + case storage::v3::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when setting a label."); + } + }}, + maybe_new_node.GetError()); + } - (*frame)[node_info.symbol] = new_node; + context.execution_stats[ExecutionStats::Key::CREATED_NODES] += 1; + + (*frame)[node_info.symbol] = *maybe_new_node; return (*frame)[node_info.symbol].ValueVertex(); } @@ -237,7 +249,7 @@ bool CreateNode::CreateNodeCursor::Pull(Frame &frame, ExecutionContext &context) SCOPED_PROFILE_OP("CreateNode"); if (input_cursor_->Pull(frame, context)) { - auto created_vertex = CreateLocalVertex(self_.node_info_, &frame, context); + auto created_vertex = CreateLocalVertexAtomically(self_.node_info_, &frame, context); if (context.trigger_context_collector) { context.trigger_context_collector->RegisterCreatedObject(created_vertex); } @@ -286,13 +298,13 @@ EdgeAccessor CreateEdge(const EdgeCreationInfo &edge_info, DbAccessor *dba, Vert auto &edge = *maybe_edge; if (const auto *properties = std::get_if(&edge_info.properties)) { for (const auto &[key, value_expression] : *properties) { - PropsSetChecked(&edge, key, value_expression->Accept(*evaluator)); + PropsSetChecked(&edge, *dba, key, value_expression->Accept(*evaluator)); } } else { auto property_map = evaluator->Visit(*std::get(edge_info.properties)); for (const auto &[key, value] : property_map.ValueMap()) { auto property_id = dba->NameToProperty(key); - PropsSetChecked(&edge, property_id, value); + PropsSetChecked(&edge, *dba, property_id, value); } } @@ -368,13 +380,12 @@ VertexAccessor &CreateExpand::CreateExpandCursor::OtherVertex(Frame &frame, Exec TypedValue &dest_node_value = frame[self_.node_info_.symbol]; ExpectType(self_.node_info_.symbol, dest_node_value, TypedValue::Type::Vertex); return dest_node_value.ValueVertex(); - } else { - auto &created_vertex = CreateLocalVertex(self_.node_info_, &frame, context); - if (context.trigger_context_collector) { - context.trigger_context_collector->RegisterCreatedObject(created_vertex); - } - return created_vertex; } + auto &created_vertex = CreateLocalVertexAtomically(self_.node_info_, &frame, context); + if (context.trigger_context_collector) { + context.trigger_context_collector->RegisterCreatedObject(created_vertex); + } + return created_vertex; } template @@ -2050,7 +2061,7 @@ bool SetProperty::SetPropertyCursor::Pull(Frame &frame, ExecutionContext &contex switch (lhs.type()) { case TypedValue::Type::Vertex: { - auto old_value = PropsSetChecked(&lhs.ValueVertex(), self_.property_, rhs); + auto old_value = PropsSetChecked(&lhs.ValueVertex(), *context.db_accessor, self_.property_, rhs); context.execution_stats[ExecutionStats::Key::UPDATED_PROPERTIES] += 1; if (context.trigger_context_collector) { // rhs cannot be moved because it was created with the allocator that is only valid during current pull @@ -2060,7 +2071,7 @@ bool SetProperty::SetPropertyCursor::Pull(Frame &frame, ExecutionContext &contex break; } case TypedValue::Type::Edge: { - auto old_value = PropsSetChecked(&lhs.ValueEdge(), self_.property_, rhs); + auto old_value = PropsSetChecked(&lhs.ValueEdge(), *context.db_accessor, self_.property_, rhs); context.execution_stats[ExecutionStats::Key::UPDATED_PROPERTIES] += 1; if (context.trigger_context_collector) { // rhs cannot be moved because it was created with the allocator that is only valid during current pull @@ -2216,7 +2227,7 @@ void SetPropertiesOnRecord(TRecordAccessor *record, const TypedValue &rhs, SetPr case TypedValue::Type::Map: { for (const auto &kv : rhs.ValueMap()) { auto key = context->db_accessor->NameToProperty(kv.first); - auto old_value = PropsSetChecked(record, key, kv.second); + auto old_value = PropsSetChecked(record, *context->db_accessor, key, kv.second); if (should_register_change) { register_set_property(std::move(old_value), key, kv.second); } @@ -2300,22 +2311,31 @@ bool SetLabels::SetLabelsCursor::Pull(Frame &frame, ExecutionContext &context) { // Skip setting labels on Null (can occur in optional match). if (vertex_value.IsNull()) return true; ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); + + auto &dba = *context.db_accessor; auto &vertex = vertex_value.ValueVertex(); - for (auto label : self_.labels_) { - auto maybe_value = vertex.AddLabel(label); + for (const auto label : self_.labels_) { + auto maybe_value = vertex.AddLabelAndValidate(label); if (maybe_value.HasError()) { - switch (maybe_value.GetError()) { - case storage::v3::Error::SERIALIZATION_ERROR: - throw TransactionSerializationException(); - case storage::v3::Error::DELETED_OBJECT: - throw QueryRuntimeException("Trying to set a label on a deleted node."); - case storage::v3::Error::VERTEX_HAS_EDGES: - case storage::v3::Error::PROPERTIES_DISABLED: - case storage::v3::Error::NONEXISTENT_OBJECT: - throw QueryRuntimeException("Unexpected error when setting a label."); - } + std::visit(utils::Overloaded{[](const storage::v3::Error error) { + switch (error) { + case storage::v3::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to set a label on a deleted node."); + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + case storage::v3::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when setting a label."); + } + }, + [&dba](const storage::v3::SchemaViolation schema_violation) { + HandleSchemaViolation(schema_violation, dba); + }}, + maybe_value.GetError()); } + context.execution_stats[ExecutionStats::Key::CREATED_LABELS]++; if (context.trigger_context_collector && *maybe_value) { context.trigger_context_collector->RegisterSetVertexLabel(vertex, label); } @@ -2358,26 +2378,11 @@ bool RemoveProperty::RemovePropertyCursor::Pull(Frame &frame, ExecutionContext & TypedValue lhs = self_.lhs_->expression_->Accept(evaluator); auto remove_prop = [property = self_.property_, &context](auto *record) { - auto maybe_old_value = record->RemoveProperty(property); - if (maybe_old_value.HasError()) { - switch (maybe_old_value.GetError()) { - case storage::v3::Error::DELETED_OBJECT: - throw QueryRuntimeException("Trying to remove a property on a deleted graph element."); - case storage::v3::Error::SERIALIZATION_ERROR: - throw TransactionSerializationException(); - case storage::v3::Error::PROPERTIES_DISABLED: - throw QueryRuntimeException( - "Can't remove property because properties on edges are " - "disabled."); - case storage::v3::Error::VERTEX_HAS_EDGES: - case storage::v3::Error::NONEXISTENT_OBJECT: - throw QueryRuntimeException("Unexpected error when removing property."); - } - } + auto old_value = PropsSetChecked(record, *context.db_accessor, property, TypedValue{}); if (context.trigger_context_collector) { context.trigger_context_collector->RegisterRemovedObjectProperty(*record, property, - TypedValue(std::move(*maybe_old_value))); + TypedValue(std::move(old_value))); } }; @@ -2431,18 +2436,25 @@ bool RemoveLabels::RemoveLabelsCursor::Pull(Frame &frame, ExecutionContext &cont ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); auto &vertex = vertex_value.ValueVertex(); for (auto label : self_.labels_) { - auto maybe_value = vertex.RemoveLabel(label); + auto maybe_value = vertex.RemoveLabelAndValidate(label); if (maybe_value.HasError()) { - switch (maybe_value.GetError()) { - case storage::v3::Error::SERIALIZATION_ERROR: - throw TransactionSerializationException(); - case storage::v3::Error::DELETED_OBJECT: - throw QueryRuntimeException("Trying to remove labels from a deleted node."); - case storage::v3::Error::VERTEX_HAS_EDGES: - case storage::v3::Error::PROPERTIES_DISABLED: - case storage::v3::Error::NONEXISTENT_OBJECT: - throw QueryRuntimeException("Unexpected error when removing labels from a node."); - } + std::visit( + utils::Overloaded{[](const storage::v3::Error error) { + switch (error) { + case storage::v3::Error::SERIALIZATION_ERROR: + throw TransactionSerializationException(); + case storage::v3::Error::DELETED_OBJECT: + throw QueryRuntimeException("Trying to remove labels from a deleted node."); + case storage::v3::Error::VERTEX_HAS_EDGES: + case storage::v3::Error::PROPERTIES_DISABLED: + case storage::v3::Error::NONEXISTENT_OBJECT: + throw QueryRuntimeException("Unexpected error when removing labels from a node."); + } + }, + [&context](const storage::v3::SchemaViolation &schema_violation) { + HandleSchemaViolation(schema_violation, *context.db_accessor); + }}, + maybe_value.GetError()); } context.execution_stats[ExecutionStats::Key::DELETED_LABELS] += 1; diff --git a/src/storage/v2/CMakeLists.txt b/src/storage/v2/CMakeLists.txt index 52ebfdd1f..d46563183 100644 --- a/src/storage/v2/CMakeLists.txt +++ b/src/storage/v2/CMakeLists.txt @@ -10,8 +10,6 @@ set(storage_v2_src_files indices.cpp property_store.cpp vertex_accessor.cpp - schemas.cpp - schema_validator.cpp storage.cpp) ##### Replication ##### diff --git a/src/storage/v2/constraints.cpp b/src/storage/v2/constraints.cpp index 5e5988099..fab6ee4c4 100644 --- a/src/storage/v2/constraints.cpp +++ b/src/storage/v2/constraints.cpp @@ -16,7 +16,6 @@ #include #include "storage/v2/mvcc.hpp" -#include "storage/v2/vertex.hpp" #include "utils/logging.hpp" namespace memgraph::storage { @@ -60,7 +59,7 @@ bool LastCommittedVersionHasLabelProperty(const Vertex &vertex, LabelId label, c std::lock_guard guard(vertex.lock); delta = vertex.delta; deleted = vertex.deleted; - has_label = VertexHasLabel(vertex, label); + has_label = utils::Contains(vertex.labels, label); size_t i = 0; for (const auto &property : properties) { @@ -143,7 +142,7 @@ bool AnyVersionHasLabelProperty(const Vertex &vertex, LabelId label, const std:: Delta *delta; { std::lock_guard guard(vertex.lock); - has_label = VertexHasLabel(vertex, label); + has_label = utils::Contains(vertex.labels, label); deleted = vertex.deleted; delta = vertex.delta; @@ -268,7 +267,7 @@ bool UniqueConstraints::Entry::operator==(const std::vector &rhs) void UniqueConstraints::UpdateBeforeCommit(const Vertex *vertex, const Transaction &tx) { for (auto &[label_props, storage] : constraints_) { - if (!VertexHasLabel(*vertex, label_props.first)) { + if (!utils::Contains(vertex->labels, label_props.first)) { continue; } auto values = ExtractPropertyValues(*vertex, label_props.second); @@ -302,7 +301,7 @@ utils::BasicResult Uniqu auto acc = constraint->second.access(); for (const Vertex &vertex : vertices) { - if (vertex.deleted || !VertexHasLabel(vertex, label)) { + if (vertex.deleted || !utils::Contains(vertex.labels, label)) { continue; } auto values = ExtractPropertyValues(vertex, properties); @@ -353,7 +352,7 @@ std::optional UniqueConstraints::Validate(const Vertex &ver for (const auto &[label_props, storage] : constraints_) { const auto &label = label_props.first; const auto &properties = label_props.second; - if (!VertexHasLabel(vertex, label)) { + if (!utils::Contains(vertex.labels, label)) { continue; } diff --git a/src/storage/v2/constraints.hpp b/src/storage/v2/constraints.hpp index 427b6ca4f..b209437f8 100644 --- a/src/storage/v2/constraints.hpp +++ b/src/storage/v2/constraints.hpp @@ -158,7 +158,7 @@ inline utils::BasicResult CreateExistenceConstraint( return false; } for (const auto &vertex : vertices) { - if (!vertex.deleted && VertexHasLabel(vertex, label) && !vertex.properties.HasProperty(property)) { + if (!vertex.deleted && utils::Contains(vertex.labels, label) && !vertex.properties.HasProperty(property)) { return ConstraintViolation{ConstraintViolation::Type::EXISTENCE, label, std::set{property}}; } } @@ -184,7 +184,7 @@ inline bool DropExistenceConstraint(Constraints *constraints, LabelId label, Pro [[nodiscard]] inline std::optional ValidateExistenceConstraints(const Vertex &vertex, const Constraints &constraints) { for (const auto &[label, property] : constraints.existence_constraints) { - if (!vertex.deleted && VertexHasLabel(vertex, label) && !vertex.properties.HasProperty(property)) { + if (!vertex.deleted && utils::Contains(vertex.labels, label) && !vertex.properties.HasProperty(property)) { return ConstraintViolation{ConstraintViolation::Type::EXISTENCE, label, std::set{property}}; } } diff --git a/src/storage/v2/durability/snapshot.cpp b/src/storage/v2/durability/snapshot.cpp index 002da25fa..16c7d017c 100644 --- a/src/storage/v2/durability/snapshot.cpp +++ b/src/storage/v2/durability/snapshot.cpp @@ -628,9 +628,8 @@ RecoveredSnapshot LoadSnapshot(const std::filesystem::path &path, utils::SkipLis void CreateSnapshot(Transaction *transaction, const std::filesystem::path &snapshot_directory, const std::filesystem::path &wal_directory, uint64_t snapshot_retention_count, utils::SkipList *vertices, utils::SkipList *edges, NameIdMapper *name_id_mapper, - Indices *indices, Constraints *constraints, Config::Items items, - const SchemaValidator &schema_validator, const std::string &uuid, const std::string_view epoch_id, - const std::deque> &epoch_history, + Indices *indices, Constraints *constraints, Config::Items items, const std::string &uuid, + const std::string_view epoch_id, const std::deque> &epoch_history, utils::FileRetainer *file_retainer) { // Ensure that the storage directory exists. utils::EnsureDirOrDie(snapshot_directory); @@ -714,9 +713,8 @@ void CreateSnapshot(Transaction *transaction, const std::filesystem::path &snaps // type and invalid from/to pointers because we don't know them here, // but that isn't an issue because we won't use that part of the API // here. - // TODO(jbajic) Fix snapshot with new schema rules - auto ea = EdgeAccessor{edge_ref, EdgeTypeId::FromUint(0UL), nullptr, nullptr, transaction, indices, constraints, - items, schema_validator}; + auto ea = + EdgeAccessor{edge_ref, EdgeTypeId::FromUint(0UL), nullptr, nullptr, transaction, indices, constraints, items}; // Get edge data. auto maybe_props = ea.Properties(View::OLD); @@ -744,7 +742,7 @@ void CreateSnapshot(Transaction *transaction, const std::filesystem::path &snaps auto acc = vertices->access(); for (auto &vertex : acc) { // The visibility check is implemented for vertices so we use it here. - auto va = VertexAccessor::Create(&vertex, transaction, indices, constraints, items, schema_validator, View::OLD); + auto va = VertexAccessor::Create(&vertex, transaction, indices, constraints, items, View::OLD); if (!va) continue; // Get vertex data. diff --git a/src/storage/v2/durability/snapshot.hpp b/src/storage/v2/durability/snapshot.hpp index 643c1a34c..b1cfad63c 100644 --- a/src/storage/v2/durability/snapshot.hpp +++ b/src/storage/v2/durability/snapshot.hpp @@ -21,7 +21,6 @@ #include "storage/v2/edge.hpp" #include "storage/v2/indices.hpp" #include "storage/v2/name_id_mapper.hpp" -#include "storage/v2/schema_validator.hpp" #include "storage/v2/transaction.hpp" #include "storage/v2/vertex.hpp" #include "utils/file_locker.hpp" @@ -69,9 +68,8 @@ RecoveredSnapshot LoadSnapshot(const std::filesystem::path &path, utils::SkipLis void CreateSnapshot(Transaction *transaction, const std::filesystem::path &snapshot_directory, const std::filesystem::path &wal_directory, uint64_t snapshot_retention_count, utils::SkipList *vertices, utils::SkipList *edges, NameIdMapper *name_id_mapper, - Indices *indices, Constraints *constraints, Config::Items items, - const SchemaValidator &schema_validator, const std::string &uuid, std::string_view epoch_id, - const std::deque> &epoch_history, + Indices *indices, Constraints *constraints, Config::Items items, const std::string &uuid, + std::string_view epoch_id, const std::deque> &epoch_history, utils::FileRetainer *file_retainer); } // namespace memgraph::storage::durability diff --git a/src/storage/v2/edge_accessor.cpp b/src/storage/v2/edge_accessor.cpp index acb3ec288..ef0444422 100644 --- a/src/storage/v2/edge_accessor.cpp +++ b/src/storage/v2/edge_accessor.cpp @@ -15,7 +15,6 @@ #include "storage/v2/mvcc.hpp" #include "storage/v2/property_value.hpp" -#include "storage/v2/schema_validator.hpp" #include "storage/v2/vertex_accessor.hpp" #include "utils/memory_tracker.hpp" @@ -55,11 +54,11 @@ bool EdgeAccessor::IsVisible(const View view) const { } VertexAccessor EdgeAccessor::FromVertex() const { - return VertexAccessor{from_vertex_, transaction_, indices_, constraints_, config_, *schema_validator_}; + return VertexAccessor{from_vertex_, transaction_, indices_, constraints_, config_}; } VertexAccessor EdgeAccessor::ToVertex() const { - return VertexAccessor{to_vertex_, transaction_, indices_, constraints_, config_, *schema_validator_}; + return VertexAccessor{to_vertex_, transaction_, indices_, constraints_, config_}; } Result EdgeAccessor::SetProperty(PropertyId property, const PropertyValue &value) { diff --git a/src/storage/v2/edge_accessor.hpp b/src/storage/v2/edge_accessor.hpp index 11abad3a1..b0a1e1151 100644 --- a/src/storage/v2/edge_accessor.hpp +++ b/src/storage/v2/edge_accessor.hpp @@ -18,7 +18,6 @@ #include "storage/v2/config.hpp" #include "storage/v2/result.hpp" -#include "storage/v2/schema_validator.hpp" #include "storage/v2/transaction.hpp" #include "storage/v2/view.hpp" @@ -35,8 +34,7 @@ class EdgeAccessor final { public: EdgeAccessor(EdgeRef edge, EdgeTypeId edge_type, Vertex *from_vertex, Vertex *to_vertex, Transaction *transaction, - Indices *indices, Constraints *constraints, Config::Items config, - const SchemaValidator &schema_validator, bool for_deleted = false) + Indices *indices, Constraints *constraints, Config::Items config, bool for_deleted = false) : edge_(edge), edge_type_(edge_type), from_vertex_(from_vertex), @@ -45,7 +43,6 @@ class EdgeAccessor final { indices_(indices), constraints_(constraints), config_(config), - schema_validator_{&schema_validator}, for_deleted_(for_deleted) {} /// @return true if the object is visible from the current transaction @@ -95,7 +92,6 @@ class EdgeAccessor final { Indices *indices_; Constraints *constraints_; Config::Items config_; - const SchemaValidator *schema_validator_; // if the accessor was created for a deleted edge. // Accessor behaves differently for some methods based on this diff --git a/src/storage/v2/indices.cpp b/src/storage/v2/indices.cpp index 2a1c1ff02..fb83ff166 100644 --- a/src/storage/v2/indices.cpp +++ b/src/storage/v2/indices.cpp @@ -14,7 +14,6 @@ #include "storage/v2/mvcc.hpp" #include "storage/v2/property_value.hpp" -#include "storage/v2/schema_validator.hpp" #include "utils/bound.hpp" #include "utils/logging.hpp" #include "utils/memory_tracker.hpp" @@ -328,7 +327,7 @@ void LabelIndex::RemoveObsoleteEntries(uint64_t oldest_active_start_timestamp) { LabelIndex::Iterable::Iterator::Iterator(Iterable *self, utils::SkipList::Iterator index_iterator) : self_(self), index_iterator_(index_iterator), - current_vertex_accessor_(nullptr, nullptr, nullptr, nullptr, self_->config_, *self_->schema_validator_), + current_vertex_accessor_(nullptr, nullptr, nullptr, nullptr, self_->config_), current_vertex_(nullptr) { AdvanceUntilValid(); } @@ -346,8 +345,8 @@ void LabelIndex::Iterable::Iterator::AdvanceUntilValid() { } if (CurrentVersionHasLabel(*index_iterator_->vertex, self_->label_, self_->transaction_, self_->view_)) { current_vertex_ = index_iterator_->vertex; - current_vertex_accessor_ = VertexAccessor{current_vertex_, self_->transaction_, self_->indices_, - self_->constraints_, self_->config_, *self_->schema_validator_}; + current_vertex_accessor_ = + VertexAccessor{current_vertex_, self_->transaction_, self_->indices_, self_->constraints_, self_->config_}; break; } } @@ -355,15 +354,14 @@ void LabelIndex::Iterable::Iterator::AdvanceUntilValid() { LabelIndex::Iterable::Iterable(utils::SkipList::Accessor index_accessor, LabelId label, View view, Transaction *transaction, Indices *indices, Constraints *constraints, - Config::Items config, const SchemaValidator &schema_validator) + Config::Items config) : index_accessor_(std::move(index_accessor)), label_(label), view_(view), transaction_(transaction), indices_(indices), constraints_(constraints), - config_(config), - schema_validator_(&schema_validator) {} + config_(config) {} void LabelIndex::RunGC() { for (auto &index_entry : index_) { @@ -480,7 +478,7 @@ void LabelPropertyIndex::RemoveObsoleteEntries(uint64_t oldest_active_start_time LabelPropertyIndex::Iterable::Iterator::Iterator(Iterable *self, utils::SkipList::Iterator index_iterator) : self_(self), index_iterator_(index_iterator), - current_vertex_accessor_(nullptr, nullptr, nullptr, nullptr, self_->config_, *self_->schema_validator_), + current_vertex_accessor_(nullptr, nullptr, nullptr, nullptr, self_->config_), current_vertex_(nullptr) { AdvanceUntilValid(); } @@ -519,8 +517,8 @@ void LabelPropertyIndex::Iterable::Iterator::AdvanceUntilValid() { if (CurrentVersionHasLabelProperty(*index_iterator_->vertex, self_->label_, self_->property_, index_iterator_->value, self_->transaction_, self_->view_)) { current_vertex_ = index_iterator_->vertex; - current_vertex_accessor_ = VertexAccessor(current_vertex_, self_->transaction_, self_->indices_, - self_->constraints_, self_->config_, *self_->schema_validator_); + current_vertex_accessor_ = + VertexAccessor(current_vertex_, self_->transaction_, self_->indices_, self_->constraints_, self_->config_); break; } } @@ -543,7 +541,7 @@ LabelPropertyIndex::Iterable::Iterable(utils::SkipList::Accessor index_ac const std::optional> &lower_bound, const std::optional> &upper_bound, View view, Transaction *transaction, Indices *indices, Constraints *constraints, - Config::Items config, const SchemaValidator &schema_validator) + Config::Items config) : index_accessor_(std::move(index_accessor)), label_(label), property_(property), @@ -553,8 +551,7 @@ LabelPropertyIndex::Iterable::Iterable(utils::SkipList::Accessor index_ac transaction_(transaction), indices_(indices), constraints_(constraints), - config_(config), - schema_validator_(&schema_validator) { + config_(config) { // We have to fix the bounds that the user provided to us. If the user // provided only one bound we should make sure that only values of that type // are returned by the iterator. We ensure this by supplying either an diff --git a/src/storage/v2/indices.hpp b/src/storage/v2/indices.hpp index 64e1501f3..336b70ead 100644 --- a/src/storage/v2/indices.hpp +++ b/src/storage/v2/indices.hpp @@ -17,7 +17,6 @@ #include "storage/v2/config.hpp" #include "storage/v2/property_value.hpp" -#include "storage/v2/schema_validator.hpp" #include "storage/v2/transaction.hpp" #include "storage/v2/vertex_accessor.hpp" #include "utils/bound.hpp" @@ -52,8 +51,8 @@ class LabelIndex { }; public: - LabelIndex(Indices *indices, Constraints *constraints, Config::Items config, const SchemaValidator &schema_validator) - : indices_(indices), constraints_(constraints), config_(config), schema_validator_{&schema_validator} {} + LabelIndex(Indices *indices, Constraints *constraints, Config::Items config) + : indices_(indices), constraints_(constraints), config_(config) {} /// @throw std::bad_alloc void UpdateOnAddLabel(LabelId label, Vertex *vertex, const Transaction &tx); @@ -73,7 +72,7 @@ class LabelIndex { class Iterable { public: Iterable(utils::SkipList::Accessor index_accessor, LabelId label, View view, Transaction *transaction, - Indices *indices, Constraints *constraints, Config::Items config, const SchemaValidator &schema_validator); + Indices *indices, Constraints *constraints, Config::Items config); class Iterator { public: @@ -106,14 +105,13 @@ class LabelIndex { Indices *indices_; Constraints *constraints_; Config::Items config_; - const SchemaValidator *schema_validator_; }; /// Returns an self with vertices visible from the given transaction. Iterable Vertices(LabelId label, View view, Transaction *transaction) { auto it = index_.find(label); MG_ASSERT(it != index_.end(), "Index for label {} doesn't exist", label.AsUint()); - return {it->second.access(), label, view, transaction, indices_, constraints_, config_, *schema_validator_}; + return {it->second.access(), label, view, transaction, indices_, constraints_, config_}; } int64_t ApproximateVertexCount(LabelId label) { @@ -131,7 +129,6 @@ class LabelIndex { Indices *indices_; Constraints *constraints_; Config::Items config_; - const SchemaValidator *schema_validator_; }; class LabelPropertyIndex { @@ -149,9 +146,8 @@ class LabelPropertyIndex { }; public: - LabelPropertyIndex(Indices *indices, Constraints *constraints, Config::Items config, - const SchemaValidator &schema_validator) - : indices_(indices), constraints_(constraints), config_(config), schema_validator_{&schema_validator} {} + LabelPropertyIndex(Indices *indices, Constraints *constraints, Config::Items config) + : indices_(indices), constraints_(constraints), config_(config) {} /// @throw std::bad_alloc void UpdateOnAddLabel(LabelId label, Vertex *vertex, const Transaction &tx); @@ -175,7 +171,7 @@ class LabelPropertyIndex { Iterable(utils::SkipList::Accessor index_accessor, LabelId label, PropertyId property, const std::optional> &lower_bound, const std::optional> &upper_bound, View view, Transaction *transaction, - Indices *indices, Constraints *constraints, Config::Items config, const SchemaValidator &schema_validator); + Indices *indices, Constraints *constraints, Config::Items config); class Iterator { public: @@ -212,17 +208,16 @@ class LabelPropertyIndex { Indices *indices_; Constraints *constraints_; Config::Items config_; - const SchemaValidator *schema_validator_; }; Iterable Vertices(LabelId label, PropertyId property, const std::optional> &lower_bound, - const std::optional> &upper_bound, View view, Transaction *transaction, - const SchemaValidator &schema_validator_) { + const std::optional> &upper_bound, View view, + Transaction *transaction) { auto it = index_.find({label, property}); MG_ASSERT(it != index_.end(), "Index for label {} and property {} doesn't exist", label.AsUint(), property.AsUint()); - return {it->second.access(), label, property, lower_bound, upper_bound, view, - transaction, indices_, constraints_, config_, schema_validator_}; + return {it->second.access(), label, property, lower_bound, upper_bound, view, + transaction, indices_, constraints_, config_}; } int64_t ApproximateVertexCount(LabelId label, PropertyId property) const { @@ -251,13 +246,11 @@ class LabelPropertyIndex { Indices *indices_; Constraints *constraints_; Config::Items config_; - const SchemaValidator *schema_validator_; }; struct Indices { - Indices(Constraints *constraints, Config::Items config, const SchemaValidator &schema_validator) - : label_index(this, constraints, config, schema_validator), - label_property_index(this, constraints, config, schema_validator) {} + Indices(Constraints *constraints, Config::Items config) + : label_index(this, constraints, config), label_property_index(this, constraints, config) {} // Disable copy and move because members hold pointer to `this`. Indices(const Indices &) = delete; diff --git a/src/storage/v2/replication/replication_server.cpp b/src/storage/v2/replication/replication_server.cpp index f8f533cac..fed501d6e 100644 --- a/src/storage/v2/replication/replication_server.cpp +++ b/src/storage/v2/replication/replication_server.cpp @@ -166,10 +166,9 @@ void Storage::ReplicationServer::SnapshotHandler(slk::Reader *req_reader, slk::B storage_->edges_.clear(); storage_->constraints_ = Constraints(); - storage_->indices_.label_index = - LabelIndex(&storage_->indices_, &storage_->constraints_, storage_->config_.items, storage_->schema_validator_); - storage_->indices_.label_property_index = LabelPropertyIndex(&storage_->indices_, &storage_->constraints_, - storage_->config_.items, storage_->schema_validator_); + storage_->indices_.label_index = LabelIndex(&storage_->indices_, &storage_->constraints_, storage_->config_.items); + storage_->indices_.label_property_index = + LabelPropertyIndex(&storage_->indices_, &storage_->constraints_, storage_->config_.items); try { spdlog::debug("Loading snapshot"); auto recovered_snapshot = durability::LoadSnapshot(*maybe_snapshot_path, &storage_->vertices_, &storage_->edges_, @@ -474,8 +473,7 @@ uint64_t Storage::ReplicationServer::ReadAndApplyDelta(durability::BaseDecoder * &transaction->transaction_, &storage_->indices_, &storage_->constraints_, - storage_->config_.items, - storage_->schema_validator_}; + storage_->config_.items}; auto ret = ea.SetProperty(transaction->NameToProperty(delta.vertex_edge_set_property.property), delta.vertex_edge_set_property.value); diff --git a/src/storage/v2/storage.cpp b/src/storage/v2/storage.cpp index ff03c6c77..cee74574d 100644 --- a/src/storage/v2/storage.cpp +++ b/src/storage/v2/storage.cpp @@ -14,7 +14,6 @@ #include #include #include -#include #include #include @@ -27,22 +26,17 @@ #include "storage/v2/durability/snapshot.hpp" #include "storage/v2/durability/wal.hpp" #include "storage/v2/edge_accessor.hpp" -#include "storage/v2/id_types.hpp" #include "storage/v2/indices.hpp" #include "storage/v2/mvcc.hpp" #include "storage/v2/replication/config.hpp" #include "storage/v2/replication/enums.hpp" #include "storage/v2/replication/replication_persistence_helper.hpp" -#include "storage/v2/schema_validator.hpp" -#include "storage/v2/schemas.hpp" #include "storage/v2/transaction.hpp" #include "storage/v2/vertex_accessor.hpp" -#include "utils/exceptions.hpp" #include "utils/file.hpp" #include "utils/logging.hpp" #include "utils/memory_tracker.hpp" #include "utils/message.hpp" -#include "utils/result.hpp" #include "utils/rw_lock.hpp" #include "utils/spin_lock.hpp" #include "utils/stat.hpp" @@ -76,9 +70,9 @@ std::string RegisterReplicaErrorToString(Storage::RegisterReplicaError error) { auto AdvanceToVisibleVertex(utils::SkipList::Iterator it, utils::SkipList::Iterator end, std::optional *vertex, Transaction *tx, View view, Indices *indices, - Constraints *constraints, Config::Items config, const SchemaValidator &schema_validator) { + Constraints *constraints, Config::Items config) { while (it != end) { - *vertex = VertexAccessor::Create(&*it, tx, indices, constraints, config, schema_validator, view); + *vertex = VertexAccessor::Create(&*it, tx, indices, constraints, config, view); if (!*vertex) { ++it; continue; @@ -91,14 +85,14 @@ auto AdvanceToVisibleVertex(utils::SkipList::Iterator it, utils::SkipLis AllVerticesIterable::Iterator::Iterator(AllVerticesIterable *self, utils::SkipList::Iterator it) : self_(self), it_(AdvanceToVisibleVertex(it, self->vertices_accessor_.end(), &self->vertex_, self->transaction_, self->view_, - self->indices_, self_->constraints_, self->config_, *self->schema_validator_)) {} + self->indices_, self_->constraints_, self->config_)) {} VertexAccessor AllVerticesIterable::Iterator::operator*() const { return *self_->vertex_; } AllVerticesIterable::Iterator &AllVerticesIterable::Iterator::operator++() { ++it_; it_ = AdvanceToVisibleVertex(it_, self_->vertices_accessor_.end(), &self_->vertex_, self_->transaction_, self_->view_, - self_->indices_, self_->constraints_, self_->config_, *self_->schema_validator_); + self_->indices_, self_->constraints_, self_->config_); return *this; } @@ -319,8 +313,7 @@ bool VerticesIterable::Iterator::operator==(const Iterator &other) const { } Storage::Storage(Config config) - : schema_validator_(schemas_), - indices_(&constraints_, config.items, schema_validator_), + : indices_(&constraints_, config.items), isolation_level_(config.transaction.isolation_level), config_(config), snapshot_directory_(config_.durability.storage_directory / durability::kSnapshotDirectory), @@ -492,8 +485,7 @@ Storage::Accessor::~Accessor() { FinalizeTransaction(); } -// TODO Remove when import csv is fixed -[[deprecated]] VertexAccessor Storage::Accessor::CreateVertex() { +VertexAccessor Storage::Accessor::CreateVertex() { OOMExceptionEnabler oom_exception; auto gid = storage_->vertex_id_.fetch_add(1, std::memory_order_acq_rel); auto acc = storage_->vertices_.access(); @@ -501,71 +493,34 @@ Storage::Accessor::~Accessor() { auto [it, inserted] = acc.insert(Vertex{storage::Gid::FromUint(gid), delta}); MG_ASSERT(inserted, "The vertex must be inserted here!"); MG_ASSERT(it != acc.end(), "Invalid Vertex accessor!"); - delta->prev.Set(&*it); - return {&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_, storage_->schema_validator_}; + return {&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_}; } -// TODO Remove when replication is fixed VertexAccessor Storage::Accessor::CreateVertex(storage::Gid gid) { OOMExceptionEnabler oom_exception; // NOTE: When we update the next `vertex_id_` here we perform a RMW // (read-modify-write) operation that ISN'T atomic! But, that isn't an issue // because this function is only called from the replication delta applier - // that runs single-threaded and while this instance is set-up to apply + // that runs single-threadedly and while this instance is set-up to apply // threads (it is the replica), it is guaranteed that no other writes are // possible. storage_->vertex_id_.store(std::max(storage_->vertex_id_.load(std::memory_order_acquire), gid.AsUint() + 1), std::memory_order_release); auto acc = storage_->vertices_.access(); auto *delta = CreateDeleteObjectDelta(&transaction_); - auto [it, inserted] = acc.insert(Vertex{gid}); + auto [it, inserted] = acc.insert(Vertex{gid, delta}); MG_ASSERT(inserted, "The vertex must be inserted here!"); MG_ASSERT(it != acc.end(), "Invalid Vertex accessor!"); delta->prev.Set(&*it); - return {&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_, storage_->schema_validator_}; -} - -ResultSchema Storage::Accessor::CreateVertexAndValidate( - storage::LabelId primary_label, const std::vector &labels, - const std::vector> &properties) { - auto maybe_schema_violation = GetSchemaValidator().ValidateVertexCreate(primary_label, labels, properties); - if (maybe_schema_violation) { - return {std::move(*maybe_schema_violation)}; - } - OOMExceptionEnabler oom_exception; - auto gid = storage_->vertex_id_.fetch_add(1, std::memory_order_acq_rel); - auto acc = storage_->vertices_.access(); - auto *delta = CreateDeleteObjectDelta(&transaction_); - auto [it, inserted] = acc.insert(Vertex{storage::Gid::FromUint(gid), delta, primary_label}); - MG_ASSERT(inserted, "The vertex must be inserted here!"); - MG_ASSERT(it != acc.end(), "Invalid Vertex accessor!"); - delta->prev.Set(&*it); - - auto va = VertexAccessor{ - &*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_, storage_->schema_validator_}; - for (const auto label : labels) { - const auto maybe_error = va.AddLabel(label); - if (maybe_error.HasError()) { - return {maybe_error.GetError()}; - } - } - // Set properties - for (auto [property_id, property_value] : properties) { - const auto maybe_error = va.SetProperty(property_id, property_value); - if (maybe_error.HasError()) { - return {maybe_error.GetError()}; - } - } - return va; + return {&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_}; } std::optional Storage::Accessor::FindVertex(Gid gid, View view) { auto acc = storage_->vertices_.access(); auto it = acc.find(gid); if (it == acc.end()) return std::nullopt; - return VertexAccessor::Create(&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_, - storage_->schema_validator_, view); + return VertexAccessor::Create(&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_, view); } Result> Storage::Accessor::DeleteVertex(VertexAccessor *vertex) { @@ -588,7 +543,7 @@ Result> Storage::Accessor::DeleteVertex(VertexAcce vertex_ptr->deleted = true; return std::make_optional(vertex_ptr, &transaction_, &storage_->indices_, &storage_->constraints_, - config_, storage_->schema_validator_, true); + config_, true); } Result>>> Storage::Accessor::DetachDeleteVertex( @@ -618,7 +573,7 @@ Result>>> Stor for (const auto &item : in_edges) { auto [edge_type, from_vertex, edge] = item; EdgeAccessor e(edge, edge_type, from_vertex, vertex_ptr, &transaction_, &storage_->indices_, - &storage_->constraints_, config_, storage_->schema_validator_); + &storage_->constraints_, config_); auto ret = DeleteEdge(&e); if (ret.HasError()) { MG_ASSERT(ret.GetError() == Error::SERIALIZATION_ERROR, "Invalid database state!"); @@ -632,7 +587,7 @@ Result>>> Stor for (const auto &item : out_edges) { auto [edge_type, to_vertex, edge] = item; EdgeAccessor e(edge, edge_type, vertex_ptr, to_vertex, &transaction_, &storage_->indices_, &storage_->constraints_, - config_, storage_->schema_validator_); + config_); auto ret = DeleteEdge(&e); if (ret.HasError()) { MG_ASSERT(ret.GetError() == Error::SERIALIZATION_ERROR, "Invalid database state!"); @@ -658,8 +613,7 @@ Result>>> Stor vertex_ptr->deleted = true; return std::make_optional( - VertexAccessor{vertex_ptr, &transaction_, &storage_->indices_, &storage_->constraints_, config_, - storage_->schema_validator_, true}, + VertexAccessor{vertex_ptr, &transaction_, &storage_->indices_, &storage_->constraints_, config_, true}, std::move(deleted_edges)); } @@ -719,7 +673,7 @@ Result Storage::Accessor::CreateEdge(VertexAccessor *from, VertexA storage_->edge_count_.fetch_add(1, std::memory_order_acq_rel); return EdgeAccessor(edge, edge_type, from_vertex, to_vertex, &transaction_, &storage_->indices_, - &storage_->constraints_, config_, storage_->schema_validator_); + &storage_->constraints_, config_); } Result Storage::Accessor::CreateEdge(VertexAccessor *from, VertexAccessor *to, EdgeTypeId edge_type, @@ -787,7 +741,7 @@ Result Storage::Accessor::CreateEdge(VertexAccessor *from, VertexA storage_->edge_count_.fetch_add(1, std::memory_order_acq_rel); return EdgeAccessor(edge, edge_type, from_vertex, to_vertex, &transaction_, &storage_->indices_, - &storage_->constraints_, config_, storage_->schema_validator_); + &storage_->constraints_, config_); } Result> Storage::Accessor::DeleteEdge(EdgeAccessor *edge) { @@ -871,8 +825,7 @@ Result> Storage::Accessor::DeleteEdge(EdgeAccessor * storage_->edge_count_.fetch_add(-1, std::memory_order_acq_rel); return std::make_optional(edge_ref, edge_type, from_vertex, to_vertex, &transaction_, - &storage_->indices_, &storage_->constraints_, config_, - storage_->schema_validator_, true); + &storage_->indices_, &storage_->constraints_, config_, true); } const std::string &Storage::Accessor::LabelToName(LabelId label) const { return storage_->LabelToName(label); } @@ -916,11 +869,11 @@ utils::BasicResult Storage::Accessor::Commit( auto validation_result = ValidateExistenceConstraints(*prev.vertex, storage_->constraints_); if (validation_result) { Abort(); - return {*validation_result}; + return *validation_result; } } - // Result of validating the vertex against unique constraints. It has to be + // Result of validating the vertex against unqiue constraints. It has to be // declared outside of the critical section scope because its value is // tested for Abort call which has to be done out of the scope. std::optional unique_constraint_violation; @@ -1001,7 +954,7 @@ utils::BasicResult Storage::Accessor::Commit( if (unique_constraint_violation) { Abort(); - return {*unique_constraint_violation}; + return *unique_constraint_violation; } } is_transaction_active_ = false; @@ -1302,29 +1255,11 @@ UniqueConstraints::DeletionStatus Storage::DropUniqueConstraint( return UniqueConstraints::DeletionStatus::SUCCESS; } -const SchemaValidator &Storage::Accessor::GetSchemaValidator() const { return storage_->schema_validator_; } - ConstraintsInfo Storage::ListAllConstraints() const { std::shared_lock storage_guard_(main_lock_); return {ListExistenceConstraints(constraints_), constraints_.unique_constraints.ListConstraints()}; } -SchemasInfo Storage::ListAllSchemas() const { - std::shared_lock storage_guard_(main_lock_); - return {schemas_.ListSchemas()}; -} - -const Schemas::Schema *Storage::GetSchema(const LabelId primary_label) const { - std::shared_lock storage_guard_(main_lock_); - return schemas_.GetSchema(primary_label); -} - -bool Storage::CreateSchema(const LabelId primary_label, const std::vector &schemas_types) { - return schemas_.CreateSchema(primary_label, schemas_types); -} - -bool Storage::DropSchema(const LabelId primary_label) { return schemas_.DropSchema(primary_label); } - StorageInfo Storage::GetInfo() const { auto vertex_count = vertices_.size(); auto edge_count = edge_count_.load(std::memory_order_acquire); @@ -1341,22 +1276,21 @@ VerticesIterable Storage::Accessor::Vertices(LabelId label, View view) { } VerticesIterable Storage::Accessor::Vertices(LabelId label, PropertyId property, View view) { - return VerticesIterable(storage_->indices_.label_property_index.Vertices( - label, property, std::nullopt, std::nullopt, view, &transaction_, storage_->schema_validator_)); + return VerticesIterable(storage_->indices_.label_property_index.Vertices(label, property, std::nullopt, std::nullopt, + view, &transaction_)); } VerticesIterable Storage::Accessor::Vertices(LabelId label, PropertyId property, const PropertyValue &value, View view) { return VerticesIterable(storage_->indices_.label_property_index.Vertices( - label, property, utils::MakeBoundInclusive(value), utils::MakeBoundInclusive(value), view, &transaction_, - storage_->schema_validator_)); + label, property, utils::MakeBoundInclusive(value), utils::MakeBoundInclusive(value), view, &transaction_)); } VerticesIterable Storage::Accessor::Vertices(LabelId label, PropertyId property, const std::optional> &lower_bound, const std::optional> &upper_bound, View view) { - return VerticesIterable(storage_->indices_.label_property_index.Vertices( - label, property, lower_bound, upper_bound, view, &transaction_, storage_->schema_validator_)); + return VerticesIterable( + storage_->indices_.label_property_index.Vertices(label, property, lower_bound, upper_bound, view, &transaction_)); } Transaction Storage::CreateTransaction(IsolationLevel isolation_level) { @@ -1884,8 +1818,8 @@ utils::BasicResult Storage::CreateSnapshot() { // Create snapshot. durability::CreateSnapshot(&transaction, snapshot_directory_, wal_directory_, config_.durability.snapshot_retention_count, &vertices_, &edges_, &name_id_mapper_, - &indices_, &constraints_, config_.items, schema_validator_, uuid_, epoch_id_, - epoch_history_, &file_retainer_); + &indices_, &constraints_, config_.items, uuid_, epoch_id_, epoch_history_, + &file_retainer_); // Finalize snapshot transaction. commit_log_->MarkFinished(transaction.start_timestamp); diff --git a/src/storage/v2/storage.hpp b/src/storage/v2/storage.hpp index 4addce561..6aab1977f 100644 --- a/src/storage/v2/storage.hpp +++ b/src/storage/v2/storage.hpp @@ -16,7 +16,6 @@ #include #include #include -#include #include "io/network/endpoint.hpp" #include "kvstore/kvstore.hpp" @@ -27,19 +26,14 @@ #include "storage/v2/durability/wal.hpp" #include "storage/v2/edge.hpp" #include "storage/v2/edge_accessor.hpp" -#include "storage/v2/id_types.hpp" #include "storage/v2/indices.hpp" #include "storage/v2/isolation_level.hpp" #include "storage/v2/mvcc.hpp" #include "storage/v2/name_id_mapper.hpp" -#include "storage/v2/property_value.hpp" #include "storage/v2/result.hpp" -#include "storage/v2/schema_validator.hpp" -#include "storage/v2/schemas.hpp" #include "storage/v2/transaction.hpp" #include "storage/v2/vertex.hpp" #include "storage/v2/vertex_accessor.hpp" -#include "utils/exceptions.hpp" #include "utils/file_locker.hpp" #include "utils/on_scope_exit.hpp" #include "utils/rw_lock.hpp" @@ -73,7 +67,6 @@ class AllVerticesIterable final { Indices *indices_; Constraints *constraints_; Config::Items config_; - const SchemaValidator *schema_validator_; std::optional vertex_; public: @@ -94,15 +87,13 @@ class AllVerticesIterable final { }; AllVerticesIterable(utils::SkipList::Accessor vertices_accessor, Transaction *transaction, View view, - Indices *indices, Constraints *constraints, Config::Items config, - SchemaValidator *schema_validator) + Indices *indices, Constraints *constraints, Config::Items config) : vertices_accessor_(std::move(vertices_accessor)), transaction_(transaction), view_(view), indices_(indices), constraints_(constraints), - config_(config), - schema_validator_(schema_validator) {} + config_(config) {} Iterator begin() { return Iterator(this, vertices_accessor_.begin()); } Iterator end() { return Iterator(this, vertices_accessor_.end()); } @@ -183,11 +174,6 @@ struct ConstraintsInfo { std::vector>> unique; }; -/// Structure used to return information about existing schemas in the storage -struct SchemasInfo { - Schemas::SchemasList schemas; -}; - /// Structure used to return information about the storage. struct StorageInfo { uint64_t vertex_count; @@ -224,21 +210,15 @@ class Storage final { ~Accessor(); - VertexAccessor CreateVertex(); - - VertexAccessor CreateVertex(storage::Gid gid); - /// @throw std::bad_alloc - ResultSchema CreateVertexAndValidate( - storage::LabelId primary_label, const std::vector &labels, - const std::vector> &properties); + VertexAccessor CreateVertex(); std::optional FindVertex(Gid gid, View view); VerticesIterable Vertices(View view) { return VerticesIterable(AllVerticesIterable(storage_->vertices_.access(), &transaction_, view, - &storage_->indices_, &storage_->constraints_, storage_->config_.items, - &storage_->schema_validator_)); + &storage_->indices_, &storage_->constraints_, + storage_->config_.items)); } VerticesIterable Vertices(LabelId label, View view); @@ -327,10 +307,6 @@ class Storage final { storage_->constraints_.unique_constraints.ListConstraints()}; } - const SchemaValidator &GetSchemaValidator() const; - - SchemasInfo ListAllSchemas() const { return {storage_->schemas_.ListSchemas()}; } - void AdvanceCommand(); /// Commit returns `ConstraintViolation` if the changes made by this @@ -346,7 +322,7 @@ class Storage final { private: /// @throw std::bad_alloc - VertexAccessor CreateVertex(storage::Gid gid, storage::LabelId primary_label); + VertexAccessor CreateVertex(storage::Gid gid); /// @throw std::bad_alloc Result CreateEdge(VertexAccessor *from, VertexAccessor *to, EdgeTypeId edge_type, storage::Gid gid); @@ -389,7 +365,7 @@ class Storage final { IndicesInfo ListAllIndices() const; /// Creates an existence constraint. Returns true if the constraint was - /// successfully added, false if it already exists and a `ConstraintViolation` + /// successfuly added, false if it already exists and a `ConstraintViolation` /// if there is an existing vertex violating the constraint. /// /// @throw std::bad_alloc @@ -427,14 +403,6 @@ class Storage final { ConstraintsInfo ListAllConstraints() const; - SchemasInfo ListAllSchemas() const; - - const Schemas::Schema *GetSchema(LabelId primary_label) const; - - bool CreateSchema(LabelId primary_label, const std::vector &schemas_types); - - bool DropSchema(LabelId primary_label); - StorageInfo GetInfo() const; bool LockPath(); @@ -536,10 +504,8 @@ class Storage final { NameIdMapper name_id_mapper_; - SchemaValidator schema_validator_; Constraints constraints_; Indices indices_; - Schemas schemas_; // Transaction engine utils::SpinLock engine_lock_; diff --git a/src/storage/v2/vertex.hpp b/src/storage/v2/vertex.hpp index c2a63144f..83f517c46 100644 --- a/src/storage/v2/vertex.hpp +++ b/src/storage/v2/vertex.hpp @@ -19,39 +19,18 @@ #include "storage/v2/edge_ref.hpp" #include "storage/v2/id_types.hpp" #include "storage/v2/property_store.hpp" -#include "utils/algorithm.hpp" #include "utils/spin_lock.hpp" namespace memgraph::storage { struct Vertex { - Vertex(Gid gid, Delta *delta, LabelId primary_label) - : gid(gid), primary_label{primary_label}, deleted(false), delta(delta) { - MG_ASSERT(delta == nullptr || delta->action == Delta::Action::DELETE_OBJECT, - "Vertex must be created with an initial DELETE_OBJECT delta!"); - } - - // TODO remove this when import replication is solved - Vertex(Gid gid, LabelId primary_label) : gid(gid), primary_label{primary_label}, deleted(false) { - MG_ASSERT(delta == nullptr || delta->action == Delta::Action::DELETE_OBJECT, - "Vertex must be created with an initial DELETE_OBJECT delta!"); - } - - // TODO remove this when import csv is solved - [[deprecated]] Vertex(Gid gid, Delta *delta) : gid(gid), deleted(false), delta(delta) { - MG_ASSERT(delta == nullptr || delta->action == Delta::Action::DELETE_OBJECT, - "Vertex must be created with an initial DELETE_OBJECT delta!"); - } - - // TODO remove this when import replication is solved - [[deprecated]] explicit Vertex(Gid gid) : gid(gid), deleted(false) { + Vertex(Gid gid, Delta *delta) : gid(gid), deleted(false), delta(delta) { MG_ASSERT(delta == nullptr || delta->action == Delta::Action::DELETE_OBJECT, "Vertex must be created with an initial DELETE_OBJECT delta!"); } Gid gid; - LabelId primary_label; std::vector labels; PropertyStore properties; @@ -73,8 +52,4 @@ inline bool operator<(const Vertex &first, const Vertex &second) { return first. inline bool operator==(const Vertex &first, const Gid &second) { return first.gid == second; } inline bool operator<(const Vertex &first, const Gid &second) { return first.gid < second; } -inline bool VertexHasLabel(const Vertex &vertex, const LabelId label) { - return vertex.primary_label == label || utils::Contains(vertex.labels, label); -} - } // namespace memgraph::storage diff --git a/src/storage/v2/vertex_accessor.cpp b/src/storage/v2/vertex_accessor.cpp index dacdedb1b..05ba1ebcc 100644 --- a/src/storage/v2/vertex_accessor.cpp +++ b/src/storage/v2/vertex_accessor.cpp @@ -18,8 +18,6 @@ #include "storage/v2/indices.hpp" #include "storage/v2/mvcc.hpp" #include "storage/v2/property_value.hpp" -#include "storage/v2/schema_validator.hpp" -#include "storage/v2/vertex.hpp" #include "utils/logging.hpp" #include "utils/memory_tracker.hpp" @@ -63,13 +61,12 @@ std::pair IsVisible(Vertex *vertex, Transaction *transaction, View v } // namespace detail std::optional VertexAccessor::Create(Vertex *vertex, Transaction *transaction, Indices *indices, - Constraints *constraints, Config::Items config, - const SchemaValidator &schema_validator, View view) { + Constraints *constraints, Config::Items config, View view) { if (const auto [exists, deleted] = detail::IsVisible(vertex, transaction, view); !exists || deleted) { return std::nullopt; } - return VertexAccessor{vertex, transaction, indices, constraints, config, schema_validator}; + return VertexAccessor{vertex, transaction, indices, constraints, config}; } bool VertexAccessor::IsVisible(View view) const { @@ -96,28 +93,6 @@ Result VertexAccessor::AddLabel(LabelId label) { return true; } -storage::ResultSchema VertexAccessor::AddLabelAndValidate(LabelId label) { - if (const auto maybe_violation_error = vertex_validator_.ValidateAddLabel(label); maybe_violation_error) { - return {*maybe_violation_error}; - } - utils::MemoryTracker::OutOfMemoryExceptionEnabler oom_exception; - std::lock_guard guard(vertex_->lock); - - if (!PrepareForWrite(transaction_, vertex_)) return {Error::SERIALIZATION_ERROR}; - - if (vertex_->deleted) return {Error::DELETED_OBJECT}; - - if (std::find(vertex_->labels.begin(), vertex_->labels.end(), label) != vertex_->labels.end()) return false; - - CreateAndLinkDelta(transaction_, vertex_, Delta::RemoveLabelTag(), label); - - vertex_->labels.push_back(label); - - UpdateOnAddLabel(indices_, label, vertex_, *transaction_); - - return true; -} - Result VertexAccessor::RemoveLabel(LabelId label) { std::lock_guard guard(vertex_->lock); @@ -135,26 +110,6 @@ Result VertexAccessor::RemoveLabel(LabelId label) { return true; } -ResultSchema VertexAccessor::RemoveLabelAndValidate(LabelId label) { - if (const auto maybe_violation_error = vertex_validator_.ValidateRemoveLabel(label); maybe_violation_error) { - return {*maybe_violation_error}; - } - std::lock_guard guard(vertex_->lock); - - if (!PrepareForWrite(transaction_, vertex_)) return {Error::SERIALIZATION_ERROR}; - - if (vertex_->deleted) return {Error::DELETED_OBJECT}; - - auto it = std::find(vertex_->labels.begin(), vertex_->labels.end(), label); - if (it == vertex_->labels.end()) return false; - - CreateAndLinkDelta(transaction_, vertex_, Delta::AddLabelTag(), label); - - std::swap(*it, *vertex_->labels.rbegin()); - vertex_->labels.pop_back(); - return true; -} - Result VertexAccessor::HasLabel(LabelId label, View view) const { bool exists = true; bool deleted = false; @@ -163,7 +118,7 @@ Result VertexAccessor::HasLabel(LabelId label, View view) const { { std::lock_guard guard(vertex_->lock); deleted = vertex_->deleted; - has_label = VertexHasLabel(*vertex_, label); + has_label = std::find(vertex_->labels.begin(), vertex_->labels.end(), label) != vertex_->labels.end(); delta = vertex_->delta; } ApplyDeltasForRead(transaction_, delta, view, [&exists, &deleted, &has_label, label](const Delta &delta) { @@ -203,40 +158,6 @@ Result VertexAccessor::HasLabel(LabelId label, View view) const { return has_label; } -Result VertexAccessor::PrimaryLabel(const View view) const { - bool exists = true; - bool deleted = false; - Delta *delta = nullptr; - { - std::lock_guard guard(vertex_->lock); - deleted = vertex_->deleted; - delta = vertex_->delta; - } - ApplyDeltasForRead(transaction_, delta, view, [&exists, &deleted](const Delta &delta) { - switch (delta.action) { - case Delta::Action::DELETE_OBJECT: { - exists = false; - break; - } - case Delta::Action::RECREATE_OBJECT: { - deleted = false; - break; - } - case Delta::Action::ADD_LABEL: - case Delta::Action::REMOVE_LABEL: - case Delta::Action::SET_PROPERTY: - case Delta::Action::ADD_IN_EDGE: - case Delta::Action::ADD_OUT_EDGE: - case Delta::Action::REMOVE_IN_EDGE: - case Delta::Action::REMOVE_OUT_EDGE: - break; - } - }); - if (!exists) return Error::NONEXISTENT_OBJECT; - if (!for_deleted_ && deleted) return Error::DELETED_OBJECT; - return vertex_->primary_label; -} - Result> VertexAccessor::Labels(View view) const { bool exists = true; bool deleted = false; @@ -309,36 +230,6 @@ Result VertexAccessor::SetProperty(PropertyId property, const Pro return std::move(current_value); } -ResultSchema VertexAccessor::SetPropertyAndValidate(PropertyId property, const PropertyValue &value) { - if (auto maybe_violation_error = vertex_validator_.ValidatePropertyUpdate(property); maybe_violation_error) { - return {*maybe_violation_error}; - } - utils::MemoryTracker::OutOfMemoryExceptionEnabler oom_exception; - std::lock_guard guard(vertex_->lock); - - if (!PrepareForWrite(transaction_, vertex_)) { - return {Error::SERIALIZATION_ERROR}; - } - - if (vertex_->deleted) { - return {Error::DELETED_OBJECT}; - } - - auto current_value = vertex_->properties.GetProperty(property); - // We could skip setting the value if the previous one is the same to the new - // one. This would save some memory as a delta would not be created as well as - // avoid copying the value. The reason we are not doing that is because the - // current code always follows the logical pattern of "create a delta" and - // "modify in-place". Additionally, the created delta will make other - // transactions get a SERIALIZATION_ERROR. - CreateAndLinkDelta(transaction_, vertex_, Delta::SetPropertyTag(), property, current_value); - vertex_->properties.SetProperty(property, value); - - UpdateOnSetProperty(indices_, property, value, vertex_, *transaction_); - - return std::move(current_value); -} - Result> VertexAccessor::ClearProperties() { std::lock_guard guard(vertex_->lock); @@ -523,8 +414,7 @@ Result> VertexAccessor::InEdges(View view, const std:: ret.reserve(in_edges.size()); for (const auto &item : in_edges) { const auto &[edge_type, from_vertex, edge] = item; - ret.emplace_back(edge, edge_type, from_vertex, vertex_, transaction_, indices_, constraints_, config_, - *vertex_validator_.schema_validator); + ret.emplace_back(edge, edge_type, from_vertex, vertex_, transaction_, indices_, constraints_, config_); } return std::move(ret); } @@ -604,8 +494,7 @@ Result> VertexAccessor::OutEdges(View view, const std: ret.reserve(out_edges.size()); for (const auto &item : out_edges) { const auto &[edge_type, to_vertex, edge] = item; - ret.emplace_back(edge, edge_type, vertex_, to_vertex, transaction_, indices_, constraints_, config_, - *vertex_validator_.schema_validator); + ret.emplace_back(edge, edge_type, vertex_, to_vertex, transaction_, indices_, constraints_, config_); } return std::move(ret); } @@ -686,21 +575,4 @@ Result VertexAccessor::OutDegree(View view) const { return degree; } -VertexAccessor::VertexValidator::VertexValidator(const SchemaValidator &schema_validator, const Vertex *vertex) - : schema_validator{&schema_validator}, vertex_{vertex} {} - -[[nodiscard]] std::optional VertexAccessor::VertexValidator::ValidatePropertyUpdate( - PropertyId property_id) const { - MG_ASSERT(vertex_ != nullptr, "Cannot validate vertex which is nullptr"); - return schema_validator->ValidatePropertyUpdate(vertex_->primary_label, property_id); -}; - -[[nodiscard]] std::optional VertexAccessor::VertexValidator::ValidateAddLabel(LabelId label) const { - return schema_validator->ValidateLabelUpdate(label); -} - -[[nodiscard]] std::optional VertexAccessor::VertexValidator::ValidateRemoveLabel(LabelId label) const { - return schema_validator->ValidateLabelUpdate(label); -} - } // namespace memgraph::storage diff --git a/src/storage/v2/vertex_accessor.hpp b/src/storage/v2/vertex_accessor.hpp index eed4cb7e5..840eec910 100644 --- a/src/storage/v2/vertex_accessor.hpp +++ b/src/storage/v2/vertex_accessor.hpp @@ -13,8 +13,6 @@ #include -#include "storage/v2/id_types.hpp" -#include "storage/v2/schema_validator.hpp" #include "storage/v2/vertex.hpp" #include "storage/v2/config.hpp" @@ -31,39 +29,20 @@ struct Constraints; class VertexAccessor final { private: - struct VertexValidator { - // TODO(jbajic) Beware since vertex is pointer it will be accessed even as nullptr - explicit VertexValidator(const SchemaValidator &schema_validator, const Vertex *vertex); - - [[nodiscard]] std::optional ValidatePropertyUpdate(PropertyId property_id) const; - - [[nodiscard]] std::optional ValidateAddLabel(LabelId label) const; - - [[nodiscard]] std::optional ValidateRemoveLabel(LabelId label) const; - - const SchemaValidator *schema_validator; - - private: - const Vertex *vertex_; - }; friend class Storage; public: - // Be careful when using VertexAccessor since it can be instantiated with - // nullptr values VertexAccessor(Vertex *vertex, Transaction *transaction, Indices *indices, Constraints *constraints, - Config::Items config, const SchemaValidator &schema_validator, bool for_deleted = false) + Config::Items config, bool for_deleted = false) : vertex_(vertex), transaction_(transaction), indices_(indices), constraints_(constraints), config_(config), - vertex_validator_{schema_validator, vertex}, for_deleted_(for_deleted) {} static std::optional Create(Vertex *vertex, Transaction *transaction, Indices *indices, - Constraints *constraints, Config::Items config, - const SchemaValidator &schema_validator, View view); + Constraints *constraints, Config::Items config, View view); /// @return true if the object is visible from the current transaction bool IsVisible(View view) const; @@ -73,23 +52,11 @@ class VertexAccessor final { /// @throw std::bad_alloc Result AddLabel(LabelId label); - /// Add a label and return `true` if insertion took place. - /// `false` is returned if the label already existed, or SchemaViolation - /// if adding the label has violated one of the schema constraints. - /// @throw std::bad_alloc - storage::ResultSchema AddLabelAndValidate(LabelId label); - /// Remove a label and return `true` if deletion took place. /// `false` is returned if the vertex did not have a label already. /// @throw std::bad_alloc Result RemoveLabel(LabelId label); - /// Remove a label and return `true` if deletion took place. - /// `false` is returned if the vertex did not have a label already. or SchemaViolation - /// if adding the label has violated one of the schema constraints. - /// @throw std::bad_alloc - ResultSchema RemoveLabelAndValidate(LabelId label); - Result HasLabel(LabelId label, View view) const; /// @throw std::bad_alloc @@ -97,16 +64,10 @@ class VertexAccessor final { /// std::vector::max_size(). Result> Labels(View view) const; - Result PrimaryLabel(View view) const; - /// Set a property value and return the old value. /// @throw std::bad_alloc Result SetProperty(PropertyId property, const PropertyValue &value); - /// Set a property value and return the old value or error. - /// @throw std::bad_alloc - ResultSchema SetPropertyAndValidate(PropertyId property, const PropertyValue &value); - /// Remove all properties and return the values of the removed properties. /// @throw std::bad_alloc Result> ClearProperties(); @@ -135,8 +96,6 @@ class VertexAccessor final { Gid Gid() const noexcept { return vertex_->gid; } - const SchemaValidator *GetSchemaValidator() const; - bool operator==(const VertexAccessor &other) const noexcept { return vertex_ == other.vertex_ && transaction_ == other.transaction_; } @@ -148,7 +107,6 @@ class VertexAccessor final { Indices *indices_; Constraints *constraints_; Config::Items config_; - VertexValidator vertex_validator_; // if the accessor was created for a deleted vertex. // Accessor behaves differently for some methods based on this diff --git a/src/storage/v3/CMakeLists.txt b/src/storage/v3/CMakeLists.txt index e9322b45a..9bc8e1f10 100644 --- a/src/storage/v3/CMakeLists.txt +++ b/src/storage/v3/CMakeLists.txt @@ -10,6 +10,8 @@ set(storage_v3_src_files indices.cpp property_store.cpp vertex_accessor.cpp + schemas.cpp + schema_validator.cpp storage.cpp) # #### Replication ##### diff --git a/src/storage/v3/constraints.cpp b/src/storage/v3/constraints.cpp index e04fa4070..a00b28aa1 100644 --- a/src/storage/v3/constraints.cpp +++ b/src/storage/v3/constraints.cpp @@ -16,6 +16,7 @@ #include #include "storage/v3/mvcc.hpp" +#include "storage/v3/vertex.hpp" #include "utils/logging.hpp" namespace memgraph::storage::v3 { @@ -59,7 +60,7 @@ bool LastCommittedVersionHasLabelProperty(const Vertex &vertex, LabelId label, c std::lock_guard guard(vertex.lock); delta = vertex.delta; deleted = vertex.deleted; - has_label = utils::Contains(vertex.labels, label); + has_label = VertexHasLabel(vertex, label); size_t i = 0; for (const auto &property : properties) { @@ -142,7 +143,7 @@ bool AnyVersionHasLabelProperty(const Vertex &vertex, LabelId label, const std:: Delta *delta{nullptr}; { std::lock_guard guard(vertex.lock); - has_label = utils::Contains(vertex.labels, label); + has_label = VertexHasLabel(vertex, label); deleted = vertex.deleted; delta = vertex.delta; @@ -267,7 +268,7 @@ bool UniqueConstraints::Entry::operator==(const std::vector &rhs) void UniqueConstraints::UpdateBeforeCommit(const Vertex *vertex, const Transaction &tx) { for (auto &[label_props, storage] : constraints_) { - if (!utils::Contains(vertex->labels, label_props.first)) { + if (!VertexHasLabel(*vertex, label_props.first)) { continue; } auto values = ExtractPropertyValues(*vertex, label_props.second); @@ -301,7 +302,7 @@ utils::BasicResult Uniqu auto acc = constraint->second.access(); for (const Vertex &vertex : vertices) { - if (vertex.deleted || !utils::Contains(vertex.labels, label)) { + if (vertex.deleted || !VertexHasLabel(vertex, label)) { continue; } auto values = ExtractPropertyValues(vertex, properties); @@ -352,7 +353,7 @@ std::optional UniqueConstraints::Validate(const Vertex &ver for (const auto &[label_props, storage] : constraints_) { const auto &label = label_props.first; const auto &properties = label_props.second; - if (!utils::Contains(vertex.labels, label)) { + if (!VertexHasLabel(vertex, label)) { continue; } diff --git a/src/storage/v3/constraints.hpp b/src/storage/v3/constraints.hpp index 3c0715c4d..25d329ad9 100644 --- a/src/storage/v3/constraints.hpp +++ b/src/storage/v3/constraints.hpp @@ -158,7 +158,7 @@ inline utils::BasicResult CreateExistenceConstraint( return false; } for (const auto &vertex : vertices) { - if (!vertex.deleted && utils::Contains(vertex.labels, label) && !vertex.properties.HasProperty(property)) { + if (!vertex.deleted && VertexHasLabel(vertex, label) && !vertex.properties.HasProperty(property)) { return ConstraintViolation{ConstraintViolation::Type::EXISTENCE, label, std::set{property}}; } } @@ -184,7 +184,7 @@ inline bool DropExistenceConstraint(Constraints *constraints, LabelId label, Pro [[nodiscard]] inline std::optional ValidateExistenceConstraints(const Vertex &vertex, const Constraints &constraints) { for (const auto &[label, property] : constraints.existence_constraints) { - if (!vertex.deleted && utils::Contains(vertex.labels, label) && !vertex.properties.HasProperty(property)) { + if (!vertex.deleted && VertexHasLabel(vertex, label) && !vertex.properties.HasProperty(property)) { return ConstraintViolation{ConstraintViolation::Type::EXISTENCE, label, std::set{property}}; } } diff --git a/src/storage/v3/durability/snapshot.cpp b/src/storage/v3/durability/snapshot.cpp index 48b4184ad..1942778e2 100644 --- a/src/storage/v3/durability/snapshot.cpp +++ b/src/storage/v3/durability/snapshot.cpp @@ -628,8 +628,9 @@ RecoveredSnapshot LoadSnapshot(const std::filesystem::path &path, utils::SkipLis void CreateSnapshot(Transaction *transaction, const std::filesystem::path &snapshot_directory, const std::filesystem::path &wal_directory, uint64_t snapshot_retention_count, utils::SkipList *vertices, utils::SkipList *edges, NameIdMapper *name_id_mapper, - Indices *indices, Constraints *constraints, Config::Items items, const std::string &uuid, - const std::string_view epoch_id, const std::deque> &epoch_history, + Indices *indices, Constraints *constraints, Config::Items items, + const SchemaValidator &schema_validator, const std::string &uuid, const std::string_view epoch_id, + const std::deque> &epoch_history, utils::FileRetainer *file_retainer) { // Ensure that the storage directory exists. utils::EnsureDirOrDie(snapshot_directory); @@ -713,8 +714,9 @@ void CreateSnapshot(Transaction *transaction, const std::filesystem::path &snaps // type and invalid from/to pointers because we don't know them here, // but that isn't an issue because we won't use that part of the API // here. - auto ea = - EdgeAccessor{edge_ref, EdgeTypeId::FromUint(0UL), nullptr, nullptr, transaction, indices, constraints, items}; + // TODO(jbajic) Fix snapshot with new schema rules + auto ea = EdgeAccessor{edge_ref, EdgeTypeId::FromUint(0UL), nullptr, nullptr, transaction, indices, constraints, + items, schema_validator}; // Get edge data. auto maybe_props = ea.Properties(View::OLD); @@ -742,7 +744,7 @@ void CreateSnapshot(Transaction *transaction, const std::filesystem::path &snaps auto acc = vertices->access(); for (auto &vertex : acc) { // The visibility check is implemented for vertices so we use it here. - auto va = VertexAccessor::Create(&vertex, transaction, indices, constraints, items, View::OLD); + auto va = VertexAccessor::Create(&vertex, transaction, indices, constraints, items, schema_validator, View::OLD); if (!va) continue; // Get vertex data. diff --git a/src/storage/v3/durability/snapshot.hpp b/src/storage/v3/durability/snapshot.hpp index 785ca7ed2..28c0edf08 100644 --- a/src/storage/v3/durability/snapshot.hpp +++ b/src/storage/v3/durability/snapshot.hpp @@ -21,6 +21,7 @@ #include "storage/v3/edge.hpp" #include "storage/v3/indices.hpp" #include "storage/v3/name_id_mapper.hpp" +#include "storage/v3/schema_validator.hpp" #include "storage/v3/transaction.hpp" #include "storage/v3/vertex.hpp" #include "utils/file_locker.hpp" @@ -68,8 +69,9 @@ RecoveredSnapshot LoadSnapshot(const std::filesystem::path &path, utils::SkipLis void CreateSnapshot(Transaction *transaction, const std::filesystem::path &snapshot_directory, const std::filesystem::path &wal_directory, uint64_t snapshot_retention_count, utils::SkipList *vertices, utils::SkipList *edges, NameIdMapper *name_id_mapper, - Indices *indices, Constraints *constraints, Config::Items items, const std::string &uuid, - std::string_view epoch_id, const std::deque> &epoch_history, + Indices *indices, Constraints *constraints, Config::Items items, + const SchemaValidator &schema_validator, const std::string &uuid, std::string_view epoch_id, + const std::deque> &epoch_history, utils::FileRetainer *file_retainer); } // namespace memgraph::storage::v3::durability diff --git a/src/storage/v3/edge_accessor.cpp b/src/storage/v3/edge_accessor.cpp index 2a1294d8e..abb5597e5 100644 --- a/src/storage/v3/edge_accessor.cpp +++ b/src/storage/v3/edge_accessor.cpp @@ -15,6 +15,7 @@ #include "storage/v3/mvcc.hpp" #include "storage/v3/property_value.hpp" +#include "storage/v3/schema_validator.hpp" #include "storage/v3/vertex_accessor.hpp" #include "utils/memory_tracker.hpp" @@ -54,11 +55,11 @@ bool EdgeAccessor::IsVisible(const View view) const { } VertexAccessor EdgeAccessor::FromVertex() const { - return VertexAccessor{from_vertex_, transaction_, indices_, constraints_, config_}; + return {from_vertex_, transaction_, indices_, constraints_, config_, *schema_validator_}; } VertexAccessor EdgeAccessor::ToVertex() const { - return VertexAccessor{to_vertex_, transaction_, indices_, constraints_, config_}; + return {to_vertex_, transaction_, indices_, constraints_, config_, *schema_validator_}; } Result EdgeAccessor::SetProperty(PropertyId property, const PropertyValue &value) { diff --git a/src/storage/v3/edge_accessor.hpp b/src/storage/v3/edge_accessor.hpp index 60497c80b..cf8b658d8 100644 --- a/src/storage/v3/edge_accessor.hpp +++ b/src/storage/v3/edge_accessor.hpp @@ -18,6 +18,7 @@ #include "storage/v3/config.hpp" #include "storage/v3/result.hpp" +#include "storage/v3/schema_validator.hpp" #include "storage/v3/transaction.hpp" #include "storage/v3/view.hpp" @@ -34,7 +35,8 @@ class EdgeAccessor final { public: EdgeAccessor(EdgeRef edge, EdgeTypeId edge_type, Vertex *from_vertex, Vertex *to_vertex, Transaction *transaction, - Indices *indices, Constraints *constraints, Config::Items config, bool for_deleted = false) + Indices *indices, Constraints *constraints, Config::Items config, + const SchemaValidator &schema_validator, bool for_deleted = false) : edge_(edge), edge_type_(edge_type), from_vertex_(from_vertex), @@ -43,6 +45,7 @@ class EdgeAccessor final { indices_(indices), constraints_(constraints), config_(config), + schema_validator_{&schema_validator}, for_deleted_(for_deleted) {} /// @return true if the object is visible from the current transaction @@ -91,6 +94,7 @@ class EdgeAccessor final { Indices *indices_; Constraints *constraints_; Config::Items config_; + const SchemaValidator *schema_validator_; // if the accessor was created for a deleted edge. // Accessor behaves differently for some methods based on this diff --git a/src/storage/v3/indices.cpp b/src/storage/v3/indices.cpp index 340484228..aab1c0fe6 100644 --- a/src/storage/v3/indices.cpp +++ b/src/storage/v3/indices.cpp @@ -10,10 +10,12 @@ // licenses/APL.txt. #include "indices.hpp" + #include #include "storage/v3/mvcc.hpp" #include "storage/v3/property_value.hpp" +#include "storage/v3/schema_validator.hpp" #include "utils/bound.hpp" #include "utils/logging.hpp" #include "utils/memory_tracker.hpp" @@ -327,7 +329,7 @@ void LabelIndex::RemoveObsoleteEntries(uint64_t oldest_active_start_timestamp) { LabelIndex::Iterable::Iterator::Iterator(Iterable *self, utils::SkipList::Iterator index_iterator) : self_(self), index_iterator_(index_iterator), - current_vertex_accessor_(nullptr, nullptr, nullptr, nullptr, self_->config_), + current_vertex_accessor_(nullptr, nullptr, nullptr, nullptr, self_->config_, *self_->schema_validator_), current_vertex_(nullptr) { AdvanceUntilValid(); } @@ -345,8 +347,8 @@ void LabelIndex::Iterable::Iterator::AdvanceUntilValid() { } if (CurrentVersionHasLabel(*index_iterator_->vertex, self_->label_, self_->transaction_, self_->view_)) { current_vertex_ = index_iterator_->vertex; - current_vertex_accessor_ = - VertexAccessor{current_vertex_, self_->transaction_, self_->indices_, self_->constraints_, self_->config_}; + current_vertex_accessor_ = VertexAccessor{current_vertex_, self_->transaction_, self_->indices_, + self_->constraints_, self_->config_, *self_->schema_validator_}; break; } } @@ -354,14 +356,15 @@ void LabelIndex::Iterable::Iterator::AdvanceUntilValid() { LabelIndex::Iterable::Iterable(utils::SkipList::Accessor index_accessor, LabelId label, View view, Transaction *transaction, Indices *indices, Constraints *constraints, - Config::Items config) + Config::Items config, const SchemaValidator &schema_validator) : index_accessor_(std::move(index_accessor)), label_(label), view_(view), transaction_(transaction), indices_(indices), constraints_(constraints), - config_(config) {} + config_(config), + schema_validator_(&schema_validator) {} void LabelIndex::RunGC() { for (auto &index_entry : index_) { @@ -478,7 +481,7 @@ void LabelPropertyIndex::RemoveObsoleteEntries(uint64_t oldest_active_start_time LabelPropertyIndex::Iterable::Iterator::Iterator(Iterable *self, utils::SkipList::Iterator index_iterator) : self_(self), index_iterator_(index_iterator), - current_vertex_accessor_(nullptr, nullptr, nullptr, nullptr, self_->config_), + current_vertex_accessor_(nullptr, nullptr, nullptr, nullptr, self_->config_, *self_->schema_validator_), current_vertex_(nullptr) { AdvanceUntilValid(); } @@ -517,8 +520,8 @@ void LabelPropertyIndex::Iterable::Iterator::AdvanceUntilValid() { if (CurrentVersionHasLabelProperty(*index_iterator_->vertex, self_->label_, self_->property_, index_iterator_->value, self_->transaction_, self_->view_)) { current_vertex_ = index_iterator_->vertex; - current_vertex_accessor_ = - VertexAccessor(current_vertex_, self_->transaction_, self_->indices_, self_->constraints_, self_->config_); + current_vertex_accessor_ = VertexAccessor(current_vertex_, self_->transaction_, self_->indices_, + self_->constraints_, self_->config_, *self_->schema_validator_); break; } } @@ -541,7 +544,7 @@ LabelPropertyIndex::Iterable::Iterable(utils::SkipList::Accessor index_ac const std::optional> &lower_bound, const std::optional> &upper_bound, View view, Transaction *transaction, Indices *indices, Constraints *constraints, - Config::Items config) + Config::Items config, const SchemaValidator &schema_validator) : index_accessor_(std::move(index_accessor)), label_(label), property_(property), @@ -551,7 +554,8 @@ LabelPropertyIndex::Iterable::Iterable(utils::SkipList::Accessor index_ac transaction_(transaction), indices_(indices), constraints_(constraints), - config_(config) { + config_(config), + schema_validator_(&schema_validator) { // We have to fix the bounds that the user provided to us. If the user // provided only one bound we should make sure that only values of that type // are returned by the iterator. We ensure this by supplying either an diff --git a/src/storage/v3/indices.hpp b/src/storage/v3/indices.hpp index fe9e83b88..cdc5df630 100644 --- a/src/storage/v3/indices.hpp +++ b/src/storage/v3/indices.hpp @@ -11,12 +11,14 @@ #pragma once +#include #include #include #include #include "storage/v3/config.hpp" #include "storage/v3/property_value.hpp" +#include "storage/v3/schema_validator.hpp" #include "storage/v3/transaction.hpp" #include "storage/v3/vertex_accessor.hpp" #include "utils/bound.hpp" @@ -51,8 +53,8 @@ class LabelIndex { }; public: - LabelIndex(Indices *indices, Constraints *constraints, Config::Items config) - : indices_(indices), constraints_(constraints), config_(config) {} + LabelIndex(Indices *indices, Constraints *constraints, Config::Items config, const SchemaValidator &schema_validator) + : indices_(indices), constraints_(constraints), config_(config), schema_validator_{&schema_validator} {} /// @throw std::bad_alloc void UpdateOnAddLabel(LabelId label, Vertex *vertex, const Transaction &tx); @@ -72,7 +74,7 @@ class LabelIndex { class Iterable { public: Iterable(utils::SkipList::Accessor index_accessor, LabelId label, View view, Transaction *transaction, - Indices *indices, Constraints *constraints, Config::Items config); + Indices *indices, Constraints *constraints, Config::Items config, const SchemaValidator &schema_validator); class Iterator { public: @@ -105,13 +107,14 @@ class LabelIndex { Indices *indices_; Constraints *constraints_; Config::Items config_; + const SchemaValidator *schema_validator_; }; /// Returns an self with vertices visible from the given transaction. Iterable Vertices(LabelId label, View view, Transaction *transaction) { auto it = index_.find(label); MG_ASSERT(it != index_.end(), "Index for label {} doesn't exist", label.AsUint()); - return {it->second.access(), label, view, transaction, indices_, constraints_, config_}; + return {it->second.access(), label, view, transaction, indices_, constraints_, config_, *schema_validator_}; } int64_t ApproximateVertexCount(LabelId label) { @@ -129,6 +132,7 @@ class LabelIndex { Indices *indices_; Constraints *constraints_; Config::Items config_; + const SchemaValidator *schema_validator_; }; class LabelPropertyIndex { @@ -146,8 +150,9 @@ class LabelPropertyIndex { }; public: - LabelPropertyIndex(Indices *indices, Constraints *constraints, Config::Items config) - : indices_(indices), constraints_(constraints), config_(config) {} + LabelPropertyIndex(Indices *indices, Constraints *constraints, Config::Items config, + const SchemaValidator &schema_validator) + : indices_(indices), constraints_(constraints), config_(config), schema_validator_{&schema_validator} {} /// @throw std::bad_alloc void UpdateOnAddLabel(LabelId label, Vertex *vertex, const Transaction &tx); @@ -171,7 +176,7 @@ class LabelPropertyIndex { Iterable(utils::SkipList::Accessor index_accessor, LabelId label, PropertyId property, const std::optional> &lower_bound, const std::optional> &upper_bound, View view, Transaction *transaction, - Indices *indices, Constraints *constraints, Config::Items config); + Indices *indices, Constraints *constraints, Config::Items config, const SchemaValidator &schema_validator); class Iterator { public: @@ -208,16 +213,17 @@ class LabelPropertyIndex { Indices *indices_; Constraints *constraints_; Config::Items config_; + const SchemaValidator *schema_validator_; }; Iterable Vertices(LabelId label, PropertyId property, const std::optional> &lower_bound, - const std::optional> &upper_bound, View view, - Transaction *transaction) { + const std::optional> &upper_bound, View view, Transaction *transaction, + const SchemaValidator &schema_validator_) { auto it = index_.find({label, property}); MG_ASSERT(it != index_.end(), "Index for label {} and property {} doesn't exist", label.AsUint(), property.AsUint()); - return {it->second.access(), label, property, lower_bound, upper_bound, view, - transaction, indices_, constraints_, config_}; + return {it->second.access(), label, property, lower_bound, upper_bound, view, + transaction, indices_, constraints_, config_, schema_validator_}; } int64_t ApproximateVertexCount(LabelId label, PropertyId property) const { @@ -246,11 +252,13 @@ class LabelPropertyIndex { Indices *indices_; Constraints *constraints_; Config::Items config_; + const SchemaValidator *schema_validator_; }; struct Indices { - Indices(Constraints *constraints, Config::Items config) - : label_index(this, constraints, config), label_property_index(this, constraints, config) {} + Indices(Constraints *constraints, Config::Items config, const SchemaValidator &schema_validator) + : label_index(this, constraints, config, schema_validator), + label_property_index(this, constraints, config, schema_validator) {} // Disable copy and move because members hold pointer to `this`. Indices(const Indices &) = delete; diff --git a/src/storage/v3/replication/replication_server.cpp b/src/storage/v3/replication/replication_server.cpp index 0568598e1..4b9b38d77 100644 --- a/src/storage/v3/replication/replication_server.cpp +++ b/src/storage/v3/replication/replication_server.cpp @@ -10,6 +10,7 @@ // licenses/APL.txt. #include "storage/v3/replication/replication_server.hpp" + #include #include @@ -162,9 +163,10 @@ void Storage::ReplicationServer::SnapshotHandler(slk::Reader *req_reader, slk::B storage_->edges_.clear(); storage_->constraints_ = Constraints(); - storage_->indices_.label_index = LabelIndex(&storage_->indices_, &storage_->constraints_, storage_->config_.items); - storage_->indices_.label_property_index = - LabelPropertyIndex(&storage_->indices_, &storage_->constraints_, storage_->config_.items); + storage_->indices_.label_index = + LabelIndex(&storage_->indices_, &storage_->constraints_, storage_->config_.items, storage_->schema_validator_); + storage_->indices_.label_property_index = LabelPropertyIndex(&storage_->indices_, &storage_->constraints_, + storage_->config_.items, storage_->schema_validator_); try { spdlog::debug("Loading snapshot"); auto recovered_snapshot = durability::LoadSnapshot(*maybe_snapshot_path, &storage_->vertices_, &storage_->edges_, @@ -461,7 +463,8 @@ uint64_t Storage::ReplicationServer::ReadAndApplyDelta(durability::BaseDecoder * &transaction->transaction_, &storage_->indices_, &storage_->constraints_, - storage_->config_.items}; + storage_->config_.items, + storage_->schema_validator_}; auto ret = ea.SetProperty(transaction->NameToProperty(delta.vertex_edge_set_property.property), delta.vertex_edge_set_property.value); diff --git a/src/storage/v2/schema_validator.cpp b/src/storage/v3/schema_validator.cpp similarity index 96% rename from src/storage/v2/schema_validator.cpp rename to src/storage/v3/schema_validator.cpp index 4c3689a9f..4aa466f7f 100644 --- a/src/storage/v2/schema_validator.cpp +++ b/src/storage/v3/schema_validator.cpp @@ -9,15 +9,15 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. -#include "storage/v2/schema_validator.hpp" +#include "storage/v3/schema_validator.hpp" #include #include #include -#include "storage/v2/schemas.hpp" +#include "storage/v3/schemas.hpp" -namespace memgraph::storage { +namespace memgraph::storage::v3 { bool operator==(const SchemaViolation &lhs, const SchemaViolation &rhs) { return lhs.status == rhs.status && lhs.label == rhs.label && @@ -103,4 +103,4 @@ SchemaValidator::SchemaValidator(Schemas &schemas) : schemas_{schemas} {} return std::nullopt; } -} // namespace memgraph::storage +} // namespace memgraph::storage::v3 diff --git a/src/storage/v2/schema_validator.hpp b/src/storage/v3/schema_validator.hpp similarity index 90% rename from src/storage/v2/schema_validator.hpp rename to src/storage/v3/schema_validator.hpp index 6ad260138..a2da4609c 100644 --- a/src/storage/v2/schema_validator.hpp +++ b/src/storage/v3/schema_validator.hpp @@ -14,12 +14,12 @@ #include #include -#include "storage/v2/id_types.hpp" -#include "storage/v2/property_value.hpp" -#include "storage/v2/result.hpp" -#include "storage/v2/schemas.hpp" +#include "storage/v3/id_types.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/result.hpp" +#include "storage/v3/schemas.hpp" -namespace memgraph::storage { +namespace memgraph::storage::v3 { struct SchemaViolation { enum class ValidationStatus : uint8_t { @@ -60,10 +60,10 @@ class SchemaValidator { [[nodiscard]] std::optional ValidateLabelUpdate(LabelId label) const; private: - storage::Schemas &schemas_; + Schemas &schemas_; }; template using ResultSchema = utils::BasicResult, TValue>; -} // namespace memgraph::storage +} // namespace memgraph::storage::v3 diff --git a/src/storage/v2/schemas.cpp b/src/storage/v3/schemas.cpp similarity index 95% rename from src/storage/v2/schemas.cpp rename to src/storage/v3/schemas.cpp index 167d2946f..2f89c80c0 100644 --- a/src/storage/v2/schemas.cpp +++ b/src/storage/v3/schemas.cpp @@ -9,14 +9,14 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. -#include "storage/v2/schemas.hpp" +#include "storage/v3/schemas.hpp" #include #include -#include "storage/v2/property_value.hpp" +#include "storage/v3/property_value.hpp" -namespace memgraph::storage { +namespace memgraph::storage::v3 { bool operator==(const SchemaProperty &lhs, const SchemaProperty &rhs) { return lhs.property_id == rhs.property_id && lhs.type == rhs.type; @@ -109,4 +109,4 @@ std::string SchemaTypeToString(const common::SchemaType type) { } } -} // namespace memgraph::storage +} // namespace memgraph::storage::v3 diff --git a/src/storage/v2/schemas.hpp b/src/storage/v3/schemas.hpp similarity index 91% rename from src/storage/v2/schemas.hpp rename to src/storage/v3/schemas.hpp index c248c5b12..157ee7a35 100644 --- a/src/storage/v2/schemas.hpp +++ b/src/storage/v3/schemas.hpp @@ -18,12 +18,12 @@ #include #include "common/types.hpp" -#include "storage/v2/id_types.hpp" -#include "storage/v2/property_value.hpp" -#include "storage/v2/temporal.hpp" +#include "storage/v3/id_types.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/temporal.hpp" #include "utils/result.hpp" -namespace memgraph::storage { +namespace memgraph::storage::v3 { struct SchemaProperty { PropertyId property_id; @@ -67,4 +67,4 @@ std::optional PropertyTypeToSchemaType(const PropertyValue & std::string SchemaTypeToString(common::SchemaType type); -} // namespace memgraph::storage +} // namespace memgraph::storage::v3 diff --git a/src/storage/v3/storage.cpp b/src/storage/v3/storage.cpp index 5ab334893..28fd250fb 100644 --- a/src/storage/v3/storage.cpp +++ b/src/storage/v3/storage.cpp @@ -15,9 +15,11 @@ #include #include #include +#include #include #include +#include #include "io/network/endpoint.hpp" #include "storage/v3/constraints.hpp" @@ -27,6 +29,7 @@ #include "storage/v3/durability/snapshot.hpp" #include "storage/v3/durability/wal.hpp" #include "storage/v3/edge_accessor.hpp" +#include "storage/v3/id_types.hpp" #include "storage/v3/indices.hpp" #include "storage/v3/mvcc.hpp" #include "storage/v3/replication/config.hpp" @@ -35,10 +38,12 @@ #include "storage/v3/replication/rpc.hpp" #include "storage/v3/transaction.hpp" #include "storage/v3/vertex_accessor.hpp" +#include "utils/exceptions.hpp" #include "utils/file.hpp" #include "utils/logging.hpp" #include "utils/memory_tracker.hpp" #include "utils/message.hpp" +#include "utils/result.hpp" #include "utils/rw_lock.hpp" #include "utils/spin_lock.hpp" #include "utils/stat.hpp" @@ -54,9 +59,9 @@ inline constexpr uint16_t kEpochHistoryRetention = 1000; auto AdvanceToVisibleVertex(utils::SkipList::Iterator it, utils::SkipList::Iterator end, std::optional *vertex, Transaction *tx, View view, Indices *indices, - Constraints *constraints, Config::Items config) { + Constraints *constraints, Config::Items config, const SchemaValidator &schema_validator) { while (it != end) { - *vertex = VertexAccessor::Create(&*it, tx, indices, constraints, config, view); + *vertex = VertexAccessor::Create(&*it, tx, indices, constraints, config, schema_validator, view); if (!*vertex) { ++it; continue; @@ -69,14 +74,14 @@ auto AdvanceToVisibleVertex(utils::SkipList::Iterator it, utils::SkipLis AllVerticesIterable::Iterator::Iterator(AllVerticesIterable *self, utils::SkipList::Iterator it) : self_(self), it_(AdvanceToVisibleVertex(it, self->vertices_accessor_.end(), &self->vertex_, self->transaction_, self->view_, - self->indices_, self_->constraints_, self->config_)) {} + self->indices_, self_->constraints_, self->config_, *self->schema_validator_)) {} VertexAccessor AllVerticesIterable::Iterator::operator*() const { return *self_->vertex_; } AllVerticesIterable::Iterator &AllVerticesIterable::Iterator::operator++() { ++it_; it_ = AdvanceToVisibleVertex(it_, self_->vertices_accessor_.end(), &self_->vertex_, self_->transaction_, self_->view_, - self_->indices_, self_->constraints_, self_->config_); + self_->indices_, self_->constraints_, self_->config_, *self_->schema_validator_); return *this; } @@ -300,7 +305,8 @@ bool VerticesIterable::Iterator::operator==(const Iterator &other) const { } Storage::Storage(Config config) - : indices_(&constraints_, config.items), + : schema_validator_(schemas_), + indices_(&constraints_, config.items, schema_validator_), isolation_level_(config.transaction.isolation_level), config_(config), snapshot_directory_(config_.durability.storage_directory / durability::kSnapshotDirectory), @@ -462,7 +468,8 @@ Storage::Accessor::~Accessor() { FinalizeTransaction(); } -VertexAccessor Storage::Accessor::CreateVertex() { +// TODO Remove when import csv is fixed +[[deprecated]] VertexAccessor Storage::Accessor::CreateVertex() { OOMExceptionEnabler oom_exception; auto gid = storage_->vertex_id_.fetch_add(1, std::memory_order_acq_rel); auto acc = storage_->vertices_.access(); @@ -470,10 +477,12 @@ VertexAccessor Storage::Accessor::CreateVertex() { auto [it, inserted] = acc.insert(Vertex{Gid::FromUint(gid), delta}); MG_ASSERT(inserted, "The vertex must be inserted here!"); MG_ASSERT(it != acc.end(), "Invalid Vertex accessor!"); + delta->prev.Set(&*it); - return {&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_}; + return {&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_, storage_->schema_validator_}; } +// TODO Remove when replication is fixed VertexAccessor Storage::Accessor::CreateVertex(Gid gid) { OOMExceptionEnabler oom_exception; // NOTE: When we update the next `vertex_id_` here we perform a RMW @@ -486,18 +495,53 @@ VertexAccessor Storage::Accessor::CreateVertex(Gid gid) { std::memory_order_release); auto acc = storage_->vertices_.access(); auto *delta = CreateDeleteObjectDelta(&transaction_); - auto [it, inserted] = acc.insert(Vertex{gid, delta}); + auto [it, inserted] = acc.insert(Vertex{gid}); MG_ASSERT(inserted, "The vertex must be inserted here!"); MG_ASSERT(it != acc.end(), "Invalid Vertex accessor!"); delta->prev.Set(&*it); - return {&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_}; + return {&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_, storage_->schema_validator_}; +} + +ResultSchema Storage::Accessor::CreateVertexAndValidate( + LabelId primary_label, const std::vector &labels, + const std::vector> &properties) { + auto maybe_schema_violation = GetSchemaValidator().ValidateVertexCreate(primary_label, labels, properties); + if (maybe_schema_violation) { + return {std::move(*maybe_schema_violation)}; + } + OOMExceptionEnabler oom_exception; + auto gid = storage_->vertex_id_.fetch_add(1, std::memory_order_acq_rel); + auto acc = storage_->vertices_.access(); + auto *delta = CreateDeleteObjectDelta(&transaction_); + auto [it, inserted] = acc.insert(Vertex{Gid::FromUint(gid), delta, primary_label}); + MG_ASSERT(inserted, "The vertex must be inserted here!"); + MG_ASSERT(it != acc.end(), "Invalid Vertex accessor!"); + delta->prev.Set(&*it); + + auto va = VertexAccessor{ + &*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_, storage_->schema_validator_}; + for (const auto label : labels) { + const auto maybe_error = va.AddLabel(label); + if (maybe_error.HasError()) { + return {maybe_error.GetError()}; + } + } + // Set properties + for (auto [property_id, property_value] : properties) { + const auto maybe_error = va.SetProperty(property_id, property_value); + if (maybe_error.HasError()) { + return {maybe_error.GetError()}; + } + } + return va; } std::optional Storage::Accessor::FindVertex(Gid gid, View view) { auto acc = storage_->vertices_.access(); auto it = acc.find(gid); if (it == acc.end()) return std::nullopt; - return VertexAccessor::Create(&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_, view); + return VertexAccessor::Create(&*it, &transaction_, &storage_->indices_, &storage_->constraints_, config_, + storage_->schema_validator_, view); } Result> Storage::Accessor::DeleteVertex(VertexAccessor *vertex) { @@ -520,7 +564,7 @@ Result> Storage::Accessor::DeleteVertex(VertexAcce vertex_ptr->deleted = true; return std::make_optional(vertex_ptr, &transaction_, &storage_->indices_, &storage_->constraints_, - config_, true); + config_, storage_->schema_validator_, true); } Result>>> Storage::Accessor::DetachDeleteVertex( @@ -550,7 +594,7 @@ Result>>> Stor for (const auto &item : in_edges) { auto [edge_type, from_vertex, edge] = item; EdgeAccessor e(edge, edge_type, from_vertex, vertex_ptr, &transaction_, &storage_->indices_, - &storage_->constraints_, config_); + &storage_->constraints_, config_, storage_->schema_validator_); auto ret = DeleteEdge(&e); if (ret.HasError()) { MG_ASSERT(ret.GetError() == Error::SERIALIZATION_ERROR, "Invalid database state!"); @@ -564,7 +608,7 @@ Result>>> Stor for (const auto &item : out_edges) { auto [edge_type, to_vertex, edge] = item; EdgeAccessor e(edge, edge_type, vertex_ptr, to_vertex, &transaction_, &storage_->indices_, &storage_->constraints_, - config_); + config_, storage_->schema_validator_); auto ret = DeleteEdge(&e); if (ret.HasError()) { MG_ASSERT(ret.GetError() == Error::SERIALIZATION_ERROR, "Invalid database state!"); @@ -590,7 +634,8 @@ Result>>> Stor vertex_ptr->deleted = true; return std::make_optional( - VertexAccessor{vertex_ptr, &transaction_, &storage_->indices_, &storage_->constraints_, config_, true}, + VertexAccessor{vertex_ptr, &transaction_, &storage_->indices_, &storage_->constraints_, config_, + storage_->schema_validator_, true}, std::move(deleted_edges)); } @@ -650,7 +695,7 @@ Result Storage::Accessor::CreateEdge(VertexAccessor *from, VertexA storage_->edge_count_.fetch_add(1, std::memory_order_acq_rel); return EdgeAccessor(edge, edge_type, from_vertex, to_vertex, &transaction_, &storage_->indices_, - &storage_->constraints_, config_); + &storage_->constraints_, config_, storage_->schema_validator_); } Result Storage::Accessor::CreateEdge(VertexAccessor *from, VertexAccessor *to, EdgeTypeId edge_type, @@ -718,7 +763,7 @@ Result Storage::Accessor::CreateEdge(VertexAccessor *from, VertexA storage_->edge_count_.fetch_add(1, std::memory_order_acq_rel); return EdgeAccessor(edge, edge_type, from_vertex, to_vertex, &transaction_, &storage_->indices_, - &storage_->constraints_, config_); + &storage_->constraints_, config_, storage_->schema_validator_); } Result> Storage::Accessor::DeleteEdge(EdgeAccessor *edge) { @@ -802,7 +847,8 @@ Result> Storage::Accessor::DeleteEdge(EdgeAccessor * storage_->edge_count_.fetch_add(-1, std::memory_order_acq_rel); return std::make_optional(edge_ref, edge_type, from_vertex, to_vertex, &transaction_, - &storage_->indices_, &storage_->constraints_, config_, true); + &storage_->indices_, &storage_->constraints_, config_, + storage_->schema_validator_, true); } const std::string &Storage::Accessor::LabelToName(LabelId label) const { return storage_->LabelToName(label); } @@ -815,11 +861,11 @@ const std::string &Storage::Accessor::EdgeTypeToName(EdgeTypeId edge_type) const return storage_->EdgeTypeToName(edge_type); } -LabelId Storage::Accessor::NameToLabel(const std::string_view &name) { return storage_->NameToLabel(name); } +LabelId Storage::Accessor::NameToLabel(const std::string_view name) { return storage_->NameToLabel(name); } -PropertyId Storage::Accessor::NameToProperty(const std::string_view &name) { return storage_->NameToProperty(name); } +PropertyId Storage::Accessor::NameToProperty(const std::string_view name) { return storage_->NameToProperty(name); } -EdgeTypeId Storage::Accessor::NameToEdgeType(const std::string_view &name) { return storage_->NameToEdgeType(name); } +EdgeTypeId Storage::Accessor::NameToEdgeType(const std::string_view name) { return storage_->NameToEdgeType(name); } void Storage::Accessor::AdvanceCommand() { ++transaction_.command_id; } @@ -846,11 +892,11 @@ utils::BasicResult Storage::Accessor::Commit( auto validation_result = ValidateExistenceConstraints(*prev.vertex, storage_->constraints_); if (validation_result) { Abort(); - return *validation_result; + return {*validation_result}; } } - // Result of validating the vertex against unqiue constraints. It has to be + // Result of validating the vertex against unique constraints. It has to be // declared outside of the critical section scope because its value is // tested for Abort call which has to be done out of the scope. std::optional unique_constraint_violation; @@ -931,7 +977,7 @@ utils::BasicResult Storage::Accessor::Commit( if (unique_constraint_violation) { Abort(); - return *unique_constraint_violation; + return {*unique_constraint_violation}; } } is_transaction_active_ = false; @@ -1124,13 +1170,13 @@ const std::string &Storage::EdgeTypeToName(EdgeTypeId edge_type) const { return name_id_mapper_.IdToName(edge_type.AsUint()); } -LabelId Storage::NameToLabel(const std::string_view &name) { return LabelId::FromUint(name_id_mapper_.NameToId(name)); } +LabelId Storage::NameToLabel(const std::string_view name) { return LabelId::FromUint(name_id_mapper_.NameToId(name)); } -PropertyId Storage::NameToProperty(const std::string_view &name) { +PropertyId Storage::NameToProperty(const std::string_view name) { return PropertyId::FromUint(name_id_mapper_.NameToId(name)); } -EdgeTypeId Storage::NameToEdgeType(const std::string_view &name) { +EdgeTypeId Storage::NameToEdgeType(const std::string_view name) { return EdgeTypeId::FromUint(name_id_mapper_.NameToId(name)); } @@ -1232,11 +1278,29 @@ UniqueConstraints::DeletionStatus Storage::DropUniqueConstraint( return UniqueConstraints::DeletionStatus::SUCCESS; } +const SchemaValidator &Storage::Accessor::GetSchemaValidator() const { return storage_->schema_validator_; } + ConstraintsInfo Storage::ListAllConstraints() const { std::shared_lock storage_guard_(main_lock_); return {ListExistenceConstraints(constraints_), constraints_.unique_constraints.ListConstraints()}; } +SchemasInfo Storage::ListAllSchemas() const { + std::shared_lock storage_guard_(main_lock_); + return {schemas_.ListSchemas()}; +} + +const Schemas::Schema *Storage::GetSchema(const LabelId primary_label) const { + std::shared_lock storage_guard_(main_lock_); + return schemas_.GetSchema(primary_label); +} + +bool Storage::CreateSchema(const LabelId primary_label, const std::vector &schemas_types) { + return schemas_.CreateSchema(primary_label, schemas_types); +} + +bool Storage::DropSchema(const LabelId primary_label) { return schemas_.DropSchema(primary_label); } + StorageInfo Storage::GetInfo() const { auto vertex_count = vertices_.size(); auto edge_count = edge_count_.load(std::memory_order_acquire); @@ -1253,21 +1317,22 @@ VerticesIterable Storage::Accessor::Vertices(LabelId label, View view) { } VerticesIterable Storage::Accessor::Vertices(LabelId label, PropertyId property, View view) { - return VerticesIterable(storage_->indices_.label_property_index.Vertices(label, property, std::nullopt, std::nullopt, - view, &transaction_)); + return VerticesIterable(storage_->indices_.label_property_index.Vertices( + label, property, std::nullopt, std::nullopt, view, &transaction_, storage_->schema_validator_)); } VerticesIterable Storage::Accessor::Vertices(LabelId label, PropertyId property, const PropertyValue &value, View view) { return VerticesIterable(storage_->indices_.label_property_index.Vertices( - label, property, utils::MakeBoundInclusive(value), utils::MakeBoundInclusive(value), view, &transaction_)); + label, property, utils::MakeBoundInclusive(value), utils::MakeBoundInclusive(value), view, &transaction_, + storage_->schema_validator_)); } VerticesIterable Storage::Accessor::Vertices(LabelId label, PropertyId property, const std::optional> &lower_bound, const std::optional> &upper_bound, View view) { - return VerticesIterable( - storage_->indices_.label_property_index.Vertices(label, property, lower_bound, upper_bound, view, &transaction_)); + return VerticesIterable(storage_->indices_.label_property_index.Vertices( + label, property, lower_bound, upper_bound, view, &transaction_, storage_->schema_validator_)); } Transaction Storage::CreateTransaction(IsolationLevel isolation_level) { @@ -1795,8 +1860,8 @@ utils::BasicResult Storage::CreateSnapshot() { // Create snapshot. durability::CreateSnapshot(&transaction, snapshot_directory_, wal_directory_, config_.durability.snapshot_retention_count, &vertices_, &edges_, &name_id_mapper_, - &indices_, &constraints_, config_.items, uuid_, epoch_id_, epoch_history_, - &file_retainer_); + &indices_, &constraints_, config_.items, schema_validator_, uuid_, epoch_id_, + epoch_history_, &file_retainer_); // Finalize snapshot transaction. commit_log_->MarkFinished(transaction.start_timestamp); diff --git a/src/storage/v3/storage.hpp b/src/storage/v3/storage.hpp index 40fe708b1..93a6687ce 100644 --- a/src/storage/v3/storage.hpp +++ b/src/storage/v3/storage.hpp @@ -16,8 +16,10 @@ #include #include #include +#include #include "io/network/endpoint.hpp" +#include "kvstore/kvstore.hpp" #include "storage/v3/commit_log.hpp" #include "storage/v3/config.hpp" #include "storage/v3/constraints.hpp" @@ -25,14 +27,19 @@ #include "storage/v3/durability/wal.hpp" #include "storage/v3/edge.hpp" #include "storage/v3/edge_accessor.hpp" +#include "storage/v3/id_types.hpp" #include "storage/v3/indices.hpp" #include "storage/v3/isolation_level.hpp" #include "storage/v3/mvcc.hpp" #include "storage/v3/name_id_mapper.hpp" +#include "storage/v3/property_value.hpp" #include "storage/v3/result.hpp" +#include "storage/v3/schema_validator.hpp" +#include "storage/v3/schemas.hpp" #include "storage/v3/transaction.hpp" #include "storage/v3/vertex.hpp" #include "storage/v3/vertex_accessor.hpp" +#include "utils/exceptions.hpp" #include "utils/file_locker.hpp" #include "utils/on_scope_exit.hpp" #include "utils/rw_lock.hpp" @@ -66,6 +73,7 @@ class AllVerticesIterable final { Indices *indices_; Constraints *constraints_; Config::Items config_; + const SchemaValidator *schema_validator_; std::optional vertex_; public: @@ -86,13 +94,15 @@ class AllVerticesIterable final { }; AllVerticesIterable(utils::SkipList::Accessor vertices_accessor, Transaction *transaction, View view, - Indices *indices, Constraints *constraints, Config::Items config) + Indices *indices, Constraints *constraints, Config::Items config, + SchemaValidator *schema_validator) : vertices_accessor_(std::move(vertices_accessor)), transaction_(transaction), view_(view), indices_(indices), constraints_(constraints), - config_(config) {} + config_(config), + schema_validator_(schema_validator) {} Iterator begin() { return {this, vertices_accessor_.begin()}; } Iterator end() { return {this, vertices_accessor_.end()}; } @@ -173,6 +183,11 @@ struct ConstraintsInfo { std::vector>> unique; }; +/// Structure used to return information about existing schemas in the storage +struct SchemasInfo { + Schemas::SchemasList schemas; +}; + /// Structure used to return information about the storage. struct StorageInfo { uint64_t vertex_count; @@ -190,10 +205,6 @@ class Storage final { /// @throw std::bad_alloc explicit Storage(Config config = Config()); - Storage(const Storage &) = delete; - Storage(Storage &&) = delete; - Storage &operator=(const Storage &) = delete; - Storage &operator=(Storage &&) = delete; ~Storage(); class Accessor final { @@ -213,15 +224,21 @@ class Storage final { ~Accessor(); - /// @throw std::bad_alloc VertexAccessor CreateVertex(); + VertexAccessor CreateVertex(Gid gid); + + /// @throw std::bad_alloc + ResultSchema CreateVertexAndValidate( + LabelId primary_label, const std::vector &labels, + const std::vector> &properties); + std::optional FindVertex(Gid gid, View view); VerticesIterable Vertices(View view) { return VerticesIterable(AllVerticesIterable(storage_->vertices_.access(), &transaction_, view, - &storage_->indices_, &storage_->constraints_, - storage_->config_.items)); + &storage_->indices_, &storage_->constraints_, storage_->config_.items, + &storage_->schema_validator_)); } VerticesIterable Vertices(LabelId label, View view); @@ -287,13 +304,13 @@ class Storage final { const std::string &EdgeTypeToName(EdgeTypeId edge_type) const; /// @throw std::bad_alloc if unable to insert a new mapping - LabelId NameToLabel(const std::string_view &name); + LabelId NameToLabel(std::string_view name); /// @throw std::bad_alloc if unable to insert a new mapping - PropertyId NameToProperty(const std::string_view &name); + PropertyId NameToProperty(std::string_view name); /// @throw std::bad_alloc if unable to insert a new mapping - EdgeTypeId NameToEdgeType(const std::string_view &name); + EdgeTypeId NameToEdgeType(std::string_view name); bool LabelIndexExists(LabelId label) const { return storage_->indices_.label_index.IndexExists(label); } @@ -310,6 +327,10 @@ class Storage final { storage_->constraints_.unique_constraints.ListConstraints()}; } + const SchemaValidator &GetSchemaValidator() const; + + SchemasInfo ListAllSchemas() const { return {storage_->schemas_.ListSchemas()}; } + void AdvanceCommand(); /// Commit returns `ConstraintViolation` if the changes made by this @@ -325,7 +346,7 @@ class Storage final { private: /// @throw std::bad_alloc - VertexAccessor CreateVertex(Gid gid); + VertexAccessor CreateVertex(Gid gid, LabelId primary_label); /// @throw std::bad_alloc Result CreateEdge(VertexAccessor *from, VertexAccessor *to, EdgeTypeId edge_type, Gid gid); @@ -347,13 +368,13 @@ class Storage final { const std::string &EdgeTypeToName(EdgeTypeId edge_type) const; /// @throw std::bad_alloc if unable to insert a new mapping - LabelId NameToLabel(const std::string_view &name); + LabelId NameToLabel(std::string_view name); /// @throw std::bad_alloc if unable to insert a new mapping - PropertyId NameToProperty(const std::string_view &name); + PropertyId NameToProperty(std::string_view name); /// @throw std::bad_alloc if unable to insert a new mapping - EdgeTypeId NameToEdgeType(const std::string_view &name); + EdgeTypeId NameToEdgeType(std::string_view name); /// @throw std::bad_alloc bool CreateIndex(LabelId label, std::optional desired_commit_timestamp = {}); @@ -368,7 +389,7 @@ class Storage final { IndicesInfo ListAllIndices() const; /// Creates an existence constraint. Returns true if the constraint was - /// successfuly added, false if it already exists and a `ConstraintViolation` + /// successfully added, false if it already exists and a `ConstraintViolation` /// if there is an existing vertex violating the constraint. /// /// @throw std::bad_alloc @@ -406,6 +427,14 @@ class Storage final { ConstraintsInfo ListAllConstraints() const; + SchemasInfo ListAllSchemas() const; + + const Schemas::Schema *GetSchema(LabelId primary_label) const; + + bool CreateSchema(LabelId primary_label, const std::vector &schemas_types); + + bool DropSchema(LabelId primary_label); + StorageInfo GetInfo() const; bool LockPath(); @@ -415,7 +444,12 @@ class Storage final { bool SetMainReplicationRole(); - enum class RegisterReplicaError : uint8_t { NAME_EXISTS, END_POINT_EXISTS, CONNECTION_FAILED }; + enum class RegisterReplicaError : uint8_t { + NAME_EXISTS, + END_POINT_EXISTS, + CONNECTION_FAILED, + COULD_NOT_BE_PERSISTED + }; /// @pre The instance should have a MAIN role /// @pre Timeout can only be set for SYNC replication @@ -493,8 +527,10 @@ class Storage final { NameIdMapper name_id_mapper_; + SchemaValidator schema_validator_; Constraints constraints_; Indices indices_; + Schemas schemas_; // Transaction engine utils::SpinLock engine_lock_; diff --git a/src/storage/v3/vertex.hpp b/src/storage/v3/vertex.hpp index 3bf1bbb49..618d23849 100644 --- a/src/storage/v3/vertex.hpp +++ b/src/storage/v3/vertex.hpp @@ -19,18 +19,39 @@ #include "storage/v3/edge_ref.hpp" #include "storage/v3/id_types.hpp" #include "storage/v3/property_store.hpp" +#include "utils/algorithm.hpp" #include "utils/spin_lock.hpp" namespace memgraph::storage::v3 { struct Vertex { + Vertex(Gid gid, Delta *delta, LabelId primary_label) + : gid(gid), primary_label{primary_label}, deleted(false), delta(delta) { + MG_ASSERT(delta == nullptr || delta->action == Delta::Action::DELETE_OBJECT, + "Vertex must be created with an initial DELETE_OBJECT delta!"); + } + + // TODO remove this when import replication is solved + Vertex(Gid gid, LabelId primary_label) : gid(gid), primary_label{primary_label}, deleted(false) { + MG_ASSERT(delta == nullptr || delta->action == Delta::Action::DELETE_OBJECT, + "Vertex must be created with an initial DELETE_OBJECT delta!"); + } + + // TODO remove this when import csv is solved Vertex(Gid gid, Delta *delta) : gid(gid), deleted(false), delta(delta) { MG_ASSERT(delta == nullptr || delta->action == Delta::Action::DELETE_OBJECT, "Vertex must be created with an initial DELETE_OBJECT delta!"); } + // TODO remove this when import replication is solved + explicit Vertex(Gid gid) : gid(gid), deleted(false) { + MG_ASSERT(delta == nullptr || delta->action == Delta::Action::DELETE_OBJECT, + "Vertex must be created with an initial DELETE_OBJECT delta!"); + } + Gid gid; + LabelId primary_label; std::vector labels; PropertyStore properties; @@ -52,4 +73,8 @@ inline bool operator<(const Vertex &first, const Vertex &second) { return first. inline bool operator==(const Vertex &first, const Gid &second) { return first.gid == second; } inline bool operator<(const Vertex &first, const Gid &second) { return first.gid < second; } +inline bool VertexHasLabel(const Vertex &vertex, const LabelId label) { + return vertex.primary_label == label || utils::Contains(vertex.labels, label); +} + } // namespace memgraph::storage::v3 diff --git a/src/storage/v3/vertex_accessor.cpp b/src/storage/v3/vertex_accessor.cpp index f9f60fa5e..867537d47 100644 --- a/src/storage/v3/vertex_accessor.cpp +++ b/src/storage/v3/vertex_accessor.cpp @@ -18,6 +18,8 @@ #include "storage/v3/indices.hpp" #include "storage/v3/mvcc.hpp" #include "storage/v3/property_value.hpp" +#include "storage/v3/schema_validator.hpp" +#include "storage/v3/vertex.hpp" #include "utils/logging.hpp" #include "utils/memory_tracker.hpp" @@ -61,12 +63,13 @@ std::pair IsVisible(Vertex *vertex, Transaction *transaction, View v } // namespace detail std::optional VertexAccessor::Create(Vertex *vertex, Transaction *transaction, Indices *indices, - Constraints *constraints, Config::Items config, View view) { + Constraints *constraints, Config::Items config, + const SchemaValidator &schema_validator, View view) { if (const auto [exists, deleted] = detail::IsVisible(vertex, transaction, view); !exists || deleted) { return std::nullopt; } - return VertexAccessor{vertex, transaction, indices, constraints, config}; + return VertexAccessor{vertex, transaction, indices, constraints, config, schema_validator}; } bool VertexAccessor::IsVisible(View view) const { @@ -93,6 +96,28 @@ Result VertexAccessor::AddLabel(LabelId label) { return true; } +ResultSchema VertexAccessor::AddLabelAndValidate(LabelId label) { + if (const auto maybe_violation_error = vertex_validator_.ValidateAddLabel(label); maybe_violation_error) { + return {*maybe_violation_error}; + } + utils::MemoryTracker::OutOfMemoryExceptionEnabler oom_exception; + std::lock_guard guard(vertex_->lock); + + if (!PrepareForWrite(transaction_, vertex_)) return {Error::SERIALIZATION_ERROR}; + + if (vertex_->deleted) return {Error::DELETED_OBJECT}; + + if (std::find(vertex_->labels.begin(), vertex_->labels.end(), label) != vertex_->labels.end()) return false; + + CreateAndLinkDelta(transaction_, vertex_, Delta::RemoveLabelTag(), label); + + vertex_->labels.push_back(label); + + UpdateOnAddLabel(indices_, label, vertex_, *transaction_); + + return true; +} + Result VertexAccessor::RemoveLabel(LabelId label) { std::lock_guard guard(vertex_->lock); @@ -110,6 +135,26 @@ Result VertexAccessor::RemoveLabel(LabelId label) { return true; } +ResultSchema VertexAccessor::RemoveLabelAndValidate(LabelId label) { + if (const auto maybe_violation_error = vertex_validator_.ValidateRemoveLabel(label); maybe_violation_error) { + return {*maybe_violation_error}; + } + std::lock_guard guard(vertex_->lock); + + if (!PrepareForWrite(transaction_, vertex_)) return {Error::SERIALIZATION_ERROR}; + + if (vertex_->deleted) return {Error::DELETED_OBJECT}; + + auto it = std::find(vertex_->labels.begin(), vertex_->labels.end(), label); + if (it == vertex_->labels.end()) return false; + + CreateAndLinkDelta(transaction_, vertex_, Delta::AddLabelTag(), label); + + std::swap(*it, *vertex_->labels.rbegin()); + vertex_->labels.pop_back(); + return true; +} + Result VertexAccessor::HasLabel(LabelId label, View view) const { bool exists = true; bool deleted = false; @@ -118,7 +163,7 @@ Result VertexAccessor::HasLabel(LabelId label, View view) const { { std::lock_guard guard(vertex_->lock); deleted = vertex_->deleted; - has_label = std::find(vertex_->labels.begin(), vertex_->labels.end(), label) != vertex_->labels.end(); + has_label = VertexHasLabel(*vertex_, label); delta = vertex_->delta; } ApplyDeltasForRead(transaction_, delta, view, [&exists, &deleted, &has_label, label](const Delta &delta) { @@ -158,6 +203,40 @@ Result VertexAccessor::HasLabel(LabelId label, View view) const { return has_label; } +Result VertexAccessor::PrimaryLabel(const View view) const { + bool exists = true; + bool deleted = false; + Delta *delta = nullptr; + { + std::lock_guard guard(vertex_->lock); + deleted = vertex_->deleted; + delta = vertex_->delta; + } + ApplyDeltasForRead(transaction_, delta, view, [&exists, &deleted](const Delta &delta) { + switch (delta.action) { + case Delta::Action::DELETE_OBJECT: { + exists = false; + break; + } + case Delta::Action::RECREATE_OBJECT: { + deleted = false; + break; + } + case Delta::Action::ADD_LABEL: + case Delta::Action::REMOVE_LABEL: + case Delta::Action::SET_PROPERTY: + case Delta::Action::ADD_IN_EDGE: + case Delta::Action::ADD_OUT_EDGE: + case Delta::Action::REMOVE_IN_EDGE: + case Delta::Action::REMOVE_OUT_EDGE: + break; + } + }); + if (!exists) return Error::NONEXISTENT_OBJECT; + if (!for_deleted_ && deleted) return Error::DELETED_OBJECT; + return vertex_->primary_label; +} + Result> VertexAccessor::Labels(View view) const { bool exists = true; bool deleted = false; @@ -230,6 +309,36 @@ Result VertexAccessor::SetProperty(PropertyId property, const Pro return std::move(current_value); } +ResultSchema VertexAccessor::SetPropertyAndValidate(PropertyId property, const PropertyValue &value) { + if (auto maybe_violation_error = vertex_validator_.ValidatePropertyUpdate(property); maybe_violation_error) { + return {*maybe_violation_error}; + } + utils::MemoryTracker::OutOfMemoryExceptionEnabler oom_exception; + std::lock_guard guard(vertex_->lock); + + if (!PrepareForWrite(transaction_, vertex_)) { + return {Error::SERIALIZATION_ERROR}; + } + + if (vertex_->deleted) { + return {Error::DELETED_OBJECT}; + } + + auto current_value = vertex_->properties.GetProperty(property); + // We could skip setting the value if the previous one is the same to the new + // one. This would save some memory as a delta would not be created as well as + // avoid copying the value. The reason we are not doing that is because the + // current code always follows the logical pattern of "create a delta" and + // "modify in-place". Additionally, the created delta will make other + // transactions get a SERIALIZATION_ERROR. + CreateAndLinkDelta(transaction_, vertex_, Delta::SetPropertyTag(), property, current_value); + vertex_->properties.SetProperty(property, value); + + UpdateOnSetProperty(indices_, property, value, vertex_, *transaction_); + + return std::move(current_value); +} + Result> VertexAccessor::ClearProperties() { std::lock_guard guard(vertex_->lock); @@ -414,7 +523,8 @@ Result> VertexAccessor::InEdges(View view, const std:: ret.reserve(in_edges.size()); for (const auto &item : in_edges) { const auto &[edge_type, from_vertex, edge] = item; - ret.emplace_back(edge, edge_type, from_vertex, vertex_, transaction_, indices_, constraints_, config_); + ret.emplace_back(edge, edge_type, from_vertex, vertex_, transaction_, indices_, constraints_, config_, + *vertex_validator_.schema_validator); } return std::move(ret); } @@ -494,7 +604,8 @@ Result> VertexAccessor::OutEdges(View view, const std: ret.reserve(out_edges.size()); for (const auto &item : out_edges) { const auto &[edge_type, to_vertex, edge] = item; - ret.emplace_back(edge, edge_type, vertex_, to_vertex, transaction_, indices_, constraints_, config_); + ret.emplace_back(edge, edge_type, vertex_, to_vertex, transaction_, indices_, constraints_, config_, + *vertex_validator_.schema_validator); } return std::move(ret); } @@ -575,4 +686,21 @@ Result VertexAccessor::OutDegree(View view) const { return degree; } +VertexAccessor::VertexValidator::VertexValidator(const SchemaValidator &schema_validator, const Vertex *vertex) + : schema_validator{&schema_validator}, vertex_{vertex} {} + +[[nodiscard]] std::optional VertexAccessor::VertexValidator::ValidatePropertyUpdate( + PropertyId property_id) const { + MG_ASSERT(vertex_ != nullptr, "Cannot validate vertex which is nullptr"); + return schema_validator->ValidatePropertyUpdate(vertex_->primary_label, property_id); +}; + +[[nodiscard]] std::optional VertexAccessor::VertexValidator::ValidateAddLabel(LabelId label) const { + return schema_validator->ValidateLabelUpdate(label); +} + +[[nodiscard]] std::optional VertexAccessor::VertexValidator::ValidateRemoveLabel(LabelId label) const { + return schema_validator->ValidateLabelUpdate(label); +} + } // namespace memgraph::storage::v3 diff --git a/src/storage/v3/vertex_accessor.hpp b/src/storage/v3/vertex_accessor.hpp index 88a2828c8..6cc7d4ff3 100644 --- a/src/storage/v3/vertex_accessor.hpp +++ b/src/storage/v3/vertex_accessor.hpp @@ -13,6 +13,8 @@ #include +#include "storage/v3/id_types.hpp" +#include "storage/v3/schema_validator.hpp" #include "storage/v3/vertex.hpp" #include "storage/v3/config.hpp" @@ -29,20 +31,39 @@ struct Constraints; class VertexAccessor final { private: + struct VertexValidator { + // TODO(jbajic) Beware since vertex is pointer it will be accessed even as nullptr + explicit VertexValidator(const SchemaValidator &schema_validator, const Vertex *vertex); + + [[nodiscard]] std::optional ValidatePropertyUpdate(PropertyId property_id) const; + + [[nodiscard]] std::optional ValidateAddLabel(LabelId label) const; + + [[nodiscard]] std::optional ValidateRemoveLabel(LabelId label) const; + + const SchemaValidator *schema_validator; + + private: + const Vertex *vertex_; + }; friend class Storage; public: + // Be careful when using VertexAccessor since it can be instantiated with + // nullptr values VertexAccessor(Vertex *vertex, Transaction *transaction, Indices *indices, Constraints *constraints, - Config::Items config, bool for_deleted = false) + Config::Items config, const SchemaValidator &schema_validator, bool for_deleted = false) : vertex_(vertex), transaction_(transaction), indices_(indices), constraints_(constraints), config_(config), + vertex_validator_{schema_validator, vertex}, for_deleted_(for_deleted) {} static std::optional Create(Vertex *vertex, Transaction *transaction, Indices *indices, - Constraints *constraints, Config::Items config, View view); + Constraints *constraints, Config::Items config, + const SchemaValidator &schema_validator, View view); /// @return true if the object is visible from the current transaction bool IsVisible(View view) const; @@ -52,11 +73,23 @@ class VertexAccessor final { /// @throw std::bad_alloc Result AddLabel(LabelId label); + /// Add a label and return `true` if insertion took place. + /// `false` is returned if the label already existed, or SchemaViolation + /// if adding the label has violated one of the schema constraints. + /// @throw std::bad_alloc + ResultSchema AddLabelAndValidate(LabelId label); + /// Remove a label and return `true` if deletion took place. /// `false` is returned if the vertex did not have a label already. /// @throw std::bad_alloc Result RemoveLabel(LabelId label); + /// Remove a label and return `true` if deletion took place. + /// `false` is returned if the vertex did not have a label already. or SchemaViolation + /// if adding the label has violated one of the schema constraints. + /// @throw std::bad_alloc + ResultSchema RemoveLabelAndValidate(LabelId label); + Result HasLabel(LabelId label, View view) const; /// @throw std::bad_alloc @@ -64,10 +97,16 @@ class VertexAccessor final { /// std::vector::max_size(). Result> Labels(View view) const; + Result PrimaryLabel(View view) const; + /// Set a property value and return the old value. /// @throw std::bad_alloc Result SetProperty(PropertyId property, const PropertyValue &value); + /// Set a property value and return the old value or error. + /// @throw std::bad_alloc + ResultSchema SetPropertyAndValidate(PropertyId property, const PropertyValue &value); + /// Remove all properties and return the values of the removed properties. /// @throw std::bad_alloc Result> ClearProperties(); @@ -96,6 +135,8 @@ class VertexAccessor final { Gid Gid() const noexcept { return vertex_->gid; } + const SchemaValidator *GetSchemaValidator() const; + bool operator==(const VertexAccessor &other) const noexcept { return vertex_ == other.vertex_ && transaction_ == other.transaction_; } @@ -107,6 +148,7 @@ class VertexAccessor final { Indices *indices_; Constraints *constraints_; Config::Items config_; + VertexValidator vertex_validator_; // if the accessor was created for a deleted vertex. // Accessor behaves differently for some methods based on this diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 78e1e366d..55fb7b01b 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -75,8 +75,9 @@ target_link_libraries(${test_prefix}bfs_single_node mg-query) add_unit_test(cypher_main_visitor.cpp) target_link_libraries(${test_prefix}cypher_main_visitor mg-query) -# add_unit_test(interpreter.cpp ${CMAKE_SOURCE_DIR}/src/glue/communication.cpp) -# target_link_libraries(${test_prefix}interpreter mg-communication mg-query) +add_unit_test(interpreter.cpp ${CMAKE_SOURCE_DIR}/src/glue/communication.cpp) +target_link_libraries(${test_prefix}interpreter mg-communication mg-query) + add_unit_test(plan_pretty_print.cpp) target_link_libraries(${test_prefix}plan_pretty_print mg-query) @@ -92,25 +93,28 @@ target_link_libraries(${test_prefix}query_expression_evaluator mg-query) add_unit_test(query_plan.cpp) target_link_libraries(${test_prefix}query_plan mg-query) -# add_unit_test(query_plan_accumulate_aggregate.cpp) -# target_link_libraries(${test_prefix}query_plan_accumulate_aggregate mg-query) +add_unit_test(query_plan_accumulate_aggregate.cpp) +target_link_libraries(${test_prefix}query_plan_accumulate_aggregate mg-query) -# add_unit_test(query_plan_bag_semantics.cpp) -# target_link_libraries(${test_prefix}query_plan_bag_semantics mg-query) +add_unit_test(query_plan_bag_semantics.cpp) +target_link_libraries(${test_prefix}query_plan_bag_semantics mg-query) -# add_unit_test(query_plan_create_set_remove_delete.cpp) -# target_link_libraries(${test_prefix}query_plan_create_set_remove_delete mg-query) -# add_unit_test(query_plan_edge_cases.cpp ${CMAKE_SOURCE_DIR}/src/glue/communication.cpp) -# target_link_libraries(${test_prefix}query_plan_edge_cases mg-communication mg-query) -# add_unit_test(query_plan_match_filter_return.cpp) -# target_link_libraries(${test_prefix}query_plan_match_filter_return mg-query) +add_unit_test(query_plan_create_set_remove_delete.cpp) +target_link_libraries(${test_prefix}query_plan_create_set_remove_delete mg-query) + +add_unit_test(query_plan_edge_cases.cpp ${CMAKE_SOURCE_DIR}/src/glue/communication.cpp) +target_link_libraries(${test_prefix}query_plan_edge_cases mg-communication mg-query) + +add_unit_test(query_plan_match_filter_return.cpp) +target_link_libraries(${test_prefix}query_plan_match_filter_return mg-query) add_unit_test(query_plan_read_write_typecheck.cpp ${CMAKE_SOURCE_DIR}/src/query/plan/read_write_type_checker.cpp) target_link_libraries(${test_prefix}query_plan_read_write_typecheck mg-query) -# add_unit_test(query_plan_v2_create_set_remove_delete.cpp) -# target_link_libraries(${test_prefix}query_plan_v2_create_set_remove_delete mg-query) +add_unit_test(query_plan_v2_create_set_remove_delete.cpp) +target_link_libraries(${test_prefix}query_plan_v2_create_set_remove_delete mg-query) + add_unit_test(query_pretty_print.cpp) target_link_libraries(${test_prefix}query_pretty_print mg-query) @@ -278,66 +282,76 @@ target_link_libraries(${test_prefix}commit_log_v2 gflags mg-utils mg-storage-v2) add_unit_test(property_value_v2.cpp) target_link_libraries(${test_prefix}property_value_v2 mg-storage-v2 mg-utils) -# add_unit_test(storage_v2.cpp) -# target_link_libraries(${test_prefix}storage_v2 mg-storage-v2 storage_test_utils) +add_unit_test(storage_v2.cpp) +target_link_libraries(${test_prefix}storage_v2 mg-storage-v2 storage_test_utils) + add_unit_test(storage_v2_constraints.cpp) target_link_libraries(${test_prefix}storage_v2_constraints mg-storage-v2) add_unit_test(storage_v2_decoder_encoder.cpp) target_link_libraries(${test_prefix}storage_v2_decoder_encoder mg-storage-v2) -# add_unit_test(storage_v2_durability.cpp) -# target_link_libraries(${test_prefix}storage_v2_durability mg-storage-v2) +add_unit_test(storage_v2_durability.cpp) +target_link_libraries(${test_prefix}storage_v2_durability mg-storage-v2) + +add_unit_test(storage_v2_edge.cpp) +target_link_libraries(${test_prefix}storage_v2_edge mg-storage-v2) -# add_unit_test(storage_v2_edge.cpp) -# target_link_libraries(${test_prefix}storage_v2_edge mg-storage-v2) add_unit_test(storage_v2_gc.cpp) target_link_libraries(${test_prefix}storage_v2_gc mg-storage-v2) -# add_unit_test(storage_v2_indices.cpp) -# target_link_libraries(${test_prefix}storage_v2_indices mg-storage-v2 mg-utils) +add_unit_test(storage_v2_indices.cpp) +target_link_libraries(${test_prefix}storage_v2_indices mg-storage-v2 mg-utils) + add_unit_test(storage_v2_name_id_mapper.cpp) target_link_libraries(${test_prefix}storage_v2_name_id_mapper mg-storage-v2) add_unit_test(storage_v2_property_store.cpp) target_link_libraries(${test_prefix}storage_v2_property_store mg-storage-v2 fmt) -# add_unit_test(storage_v2_wal_file.cpp) -# target_link_libraries(${test_prefix}storage_v2_wal_file mg-storage-v2 fmt) +add_unit_test(storage_v2_wal_file.cpp) +target_link_libraries(${test_prefix}storage_v2_wal_file mg-storage-v2 fmt) + +add_unit_test(storage_v2_replication.cpp) +target_link_libraries(${test_prefix}storage_v2_replication mg-storage-v2 fmt) -# add_unit_test(storage_v2_replication.cpp) -# target_link_libraries(${test_prefix}storage_v2_replication mg-storage-v2 fmt) add_unit_test(storage_v2_isolation_level.cpp) target_link_libraries(${test_prefix}storage_v2_isolation_level mg-storage-v2) # Test mg-storage-v3 - add_unit_test(storage_v3.cpp) target_link_libraries(${test_prefix}storage_v3 mg-storage-v3) add_unit_test(storage_v3_schema.cpp) -target_link_libraries(${test_prefix}storage_v3_schema mg-storage-v2) +target_link_libraries(${test_prefix}storage_v3_schema mg-storage-v3) -add_unit_test(interpreter_v2.cpp ${CMAKE_SOURCE_DIR}/src/glue/communication.cpp) -target_link_libraries(${test_prefix}interpreter_v2 mg-storage-v2 mg-query mg-communication) +# Test mg-query-v3 +add_unit_test(query_v2_interpreter.cpp ${CMAKE_SOURCE_DIR}/src/glue/v2/communication.cpp) +target_link_libraries(${test_prefix}query_v2_interpreter mg-storage-v3 mg-query-v2 mg-communication) add_unit_test(query_v2_query_plan_accumulate_aggregate.cpp) -target_link_libraries(${test_prefix}query_v2_query_plan_accumulate_aggregate mg-query) +target_link_libraries(${test_prefix}query_v2_query_plan_accumulate_aggregate mg-query-v2) add_unit_test(query_v2_query_plan_create_set_remove_delete.cpp) -target_link_libraries(${test_prefix}query_v2_query_plan_create_set_remove_delete mg-query) +target_link_libraries(${test_prefix}query_v2_query_plan_create_set_remove_delete mg-query-v2) add_unit_test(query_v2_query_plan_bag_semantics.cpp) -target_link_libraries(${test_prefix}query_v2_query_plan_bag_semantics mg-query) +target_link_libraries(${test_prefix}query_v2_query_plan_bag_semantics mg-query-v2) -add_unit_test(query_v2_query_plan_edge_cases.cpp ${CMAKE_SOURCE_DIR}/src/glue/communication.cpp) -target_link_libraries(${test_prefix}query_v2_query_plan_edge_cases mg-communication mg-query) +add_unit_test(query_v2_query_plan_edge_cases.cpp ${CMAKE_SOURCE_DIR}/src/glue/v2/communication.cpp) +target_link_libraries(${test_prefix}query_v2_query_plan_edge_cases mg-communication mg-query-v2) add_unit_test(query_v2_query_plan_v2_create_set_remove_delete.cpp) -target_link_libraries(${test_prefix}query_v2_query_plan_v2_create_set_remove_delete mg-query) +target_link_libraries(${test_prefix}query_v2_query_plan_v2_create_set_remove_delete mg-query-v2) add_unit_test(query_v2_query_plan_match_filter_return.cpp) -target_link_libraries(${test_prefix}query_v2_query_plan_match_filter_return mg-query) +target_link_libraries(${test_prefix}query_v2_query_plan_match_filter_return mg-query-v2) + +add_unit_test(query_v2_cypher_main_visitor.cpp) +target_link_libraries(${test_prefix}query_v2_cypher_main_visitor mg-query-v2) + +add_unit_test(query_v2_query_required_privileges.cpp) +target_link_libraries(${test_prefix}query_v2_query_required_privileges mg-query-v2) add_unit_test(replication_persistence_helper.cpp) target_link_libraries(${test_prefix}replication_persistence_helper mg-storage-v2) diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index ac5af8041..40aeb0161 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -32,7 +32,6 @@ #include #include -#include "common/types.hpp" #include "query/exceptions.hpp" #include "query/frontend/ast/ast.hpp" #include "query/frontend/ast/cypher_main_visitor.hpp" @@ -2214,8 +2213,6 @@ TEST_P(CypherMainVisitorTest, GrantPrivilege) { {AuthQuery::Privilege::MODULE_READ}); check_auth_query(&ast_generator, "GRANT MODULE_WRITE TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, {AuthQuery::Privilege::MODULE_WRITE}); - check_auth_query(&ast_generator, "GRANT SCHEMA TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, - {AuthQuery::Privilege::SCHEMA}); } TEST_P(CypherMainVisitorTest, DenyPrivilege) { @@ -2256,8 +2253,6 @@ TEST_P(CypherMainVisitorTest, DenyPrivilege) { {AuthQuery::Privilege::MODULE_READ}); check_auth_query(&ast_generator, "DENY MODULE_WRITE TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, {AuthQuery::Privilege::MODULE_WRITE}); - check_auth_query(&ast_generator, "DENY SCHEMA TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, - {AuthQuery::Privilege::SCHEMA}); } TEST_P(CypherMainVisitorTest, RevokePrivilege) { @@ -2300,8 +2295,6 @@ TEST_P(CypherMainVisitorTest, RevokePrivilege) { {}, {AuthQuery::Privilege::MODULE_READ}); check_auth_query(&ast_generator, "REVOKE MODULE_WRITE FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, {AuthQuery::Privilege::MODULE_WRITE}); - check_auth_query(&ast_generator, "REVOKE SCHEMA FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, - {AuthQuery::Privilege::SCHEMA}); } TEST_P(CypherMainVisitorTest, ShowPrivileges) { @@ -4216,110 +4209,3 @@ TEST_P(CypherMainVisitorTest, Foreach) { ASSERT_TRUE(dynamic_cast(*++clauses.begin())); } } - -TEST_P(CypherMainVisitorTest, TestShowSchemas) { - auto &ast_generator = *GetParam(); - auto *query = dynamic_cast(ast_generator.ParseQuery("SHOW SCHEMAS")); - ASSERT_TRUE(query); - EXPECT_EQ(query->action_, SchemaQuery::Action::SHOW_SCHEMAS); -} - -TEST_P(CypherMainVisitorTest, TestShowSchema) { - auto &ast_generator = *GetParam(); - EXPECT_THROW(ast_generator.ParseQuery("SHOW SCHEMA ON label"), SyntaxException); - EXPECT_THROW(ast_generator.ParseQuery("SHOW SCHEMA :label"), SyntaxException); - EXPECT_THROW(ast_generator.ParseQuery("SHOW SCHEMA label"), SyntaxException); - - auto *query = dynamic_cast(ast_generator.ParseQuery("SHOW SCHEMA ON :label")); - ASSERT_TRUE(query); - EXPECT_EQ(query->action_, SchemaQuery::Action::SHOW_SCHEMA); - EXPECT_EQ(query->label_, ast_generator.Label("label")); -} - -void AssertSchemaPropertyMap(auto &schema_property_map, - std::vector> properties_type, - auto &ast_generator) { - EXPECT_EQ(schema_property_map.size(), properties_type.size()); - for (size_t i{0}; i < schema_property_map.size(); ++i) { - // Assert PropertyId - EXPECT_EQ(schema_property_map[i].first, ast_generator.Prop(properties_type[i].first)); - // Assert Property Type - EXPECT_EQ(schema_property_map[i].second, properties_type[i].second); - } -} - -TEST_P(CypherMainVisitorTest, TestCreateSchema) { - { - auto &ast_generator = *GetParam(); - EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label"), SyntaxException); - EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label()"), SyntaxException); - EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(123 INTEGER)"), SyntaxException); - EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(name TYPE)"), SyntaxException); - EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(name, age)"), SyntaxException); - EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(name, DURATION)"), SyntaxException); - EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON label(name INTEGER)"), SyntaxException); - EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(name INTEGER, name INTEGER)"), SemanticException); - EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(name INTEGER, name STRING)"), SemanticException); - } - { - auto &ast_generator = *GetParam(); - auto *query = dynamic_cast(ast_generator.ParseQuery("CREATE SCHEMA ON :label1(name STRING)")); - ASSERT_TRUE(query); - EXPECT_EQ(query->action_, SchemaQuery::Action::CREATE_SCHEMA); - EXPECT_EQ(query->label_, ast_generator.Label("label1")); - AssertSchemaPropertyMap(query->schema_type_map_, {{"name", memgraph::common::SchemaType::STRING}}, ast_generator); - } - { - auto &ast_generator = *GetParam(); - auto *query = dynamic_cast(ast_generator.ParseQuery("CREATE SCHEMA ON :label2(name string)")); - ASSERT_TRUE(query); - EXPECT_EQ(query->action_, SchemaQuery::Action::CREATE_SCHEMA); - EXPECT_EQ(query->label_, ast_generator.Label("label2")); - AssertSchemaPropertyMap(query->schema_type_map_, {{"name", memgraph::common::SchemaType::STRING}}, ast_generator); - } - { - auto &ast_generator = *GetParam(); - auto *query = dynamic_cast( - ast_generator.ParseQuery("CREATE SCHEMA ON :label3(first_name STRING, last_name STRING)")); - ASSERT_TRUE(query); - EXPECT_EQ(query->action_, SchemaQuery::Action::CREATE_SCHEMA); - EXPECT_EQ(query->label_, ast_generator.Label("label3")); - AssertSchemaPropertyMap( - query->schema_type_map_, - {{"first_name", memgraph::common::SchemaType::STRING}, {"last_name", memgraph::common::SchemaType::STRING}}, - ast_generator); - } - { - auto &ast_generator = *GetParam(); - auto *query = dynamic_cast( - ast_generator.ParseQuery("CREATE SCHEMA ON :label4(name STRING, age INTEGER, dur DURATION, birthday1 " - "LOCALDATETIME, birthday2 DATE, some_time LOCALTIME, speaks_truth BOOL)")); - ASSERT_TRUE(query); - EXPECT_EQ(query->action_, SchemaQuery::Action::CREATE_SCHEMA); - EXPECT_EQ(query->label_, ast_generator.Label("label4")); - AssertSchemaPropertyMap(query->schema_type_map_, - { - {"name", memgraph::common::SchemaType::STRING}, - {"age", memgraph::common::SchemaType::INT}, - {"dur", memgraph::common::SchemaType::DURATION}, - {"birthday1", memgraph::common::SchemaType::LOCALDATETIME}, - {"birthday2", memgraph::common::SchemaType::DATE}, - {"some_time", memgraph::common::SchemaType::LOCALTIME}, - {"speaks_truth", memgraph::common::SchemaType::BOOL}, - }, - ast_generator); - } -} - -TEST_P(CypherMainVisitorTest, TestDropSchema) { - auto &ast_generator = *GetParam(); - EXPECT_THROW(ast_generator.ParseQuery("DROP SCHEMA"), SyntaxException); - EXPECT_THROW(ast_generator.ParseQuery("DROP SCHEMA ON label"), SyntaxException); - EXPECT_THROW(ast_generator.ParseQuery("DROP SCHEMA :label"), SyntaxException); - EXPECT_THROW(ast_generator.ParseQuery("DROP SCHEMA ON :label()"), SyntaxException); - - auto *query = dynamic_cast(ast_generator.ParseQuery("DROP SCHEMA ON :label")); - ASSERT_TRUE(query); - EXPECT_EQ(query->action_, SchemaQuery::Action::DROP_SCHEMA); - EXPECT_EQ(query->label_, ast_generator.Label("label")); -} diff --git a/tests/unit/query_required_privileges.cpp b/tests/unit/query_required_privileges.cpp index 4aab492e1..ad21b10c4 100644 --- a/tests/unit/query_required_privileges.cpp +++ b/tests/unit/query_required_privileges.cpp @@ -192,11 +192,6 @@ TEST_F(TestPrivilegeExtractor, ShowVersion) { EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::STATS)); } -TEST_F(TestPrivilegeExtractor, SchemaQuery) { - auto *query = storage.Create(); - EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::SCHEMA)); -} - TEST_F(TestPrivilegeExtractor, CallProcedureQuery) { { auto *query = QUERY(SINGLE_QUERY(CALL_PROCEDURE("mg.get_module_files"))); diff --git a/tests/unit/query_v2_cypher_main_visitor.cpp b/tests/unit/query_v2_cypher_main_visitor.cpp new file mode 100644 index 000000000..308f9a049 --- /dev/null +++ b/tests/unit/query_v2_cypher_main_visitor.cpp @@ -0,0 +1,4326 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include +#include +#include +#include +#include +#include +#include +#include + +////////////////////////////////////////////////////// +// "json.hpp" should always come before "antrl4-runtime.h" +// "json.hpp" uses libc's EOF macro while +// "antrl4-runtime.h" contains a static variable of the +// same name, EOF. +// This hides the definition of the macro which causes +// the compilation to fail. +#include +////////////////////////////////////////////////////// +#include +#include +#include +#include + +#include "common/types.hpp" +#include "query/v2/exceptions.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/frontend/ast/cypher_main_visitor.hpp" +#include "query/v2/frontend/opencypher/parser.hpp" +#include "query/v2/frontend/stripped.hpp" +#include "query/v2/procedure/cypher_types.hpp" +#include "query/v2/procedure/mg_procedure_impl.hpp" +#include "query/v2/procedure/module.hpp" +#include "query/v2/typed_value.hpp" + +#include "utils/string.hpp" +#include "utils/variant_helpers.hpp" + +using namespace memgraph::query::v2; +using namespace memgraph::query::v2::frontend; +using memgraph::query::v2::TypedValue; +using testing::ElementsAre; +using testing::Pair; +using testing::UnorderedElementsAre; + +// Base class for all test types +class Base { + public: + ParsingContext context_; + Parameters parameters_; + + virtual ~Base() {} + + virtual Query *ParseQuery(const std::string &query_string) = 0; + + virtual PropertyIx Prop(const std::string &prop_name) = 0; + + virtual LabelIx Label(const std::string &label_name) = 0; + + virtual EdgeTypeIx EdgeType(const std::string &edge_type_name) = 0; + + TypedValue LiteralValue(Expression *expression) { + if (context_.is_query_cached) { + auto *param_lookup = dynamic_cast(expression); + return TypedValue(parameters_.AtTokenPosition(param_lookup->token_position_)); + } else { + auto *literal = dynamic_cast(expression); + return TypedValue(literal->value_); + } + } + + TypedValue GetLiteral(Expression *expression, const bool use_parameter_lookup, + const std::optional &token_position = std::nullopt) const { + if (use_parameter_lookup) { + auto *param_lookup = dynamic_cast(expression); + if (param_lookup == nullptr) { + ADD_FAILURE(); + return {}; + } + if (token_position) { + EXPECT_EQ(param_lookup->token_position_, *token_position); + } + return TypedValue(parameters_.AtTokenPosition(param_lookup->token_position_)); + } + + auto *literal = dynamic_cast(expression); + if (literal == nullptr) { + ADD_FAILURE(); + return {}; + } + if (token_position) { + EXPECT_EQ(literal->token_position_, *token_position); + } + return TypedValue(literal->value_); + } + + template + void CheckLiteral(Expression *expression, const TValue &expected, + const std::optional &token_position = std::nullopt) const { + TypedValue expected_tv(expected); + const auto use_parameter_lookup = !expected_tv.IsNull() && context_.is_query_cached; + TypedValue value = GetLiteral(expression, use_parameter_lookup, token_position); + EXPECT_TRUE(TypedValue::BoolEqual{}(value, expected_tv)); + } +}; + +// This generator uses ast constructed by parsing the query. +class AstGenerator : public Base { + public: + Query *ParseQuery(const std::string &query_string) override { + ::frontend::opencypher::Parser parser(query_string); + CypherMainVisitor visitor(context_, &ast_storage_); + visitor.visit(parser.tree()); + return visitor.query(); + } + + PropertyIx Prop(const std::string &prop_name) override { return ast_storage_.GetPropertyIx(prop_name); } + + LabelIx Label(const std::string &name) override { return ast_storage_.GetLabelIx(name); } + + EdgeTypeIx EdgeType(const std::string &name) override { return ast_storage_.GetEdgeTypeIx(name); } + + AstStorage ast_storage_; +}; + +// This clones ast, but uses original one. This done just to ensure that cloning +// doesn't change original. +class OriginalAfterCloningAstGenerator : public AstGenerator { + public: + Query *ParseQuery(const std::string &query_string) override { + auto *original_query = AstGenerator::ParseQuery(query_string); + AstStorage storage; + original_query->Clone(&storage); + return original_query; + } +}; + +// This generator clones parsed ast and uses that one. +// Original ast is cleared after cloning to ensure that cloned ast doesn't reuse +// any data from original ast. +class ClonedAstGenerator : public Base { + public: + Query *ParseQuery(const std::string &query_string) override { + ::frontend::opencypher::Parser parser(query_string); + AstStorage tmp_storage; + { + // Add a label, property and edge type into temporary storage so + // indices have to change in cloned AST. + tmp_storage.GetLabelIx("jkfdklajfkdalsj"); + tmp_storage.GetPropertyIx("fdjakfjdklfjdaslk"); + tmp_storage.GetEdgeTypeIx("fdjkalfjdlkajfdkla"); + } + CypherMainVisitor visitor(context_, &tmp_storage); + visitor.visit(parser.tree()); + return visitor.query()->Clone(&ast_storage_); + } + + PropertyIx Prop(const std::string &prop_name) override { return ast_storage_.GetPropertyIx(prop_name); } + + LabelIx Label(const std::string &name) override { return ast_storage_.GetLabelIx(name); } + + EdgeTypeIx EdgeType(const std::string &name) override { return ast_storage_.GetEdgeTypeIx(name); } + + AstStorage ast_storage_; +}; + +// This generator strips ast, clones it and then plugs stripped out literals in +// the same way it is done in ast cacheing in interpreter. +class CachedAstGenerator : public Base { + public: + Query *ParseQuery(const std::string &query_string) override { + context_.is_query_cached = true; + StrippedQuery stripped(query_string); + parameters_ = stripped.literals(); + ::frontend::opencypher::Parser parser(stripped.query()); + AstStorage tmp_storage; + CypherMainVisitor visitor(context_, &tmp_storage); + visitor.visit(parser.tree()); + return visitor.query()->Clone(&ast_storage_); + } + + PropertyIx Prop(const std::string &prop_name) override { return ast_storage_.GetPropertyIx(prop_name); } + + LabelIx Label(const std::string &name) override { return ast_storage_.GetLabelIx(name); } + + EdgeTypeIx EdgeType(const std::string &name) override { return ast_storage_.GetEdgeTypeIx(name); } + + AstStorage ast_storage_; +}; + +class MockModule : public procedure::Module { + public: + MockModule(){}; + ~MockModule() override{}; + MockModule(const MockModule &) = delete; + MockModule(MockModule &&) = delete; + MockModule &operator=(const MockModule &) = delete; + MockModule &operator=(MockModule &&) = delete; + + bool Close() override { return true; }; + + const std::map> *Procedures() const override { return &procedures; } + + const std::map> *Transformations() const override { return &transformations; } + + const std::map> *Functions() const override { return &functions; } + + std::optional Path() const override { return std::nullopt; }; + + std::map> procedures{}; + std::map> transformations{}; + std::map> functions{}; +}; + +void DummyProcCallback(mgp_list * /*args*/, mgp_graph * /*graph*/, mgp_result * /*result*/, mgp_memory * /*memory*/){}; +void DummyFuncCallback(mgp_list * /*args*/, mgp_func_context * /*func_ctx*/, mgp_func_result * /*result*/, + mgp_memory * /*memory*/){}; + +enum class ProcedureType { WRITE, READ }; + +std::string ToString(const ProcedureType type) { return type == ProcedureType::WRITE ? "write" : "read"; } + +class CypherMainVisitorTest : public ::testing::TestWithParam> { + public: + void SetUp() override { + { + auto mock_module_owner = std::make_unique(); + mock_module = mock_module_owner.get(); + procedure::gModuleRegistry.RegisterModule("mock_module", std::move(mock_module_owner)); + } + { + auto mock_module_with_dots_in_name_owner = std::make_unique(); + mock_module_with_dots_in_name = mock_module_with_dots_in_name_owner.get(); + procedure::gModuleRegistry.RegisterModule("mock_module.with.dots.in.name", + std::move(mock_module_with_dots_in_name_owner)); + } + } + + void TearDown() override { + // To release any_type + procedure::gModuleRegistry.UnloadAllModules(); + } + + static void AddProc(MockModule &module, const char *name, const std::vector &args, + const std::vector &results, const ProcedureType type) { + memgraph::utils::MemoryResource *memory = memgraph::utils::NewDeleteResource(); + const bool is_write = type == ProcedureType::WRITE; + mgp_proc proc(name, DummyProcCallback, memory, {.is_write = is_write}); + for (const auto arg : args) { + proc.args.emplace_back(memgraph::utils::pmr::string{arg, memory}, &any_type); + } + for (const auto result : results) { + proc.results.emplace(memgraph::utils::pmr::string{result, memory}, std::make_pair(&any_type, false)); + } + module.procedures.emplace(name, std::move(proc)); + } + + static void AddFunc(MockModule &module, const char *name, const std::vector &args) { + memgraph::utils::MemoryResource *memory = memgraph::utils::NewDeleteResource(); + mgp_func func(name, DummyFuncCallback, memory); + for (const auto arg : args) { + func.args.emplace_back(memgraph::utils::pmr::string{arg, memory}, &any_type); + } + module.functions.emplace(name, std::move(func)); + } + + std::string CreateProcByType(const ProcedureType type, const std::vector &args) { + const auto proc_name = std::string{"proc_"} + ToString(type); + SCOPED_TRACE(proc_name); + AddProc(*mock_module, proc_name.c_str(), {}, args, type); + return std::string{"mock_module."} + proc_name; + } + + static const procedure::AnyType any_type; + MockModule *mock_module{nullptr}; + MockModule *mock_module_with_dots_in_name{nullptr}; +}; + +const procedure::AnyType CypherMainVisitorTest::any_type{}; + +std::shared_ptr gAstGeneratorTypes[] = { + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), +}; + +INSTANTIATE_TEST_CASE_P(AstGeneratorTypes, CypherMainVisitorTest, ::testing::ValuesIn(gAstGeneratorTypes)); + +// NOTE: The above used to use *Typed Tests* functionality of gtest library. +// Unfortunately, the compilation time of this test increased to full 2 minutes! +// Although using Typed Tests is the recommended way to achieve what we want, we +// are (ab)using *Value-Parameterized Tests* functionality instead. This cuts +// down the compilation time to about 20 seconds. The original code is here for +// future reference in case someone gets the idea to change to *appropriate* +// Typed Tests mechanism and ruin the compilation times. +// +// typedef ::testing::Types +// AstGeneratorTypes; +// +// TYPED_TEST_CASE(CypherMainVisitorTest, AstGeneratorTypes); + +TEST_P(CypherMainVisitorTest, SyntaxException) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("CREATE ()-[*1....2]-()"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, SyntaxExceptionOnTrailingText) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 2 + 2 mirko"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, PropertyLookup) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN n.x")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *property_lookup = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(property_lookup->expression_); + auto identifier = dynamic_cast(property_lookup->expression_); + ASSERT_TRUE(identifier); + ASSERT_EQ(identifier->name_, "n"); + ASSERT_EQ(property_lookup->property_, ast_generator.Prop("x")); +} + +TEST_P(CypherMainVisitorTest, LabelsTest) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN n:x:y")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *labels_test = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(labels_test->expression_); + auto identifier = dynamic_cast(labels_test->expression_); + ASSERT_TRUE(identifier); + ASSERT_EQ(identifier->name_, "n"); + ASSERT_THAT(labels_test->labels_, ElementsAre(ast_generator.Label("x"), ast_generator.Label("y"))); +} + +TEST_P(CypherMainVisitorTest, EscapedLabel) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN n:`l-$\"'ab``e````l`")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *labels_test = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + auto identifier = dynamic_cast(labels_test->expression_); + ASSERT_EQ(identifier->name_, "n"); + ASSERT_THAT(labels_test->labels_, ElementsAre(ast_generator.Label("l-$\"'ab`e``l"))); +} + +TEST_P(CypherMainVisitorTest, KeywordLabel) { + for (const auto &label : {"DeLete", "UsER"}) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery(fmt::format("RETURN n:{}", label))); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *labels_test = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + auto identifier = dynamic_cast(labels_test->expression_); + ASSERT_EQ(identifier->name_, "n"); + ASSERT_THAT(labels_test->labels_, ElementsAre(ast_generator.Label(label))); + } +} + +TEST_P(CypherMainVisitorTest, HexLetterLabel) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN n:a")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *labels_test = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + auto identifier = dynamic_cast(labels_test->expression_); + EXPECT_EQ(identifier->name_, "n"); + ASSERT_THAT(labels_test->labels_, ElementsAre(ast_generator.Label("a"))); +} + +TEST_P(CypherMainVisitorTest, ReturnNoDistinctNoBagSemantics) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN x")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ASSERT_FALSE(return_clause->body_.all_identifiers); + ASSERT_EQ(return_clause->body_.order_by.size(), 0U); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 1U); + ASSERT_FALSE(return_clause->body_.limit); + ASSERT_FALSE(return_clause->body_.skip); + ASSERT_FALSE(return_clause->body_.distinct); +} + +TEST_P(CypherMainVisitorTest, ReturnDistinct) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN DISTINCT x")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(return_clause->body_.distinct); +} + +TEST_P(CypherMainVisitorTest, ReturnLimit) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN x LIMIT 5")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(return_clause->body_.limit); + ast_generator.CheckLiteral(return_clause->body_.limit, 5); +} + +TEST_P(CypherMainVisitorTest, ReturnSkip) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN x SKIP 5")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(return_clause->body_.skip); + ast_generator.CheckLiteral(return_clause->body_.skip, 5); +} + +TEST_P(CypherMainVisitorTest, ReturnOrderBy) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN x, y, z ORDER BY z ASC, x, y DESC")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ASSERT_EQ(return_clause->body_.order_by.size(), 3U); + std::vector> ordering; + for (const auto &sort_item : return_clause->body_.order_by) { + auto *identifier = dynamic_cast(sort_item.expression); + ordering.emplace_back(sort_item.ordering, identifier->name_); + } + ASSERT_THAT(ordering, + UnorderedElementsAre(Pair(Ordering::ASC, "z"), Pair(Ordering::ASC, "x"), Pair(Ordering::DESC, "y"))); +} + +TEST_P(CypherMainVisitorTest, ReturnNamedIdentifier) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN var AS var5")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ASSERT_FALSE(return_clause->body_.all_identifiers); + auto *named_expr = return_clause->body_.named_expressions[0]; + ASSERT_EQ(named_expr->name_, "var5"); + auto *identifier = dynamic_cast(named_expr->expression_); + ASSERT_EQ(identifier->name_, "var"); +} + +TEST_P(CypherMainVisitorTest, ReturnAsterisk) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN *")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(return_clause->body_.all_identifiers); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 0U); +} + +TEST_P(CypherMainVisitorTest, IntegerLiteral) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN 42")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, 42, 1); +} + +TEST_P(CypherMainVisitorTest, IntegerLiteralTooLarge) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 10000000000000000000000000"), SemanticException); +} + +TEST_P(CypherMainVisitorTest, BooleanLiteralTrue) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN TrUe")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, true, 1); +} + +TEST_P(CypherMainVisitorTest, BooleanLiteralFalse) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN faLSE")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, false, 1); +} + +TEST_P(CypherMainVisitorTest, NullLiteral) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN nULl")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, TypedValue(), 1); +} + +TEST_P(CypherMainVisitorTest, ParenthesizedExpression) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN (2)")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, 2); +} + +TEST_P(CypherMainVisitorTest, OrOperator) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN true Or false oR n")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *or_operator2 = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(or_operator2); + auto *or_operator1 = dynamic_cast(or_operator2->expression1_); + ASSERT_TRUE(or_operator1); + ast_generator.CheckLiteral(or_operator1->expression1_, true); + ast_generator.CheckLiteral(or_operator1->expression2_, false); + auto *operand3 = dynamic_cast(or_operator2->expression2_); + ASSERT_TRUE(operand3); + ASSERT_EQ(operand3->name_, "n"); +} + +TEST_P(CypherMainVisitorTest, XorOperator) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN true xOr false")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *xor_operator = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ast_generator.CheckLiteral(xor_operator->expression1_, true); + ast_generator.CheckLiteral(xor_operator->expression2_, false); +} + +TEST_P(CypherMainVisitorTest, AndOperator) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN true and false")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *and_operator = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ast_generator.CheckLiteral(and_operator->expression1_, true); + ast_generator.CheckLiteral(and_operator->expression2_, false); +} + +TEST_P(CypherMainVisitorTest, AdditionSubtractionOperators) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN 1 - 2 + 3")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *addition_operator = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(addition_operator); + auto *subtraction_operator = dynamic_cast(addition_operator->expression1_); + ASSERT_TRUE(subtraction_operator); + ast_generator.CheckLiteral(subtraction_operator->expression1_, 1); + ast_generator.CheckLiteral(subtraction_operator->expression2_, 2); + ast_generator.CheckLiteral(addition_operator->expression2_, 3); +} + +TEST_P(CypherMainVisitorTest, MulitplicationOperator) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN 2 * 3")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *mult_operator = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ast_generator.CheckLiteral(mult_operator->expression1_, 2); + ast_generator.CheckLiteral(mult_operator->expression2_, 3); +} + +TEST_P(CypherMainVisitorTest, DivisionOperator) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN 2 / 3")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *div_operator = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ast_generator.CheckLiteral(div_operator->expression1_, 2); + ast_generator.CheckLiteral(div_operator->expression2_, 3); +} + +TEST_P(CypherMainVisitorTest, ModOperator) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN 2 % 3")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *mod_operator = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ast_generator.CheckLiteral(mod_operator->expression1_, 2); + ast_generator.CheckLiteral(mod_operator->expression2_, 3); +} + +#define CHECK_COMPARISON(TYPE, VALUE1, VALUE2) \ + do { \ + auto *and_operator = dynamic_cast(_operator); \ + ASSERT_TRUE(and_operator); \ + _operator = and_operator->expression1_; \ + auto *cmp_operator = dynamic_cast(and_operator->expression2_); \ + ASSERT_TRUE(cmp_operator); \ + ast_generator.CheckLiteral(cmp_operator->expression1_, VALUE1); \ + ast_generator.CheckLiteral(cmp_operator->expression2_, VALUE2); \ + } while (0) + +TEST_P(CypherMainVisitorTest, ComparisonOperators) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN 2 = 3 != 4 <> 5 < 6 > 7 <= 8 >= 9")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + Expression *_operator = return_clause->body_.named_expressions[0]->expression_; + CHECK_COMPARISON(GreaterEqualOperator, 8, 9); + CHECK_COMPARISON(LessEqualOperator, 7, 8); + CHECK_COMPARISON(GreaterOperator, 6, 7); + CHECK_COMPARISON(LessOperator, 5, 6); + CHECK_COMPARISON(NotEqualOperator, 4, 5); + CHECK_COMPARISON(NotEqualOperator, 3, 4); + auto *cmp_operator = dynamic_cast(_operator); + ASSERT_TRUE(cmp_operator); + ast_generator.CheckLiteral(cmp_operator->expression1_, 2); + ast_generator.CheckLiteral(cmp_operator->expression2_, 3); +} + +#undef CHECK_COMPARISON + +TEST_P(CypherMainVisitorTest, ListIndexing) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN [1,2,3] [ 2 ]")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *list_index_op = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(list_index_op); + auto *list = dynamic_cast(list_index_op->expression1_); + EXPECT_TRUE(list); + ast_generator.CheckLiteral(list_index_op->expression2_, 2); +} + +TEST_P(CypherMainVisitorTest, ListSlicingOperatorNoBounds) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("RETURN [1,2,3] [ .. ]"), SemanticException); +} + +TEST_P(CypherMainVisitorTest, ListSlicingOperator) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN [1,2,3] [ .. 2 ]")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *list_slicing_op = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(list_slicing_op); + auto *list = dynamic_cast(list_slicing_op->list_); + EXPECT_TRUE(list); + EXPECT_FALSE(list_slicing_op->lower_bound_); + ast_generator.CheckLiteral(list_slicing_op->upper_bound_, 2); +} + +TEST_P(CypherMainVisitorTest, InListOperator) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN 5 IN [1,2]")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *in_list_operator = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(in_list_operator); + ast_generator.CheckLiteral(in_list_operator->expression1_, 5); + auto *list = dynamic_cast(in_list_operator->expression2_); + ASSERT_TRUE(list); +} + +TEST_P(CypherMainVisitorTest, InWithListIndexing) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN 1 IN [[1,2]][0]")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *in_list_operator = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(in_list_operator); + ast_generator.CheckLiteral(in_list_operator->expression1_, 1); + auto *list_indexing = dynamic_cast(in_list_operator->expression2_); + ASSERT_TRUE(list_indexing); + auto *list = dynamic_cast(list_indexing->expression1_); + EXPECT_TRUE(list); + ast_generator.CheckLiteral(list_indexing->expression2_, 0); +} + +TEST_P(CypherMainVisitorTest, CaseGenericForm) { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast(ast_generator.ParseQuery("RETURN CASE WHEN n < 10 THEN 1 WHEN n > 10 THEN 2 END")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *if_operator = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(if_operator); + auto *condition = dynamic_cast(if_operator->condition_); + ASSERT_TRUE(condition); + ast_generator.CheckLiteral(if_operator->then_expression_, 1); + + auto *if_operator2 = dynamic_cast(if_operator->else_expression_); + ASSERT_TRUE(if_operator2); + auto *condition2 = dynamic_cast(if_operator2->condition_); + ASSERT_TRUE(condition2); + ast_generator.CheckLiteral(if_operator2->then_expression_, 2); + ast_generator.CheckLiteral(if_operator2->else_expression_, TypedValue()); +} + +TEST_P(CypherMainVisitorTest, CaseGenericFormElse) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN CASE WHEN n < 10 THEN 1 ELSE 2 END")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *if_operator = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + auto *condition = dynamic_cast(if_operator->condition_); + ASSERT_TRUE(condition); + ast_generator.CheckLiteral(if_operator->then_expression_, 1); + ast_generator.CheckLiteral(if_operator->else_expression_, 2); +} + +TEST_P(CypherMainVisitorTest, CaseSimpleForm) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN CASE 5 WHEN 10 THEN 1 END")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *if_operator = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + auto *condition = dynamic_cast(if_operator->condition_); + ASSERT_TRUE(condition); + ast_generator.CheckLiteral(condition->expression1_, 5); + ast_generator.CheckLiteral(condition->expression2_, 10); + ast_generator.CheckLiteral(if_operator->then_expression_, 1); + ast_generator.CheckLiteral(if_operator->else_expression_, TypedValue()); +} + +TEST_P(CypherMainVisitorTest, IsNull) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN 2 iS NulL")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *is_type_operator = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ast_generator.CheckLiteral(is_type_operator->expression_, 2); +} + +TEST_P(CypherMainVisitorTest, IsNotNull) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN 2 iS nOT NulL")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *not_operator = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + auto *is_type_operator = dynamic_cast(not_operator->expression_); + ast_generator.CheckLiteral(is_type_operator->expression_, 2); +} + +TEST_P(CypherMainVisitorTest, NotOperator) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN not true")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *not_operator = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ast_generator.CheckLiteral(not_operator->expression_, true); +} + +TEST_P(CypherMainVisitorTest, UnaryMinusPlusOperators) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN -+5")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *unary_minus_operator = + dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(unary_minus_operator); + auto *unary_plus_operator = dynamic_cast(unary_minus_operator->expression_); + ASSERT_TRUE(unary_plus_operator); + ast_generator.CheckLiteral(unary_plus_operator->expression_, 5); +} + +TEST_P(CypherMainVisitorTest, Aggregation) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast( + ast_generator.ParseQuery("RETURN COUNT(a), MIN(b), MAX(c), SUM(d), AVG(e), COLLECT(f), COUNT(*)")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 7U); + Aggregation::Op ops[] = {Aggregation::Op::COUNT, Aggregation::Op::MIN, Aggregation::Op::MAX, + Aggregation::Op::SUM, Aggregation::Op::AVG, Aggregation::Op::COLLECT_LIST}; + std::string ids[] = {"a", "b", "c", "d", "e", "f"}; + for (int i = 0; i < 6; ++i) { + auto *aggregation = dynamic_cast(return_clause->body_.named_expressions[i]->expression_); + ASSERT_TRUE(aggregation); + ASSERT_EQ(aggregation->op_, ops[i]); + auto *identifier = dynamic_cast(aggregation->expression1_); + ASSERT_TRUE(identifier); + ASSERT_EQ(identifier->name_, ids[i]); + } + auto *aggregation = dynamic_cast(return_clause->body_.named_expressions[6]->expression_); + ASSERT_TRUE(aggregation); + ASSERT_EQ(aggregation->op_, Aggregation::Op::COUNT); + ASSERT_FALSE(aggregation->expression1_); +} + +TEST_P(CypherMainVisitorTest, UndefinedFunction) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("RETURN " + "IHopeWeWillNeverHaveAwesomeMemgraphProcedureWithS" + "uchALongAndAwesomeNameSinceThisTestWouldFail(1)"), + SemanticException); +} + +TEST_P(CypherMainVisitorTest, MissingFunction) { + AddFunc(*mock_module, "get", {}); + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("RETURN missing_function.get()"), SemanticException); +} + +TEST_P(CypherMainVisitorTest, Function) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN abs(n, 2)")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 1); + auto *function = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(function); + ASSERT_TRUE(function->function_); +} + +TEST_P(CypherMainVisitorTest, MagicFunction) { + AddFunc(*mock_module, "get", {}); + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN mock_module.get()")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 1); + auto *function = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(function); + ASSERT_TRUE(function->function_); +} + +TEST_P(CypherMainVisitorTest, StringLiteralDoubleQuotes) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN \"mi'rko\"")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, "mi'rko", 1); +} + +TEST_P(CypherMainVisitorTest, StringLiteralSingleQuotes) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN 'mi\"rko'")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, "mi\"rko", 1); +} + +TEST_P(CypherMainVisitorTest, StringLiteralEscapedChars) { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast(ast_generator.ParseQuery("RETURN '\\\\\\'\\\"\\b\\B\\f\\F\\n\\N\\r\\R\\t\\T'")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, "\\'\"\b\b\f\f\n\n\r\r\t\t", 1); +} + +TEST_P(CypherMainVisitorTest, StringLiteralEscapedUtf16) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN '\\u221daaa\\u221daaa'")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, + "\xE2\x88\x9D" + "aaa" + "\xE2\x88\x9D" + "aaa", + 1); // u8"\u221daaa\u221daaa" +} + +TEST_P(CypherMainVisitorTest, StringLiteralEscapedUtf16Error) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("RETURN '\\U221daaa'"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, StringLiteralEscapedUtf32) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN '\\U0001F600aaaa\\U0001F600aaaaaaaa'")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, + "\xF0\x9F\x98\x80" + "aaaa" + "\xF0\x9F\x98\x80" + "aaaaaaaa", + 1); // u8"\U0001F600aaaa\U0001F600aaaaaaaa" +} + +TEST_P(CypherMainVisitorTest, DoubleLiteral) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN 3.5")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, 3.5, 1); +} + +TEST_P(CypherMainVisitorTest, DoubleLiteralExponent) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN 5e-1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ast_generator.CheckLiteral(return_clause->body_.named_expressions[0]->expression_, 0.5, 1); +} + +TEST_P(CypherMainVisitorTest, ListLiteral) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN [3, [], 'johhny']")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *list_literal = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(list_literal); + ASSERT_EQ(3, list_literal->elements_.size()); + ast_generator.CheckLiteral(list_literal->elements_[0], 3); + auto *elem_1 = dynamic_cast(list_literal->elements_[1]); + ASSERT_TRUE(elem_1); + EXPECT_EQ(0, elem_1->elements_.size()); + ast_generator.CheckLiteral(list_literal->elements_[2], "johhny"); +} + +TEST_P(CypherMainVisitorTest, MapLiteral) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN {a: 1, b: 'bla', c: [1, {a: 42}]}")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + auto *map_literal = dynamic_cast(return_clause->body_.named_expressions[0]->expression_); + ASSERT_TRUE(map_literal); + ASSERT_EQ(3, map_literal->elements_.size()); + ast_generator.CheckLiteral(map_literal->elements_[ast_generator.Prop("a")], 1); + ast_generator.CheckLiteral(map_literal->elements_[ast_generator.Prop("b")], "bla"); + auto *elem_2 = dynamic_cast(map_literal->elements_[ast_generator.Prop("c")]); + ASSERT_TRUE(elem_2); + EXPECT_EQ(2, elem_2->elements_.size()); + auto *elem_2_1 = dynamic_cast(elem_2->elements_[1]); + ASSERT_TRUE(elem_2_1); + EXPECT_EQ(1, elem_2_1->elements_.size()); +} + +TEST_P(CypherMainVisitorTest, NodePattern) { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast(ast_generator.ParseQuery("MATCH (:label1:label2:label3 {a : 5, b : 10}) RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *match = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(match); + EXPECT_FALSE(match->optional_); + EXPECT_FALSE(match->where_); + ASSERT_EQ(match->patterns_.size(), 1U); + ASSERT_TRUE(match->patterns_[0]); + ASSERT_EQ(match->patterns_[0]->atoms_.size(), 1U); + auto node = dynamic_cast(match->patterns_[0]->atoms_[0]); + ASSERT_TRUE(node); + ASSERT_TRUE(node->identifier_); + EXPECT_EQ(node->identifier_->name_, CypherMainVisitor::kAnonPrefix + std::to_string(1)); + EXPECT_FALSE(node->identifier_->user_declared_); + EXPECT_THAT(node->labels_, UnorderedElementsAre(ast_generator.Label("label1"), ast_generator.Label("label2"), + ast_generator.Label("label3"))); + std::unordered_map properties; + for (auto x : std::get<0>(node->properties_)) { + TypedValue value = ast_generator.LiteralValue(x.second); + ASSERT_TRUE(value.type() == TypedValue::Type::Int); + properties[x.first] = value.ValueInt(); + } + EXPECT_THAT(properties, UnorderedElementsAre(Pair(ast_generator.Prop("a"), 5), Pair(ast_generator.Prop("b"), 10))); +} + +TEST_P(CypherMainVisitorTest, PropertyMapSameKeyAppearsTwice) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("MATCH ({a : 1, a : 2})"), SemanticException); +} + +TEST_P(CypherMainVisitorTest, NodePatternIdentifier) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("MATCH (var) RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(match); + EXPECT_FALSE(match->optional_); + EXPECT_FALSE(match->where_); + auto node = dynamic_cast(match->patterns_[0]->atoms_[0]); + ASSERT_TRUE(node); + ASSERT_TRUE(node->identifier_); + EXPECT_EQ(node->identifier_->name_, "var"); + EXPECT_TRUE(node->identifier_->user_declared_); + EXPECT_THAT(node->labels_, UnorderedElementsAre()); + EXPECT_THAT(std::get<0>(node->properties_), UnorderedElementsAre()); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternNoDetails) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("MATCH ()--() RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(match); + EXPECT_FALSE(match->optional_); + EXPECT_FALSE(match->where_); + ASSERT_EQ(match->patterns_.size(), 1U); + ASSERT_TRUE(match->patterns_[0]); + ASSERT_EQ(match->patterns_[0]->atoms_.size(), 3U); + auto *node1 = dynamic_cast(match->patterns_[0]->atoms_[0]); + ASSERT_TRUE(node1); + auto *edge = dynamic_cast(match->patterns_[0]->atoms_[1]); + ASSERT_TRUE(edge); + auto *node2 = dynamic_cast(match->patterns_[0]->atoms_[2]); + ASSERT_TRUE(node2); + ASSERT_TRUE(node1->identifier_); + ASSERT_TRUE(edge->identifier_); + ASSERT_TRUE(node2->identifier_); + EXPECT_THAT( + std::vector({node1->identifier_->name_, edge->identifier_->name_, node2->identifier_->name_}), + UnorderedElementsAre(CypherMainVisitor::kAnonPrefix + std::to_string(1), + CypherMainVisitor::kAnonPrefix + std::to_string(2), + CypherMainVisitor::kAnonPrefix + std::to_string(3))); + EXPECT_FALSE(node1->identifier_->user_declared_); + EXPECT_FALSE(edge->identifier_->user_declared_); + EXPECT_FALSE(node2->identifier_->user_declared_); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::BOTH); +} + +// PatternPart in braces. +TEST_P(CypherMainVisitorTest, PatternPartBraces) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("MATCH ((()--())) RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(match); + EXPECT_FALSE(match->where_); + ASSERT_EQ(match->patterns_.size(), 1U); + ASSERT_TRUE(match->patterns_[0]); + ASSERT_EQ(match->patterns_[0]->atoms_.size(), 3U); + auto *node1 = dynamic_cast(match->patterns_[0]->atoms_[0]); + ASSERT_TRUE(node1); + auto *edge = dynamic_cast(match->patterns_[0]->atoms_[1]); + ASSERT_TRUE(edge); + auto *node2 = dynamic_cast(match->patterns_[0]->atoms_[2]); + ASSERT_TRUE(node2); + ASSERT_TRUE(node1->identifier_); + ASSERT_TRUE(edge->identifier_); + ASSERT_TRUE(node2->identifier_); + EXPECT_THAT( + std::vector({node1->identifier_->name_, edge->identifier_->name_, node2->identifier_->name_}), + UnorderedElementsAre(CypherMainVisitor::kAnonPrefix + std::to_string(1), + CypherMainVisitor::kAnonPrefix + std::to_string(2), + CypherMainVisitor::kAnonPrefix + std::to_string(3))); + EXPECT_FALSE(node1->identifier_->user_declared_); + EXPECT_FALSE(edge->identifier_->user_declared_); + EXPECT_FALSE(node2->identifier_->user_declared_); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::BOTH); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternDetails) { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast(ast_generator.ParseQuery("MATCH ()<-[:type1|type2 {a : 5, b : 10}]-() RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(match); + EXPECT_FALSE(match->optional_); + EXPECT_FALSE(match->where_); + auto *edge = dynamic_cast(match->patterns_[0]->atoms_[1]); + ASSERT_TRUE(edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::IN); + EXPECT_THAT(edge->edge_types_, + UnorderedElementsAre(ast_generator.EdgeType("type1"), ast_generator.EdgeType("type2"))); + std::unordered_map properties; + for (auto x : std::get<0>(edge->properties_)) { + TypedValue value = ast_generator.LiteralValue(x.second); + ASSERT_TRUE(value.type() == TypedValue::Type::Int); + properties[x.first] = value.ValueInt(); + } + EXPECT_THAT(properties, UnorderedElementsAre(Pair(ast_generator.Prop("a"), 5), Pair(ast_generator.Prop("b"), 10))); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternVariable) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("MATCH ()-[var]->() RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(match); + EXPECT_FALSE(match->optional_); + EXPECT_FALSE(match->where_); + auto *edge = dynamic_cast(match->patterns_[0]->atoms_[1]); + ASSERT_TRUE(edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::OUT); + ASSERT_TRUE(edge->identifier_); + EXPECT_THAT(edge->identifier_->name_, "var"); + EXPECT_TRUE(edge->identifier_->user_declared_); +} + +// Assert that match has a single pattern with a single edge atom and store it +// in edge parameter. +void AssertMatchSingleEdgeAtom(Match *match, EdgeAtom *&edge) { + ASSERT_TRUE(match); + ASSERT_EQ(match->patterns_.size(), 1U); + ASSERT_EQ(match->patterns_[0]->atoms_.size(), 3U); + edge = dynamic_cast(match->patterns_[0]->atoms_[1]); + ASSERT_TRUE(edge); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternUnbounded) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("MATCH ()-[r*]->() RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast(single_query->clauses_[0]); + EdgeAtom *edge = nullptr; + AssertMatchSingleEdgeAtom(match, edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::OUT); + EXPECT_EQ(edge->type_, EdgeAtom::Type::DEPTH_FIRST); + EXPECT_EQ(edge->lower_bound_, nullptr); + EXPECT_EQ(edge->upper_bound_, nullptr); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternLowerBounded) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("MATCH ()-[r*42..]->() RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast(single_query->clauses_[0]); + EdgeAtom *edge = nullptr; + AssertMatchSingleEdgeAtom(match, edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::OUT); + EXPECT_EQ(edge->type_, EdgeAtom::Type::DEPTH_FIRST); + ast_generator.CheckLiteral(edge->lower_bound_, 42); + EXPECT_EQ(edge->upper_bound_, nullptr); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternUpperBounded) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("MATCH ()-[r*..42]->() RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast(single_query->clauses_[0]); + EdgeAtom *edge = nullptr; + AssertMatchSingleEdgeAtom(match, edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::OUT); + EXPECT_EQ(edge->type_, EdgeAtom::Type::DEPTH_FIRST); + EXPECT_EQ(edge->lower_bound_, nullptr); + ast_generator.CheckLiteral(edge->upper_bound_, 42); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternLowerUpperBounded) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("MATCH ()-[r*24..42]->() RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast(single_query->clauses_[0]); + EdgeAtom *edge = nullptr; + AssertMatchSingleEdgeAtom(match, edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::OUT); + EXPECT_EQ(edge->type_, EdgeAtom::Type::DEPTH_FIRST); + ast_generator.CheckLiteral(edge->lower_bound_, 24); + ast_generator.CheckLiteral(edge->upper_bound_, 42); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternFixedRange) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("MATCH ()-[r*42]->() RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast(single_query->clauses_[0]); + EdgeAtom *edge = nullptr; + AssertMatchSingleEdgeAtom(match, edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::OUT); + EXPECT_EQ(edge->type_, EdgeAtom::Type::DEPTH_FIRST); + ast_generator.CheckLiteral(edge->lower_bound_, 42); + ast_generator.CheckLiteral(edge->upper_bound_, 42); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternFloatingUpperBound) { + // [r*1...2] should be parsed as [r*1..0.2] + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("MATCH ()-[r*1...2]->() RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast(single_query->clauses_[0]); + EdgeAtom *edge = nullptr; + AssertMatchSingleEdgeAtom(match, edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::OUT); + EXPECT_EQ(edge->type_, EdgeAtom::Type::DEPTH_FIRST); + ast_generator.CheckLiteral(edge->lower_bound_, 1); + ast_generator.CheckLiteral(edge->upper_bound_, 0.2); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternUnboundedWithProperty) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("MATCH ()-[r* {prop: 42}]->() RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast(single_query->clauses_[0]); + EdgeAtom *edge = nullptr; + AssertMatchSingleEdgeAtom(match, edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::OUT); + EXPECT_EQ(edge->type_, EdgeAtom::Type::DEPTH_FIRST); + EXPECT_EQ(edge->lower_bound_, nullptr); + EXPECT_EQ(edge->upper_bound_, nullptr); + ast_generator.CheckLiteral(std::get<0>(edge->properties_)[ast_generator.Prop("prop")], 42); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternDotsUnboundedWithEdgeTypeProperty) { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast(ast_generator.ParseQuery("MATCH ()-[r:edge_type*..{prop: 42}]->() RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast(single_query->clauses_[0]); + EdgeAtom *edge = nullptr; + AssertMatchSingleEdgeAtom(match, edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::OUT); + EXPECT_EQ(edge->type_, EdgeAtom::Type::DEPTH_FIRST); + EXPECT_EQ(edge->lower_bound_, nullptr); + EXPECT_EQ(edge->upper_bound_, nullptr); + ast_generator.CheckLiteral(std::get<0>(edge->properties_)[ast_generator.Prop("prop")], 42); + ASSERT_EQ(edge->edge_types_.size(), 1U); + auto edge_type = ast_generator.EdgeType("edge_type"); + EXPECT_EQ(edge->edge_types_[0], edge_type); +} + +TEST_P(CypherMainVisitorTest, RelationshipPatternUpperBoundedWithProperty) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("MATCH ()-[r*..2{prop: 42}]->() RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + auto *match = dynamic_cast(single_query->clauses_[0]); + EdgeAtom *edge = nullptr; + AssertMatchSingleEdgeAtom(match, edge); + EXPECT_EQ(edge->direction_, EdgeAtom::Direction::OUT); + EXPECT_EQ(edge->type_, EdgeAtom::Type::DEPTH_FIRST); + EXPECT_EQ(edge->lower_bound_, nullptr); + ast_generator.CheckLiteral(edge->upper_bound_, 2); + ast_generator.CheckLiteral(std::get<0>(edge->properties_)[ast_generator.Prop("prop")], 42); +} + +// TODO maybe uncomment +// // PatternPart with variable. +// TEST_P(CypherMainVisitorTest, PatternPartVariable) { +// ParserTables parser("CREATE var=()--()"); +// ASSERT_EQ(parser.identifiers_map_.size(), 1U); +// ASSERT_EQ(parser.pattern_parts_.size(), 1U); +// ASSERT_EQ(parser.relationships_.size(), 1U); +// ASSERT_EQ(parser.nodes_.size(), 2U); +// ASSERT_EQ(parser.pattern_parts_.begin()->second.nodes.size(), 2U); +// ASSERT_EQ(parser.pattern_parts_.begin()->second.relationships.size(), 1U); +// ASSERT_NE(parser.identifiers_map_.find("var"), +// parser.identifiers_map_.end()); +// auto output_identifier = parser.identifiers_map_["var"]; +// ASSERT_NE(parser.pattern_parts_.find(output_identifier), +// parser.pattern_parts_.end()); +// } + +TEST_P(CypherMainVisitorTest, ReturnUnanemdIdentifier) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN var")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(return_clause); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 1U); + auto *named_expr = return_clause->body_.named_expressions[0]; + ASSERT_TRUE(named_expr); + ASSERT_EQ(named_expr->name_, "var"); + auto *identifier = dynamic_cast(named_expr->expression_); + ASSERT_TRUE(identifier); + ASSERT_EQ(identifier->name_, "var"); + ASSERT_TRUE(identifier->user_declared_); +} + +TEST_P(CypherMainVisitorTest, Create) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("CREATE (n)")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *create = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(create); + ASSERT_EQ(create->patterns_.size(), 1U); + ASSERT_TRUE(create->patterns_[0]); + ASSERT_EQ(create->patterns_[0]->atoms_.size(), 1U); + auto node = dynamic_cast(create->patterns_[0]->atoms_[0]); + ASSERT_TRUE(node); + ASSERT_TRUE(node->identifier_); + ASSERT_EQ(node->identifier_->name_, "n"); +} + +TEST_P(CypherMainVisitorTest, Delete) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("DELETE n, m")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *del = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(del); + ASSERT_FALSE(del->detach_); + ASSERT_EQ(del->expressions_.size(), 2U); + auto *identifier1 = dynamic_cast(del->expressions_[0]); + ASSERT_TRUE(identifier1); + ASSERT_EQ(identifier1->name_, "n"); + auto *identifier2 = dynamic_cast(del->expressions_[1]); + ASSERT_TRUE(identifier2); + ASSERT_EQ(identifier2->name_, "m"); +} + +TEST_P(CypherMainVisitorTest, DeleteDetach) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("DETACH DELETE n")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *del = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(del); + ASSERT_TRUE(del->detach_); + ASSERT_EQ(del->expressions_.size(), 1U); + auto *identifier1 = dynamic_cast(del->expressions_[0]); + ASSERT_TRUE(identifier1); + ASSERT_EQ(identifier1->name_, "n"); +} + +TEST_P(CypherMainVisitorTest, OptionalMatchWhere) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("OPTIONAL MATCH (n) WHERE m RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *match = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(match); + EXPECT_TRUE(match->optional_); + ASSERT_TRUE(match->where_); + auto *identifier = dynamic_cast(match->where_->expression_); + ASSERT_TRUE(identifier); + ASSERT_EQ(identifier->name_, "m"); +} + +TEST_P(CypherMainVisitorTest, Set) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("SET a.x = b, c = d, e += f, g : h : i ")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 4U); + + { + auto *set_property = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(set_property); + ASSERT_TRUE(set_property->property_lookup_); + auto *identifier1 = dynamic_cast(set_property->property_lookup_->expression_); + ASSERT_TRUE(identifier1); + ASSERT_EQ(identifier1->name_, "a"); + ASSERT_EQ(set_property->property_lookup_->property_, ast_generator.Prop("x")); + auto *identifier2 = dynamic_cast(set_property->expression_); + ASSERT_EQ(identifier2->name_, "b"); + } + + { + auto *set_properties_assignment = dynamic_cast(single_query->clauses_[1]); + ASSERT_TRUE(set_properties_assignment); + ASSERT_FALSE(set_properties_assignment->update_); + ASSERT_TRUE(set_properties_assignment->identifier_); + ASSERT_EQ(set_properties_assignment->identifier_->name_, "c"); + auto *identifier = dynamic_cast(set_properties_assignment->expression_); + ASSERT_EQ(identifier->name_, "d"); + } + + { + auto *set_properties_update = dynamic_cast(single_query->clauses_[2]); + ASSERT_TRUE(set_properties_update); + ASSERT_TRUE(set_properties_update->update_); + ASSERT_TRUE(set_properties_update->identifier_); + ASSERT_EQ(set_properties_update->identifier_->name_, "e"); + auto *identifier = dynamic_cast(set_properties_update->expression_); + ASSERT_EQ(identifier->name_, "f"); + } + + { + auto *set_labels = dynamic_cast(single_query->clauses_[3]); + ASSERT_TRUE(set_labels); + ASSERT_TRUE(set_labels->identifier_); + ASSERT_EQ(set_labels->identifier_->name_, "g"); + ASSERT_THAT(set_labels->labels_, UnorderedElementsAre(ast_generator.Label("h"), ast_generator.Label("i"))); + } +} + +TEST_P(CypherMainVisitorTest, Remove) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("REMOVE a.x, g : h : i")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + { + auto *remove_property = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(remove_property); + ASSERT_TRUE(remove_property->property_lookup_); + auto *identifier1 = dynamic_cast(remove_property->property_lookup_->expression_); + ASSERT_TRUE(identifier1); + ASSERT_EQ(identifier1->name_, "a"); + ASSERT_EQ(remove_property->property_lookup_->property_, ast_generator.Prop("x")); + } + { + auto *remove_labels = dynamic_cast(single_query->clauses_[1]); + ASSERT_TRUE(remove_labels); + ASSERT_TRUE(remove_labels->identifier_); + ASSERT_EQ(remove_labels->identifier_->name_, "g"); + ASSERT_THAT(remove_labels->labels_, UnorderedElementsAre(ast_generator.Label("h"), ast_generator.Label("i"))); + } +} + +TEST_P(CypherMainVisitorTest, With) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("WITH n AS m RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *with = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(with); + ASSERT_FALSE(with->body_.distinct); + ASSERT_FALSE(with->body_.limit); + ASSERT_FALSE(with->body_.skip); + ASSERT_EQ(with->body_.order_by.size(), 0U); + ASSERT_FALSE(with->where_); + ASSERT_EQ(with->body_.named_expressions.size(), 1U); + auto *named_expr = with->body_.named_expressions[0]; + ASSERT_EQ(named_expr->name_, "m"); + auto *identifier = dynamic_cast(named_expr->expression_); + ASSERT_EQ(identifier->name_, "n"); +} + +TEST_P(CypherMainVisitorTest, WithNonAliasedExpression) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("WITH n.x RETURN 1"), SemanticException); +} + +TEST_P(CypherMainVisitorTest, WithNonAliasedVariable) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("WITH n RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *with = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(with); + ASSERT_EQ(with->body_.named_expressions.size(), 1U); + auto *named_expr = with->body_.named_expressions[0]; + ASSERT_EQ(named_expr->name_, "n"); + auto *identifier = dynamic_cast(named_expr->expression_); + ASSERT_EQ(identifier->name_, "n"); +} + +TEST_P(CypherMainVisitorTest, WithDistinct) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("WITH DISTINCT n AS m RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *with = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(with->body_.distinct); + ASSERT_FALSE(with->where_); + ASSERT_EQ(with->body_.named_expressions.size(), 1U); + auto *named_expr = with->body_.named_expressions[0]; + ASSERT_EQ(named_expr->name_, "m"); + auto *identifier = dynamic_cast(named_expr->expression_); + ASSERT_EQ(identifier->name_, "n"); +} + +TEST_P(CypherMainVisitorTest, WithBag) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("WITH n as m ORDER BY m SKIP 1 LIMIT 2 RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *with = dynamic_cast(single_query->clauses_[0]); + ASSERT_FALSE(with->body_.distinct); + ASSERT_FALSE(with->where_); + ASSERT_EQ(with->body_.named_expressions.size(), 1U); + // No need to check contents of body. That is checked in RETURN clause tests. + ASSERT_EQ(with->body_.order_by.size(), 1U); + ASSERT_TRUE(with->body_.limit); + ASSERT_TRUE(with->body_.skip); +} + +TEST_P(CypherMainVisitorTest, WithWhere) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("WITH n AS m WHERE k RETURN 1")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *with = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(with); + ASSERT_TRUE(with->where_); + auto *identifier = dynamic_cast(with->where_->expression_); + ASSERT_TRUE(identifier); + ASSERT_EQ(identifier->name_, "k"); + ASSERT_EQ(with->body_.named_expressions.size(), 1U); + auto *named_expr = with->body_.named_expressions[0]; + ASSERT_EQ(named_expr->name_, "m"); + auto *identifier2 = dynamic_cast(named_expr->expression_); + ASSERT_EQ(identifier2->name_, "n"); +} + +TEST_P(CypherMainVisitorTest, WithAnonymousVariableCapture) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("WITH 5 as anon1 MATCH () return *")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 3U); + auto *match = dynamic_cast(single_query->clauses_[1]); + ASSERT_TRUE(match); + ASSERT_EQ(match->patterns_.size(), 1U); + auto *pattern = match->patterns_[0]; + ASSERT_TRUE(pattern); + ASSERT_EQ(pattern->atoms_.size(), 1U); + auto *atom = dynamic_cast(pattern->atoms_[0]); + ASSERT_TRUE(atom); + ASSERT_NE("anon1", atom->identifier_->name_); +} + +TEST_P(CypherMainVisitorTest, ClausesOrdering) { + // Obviously some of the ridiculous combinations don't fail here, but they + // will fail in semantic analysis or they make perfect sense as a part of + // bigger query. + auto &ast_generator = *GetParam(); + ast_generator.ParseQuery("RETURN 1"); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 1 RETURN 1"), SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 1 MATCH (n) RETURN n"), SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 1 DELETE n"), SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 1 MERGE (n)"), SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 1 WITH n AS m RETURN 1"), SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 1 AS n UNWIND n AS x RETURN x"), SemanticException); + + ASSERT_THROW(ast_generator.ParseQuery("OPTIONAL MATCH (n) MATCH (m) RETURN n, m"), SemanticException); + ast_generator.ParseQuery("OPTIONAL MATCH (n) WITH n MATCH (m) RETURN n, m"); + ast_generator.ParseQuery("OPTIONAL MATCH (n) OPTIONAL MATCH (m) RETURN n, m"); + ast_generator.ParseQuery("MATCH (n) OPTIONAL MATCH (m) RETURN n, m"); + + ast_generator.ParseQuery("CREATE (n)"); + ASSERT_THROW(ast_generator.ParseQuery("SET n:x MATCH (n) RETURN n"), SemanticException); + ast_generator.ParseQuery("REMOVE n.x SET n.x = 1"); + ast_generator.ParseQuery("REMOVE n:L RETURN n"); + ast_generator.ParseQuery("SET n.x = 1 WITH n AS m RETURN m"); + + ASSERT_THROW(ast_generator.ParseQuery("MATCH (n)"), SemanticException); + ast_generator.ParseQuery("MATCH (n) MATCH (n) RETURN n"); + ast_generator.ParseQuery("MATCH (n) SET n = m"); + ast_generator.ParseQuery("MATCH (n) RETURN n"); + ast_generator.ParseQuery("MATCH (n) WITH n AS m RETURN m"); + + ASSERT_THROW(ast_generator.ParseQuery("WITH 1 AS n"), SemanticException); + ast_generator.ParseQuery("WITH 1 AS n WITH n AS m RETURN m"); + ast_generator.ParseQuery("WITH 1 AS n RETURN n"); + ast_generator.ParseQuery("WITH 1 AS n SET n += m"); + ast_generator.ParseQuery("WITH 1 AS n MATCH (n) RETURN n"); + + ASSERT_THROW(ast_generator.ParseQuery("UNWIND [1,2,3] AS x"), SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("CREATE (n) UNWIND [1,2,3] AS x RETURN x"), SemanticException); + ast_generator.ParseQuery("UNWIND [1,2,3] AS x CREATE (n) RETURN x"); + ast_generator.ParseQuery("CREATE (n) WITH n UNWIND [1,2,3] AS x RETURN x"); +} + +TEST_P(CypherMainVisitorTest, Merge) { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast(ast_generator.ParseQuery("MERGE (a) -[:r]- (b) ON MATCH SET a.x = b.x " + "ON CREATE SET b :label ON MATCH SET b = a")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *merge = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(merge); + EXPECT_TRUE(dynamic_cast(merge->pattern_)); + ASSERT_EQ(merge->on_match_.size(), 2U); + EXPECT_TRUE(dynamic_cast(merge->on_match_[0])); + EXPECT_TRUE(dynamic_cast(merge->on_match_[1])); + ASSERT_EQ(merge->on_create_.size(), 1U); + EXPECT_TRUE(dynamic_cast(merge->on_create_[0])); +} + +TEST_P(CypherMainVisitorTest, Unwind) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("UNWIND [1,2,3] AS elem RETURN elem")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *unwind = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(unwind); + auto *ret = dynamic_cast(single_query->clauses_[1]); + EXPECT_TRUE(ret); + ASSERT_TRUE(unwind->named_expression_); + EXPECT_EQ(unwind->named_expression_->name_, "elem"); + auto *expr = unwind->named_expression_->expression_; + ASSERT_TRUE(expr); + ASSERT_TRUE(dynamic_cast(expr)); +} + +TEST_P(CypherMainVisitorTest, UnwindWithoutAsError) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("UNWIND [1,2,3] RETURN 42"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, CreateIndex) { + auto &ast_generator = *GetParam(); + auto *index_query = dynamic_cast(ast_generator.ParseQuery("Create InDeX oN :mirko(slavko)")); + ASSERT_TRUE(index_query); + EXPECT_EQ(index_query->action_, IndexQuery::Action::CREATE); + EXPECT_EQ(index_query->label_, ast_generator.Label("mirko")); + std::vector expected_properties{ast_generator.Prop("slavko")}; + EXPECT_EQ(index_query->properties_, expected_properties); +} + +TEST_P(CypherMainVisitorTest, DropIndex) { + auto &ast_generator = *GetParam(); + auto *index_query = dynamic_cast(ast_generator.ParseQuery("dRoP InDeX oN :mirko(slavko)")); + ASSERT_TRUE(index_query); + EXPECT_EQ(index_query->action_, IndexQuery::Action::DROP); + EXPECT_EQ(index_query->label_, ast_generator.Label("mirko")); + std::vector expected_properties{ast_generator.Prop("slavko")}; + EXPECT_EQ(index_query->properties_, expected_properties); +} + +TEST_P(CypherMainVisitorTest, DropIndexWithoutProperties) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("dRoP InDeX oN :mirko()"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, DropIndexWithMultipleProperties) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("dRoP InDeX oN :mirko(slavko, pero)"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, ReturnAll) { + { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("RETURN all(x in [1,2,3])"), SyntaxException); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN all(x IN [1,2,3] WHERE x = 2)")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *ret = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(ret); + ASSERT_EQ(ret->body_.named_expressions.size(), 1U); + auto *all = dynamic_cast(ret->body_.named_expressions[0]->expression_); + ASSERT_TRUE(all); + EXPECT_EQ(all->identifier_->name_, "x"); + auto *list_literal = dynamic_cast(all->list_expression_); + EXPECT_TRUE(list_literal); + auto *eq = dynamic_cast(all->where_->expression_); + EXPECT_TRUE(eq); + } +} + +TEST_P(CypherMainVisitorTest, ReturnSingle) { + { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("RETURN single(x in [1,2,3])"), SyntaxException); + } + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN single(x IN [1,2,3] WHERE x = 2)")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *ret = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(ret); + ASSERT_EQ(ret->body_.named_expressions.size(), 1U); + auto *single = dynamic_cast(ret->body_.named_expressions[0]->expression_); + ASSERT_TRUE(single); + EXPECT_EQ(single->identifier_->name_, "x"); + auto *list_literal = dynamic_cast(single->list_expression_); + EXPECT_TRUE(list_literal); + auto *eq = dynamic_cast(single->where_->expression_); + EXPECT_TRUE(eq); +} + +TEST_P(CypherMainVisitorTest, ReturnReduce) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN reduce(sum = 0, x IN [1,2,3] | sum + x)")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *ret = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(ret); + ASSERT_EQ(ret->body_.named_expressions.size(), 1U); + auto *reduce = dynamic_cast(ret->body_.named_expressions[0]->expression_); + ASSERT_TRUE(reduce); + EXPECT_EQ(reduce->accumulator_->name_, "sum"); + ast_generator.CheckLiteral(reduce->initializer_, 0); + EXPECT_EQ(reduce->identifier_->name_, "x"); + auto *list_literal = dynamic_cast(reduce->list_); + EXPECT_TRUE(list_literal); + auto *add = dynamic_cast(reduce->expression_); + EXPECT_TRUE(add); +} + +TEST_P(CypherMainVisitorTest, ReturnExtract) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN extract(x IN [1,2,3] | sum + x)")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *ret = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(ret); + ASSERT_EQ(ret->body_.named_expressions.size(), 1U); + auto *extract = dynamic_cast(ret->body_.named_expressions[0]->expression_); + ASSERT_TRUE(extract); + EXPECT_EQ(extract->identifier_->name_, "x"); + auto *list_literal = dynamic_cast(extract->list_); + EXPECT_TRUE(list_literal); + auto *add = dynamic_cast(extract->expression_); + EXPECT_TRUE(add); +} + +TEST_P(CypherMainVisitorTest, MatchBfsReturn) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast( + ast_generator.ParseQuery("MATCH (n) -[r:type1|type2 *bfs..10 (e, n|e.prop = 42)]-> (m) RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *match = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(match); + ASSERT_EQ(match->patterns_.size(), 1U); + ASSERT_EQ(match->patterns_[0]->atoms_.size(), 3U); + auto *bfs = dynamic_cast(match->patterns_[0]->atoms_[1]); + ASSERT_TRUE(bfs); + EXPECT_TRUE(bfs->IsVariable()); + EXPECT_EQ(bfs->direction_, EdgeAtom::Direction::OUT); + EXPECT_THAT(bfs->edge_types_, UnorderedElementsAre(ast_generator.EdgeType("type1"), ast_generator.EdgeType("type2"))); + EXPECT_EQ(bfs->identifier_->name_, "r"); + EXPECT_EQ(bfs->filter_lambda_.inner_edge->name_, "e"); + EXPECT_TRUE(bfs->filter_lambda_.inner_edge->user_declared_); + EXPECT_EQ(bfs->filter_lambda_.inner_node->name_, "n"); + EXPECT_TRUE(bfs->filter_lambda_.inner_node->user_declared_); + ast_generator.CheckLiteral(bfs->upper_bound_, 10); + auto *eq = dynamic_cast(bfs->filter_lambda_.expression); + ASSERT_TRUE(eq); +} + +TEST_P(CypherMainVisitorTest, MatchVariableLambdaSymbols) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("MATCH () -[*]- () RETURN *")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *match = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(match); + ASSERT_EQ(match->patterns_.size(), 1U); + ASSERT_EQ(match->patterns_[0]->atoms_.size(), 3U); + auto *var_expand = dynamic_cast(match->patterns_[0]->atoms_[1]); + ASSERT_TRUE(var_expand); + ASSERT_TRUE(var_expand->IsVariable()); + EXPECT_FALSE(var_expand->filter_lambda_.inner_edge->user_declared_); + EXPECT_FALSE(var_expand->filter_lambda_.inner_node->user_declared_); +} + +TEST_P(CypherMainVisitorTest, MatchWShortestReturn) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast( + ast_generator.ParseQuery("MATCH ()-[r:type1|type2 *wShortest 10 (we, wn | 42) total_weight " + "(e, n | true)]->() RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *match = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(match); + ASSERT_EQ(match->patterns_.size(), 1U); + ASSERT_EQ(match->patterns_[0]->atoms_.size(), 3U); + auto *shortest = dynamic_cast(match->patterns_[0]->atoms_[1]); + ASSERT_TRUE(shortest); + EXPECT_TRUE(shortest->IsVariable()); + EXPECT_EQ(shortest->type_, EdgeAtom::Type::WEIGHTED_SHORTEST_PATH); + EXPECT_EQ(shortest->direction_, EdgeAtom::Direction::OUT); + EXPECT_THAT(shortest->edge_types_, + UnorderedElementsAre(ast_generator.EdgeType("type1"), ast_generator.EdgeType("type2"))); + ast_generator.CheckLiteral(shortest->upper_bound_, 10); + EXPECT_FALSE(shortest->lower_bound_); + EXPECT_EQ(shortest->identifier_->name_, "r"); + EXPECT_EQ(shortest->filter_lambda_.inner_edge->name_, "e"); + EXPECT_TRUE(shortest->filter_lambda_.inner_edge->user_declared_); + EXPECT_EQ(shortest->filter_lambda_.inner_node->name_, "n"); + EXPECT_TRUE(shortest->filter_lambda_.inner_node->user_declared_); + ast_generator.CheckLiteral(shortest->filter_lambda_.expression, true); + EXPECT_EQ(shortest->weight_lambda_.inner_edge->name_, "we"); + EXPECT_TRUE(shortest->weight_lambda_.inner_edge->user_declared_); + EXPECT_EQ(shortest->weight_lambda_.inner_node->name_, "wn"); + EXPECT_TRUE(shortest->weight_lambda_.inner_node->user_declared_); + ast_generator.CheckLiteral(shortest->weight_lambda_.expression, 42); + ASSERT_TRUE(shortest->total_weight_); + EXPECT_EQ(shortest->total_weight_->name_, "total_weight"); + EXPECT_TRUE(shortest->total_weight_->user_declared_); +} + +TEST_P(CypherMainVisitorTest, MatchWShortestNoFilterReturn) { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast(ast_generator.ParseQuery("MATCH ()-[r:type1|type2 *wShortest 10 (we, wn | 42)]->() " + "RETURN r")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *match = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(match); + ASSERT_EQ(match->patterns_.size(), 1U); + ASSERT_EQ(match->patterns_[0]->atoms_.size(), 3U); + auto *shortest = dynamic_cast(match->patterns_[0]->atoms_[1]); + ASSERT_TRUE(shortest); + EXPECT_TRUE(shortest->IsVariable()); + EXPECT_EQ(shortest->type_, EdgeAtom::Type::WEIGHTED_SHORTEST_PATH); + EXPECT_EQ(shortest->direction_, EdgeAtom::Direction::OUT); + EXPECT_THAT(shortest->edge_types_, + UnorderedElementsAre(ast_generator.EdgeType("type1"), ast_generator.EdgeType("type2"))); + ast_generator.CheckLiteral(shortest->upper_bound_, 10); + EXPECT_FALSE(shortest->lower_bound_); + EXPECT_EQ(shortest->identifier_->name_, "r"); + EXPECT_FALSE(shortest->filter_lambda_.expression); + EXPECT_FALSE(shortest->filter_lambda_.inner_edge->user_declared_); + EXPECT_FALSE(shortest->filter_lambda_.inner_node->user_declared_); + EXPECT_EQ(shortest->weight_lambda_.inner_edge->name_, "we"); + EXPECT_TRUE(shortest->weight_lambda_.inner_edge->user_declared_); + EXPECT_EQ(shortest->weight_lambda_.inner_node->name_, "wn"); + EXPECT_TRUE(shortest->weight_lambda_.inner_node->user_declared_); + ast_generator.CheckLiteral(shortest->weight_lambda_.expression, 42); + ASSERT_TRUE(shortest->total_weight_); + EXPECT_FALSE(shortest->total_weight_->user_declared_); +} + +TEST_P(CypherMainVisitorTest, SemanticExceptionOnWShortestLowerBound) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("MATCH ()-[r *wShortest 10.. (e, n | 42)]-() RETURN r"), SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("MATCH ()-[r *wShortest 10..20 (e, n | 42)]-() RETURN r"), SemanticException); +} + +TEST_P(CypherMainVisitorTest, SemanticExceptionOnWShortestWithoutLambda) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("MATCH ()-[r *wShortest]-() RETURN r"), SemanticException); +} + +TEST_P(CypherMainVisitorTest, SemanticExceptionOnUnionTypeMix) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 5 as X UNION ALL RETURN 6 AS X UNION RETURN 7 AS X"), + SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 5 as X UNION RETURN 6 AS X UNION ALL RETURN 7 AS X"), + SemanticException); +} + +TEST_P(CypherMainVisitorTest, Union) { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast(ast_generator.ParseQuery("RETURN 5 AS X, 6 AS Y UNION RETURN 6 AS X, 5 AS Y")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ASSERT_FALSE(return_clause->body_.all_identifiers); + ASSERT_EQ(return_clause->body_.order_by.size(), 0U); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 2U); + ASSERT_FALSE(return_clause->body_.limit); + ASSERT_FALSE(return_clause->body_.skip); + ASSERT_FALSE(return_clause->body_.distinct); + + ASSERT_EQ(query->cypher_unions_.size(), 1); + auto *cypher_union = query->cypher_unions_.at(0); + ASSERT_TRUE(cypher_union); + ASSERT_TRUE(cypher_union->distinct_); + ASSERT_TRUE(single_query = cypher_union->single_query_); + ASSERT_EQ(single_query->clauses_.size(), 1U); + return_clause = dynamic_cast(single_query->clauses_[0]); + ASSERT_FALSE(return_clause->body_.all_identifiers); + ASSERT_EQ(return_clause->body_.order_by.size(), 0U); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 2U); + ASSERT_FALSE(return_clause->body_.limit); + ASSERT_FALSE(return_clause->body_.skip); + ASSERT_FALSE(return_clause->body_.distinct); +} + +TEST_P(CypherMainVisitorTest, UnionAll) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast( + ast_generator.ParseQuery("RETURN 5 AS X UNION ALL RETURN 6 AS X UNION ALL RETURN 7 AS X")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ASSERT_FALSE(return_clause->body_.all_identifiers); + ASSERT_EQ(return_clause->body_.order_by.size(), 0U); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 1U); + ASSERT_FALSE(return_clause->body_.limit); + ASSERT_FALSE(return_clause->body_.skip); + ASSERT_FALSE(return_clause->body_.distinct); + + ASSERT_EQ(query->cypher_unions_.size(), 2); + + auto *cypher_union = query->cypher_unions_.at(0); + ASSERT_TRUE(cypher_union); + ASSERT_FALSE(cypher_union->distinct_); + ASSERT_TRUE(single_query = cypher_union->single_query_); + ASSERT_EQ(single_query->clauses_.size(), 1U); + return_clause = dynamic_cast(single_query->clauses_[0]); + ASSERT_FALSE(return_clause->body_.all_identifiers); + ASSERT_EQ(return_clause->body_.order_by.size(), 0U); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 1U); + ASSERT_FALSE(return_clause->body_.limit); + ASSERT_FALSE(return_clause->body_.skip); + ASSERT_FALSE(return_clause->body_.distinct); + + cypher_union = query->cypher_unions_.at(1); + ASSERT_TRUE(cypher_union); + ASSERT_FALSE(cypher_union->distinct_); + ASSERT_TRUE(single_query = cypher_union->single_query_); + ASSERT_EQ(single_query->clauses_.size(), 1U); + return_clause = dynamic_cast(single_query->clauses_[0]); + ASSERT_FALSE(return_clause->body_.all_identifiers); + ASSERT_EQ(return_clause->body_.order_by.size(), 0U); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 1U); + ASSERT_FALSE(return_clause->body_.limit); + ASSERT_FALSE(return_clause->body_.skip); + ASSERT_FALSE(return_clause->body_.distinct); +} + +void check_auth_query(Base *ast_generator, std::string input, AuthQuery::Action action, std::string user, + std::string role, std::string user_or_role, std::optional password, + std::vector privileges) { + auto *auth_query = dynamic_cast(ast_generator->ParseQuery(input)); + ASSERT_TRUE(auth_query); + EXPECT_EQ(auth_query->action_, action); + EXPECT_EQ(auth_query->user_, user); + EXPECT_EQ(auth_query->role_, role); + EXPECT_EQ(auth_query->user_or_role_, user_or_role); + ASSERT_EQ(static_cast(auth_query->password_), static_cast(password)); + if (password) { + ast_generator->CheckLiteral(auth_query->password_, *password); + } + EXPECT_EQ(auth_query->privileges_, privileges); +} + +TEST_P(CypherMainVisitorTest, UserOrRoleName) { + auto &ast_generator = *GetParam(); + check_auth_query(&ast_generator, "CREATE ROLE `user`", AuthQuery::Action::CREATE_ROLE, "", "user", "", {}, {}); + check_auth_query(&ast_generator, "CREATE ROLE us___er", AuthQuery::Action::CREATE_ROLE, "", "us___er", "", {}, {}); + check_auth_query(&ast_generator, "CREATE ROLE `us+er`", AuthQuery::Action::CREATE_ROLE, "", "us+er", "", {}, {}); + check_auth_query(&ast_generator, "CREATE ROLE `us|er`", AuthQuery::Action::CREATE_ROLE, "", "us|er", "", {}, {}); + check_auth_query(&ast_generator, "CREATE ROLE `us er`", AuthQuery::Action::CREATE_ROLE, "", "us er", "", {}, {}); +} + +TEST_P(CypherMainVisitorTest, CreateRole) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("CREATE ROLE"), SyntaxException); + check_auth_query(&ast_generator, "CREATE ROLE rola", AuthQuery::Action::CREATE_ROLE, "", "rola", "", {}, {}); + ASSERT_THROW(ast_generator.ParseQuery("CREATE ROLE lagano rolamo"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, DropRole) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("DROP ROLE"), SyntaxException); + check_auth_query(&ast_generator, "DROP ROLE rola", AuthQuery::Action::DROP_ROLE, "", "rola", "", {}, {}); + ASSERT_THROW(ast_generator.ParseQuery("DROP ROLE lagano rolamo"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, ShowRoles) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("SHOW ROLES ROLES"), SyntaxException); + check_auth_query(&ast_generator, "SHOW ROLES", AuthQuery::Action::SHOW_ROLES, "", "", "", {}, {}); +} + +TEST_P(CypherMainVisitorTest, CreateUser) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("CREATE USER"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CREATE USER 123"), SyntaxException); + check_auth_query(&ast_generator, "CREATE USER user", AuthQuery::Action::CREATE_USER, "user", "", "", {}, {}); + check_auth_query(&ast_generator, "CREATE USER user IDENTIFIED BY 'password'", AuthQuery::Action::CREATE_USER, "user", + "", "", TypedValue("password"), {}); + check_auth_query(&ast_generator, "CREATE USER user IDENTIFIED BY ''", AuthQuery::Action::CREATE_USER, "user", "", "", + TypedValue(""), {}); + check_auth_query(&ast_generator, "CREATE USER user IDENTIFIED BY null", AuthQuery::Action::CREATE_USER, "user", "", + "", TypedValue(), {}); + ASSERT_THROW(ast_generator.ParseQuery("CRATE USER user IDENTIFIED BY password"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CREATE USER user IDENTIFIED BY 5"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CREATE USER user IDENTIFIED BY "), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, SetPassword) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("SET PASSWORD FOR"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("SET PASSWORD FOR user "), SyntaxException); + check_auth_query(&ast_generator, "SET PASSWORD FOR user TO null", AuthQuery::Action::SET_PASSWORD, "user", "", "", + TypedValue(), {}); + check_auth_query(&ast_generator, "SET PASSWORD FOR user TO 'password'", AuthQuery::Action::SET_PASSWORD, "user", "", + "", TypedValue("password"), {}); + ASSERT_THROW(ast_generator.ParseQuery("SET PASSWORD FOR user To 5"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, DropUser) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("DROP USER"), SyntaxException); + check_auth_query(&ast_generator, "DROP USER user", AuthQuery::Action::DROP_USER, "user", "", "", {}, {}); + ASSERT_THROW(ast_generator.ParseQuery("DROP USER lagano rolamo"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, ShowUsers) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("SHOW USERS ROLES"), SyntaxException); + check_auth_query(&ast_generator, "SHOW USERS", AuthQuery::Action::SHOW_USERS, "", "", "", {}, {}); +} + +TEST_P(CypherMainVisitorTest, SetRole) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("SET ROLE"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("SET ROLE user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("SET ROLE FOR user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("SET ROLE FOR user TO"), SyntaxException); + check_auth_query(&ast_generator, "SET ROLE FOR user TO role", AuthQuery::Action::SET_ROLE, "user", "role", "", {}, + {}); + check_auth_query(&ast_generator, "SET ROLE FOR user TO null", AuthQuery::Action::SET_ROLE, "user", "null", "", {}, + {}); +} + +TEST_P(CypherMainVisitorTest, ClearRole) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("CLEAR ROLE"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CLEAR ROLE user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CLEAR ROLE FOR user TO"), SyntaxException); + check_auth_query(&ast_generator, "CLEAR ROLE FOR user", AuthQuery::Action::CLEAR_ROLE, "user", "", "", {}, {}); +} + +TEST_P(CypherMainVisitorTest, GrantPrivilege) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("GRANT"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("GRANT TO user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("GRANT BLABLA TO user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("GRANT MATCH, TO user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("GRANT MATCH, BLABLA TO user"), SyntaxException); + check_auth_query(&ast_generator, "GRANT MATCH TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MATCH}); + check_auth_query(&ast_generator, "GRANT MATCH, AUTH TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MATCH, AuthQuery::Privilege::AUTH}); + // Verify that all privileges are correctly visited. + check_auth_query(&ast_generator, "GRANT CREATE TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::CREATE}); + check_auth_query(&ast_generator, "GRANT DELETE TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::DELETE}); + check_auth_query(&ast_generator, "GRANT MERGE TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MERGE}); + check_auth_query(&ast_generator, "GRANT SET TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::SET}); + check_auth_query(&ast_generator, "GRANT REMOVE TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::REMOVE}); + check_auth_query(&ast_generator, "GRANT INDEX TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::INDEX}); + check_auth_query(&ast_generator, "GRANT STATS TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::STATS}); + check_auth_query(&ast_generator, "GRANT AUTH TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::AUTH}); + check_auth_query(&ast_generator, "GRANT CONSTRAINT TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::CONSTRAINT}); + check_auth_query(&ast_generator, "GRANT DUMP TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::DUMP}); + check_auth_query(&ast_generator, "GRANT REPLICATION TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::REPLICATION}); + check_auth_query(&ast_generator, "GRANT DURABILITY TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::DURABILITY}); + check_auth_query(&ast_generator, "GRANT READ_FILE TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::READ_FILE}); + check_auth_query(&ast_generator, "GRANT FREE_MEMORY TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::FREE_MEMORY}); + check_auth_query(&ast_generator, "GRANT TRIGGER TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::TRIGGER}); + check_auth_query(&ast_generator, "GRANT CONFIG TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::CONFIG}); + check_auth_query(&ast_generator, "GRANT STREAM TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::STREAM}); + check_auth_query(&ast_generator, "GRANT WEBSOCKET TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::WEBSOCKET}); + check_auth_query(&ast_generator, "GRANT MODULE_READ TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MODULE_READ}); + check_auth_query(&ast_generator, "GRANT MODULE_WRITE TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MODULE_WRITE}); + check_auth_query(&ast_generator, "GRANT SCHEMA TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::SCHEMA}); +} + +TEST_P(CypherMainVisitorTest, DenyPrivilege) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("DENY"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("DENY TO user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("DENY BLABLA TO user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("DENY MATCH, TO user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("DENY MATCH, BLABLA TO user"), SyntaxException); + check_auth_query(&ast_generator, "DENY MATCH TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MATCH}); + check_auth_query(&ast_generator, "DENY MATCH, AUTH TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MATCH, AuthQuery::Privilege::AUTH}); + // Verify that all privileges are correctly visited. + check_auth_query(&ast_generator, "DENY CREATE TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::CREATE}); + check_auth_query(&ast_generator, "DENY DELETE TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::DELETE}); + check_auth_query(&ast_generator, "DENY MERGE TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MERGE}); + check_auth_query(&ast_generator, "DENY SET TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::SET}); + check_auth_query(&ast_generator, "DENY REMOVE TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::REMOVE}); + check_auth_query(&ast_generator, "DENY INDEX TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::INDEX}); + check_auth_query(&ast_generator, "DENY STATS TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::STATS}); + check_auth_query(&ast_generator, "DENY AUTH TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::AUTH}); + check_auth_query(&ast_generator, "DENY CONSTRAINT TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::CONSTRAINT}); + check_auth_query(&ast_generator, "DENY DUMP TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::DUMP}); + check_auth_query(&ast_generator, "DENY WEBSOCKET TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::WEBSOCKET}); + check_auth_query(&ast_generator, "DENY MODULE_READ TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MODULE_READ}); + check_auth_query(&ast_generator, "DENY MODULE_WRITE TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MODULE_WRITE}); + check_auth_query(&ast_generator, "DENY SCHEMA TO user", AuthQuery::Action::DENY_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::SCHEMA}); +} + +TEST_P(CypherMainVisitorTest, RevokePrivilege) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("REVOKE"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("REVOKE FROM user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("REVOKE BLABLA FROM user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("REVOKE MATCH, FROM user"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("REVOKE MATCH, BLABLA FROM user"), SyntaxException); + check_auth_query(&ast_generator, "REVOKE MATCH FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MATCH}); + check_auth_query(&ast_generator, "REVOKE MATCH, AUTH FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", + {}, {AuthQuery::Privilege::MATCH, AuthQuery::Privilege::AUTH}); + check_auth_query(&ast_generator, "REVOKE ALL PRIVILEGES FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", + "user", {}, kPrivilegesAll); + // Verify that all privileges are correctly visited. + check_auth_query(&ast_generator, "REVOKE CREATE FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::CREATE}); + check_auth_query(&ast_generator, "REVOKE DELETE FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::DELETE}); + check_auth_query(&ast_generator, "REVOKE MERGE FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::MERGE}); + check_auth_query(&ast_generator, "REVOKE SET FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::SET}); + check_auth_query(&ast_generator, "REVOKE REMOVE FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::REMOVE}); + check_auth_query(&ast_generator, "REVOKE INDEX FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::INDEX}); + check_auth_query(&ast_generator, "REVOKE STATS FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::STATS}); + check_auth_query(&ast_generator, "REVOKE AUTH FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::AUTH}); + check_auth_query(&ast_generator, "REVOKE CONSTRAINT FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", + {}, {AuthQuery::Privilege::CONSTRAINT}); + check_auth_query(&ast_generator, "REVOKE DUMP FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::DUMP}); + check_auth_query(&ast_generator, "REVOKE WEBSOCKET FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", + {}, {AuthQuery::Privilege::WEBSOCKET}); + check_auth_query(&ast_generator, "REVOKE MODULE_READ FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", + {}, {AuthQuery::Privilege::MODULE_READ}); + check_auth_query(&ast_generator, "REVOKE MODULE_WRITE FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", + {}, {AuthQuery::Privilege::MODULE_WRITE}); + check_auth_query(&ast_generator, "REVOKE SCHEMA FROM user", AuthQuery::Action::REVOKE_PRIVILEGE, "", "", "user", {}, + {AuthQuery::Privilege::SCHEMA}); +} + +TEST_P(CypherMainVisitorTest, ShowPrivileges) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("SHOW PRIVILEGES FOR"), SyntaxException); + check_auth_query(&ast_generator, "SHOW PRIVILEGES FOR user", AuthQuery::Action::SHOW_PRIVILEGES, "", "", "user", {}, + {}); + ASSERT_THROW(ast_generator.ParseQuery("SHOW PRIVILEGES FOR user1, user2"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, ShowRoleForUser) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("SHOW ROLE FOR "), SyntaxException); + check_auth_query(&ast_generator, "SHOW ROLE FOR user", AuthQuery::Action::SHOW_ROLE_FOR_USER, "user", "", "", {}, {}); + ASSERT_THROW(ast_generator.ParseQuery("SHOW ROLE FOR user1, user2"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, ShowUsersForRole) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("SHOW USERS FOR "), SyntaxException); + check_auth_query(&ast_generator, "SHOW USERS FOR role", AuthQuery::Action::SHOW_USERS_FOR_ROLE, "", "role", "", {}, + {}); + ASSERT_THROW(ast_generator.ParseQuery("SHOW USERS FOR role1, role2"), SyntaxException); +} + +void check_replication_query(Base *ast_generator, const ReplicationQuery *query, const std::string name, + const std::optional socket_address, const ReplicationQuery::SyncMode sync_mode, + const std::optional port = {}) { + EXPECT_EQ(query->replica_name_, name); + EXPECT_EQ(query->sync_mode_, sync_mode); + ASSERT_EQ(static_cast(query->socket_address_), static_cast(socket_address)); + if (socket_address) { + ast_generator->CheckLiteral(query->socket_address_, *socket_address); + } + ASSERT_EQ(static_cast(query->port_), static_cast(port)); + if (port) { + ast_generator->CheckLiteral(query->port_, *port); + } +} + +TEST_P(CypherMainVisitorTest, TestShowReplicationMode) { + auto &ast_generator = *GetParam(); + const std::string raw_query = "SHOW REPLICATION ROLE"; + auto *parsed_query = dynamic_cast(ast_generator.ParseQuery(raw_query)); + EXPECT_EQ(parsed_query->action_, ReplicationQuery::Action::SHOW_REPLICATION_ROLE); +} + +TEST_P(CypherMainVisitorTest, TestShowReplicasQuery) { + auto &ast_generator = *GetParam(); + const std::string raw_query = "SHOW REPLICAS"; + auto *parsed_query = dynamic_cast(ast_generator.ParseQuery(raw_query)); + EXPECT_EQ(parsed_query->action_, ReplicationQuery::Action::SHOW_REPLICAS); +} + +TEST_P(CypherMainVisitorTest, TestSetReplicationMode) { + auto &ast_generator = *GetParam(); + + { + const std::string query = "SET REPLICATION ROLE"; + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = "SET REPLICATION ROLE TO BUTTERY"; + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = "SET REPLICATION ROLE TO MAIN"; + auto *parsed_query = dynamic_cast(ast_generator.ParseQuery(query)); + EXPECT_EQ(parsed_query->action_, ReplicationQuery::Action::SET_REPLICATION_ROLE); + EXPECT_EQ(parsed_query->role_, ReplicationQuery::ReplicationRole::MAIN); + } + + { + const std::string query = "SET REPLICATION ROLE TO MAIN WITH PORT 10000"; + ASSERT_THROW(ast_generator.ParseQuery(query), SemanticException); + } + + { + const std::string query = "SET REPLICATION ROLE TO REPLICA WITH PORT 10000"; + auto *parsed_query = dynamic_cast(ast_generator.ParseQuery(query)); + EXPECT_EQ(parsed_query->action_, ReplicationQuery::Action::SET_REPLICATION_ROLE); + EXPECT_EQ(parsed_query->role_, ReplicationQuery::ReplicationRole::REPLICA); + ast_generator.CheckLiteral(parsed_query->port_, TypedValue(10000)); + } +} + +TEST_P(CypherMainVisitorTest, TestRegisterReplicationQuery) { + auto &ast_generator = *GetParam(); + + const std::string faulty_query = "REGISTER REPLICA TO"; + ASSERT_THROW(ast_generator.ParseQuery(faulty_query), SyntaxException); + + const std::string faulty_query_with_timeout = R"(REGISTER REPLICA replica1 SYNC WITH TIMEOUT 1.0 TO "127.0.0.1")"; + ASSERT_THROW(ast_generator.ParseQuery(faulty_query_with_timeout), SyntaxException); + + const std::string correct_query = R"(REGISTER REPLICA replica1 SYNC TO "127.0.0.1")"; + auto *correct_query_parsed = dynamic_cast(ast_generator.ParseQuery(correct_query)); + check_replication_query(&ast_generator, correct_query_parsed, "replica1", TypedValue("127.0.0.1"), + ReplicationQuery::SyncMode::SYNC); + + std::string full_query = R"(REGISTER REPLICA replica2 SYNC TO "1.1.1.1:10000")"; + auto *full_query_parsed = dynamic_cast(ast_generator.ParseQuery(full_query)); + ASSERT_TRUE(full_query_parsed); + check_replication_query(&ast_generator, full_query_parsed, "replica2", TypedValue("1.1.1.1:10000"), + ReplicationQuery::SyncMode::SYNC); +} + +TEST_P(CypherMainVisitorTest, TestDeleteReplica) { + auto &ast_generator = *GetParam(); + + std::string missing_name_query = "DROP REPLICA"; + ASSERT_THROW(ast_generator.ParseQuery(missing_name_query), SyntaxException); + + std::string correct_query = "DROP REPLICA replica1"; + auto *correct_query_parsed = dynamic_cast(ast_generator.ParseQuery(correct_query)); + ASSERT_TRUE(correct_query_parsed); + EXPECT_EQ(correct_query_parsed->replica_name_, "replica1"); +} + +TEST_P(CypherMainVisitorTest, TestExplainRegularQuery) { + auto &ast_generator = *GetParam(); + EXPECT_TRUE(dynamic_cast(ast_generator.ParseQuery("EXPLAIN RETURN n"))); +} + +TEST_P(CypherMainVisitorTest, TestExplainExplainQuery) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("EXPLAIN EXPLAIN RETURN n"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, TestExplainAuthQuery) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("EXPLAIN SHOW ROLES"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, TestProfileRegularQuery) { + { + auto &ast_generator = *GetParam(); + EXPECT_TRUE(dynamic_cast(ast_generator.ParseQuery("PROFILE RETURN n"))); + } +} + +TEST_P(CypherMainVisitorTest, TestProfileComplicatedQuery) { + { + auto &ast_generator = *GetParam(); + EXPECT_TRUE( + dynamic_cast(ast_generator.ParseQuery("profile optional match (n) where n.hello = 5 " + "return n union optional match (n) where n.there = 10 " + "return n"))); + } +} + +TEST_P(CypherMainVisitorTest, TestProfileProfileQuery) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("PROFILE PROFILE RETURN n"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, TestProfileAuthQuery) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("PROFILE SHOW ROLES"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, TestShowStorageInfo) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("SHOW STORAGE INFO")); + ASSERT_TRUE(query); + EXPECT_EQ(query->info_type_, InfoQuery::InfoType::STORAGE); +} + +TEST_P(CypherMainVisitorTest, TestShowIndexInfo) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("SHOW INDEX INFO")); + ASSERT_TRUE(query); + EXPECT_EQ(query->info_type_, InfoQuery::InfoType::INDEX); +} + +TEST_P(CypherMainVisitorTest, TestShowConstraintInfo) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("SHOW CONSTRAINT INFO")); + ASSERT_TRUE(query); + EXPECT_EQ(query->info_type_, InfoQuery::InfoType::CONSTRAINT); +} + +TEST_P(CypherMainVisitorTest, CreateConstraintSyntaxError) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (:label) ASSERT EXISTS"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT () ASSERT EXISTS"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON () ASSERT EXISTS(prop1)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON () ASSERT EXISTS (prop1, prop2)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT " + "EXISTS (n.prop1, missing.prop2)"), + SemanticException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT " + "EXISTS (m.prop1, m.prop2)"), + SemanticException); + + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (:label) ASSERT IS UNIQUE"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT () ASSERT IS UNIQUE"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON () ASSERT prop1 IS UNIQUE"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON () ASSERT prop1, prop2 IS UNIQUE"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT " + "n.prop1, missing.prop2 IS UNIQUE"), + SemanticException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT " + "m.prop1, m.prop2 IS UNIQUE"), + SemanticException); + + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (:label) ASSERT IS NODE KEY"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT () ASSERT IS NODE KEY"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON () ASSERT (prop1) IS NODE KEY"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON () ASSERT (prop1, prop2) IS NODE KEY"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT " + "(n.prop1, missing.prop2) IS NODE KEY"), + SemanticException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT " + "(m.prop1, m.prop2) IS NODE KEY"), + SemanticException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT " + "n.prop1, n.prop2 IS NODE KEY"), + SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT " + "exists(n.prop1, n.prop2) IS NODE KEY"), + SyntaxException); +} + +TEST_P(CypherMainVisitorTest, CreateConstraint) { + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast( + ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT EXISTS(n.prop1)")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::CREATE); + EXPECT_EQ(query->constraint_.type, Constraint::Type::EXISTS); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, UnorderedElementsAre(ast_generator.Prop("prop1"))); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast( + ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT EXISTS (n.prop1, n.prop2)")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::CREATE); + EXPECT_EQ(query->constraint_.type, Constraint::Type::EXISTS); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, + UnorderedElementsAre(ast_generator.Prop("prop1"), ast_generator.Prop("prop2"))); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast( + ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT n.prop1 IS UNIQUE")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::CREATE); + EXPECT_EQ(query->constraint_.type, Constraint::Type::UNIQUE); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, UnorderedElementsAre(ast_generator.Prop("prop1"))); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast( + ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT n.prop1, n.prop2 IS UNIQUE")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::CREATE); + EXPECT_EQ(query->constraint_.type, Constraint::Type::UNIQUE); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, + UnorderedElementsAre(ast_generator.Prop("prop1"), ast_generator.Prop("prop2"))); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast( + ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT (n.prop1) IS NODE KEY")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::CREATE); + EXPECT_EQ(query->constraint_.type, Constraint::Type::NODE_KEY); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, UnorderedElementsAre(ast_generator.Prop("prop1"))); + } + { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast(ast_generator.ParseQuery("CREATE CONSTRAINT ON (n:label) ASSERT " + "(n.prop1, n.prop2) IS NODE KEY")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::CREATE); + EXPECT_EQ(query->constraint_.type, Constraint::Type::NODE_KEY); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, + UnorderedElementsAre(ast_generator.Prop("prop1"), ast_generator.Prop("prop2"))); + } +} + +TEST_P(CypherMainVisitorTest, DropConstraint) { + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast( + ast_generator.ParseQuery("DROP CONSTRAINT ON (n:label) ASSERT EXISTS(n.prop1)")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::DROP); + EXPECT_EQ(query->constraint_.type, Constraint::Type::EXISTS); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, UnorderedElementsAre(ast_generator.Prop("prop1"))); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast( + ast_generator.ParseQuery("DROP CONSTRAINT ON (n:label) ASSERT EXISTS(n.prop1, n.prop2)")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::DROP); + EXPECT_EQ(query->constraint_.type, Constraint::Type::EXISTS); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, + UnorderedElementsAre(ast_generator.Prop("prop1"), ast_generator.Prop("prop2"))); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast( + ast_generator.ParseQuery("DROP CONSTRAINT ON (n:label) ASSERT n.prop1 IS UNIQUE")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::DROP); + EXPECT_EQ(query->constraint_.type, Constraint::Type::UNIQUE); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, UnorderedElementsAre(ast_generator.Prop("prop1"))); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast( + ast_generator.ParseQuery("DROP CONSTRAINT ON (n:label) ASSERT n.prop1, n.prop2 IS UNIQUE")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::DROP); + EXPECT_EQ(query->constraint_.type, Constraint::Type::UNIQUE); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, + UnorderedElementsAre(ast_generator.Prop("prop1"), ast_generator.Prop("prop2"))); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast( + ast_generator.ParseQuery("DROP CONSTRAINT ON (n:label) ASSERT (n.prop1) IS NODE KEY")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::DROP); + EXPECT_EQ(query->constraint_.type, Constraint::Type::NODE_KEY); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, UnorderedElementsAre(ast_generator.Prop("prop1"))); + } + { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast(ast_generator.ParseQuery("DROP CONSTRAINT ON (n:label) ASSERT " + "(n.prop1, n.prop2) IS NODE KEY")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_type_, ConstraintQuery::ActionType::DROP); + EXPECT_EQ(query->constraint_.type, Constraint::Type::NODE_KEY); + EXPECT_EQ(query->constraint_.label, ast_generator.Label("label")); + EXPECT_THAT(query->constraint_.properties, + UnorderedElementsAre(ast_generator.Prop("prop1"), ast_generator.Prop("prop2"))); + } +} + +TEST_P(CypherMainVisitorTest, RegexMatch) { + { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast(ast_generator.ParseQuery("MATCH (n) WHERE n.name =~ \".*bla.*\" RETURN n.name")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *match_clause = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(match_clause); + auto *regex_match = dynamic_cast(match_clause->where_->expression_); + ASSERT_TRUE(regex_match); + ASSERT_TRUE(dynamic_cast(regex_match->string_expr_)); + ast_generator.CheckLiteral(regex_match->regex_, ".*bla.*"); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN \"text\" =~ \".*bla.*\"")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *return_clause = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(return_clause); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 1U); + auto *named_expression = return_clause->body_.named_expressions[0]; + auto *regex_match = dynamic_cast(named_expression->expression_); + ASSERT_TRUE(regex_match); + ast_generator.CheckLiteral(regex_match->string_expr_, "text"); + ast_generator.CheckLiteral(regex_match->regex_, ".*bla.*"); + } +} + +// NOLINTNEXTLINE(hicpp-special-member-functions) +TEST_P(CypherMainVisitorTest, DumpDatabase) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("DUMP DATABASE")); + ASSERT_TRUE(query); +} + +namespace { +template +void CheckCallProcedureDefaultMemoryLimit(const TAst &ast, const CallProcedure &call_proc) { + // Should be 100 MB + auto *literal = dynamic_cast(call_proc.memory_limit_); + ASSERT_TRUE(literal); + TypedValue value(literal->value_); + ASSERT_TRUE(TypedValue::BoolEqual{}(value, TypedValue(100))); + ASSERT_EQ(call_proc.memory_scale_, 1024 * 1024); +} +} // namespace + +TEST_P(CypherMainVisitorTest, CallProcedureWithDotsInName) { + AddProc(*mock_module_with_dots_in_name, "proc", {}, {"res"}, ProcedureType::WRITE); + auto &ast_generator = *GetParam(); + + auto *query = + dynamic_cast(ast_generator.ParseQuery("CALL mock_module.with.dots.in.name.proc() YIELD res")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mock_module.with.dots.in.name.proc"); + ASSERT_TRUE(call_proc->arguments_.empty()); + std::vector identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + ASSERT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + std::vector expected_names{"res"}; + ASSERT_EQ(identifier_names, expected_names); + ASSERT_EQ(identifier_names, call_proc->result_fields_); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithDashesInName) { + AddProc(*mock_module, "proc-with-dashes", {}, {"res"}, ProcedureType::READ); + auto &ast_generator = *GetParam(); + + auto *query = + dynamic_cast(ast_generator.ParseQuery("CALL `mock_module.proc-with-dashes`() YIELD res")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mock_module.proc-with-dashes"); + ASSERT_TRUE(call_proc->arguments_.empty()); + std::vector identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + ASSERT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + std::vector expected_names{"res"}; + ASSERT_EQ(identifier_names, expected_names); + ASSERT_EQ(identifier_names, call_proc->result_fields_); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithYieldSomeFields) { + auto &ast_generator = *GetParam(); + auto check_proc = [this, &ast_generator](const ProcedureType type) { + const auto proc_name = std::string{"proc_"} + ToString(type); + SCOPED_TRACE(proc_name); + const auto fully_qualified_proc_name = std::string{"mock_module."} + proc_name; + AddProc(*mock_module, proc_name.c_str(), {}, {"fst", "field-with-dashes", "last_field"}, type); + auto *query = dynamic_cast(ast_generator.ParseQuery( + fmt::format("CALL {}() YIELD fst, `field-with-dashes`, last_field", fully_qualified_proc_name))); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->is_write_, type == ProcedureType::WRITE); + ASSERT_EQ(call_proc->procedure_name_, fully_qualified_proc_name); + ASSERT_TRUE(call_proc->arguments_.empty()); + ASSERT_EQ(call_proc->result_fields_.size(), 3U); + ASSERT_EQ(call_proc->result_identifiers_.size(), call_proc->result_fields_.size()); + std::vector identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + ASSERT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + std::vector expected_names{"fst", "field-with-dashes", "last_field"}; + ASSERT_EQ(identifier_names, expected_names); + ASSERT_EQ(identifier_names, call_proc->result_fields_); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); + }; + check_proc(ProcedureType::READ); + check_proc(ProcedureType::WRITE); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithYieldAliasedFields) { + AddProc(*mock_module, "proc", {}, {"fst", "snd", "thrd"}, ProcedureType::READ); + auto &ast_generator = *GetParam(); + + auto *query = + dynamic_cast(ast_generator.ParseQuery("CALL mock_module.proc() YIELD fst AS res1, snd AS " + "`result-with-dashes`, thrd AS last_result")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mock_module.proc"); + ASSERT_TRUE(call_proc->arguments_.empty()); + ASSERT_EQ(call_proc->result_fields_.size(), 3U); + ASSERT_EQ(call_proc->result_identifiers_.size(), call_proc->result_fields_.size()); + std::vector identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + ASSERT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + std::vector aliased_names{"res1", "result-with-dashes", "last_result"}; + ASSERT_EQ(identifier_names, aliased_names); + std::vector field_names{"fst", "snd", "thrd"}; + ASSERT_EQ(call_proc->result_fields_, field_names); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithArguments) { + AddProc(*mock_module, "proc", {"arg1", "arg2", "arg3"}, {"res"}, ProcedureType::READ); + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("CALL mock_module.proc(0, 1, 2) YIELD res")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mock_module.proc"); + ASSERT_EQ(call_proc->arguments_.size(), 3U); + for (int64_t i = 0; i < 3; ++i) { + ast_generator.CheckLiteral(call_proc->arguments_[i], i); + } + std::vector identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + ASSERT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + std::vector expected_names{"res"}; + ASSERT_EQ(identifier_names, expected_names); + ASSERT_EQ(identifier_names, call_proc->result_fields_); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); +} + +TEST_P(CypherMainVisitorTest, CallProcedureYieldAsterisk) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("CALL mg.procedures() YIELD *")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mg.procedures"); + ASSERT_TRUE(call_proc->arguments_.empty()); + std::vector identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + ASSERT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + ASSERT_THAT(identifier_names, UnorderedElementsAre("name", "signature", "is_write", "path", "is_editable")); + ASSERT_EQ(identifier_names, call_proc->result_fields_); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); +} + +TEST_P(CypherMainVisitorTest, CallProcedureYieldAsteriskReturnAsterisk) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("CALL mg.procedures() YIELD * RETURN *")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *ret = dynamic_cast(single_query->clauses_[1]); + ASSERT_TRUE(ret); + ASSERT_TRUE(ret->body_.all_identifiers); + auto *call_proc = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mg.procedures"); + ASSERT_TRUE(call_proc->arguments_.empty()); + std::vector identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + ASSERT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + ASSERT_THAT(identifier_names, UnorderedElementsAre("name", "signature", "is_write", "path", "is_editable")); + ASSERT_EQ(identifier_names, call_proc->result_fields_); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithoutYield) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("CALL mg.load_all()")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mg.load_all"); + ASSERT_TRUE(call_proc->arguments_.empty()); + ASSERT_TRUE(call_proc->result_fields_.empty()); + ASSERT_TRUE(call_proc->result_identifiers_.empty()); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithMemoryLimitWithoutYield) { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast(ast_generator.ParseQuery("CALL mg.load_all() PROCEDURE MEMORY LIMIT 32 KB")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mg.load_all"); + ASSERT_TRUE(call_proc->arguments_.empty()); + ASSERT_TRUE(call_proc->result_fields_.empty()); + ASSERT_TRUE(call_proc->result_identifiers_.empty()); + ast_generator.CheckLiteral(call_proc->memory_limit_, 32); + ASSERT_EQ(call_proc->memory_scale_, 1024); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithMemoryUnlimitedWithoutYield) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("CALL mg.load_all() PROCEDURE MEMORY UNLIMITED")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mg.load_all"); + ASSERT_TRUE(call_proc->arguments_.empty()); + ASSERT_TRUE(call_proc->result_fields_.empty()); + ASSERT_TRUE(call_proc->result_identifiers_.empty()); + ASSERT_FALSE(call_proc->memory_limit_); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithMemoryLimit) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast( + ast_generator.ParseQuery("CALL mg.load_all() PROCEDURE MEMORY LIMIT 32 MB YIELD res")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mg.load_all"); + ASSERT_TRUE(call_proc->arguments_.empty()); + std::vector identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + ASSERT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + std::vector expected_names{"res"}; + ASSERT_EQ(identifier_names, expected_names); + ASSERT_EQ(identifier_names, call_proc->result_fields_); + ast_generator.CheckLiteral(call_proc->memory_limit_, 32); + ASSERT_EQ(call_proc->memory_scale_, 1024 * 1024); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithMemoryUnlimited) { + auto &ast_generator = *GetParam(); + auto *query = + dynamic_cast(ast_generator.ParseQuery("CALL mg.load_all() PROCEDURE MEMORY UNLIMITED YIELD res")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "mg.load_all"); + ASSERT_TRUE(call_proc->arguments_.empty()); + std::vector identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + ASSERT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + std::vector expected_names{"res"}; + ASSERT_EQ(identifier_names, expected_names); + ASSERT_EQ(identifier_names, call_proc->result_fields_); + ASSERT_FALSE(call_proc->memory_limit_); +} + +namespace { +template +void TestInvalidQuery(const auto &query, Base &ast_generator) { + SCOPED_TRACE(query); + EXPECT_THROW(ast_generator.ParseQuery(query), TException) << query; +} + +template +void TestInvalidQueryWithMessage(const auto &query, Base &ast_generator, const std::string_view message) { + bool exception_is_thrown = false; + try { + ast_generator.ParseQuery(query); + } catch (const TException &se) { + EXPECT_EQ(std::string_view{se.what()}, message); + exception_is_thrown = true; + } catch (...) { + FAIL() << "Unexpected exception"; + } + EXPECT_TRUE(exception_is_thrown); +} + +void CheckParsedCallProcedure(const CypherQuery &query, Base &ast_generator, + const std::string_view fully_qualified_proc_name, + const std::vector &args, const ProcedureType type, + const size_t clauses_size, const size_t call_procedure_index) { + ASSERT_NE(query.single_query_, nullptr); + auto *single_query = query.single_query_; + EXPECT_EQ(single_query->clauses_.size(), clauses_size); + ASSERT_FALSE(single_query->clauses_.empty()); + ASSERT_LT(call_procedure_index, clauses_size); + auto *call_proc = dynamic_cast(single_query->clauses_[call_procedure_index]); + ASSERT_NE(call_proc, nullptr); + EXPECT_EQ(call_proc->procedure_name_, fully_qualified_proc_name); + EXPECT_TRUE(call_proc->arguments_.empty()); + EXPECT_EQ(call_proc->result_fields_.size(), 2U); + EXPECT_EQ(call_proc->result_identifiers_.size(), call_proc->result_fields_.size()); + std::vector identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + EXPECT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + std::vector args_as_str{}; + std::transform(args.begin(), args.end(), std::back_inserter(args_as_str), + [](const std::string_view arg) { return std::string{arg}; }); + EXPECT_EQ(identifier_names, args_as_str); + EXPECT_EQ(identifier_names, call_proc->result_fields_); + ASSERT_EQ(call_proc->is_write_, type == ProcedureType::WRITE); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); +}; +} // namespace + +TEST_P(CypherMainVisitorTest, CallProcedureMultipleQueryPartsAfter) { + auto &ast_generator = *GetParam(); + static constexpr std::string_view fst{"fst"}; + static constexpr std::string_view snd{"snd"}; + const std::vector args{fst, snd}; + + const auto read_proc = CreateProcByType(ProcedureType::READ, args); + const auto write_proc = CreateProcByType(ProcedureType::WRITE, args); + + const auto check_parsed_call_proc = [&ast_generator, &args](const CypherQuery &query, + const std::string_view fully_qualified_proc_name, + const ProcedureType type, const size_t clause_size) { + CheckParsedCallProcedure(query, ast_generator, fully_qualified_proc_name, args, type, clause_size, 0); + }; + { + SCOPED_TRACE("Read query part"); + { + SCOPED_TRACE("With WITH"); + static constexpr std::string_view kQueryWithWith{"CALL {}() YIELD {},{} WITH {},{} UNWIND {} as u RETURN u"}; + static constexpr size_t kQueryParts{4}; + { + SCOPED_TRACE("Write proc"); + const auto query_str = fmt::format(kQueryWithWith, write_proc, fst, snd, fst, snd, fst); + TestInvalidQueryWithMessage( + query_str, ast_generator, + "WITH can't be put after calling a writeable procedure, only RETURN clause can be put after."); + } + { + SCOPED_TRACE("Read proc"); + const auto query_str = fmt::format(kQueryWithWith, read_proc, fst, snd, fst, snd, fst); + const auto *query = dynamic_cast(ast_generator.ParseQuery(query_str)); + ASSERT_NE(query, nullptr); + check_parsed_call_proc(*query, read_proc, ProcedureType::READ, kQueryParts); + } + } + { + SCOPED_TRACE("Without WITH"); + static constexpr std::string_view kQueryWithoutWith{"CALL {}() YIELD {},{} UNWIND {} as u RETURN u"}; + static constexpr size_t kQueryParts{3}; + { + SCOPED_TRACE("Write proc"); + const auto query_str = fmt::format(kQueryWithoutWith, write_proc, fst, snd, fst); + TestInvalidQueryWithMessage( + query_str, ast_generator, + "UNWIND can't be put after calling a writeable procedure, only RETURN clause can be put after."); + } + { + SCOPED_TRACE("Read proc"); + const auto query_str = fmt::format(kQueryWithoutWith, read_proc, fst, snd, fst); + const auto *query = dynamic_cast(ast_generator.ParseQuery(query_str)); + ASSERT_NE(query, nullptr); + check_parsed_call_proc(*query, read_proc, ProcedureType::READ, kQueryParts); + } + } + } + { + SCOPED_TRACE("Write query part"); + { + SCOPED_TRACE("With WITH"); + static constexpr std::string_view kQueryWithWith{ + "CALL {}() YIELD {},{} WITH {},{} CREATE(n {{prop : {}}}) RETURN n"}; + static constexpr size_t kQueryParts{4}; + { + SCOPED_TRACE("Write proc"); + const auto query_str = fmt::format(kQueryWithWith, write_proc, fst, snd, fst, snd, fst); + TestInvalidQueryWithMessage( + query_str, ast_generator, + "WITH can't be put after calling a writeable procedure, only RETURN clause can be put after."); + } + { + SCOPED_TRACE("Read proc"); + const auto query_str = fmt::format(kQueryWithWith, read_proc, fst, snd, fst, snd, fst); + const auto *query = dynamic_cast(ast_generator.ParseQuery(query_str)); + ASSERT_NE(query, nullptr); + check_parsed_call_proc(*query, read_proc, ProcedureType::READ, kQueryParts); + } + } + { + SCOPED_TRACE("Without WITH"); + static constexpr std::string_view kQueryWithoutWith{"CALL {}() YIELD {},{} CREATE(n {{prop : {}}}) RETURN n"}; + static constexpr size_t kQueryParts{3}; + { + SCOPED_TRACE("Write proc"); + const auto query_str = fmt::format(kQueryWithoutWith, write_proc, fst, snd, fst); + TestInvalidQueryWithMessage( + query_str, ast_generator, + "Update clause can't be put after calling a writeable procedure, only RETURN clause can be put after."); + } + { + SCOPED_TRACE("Read proc"); + const auto query_str = fmt::format(kQueryWithoutWith, read_proc, fst, snd, fst, snd, fst); + const auto *query = dynamic_cast(ast_generator.ParseQuery(query_str)); + ASSERT_NE(query, nullptr); + check_parsed_call_proc(*query, read_proc, ProcedureType::READ, kQueryParts); + } + } + } +} + +TEST_P(CypherMainVisitorTest, CallProcedureMultipleQueryPartsBefore) { + auto &ast_generator = *GetParam(); + static constexpr std::string_view fst{"fst"}; + static constexpr std::string_view snd{"snd"}; + const std::vector args{fst, snd}; + + const auto read_proc = CreateProcByType(ProcedureType::READ, args); + const auto write_proc = CreateProcByType(ProcedureType::WRITE, args); + + const auto check_parsed_call_proc = [&ast_generator, &args](const CypherQuery &query, + const std::string_view fully_qualified_proc_name, + const ProcedureType type, const size_t clause_size) { + CheckParsedCallProcedure(query, ast_generator, fully_qualified_proc_name, args, type, clause_size, clause_size - 2); + }; + { + SCOPED_TRACE("Read query part"); + static constexpr std::string_view kQueryWithReadQueryPart{"MATCH (n) CALL {}() YIELD * RETURN *"}; + static constexpr size_t kQueryParts{3}; + { + SCOPED_TRACE("Write proc"); + const auto query_str = fmt::format(kQueryWithReadQueryPart, write_proc); + const auto *query = dynamic_cast(ast_generator.ParseQuery(query_str)); + ASSERT_NE(query, nullptr); + check_parsed_call_proc(*query, write_proc, ProcedureType::WRITE, kQueryParts); + } + { + SCOPED_TRACE("Read proc"); + const auto query_str = fmt::format(kQueryWithReadQueryPart, read_proc); + const auto *query = dynamic_cast(ast_generator.ParseQuery(query_str)); + ASSERT_NE(query, nullptr); + check_parsed_call_proc(*query, read_proc, ProcedureType::READ, kQueryParts); + } + } + { + SCOPED_TRACE("Write query part"); + static constexpr std::string_view kQueryWithWriteQueryPart{"CREATE (n) WITH n CALL {}() YIELD * RETURN *"}; + static constexpr size_t kQueryParts{4}; + { + SCOPED_TRACE("Write proc"); + const auto query_str = fmt::format(kQueryWithWriteQueryPart, write_proc, fst, snd, fst); + TestInvalidQueryWithMessage( + query_str, ast_generator, "Write procedures cannot be used in queries that contains any update clauses!"); + } + { + SCOPED_TRACE("Read proc"); + const auto query_str = fmt::format(kQueryWithWriteQueryPart, read_proc, fst, snd, fst, snd, fst); + const auto *query = dynamic_cast(ast_generator.ParseQuery(query_str)); + ASSERT_NE(query, nullptr); + check_parsed_call_proc(*query, read_proc, ProcedureType::READ, kQueryParts); + } + } +} + +TEST_P(CypherMainVisitorTest, CallProcedureMultipleProcedures) { + auto &ast_generator = *GetParam(); + static constexpr std::string_view fst{"fst"}; + static constexpr std::string_view snd{"snd"}; + const std::vector args{fst, snd}; + + const auto read_proc = CreateProcByType(ProcedureType::READ, args); + const auto write_proc = CreateProcByType(ProcedureType::WRITE, args); + + { + SCOPED_TRACE("Read then write"); + const auto query_str = fmt::format("CALL {}() YIELD * CALL {}() YIELD * RETURN *", read_proc, write_proc); + static constexpr size_t kQueryParts{3}; + const auto *query = dynamic_cast(ast_generator.ParseQuery(query_str)); + ASSERT_NE(query, nullptr); + + CheckParsedCallProcedure(*query, ast_generator, read_proc, args, ProcedureType::READ, kQueryParts, 0); + CheckParsedCallProcedure(*query, ast_generator, write_proc, args, ProcedureType::WRITE, kQueryParts, 1); + } + { + SCOPED_TRACE("Write then read"); + const auto query_str = fmt::format("CALL {}() YIELD * CALL {}() YIELD * RETURN *", write_proc, read_proc); + TestInvalidQueryWithMessage( + query_str, ast_generator, + "CALL can't be put after calling a writeable procedure, only RETURN clause can be put after."); + } + { + SCOPED_TRACE("Write twice"); + const auto query_str = fmt::format("CALL {}() YIELD * CALL {}() YIELD * RETURN *", write_proc, write_proc); + TestInvalidQueryWithMessage( + query_str, ast_generator, + "CALL can't be put after calling a writeable procedure, only RETURN clause can be put after."); + } +} + +TEST_P(CypherMainVisitorTest, IncorrectCallProcedure) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc-with-dashes()"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc() yield field-with-dashes"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc() yield field.with.dots"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc() yield res AS result-with-dashes"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc() yield res AS result.with.dots"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("WITH 42 AS x CALL not_standalone(x)"), SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("CALL procedure() YIELD"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 42, CALL procedure() YIELD"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 42, CALL procedure() YIELD res"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 42 AS x CALL procedure() YIELD res"), SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc.with.dots() MEMORY YIELD res"), SyntaxException); + // mg.procedures returns something, so it needs to have a YIELD. + ASSERT_THROW(ast_generator.ParseQuery("CALL mg.procedures()"), SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("CALL mg.procedures() PROCEDURE MEMORY UNLIMITED"), SemanticException); + // TODO: Implement support for the following syntax. These are defined in + // Neo4j and accepted in openCypher CIP. + ASSERT_THROW(ast_generator.ParseQuery("CALL proc"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc RETURN 42"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc() YIELD res WHERE res > 42"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc() YIELD res WHERE res > 42 RETURN *"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, TestLockPathQuery) { + auto &ast_generator = *GetParam(); + + const auto test_lock_path_query = [&](const std::string_view command, const LockPathQuery::Action action) { + ASSERT_THROW(ast_generator.ParseQuery(command.data()), SyntaxException); + + { + const std::string query = fmt::format("{} ME", command); + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = fmt::format("{} DATA", command); + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = fmt::format("{} DATA STUFF", command); + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = fmt::format("{} DATA DIRECTORY", command); + auto *parsed_query = dynamic_cast(ast_generator.ParseQuery(query)); + ASSERT_TRUE(parsed_query); + EXPECT_EQ(parsed_query->action_, action); + } + }; + + test_lock_path_query("LOCK", LockPathQuery::Action::LOCK_PATH); + test_lock_path_query("UNLOCK", LockPathQuery::Action::UNLOCK_PATH); +} + +TEST_P(CypherMainVisitorTest, TestLoadCsvClause) { + auto &ast_generator = *GetParam(); + + { + const std::string query = R"(LOAD CSV FROM "file.csv")"; + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = R"(LOAD CSV FROM "file.csv" WITH)"; + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = R"(LOAD CSV FROM "file.csv" WITH HEADER)"; + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = R"(LOAD CSV FROM "file.csv" WITH HEADER DELIMITER ";")"; + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = R"(LOAD CSV FROM "file.csv" WITH HEADER DELIMITER ";" QUOTE "'")"; + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = R"(LOAD CSV FROM "file.csv" WITH HEADER DELIMITER ";" QUOTE "'" AS)"; + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = R"(LOAD CSV FROM file WITH HEADER IGNORE BAD DELIMITER ";" QUOTE "'" AS x)"; + ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException); + } + + { + const std::string query = R"(LOAD CSV FROM "file.csv" WITH HEADER IGNORE BAD DELIMITER 0 QUOTE "'" AS x)"; + ASSERT_THROW(ast_generator.ParseQuery(query), SemanticException); + } + + { + const std::string query = R"(LOAD CSV FROM "file.csv" WITH HEADER IGNORE BAD DELIMITER ";" QUOTE 0 AS x)"; + ASSERT_THROW(ast_generator.ParseQuery(query), SemanticException); + } + + { + // can't be a standalone clause + const std::string query = R"(LOAD CSV FROM "file.csv" WITH HEADER IGNORE BAD DELIMITER ";" QUOTE "'" AS x)"; + ASSERT_THROW(ast_generator.ParseQuery(query), SemanticException); + } + + { + const std::string query = + R"(LOAD CSV FROM "file.csv" WITH HEADER IGNORE BAD DELIMITER ";" QUOTE "'" AS x RETURN x)"; + auto *parsed_query = dynamic_cast(ast_generator.ParseQuery(query)); + ASSERT_TRUE(parsed_query); + auto *load_csv_clause = dynamic_cast(parsed_query->single_query_->clauses_[0]); + ASSERT_TRUE(load_csv_clause); + ASSERT_TRUE(load_csv_clause->with_header_); + ASSERT_TRUE(load_csv_clause->ignore_bad_); + } +} + +TEST_P(CypherMainVisitorTest, MemoryLimit) { + auto &ast_generator = *GetParam(); + + ASSERT_THROW(ast_generator.ParseQuery("RETURN x QUE"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN x QUERY"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN x QUERY MEM"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN x QUERY MEMORY"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN x QUERY MEMORY LIM"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN x QUERY MEMORY LIMIT"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN x QUERY MEMORY LIMIT KB"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN x QUERY MEMORY LIMIT 12GB"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("QUERY MEMORY LIMIT 12KB RETURN x"), SyntaxException); + + { + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN x")); + ASSERT_TRUE(query); + ASSERT_FALSE(query->memory_limit_); + } + + { + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN x QUERY MEMORY LIMIT 12KB")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->memory_limit_); + ast_generator.CheckLiteral(query->memory_limit_, 12); + ASSERT_EQ(query->memory_scale_, 1024U); + } + + { + auto *query = dynamic_cast(ast_generator.ParseQuery("RETURN x QUERY MEMORY LIMIT 12MB")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->memory_limit_); + ast_generator.CheckLiteral(query->memory_limit_, 12); + ASSERT_EQ(query->memory_scale_, 1024U * 1024U); + } + + { + auto *query = dynamic_cast( + ast_generator.ParseQuery("CALL mg.procedures() YIELD x RETURN x QUERY MEMORY LIMIT 12MB")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->memory_limit_); + ast_generator.CheckLiteral(query->memory_limit_, 12); + ASSERT_EQ(query->memory_scale_, 1024U * 1024U); + + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *call_proc = dynamic_cast(single_query->clauses_[0]); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); + } + + { + auto *query = dynamic_cast(ast_generator.ParseQuery( + "CALL mg.procedures() PROCEDURE MEMORY LIMIT 3KB YIELD x RETURN x QUERY MEMORY LIMIT 12MB")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->memory_limit_); + ast_generator.CheckLiteral(query->memory_limit_, 12); + ASSERT_EQ(query->memory_scale_, 1024U * 1024U); + + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *call_proc = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(call_proc->memory_limit_); + ast_generator.CheckLiteral(call_proc->memory_limit_, 3); + ASSERT_EQ(call_proc->memory_scale_, 1024U); + } + + { + auto *query = dynamic_cast( + ast_generator.ParseQuery("CALL mg.procedures() PROCEDURE MEMORY LIMIT 3KB YIELD x RETURN x")); + ASSERT_TRUE(query); + ASSERT_FALSE(query->memory_limit_); + + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 2U); + auto *call_proc = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(call_proc->memory_limit_); + ast_generator.CheckLiteral(call_proc->memory_limit_, 3); + ASSERT_EQ(call_proc->memory_scale_, 1024U); + } + + { + auto *query = + dynamic_cast(ast_generator.ParseQuery("CALL mg.load_all() PROCEDURE MEMORY LIMIT 3KB")); + ASSERT_TRUE(query); + ASSERT_FALSE(query->memory_limit_); + + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(call_proc->memory_limit_); + ast_generator.CheckLiteral(call_proc->memory_limit_, 3); + ASSERT_EQ(call_proc->memory_scale_, 1024U); + } + + { + auto *query = dynamic_cast(ast_generator.ParseQuery("CALL mg.load_all() QUERY MEMORY LIMIT 3KB")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->memory_limit_); + ast_generator.CheckLiteral(query->memory_limit_, 3); + ASSERT_EQ(query->memory_scale_, 1024U); + + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast(single_query->clauses_[0]); + CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); + } +} + +TEST_P(CypherMainVisitorTest, DropTrigger) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("DROP TR", ast_generator); + TestInvalidQuery("DROP TRIGGER", ast_generator); + + auto *parsed_query = dynamic_cast(ast_generator.ParseQuery("DROP TRIGGER trigger")); + EXPECT_EQ(parsed_query->action_, TriggerQuery::Action::DROP_TRIGGER); + EXPECT_EQ(parsed_query->trigger_name_, "trigger"); +} + +TEST_P(CypherMainVisitorTest, ShowTriggers) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("SHOW TR", ast_generator); + TestInvalidQuery("SHOW TRIGGER", ast_generator); + + auto *parsed_query = dynamic_cast(ast_generator.ParseQuery("SHOW TRIGGERS")); + EXPECT_EQ(parsed_query->action_, TriggerQuery::Action::SHOW_TRIGGERS); +} + +namespace { +void ValidateCreateQuery(Base &ast_generator, const auto &query, const auto &trigger_name, + const memgraph::query::v2::TriggerQuery::EventType event_type, const auto &phase, + const auto &statement) { + auto *parsed_query = dynamic_cast(ast_generator.ParseQuery(query)); + EXPECT_EQ(parsed_query->action_, TriggerQuery::Action::CREATE_TRIGGER); + EXPECT_EQ(parsed_query->trigger_name_, trigger_name); + EXPECT_EQ(parsed_query->event_type_, event_type); + EXPECT_EQ(parsed_query->before_commit_, phase == "BEFORE"); + EXPECT_EQ(parsed_query->statement_, statement); +} +} // namespace + +TEST_P(CypherMainVisitorTest, CreateTriggers) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("CREATE TRIGGER", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON ", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON ()", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON -->", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON CREATE", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON () CREATE", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON --> CREATE", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON DELETE", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON () DELETE", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON --> DELETE", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON UPDATE", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON () UPDATE", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON --> UPDATE", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON CREATE BEFORE", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON CREATE BEFORE COMMIT", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON CREATE AFTER", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON CREATE AFTER COMMIT", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON -> CREATE AFTER COMMIT EXECUTE a", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON ) CREATE AFTER COMMIT EXECUTE a", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON ( CREATE AFTER COMMIT EXECUTE a", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON CRETE AFTER COMMIT EXECUTE a", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON DELET AFTER COMMIT EXECUTE a", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON UPDTE AFTER COMMIT EXECUTE a", ast_generator); + TestInvalidQuery("CREATE TRIGGER trigger ON UPDATE COMMIT EXECUTE a", ast_generator); + + static constexpr std::string_view query_template = "CREATE TRIGGER trigger {} {} COMMIT EXECUTE {}"; + + static constexpr std::array events{ + std::pair{"", memgraph::query::v2::TriggerQuery::EventType::ANY}, + std::pair{"ON CREATE", memgraph::query::v2::TriggerQuery::EventType::CREATE}, + std::pair{"ON () CREATE", memgraph::query::v2::TriggerQuery::EventType::VERTEX_CREATE}, + std::pair{"ON --> CREATE", memgraph::query::v2::TriggerQuery::EventType::EDGE_CREATE}, + std::pair{"ON DELETE", memgraph::query::v2::TriggerQuery::EventType::DELETE}, + std::pair{"ON () DELETE", memgraph::query::v2::TriggerQuery::EventType::VERTEX_DELETE}, + std::pair{"ON --> DELETE", memgraph::query::v2::TriggerQuery::EventType::EDGE_DELETE}, + std::pair{"ON UPDATE", memgraph::query::v2::TriggerQuery::EventType::UPDATE}, + std::pair{"ON () UPDATE", memgraph::query::v2::TriggerQuery::EventType::VERTEX_UPDATE}, + std::pair{"ON --> UPDATE", memgraph::query::v2::TriggerQuery::EventType::EDGE_UPDATE}}; + + static constexpr std::array phases{"BEFORE", "AFTER"}; + + static constexpr std::array statements{ + "", "SOME SUPER\nSTATEMENT", "Statement with 12312321 3 ", " Statement with 12312321 3 " + + }; + + for (const auto &[event_string, event_type] : events) { + for (const auto &phase : phases) { + for (const auto &statement : statements) { + ValidateCreateQuery(ast_generator, fmt::format(query_template, event_string, phase, statement), "trigger", + event_type, phase, memgraph::utils::Trim(statement)); + } + } + } +} + +namespace { +void ValidateSetIsolationLevelQuery(Base &ast_generator, const auto &query, const auto scope, + const auto isolation_level) { + auto *parsed_query = dynamic_cast(ast_generator.ParseQuery(query)); + EXPECT_EQ(parsed_query->isolation_level_scope_, scope); + EXPECT_EQ(parsed_query->isolation_level_, isolation_level); +} +} // namespace + +TEST_P(CypherMainVisitorTest, SetIsolationLevelQuery) { + auto &ast_generator = *GetParam(); + TestInvalidQuery("SET ISO", ast_generator); + TestInvalidQuery("SET TRANSACTION ISOLATION", ast_generator); + TestInvalidQuery("SET TRANSACTION ISOLATION LEVEL", ast_generator); + TestInvalidQuery("SET TRANSACTION ISOLATION LEVEL READ COMMITTED", ast_generator); + TestInvalidQuery("SET NEXT TRANSACTION ISOLATION LEVEL", ast_generator); + TestInvalidQuery("SET ISOLATION LEVEL READ COMMITTED", ast_generator); + TestInvalidQuery("SET GLOBAL ISOLATION LEVEL READ COMMITTED", ast_generator); + TestInvalidQuery("SET GLOBAL TRANSACTION ISOLATION LEVEL READ COMITTED", ast_generator); + TestInvalidQuery("SET GLOBAL TRANSACTION ISOLATION LEVEL READ_COMITTED", ast_generator); + TestInvalidQuery("SET SESSION TRANSACTION ISOLATION LEVEL READCOMITTED", ast_generator); + + static constexpr std::array scopes{ + std::pair{"GLOBAL", memgraph::query::v2::IsolationLevelQuery::IsolationLevelScope::GLOBAL}, + std::pair{"SESSION", memgraph::query::v2::IsolationLevelQuery::IsolationLevelScope::SESSION}, + std::pair{"NEXT", memgraph::query::v2::IsolationLevelQuery::IsolationLevelScope::NEXT}}; + static constexpr std::array isolation_levels{ + std::pair{"READ UNCOMMITTED", memgraph::query::v2::IsolationLevelQuery::IsolationLevel::READ_UNCOMMITTED}, + std::pair{"READ COMMITTED", memgraph::query::v2::IsolationLevelQuery::IsolationLevel::READ_COMMITTED}, + std::pair{"SNAPSHOT ISOLATION", memgraph::query::v2::IsolationLevelQuery::IsolationLevel::SNAPSHOT_ISOLATION}}; + + static constexpr const auto *query_template = "SET {} TRANSACTION ISOLATION LEVEL {}"; + + for (const auto &[scope_string, scope] : scopes) { + for (const auto &[isolation_level_string, isolation_level] : isolation_levels) { + ValidateSetIsolationLevelQuery(ast_generator, fmt::format(query_template, scope_string, isolation_level_string), + scope, isolation_level); + } + } +} + +TEST_P(CypherMainVisitorTest, CreateSnapshotQuery) { + auto &ast_generator = *GetParam(); + ASSERT_TRUE(dynamic_cast(ast_generator.ParseQuery("CREATE SNAPSHOT"))); +} + +void CheckOptionalExpression(Base &ast_generator, Expression *expression, const std::optional &expected) { + EXPECT_EQ(expression != nullptr, expected.has_value()); + if (expected.has_value()) { + EXPECT_NO_FATAL_FAILURE(ast_generator.CheckLiteral(expression, *expected)); + } +}; + +void ValidateMostlyEmptyStreamQuery(Base &ast_generator, const std::string &query_string, + const StreamQuery::Action action, const std::string_view stream_name, + const std::optional &batch_limit = std::nullopt, + const std::optional &timeout = std::nullopt) { + auto *parsed_query = dynamic_cast(ast_generator.ParseQuery(query_string)); + ASSERT_NE(parsed_query, nullptr); + EXPECT_EQ(parsed_query->action_, action); + EXPECT_EQ(parsed_query->stream_name_, stream_name); + auto topic_names = std::get_if(&parsed_query->topic_names_); + EXPECT_NE(topic_names, nullptr); + EXPECT_EQ(*topic_names, nullptr); + EXPECT_TRUE(topic_names); + EXPECT_FALSE(*topic_names); + EXPECT_TRUE(parsed_query->transform_name_.empty()); + EXPECT_TRUE(parsed_query->consumer_group_.empty()); + EXPECT_EQ(parsed_query->batch_interval_, nullptr); + EXPECT_EQ(parsed_query->batch_size_, nullptr); + EXPECT_EQ(parsed_query->service_url_, nullptr); + EXPECT_EQ(parsed_query->bootstrap_servers_, nullptr); + EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->batch_limit_, batch_limit)); + EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->timeout_, timeout)); + EXPECT_TRUE(parsed_query->configs_.empty()); + EXPECT_TRUE(parsed_query->credentials_.empty()); +} + +TEST_P(CypherMainVisitorTest, DropStream) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("DROP ST", ast_generator); + TestInvalidQuery("DROP STREAM", ast_generator); + TestInvalidQuery("DROP STREAMS", ast_generator); + + ValidateMostlyEmptyStreamQuery(ast_generator, "DrOP STREAm droppedStream", StreamQuery::Action::DROP_STREAM, + "droppedStream"); +} + +TEST_P(CypherMainVisitorTest, StartStream) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("START ST", ast_generator); + TestInvalidQuery("START STREAM", ast_generator); + TestInvalidQuery("START STREAMS", ast_generator); + + ValidateMostlyEmptyStreamQuery(ast_generator, "START STREAM startedStream", StreamQuery::Action::START_STREAM, + "startedStream"); +} + +TEST_P(CypherMainVisitorTest, StartAllStreams) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("START ALL", ast_generator); + TestInvalidQuery("START ALL STREAM", ast_generator); + TestInvalidQuery("START STREAMS ALL", ast_generator); + + ValidateMostlyEmptyStreamQuery(ast_generator, "StARt AlL StrEAMS", StreamQuery::Action::START_ALL_STREAMS, ""); +} + +TEST_P(CypherMainVisitorTest, ShowStreams) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("SHOW ALL", ast_generator); + TestInvalidQuery("SHOW STREAM", ast_generator); + TestInvalidQuery("SHOW STREAMS ALL", ast_generator); + + ValidateMostlyEmptyStreamQuery(ast_generator, "SHOW STREAMS", StreamQuery::Action::SHOW_STREAMS, ""); +} + +TEST_P(CypherMainVisitorTest, StopStream) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("STOP ST", ast_generator); + TestInvalidQuery("STOP STREAM", ast_generator); + TestInvalidQuery("STOP STREAMS", ast_generator); + TestInvalidQuery("STOP STREAM invalid stream name", ast_generator); + + ValidateMostlyEmptyStreamQuery(ast_generator, "STOP stREAM stoppedStream", StreamQuery::Action::STOP_STREAM, + "stoppedStream"); +} + +TEST_P(CypherMainVisitorTest, StopAllStreams) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("STOP ALL", ast_generator); + TestInvalidQuery("STOP ALL STREAM", ast_generator); + TestInvalidQuery("STOP STREAMS ALL", ast_generator); + + ValidateMostlyEmptyStreamQuery(ast_generator, "SToP ALL STReaMS", StreamQuery::Action::STOP_ALL_STREAMS, ""); +} + +void ValidateTopicNames(const auto &topic_names, const std::vector &expected_topic_names, + Base &ast_generator) { + std::visit(memgraph::utils::Overloaded{ + [&](Expression *expression) { + ast_generator.CheckLiteral(expression, memgraph::utils::Join(expected_topic_names, ",")); + }, + [&](const std::vector &topic_names) { EXPECT_EQ(topic_names, expected_topic_names); }}, + topic_names); +} + +void ValidateCreateKafkaStreamQuery(Base &ast_generator, const std::string &query_string, + const std::string_view stream_name, const std::vector &topic_names, + const std::string_view transform_name, const std::string_view consumer_group, + const std::optional &batch_interval, + const std::optional &batch_size, + const std::string_view bootstrap_servers, + const std::unordered_map &configs, + const std::unordered_map &credentials) { + SCOPED_TRACE(query_string); + StreamQuery *parsed_query{nullptr}; + ASSERT_NO_THROW(parsed_query = dynamic_cast(ast_generator.ParseQuery(query_string))) << query_string; + ASSERT_NE(parsed_query, nullptr); + EXPECT_EQ(parsed_query->stream_name_, stream_name); + ValidateTopicNames(parsed_query->topic_names_, topic_names, ast_generator); + EXPECT_EQ(parsed_query->transform_name_, transform_name); + EXPECT_EQ(parsed_query->consumer_group_, consumer_group); + EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->batch_interval_, batch_interval)); + EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->batch_size_, batch_size)); + EXPECT_EQ(parsed_query->batch_limit_, nullptr); + if (bootstrap_servers.empty()) { + EXPECT_EQ(parsed_query->bootstrap_servers_, nullptr); + } else { + EXPECT_NE(parsed_query->bootstrap_servers_, nullptr); + } + + const auto evaluate_config_map = [&ast_generator](const std::unordered_map &config_map) { + std::unordered_map evaluated_config_map; + const auto expr_to_str = [&ast_generator](Expression *expression) { + return std::string{ast_generator.GetLiteral(expression, ast_generator.context_.is_query_cached).ValueString()}; + }; + std::transform(config_map.begin(), config_map.end(), + std::inserter(evaluated_config_map, evaluated_config_map.end()), + [&expr_to_str](const auto expr_pair) { + return std::pair{expr_to_str(expr_pair.first), expr_to_str(expr_pair.second)}; + }); + return evaluated_config_map; + }; + + using testing::UnorderedElementsAreArray; + EXPECT_THAT(evaluate_config_map(parsed_query->configs_), UnorderedElementsAreArray(configs.begin(), configs.end())); + EXPECT_THAT(evaluate_config_map(parsed_query->credentials_), + UnorderedElementsAreArray(credentials.begin(), credentials.end())); +} + +TEST_P(CypherMainVisitorTest, CreateKafkaStream) { + auto &ast_generator = *GetParam(); + TestInvalidQuery("CREATE KAFKA STREAM", ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM invalid stream name TOPICS topic1 TRANSFORM transform", ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS invalid topic name TRANSFORM transform", ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM invalid transformation name", ast_generator); + // required configs are missing + TestInvalidQuery("CREATE KAFKA STREAM stream TRANSFORM transform", ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS TRANSFORM transform", ast_generator); + // required configs are missing + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1", ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM", ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform CONSUMER_GROUP", ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform CONSUMER_GROUP invalid consumer group", + ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform BATCH_INTERVAL", ast_generator); + TestInvalidQuery( + "CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform BATCH_INTERVAL 'invalid interval'", ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform TOPICS topic2", + ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform BATCH_SIZE", ast_generator); + TestInvalidQuery( + "CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform BATCH_SIZE 'invalid size'", ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1, TRANSFORM transform BATCH_SIZE 2 CONSUMER_GROUP Gru", + ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform BOOTSTRAP_SERVERS localhost:9092", + ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform BOOTSTRAP_SERVERS", ast_generator); + // the keys must be string literals + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform CONFIGS { symbolicname : 'string' }", + ast_generator); + TestInvalidQuery( + "CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform CREDENTIALS { symbolicname : 'string' }", + ast_generator); + TestInvalidQuery("CREATE KAFKA STREAM stream TOPICS topic1 TRANSFORM transform CREDENTIALS 2", ast_generator); + + const std::vector topic_names{"topic1_name.with_dot", "topic1_name.with_multiple.dots", + "topic-name.with-multiple.dots-and-dashes"}; + + static constexpr std::string_view kStreamName{"SomeSuperStream"}; + static constexpr std::string_view kTransformName{"moreAwesomeTransform"}; + + auto check_topic_names = [&](const std::vector &topic_names) { + static constexpr std::string_view kConsumerGroup{"ConsumerGru"}; + static constexpr int kBatchInterval = 324; + const TypedValue batch_interval_value{kBatchInterval}; + static constexpr int kBatchSize = 1; + const TypedValue batch_size_value{kBatchSize}; + + const auto topic_names_as_str = memgraph::utils::Join(topic_names, ","); + + ValidateCreateKafkaStreamQuery( + ast_generator, + fmt::format("CREATE KAFKA STREAM {} TOPICS {} TRANSFORM {}", kStreamName, topic_names_as_str, kTransformName), + kStreamName, topic_names, kTransformName, "", std::nullopt, std::nullopt, {}, {}, {}); + + ValidateCreateKafkaStreamQuery(ast_generator, + fmt::format("CREATE KAFKA STREAM {} TOPICS {} TRANSFORM {} CONSUMER_GROUP {} ", + kStreamName, topic_names_as_str, kTransformName, kConsumerGroup), + kStreamName, topic_names, kTransformName, kConsumerGroup, std::nullopt, std::nullopt, + {}, {}, {}); + + ValidateCreateKafkaStreamQuery(ast_generator, + fmt::format("CREATE KAFKA STREAM {} TRANSFORM {} TOPICS {} BATCH_INTERVAL {}", + kStreamName, kTransformName, topic_names_as_str, kBatchInterval), + kStreamName, topic_names, kTransformName, "", batch_interval_value, std::nullopt, {}, + {}, {}); + + ValidateCreateKafkaStreamQuery(ast_generator, + fmt::format("CREATE KAFKA STREAM {} BATCH_SIZE {} TOPICS {} TRANSFORM {}", + kStreamName, kBatchSize, topic_names_as_str, kTransformName), + kStreamName, topic_names, kTransformName, "", std::nullopt, batch_size_value, {}, {}, + {}); + + ValidateCreateKafkaStreamQuery(ast_generator, + fmt::format("CREATE KAFKA STREAM {} TOPICS '{}' BATCH_SIZE {} TRANSFORM {}", + kStreamName, topic_names_as_str, kBatchSize, kTransformName), + kStreamName, topic_names, kTransformName, "", std::nullopt, batch_size_value, {}, {}, + {}); + + ValidateCreateKafkaStreamQuery( + ast_generator, + fmt::format("CREATE KAFKA STREAM {} TOPICS {} TRANSFORM {} CONSUMER_GROUP {} BATCH_INTERVAL {} BATCH_SIZE {}", + kStreamName, topic_names_as_str, kTransformName, kConsumerGroup, kBatchInterval, kBatchSize), + kStreamName, topic_names, kTransformName, kConsumerGroup, batch_interval_value, batch_size_value, {}, {}, {}); + using namespace std::string_literals; + const auto host1 = "localhost:9094"s; + ValidateCreateKafkaStreamQuery( + ast_generator, + fmt::format("CREATE KAFKA STREAM {} TOPICS {} CONSUMER_GROUP {} BATCH_SIZE {} BATCH_INTERVAL {} TRANSFORM {} " + "BOOTSTRAP_SERVERS '{}'", + kStreamName, topic_names_as_str, kConsumerGroup, kBatchSize, kBatchInterval, kTransformName, host1), + kStreamName, topic_names, kTransformName, kConsumerGroup, batch_interval_value, batch_size_value, host1, {}, + {}); + + ValidateCreateKafkaStreamQuery( + ast_generator, + fmt::format("CREATE KAFKA STREAM {} CONSUMER_GROUP {} TOPICS {} BATCH_INTERVAL {} TRANSFORM {} BATCH_SIZE {} " + "BOOTSTRAP_SERVERS '{}'", + kStreamName, kConsumerGroup, topic_names_as_str, kBatchInterval, kTransformName, kBatchSize, host1), + kStreamName, topic_names, kTransformName, kConsumerGroup, batch_interval_value, batch_size_value, host1, {}, + {}); + + const auto host2 = "localhost:9094,localhost:1994,168.1.1.256:345"s; + ValidateCreateKafkaStreamQuery( + ast_generator, + fmt::format("CREATE KAFKA STREAM {} TOPICS {} BOOTSTRAP_SERVERS '{}' CONSUMER_GROUP {} TRANSFORM {} " + "BATCH_INTERVAL {} BATCH_SIZE {}", + kStreamName, topic_names_as_str, host2, kConsumerGroup, kTransformName, kBatchInterval, kBatchSize), + kStreamName, topic_names, kTransformName, kConsumerGroup, batch_interval_value, batch_size_value, host2, {}, + {}); + }; + + for (const auto &topic_name : topic_names) { + EXPECT_NO_FATAL_FAILURE(check_topic_names({topic_name})); + } + + EXPECT_NO_FATAL_FAILURE(check_topic_names(topic_names)); + + auto check_consumer_group = [&](const std::string_view consumer_group) { + const std::string kTopicName{"topic1"}; + ValidateCreateKafkaStreamQuery(ast_generator, + fmt::format("CREATE KAFKA STREAM {} TOPICS {} TRANSFORM {} CONSUMER_GROUP {}", + kStreamName, kTopicName, kTransformName, consumer_group), + kStreamName, {kTopicName}, kTransformName, consumer_group, std::nullopt, + std::nullopt, {}, {}, {}); + }; + + using namespace std::literals; + static constexpr std::array consumer_groups{"consumergru"sv, "consumer-group-with-dash"sv, + "consumer_group.with.dot"sv, "consumer-group.With-Dot-and.dash"sv}; + + for (const auto consumer_group : consumer_groups) { + EXPECT_NO_FATAL_FAILURE(check_consumer_group(consumer_group)); + } + + auto check_config_map = [&](const std::unordered_map &config_map) { + const std::string kTopicName{"topic1"}; + + const auto map_as_str = std::invoke([&config_map] { + std::stringstream buffer; + buffer << '{'; + if (!config_map.empty()) { + auto it = config_map.begin(); + buffer << fmt::format("'{}': '{}'", it->first, it->second); + for (; it != config_map.end(); ++it) { + buffer << fmt::format(", '{}': '{}'", it->first, it->second); + } + } + buffer << '}'; + return std::move(buffer).str(); + }); + + ValidateCreateKafkaStreamQuery(ast_generator, + fmt::format("CREATE KAFKA STREAM {} TOPICS {} TRANSFORM {} CONFIGS {}", kStreamName, + kTopicName, kTransformName, map_as_str), + kStreamName, {kTopicName}, kTransformName, "", std::nullopt, std::nullopt, {}, + config_map, {}); + + ValidateCreateKafkaStreamQuery(ast_generator, + fmt::format("CREATE KAFKA STREAM {} TOPICS {} TRANSFORM {} CREDENTIALS {}", + kStreamName, kTopicName, kTransformName, map_as_str), + kStreamName, {kTopicName}, kTransformName, "", std::nullopt, std::nullopt, {}, {}, + config_map); + }; + + const std::array config_maps = {std::unordered_map{}, + std::unordered_map{{"key", "value"}}, + std::unordered_map{{"key.with.dot", "value.with.doth"}, + {"key with space", "value with space"}}}; + for (const auto &map_to_test : config_maps) { + EXPECT_NO_FATAL_FAILURE(check_config_map(map_to_test)); + } +} + +void ValidateCreatePulsarStreamQuery(Base &ast_generator, const std::string &query_string, + const std::string_view stream_name, const std::vector &topic_names, + const std::string_view transform_name, + const std::optional &batch_interval, + const std::optional &batch_size, const std::string_view service_url) { + SCOPED_TRACE(query_string); + + StreamQuery *parsed_query{nullptr}; + ASSERT_NO_THROW(parsed_query = dynamic_cast(ast_generator.ParseQuery(query_string))) << query_string; + ASSERT_NE(parsed_query, nullptr); + EXPECT_EQ(parsed_query->stream_name_, stream_name); + ValidateTopicNames(parsed_query->topic_names_, topic_names, ast_generator); + EXPECT_EQ(parsed_query->transform_name_, transform_name); + EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->batch_interval_, batch_interval)); + EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->batch_size_, batch_size)); + EXPECT_EQ(parsed_query->batch_limit_, nullptr); + if (service_url.empty()) { + EXPECT_EQ(parsed_query->service_url_, nullptr); + return; + } + EXPECT_NE(parsed_query->service_url_, nullptr); +} + +TEST_P(CypherMainVisitorTest, CreatePulsarStream) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("CREATE PULSAR STREAM", ast_generator); + TestInvalidQuery("CREATE PULSAR STREAM stream", ast_generator); + TestInvalidQuery("CREATE PULSAR STREAM stream TOPICS", ast_generator); + TestInvalidQuery("CREATE PULSAR STREAM stream TOPICS topic_name", ast_generator); + TestInvalidQuery("CREATE PULSAR STREAM stream TOPICS topic_name TRANSFORM", ast_generator); + TestInvalidQuery("CREATE PULSAR STREAM stream TOPICS topic_name TRANSFORM transform.name SERVICE_URL", ast_generator); + TestInvalidQuery( + "CREATE PULSAR STREAM stream TOPICS topic_name TRANSFORM transform.name SERVICE_URL 1", ast_generator); + TestInvalidQuery( + "CREATE PULSAR STREAM stream TOPICS topic_name TRANSFORM transform.name BOOTSTRAP_SERVERS 'bootstrap'", + ast_generator); + TestInvalidQuery( + "CREATE PULSAR STREAM stream TOPICS topic_name TRANSFORM transform.name SERVICE_URL 'test' TOPICS topic_name", + ast_generator); + TestInvalidQuery( + "CREATE PULSAR STREAM stream TRANSFORM transform.name TOPICS topic_name TRANSFORM transform.name SERVICE_URL " + "'test'", + ast_generator); + TestInvalidQuery( + "CREATE PULSAR STREAM stream BATCH_INTERVAL 1 TOPICS topic_name TRANSFORM transform.name SERVICE_URL 'test' " + "BATCH_INTERVAL 1000", + ast_generator); + TestInvalidQuery( + "CREATE PULSAR STREAM stream BATCH_INTERVAL 'a' TOPICS topic_name TRANSFORM transform.name SERVICE_URL 'test'", + ast_generator); + TestInvalidQuery( + "CREATE PULSAR STREAM stream BATCH_SIZE 'a' TOPICS topic_name TRANSFORM transform.name SERVICE_URL 'test'", + ast_generator); + + const std::vector topic_names{"topic1", "topic2"}; + const std::string topic_names_str = memgraph::utils::Join(topic_names, ","); + static constexpr std::string_view kStreamName{"PulsarStream"}; + static constexpr std::string_view kTransformName{"boringTransformation"}; + static constexpr std::string_view kServiceUrl{"localhost"}; + static constexpr int kBatchSize{1000}; + static constexpr int kBatchInterval{231321}; + + { + SCOPED_TRACE("single topic"); + ValidateCreatePulsarStreamQuery( + ast_generator, + fmt::format("CREATE PULSAR STREAM {} TOPICS {} TRANSFORM {}", kStreamName, topic_names[0], kTransformName), + kStreamName, {topic_names[0]}, kTransformName, std::nullopt, std::nullopt, ""); + } + { + SCOPED_TRACE("multiple topics"); + ValidateCreatePulsarStreamQuery( + ast_generator, + fmt::format("CREATE PULSAR STREAM {} TRANSFORM {} TOPICS {}", kStreamName, kTransformName, topic_names_str), + kStreamName, topic_names, kTransformName, std::nullopt, std::nullopt, ""); + } + { + SCOPED_TRACE("topic name in string"); + ValidateCreatePulsarStreamQuery( + ast_generator, + fmt::format("CREATE PULSAR STREAM {} TRANSFORM {} TOPICS '{}'", kStreamName, kTransformName, topic_names_str), + kStreamName, topic_names, kTransformName, std::nullopt, std::nullopt, ""); + } + { + SCOPED_TRACE("service url"); + ValidateCreatePulsarStreamQuery(ast_generator, + fmt::format("CREATE PULSAR STREAM {} SERVICE_URL '{}' TRANSFORM {} TOPICS {}", + kStreamName, kServiceUrl, kTransformName, topic_names_str), + kStreamName, topic_names, kTransformName, std::nullopt, std::nullopt, kServiceUrl); + ValidateCreatePulsarStreamQuery(ast_generator, + fmt::format("CREATE PULSAR STREAM {} TRANSFORM {} SERVICE_URL '{}' TOPICS {}", + kStreamName, kTransformName, kServiceUrl, topic_names_str), + kStreamName, topic_names, kTransformName, std::nullopt, std::nullopt, kServiceUrl); + } + { + SCOPED_TRACE("batch size"); + ValidateCreatePulsarStreamQuery( + ast_generator, + fmt::format("CREATE PULSAR STREAM {} SERVICE_URL '{}' BATCH_SIZE {} TRANSFORM {} TOPICS {}", kStreamName, + kServiceUrl, kBatchSize, kTransformName, topic_names_str), + kStreamName, topic_names, kTransformName, std::nullopt, TypedValue(kBatchSize), kServiceUrl); + ValidateCreatePulsarStreamQuery( + ast_generator, + fmt::format("CREATE PULSAR STREAM {} TRANSFORM {} SERVICE_URL '{}' TOPICS {} BATCH_SIZE {}", kStreamName, + kTransformName, kServiceUrl, topic_names_str, kBatchSize), + kStreamName, topic_names, kTransformName, std::nullopt, TypedValue(kBatchSize), kServiceUrl); + } + { + SCOPED_TRACE("batch interval"); + ValidateCreatePulsarStreamQuery( + ast_generator, + fmt::format("CREATE PULSAR STREAM {} BATCH_INTERVAL {} SERVICE_URL '{}' BATCH_SIZE {} TRANSFORM {} TOPICS {}", + kStreamName, kBatchInterval, kServiceUrl, kBatchSize, kTransformName, topic_names_str), + kStreamName, topic_names, kTransformName, TypedValue(kBatchInterval), TypedValue(kBatchSize), kServiceUrl); + ValidateCreatePulsarStreamQuery( + ast_generator, + fmt::format("CREATE PULSAR STREAM {} TRANSFORM {} SERVICE_URL '{}' BATCH_INTERVAL {} TOPICS {} BATCH_SIZE {}", + kStreamName, kTransformName, kServiceUrl, kBatchInterval, topic_names_str, kBatchSize), + kStreamName, topic_names, kTransformName, TypedValue(kBatchInterval), TypedValue(kBatchSize), kServiceUrl); + } +} + +TEST_P(CypherMainVisitorTest, CheckStream) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("CHECK STREAM", ast_generator); + TestInvalidQuery("CHECK STREAMS", ast_generator); + TestInvalidQuery("CHECK STREAMS something", ast_generator); + TestInvalidQuery("CHECK STREAM something,something", ast_generator); + TestInvalidQuery("CHECK STREAM something BATCH LIMIT 1", ast_generator); + TestInvalidQuery("CHECK STREAM something BATCH_LIMIT", ast_generator); + TestInvalidQuery("CHECK STREAM something TIMEOUT", ast_generator); + TestInvalidQuery("CHECK STREAM something BATCH_LIMIT 1 TIMEOUT", ast_generator); + TestInvalidQuery("CHECK STREAM something BATCH_LIMIT 'it should be an integer'", ast_generator); + TestInvalidQuery("CHECK STREAM something BATCH_LIMIT 2.5", ast_generator); + TestInvalidQuery("CHECK STREAM something TIMEOUT 'it should be an integer'", ast_generator); + + ValidateMostlyEmptyStreamQuery(ast_generator, "CHECK STREAM checkedStream", StreamQuery::Action::CHECK_STREAM, + "checkedStream"); + ValidateMostlyEmptyStreamQuery(ast_generator, "CHECK STREAM checkedStream bAtCH_LIMIT 42", + StreamQuery::Action::CHECK_STREAM, "checkedStream", TypedValue(42)); + ValidateMostlyEmptyStreamQuery(ast_generator, "CHECK STREAM checkedStream TimEOuT 666", + StreamQuery::Action::CHECK_STREAM, "checkedStream", std::nullopt, TypedValue(666)); + ValidateMostlyEmptyStreamQuery(ast_generator, "CHECK STREAM checkedStream BATCH_LIMIT 30 TIMEOUT 444", + StreamQuery::Action::CHECK_STREAM, "checkedStream", TypedValue(30), TypedValue(444)); +} + +TEST_P(CypherMainVisitorTest, SettingQuery) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("SHOW DB SETTINGS", ast_generator); + TestInvalidQuery("SHOW SETTINGS", ast_generator); + TestInvalidQuery("SHOW DATABASE SETTING", ast_generator); + TestInvalidQuery("SHOW DB SETTING 'setting'", ast_generator); + TestInvalidQuery("SHOW SETTING 'setting'", ast_generator); + TestInvalidQuery("SHOW DATABASE SETTING 1", ast_generator); + TestInvalidQuery("SET SETTING 'setting' TO 'value'", ast_generator); + TestInvalidQuery("SET DB SETTING 'setting' TO 'value'", ast_generator); + TestInvalidQuery("SET DATABASE SETTING 1 TO 'value'", ast_generator); + TestInvalidQuery("SET DATABASE SETTING 'setting' TO 2", ast_generator); + + const auto validate_setting_query = [&](const auto &query, const auto action, + const std::optional &expected_setting_name, + const std::optional &expected_setting_value) { + auto *parsed_query = dynamic_cast(ast_generator.ParseQuery(query)); + EXPECT_EQ(parsed_query->action_, action) << query; + EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->setting_name_, expected_setting_name)); + EXPECT_NO_FATAL_FAILURE( + CheckOptionalExpression(ast_generator, parsed_query->setting_value_, expected_setting_value)); + }; + + validate_setting_query("SHOW DATABASE SETTINGS", SettingQuery::Action::SHOW_ALL_SETTINGS, std::nullopt, std::nullopt); + validate_setting_query("SHOW DATABASE SETTING 'setting'", SettingQuery::Action::SHOW_SETTING, TypedValue{"setting"}, + std::nullopt); + validate_setting_query("SET DATABASE SETTING 'setting' TO 'value'", SettingQuery::Action::SET_SETTING, + TypedValue{"setting"}, TypedValue{"value"}); +} + +TEST_P(CypherMainVisitorTest, VersionQuery) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("SHOW VERION", ast_generator); + TestInvalidQuery("SHOW VER", ast_generator); + TestInvalidQuery("SHOW VERSIONS", ast_generator); + ASSERT_NO_THROW(ast_generator.ParseQuery("SHOW VERSION")); +} + +TEST_P(CypherMainVisitorTest, ForeachThrow) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("FOREACH(i IN [1, 2] | UNWIND [1,2,3] AS j CREATE (n))"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("FOREACH(i IN [1, 2] CREATE (:Foo {prop : i}))"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("FOREACH(i IN [1, 2] | MATCH (n)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("FOREACH(i IN x | MATCH (n)"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, Foreach) { + auto &ast_generator = *GetParam(); + // CREATE + { + auto *query = dynamic_cast( + ast_generator.ParseQuery("FOREACH (age IN [1, 2, 3] | CREATE (m:Age {amount: age}))")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *foreach = dynamic_cast(single_query->clauses_[0]); + ASSERT_TRUE(foreach); + ASSERT_TRUE(foreach->named_expression_); + EXPECT_EQ(foreach->named_expression_->name_, "age"); + auto *expr = foreach->named_expression_->expression_; + ASSERT_TRUE(expr); + ASSERT_TRUE(dynamic_cast(expr)); + const auto &clauses = foreach->clauses_; + ASSERT_TRUE(clauses.size() == 1); + ASSERT_TRUE(dynamic_cast(clauses.front())); + } + // SET + { + auto *query = + dynamic_cast(ast_generator.ParseQuery("FOREACH (i IN nodes(path) | SET i.checkpoint = true)")); + auto *foreach = dynamic_cast(query->single_query_->clauses_[0]); + const auto &clauses = foreach->clauses_; + ASSERT_TRUE(clauses.size() == 1); + ASSERT_TRUE(dynamic_cast(clauses.front())); + } + // REMOVE + { + auto *query = dynamic_cast(ast_generator.ParseQuery("FOREACH (i IN nodes(path) | REMOVE i.prop)")); + auto *foreach = dynamic_cast(query->single_query_->clauses_[0]); + const auto &clauses = foreach->clauses_; + ASSERT_TRUE(clauses.size() == 1); + ASSERT_TRUE(dynamic_cast(clauses.front())); + } + // MERGE + { + // merge works as create here + auto *query = + dynamic_cast(ast_generator.ParseQuery("FOREACH (i IN [1, 2, 3] | MERGE (n {no : i}))")); + auto *foreach = dynamic_cast(query->single_query_->clauses_[0]); + const auto &clauses = foreach->clauses_; + ASSERT_TRUE(clauses.size() == 1); + ASSERT_TRUE(dynamic_cast(clauses.front())); + } + // CYPHER DELETE + { + auto *query = dynamic_cast(ast_generator.ParseQuery("FOREACH (i IN nodes(path) | DETACH DELETE i)")); + auto *foreach = dynamic_cast(query->single_query_->clauses_[0]); + const auto &clauses = foreach->clauses_; + ASSERT_TRUE(clauses.size() == 1); + ASSERT_TRUE(dynamic_cast(clauses.front())); + } + // nested FOREACH + { + auto *query = dynamic_cast(ast_generator.ParseQuery( + "FOREACH (i IN nodes(path) | FOREACH (age IN i.list | CREATE (m:Age {amount: age})))")); + + auto *foreach = dynamic_cast(query->single_query_->clauses_[0]); + const auto &clauses = foreach->clauses_; + ASSERT_TRUE(clauses.size() == 1); + ASSERT_TRUE(dynamic_cast(clauses.front())); + } + // Multiple update clauses + { + auto *query = dynamic_cast( + ast_generator.ParseQuery("FOREACH (i IN nodes(path) | SET i.checkpoint = true REMOVE i.prop)")); + auto *foreach = dynamic_cast(query->single_query_->clauses_[0]); + const auto &clauses = foreach->clauses_; + ASSERT_TRUE(clauses.size() == 2); + ASSERT_TRUE(dynamic_cast(clauses.front())); + ASSERT_TRUE(dynamic_cast(*++clauses.begin())); + } +} + +TEST_P(CypherMainVisitorTest, TestShowSchemas) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("SHOW SCHEMAS")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_, SchemaQuery::Action::SHOW_SCHEMAS); +} + +TEST_P(CypherMainVisitorTest, TestShowSchema) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("SHOW SCHEMA ON label"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("SHOW SCHEMA :label"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("SHOW SCHEMA label"), SyntaxException); + + auto *query = dynamic_cast(ast_generator.ParseQuery("SHOW SCHEMA ON :label")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_, SchemaQuery::Action::SHOW_SCHEMA); + EXPECT_EQ(query->label_, ast_generator.Label("label")); +} + +void AssertSchemaPropertyMap(auto &schema_property_map, + std::vector> properties_type, + auto &ast_generator) { + EXPECT_EQ(schema_property_map.size(), properties_type.size()); + for (size_t i{0}; i < schema_property_map.size(); ++i) { + // Assert PropertyId + EXPECT_EQ(schema_property_map[i].first, ast_generator.Prop(properties_type[i].first)); + // Assert Property Type + EXPECT_EQ(schema_property_map[i].second, properties_type[i].second); + } +} + +TEST_P(CypherMainVisitorTest, TestCreateSchema) { + { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label()"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(123 INTEGER)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(name TYPE)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(name, age)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(name, DURATION)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON label(name INTEGER)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(name INTEGER, name INTEGER)"), SemanticException); + EXPECT_THROW(ast_generator.ParseQuery("CREATE SCHEMA ON :label(name INTEGER, name STRING)"), SemanticException); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("CREATE SCHEMA ON :label1(name STRING)")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_, SchemaQuery::Action::CREATE_SCHEMA); + EXPECT_EQ(query->label_, ast_generator.Label("label1")); + AssertSchemaPropertyMap(query->schema_type_map_, {{"name", memgraph::common::SchemaType::STRING}}, ast_generator); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast(ast_generator.ParseQuery("CREATE SCHEMA ON :label2(name string)")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_, SchemaQuery::Action::CREATE_SCHEMA); + EXPECT_EQ(query->label_, ast_generator.Label("label2")); + AssertSchemaPropertyMap(query->schema_type_map_, {{"name", memgraph::common::SchemaType::STRING}}, ast_generator); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast( + ast_generator.ParseQuery("CREATE SCHEMA ON :label3(first_name STRING, last_name STRING)")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_, SchemaQuery::Action::CREATE_SCHEMA); + EXPECT_EQ(query->label_, ast_generator.Label("label3")); + AssertSchemaPropertyMap( + query->schema_type_map_, + {{"first_name", memgraph::common::SchemaType::STRING}, {"last_name", memgraph::common::SchemaType::STRING}}, + ast_generator); + } + { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast( + ast_generator.ParseQuery("CREATE SCHEMA ON :label4(name STRING, age INTEGER, dur DURATION, birthday1 " + "LOCALDATETIME, birthday2 DATE, some_time LOCALTIME, speaks_truth BOOL)")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_, SchemaQuery::Action::CREATE_SCHEMA); + EXPECT_EQ(query->label_, ast_generator.Label("label4")); + AssertSchemaPropertyMap(query->schema_type_map_, + { + {"name", memgraph::common::SchemaType::STRING}, + {"age", memgraph::common::SchemaType::INT}, + {"dur", memgraph::common::SchemaType::DURATION}, + {"birthday1", memgraph::common::SchemaType::LOCALDATETIME}, + {"birthday2", memgraph::common::SchemaType::DATE}, + {"some_time", memgraph::common::SchemaType::LOCALTIME}, + {"speaks_truth", memgraph::common::SchemaType::BOOL}, + }, + ast_generator); + } +} + +TEST_P(CypherMainVisitorTest, TestDropSchema) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("DROP SCHEMA"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("DROP SCHEMA ON label"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("DROP SCHEMA :label"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("DROP SCHEMA ON :label()"), SyntaxException); + + auto *query = dynamic_cast(ast_generator.ParseQuery("DROP SCHEMA ON :label")); + ASSERT_TRUE(query); + EXPECT_EQ(query->action_, SchemaQuery::Action::DROP_SCHEMA); + EXPECT_EQ(query->label_, ast_generator.Label("label")); +} diff --git a/tests/unit/interpreter_v2.cpp b/tests/unit/query_v2_interpreter.cpp similarity index 90% rename from tests/unit/interpreter_v2.cpp rename to tests/unit/query_v2_interpreter.cpp index 1c4da6865..b73cbeb5a 100644 --- a/tests/unit/interpreter_v2.cpp +++ b/tests/unit/query_v2_interpreter.cpp @@ -15,20 +15,21 @@ #include #include -#include "communication/bolt/v1/value.hpp" -#include "communication/result_stream_faker.hpp" -#include "glue/communication.hpp" #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "query/auth_checker.hpp" -#include "query/config.hpp" -#include "query/exceptions.hpp" -#include "query/interpreter.hpp" -#include "query/stream.hpp" -#include "query/typed_value.hpp" -#include "query_common.hpp" -#include "storage/v2/isolation_level.hpp" -#include "storage/v2/property_value.hpp" + +#include "communication/bolt/v1/value.hpp" +#include "glue/v2/communication.hpp" +#include "query/v2/auth_checker.hpp" +#include "query/v2/config.hpp" +#include "query/v2/exceptions.hpp" +#include "query/v2/interpreter.hpp" +#include "query/v2/stream.hpp" +#include "query/v2/typed_value.hpp" +#include "query_v2_query_common.hpp" +#include "result_stream_faker.hpp" +#include "storage/v3/isolation_level.hpp" +#include "storage/v3/property_value.hpp" #include "utils/csv_parsing.hpp" #include "utils/logging.hpp" @@ -48,13 +49,14 @@ auto StringToUnorderedSet(const std::string &element) { }; struct InterpreterFaker { - InterpreterFaker(memgraph::storage::Storage *db, const memgraph::query::InterpreterConfig config, + InterpreterFaker(memgraph::storage::v3::Storage *db, const memgraph::query::v2::InterpreterConfig config, const std::filesystem::path &data_directory) : interpreter_context(db, config, data_directory), interpreter(&interpreter_context) { interpreter_context.auth_checker = &auth_checker; } - auto Prepare(const std::string &query, const std::map ¶ms = {}) { + auto Prepare(const std::string &query, + const std::map ¶ms = {}) { ResultStreamFaker stream(interpreter_context.db); const auto [header, _, qid] = interpreter.Prepare(query, params, nullptr); @@ -72,7 +74,8 @@ struct InterpreterFaker { * * Return the query stream. */ - auto Interpret(const std::string &query, const std::map ¶ms = {}) { + auto Interpret(const std::string &query, + const std::map ¶ms = {}) { auto prepare_result = Prepare(query, params); auto &stream = prepare_result.first; @@ -82,9 +85,9 @@ struct InterpreterFaker { return std::move(stream); } - memgraph::query::AllowEverythingAuthChecker auth_checker; - memgraph::query::InterpreterContext interpreter_context; - memgraph::query::Interpreter interpreter; + memgraph::query::v2::AllowEverythingAuthChecker auth_checker; + memgraph::query::v2::InterpreterContext interpreter_context; + memgraph::query::v2::Interpreter interpreter; }; } // namespace @@ -94,12 +97,13 @@ struct InterpreterFaker { class InterpreterTest : public ::testing::Test { protected: - memgraph::storage::Storage db_; - std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_interpreter"}; + memgraph::storage::v3::Storage db_; + std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_query_v2_interpreter"}; InterpreterFaker default_interpreter{&db_, {}, data_directory}; - auto Prepare(const std::string &query, const std::map ¶ms = {}) { + auto Prepare(const std::string &query, + const std::map ¶ms = {}) { return default_interpreter.Prepare(query, params); } @@ -107,7 +111,8 @@ class InterpreterTest : public ::testing::Test { default_interpreter.Pull(stream, n, qid); } - auto Interpret(const std::string &query, const std::map ¶ms = {}) { + auto Interpret(const std::string &query, + const std::map ¶ms = {}) { return default_interpreter.Interpret(query, params); } }; @@ -197,8 +202,8 @@ TEST_F(InterpreterTest, AstCache) { // Run query with same ast multiple times with different parameters. TEST_F(InterpreterTest, Parameters) { { - auto stream = Interpret("RETURN $2 + $`a b`", {{"2", memgraph::storage::PropertyValue(10)}, - {"a b", memgraph::storage::PropertyValue(15)}}); + auto stream = Interpret("RETURN $2 + $`a b`", {{"2", memgraph::storage::v3::PropertyValue(10)}, + {"a b", memgraph::storage::v3::PropertyValue(15)}}); ASSERT_EQ(stream.GetHeader().size(), 1U); EXPECT_EQ(stream.GetHeader()[0], "$2 + $`a b`"); ASSERT_EQ(stream.GetResults().size(), 1U); @@ -207,9 +212,9 @@ TEST_F(InterpreterTest, Parameters) { } { // Not needed parameter. - auto stream = Interpret("RETURN $2 + $`a b`", {{"2", memgraph::storage::PropertyValue(10)}, - {"a b", memgraph::storage::PropertyValue(15)}, - {"c", memgraph::storage::PropertyValue(10)}}); + auto stream = Interpret("RETURN $2 + $`a b`", {{"2", memgraph::storage::v3::PropertyValue(10)}, + {"a b", memgraph::storage::v3::PropertyValue(15)}, + {"c", memgraph::storage::v3::PropertyValue(10)}}); ASSERT_EQ(stream.GetHeader().size(), 1U); EXPECT_EQ(stream.GetHeader()[0], "$2 + $`a b`"); ASSERT_EQ(stream.GetResults().size(), 1U); @@ -218,28 +223,29 @@ TEST_F(InterpreterTest, Parameters) { } { // Cached ast, different parameters. - auto stream = Interpret("RETURN $2 + $`a b`", {{"2", memgraph::storage::PropertyValue("da")}, - {"a b", memgraph::storage::PropertyValue("ne")}}); + auto stream = Interpret("RETURN $2 + $`a b`", {{"2", memgraph::storage::v3::PropertyValue("da")}, + {"a b", memgraph::storage::v3::PropertyValue("ne")}}); ASSERT_EQ(stream.GetResults().size(), 1U); ASSERT_EQ(stream.GetResults()[0].size(), 1U); ASSERT_EQ(stream.GetResults()[0][0].ValueString(), "dane"); } { // Non-primitive literal. - auto stream = - Interpret("RETURN $2", {{"2", memgraph::storage::PropertyValue(std::vector{ - memgraph::storage::PropertyValue(5), memgraph::storage::PropertyValue(2), - memgraph::storage::PropertyValue(3)})}}); + auto stream = Interpret( + "RETURN $2", {{"2", memgraph::storage::v3::PropertyValue(std::vector{ + memgraph::storage::v3::PropertyValue(5), memgraph::storage::v3::PropertyValue(2), + memgraph::storage::v3::PropertyValue(3)})}}); ASSERT_EQ(stream.GetResults().size(), 1U); ASSERT_EQ(stream.GetResults()[0].size(), 1U); - auto result = memgraph::query::test_common::ToIntList(memgraph::glue::ToTypedValue(stream.GetResults()[0][0])); + auto result = + memgraph::query::v2::test_common::ToIntList(memgraph::glue::v2::ToTypedValue(stream.GetResults()[0][0])); ASSERT_THAT(result, testing::ElementsAre(5, 2, 3)); } { // Cached ast, unprovided parameter. - ASSERT_THROW(Interpret("RETURN $2 + $`a b`", {{"2", memgraph::storage::PropertyValue("da")}, - {"ab", memgraph::storage::PropertyValue("ne")}}), - memgraph::query::UnprovidedParameterError); + ASSERT_THROW(Interpret("RETURN $2 + $`a b`", {{"2", memgraph::storage::v3::PropertyValue("da")}, + {"ab", memgraph::storage::v3::PropertyValue("ne")}}), + memgraph::query::v2::UnprovidedParameterError); } } @@ -247,12 +253,12 @@ TEST_F(InterpreterTest, Parameters) { TEST_F(InterpreterTest, ParametersAsPropertyMap) { { EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :label(name STRING, age INTEGER)")); - std::map property_map{}; - property_map["name"] = memgraph::storage::PropertyValue("name1"); - property_map["age"] = memgraph::storage::PropertyValue(25); + std::map property_map{}; + property_map["name"] = memgraph::storage::v3::PropertyValue("name1"); + property_map["age"] = memgraph::storage::v3::PropertyValue(25); auto stream = Interpret("CREATE (n:label $prop) RETURN n", { - {"prop", memgraph::storage::PropertyValue(property_map)}, + {"prop", memgraph::storage::v3::PropertyValue(property_map)}, }); ASSERT_EQ(stream.GetHeader().size(), 1U); ASSERT_EQ(stream.GetHeader()[0], "n"); @@ -264,13 +270,13 @@ TEST_F(InterpreterTest, ParametersAsPropertyMap) { } { EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :Person(name STRING, age INTEGER)")); - std::map property_map{}; - property_map["name"] = memgraph::storage::PropertyValue("name1"); - property_map["age"] = memgraph::storage::PropertyValue(25); + std::map property_map{}; + property_map["name"] = memgraph::storage::v3::PropertyValue("name1"); + property_map["age"] = memgraph::storage::v3::PropertyValue(25); EXPECT_NO_THROW(Interpret("CREATE (:Person {name: 'test', age: 30})")); auto stream = Interpret("MATCH (m:Person) CREATE (n:Person $prop) RETURN n", { - {"prop", memgraph::storage::PropertyValue(property_map)}, + {"prop", memgraph::storage::v3::PropertyValue(property_map)}, }); ASSERT_EQ(stream.GetHeader().size(), 1U); ASSERT_EQ(stream.GetHeader()[0], "n"); @@ -282,12 +288,12 @@ TEST_F(InterpreterTest, ParametersAsPropertyMap) { } { EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :L1(name STRING)")); - std::map property_map{}; - property_map["name"] = memgraph::storage::PropertyValue("name1"); - property_map["weight"] = memgraph::storage::PropertyValue(121); + std::map property_map{}; + property_map["name"] = memgraph::storage::v3::PropertyValue("name1"); + property_map["weight"] = memgraph::storage::v3::PropertyValue(121); auto stream = Interpret("CREATE (:L1 {name: 'name1'})-[r:TO $prop]->(:L1 {name: 'name2'}) RETURN r", { - {"prop", memgraph::storage::PropertyValue(property_map)}, + {"prop", memgraph::storage::v3::PropertyValue(property_map)}, }); ASSERT_EQ(stream.GetHeader().size(), 1U); ASSERT_EQ(stream.GetHeader()[0], "r"); @@ -298,25 +304,25 @@ TEST_F(InterpreterTest, ParametersAsPropertyMap) { EXPECT_EQ(result.properties["weight"].ValueInt(), 121); } { - std::map property_map{}; - property_map["name"] = memgraph::storage::PropertyValue("name1"); - property_map["age"] = memgraph::storage::PropertyValue(15); + std::map property_map{}; + property_map["name"] = memgraph::storage::v3::PropertyValue("name1"); + property_map["age"] = memgraph::storage::v3::PropertyValue(15); ASSERT_THROW(Interpret("MATCH (n $prop) RETURN n", { - {"prop", memgraph::storage::PropertyValue(property_map)}, + {"prop", memgraph::storage::v3::PropertyValue(property_map)}, }), - memgraph::query::SemanticException); + memgraph::query::v2::SemanticException); } { EXPECT_NO_THROW(Interpret("CREATE SCHEMA ON :L2(name STRING, age INTEGER)")); - std::map property_map{}; - property_map["name"] = memgraph::storage::PropertyValue("name1"); - property_map["age"] = memgraph::storage::PropertyValue(15); + std::map property_map{}; + property_map["name"] = memgraph::storage::v3::PropertyValue("name1"); + property_map["age"] = memgraph::storage::v3::PropertyValue(15); ASSERT_THROW(Interpret("MERGE (n:L2 $prop) RETURN n", { - {"prop", memgraph::storage::PropertyValue(property_map)}, + {"prop", memgraph::storage::v3::PropertyValue(property_map)}, }), - memgraph::query::SemanticException); + memgraph::query::v2::SemanticException); } } @@ -331,26 +337,26 @@ TEST_F(InterpreterTest, Bfs) { const auto kReachable = "reachable"; const auto kId = "id"; - std::vector> levels(kNumLevels); + std::vector> levels(kNumLevels); int id = 0; // Set up. { auto storage_dba = db_.Access(); - memgraph::query::DbAccessor dba(&storage_dba); + memgraph::query::v2::DbAccessor dba(&storage_dba); auto add_node = [&](int level, bool reachable) { auto node = dba.InsertVertex(); - MG_ASSERT(node.SetProperty(dba.NameToProperty(kId), memgraph::storage::PropertyValue(id++)).HasValue()); + MG_ASSERT(node.SetProperty(dba.NameToProperty(kId), memgraph::storage::v3::PropertyValue(id++)).HasValue()); MG_ASSERT( - node.SetProperty(dba.NameToProperty(kReachable), memgraph::storage::PropertyValue(reachable)).HasValue()); + node.SetProperty(dba.NameToProperty(kReachable), memgraph::storage::v3::PropertyValue(reachable)).HasValue()); levels[level].push_back(node); return node; }; auto add_edge = [&](auto &v1, auto &v2, bool reachable) { auto edge = dba.InsertEdge(&v1, &v2, dba.NameToEdgeType("edge")); - MG_ASSERT( - edge->SetProperty(dba.NameToProperty(kReachable), memgraph::storage::PropertyValue(reachable)).HasValue()); + MG_ASSERT(edge->SetProperty(dba.NameToProperty(kReachable), memgraph::storage::v3::PropertyValue(reachable)) + .HasValue()); }; // Add source node. @@ -497,45 +503,45 @@ TEST_F(InterpreterTest, ShortestPath) { TEST_F(InterpreterTest, CreateLabelIndexInMulticommandTransaction) { Interpret("BEGIN"); - ASSERT_THROW(Interpret("CREATE INDEX ON :X"), memgraph::query::IndexInMulticommandTxException); + ASSERT_THROW(Interpret("CREATE INDEX ON :X"), memgraph::query::v2::IndexInMulticommandTxException); Interpret("ROLLBACK"); } TEST_F(InterpreterTest, CreateLabelPropertyIndexInMulticommandTransaction) { Interpret("BEGIN"); - ASSERT_THROW(Interpret("CREATE INDEX ON :X(y)"), memgraph::query::IndexInMulticommandTxException); + ASSERT_THROW(Interpret("CREATE INDEX ON :X(y)"), memgraph::query::v2::IndexInMulticommandTxException); Interpret("ROLLBACK"); } TEST_F(InterpreterTest, CreateExistenceConstraintInMulticommandTransaction) { Interpret("BEGIN"); ASSERT_THROW(Interpret("CREATE CONSTRAINT ON (n:A) ASSERT EXISTS (n.a)"), - memgraph::query::ConstraintInMulticommandTxException); + memgraph::query::v2::ConstraintInMulticommandTxException); Interpret("ROLLBACK"); } TEST_F(InterpreterTest, CreateUniqueConstraintInMulticommandTransaction) { Interpret("BEGIN"); ASSERT_THROW(Interpret("CREATE CONSTRAINT ON (n:A) ASSERT n.a, n.b IS UNIQUE"), - memgraph::query::ConstraintInMulticommandTxException); + memgraph::query::v2::ConstraintInMulticommandTxException); Interpret("ROLLBACK"); } TEST_F(InterpreterTest, ShowIndexInfoInMulticommandTransaction) { Interpret("BEGIN"); - ASSERT_THROW(Interpret("SHOW INDEX INFO"), memgraph::query::InfoInMulticommandTxException); + ASSERT_THROW(Interpret("SHOW INDEX INFO"), memgraph::query::v2::InfoInMulticommandTxException); Interpret("ROLLBACK"); } TEST_F(InterpreterTest, ShowConstraintInfoInMulticommandTransaction) { Interpret("BEGIN"); - ASSERT_THROW(Interpret("SHOW CONSTRAINT INFO"), memgraph::query::InfoInMulticommandTxException); + ASSERT_THROW(Interpret("SHOW CONSTRAINT INFO"), memgraph::query::v2::InfoInMulticommandTxException); Interpret("ROLLBACK"); } TEST_F(InterpreterTest, ShowStorageInfoInMulticommandTransaction) { Interpret("BEGIN"); - ASSERT_THROW(Interpret("SHOW STORAGE INFO"), memgraph::query::InfoInMulticommandTxException); + ASSERT_THROW(Interpret("SHOW STORAGE INFO"), memgraph::query::v2::InfoInMulticommandTxException); Interpret("ROLLBACK"); } @@ -546,20 +552,21 @@ TEST_F(InterpreterTest, ExistenceConstraintTest) { Interpret("CREATE CONSTRAINT ON (n:A) ASSERT EXISTS (n.b);"); Interpret("CREATE (:A{a: 3, b:1})"); Interpret("CREATE (:A{a: 3, b:2})"); - ASSERT_THROW(Interpret("CREATE (:A {a: 12})"), memgraph::query::QueryException); + ASSERT_THROW(Interpret("CREATE (:A {a: 12})"), memgraph::query::v2::QueryException); Interpret("MATCH (n:A{a:3, b: 2}) SET n.b=5"); Interpret("CREATE (:A{a:2, b: 3})"); Interpret("MATCH (n:A{a:3, b: 1}) DETACH DELETE n"); Interpret("CREATE (n:A{a:2, b: 3})"); - ASSERT_THROW(Interpret("CREATE CONSTRAINT ON (n:A) ASSERT EXISTS (n.c);"), memgraph::query::QueryRuntimeException); + ASSERT_THROW(Interpret("CREATE CONSTRAINT ON (n:A) ASSERT EXISTS (n.c);"), + memgraph::query::v2::QueryRuntimeException); } TEST_F(InterpreterTest, UniqueConstraintTest) { ASSERT_NO_THROW(Interpret("CREATE SCHEMA ON :A(a INTEGER);")); // Empty property list should result with syntax exception. - ASSERT_THROW(Interpret("CREATE CONSTRAINT ON (n:A) ASSERT IS UNIQUE;"), memgraph::query::SyntaxException); - ASSERT_THROW(Interpret("DROP CONSTRAINT ON (n:A) ASSERT IS UNIQUE;"), memgraph::query::SyntaxException); + ASSERT_THROW(Interpret("CREATE CONSTRAINT ON (n:A) ASSERT IS UNIQUE;"), memgraph::query::v2::SyntaxException); + ASSERT_THROW(Interpret("DROP CONSTRAINT ON (n:A) ASSERT IS UNIQUE;"), memgraph::query::v2::SyntaxException); // Too large list of properties should also result with syntax exception. { @@ -572,26 +579,27 @@ TEST_F(InterpreterTest, UniqueConstraintTest) { stream << " IS UNIQUE;"; std::string create_query = "CREATE CONSTRAINT" + stream.str(); std::string drop_query = "DROP CONSTRAINT" + stream.str(); - ASSERT_THROW(Interpret(create_query), memgraph::query::SyntaxException); - ASSERT_THROW(Interpret(drop_query), memgraph::query::SyntaxException); + ASSERT_THROW(Interpret(create_query), memgraph::query::v2::SyntaxException); + ASSERT_THROW(Interpret(drop_query), memgraph::query::v2::SyntaxException); } // Providing property list with duplicates results with syntax exception. ASSERT_THROW(Interpret("CREATE CONSTRAINT ON (n:A) ASSERT n.a, n.b, n.a IS UNIQUE;"), - memgraph::query::SyntaxException); - ASSERT_THROW(Interpret("DROP CONSTRAINT ON (n:A) ASSERT n.a, n.b, n.a IS UNIQUE;"), memgraph::query::SyntaxException); + memgraph::query::v2::SyntaxException); + ASSERT_THROW(Interpret("DROP CONSTRAINT ON (n:A) ASSERT n.a, n.b, n.a IS UNIQUE;"), + memgraph::query::v2::SyntaxException); // Commit of vertex should fail if a constraint is violated. Interpret("CREATE CONSTRAINT ON (n:A) ASSERT n.a, n.b IS UNIQUE;"); Interpret("CREATE (:A{a:1, b:2})"); Interpret("CREATE (:A{a:1, b:3})"); - ASSERT_THROW(Interpret("CREATE (:A{a:1, b:2})"), memgraph::query::QueryException); + ASSERT_THROW(Interpret("CREATE (:A{a:1, b:2})"), memgraph::query::v2::QueryException); // Attempt to create a constraint should fail if it's violated. Interpret("CREATE (:A{a:1, c:2})"); Interpret("CREATE (:A{a:1, c:2})"); ASSERT_THROW(Interpret("CREATE CONSTRAINT ON (n:A) ASSERT n.a, n.c IS UNIQUE;"), - memgraph::query::QueryRuntimeException); + memgraph::query::v2::QueryRuntimeException); Interpret("MATCH (n:A{a:2, b:2}) SET n.a=1"); Interpret("CREATE (:A{a:2})"); @@ -716,7 +724,7 @@ TEST_F(InterpreterTest, ExplainQueryWithParams) { EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); auto stream = - Interpret("EXPLAIN MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::PropertyValue(42)}}); + Interpret("EXPLAIN MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::v3::PropertyValue(42)}}); ASSERT_EQ(stream.GetHeader().size(), 1U); EXPECT_EQ(stream.GetHeader().front(), "QUERY PLAN"); std::vector expected_rows{" * Produce {n}", " * Filter", " * ScanAll (n)", " * Once"}; @@ -731,7 +739,7 @@ TEST_F(InterpreterTest, ExplainQueryWithParams) { EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); // We should have AST cache for EXPLAIN ... and for inner MATCH ... EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); - Interpret("MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::PropertyValue("something else")}}); + Interpret("MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::v3::PropertyValue("something else")}}); EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); } @@ -801,7 +809,7 @@ TEST_F(InterpreterTest, ProfileQueryMultiplePulls) { TEST_F(InterpreterTest, ProfileQueryInMulticommandTransaction) { Interpret("BEGIN"); - ASSERT_THROW(Interpret("PROFILE MATCH (n) RETURN *;"), memgraph::query::ProfileInMulticommandTxException); + ASSERT_THROW(Interpret("PROFILE MATCH (n) RETURN *;"), memgraph::query::v2::ProfileInMulticommandTxException); Interpret("ROLLBACK"); } @@ -811,7 +819,7 @@ TEST_F(InterpreterTest, ProfileQueryWithParams) { EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); auto stream = - Interpret("PROFILE MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::PropertyValue(42)}}); + Interpret("PROFILE MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::v3::PropertyValue(42)}}); std::vector expected_header{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"}; EXPECT_EQ(stream.GetHeader(), expected_header); std::vector expected_rows{"* Produce", "* Filter", "* ScanAll", "* Once"}; @@ -826,7 +834,7 @@ TEST_F(InterpreterTest, ProfileQueryWithParams) { EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); // We should have AST cache for PROFILE ... and for inner MATCH ... EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); - Interpret("MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::PropertyValue("something else")}}); + Interpret("MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::v3::PropertyValue("something else")}}); EXPECT_EQ(interpreter_context.plan_cache.size(), 1U); EXPECT_EQ(interpreter_context.ast_cache.size(), 2U); } @@ -860,10 +868,10 @@ TEST_F(InterpreterTest, ProfileQueryWithLiterals) { TEST_F(InterpreterTest, Transactions) { auto &interpreter = default_interpreter.interpreter; { - ASSERT_THROW(interpreter.CommitTransaction(), memgraph::query::ExplicitTransactionUsageException); - ASSERT_THROW(interpreter.RollbackTransaction(), memgraph::query::ExplicitTransactionUsageException); + ASSERT_THROW(interpreter.CommitTransaction(), memgraph::query::v2::ExplicitTransactionUsageException); + ASSERT_THROW(interpreter.RollbackTransaction(), memgraph::query::v2::ExplicitTransactionUsageException); interpreter.BeginTransaction(); - ASSERT_THROW(interpreter.BeginTransaction(), memgraph::query::ExplicitTransactionUsageException); + ASSERT_THROW(interpreter.BeginTransaction(), memgraph::query::v2::ExplicitTransactionUsageException); auto [stream, qid] = Prepare("RETURN 2"); ASSERT_EQ(stream.GetHeader().size(), 1U); EXPECT_EQ(stream.GetHeader()[0], "2"); @@ -894,7 +902,7 @@ TEST_F(InterpreterTest, Qid) { interpreter.BeginTransaction(); auto [stream, qid] = Prepare("RETURN 2"); ASSERT_TRUE(qid); - ASSERT_THROW(Pull(&stream, {}, *qid + 1), memgraph::query::InvalidArgumentsException); + ASSERT_THROW(Pull(&stream, {}, *qid + 1), memgraph::query::v2::InvalidArgumentsException); interpreter.RollbackTransaction(); } { @@ -1496,40 +1504,41 @@ TEST_F(InterpreterTest, LoadCsvClauseNotification) { TEST_F(InterpreterTest, CreateSchemaMulticommandTransaction) { Interpret("BEGIN"); ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING, age INTEGER)"), - memgraph::query::ConstraintInMulticommandTxException); + memgraph::query::v2::ConstraintInMulticommandTxException); Interpret("ROLLBACK"); } TEST_F(InterpreterTest, ShowSchemasMulticommandTransaction) { Interpret("BEGIN"); - ASSERT_THROW(Interpret("SHOW SCHEMAS"), memgraph::query::ConstraintInMulticommandTxException); + ASSERT_THROW(Interpret("SHOW SCHEMAS"), memgraph::query::v2::ConstraintInMulticommandTxException); Interpret("ROLLBACK"); } TEST_F(InterpreterTest, ShowSchemaMulticommandTransaction) { Interpret("BEGIN"); - ASSERT_THROW(Interpret("SHOW SCHEMA ON :label"), memgraph::query::ConstraintInMulticommandTxException); + ASSERT_THROW(Interpret("SHOW SCHEMA ON :label"), memgraph::query::v2::ConstraintInMulticommandTxException); Interpret("ROLLBACK"); } TEST_F(InterpreterTest, DropSchemaMulticommandTransaction) { Interpret("BEGIN"); - ASSERT_THROW(Interpret("DROP SCHEMA ON :label"), memgraph::query::ConstraintInMulticommandTxException); + ASSERT_THROW(Interpret("DROP SCHEMA ON :label"), memgraph::query::v2::ConstraintInMulticommandTxException); Interpret("ROLLBACK"); } TEST_F(InterpreterTest, SchemaTestCreateAndShow) { // Empty schema type map should result with syntax exception. - ASSERT_THROW(Interpret("CREATE SCHEMA ON :label();"), memgraph::query::SyntaxException); + ASSERT_THROW(Interpret("CREATE SCHEMA ON :label();"), memgraph::query::v2::SyntaxException); // Duplicate properties are should also cause an exception - ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING, name STRING);"), memgraph::query::SemanticException); - ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING, name INTEGER);"), memgraph::query::SemanticException); + ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING, name STRING);"), memgraph::query::v2::SemanticException); + ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING, name INTEGER);"), + memgraph::query::v2::SemanticException); { // Cannot create same schema twice Interpret("CREATE SCHEMA ON :label(name STRING, age INTEGER)"); - ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING);"), memgraph::query::QueryException); + ASSERT_THROW(Interpret("CREATE SCHEMA ON :label(name STRING);"), memgraph::query::v2::QueryException); } // Show schema { @@ -1588,9 +1597,9 @@ TEST_F(InterpreterTest, SchemaTestCreateAndShow) { TEST_F(InterpreterTest, SchemaTestCreateDropAndShow) { Interpret("CREATE SCHEMA ON :label(name STRING, age INTEGER)"); // Wrong syntax for dropping schema. - ASSERT_THROW(Interpret("DROP SCHEMA ON :label();"), memgraph::query::SyntaxException); + ASSERT_THROW(Interpret("DROP SCHEMA ON :label();"), memgraph::query::v2::SyntaxException); // Cannot drop non existant schema. - ASSERT_THROW(Interpret("DROP SCHEMA ON :label1;"), memgraph::query::QueryException); + ASSERT_THROW(Interpret("DROP SCHEMA ON :label1;"), memgraph::query::v2::QueryException); // Create Schema and Drop auto get_number_of_schemas = [this]() { diff --git a/tests/unit/query_v2_query_common.hpp b/tests/unit/query_v2_query_common.hpp new file mode 100644 index 000000000..5471ff42a --- /dev/null +++ b/tests/unit/query_v2_query_common.hpp @@ -0,0 +1,594 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file +/// This file provides macros for easier construction of openCypher query AST. +/// The usage of macros is very similar to how one would write openCypher. For +/// example: +/// +/// AstStorage storage; // Macros rely on storage being in scope. +/// // PROPERTY_LOOKUP and PROPERTY_PAIR macros +/// // rely on a DbAccessor *reference* named dba. +/// database::GraphDb db; +/// auto dba_ptr = db.Access(); +/// auto &dba = *dba_ptr; +/// +/// QUERY(MATCH(PATTERN(NODE("n"), EDGE("e"), NODE("m"))), +/// WHERE(LESS(PROPERTY_LOOKUP("e", edge_prop), LITERAL(3))), +/// RETURN(SUM(PROPERTY_LOOKUP("m", prop)), AS("sum"), +/// ORDER_BY(IDENT("sum")), +/// SKIP(ADD(LITERAL(1), LITERAL(2))))); +/// +/// Each of the macros is accompanied by a function. The functions use overload +/// resolution and template magic to provide a type safe way of constructing +/// queries. Although the functions can be used by themselves, it is more +/// convenient to use the macros. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/frontend/ast/pretty_print.hpp" +#include "storage/v3/id_types.hpp" +#include "utils/string.hpp" + +namespace memgraph::query::v2 { + +namespace test_common { + +auto ToIntList(const TypedValue &t) { + std::vector list; + for (auto x : t.ValueList()) { + list.push_back(x.ValueInt()); + } + return list; +}; + +auto ToIntMap(const TypedValue &t) { + std::map map; + for (const auto &kv : t.ValueMap()) map.emplace(kv.first, kv.second.ValueInt()); + return map; +}; + +std::string ToString(Expression *expr) { + std::ostringstream ss; + PrintExpression(expr, &ss); + return ss.str(); +} + +std::string ToString(NamedExpression *expr) { + std::ostringstream ss; + PrintExpression(expr, &ss); + return ss.str(); +} + +// Custom types for ORDER BY, SKIP, LIMIT, ON MATCH and ON CREATE expressions, +// so that they can be used to resolve function calls. +struct OrderBy { + std::vector expressions; +}; +struct Skip { + Expression *expression = nullptr; +}; +struct Limit { + Expression *expression = nullptr; +}; +struct OnMatch { + std::vector set; +}; +struct OnCreate { + std::vector set; +}; + +// Helper functions for filling the OrderBy with expressions. +auto FillOrderBy(OrderBy &order_by, Expression *expression, Ordering ordering = Ordering::ASC) { + order_by.expressions.push_back({ordering, expression}); +} +template +auto FillOrderBy(OrderBy &order_by, Expression *expression, Ordering ordering, T... rest) { + FillOrderBy(order_by, expression, ordering); + FillOrderBy(order_by, rest...); +} +template +auto FillOrderBy(OrderBy &order_by, Expression *expression, T... rest) { + FillOrderBy(order_by, expression); + FillOrderBy(order_by, rest...); +} + +/// Create OrderBy expressions. +/// +/// The supported combination of arguments is: (Expression, [Ordering])+ +/// Since the Ordering is optional, by default it is ascending. +template +auto GetOrderBy(T... exprs) { + OrderBy order_by; + FillOrderBy(order_by, exprs...); + return order_by; +} + +/// Create PropertyLookup with given name and property. +/// +/// Name is used to create the Identifier which is used for property lookup. +template +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &dba, const std::string &name, + memgraph::storage::v3::PropertyId property) { + return storage.Create(storage.Create(name), + storage.GetPropertyIx(dba.PropertyToName(property))); +} + +template +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &dba, Expression *expr, + memgraph::storage::v3::PropertyId property) { + return storage.Create(expr, storage.GetPropertyIx(dba.PropertyToName(property))); +} + +template +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &dba, Expression *expr, const std::string &property) { + return storage.Create(expr, storage.GetPropertyIx(property)); +} + +template +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &, const std::string &name, + const std::pair &prop_pair) { + return storage.Create(storage.Create(name), storage.GetPropertyIx(prop_pair.first)); +} + +template +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &, Expression *expr, + const std::pair &prop_pair) { + return storage.Create(expr, storage.GetPropertyIx(prop_pair.first)); +} + +/// Create an EdgeAtom with given name, direction and edge_type. +/// +/// Name is used to create the Identifier which is assigned to the edge. +auto GetEdge(AstStorage &storage, const std::string &name, EdgeAtom::Direction dir = EdgeAtom::Direction::BOTH, + const std::vector &edge_types = {}) { + std::vector types; + types.reserve(edge_types.size()); + for (const auto &type : edge_types) { + types.push_back(storage.GetEdgeTypeIx(type)); + } + return storage.Create(storage.Create(name), EdgeAtom::Type::SINGLE, dir, types); +} + +/// Create a variable length expansion EdgeAtom with given name, direction and +/// edge_type. +/// +/// Name is used to create the Identifier which is assigned to the edge. +auto GetEdgeVariable(AstStorage &storage, const std::string &name, EdgeAtom::Type type = EdgeAtom::Type::DEPTH_FIRST, + EdgeAtom::Direction dir = EdgeAtom::Direction::BOTH, + const std::vector &edge_types = {}, Identifier *flambda_inner_edge = nullptr, + Identifier *flambda_inner_node = nullptr, Identifier *wlambda_inner_edge = nullptr, + Identifier *wlambda_inner_node = nullptr, Expression *wlambda_expression = nullptr, + Identifier *total_weight = nullptr) { + std::vector types; + types.reserve(edge_types.size()); + for (const auto &type : edge_types) { + types.push_back(storage.GetEdgeTypeIx(type)); + } + auto r_val = storage.Create(storage.Create(name), type, dir, types); + + r_val->filter_lambda_.inner_edge = + flambda_inner_edge ? flambda_inner_edge : storage.Create(memgraph::utils::RandomString(20)); + r_val->filter_lambda_.inner_node = + flambda_inner_node ? flambda_inner_node : storage.Create(memgraph::utils::RandomString(20)); + + if (type == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) { + r_val->weight_lambda_.inner_edge = + wlambda_inner_edge ? wlambda_inner_edge : storage.Create(memgraph::utils::RandomString(20)); + r_val->weight_lambda_.inner_node = + wlambda_inner_node ? wlambda_inner_node : storage.Create(memgraph::utils::RandomString(20)); + r_val->weight_lambda_.expression = + wlambda_expression ? wlambda_expression : storage.Create(1); + + r_val->total_weight_ = total_weight; + } + + return r_val; +} + +/// Create a NodeAtom with given name and label. +/// +/// Name is used to create the Identifier which is assigned to the node. +auto GetNode(AstStorage &storage, const std::string &name, std::optional label = std::nullopt) { + auto node = storage.Create(storage.Create(name)); + if (label) node->labels_.emplace_back(storage.GetLabelIx(*label)); + return node; +} + +/// Create a Pattern with given atoms. +auto GetPattern(AstStorage &storage, std::vector atoms) { + auto pattern = storage.Create(); + pattern->identifier_ = storage.Create(memgraph::utils::RandomString(20), false); + pattern->atoms_.insert(pattern->atoms_.begin(), atoms.begin(), atoms.end()); + return pattern; +} + +/// Create a Pattern with given name and atoms. +auto GetPattern(AstStorage &storage, const std::string &name, std::vector atoms) { + auto pattern = storage.Create(); + pattern->identifier_ = storage.Create(name, true); + pattern->atoms_.insert(pattern->atoms_.begin(), atoms.begin(), atoms.end()); + return pattern; +} + +/// This function fills an AST node which with given patterns. +/// +/// The function is most commonly used to create Match and Create clauses. +template +auto GetWithPatterns(TWithPatterns *with_patterns, std::vector patterns) { + with_patterns->patterns_.insert(with_patterns->patterns_.begin(), patterns.begin(), patterns.end()); + return with_patterns; +} + +/// Create a query with given clauses. + +auto GetSingleQuery(SingleQuery *single_query, Clause *clause) { + single_query->clauses_.emplace_back(clause); + return single_query; +} +auto GetSingleQuery(SingleQuery *single_query, Match *match, Where *where) { + match->where_ = where; + single_query->clauses_.emplace_back(match); + return single_query; +} +auto GetSingleQuery(SingleQuery *single_query, With *with, Where *where) { + with->where_ = where; + single_query->clauses_.emplace_back(with); + return single_query; +} +template +auto GetSingleQuery(SingleQuery *single_query, Match *match, Where *where, T *...clauses) { + match->where_ = where; + single_query->clauses_.emplace_back(match); + return GetSingleQuery(single_query, clauses...); +} +template +auto GetSingleQuery(SingleQuery *single_query, With *with, Where *where, T *...clauses) { + with->where_ = where; + single_query->clauses_.emplace_back(with); + return GetSingleQuery(single_query, clauses...); +} + +template +auto GetSingleQuery(SingleQuery *single_query, Clause *clause, T *...clauses) { + single_query->clauses_.emplace_back(clause); + return GetSingleQuery(single_query, clauses...); +} + +auto GetCypherUnion(CypherUnion *cypher_union, SingleQuery *single_query) { + cypher_union->single_query_ = single_query; + return cypher_union; +} + +auto GetQuery(AstStorage &storage, SingleQuery *single_query) { + auto *query = storage.Create(); + query->single_query_ = single_query; + return query; +} + +template +auto GetQuery(AstStorage &storage, SingleQuery *single_query, T *...cypher_unions) { + auto *query = storage.Create(); + query->single_query_ = single_query; + query->cypher_unions_ = std::vector{cypher_unions...}; + return query; +} + +// Helper functions for constructing RETURN and WITH clauses. +void FillReturnBody(AstStorage &, ReturnBody &body, NamedExpression *named_expr) { + body.named_expressions.emplace_back(named_expr); +} +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name) { + if (name == "*") { + body.all_identifiers = true; + } else { + auto *ident = storage.Create(name); + auto *named_expr = storage.Create(name, ident); + body.named_expressions.emplace_back(named_expr); + } +} +void FillReturnBody(AstStorage &, ReturnBody &body, Limit limit) { body.limit = limit.expression; } +void FillReturnBody(AstStorage &, ReturnBody &body, Skip skip, Limit limit = Limit{}) { + body.skip = skip.expression; + body.limit = limit.expression; +} +void FillReturnBody(AstStorage &, ReturnBody &body, OrderBy order_by, Limit limit = Limit{}) { + body.order_by = order_by.expressions; + body.limit = limit.expression; +} +void FillReturnBody(AstStorage &, ReturnBody &body, OrderBy order_by, Skip skip, Limit limit = Limit{}) { + body.order_by = order_by.expressions; + body.skip = skip.expression; + body.limit = limit.expression; +} +void FillReturnBody(AstStorage &, ReturnBody &body, Expression *expr, NamedExpression *named_expr) { + // This overload supports `RETURN(expr, AS(name))` construct, since + // NamedExpression does not inherit Expression. + named_expr->expression_ = expr; + body.named_expressions.emplace_back(named_expr); +} +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name, NamedExpression *named_expr) { + named_expr->expression_ = storage.Create(name); + body.named_expressions.emplace_back(named_expr); +} +template +void FillReturnBody(AstStorage &storage, ReturnBody &body, Expression *expr, NamedExpression *named_expr, T... rest) { + named_expr->expression_ = expr; + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} +template +void FillReturnBody(AstStorage &storage, ReturnBody &body, NamedExpression *named_expr, T... rest) { + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} +template +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name, NamedExpression *named_expr, + T... rest) { + named_expr->expression_ = storage.Create(name); + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} +template +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name, T... rest) { + auto *ident = storage.Create(name); + auto *named_expr = storage.Create(name, ident); + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} + +/// Create the return clause with given expressions. +/// +/// The supported expression combination of arguments is: +/// +/// (String | NamedExpression | (Expression NamedExpression))+ +/// [OrderBy] [Skip] [Limit] +/// +/// When the pair (Expression NamedExpression) is given, the Expression will be +/// moved inside the NamedExpression. This is done, so that the constructs like +/// RETURN(expr, AS("name"), ...) are supported. Taking a String is a shorthand +/// for RETURN(IDENT(string), AS(string), ....). +/// +/// @sa GetWith +template +auto GetReturn(AstStorage &storage, bool distinct, T... exprs) { + auto ret = storage.Create(); + ret->body_.distinct = distinct; + FillReturnBody(storage, ret->body_, exprs...); + return ret; +} + +/// Create the with clause with given expressions. +/// +/// The supported expression combination is the same as for @c GetReturn. +/// +/// @sa GetReturn +template +auto GetWith(AstStorage &storage, bool distinct, T... exprs) { + auto with = storage.Create(); + with->body_.distinct = distinct; + FillReturnBody(storage, with->body_, exprs...); + return with; +} + +/// Create the UNWIND clause with given named expression. +auto GetUnwind(AstStorage &storage, NamedExpression *named_expr) { + return storage.Create(named_expr); +} +auto GetUnwind(AstStorage &storage, Expression *expr, NamedExpression *as) { + as->expression_ = expr; + return GetUnwind(storage, as); +} + +/// Create the delete clause with given named expressions. +auto GetDelete(AstStorage &storage, std::vector exprs, bool detach = false) { + auto del = storage.Create(); + del->expressions_.insert(del->expressions_.begin(), exprs.begin(), exprs.end()); + del->detach_ = detach; + return del; +} + +/// Create a set property clause for given property lookup and the right hand +/// side expression. +auto GetSet(AstStorage &storage, PropertyLookup *prop_lookup, Expression *expr) { + return storage.Create(prop_lookup, expr); +} + +/// Create a set properties clause for given identifier name and the right hand +/// side expression. +auto GetSet(AstStorage &storage, const std::string &name, Expression *expr, bool update = false) { + return storage.Create(storage.Create(name), expr, update); +} + +/// Create a set labels clause for given identifier name and labels. +auto GetSet(AstStorage &storage, const std::string &name, std::vector label_names) { + std::vector labels; + labels.reserve(label_names.size()); + for (const auto &label : label_names) { + labels.push_back(storage.GetLabelIx(label)); + } + return storage.Create(storage.Create(name), labels); +} + +/// Create a remove property clause for given property lookup +auto GetRemove(AstStorage &storage, PropertyLookup *prop_lookup) { return storage.Create(prop_lookup); } + +/// Create a remove labels clause for given identifier name and labels. +auto GetRemove(AstStorage &storage, const std::string &name, std::vector label_names) { + std::vector labels; + labels.reserve(label_names.size()); + for (const auto &label : label_names) { + labels.push_back(storage.GetLabelIx(label)); + } + return storage.Create(storage.Create(name), labels); +} + +/// Create a Merge clause for given Pattern with optional OnMatch and OnCreate +/// parts. +auto GetMerge(AstStorage &storage, Pattern *pattern, OnCreate on_create = OnCreate{}) { + auto *merge = storage.Create(); + merge->pattern_ = pattern; + merge->on_create_ = on_create.set; + return merge; +} +auto GetMerge(AstStorage &storage, Pattern *pattern, OnMatch on_match, OnCreate on_create = OnCreate{}) { + auto *merge = storage.Create(); + merge->pattern_ = pattern; + merge->on_match_ = on_match.set; + merge->on_create_ = on_create.set; + return merge; +} + +auto GetCallProcedure(AstStorage &storage, std::string procedure_name, + std::vector arguments = {}) { + auto *call_procedure = storage.Create(); + call_procedure->procedure_name_ = std::move(procedure_name); + call_procedure->arguments_ = std::move(arguments); + return call_procedure; +} + +/// Create the FOREACH clause with given named expression. +auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vector &clauses) { + return storage.Create(named_expr, clauses); +} + +} // namespace test_common + +} // namespace memgraph::query::v2 + +/// All the following macros implicitly pass `storage` variable to functions. +/// You need to have `AstStorage storage;` somewhere in scope to use them. +/// Refer to function documentation to see what the macro does. +/// +/// Example usage: +/// +/// // Create MATCH (n) -[r]- (m) RETURN m AS new_name +/// AstStorage storage; +/// auto query = QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), +/// RETURN(NEXPR("new_name"), IDENT("m"))); +#define NODE(...) memgraph::query::v2::test_common::GetNode(storage, __VA_ARGS__) +#define EDGE(...) memgraph::query::v2::test_common::GetEdge(storage, __VA_ARGS__) +#define EDGE_VARIABLE(...) memgraph::query::v2::test_common::GetEdgeVariable(storage, __VA_ARGS__) +#define PATTERN(...) memgraph::query::v2::test_common::GetPattern(storage, {__VA_ARGS__}) +#define NAMED_PATTERN(name, ...) memgraph::query::v2::test_common::GetPattern(storage, name, {__VA_ARGS__}) +#define OPTIONAL_MATCH(...) \ + memgraph::query::v2::test_common::GetWithPatterns(storage.Create(true), {__VA_ARGS__}) +#define MATCH(...) \ + memgraph::query::v2::test_common::GetWithPatterns(storage.Create(), {__VA_ARGS__}) +#define WHERE(expr) storage.Create((expr)) +#define CREATE(...) \ + memgraph::query::v2::test_common::GetWithPatterns(storage.Create(), {__VA_ARGS__}) +#define IDENT(...) storage.Create(__VA_ARGS__) +#define LITERAL(val) storage.Create((val)) +#define LIST(...) \ + storage.Create(std::vector{__VA_ARGS__}) +#define MAP(...) \ + storage.Create( \ + std::unordered_map{__VA_ARGS__}) +#define PROPERTY_PAIR(property_name) std::make_pair(property_name, dba.NameToProperty(property_name)) +#define PROPERTY_LOOKUP(...) memgraph::query::v2::test_common::GetPropertyLookup(storage, dba, __VA_ARGS__) +#define PARAMETER_LOOKUP(token_position) storage.Create((token_position)) +#define NEXPR(name, expr) storage.Create((name), (expr)) +// AS is alternative to NEXPR which does not initialize NamedExpression with +// Expression. It should be used with RETURN or WITH. For example: +// RETURN(IDENT("n"), AS("n")) vs. RETURN(NEXPR("n", IDENT("n"))). +#define AS(name) storage.Create((name)) +#define RETURN(...) memgraph::query::v2::test_common::GetReturn(storage, false, __VA_ARGS__) +#define WITH(...) memgraph::query::v2::test_common::GetWith(storage, false, __VA_ARGS__) +#define RETURN_DISTINCT(...) memgraph::query::v2::test_common::GetReturn(storage, true, __VA_ARGS__) +#define WITH_DISTINCT(...) memgraph::query::v2::test_common::GetWith(storage, true, __VA_ARGS__) +#define UNWIND(...) memgraph::query::v2::test_common::GetUnwind(storage, __VA_ARGS__) +#define ORDER_BY(...) memgraph::query::v2::test_common::GetOrderBy(__VA_ARGS__) +#define SKIP(expr) \ + memgraph::query::v2::test_common::Skip { (expr) } +#define LIMIT(expr) \ + memgraph::query::v2::test_common::Limit { (expr) } +#define DELETE(...) memgraph::query::v2::test_common::GetDelete(storage, {__VA_ARGS__}) +#define DETACH_DELETE(...) memgraph::query::v2::test_common::GetDelete(storage, {__VA_ARGS__}, true) +#define SET(...) memgraph::query::v2::test_common::GetSet(storage, __VA_ARGS__) +#define REMOVE(...) memgraph::query::v2::test_common::GetRemove(storage, __VA_ARGS__) +#define MERGE(...) memgraph::query::v2::test_common::GetMerge(storage, __VA_ARGS__) +#define ON_MATCH(...) \ + memgraph::query::v2::test_common::OnMatch { \ + std::vector { __VA_ARGS__ } \ + } +#define ON_CREATE(...) \ + memgraph::query::v2::test_common::OnCreate { \ + std::vector { __VA_ARGS__ } \ + } +#define CREATE_INDEX_ON(label, property) \ + storage.Create(memgraph::query::v2::IndexQuery::Action::CREATE, (label), \ + std::vector{(property)}) +#define QUERY(...) memgraph::query::v2::test_common::GetQuery(storage, __VA_ARGS__) +#define SINGLE_QUERY(...) memgraph::query::v2::test_common::GetSingleQuery(storage.Create(), __VA_ARGS__) +#define UNION(...) memgraph::query::v2::test_common::GetCypherUnion(storage.Create(true), __VA_ARGS__) +#define UNION_ALL(...) memgraph::query::v2::test_common::GetCypherUnion(storage.Create(false), __VA_ARGS__) +#define FOREACH(...) memgraph::query::v2::test_common::GetForeach(storage, __VA_ARGS__) +// Various operators +#define NOT(expr) storage.Create((expr)) +#define UPLUS(expr) storage.Create((expr)) +#define UMINUS(expr) storage.Create((expr)) +#define IS_NULL(expr) storage.Create((expr)) +#define ADD(expr1, expr2) storage.Create((expr1), (expr2)) +#define LESS(expr1, expr2) storage.Create((expr1), (expr2)) +#define LESS_EQ(expr1, expr2) storage.Create((expr1), (expr2)) +#define GREATER(expr1, expr2) storage.Create((expr1), (expr2)) +#define GREATER_EQ(expr1, expr2) storage.Create((expr1), (expr2)) +#define SUM(expr) \ + storage.Create((expr), nullptr, memgraph::query::v2::Aggregation::Op::SUM) +#define COUNT(expr) \ + storage.Create((expr), nullptr, memgraph::query::v2::Aggregation::Op::COUNT) +#define AVG(expr) \ + storage.Create((expr), nullptr, memgraph::query::v2::Aggregation::Op::AVG) +#define COLLECT_LIST(expr) \ + storage.Create((expr), nullptr, memgraph::query::v2::Aggregation::Op::COLLECT_LIST) +#define EQ(expr1, expr2) storage.Create((expr1), (expr2)) +#define NEQ(expr1, expr2) storage.Create((expr1), (expr2)) +#define AND(expr1, expr2) storage.Create((expr1), (expr2)) +#define OR(expr1, expr2) storage.Create((expr1), (expr2)) +#define IN_LIST(expr1, expr2) storage.Create((expr1), (expr2)) +#define IF(cond, then, else) storage.Create((cond), (then), (else)) +// Function call +#define FN(function_name, ...) \ + storage.Create(memgraph::utils::ToUpperCase(function_name), \ + std::vector{__VA_ARGS__}) +// List slicing +#define SLICE(list, lower_bound, upper_bound) \ + storage.Create(list, lower_bound, upper_bound) +// all(variable IN list WHERE predicate) +#define ALL(variable, list, where) \ + storage.Create(storage.Create(variable), list, where) +#define SINGLE(variable, list, where) \ + storage.Create(storage.Create(variable), list, where) +#define ANY(variable, list, where) \ + storage.Create(storage.Create(variable), list, where) +#define NONE(variable, list, where) \ + storage.Create(storage.Create(variable), list, where) +#define REDUCE(accumulator, initializer, variable, list, expr) \ + storage.Create(storage.Create(accumulator), \ + initializer, storage.Create(variable), \ + list, expr) +#define COALESCE(...) \ + storage.Create(std::vector{__VA_ARGS__}) +#define EXTRACT(variable, list, expr) \ + storage.Create(storage.Create(variable), list, expr) +#define AUTH_QUERY(action, user, role, user_or_role, password, privileges) \ + storage.Create((action), (user), (role), (user_or_role), password, (privileges)) +#define DROP_USER(usernames) storage.Create((usernames)) +#define CALL_PROCEDURE(...) memgraph::query::v2::test_common::GetCallProcedure(storage, __VA_ARGS__) diff --git a/tests/unit/query_v2_query_plan_accumulate_aggregate.cpp b/tests/unit/query_v2_query_plan_accumulate_aggregate.cpp index 224a3cc69..784b45a8d 100644 --- a/tests/unit/query_v2_query_plan_accumulate_aggregate.cpp +++ b/tests/unit/query_v2_query_plan_accumulate_aggregate.cpp @@ -18,16 +18,17 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "query/context.hpp" -#include "query/exceptions.hpp" -#include "query/plan/operator.hpp" -#include "query_plan_common.hpp" -#include "storage/v2/property_value.hpp" +#include "query/v2/context.hpp" +#include "query/v2/exceptions.hpp" +#include "query/v2/plan/operator.hpp" +#include "query_v2_query_plan_common.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/schemas.hpp" -using namespace memgraph::query; -using namespace memgraph::query::plan; -using memgraph::query::test_common::ToIntList; -using memgraph::query::test_common::ToIntMap; +using namespace memgraph::query::v2; +using namespace memgraph::query::v2::plan; +using test_common::ToIntList; +using test_common::ToIntMap; using testing::UnorderedElementsAre; namespace memgraph::query::v2::tests { @@ -35,12 +36,12 @@ namespace memgraph::query::v2::tests { class QueryPlanAccumulateAggregateTest : public ::testing::Test { protected: void SetUp() override { - ASSERT_TRUE(db.CreateSchema(label, {storage::SchemaProperty{property, common::SchemaType::INT}})); + ASSERT_TRUE(db.CreateSchema(label, {storage::v3::SchemaProperty{property, common::SchemaType::INT}})); } - storage::Storage db; - const storage::LabelId label{db.NameToLabel("label")}; - const storage::PropertyId property{db.NameToProperty("property")}; + storage::v3::Storage db; + const storage::v3::LabelId label{db.NameToLabel("label")}; + const storage::v3::PropertyId property{db.NameToProperty("property")}; }; TEST_F(QueryPlanAccumulateAggregateTest, Accumulate) { @@ -55,10 +56,10 @@ TEST_F(QueryPlanAccumulateAggregateTest, Accumulate) { DbAccessor dba(&storage_dba); auto prop = dba.NameToProperty("x"); - auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); - ASSERT_TRUE(v1.SetProperty(prop, storage::PropertyValue(0)).HasValue()); - auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); - ASSERT_TRUE(v2.SetProperty(prop, storage::PropertyValue(0)).HasValue()); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(v1.SetProperty(prop, storage::v3::PropertyValue(0)).HasValue()); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + ASSERT_TRUE(v2.SetProperty(prop, storage::v3::PropertyValue(0)).HasValue()); ASSERT_TRUE(dba.InsertEdge(&v1, &v2, dba.NameToEdgeType("T")).HasValue()); dba.AdvanceCommand(); @@ -67,7 +68,7 @@ TEST_F(QueryPlanAccumulateAggregateTest, Accumulate) { auto n = MakeScanAll(storage, symbol_table, "n"); auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::BOTH, {}, "m", false, - storage::View::OLD); + storage::v3::View::OLD); auto one = LITERAL(1); auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); @@ -109,7 +110,7 @@ TEST_F(QueryPlanAccumulateAggregateTest, AccumulateAdvance) { NodeCreationInfo node; node.symbol = symbol_table.CreateSymbol("n", true); node.labels = {label}; - std::get>>(node.properties) + std::get>>(node.properties) .emplace_back(property, LITERAL(1)); auto create = std::make_shared(nullptr, node); auto accumulate = std::make_shared(create, std::vector{node.symbol}, advance); @@ -157,14 +158,14 @@ std::shared_ptr MakeAggregationProduce(std::shared_ptr class QueryPlanAggregateOps : public ::testing::Test { protected: void SetUp() override { - ASSERT_TRUE(db.CreateSchema(label, {storage::SchemaProperty{property, common::SchemaType::INT}})); + ASSERT_TRUE(db.CreateSchema(label, {storage::v3::SchemaProperty{property, common::SchemaType::INT}})); } - storage::Storage db; - storage::Storage::Accessor storage_dba{db.Access()}; + storage::v3::Storage db; + storage::v3::Storage::Accessor storage_dba{db.Access()}; DbAccessor dba{&storage_dba}; - storage::LabelId label = db.NameToLabel("label"); - storage::PropertyId property = db.NameToProperty("property"); - storage::PropertyId prop = db.NameToProperty("prop"); + storage::v3::LabelId label = db.NameToLabel("label"); + storage::v3::PropertyId property = db.NameToProperty("property"); + storage::v3::PropertyId prop = db.NameToProperty("prop"); AstStorage storage; SymbolTable symbol_table; @@ -173,18 +174,18 @@ class QueryPlanAggregateOps : public ::testing::Test { // setup is several nodes most of which have an int property set // we will take the sum, avg, min, max and count // we won't group by anything - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}) - ->SetProperty(prop, storage::PropertyValue(5)) + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}) + ->SetProperty(prop, storage::v3::PropertyValue(5)) .HasValue()); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}) - ->SetProperty(prop, storage::PropertyValue(7)) + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}) + ->SetProperty(prop, storage::v3::PropertyValue(7)) .HasValue()); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}) - ->SetProperty(prop, storage::PropertyValue(12)) + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}) + ->SetProperty(prop, storage::v3::PropertyValue(12)) .HasValue()); // a missing property (null) gets ignored by all aggregations except // COUNT(*) - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(4)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(4)}}).HasValue()); dba.AdvanceCommand(); } @@ -305,9 +306,9 @@ TEST_F(QueryPlanAccumulateAggregateTest, AggregateGroupByValues) { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - // a vector of storage::PropertyValue to be set as property values on vertices + // a vector of storage::v3::PropertyValue to be set as property values on vertices // most of them should result in a distinct group (commented where not) - std::vector group_by_vals; + std::vector group_by_vals; group_by_vals.emplace_back(4); group_by_vals.emplace_back(7); group_by_vals.emplace_back(7.3); @@ -317,20 +318,22 @@ TEST_F(QueryPlanAccumulateAggregateTest, AggregateGroupByValues) { group_by_vals.emplace_back("1"); group_by_vals.emplace_back(true); group_by_vals.emplace_back(false); - group_by_vals.emplace_back(std::vector{storage::PropertyValue(1)}); - group_by_vals.emplace_back(std::vector{storage::PropertyValue(1), storage::PropertyValue(2)}); - group_by_vals.emplace_back(std::vector{storage::PropertyValue(2), storage::PropertyValue(1)}); - group_by_vals.emplace_back(storage::PropertyValue()); + group_by_vals.emplace_back(std::vector{storage::v3::PropertyValue(1)}); + group_by_vals.emplace_back( + std::vector{storage::v3::PropertyValue(1), storage::v3::PropertyValue(2)}); + group_by_vals.emplace_back( + std::vector{storage::v3::PropertyValue(2), storage::v3::PropertyValue(1)}); + group_by_vals.emplace_back(storage::v3::PropertyValue()); // should NOT result in another group because 7.0 == 7 group_by_vals.emplace_back(7.0); // should NOT result in another group group_by_vals.emplace_back( - std::vector{storage::PropertyValue(1), storage::PropertyValue(2.0)}); + std::vector{storage::v3::PropertyValue(1), storage::v3::PropertyValue(2.0)}); // generate a lot of vertices and set props on them auto prop = dba.NameToProperty("prop"); for (int i = 0; i < 1000; ++i) - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}) + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}) ->SetProperty(prop, group_by_vals[i % group_by_vals.size()]) .HasValue()); dba.AdvanceCommand(); @@ -371,10 +374,10 @@ TEST_F(QueryPlanAccumulateAggregateTest, AggregateMultipleGroupBy) { auto prop2 = dba.NameToProperty("prop2"); auto prop3 = dba.NameToProperty("prop3"); for (int i = 0; i < 2 * 3 * 5; ++i) { - auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(i)}}); - ASSERT_TRUE(v.SetProperty(prop1, storage::PropertyValue(static_cast(i % 2))).HasValue()); - ASSERT_TRUE(v.SetProperty(prop2, storage::PropertyValue(i % 3)).HasValue()); - ASSERT_TRUE(v.SetProperty(prop3, storage::PropertyValue("value" + std::to_string(i % 5))).HasValue()); + auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i)}}); + ASSERT_TRUE(v.SetProperty(prop1, storage::v3::PropertyValue(static_cast(i % 2))).HasValue()); + ASSERT_TRUE(v.SetProperty(prop2, storage::v3::PropertyValue(i % 3)).HasValue()); + ASSERT_TRUE(v.SetProperty(prop3, storage::v3::PropertyValue("value" + std::to_string(i % 5))).HasValue()); } dba.AdvanceCommand(); @@ -396,7 +399,7 @@ TEST_F(QueryPlanAccumulateAggregateTest, AggregateMultipleGroupBy) { } TEST(QueryPlan, AggregateNoInput) { - storage::Storage db; + storage::v3::Storage db; auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); AstStorage storage; @@ -448,24 +451,24 @@ TEST_F(QueryPlanAccumulateAggregateTest, AggregateCountEdgeCases) { EXPECT_EQ(0, count()); // one vertex, no property set - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); dba.AdvanceCommand(); EXPECT_EQ(0, count()); // one vertex, property set - for (auto va : dba.Vertices(storage::View::OLD)) - ASSERT_TRUE(va.SetProperty(prop, storage::PropertyValue(42)).HasValue()); + for (auto va : dba.Vertices(storage::v3::View::OLD)) + ASSERT_TRUE(va.SetProperty(prop, storage::v3::PropertyValue(42)).HasValue()); dba.AdvanceCommand(); EXPECT_EQ(1, count()); // two vertices, one with property set - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}).HasValue()); dba.AdvanceCommand(); EXPECT_EQ(1, count()); // two vertices, both with property set - for (auto va : dba.Vertices(storage::View::OLD)) - ASSERT_TRUE(va.SetProperty(prop, storage::PropertyValue(42)).HasValue()); + for (auto va : dba.Vertices(storage::v3::View::OLD)) + ASSERT_TRUE(va.SetProperty(prop, storage::v3::PropertyValue(42)).HasValue()); dba.AdvanceCommand(); EXPECT_EQ(2, count()); } @@ -477,11 +480,11 @@ TEST_F(QueryPlanAccumulateAggregateTest, AggregateFirstValueTypes) { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); auto prop_string = dba.NameToProperty("string"); - ASSERT_TRUE(v1.SetProperty(prop_string, storage::PropertyValue("johhny")).HasValue()); + ASSERT_TRUE(v1.SetProperty(prop_string, storage::v3::PropertyValue("johhny")).HasValue()); auto prop_int = dba.NameToProperty("int"); - ASSERT_TRUE(v1.SetProperty(prop_int, storage::PropertyValue(12)).HasValue()); + ASSERT_TRUE(v1.SetProperty(prop_int, storage::v3::PropertyValue(12)).HasValue()); dba.AdvanceCommand(); AstStorage storage; @@ -531,18 +534,18 @@ TEST_F(QueryPlanAccumulateAggregateTest, AggregateTypes) { DbAccessor dba(&storage_dba); auto p1 = dba.NameToProperty("p1"); // has only string props - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}) - ->SetProperty(p1, storage::PropertyValue("string")) + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}) + ->SetProperty(p1, storage::v3::PropertyValue("string")) .HasValue()); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}) - ->SetProperty(p1, storage::PropertyValue("str2")) + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}) + ->SetProperty(p1, storage::v3::PropertyValue("str2")) .HasValue()); auto p2 = dba.NameToProperty("p2"); // combines int and bool - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}) - ->SetProperty(p2, storage::PropertyValue(42)) + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}) + ->SetProperty(p2, storage::v3::PropertyValue(42)) .HasValue()); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}) - ->SetProperty(p2, storage::PropertyValue(true)) + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}) + ->SetProperty(p2, storage::v3::PropertyValue(true)) .HasValue()); dba.AdvanceCommand(); @@ -589,18 +592,18 @@ TEST_F(QueryPlanAccumulateAggregateTest, AggregateTypes) { } TEST(QueryPlan, Unwind) { - storage::Storage db; + storage::v3::Storage db; auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); AstStorage storage; SymbolTable symbol_table; // UNWIND [ [1, true, "x"], [], ["bla"] ] AS x UNWIND x as y RETURN x, y - auto input_expr = storage.Create(std::vector{ - storage::PropertyValue(std::vector{ - storage::PropertyValue(1), storage::PropertyValue(true), storage::PropertyValue("x")}), - storage::PropertyValue(std::vector{}), - storage::PropertyValue(std::vector{storage::PropertyValue("bla")})}); + auto input_expr = storage.Create(std::vector{ + storage::v3::PropertyValue(std::vector{ + storage::v3::PropertyValue(1), storage::v3::PropertyValue(true), storage::v3::PropertyValue("x")}), + storage::v3::PropertyValue(std::vector{}), + storage::v3::PropertyValue(std::vector{storage::v3::PropertyValue("bla")})}); auto x = symbol_table.CreateSymbol("x", true); auto unwind_0 = std::make_shared(nullptr, input_expr, x); diff --git a/tests/unit/query_v2_query_plan_bag_semantics.cpp b/tests/unit/query_v2_query_plan_bag_semantics.cpp index 496f1dc9b..a07ced087 100644 --- a/tests/unit/query_v2_query_plan_bag_semantics.cpp +++ b/tests/unit/query_v2_query_plan_bag_semantics.cpp @@ -17,28 +17,28 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "query/context.hpp" -#include "query/exceptions.hpp" -#include "query/frontend/ast/ast.hpp" -#include "query/plan/operator.hpp" +#include "query/v2/context.hpp" +#include "query/v2/exceptions.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/plan/operator.hpp" -#include "query_plan_common.hpp" -#include "storage/v2/property_value.hpp" +#include "query_v2_query_plan_common.hpp" +#include "storage/v3/property_value.hpp" -using namespace memgraph::query; -using namespace memgraph::query::plan; +using namespace memgraph::query::v2; +using namespace memgraph::query::v2::plan; -namespace memgraph::query::tests { +namespace memgraph::query::v2::tests { class QueryPlanBagSemanticsTest : public testing::Test { protected: void SetUp() override { - ASSERT_TRUE(db.CreateSchema(label, {storage::SchemaProperty{property, common::SchemaType::INT}})); + ASSERT_TRUE(db.CreateSchema(label, {storage::v3::SchemaProperty{property, common::SchemaType::INT}})); } - storage::Storage db; - const storage::LabelId label{db.NameToLabel("label")}; - const storage::PropertyId property{db.NameToProperty("property")}; + storage::v3::Storage db; + const storage::v3::LabelId label{db.NameToLabel("label")}; + const storage::v3::PropertyId property{db.NameToProperty("property")}; }; TEST_F(QueryPlanBagSemanticsTest, Skip) { @@ -54,20 +54,20 @@ TEST_F(QueryPlanBagSemanticsTest, Skip) { auto context = MakeContext(storage, symbol_table, &dba); EXPECT_EQ(0, PullAll(*skip, &context)); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); dba.AdvanceCommand(); EXPECT_EQ(0, PullAll(*skip, &context)); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}).HasValue()); dba.AdvanceCommand(); EXPECT_EQ(0, PullAll(*skip, &context)); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}).HasValue()); dba.AdvanceCommand(); EXPECT_EQ(1, PullAll(*skip, &context)); for (int i = 0; i < 10; ++i) { - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(i + 3)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i + 3)}}).HasValue()); } dba.AdvanceCommand(); EXPECT_EQ(11, PullAll(*skip, &context)); @@ -86,20 +86,20 @@ TEST_F(QueryPlanBagSemanticsTest, Limit) { auto context = MakeContext(storage, symbol_table, &dba); EXPECT_EQ(0, PullAll(*skip, &context)); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); dba.AdvanceCommand(); EXPECT_EQ(1, PullAll(*skip, &context)); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}).HasValue()); dba.AdvanceCommand(); EXPECT_EQ(2, PullAll(*skip, &context)); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}).HasValue()); dba.AdvanceCommand(); EXPECT_EQ(2, PullAll(*skip, &context)); for (int i = 0; i < 10; ++i) { - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(i + 3)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i + 3)}}).HasValue()); } dba.AdvanceCommand(); EXPECT_EQ(2, PullAll(*skip, &context)); @@ -111,8 +111,8 @@ TEST_F(QueryPlanBagSemanticsTest, CreateLimit) { // in the end we need to have 3 vertices in the db auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}).HasValue()); dba.AdvanceCommand(); AstStorage storage; @@ -122,14 +122,15 @@ TEST_F(QueryPlanBagSemanticsTest, CreateLimit) { NodeCreationInfo m; m.symbol = symbol_table.CreateSymbol("m", true); m.labels = {label}; - std::get>>(m.properties).emplace_back(property, LITERAL(3)); + std::get>>(m.properties) + .emplace_back(property, LITERAL(3)); auto c = std::make_shared(n.op_, m); auto skip = std::make_shared(c, LITERAL(1)); auto context = MakeContext(storage, symbol_table, &dba); EXPECT_EQ(1, PullAll(*skip, &context)); dba.AdvanceCommand(); - EXPECT_EQ(3, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(3, CountIterable(dba.Vertices(storage::v3::View::OLD))); } TEST_F(QueryPlanBagSemanticsTest, OrderBy) { @@ -141,33 +142,34 @@ TEST_F(QueryPlanBagSemanticsTest, OrderBy) { // contains a series of tests // each test defines the ordering a vector of values in the desired order - auto Null = storage::PropertyValue(); - std::vector>> orderable{ + auto Null = storage::v3::PropertyValue(); + std::vector>> orderable{ {Ordering::ASC, - {storage::PropertyValue(0), storage::PropertyValue(0), storage::PropertyValue(0.5), storage::PropertyValue(1), - storage::PropertyValue(2), storage::PropertyValue(12.6), storage::PropertyValue(42), Null, Null}}, + {storage::v3::PropertyValue(0), storage::v3::PropertyValue(0), storage::v3::PropertyValue(0.5), + storage::v3::PropertyValue(1), storage::v3::PropertyValue(2), storage::v3::PropertyValue(12.6), + storage::v3::PropertyValue(42), Null, Null}}, {Ordering::ASC, - {storage::PropertyValue(false), storage::PropertyValue(false), storage::PropertyValue(true), - storage::PropertyValue(true), Null, Null}}, + {storage::v3::PropertyValue(false), storage::v3::PropertyValue(false), storage::v3::PropertyValue(true), + storage::v3::PropertyValue(true), Null, Null}}, {Ordering::ASC, - {storage::PropertyValue("A"), storage::PropertyValue("B"), storage::PropertyValue("a"), - storage::PropertyValue("a"), storage::PropertyValue("aa"), storage::PropertyValue("ab"), - storage::PropertyValue("aba"), Null, Null}}, + {storage::v3::PropertyValue("A"), storage::v3::PropertyValue("B"), storage::v3::PropertyValue("a"), + storage::v3::PropertyValue("a"), storage::v3::PropertyValue("aa"), storage::v3::PropertyValue("ab"), + storage::v3::PropertyValue("aba"), Null, Null}}, {Ordering::DESC, - {Null, Null, storage::PropertyValue(33), storage::PropertyValue(33), storage::PropertyValue(32.5), - storage::PropertyValue(32), storage::PropertyValue(2.2), storage::PropertyValue(2.1), - storage::PropertyValue(0)}}, - {Ordering::DESC, {Null, storage::PropertyValue(true), storage::PropertyValue(false)}}, - {Ordering::DESC, {Null, storage::PropertyValue("zorro"), storage::PropertyValue("borro")}}}; + {Null, Null, storage::v3::PropertyValue(33), storage::v3::PropertyValue(33), storage::v3::PropertyValue(32.5), + storage::v3::PropertyValue(32), storage::v3::PropertyValue(2.2), storage::v3::PropertyValue(2.1), + storage::v3::PropertyValue(0)}}, + {Ordering::DESC, {Null, storage::v3::PropertyValue(true), storage::v3::PropertyValue(false)}}, + {Ordering::DESC, {Null, storage::v3::PropertyValue("zorro"), storage::v3::PropertyValue("borro")}}}; for (const auto &order_value_pair : orderable) { std::vector values; values.reserve(order_value_pair.second.size()); for (const auto &v : order_value_pair.second) values.emplace_back(v); // empty database - for (auto vertex : dba.Vertices(storage::View::OLD)) ASSERT_TRUE(dba.DetachRemoveVertex(&vertex).HasValue()); + for (auto vertex : dba.Vertices(storage::v3::View::OLD)) ASSERT_TRUE(dba.DetachRemoveVertex(&vertex).HasValue()); dba.AdvanceCommand(); - ASSERT_EQ(0, CountIterable(dba.Vertices(storage::View::OLD))); + ASSERT_EQ(0, CountIterable(dba.Vertices(storage::v3::View::OLD))); // take some effort to shuffle the values // because we are testing that something not ordered gets ordered @@ -183,8 +185,8 @@ TEST_F(QueryPlanBagSemanticsTest, OrderBy) { // create the vertices for (const auto &value : shuffled) { - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}) - ->SetProperty(prop, storage::PropertyValue(value)) + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}) + ->SetProperty(prop, storage::v3::PropertyValue(value)) .HasValue()); } dba.AdvanceCommand(); @@ -221,9 +223,9 @@ TEST_F(QueryPlanBagSemanticsTest, OrderByMultiple) { for (int i = 0; i < N * N; ++i) prop_values.emplace_back(i % N, i / N); std::random_shuffle(prop_values.begin(), prop_values.end()); for (const auto &pair : prop_values) { - auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); - ASSERT_TRUE(v.SetProperty(p1, storage::PropertyValue(pair.first)).HasValue()); - ASSERT_TRUE(v.SetProperty(p2, storage::PropertyValue(pair.second)).HasValue()); + auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(v.SetProperty(p1, storage::v3::PropertyValue(pair.first)).HasValue()); + ASSERT_TRUE(v.SetProperty(p2, storage::v3::PropertyValue(pair.second)).HasValue()); } dba.AdvanceCommand(); @@ -265,37 +267,37 @@ TEST_F(QueryPlanBagSemanticsTest, OrderByExceptions) { // a vector of pairs of typed values that should result // in an exception when trying to order on them - std::vector> exception_pairs{ - {storage::PropertyValue(42), storage::PropertyValue(true)}, - {storage::PropertyValue(42), storage::PropertyValue("bla")}, - {storage::PropertyValue(42), - storage::PropertyValue(std::vector{storage::PropertyValue(42)})}, - {storage::PropertyValue(true), storage::PropertyValue("bla")}, - {storage::PropertyValue(true), - storage::PropertyValue(std::vector{storage::PropertyValue(true)})}, - {storage::PropertyValue("bla"), - storage::PropertyValue(std::vector{storage::PropertyValue("bla")})}, + std::vector> exception_pairs{ + {storage::v3::PropertyValue(42), storage::v3::PropertyValue(true)}, + {storage::v3::PropertyValue(42), storage::v3::PropertyValue("bla")}, + {storage::v3::PropertyValue(42), + storage::v3::PropertyValue(std::vector{storage::v3::PropertyValue(42)})}, + {storage::v3::PropertyValue(true), storage::v3::PropertyValue("bla")}, + {storage::v3::PropertyValue(true), + storage::v3::PropertyValue(std::vector{storage::v3::PropertyValue(true)})}, + {storage::v3::PropertyValue("bla"), + storage::v3::PropertyValue(std::vector{storage::v3::PropertyValue("bla")})}, // illegal comparisons of same-type values - {storage::PropertyValue(std::vector{storage::PropertyValue(42)}), - storage::PropertyValue(std::vector{storage::PropertyValue(42)})}}; + {storage::v3::PropertyValue(std::vector{storage::v3::PropertyValue(42)}), + storage::v3::PropertyValue(std::vector{storage::v3::PropertyValue(42)})}}; for (const auto &pair : exception_pairs) { // empty database - for (auto vertex : dba.Vertices(storage::View::OLD)) ASSERT_TRUE(dba.DetachRemoveVertex(&vertex).HasValue()); + for (auto vertex : dba.Vertices(storage::v3::View::OLD)) ASSERT_TRUE(dba.DetachRemoveVertex(&vertex).HasValue()); dba.AdvanceCommand(); - ASSERT_EQ(0, CountIterable(dba.Vertices(storage::View::OLD))); + ASSERT_EQ(0, CountIterable(dba.Vertices(storage::v3::View::OLD))); // make two vertices, and set values - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}) + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}) ->SetProperty(prop, pair.first) .HasValue()); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}) + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}) ->SetProperty(prop, pair.second) .HasValue()); dba.AdvanceCommand(); - ASSERT_EQ(2, CountIterable(dba.Vertices(storage::View::OLD))); - for (const auto &va : dba.Vertices(storage::View::OLD)) - ASSERT_NE(va.GetProperty(storage::View::OLD, prop).GetValue().type(), storage::PropertyValue::Type::Null); + ASSERT_EQ(2, CountIterable(dba.Vertices(storage::v3::View::OLD))); + for (const auto &va : dba.Vertices(storage::v3::View::OLD)) + ASSERT_NE(va.GetProperty(storage::v3::View::OLD, prop).GetValue().type(), storage::v3::PropertyValue::Type::Null); // order by and expect an exception auto n = MakeScanAll(storage, symbol_table, "n"); @@ -306,4 +308,4 @@ TEST_F(QueryPlanBagSemanticsTest, OrderByExceptions) { EXPECT_THROW(PullAll(*order_by, &context), QueryRuntimeException); } } -} // namespace memgraph::query::tests +} // namespace memgraph::query::v2::tests diff --git a/tests/unit/query_v2_query_plan_common.hpp b/tests/unit/query_v2_query_plan_common.hpp new file mode 100644 index 000000000..7e535b1d5 --- /dev/null +++ b/tests/unit/query_v2_query_plan_common.hpp @@ -0,0 +1,225 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include +#include + +#include "query/v2/common.hpp" +#include "query/v2/context.hpp" +#include "query/v2/db_accessor.hpp" +#include "query/v2/frontend/semantic/symbol_table.hpp" +#include "query/v2/interpret/frame.hpp" +#include "query/v2/plan/operator.hpp" +#include "storage/v3/storage.hpp" +#include "utils/logging.hpp" + +#include "query_v2_query_common.hpp" + +using namespace memgraph::query::v2; +using namespace memgraph::query::v2::plan; + +using Bound = ScanAllByLabelPropertyRange::Bound; + +ExecutionContext MakeContext(const AstStorage &storage, const SymbolTable &symbol_table, + memgraph::query::v2::DbAccessor *dba) { + ExecutionContext context{dba}; + context.symbol_table = symbol_table; + context.evaluation_context.properties = NamesToProperties(storage.properties_, dba); + context.evaluation_context.labels = NamesToLabels(storage.labels_, dba); + return context; +} + +/** Helper function that collects all the results from the given Produce. */ +std::vector> CollectProduce(const Produce &produce, ExecutionContext *context) { + Frame frame(context->symbol_table.max_position()); + + // top level node in the operator tree is a produce (return) + // so stream out results + + // collect the symbols from the return clause + std::vector symbols; + for (auto named_expression : produce.named_expressions_) + symbols.emplace_back(context->symbol_table.at(*named_expression)); + + // stream out results + auto cursor = produce.MakeCursor(memgraph::utils::NewDeleteResource()); + std::vector> results; + while (cursor->Pull(frame, *context)) { + std::vector values; + for (auto &symbol : symbols) values.emplace_back(frame[symbol]); + results.emplace_back(values); + } + + return results; +} + +int PullAll(const LogicalOperator &logical_op, ExecutionContext *context) { + Frame frame(context->symbol_table.max_position()); + auto cursor = logical_op.MakeCursor(memgraph::utils::NewDeleteResource()); + int count = 0; + while (cursor->Pull(frame, *context)) count++; + return count; +} + +template +auto MakeProduce(std::shared_ptr input, TNamedExpressions... named_expressions) { + return std::make_shared(input, std::vector{named_expressions...}); +} + +struct ScanAllTuple { + NodeAtom *node_; + std::shared_ptr op_; + Symbol sym_; +}; + +/** + * Creates and returns a tuple of stuff for a scan-all starting + * from the node with the given name. + * + * Returns ScanAllTuple(node_atom, scan_all_logical_op, symbol). + */ +ScanAllTuple MakeScanAll(AstStorage &storage, SymbolTable &symbol_table, const std::string &identifier, + std::shared_ptr input = {nullptr}, + memgraph::storage::v3::View view = memgraph::storage::v3::View::OLD) { + auto node = NODE(identifier); + auto symbol = symbol_table.CreateSymbol(identifier, true); + node->identifier_->MapTo(symbol); + auto logical_op = std::make_shared(input, symbol, view); + return ScanAllTuple{node, logical_op, symbol}; +} + +ScanAllTuple MakeScanAllNew(AstStorage &storage, SymbolTable &symbol_table, const std::string &identifier, + std::shared_ptr input = {nullptr}, + memgraph::storage::v3::View view = memgraph::storage::v3::View::OLD) { + auto *node = NODE(identifier, "label"); + auto symbol = symbol_table.CreateSymbol(identifier, true); + node->identifier_->MapTo(symbol); + auto logical_op = std::make_shared(input, symbol, view); + return ScanAllTuple{node, logical_op, symbol}; +} + +/** + * Creates and returns a tuple of stuff for a scan-all starting + * from the node with the given name and label. + * + * Returns ScanAllTuple(node_atom, scan_all_logical_op, symbol). + */ +ScanAllTuple MakeScanAllByLabel(AstStorage &storage, SymbolTable &symbol_table, const std::string &identifier, + memgraph::storage::v3::LabelId label, + std::shared_ptr input = {nullptr}, + memgraph::storage::v3::View view = memgraph::storage::v3::View::OLD) { + auto node = NODE(identifier); + auto symbol = symbol_table.CreateSymbol(identifier, true); + node->identifier_->MapTo(symbol); + auto logical_op = std::make_shared(input, symbol, label, view); + return ScanAllTuple{node, logical_op, symbol}; +} + +/** + * Creates and returns a tuple of stuff for a scan-all starting from the node + * with the given name and label whose property values are in range. + * + * Returns ScanAllTuple(node_atom, scan_all_logical_op, symbol). + */ +ScanAllTuple MakeScanAllByLabelPropertyRange(AstStorage &storage, SymbolTable &symbol_table, std::string identifier, + memgraph::storage::v3::LabelId label, + memgraph::storage::v3::PropertyId property, + const std::string &property_name, std::optional lower_bound, + std::optional upper_bound, + std::shared_ptr input = {nullptr}, + memgraph::storage::v3::View view = memgraph::storage::v3::View::OLD) { + auto node = NODE(identifier); + auto symbol = symbol_table.CreateSymbol(identifier, true); + node->identifier_->MapTo(symbol); + auto logical_op = std::make_shared(input, symbol, label, property, property_name, + lower_bound, upper_bound, view); + return ScanAllTuple{node, logical_op, symbol}; +} + +/** + * Creates and returns a tuple of stuff for a scan-all starting from the node + * with the given name and label whose property value is equal to given value. + * + * Returns ScanAllTuple(node_atom, scan_all_logical_op, symbol). + */ +ScanAllTuple MakeScanAllByLabelPropertyValue(AstStorage &storage, SymbolTable &symbol_table, std::string identifier, + memgraph::storage::v3::LabelId label, + memgraph::storage::v3::PropertyId property, + const std::string &property_name, Expression *value, + std::shared_ptr input = {nullptr}, + memgraph::storage::v3::View view = memgraph::storage::v3::View::OLD) { + auto node = NODE(identifier); + auto symbol = symbol_table.CreateSymbol(identifier, true); + node->identifier_->MapTo(symbol); + auto logical_op = + std::make_shared(input, symbol, label, property, property_name, value, view); + return ScanAllTuple{node, logical_op, symbol}; +} + +struct ExpandTuple { + EdgeAtom *edge_; + Symbol edge_sym_; + NodeAtom *node_; + Symbol node_sym_; + std::shared_ptr op_; +}; + +ExpandTuple MakeExpand(AstStorage &storage, SymbolTable &symbol_table, std::shared_ptr input, + Symbol input_symbol, const std::string &edge_identifier, EdgeAtom::Direction direction, + const std::vector &edge_types, + const std::string &node_identifier, bool existing_node, memgraph::storage::v3::View view) { + auto edge = EDGE(edge_identifier, direction); + auto edge_sym = symbol_table.CreateSymbol(edge_identifier, true); + edge->identifier_->MapTo(edge_sym); + + auto node = NODE(node_identifier); + auto node_sym = symbol_table.CreateSymbol(node_identifier, true); + node->identifier_->MapTo(node_sym); + + auto op = + std::make_shared(input, input_symbol, node_sym, edge_sym, direction, edge_types, existing_node, view); + + return ExpandTuple{edge, edge_sym, node, node_sym, op}; +} + +struct UnwindTuple { + Symbol sym_; + std::shared_ptr op_; +}; + +UnwindTuple MakeUnwind(SymbolTable &symbol_table, const std::string &symbol_name, + std::shared_ptr input, Expression *input_expression) { + auto sym = symbol_table.CreateSymbol(symbol_name, true); + auto op = std::make_shared(input, input_expression, sym); + return UnwindTuple{sym, op}; +} + +template +auto CountIterable(TIterable &&iterable) { + uint64_t count = 0; + for (auto it = iterable.begin(); it != iterable.end(); ++it) { + ++count; + } + return count; +} + +inline uint64_t CountEdges(memgraph::query::v2::DbAccessor *dba, memgraph::storage::v3::View view) { + uint64_t count = 0; + for (auto vertex : dba->Vertices(view)) { + auto maybe_edges = vertex.OutEdges(view); + MG_ASSERT(maybe_edges.HasValue()); + count += CountIterable(*maybe_edges); + } + return count; +} diff --git a/tests/unit/query_v2_query_plan_create_set_remove_delete.cpp b/tests/unit/query_v2_query_plan_create_set_remove_delete.cpp index 81904c17e..723aec472 100644 --- a/tests/unit/query_v2_query_plan_create_set_remove_delete.cpp +++ b/tests/unit/query_v2_query_plan_create_set_remove_delete.cpp @@ -18,34 +18,34 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "query/context.hpp" -#include "query/db_accessor.hpp" -#include "query/exceptions.hpp" -#include "query/interpret/frame.hpp" -#include "query/plan/operator.hpp" +#include "query/v2/context.hpp" +#include "query/v2/db_accessor.hpp" +#include "query/v2/exceptions.hpp" +#include "query/v2/interpret/frame.hpp" +#include "query/v2/plan/operator.hpp" -#include "query_plan_common.hpp" -#include "storage/v2/id_types.hpp" -#include "storage/v2/property_value.hpp" -#include "storage/v2/schemas.hpp" -#include "storage/v2/storage.hpp" -#include "storage/v2/vertex.hpp" -#include "storage/v2/view.hpp" +#include "query_v2_query_plan_common.hpp" +#include "storage/v3/id_types.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/schemas.hpp" +#include "storage/v3/storage.hpp" +#include "storage/v3/vertex.hpp" +#include "storage/v3/view.hpp" -using namespace memgraph::query; -using namespace memgraph::query::plan; +using namespace memgraph::query::v2; +using namespace memgraph::query::v2::plan; namespace memgraph::query::tests { class QueryPlanCRUDTest : public testing::Test { protected: void SetUp() override { - ASSERT_TRUE(db.CreateSchema(label, {storage::SchemaProperty{property, common::SchemaType::INT}})); + ASSERT_TRUE(db.CreateSchema(label, {storage::v3::SchemaProperty{property, common::SchemaType::INT}})); } - storage::Storage db; - const storage::LabelId label{db.NameToLabel("label")}; - const storage::PropertyId property{db.NameToProperty("property")}; + storage::v3::Storage db; + const storage::v3::LabelId label{db.NameToLabel("label")}; + const storage::v3::PropertyId property{db.NameToProperty("property")}; }; TEST_F(QueryPlanCRUDTest, CreateNodeWithAttributes) { @@ -58,7 +58,7 @@ TEST_F(QueryPlanCRUDTest, CreateNodeWithAttributes) { NodeCreationInfo node; node.symbol = symbol_table.CreateSymbol("n", true); node.labels.emplace_back(label); - std::get>>(node.properties) + std::get>>(node.properties) .emplace_back(property, LITERAL(42)); auto create = std::make_shared(nullptr, node); @@ -68,18 +68,18 @@ TEST_F(QueryPlanCRUDTest, CreateNodeWithAttributes) { // count the number of vertices int vertex_count = 0; - for (auto vertex : dba.Vertices(storage::View::OLD)) { + for (auto vertex : dba.Vertices(storage::v3::View::OLD)) { vertex_count++; - auto maybe_labels = vertex.Labels(storage::View::OLD); + auto maybe_labels = vertex.Labels(storage::v3::View::OLD); ASSERT_TRUE(maybe_labels.HasValue()); const auto &labels = *maybe_labels; EXPECT_EQ(labels.size(), 0); - auto maybe_properties = vertex.Properties(storage::View::OLD); + auto maybe_properties = vertex.Properties(storage::v3::View::OLD); ASSERT_TRUE(maybe_properties.HasValue()); const auto &properties = *maybe_properties; EXPECT_EQ(properties.size(), 1); - auto maybe_prop = vertex.GetProperty(storage::View::OLD, property); + auto maybe_prop = vertex.GetProperty(storage::v3::View::OLD, property); ASSERT_TRUE(maybe_prop.HasValue()); auto prop_eq = TypedValue(*maybe_prop) == TypedValue(42); ASSERT_EQ(prop_eq.type(), TypedValue::Type::Bool); @@ -90,13 +90,13 @@ TEST_F(QueryPlanCRUDTest, CreateNodeWithAttributes) { TEST(QueryPlan, CreateReturn) { // test CREATE (n:Person {age: 42}) RETURN n, n.age - storage::Storage db; + storage::v3::Storage db; auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - storage::LabelId label = dba.NameToLabel("Person"); + storage::v3::LabelId label = dba.NameToLabel("Person"); auto property = PROPERTY_PAIR("property"); - db.CreateSchema(label, {storage::SchemaProperty{property.second, common::SchemaType::INT}}); + db.CreateSchema(label, {storage::v3::SchemaProperty{property.second, common::SchemaType::INT}}); AstStorage storage; SymbolTable symbol_table; @@ -104,7 +104,7 @@ TEST(QueryPlan, CreateReturn) { NodeCreationInfo node; node.symbol = symbol_table.CreateSymbol("n", true); node.labels.emplace_back(label); - std::get>>(node.properties) + std::get>>(node.properties) .emplace_back(property.second, LITERAL(42)); auto create = std::make_shared(nullptr, node); @@ -119,47 +119,47 @@ TEST(QueryPlan, CreateReturn) { EXPECT_EQ(1, results.size()); EXPECT_EQ(2, results[0].size()); EXPECT_EQ(TypedValue::Type::Vertex, results[0][0].type()); - auto maybe_labels = results[0][0].ValueVertex().Labels(storage::View::NEW); + auto maybe_labels = results[0][0].ValueVertex().Labels(storage::v3::View::NEW); EXPECT_EQ(maybe_labels->size(), 0); EXPECT_EQ(TypedValue::Type::Int, results[0][1].type()); EXPECT_EQ(42, results[0][1].ValueInt()); dba.AdvanceCommand(); - EXPECT_EQ(1, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::v3::View::OLD))); } TEST(QueryPlan, CreateExpand) { - storage::Storage db; + storage::v3::Storage db; auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - storage::LabelId label_node_1 = dba.NameToLabel("Node1"); - storage::LabelId label_node_2 = dba.NameToLabel("Node2"); + storage::v3::LabelId label_node_1 = dba.NameToLabel("Node1"); + storage::v3::LabelId label_node_2 = dba.NameToLabel("Node2"); auto property = PROPERTY_PAIR("property"); - storage::EdgeTypeId edge_type = dba.NameToEdgeType("edge_type"); - db.CreateSchema(label_node_1, {storage::SchemaProperty{property.second, common::SchemaType::INT}}); - db.CreateSchema(label_node_2, {storage::SchemaProperty{property.second, common::SchemaType::INT}}); + storage::v3::EdgeTypeId edge_type = dba.NameToEdgeType("edge_type"); + db.CreateSchema(label_node_1, {storage::v3::SchemaProperty{property.second, common::SchemaType::INT}}); + db.CreateSchema(label_node_2, {storage::v3::SchemaProperty{property.second, common::SchemaType::INT}}); SymbolTable symbol_table; AstStorage storage; auto test_create_path = [&](bool cycle, int expected_nodes_created, int expected_edges_created) { - int before_v = CountIterable(dba.Vertices(storage::View::OLD)); - int before_e = CountEdges(&dba, storage::View::OLD); + int before_v = CountIterable(dba.Vertices(storage::v3::View::OLD)); + int before_e = CountEdges(&dba, storage::v3::View::OLD); // data for the first node NodeCreationInfo n; n.symbol = symbol_table.CreateSymbol("n", true); n.labels.emplace_back(label_node_1); - std::get>>(n.properties) + std::get>>(n.properties) .emplace_back(property.second, LITERAL(1)); // data for the second node NodeCreationInfo m; m.symbol = cycle ? n.symbol : symbol_table.CreateSymbol("m", true); m.labels.emplace_back(label_node_2); - std::get>>(m.properties) + std::get>>(m.properties) .emplace_back(property.second, LITERAL(2)); EdgeCreationInfo r; @@ -173,37 +173,37 @@ TEST(QueryPlan, CreateExpand) { PullAll(*create_expand, &context); dba.AdvanceCommand(); - EXPECT_EQ(CountIterable(dba.Vertices(storage::View::OLD)) - before_v, expected_nodes_created); - EXPECT_EQ(CountEdges(&dba, storage::View::OLD) - before_e, expected_edges_created); + EXPECT_EQ(CountIterable(dba.Vertices(storage::v3::View::OLD)) - before_v, expected_nodes_created); + EXPECT_EQ(CountEdges(&dba, storage::v3::View::OLD) - before_e, expected_edges_created); }; test_create_path(false, 2, 1); test_create_path(true, 1, 1); - for (auto vertex : dba.Vertices(storage::View::OLD)) { - auto maybe_labels = vertex.Labels(storage::View::OLD); + for (auto vertex : dba.Vertices(storage::v3::View::OLD)) { + auto maybe_labels = vertex.Labels(storage::v3::View::OLD); MG_ASSERT(maybe_labels.HasValue()); const auto &labels = *maybe_labels; EXPECT_EQ(labels.size(), 0); - auto maybe_primary_label = vertex.PrimaryLabel(storage::View::OLD); + auto maybe_primary_label = vertex.PrimaryLabel(storage::v3::View::OLD); ASSERT_TRUE(maybe_primary_label.HasValue()); if (*maybe_primary_label == label_node_1) { // node created by first op - EXPECT_EQ(vertex.GetProperty(storage::View::OLD, property.second)->ValueInt(), 1); + EXPECT_EQ(vertex.GetProperty(storage::v3::View::OLD, property.second)->ValueInt(), 1); } else if (*maybe_primary_label == label_node_2) { // node create by expansion - EXPECT_EQ(vertex.GetProperty(storage::View::OLD, property.second)->ValueInt(), 2); + EXPECT_EQ(vertex.GetProperty(storage::v3::View::OLD, property.second)->ValueInt(), 2); } else { // should not happen FAIL(); } - for (auto vertex : dba.Vertices(storage::View::OLD)) { - auto maybe_edges = vertex.OutEdges(storage::View::OLD); + for (auto vertex : dba.Vertices(storage::v3::View::OLD)) { + auto maybe_edges = vertex.OutEdges(storage::v3::View::OLD); MG_ASSERT(maybe_edges.HasValue()); for (auto edge : *maybe_edges) { EXPECT_EQ(edge.EdgeType(), edge_type); - EXPECT_EQ(edge.GetProperty(storage::View::OLD, property.second)->ValueInt(), 3); + EXPECT_EQ(edge.GetProperty(storage::v3::View::OLD, property.second)->ValueInt(), 3); } } } @@ -213,9 +213,9 @@ TEST_F(QueryPlanCRUDTest, MatchCreateNode) { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}).HasValue()); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}).HasValue()); dba.AdvanceCommand(); SymbolTable symbol_table; @@ -227,38 +227,39 @@ TEST_F(QueryPlanCRUDTest, MatchCreateNode) { NodeCreationInfo m; m.symbol = symbol_table.CreateSymbol("m", true); m.labels = {label}; - std::get>>(m.properties).emplace_back(property, LITERAL(1)); + std::get>>(m.properties) + .emplace_back(property, LITERAL(1)); // creation op auto create_node = std::make_shared(n_scan_all.op_, m); - EXPECT_EQ(CountIterable(dba.Vertices(storage::View::OLD)), 3); + EXPECT_EQ(CountIterable(dba.Vertices(storage::v3::View::OLD)), 3); auto context = MakeContext(storage, symbol_table, &dba); PullAll(*create_node, &context); dba.AdvanceCommand(); - EXPECT_EQ(CountIterable(dba.Vertices(storage::View::OLD)), 6); + EXPECT_EQ(CountIterable(dba.Vertices(storage::v3::View::OLD)), 6); } TEST_F(QueryPlanCRUDTest, MatchCreateExpand) { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}).HasValue()); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}).HasValue()); dba.AdvanceCommand(); - // storage::LabelId label_node_1 = dba.NameToLabel("Node1"); - // storage::LabelId label_node_2 = dba.NameToLabel("Node2"); - // storage::PropertyId property = dba.NameToLabel("prop"); - storage::EdgeTypeId edge_type = dba.NameToEdgeType("edge_type"); + // storage::v3::LabelId label_node_1 = dba.NameToLabel("Node1"); + // storage::v3::LabelId label_node_2 = dba.NameToLabel("Node2"); + // storage::v3::PropertyId property = dba.NameToLabel("prop"); + storage::v3::EdgeTypeId edge_type = dba.NameToEdgeType("edge_type"); SymbolTable symbol_table; AstStorage storage; auto test_create_path = [&](bool cycle, int expected_nodes_created, int expected_edges_created) { - int before_v = CountIterable(dba.Vertices(storage::View::OLD)); - int before_e = CountEdges(&dba, storage::View::OLD); + int before_v = CountIterable(dba.Vertices(storage::v3::View::OLD)); + int before_e = CountEdges(&dba, storage::v3::View::OLD); // data for the first node auto n_scan_all = MakeScanAll(storage, symbol_table, "n"); @@ -267,7 +268,7 @@ TEST_F(QueryPlanCRUDTest, MatchCreateExpand) { NodeCreationInfo m; m.symbol = cycle ? n_scan_all.sym_ : symbol_table.CreateSymbol("m", true); m.labels = {label}; - std::get>>(m.properties) + std::get>>(m.properties) .emplace_back(property, LITERAL(1)); EdgeCreationInfo r; @@ -280,8 +281,8 @@ TEST_F(QueryPlanCRUDTest, MatchCreateExpand) { PullAll(*create_expand, &context); dba.AdvanceCommand(); - EXPECT_EQ(CountIterable(dba.Vertices(storage::View::OLD)) - before_v, expected_nodes_created); - EXPECT_EQ(CountEdges(&dba, storage::View::OLD) - before_e, expected_edges_created); + EXPECT_EQ(CountIterable(dba.Vertices(storage::v3::View::OLD)) - before_v, expected_nodes_created); + EXPECT_EQ(CountEdges(&dba, storage::v3::View::OLD) - before_e, expected_edges_created); }; test_create_path(false, 3, 3); @@ -295,15 +296,15 @@ TEST_F(QueryPlanCRUDTest, Delete) { // make a fully-connected (one-direction, no cycles) with 4 nodes std::vector vertices; for (int i = 0; i < 4; ++i) { - vertices.push_back(*dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(i)}})); + vertices.push_back(*dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i)}})); } auto type = dba.NameToEdgeType("type"); for (int j = 0; j < 4; ++j) for (int k = j + 1; k < 4; ++k) ASSERT_TRUE(dba.InsertEdge(&vertices[j], &vertices[k], type).HasValue()); dba.AdvanceCommand(); - EXPECT_EQ(4, CountIterable(dba.Vertices(storage::View::OLD))); - EXPECT_EQ(6, CountEdges(&dba, storage::View::OLD)); + EXPECT_EQ(4, CountIterable(dba.Vertices(storage::v3::View::OLD))); + EXPECT_EQ(6, CountEdges(&dba, storage::v3::View::OLD)); AstStorage storage; SymbolTable symbol_table; @@ -316,8 +317,8 @@ TEST_F(QueryPlanCRUDTest, Delete) { auto context = MakeContext(storage, symbol_table, &dba); EXPECT_THROW(PullAll(*delete_op, &context), QueryRuntimeException); dba.AdvanceCommand(); - EXPECT_EQ(4, CountIterable(dba.Vertices(storage::View::OLD))); - EXPECT_EQ(6, CountEdges(&dba, storage::View::OLD)); + EXPECT_EQ(4, CountIterable(dba.Vertices(storage::v3::View::OLD))); + EXPECT_EQ(6, CountEdges(&dba, storage::v3::View::OLD)); } // detach delete a single vertex @@ -329,22 +330,22 @@ TEST_F(QueryPlanCRUDTest, Delete) { auto context = MakeContext(storage, symbol_table, &dba); delete_op->MakeCursor(utils::NewDeleteResource())->Pull(frame, context); dba.AdvanceCommand(); - EXPECT_EQ(3, CountIterable(dba.Vertices(storage::View::OLD))); - EXPECT_EQ(3, CountEdges(&dba, storage::View::OLD)); + EXPECT_EQ(3, CountIterable(dba.Vertices(storage::v3::View::OLD))); + EXPECT_EQ(3, CountEdges(&dba, storage::v3::View::OLD)); } // delete all remaining edges { auto n = MakeScanAll(storage, symbol_table, "n"); auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "m", false, - storage::View::NEW); + storage::v3::View::NEW); auto r_get = storage.Create("r")->MapTo(r_m.edge_sym_); auto delete_op = std::make_shared(r_m.op_, std::vector{r_get}, false); auto context = MakeContext(storage, symbol_table, &dba); PullAll(*delete_op, &context); dba.AdvanceCommand(); - EXPECT_EQ(3, CountIterable(dba.Vertices(storage::View::OLD))); - EXPECT_EQ(0, CountEdges(&dba, storage::View::OLD)); + EXPECT_EQ(3, CountIterable(dba.Vertices(storage::v3::View::OLD))); + EXPECT_EQ(0, CountEdges(&dba, storage::v3::View::OLD)); } // delete all remaining vertices @@ -355,8 +356,8 @@ TEST_F(QueryPlanCRUDTest, Delete) { auto context = MakeContext(storage, symbol_table, &dba); PullAll(*delete_op, &context); dba.AdvanceCommand(); - EXPECT_EQ(0, CountIterable(dba.Vertices(storage::View::OLD))); - EXPECT_EQ(0, CountEdges(&dba, storage::View::OLD)); + EXPECT_EQ(0, CountIterable(dba.Vertices(storage::v3::View::OLD))); + EXPECT_EQ(0, CountEdges(&dba, storage::v3::View::OLD)); } } @@ -376,19 +377,19 @@ TEST_F(QueryPlanCRUDTest, DeleteTwiceDeleteBlockingEdge) { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); - auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); ASSERT_TRUE(dba.InsertEdge(&v1, &v2, dba.NameToEdgeType("T")).HasValue()); dba.AdvanceCommand(); - EXPECT_EQ(2, CountIterable(dba.Vertices(storage::View::OLD))); - EXPECT_EQ(1, CountEdges(&dba, storage::View::OLD)); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::v3::View::OLD))); + EXPECT_EQ(1, CountEdges(&dba, storage::v3::View::OLD)); AstStorage storage; SymbolTable symbol_table; auto n = MakeScanAll(storage, symbol_table, "n"); auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::BOTH, {}, "m", false, - storage::View::OLD); + storage::v3::View::OLD); // getter expressions for deletion auto n_get = storage.Create("n")->MapTo(n.sym_); @@ -399,8 +400,8 @@ TEST_F(QueryPlanCRUDTest, DeleteTwiceDeleteBlockingEdge) { auto context = MakeContext(storage, symbol_table, &dba); EXPECT_EQ(2, PullAll(*delete_op, &context)); dba.AdvanceCommand(); - EXPECT_EQ(0, CountIterable(dba.Vertices(storage::View::OLD))); - EXPECT_EQ(0, CountEdges(&dba, storage::View::OLD)); + EXPECT_EQ(0, CountIterable(dba.Vertices(storage::v3::View::OLD))); + EXPECT_EQ(0, CountEdges(&dba, storage::v3::View::OLD)); }; test_delete(true); @@ -413,14 +414,14 @@ TEST_F(QueryPlanCRUDTest, DeleteReturn) { // make a fully-connected (one-direction, no cycles) with 4 nodes for (int i = 0; i < 4; ++i) { - const auto property_value = storage::PropertyValue(i); + const auto property_value = storage::v3::PropertyValue(i); auto va = *dba.InsertVertexAndValidate(label, {}, {{property, property_value}}); - EXPECT_EQ(*va.GetProperty(storage::View::NEW, property), property_value); + EXPECT_EQ(*va.GetProperty(storage::v3::View::NEW, property), property_value); } dba.AdvanceCommand(); - EXPECT_EQ(4, CountIterable(dba.Vertices(storage::View::OLD))); - EXPECT_EQ(0, CountEdges(&dba, storage::View::OLD)); + EXPECT_EQ(4, CountIterable(dba.Vertices(storage::v3::View::OLD))); + EXPECT_EQ(0, CountEdges(&dba, storage::v3::View::OLD)); AstStorage storage; SymbolTable symbol_table; @@ -440,7 +441,7 @@ TEST_F(QueryPlanCRUDTest, DeleteReturn) { TEST(QueryPlan, DeleteNull) { // test (simplified) WITH Null as x delete x - storage::Storage db; + storage::v3::Storage db; auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); AstStorage storage; @@ -469,7 +470,7 @@ TEST_F(QueryPlanCRUDTest, DeleteAdvance) { { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); dba.AdvanceCommand(); auto produce = MakeProduce(advance, NEXPR("res", LITERAL(42))->MapTo(res_sym)); auto context = MakeContext(storage, symbol_table, &dba); @@ -478,7 +479,7 @@ TEST_F(QueryPlanCRUDTest, DeleteAdvance) { { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}).HasValue()); dba.AdvanceCommand(); auto n_prop = PROPERTY_LOOKUP(n_get, dba.NameToProperty("prop")); auto produce = MakeProduce(advance, NEXPR("res", n_prop)->MapTo(res_sym)); @@ -495,10 +496,10 @@ TEST_F(QueryPlanCRUDTest, SetProperty) { // the origin vertex in each par and both edges // have a property set - auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); - auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); - auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}); - auto v4 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(4)}}); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}); + auto v4 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(4)}}); auto edge_type = dba.NameToEdgeType("edge_type"); ASSERT_TRUE(dba.InsertEdge(&v1, &v3, edge_type).HasValue()); ASSERT_TRUE(dba.InsertEdge(&v2, &v4, edge_type).HasValue()); @@ -510,7 +511,7 @@ TEST_F(QueryPlanCRUDTest, SetProperty) { // scan (n)-[r]->(m) auto n = MakeScanAll(storage, symbol_table, "n"); auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "m", false, - storage::View::OLD); + storage::v3::View::OLD); // set prop1 to 42 on n and r auto prop1 = dba.NameToProperty("prop1"); @@ -525,18 +526,18 @@ TEST_F(QueryPlanCRUDTest, SetProperty) { EXPECT_EQ(2, PullAll(*set_r_p, &context)); dba.AdvanceCommand(); - EXPECT_EQ(CountEdges(&dba, storage::View::OLD), 2); - for (auto vertex : dba.Vertices(storage::View::OLD)) { - auto maybe_edges = vertex.OutEdges(storage::View::OLD); + EXPECT_EQ(CountEdges(&dba, storage::v3::View::OLD), 2); + for (auto vertex : dba.Vertices(storage::v3::View::OLD)) { + auto maybe_edges = vertex.OutEdges(storage::v3::View::OLD); ASSERT_TRUE(maybe_edges.HasValue()); for (auto edge : *maybe_edges) { - ASSERT_EQ(edge.GetProperty(storage::View::OLD, prop1)->type(), storage::PropertyValue::Type::Int); - EXPECT_EQ(edge.GetProperty(storage::View::OLD, prop1)->ValueInt(), 42); + ASSERT_EQ(edge.GetProperty(storage::v3::View::OLD, prop1)->type(), storage::v3::PropertyValue::Type::Int); + EXPECT_EQ(edge.GetProperty(storage::v3::View::OLD, prop1)->ValueInt(), 42); auto from = edge.From(); auto to = edge.To(); - ASSERT_EQ(from.GetProperty(storage::View::OLD, prop1)->type(), storage::PropertyValue::Type::Int); - EXPECT_EQ(from.GetProperty(storage::View::OLD, prop1)->ValueInt(), 42); - ASSERT_EQ(to.GetProperty(storage::View::OLD, prop1)->type(), storage::PropertyValue::Type::Null); + ASSERT_EQ(from.GetProperty(storage::v3::View::OLD, prop1)->type(), storage::v3::PropertyValue::Type::Int); + EXPECT_EQ(from.GetProperty(storage::v3::View::OLD, prop1)->ValueInt(), 42); + ASSERT_EQ(to.GetProperty(storage::v3::View::OLD, prop1)->type(), storage::v3::PropertyValue::Type::Null); } } } @@ -550,14 +551,14 @@ TEST_F(QueryPlanCRUDTest, SetProperties) { auto prop_a = dba.NameToProperty("a"); auto prop_b = dba.NameToProperty("b"); auto prop_c = dba.NameToProperty("c"); - auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); - auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); dba.AdvanceCommand(); auto e = dba.InsertEdge(&v1, &v2, dba.NameToEdgeType("R")); - ASSERT_TRUE(v1.SetPropertyAndValidate(prop_a, storage::PropertyValue(0)).HasValue()); - ASSERT_TRUE(e->SetProperty(prop_b, storage::PropertyValue(1)).HasValue()); - ASSERT_TRUE(v2.SetPropertyAndValidate(prop_c, storage::PropertyValue(2)).HasValue()); + ASSERT_TRUE(v1.SetPropertyAndValidate(prop_a, storage::v3::PropertyValue(0)).HasValue()); + ASSERT_TRUE(e->SetProperty(prop_b, storage::v3::PropertyValue(1)).HasValue()); + ASSERT_TRUE(v2.SetPropertyAndValidate(prop_c, storage::v3::PropertyValue(2)).HasValue()); dba.AdvanceCommand(); AstStorage storage; @@ -566,7 +567,7 @@ TEST_F(QueryPlanCRUDTest, SetProperties) { // scan (n)-[r]->(m) auto n = MakeScanAll(storage, symbol_table, "n"); auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "m", false, - storage::View::OLD); + storage::v3::View::OLD); auto op = update ? plan::SetProperties::Op::UPDATE : plan::SetProperties::Op::REPLACE; @@ -579,32 +580,32 @@ TEST_F(QueryPlanCRUDTest, SetProperties) { EXPECT_EQ(1, PullAll(*set_m_to_r, &context)); dba.AdvanceCommand(); - EXPECT_EQ(CountEdges(&dba, storage::View::OLD), 1); - for (auto vertex : dba.Vertices(storage::View::OLD)) { - auto maybe_edges = vertex.OutEdges(storage::View::OLD); + EXPECT_EQ(CountEdges(&dba, storage::v3::View::OLD), 1); + for (auto vertex : dba.Vertices(storage::v3::View::OLD)) { + auto maybe_edges = vertex.OutEdges(storage::v3::View::OLD); ASSERT_TRUE(maybe_edges.HasValue()); for (auto edge : *maybe_edges) { auto from = edge.From(); - EXPECT_EQ(from.Properties(storage::View::OLD)->size(), update ? 3 : 1); + EXPECT_EQ(from.Properties(storage::v3::View::OLD)->size(), update ? 3 : 1); if (update) { - ASSERT_EQ(from.GetProperty(storage::View::OLD, prop_a)->type(), storage::PropertyValue::Type::Int); - EXPECT_EQ(from.GetProperty(storage::View::OLD, prop_a)->ValueInt(), 0); + ASSERT_EQ(from.GetProperty(storage::v3::View::OLD, prop_a)->type(), storage::v3::PropertyValue::Type::Int); + EXPECT_EQ(from.GetProperty(storage::v3::View::OLD, prop_a)->ValueInt(), 0); } - ASSERT_EQ(from.GetProperty(storage::View::OLD, prop_b)->type(), storage::PropertyValue::Type::Int); - EXPECT_EQ(from.GetProperty(storage::View::OLD, prop_b)->ValueInt(), 1); + ASSERT_EQ(from.GetProperty(storage::v3::View::OLD, prop_b)->type(), storage::v3::PropertyValue::Type::Int); + EXPECT_EQ(from.GetProperty(storage::v3::View::OLD, prop_b)->ValueInt(), 1); - EXPECT_EQ(edge.Properties(storage::View::OLD)->size(), update ? 3 : 2); + EXPECT_EQ(edge.Properties(storage::v3::View::OLD)->size(), update ? 3 : 2); if (update) { - ASSERT_EQ(edge.GetProperty(storage::View::OLD, prop_b)->type(), storage::PropertyValue::Type::Int); - EXPECT_EQ(edge.GetProperty(storage::View::OLD, prop_b)->ValueInt(), 1); + ASSERT_EQ(edge.GetProperty(storage::v3::View::OLD, prop_b)->type(), storage::v3::PropertyValue::Type::Int); + EXPECT_EQ(edge.GetProperty(storage::v3::View::OLD, prop_b)->ValueInt(), 1); } - ASSERT_EQ(edge.GetProperty(storage::View::OLD, prop_c)->type(), storage::PropertyValue::Type::Int); - EXPECT_EQ(edge.GetProperty(storage::View::OLD, prop_c)->ValueInt(), 2); + ASSERT_EQ(edge.GetProperty(storage::v3::View::OLD, prop_c)->type(), storage::v3::PropertyValue::Type::Int); + EXPECT_EQ(edge.GetProperty(storage::v3::View::OLD, prop_c)->ValueInt(), 2); auto to = edge.To(); - EXPECT_EQ(to.Properties(storage::View::OLD)->size(), 2); - ASSERT_EQ(to.GetProperty(storage::View::OLD, prop_c)->type(), storage::PropertyValue::Type::Int); - EXPECT_EQ(to.GetProperty(storage::View::OLD, prop_c)->ValueInt(), 2); + EXPECT_EQ(to.Properties(storage::v3::View::OLD)->size(), 2); + ASSERT_EQ(to.GetProperty(storage::v3::View::OLD, prop_c)->type(), storage::v3::PropertyValue::Type::Int); + EXPECT_EQ(to.GetProperty(storage::v3::View::OLD, prop_c)->ValueInt(), 2); } } }; @@ -617,8 +618,8 @@ TEST_F(QueryPlanCRUDTest, SetSecondaryLabels) { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); - auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); auto label1 = dba.NameToLabel("label1"); auto label2 = dba.NameToLabel("label2"); @@ -631,14 +632,14 @@ TEST_F(QueryPlanCRUDTest, SetSecondaryLabels) { SymbolTable symbol_table; auto n = MakeScanAll(storage, symbol_table, "n"); - auto label_set = std::make_shared(n.op_, n.sym_, std::vector{label2, label3}); + auto label_set = std::make_shared(n.op_, n.sym_, std::vector{label2, label3}); auto context = MakeContext(storage, symbol_table, &dba); EXPECT_EQ(2, PullAll(*label_set, &context)); - for (auto vertex : dba.Vertices(storage::View::OLD)) { - EXPECT_EQ(3, vertex.Labels(storage::View::NEW)->size()); - EXPECT_TRUE(*vertex.HasLabel(storage::View::NEW, label2)); - EXPECT_TRUE(*vertex.HasLabel(storage::View::NEW, label3)); + for (auto vertex : dba.Vertices(storage::v3::View::OLD)) { + EXPECT_EQ(3, vertex.Labels(storage::v3::View::NEW)->size()); + EXPECT_TRUE(*vertex.HasLabel(storage::v3::View::NEW, label2)); + EXPECT_TRUE(*vertex.HasLabel(storage::v3::View::NEW, label3)); } } @@ -650,23 +651,23 @@ TEST_F(QueryPlanCRUDTest, RemoveProperty) { // the origin vertex in each par and both edges // have a property set auto prop1 = dba.NameToProperty("prop1"); - auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); - auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); - auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}); - auto v4 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(4)}}); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}); + auto v4 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(4)}}); auto edge_type = dba.NameToEdgeType("edge_type"); { auto e = dba.InsertEdge(&v1, &v3, edge_type); ASSERT_TRUE(e.HasValue()); - ASSERT_TRUE(e->SetProperty(prop1, storage::PropertyValue(42)).HasValue()); + ASSERT_TRUE(e->SetProperty(prop1, storage::v3::PropertyValue(42)).HasValue()); } ASSERT_TRUE(dba.InsertEdge(&v2, &v4, edge_type).HasValue()); - ASSERT_TRUE(v2.SetProperty(prop1, storage::PropertyValue(42)).HasValue()); - ASSERT_TRUE(v3.SetProperty(prop1, storage::PropertyValue(42)).HasValue()); - ASSERT_TRUE(v4.SetProperty(prop1, storage::PropertyValue(42)).HasValue()); + ASSERT_TRUE(v2.SetProperty(prop1, storage::v3::PropertyValue(42)).HasValue()); + ASSERT_TRUE(v3.SetProperty(prop1, storage::v3::PropertyValue(42)).HasValue()); + ASSERT_TRUE(v4.SetProperty(prop1, storage::v3::PropertyValue(42)).HasValue()); auto prop2 = dba.NameToProperty("prop2"); - ASSERT_TRUE(v1.SetProperty(prop2, storage::PropertyValue(0)).HasValue()); - ASSERT_TRUE(v2.SetProperty(prop2, storage::PropertyValue(0)).HasValue()); + ASSERT_TRUE(v1.SetProperty(prop2, storage::v3::PropertyValue(0)).HasValue()); + ASSERT_TRUE(v2.SetProperty(prop2, storage::v3::PropertyValue(0)).HasValue()); dba.AdvanceCommand(); AstStorage storage; @@ -675,7 +676,7 @@ TEST_F(QueryPlanCRUDTest, RemoveProperty) { // scan (n)-[r]->(m) auto n = MakeScanAll(storage, symbol_table, "n"); auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "m", false, - storage::View::OLD); + storage::v3::View::OLD); auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop1); auto set_n_p = std::make_shared(r_m.op_, prop1, n_p); @@ -686,17 +687,17 @@ TEST_F(QueryPlanCRUDTest, RemoveProperty) { EXPECT_EQ(2, PullAll(*set_r_p, &context)); dba.AdvanceCommand(); - EXPECT_EQ(CountEdges(&dba, storage::View::OLD), 2); - for (auto vertex : dba.Vertices(storage::View::OLD)) { - auto maybe_edges = vertex.OutEdges(storage::View::OLD); + EXPECT_EQ(CountEdges(&dba, storage::v3::View::OLD), 2); + for (auto vertex : dba.Vertices(storage::v3::View::OLD)) { + auto maybe_edges = vertex.OutEdges(storage::v3::View::OLD); ASSERT_TRUE(maybe_edges.HasValue()); for (auto edge : *maybe_edges) { - EXPECT_EQ(edge.GetProperty(storage::View::OLD, prop1)->type(), storage::PropertyValue::Type::Null); + EXPECT_EQ(edge.GetProperty(storage::v3::View::OLD, prop1)->type(), storage::v3::PropertyValue::Type::Null); auto from = edge.From(); auto to = edge.To(); - EXPECT_EQ(from.GetProperty(storage::View::OLD, prop1)->type(), storage::PropertyValue::Type::Null); - EXPECT_EQ(from.GetProperty(storage::View::OLD, prop2)->type(), storage::PropertyValue::Type::Int); - EXPECT_EQ(to.GetProperty(storage::View::OLD, prop1)->type(), storage::PropertyValue::Type::Int); + EXPECT_EQ(from.GetProperty(storage::v3::View::OLD, prop1)->type(), storage::v3::PropertyValue::Type::Null); + EXPECT_EQ(from.GetProperty(storage::v3::View::OLD, prop2)->type(), storage::v3::PropertyValue::Type::Int); + EXPECT_EQ(to.GetProperty(storage::v3::View::OLD, prop1)->type(), storage::v3::PropertyValue::Type::Int); } } } @@ -708,11 +709,11 @@ TEST_F(QueryPlanCRUDTest, RemoveLabels) { auto label1 = dba.NameToLabel("label1"); auto label2 = dba.NameToLabel("label2"); auto label3 = dba.NameToLabel("label3"); - auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); ASSERT_TRUE(v1.AddLabel(label1).HasValue()); ASSERT_TRUE(v1.AddLabel(label2).HasValue()); ASSERT_TRUE(v1.AddLabel(label3).HasValue()); - auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); ASSERT_TRUE(v2.AddLabel(label1).HasValue()); ASSERT_TRUE(v2.AddLabel(label3).HasValue()); dba.AdvanceCommand(); @@ -722,14 +723,14 @@ TEST_F(QueryPlanCRUDTest, RemoveLabels) { auto n = MakeScanAll(storage, symbol_table, "n"); auto label_remove = - std::make_shared(n.op_, n.sym_, std::vector{label1, label2}); + std::make_shared(n.op_, n.sym_, std::vector{label1, label2}); auto context = MakeContext(storage, symbol_table, &dba); EXPECT_EQ(2, PullAll(*label_remove, &context)); - for (auto vertex : dba.Vertices(storage::View::OLD)) { - EXPECT_EQ(1, vertex.Labels(storage::View::NEW)->size()); - EXPECT_FALSE(*vertex.HasLabel(storage::View::NEW, label1)); - EXPECT_FALSE(*vertex.HasLabel(storage::View::NEW, label2)); + for (auto vertex : dba.Vertices(storage::v3::View::OLD)) { + EXPECT_EQ(1, vertex.Labels(storage::v3::View::NEW)->size()); + EXPECT_FALSE(*vertex.HasLabel(storage::v3::View::NEW, label1)); + EXPECT_FALSE(*vertex.HasLabel(storage::v3::View::NEW, label2)); } } @@ -738,11 +739,11 @@ TEST_F(QueryPlanCRUDTest, NodeFilterSet) { DbAccessor dba(&storage_dba); // Create a graph such that (v1 {prop: 42}) is connected to v2 and v3. - auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); auto prop = PROPERTY_PAIR("prop"); - ASSERT_TRUE(v1.SetProperty(prop.second, storage::PropertyValue(42)).HasValue()); - auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); - auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}); + ASSERT_TRUE(v1.SetProperty(prop.second, storage::v3::PropertyValue(42)).HasValue()); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}); auto edge_type = dba.NameToEdgeType("Edge"); ASSERT_TRUE(dba.InsertEdge(&v1, &v2, edge_type).HasValue()); ASSERT_TRUE(dba.InsertEdge(&v1, &v3, edge_type).HasValue()); @@ -756,7 +757,7 @@ TEST_F(QueryPlanCRUDTest, NodeFilterSet) { auto scan_all = MakeScanAll(storage, symbol_table, "n"); std::get<0>(scan_all.node_->properties_)[storage.GetPropertyIx(prop.first)] = LITERAL(42); auto expand = MakeExpand(storage, symbol_table, scan_all.op_, scan_all.sym_, "r", EdgeAtom::Direction::BOTH, {}, "m", - false, storage::View::OLD); + false, storage::v3::View::OLD); auto *filter_expr = EQ(storage.Create(scan_all.node_->identifier_, storage.GetPropertyIx(prop.first)), LITERAL(42)); auto node_filter = std::make_shared(expand.op_, filter_expr); @@ -767,7 +768,7 @@ TEST_F(QueryPlanCRUDTest, NodeFilterSet) { auto context = MakeContext(storage, symbol_table, &dba); EXPECT_EQ(2, PullAll(*set, &context)); dba.AdvanceCommand(); - auto prop_eq = TypedValue(*v1.GetProperty(storage::View::OLD, prop.second)) == TypedValue(42 + 2); + auto prop_eq = TypedValue(*v1.GetProperty(storage::v3::View::OLD, prop.second)) == TypedValue(42 + 2); ASSERT_EQ(prop_eq.type(), TypedValue::Type::Bool); EXPECT_TRUE(prop_eq.ValueBool()); } @@ -777,11 +778,11 @@ TEST_F(QueryPlanCRUDTest, FilterRemove) { DbAccessor dba(&storage_dba); // Create a graph such that (v1 {prop: 42}) is connected to v2 and v3. - auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); auto prop = PROPERTY_PAIR("prop"); - ASSERT_TRUE(v1.SetProperty(prop.second, storage::PropertyValue(42)).HasValue()); - auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); - auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}); + ASSERT_TRUE(v1.SetProperty(prop.second, storage::v3::PropertyValue(42)).HasValue()); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}); auto edge_type = dba.NameToEdgeType("Edge"); ASSERT_TRUE(dba.InsertEdge(&v1, &v2, edge_type).HasValue()); ASSERT_TRUE(dba.InsertEdge(&v1, &v3, edge_type).HasValue()); @@ -794,7 +795,7 @@ TEST_F(QueryPlanCRUDTest, FilterRemove) { auto scan_all = MakeScanAll(storage, symbol_table, "n"); std::get<0>(scan_all.node_->properties_)[storage.GetPropertyIx(prop.first)] = LITERAL(42); auto expand = MakeExpand(storage, symbol_table, scan_all.op_, scan_all.sym_, "r", EdgeAtom::Direction::BOTH, {}, "m", - false, storage::View::OLD); + false, storage::v3::View::OLD); auto filter_prop = PROPERTY_LOOKUP(IDENT("n")->MapTo(scan_all.sym_), prop); auto filter = std::make_shared(expand.op_, LESS(filter_prop, LITERAL(43))); // REMOVE n.prop @@ -803,14 +804,14 @@ TEST_F(QueryPlanCRUDTest, FilterRemove) { auto context = MakeContext(storage, symbol_table, &dba); EXPECT_EQ(2, PullAll(*rem, &context)); dba.AdvanceCommand(); - EXPECT_EQ(v1.GetProperty(storage::View::OLD, prop.second)->type(), storage::PropertyValue::Type::Null); + EXPECT_EQ(v1.GetProperty(storage::v3::View::OLD, prop.second)->type(), storage::v3::PropertyValue::Type::Null); } TEST_F(QueryPlanCRUDTest, SetRemove) { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); auto label1 = dba.NameToLabel("label1"); auto label2 = dba.NameToLabel("label2"); dba.AdvanceCommand(); @@ -821,13 +822,14 @@ TEST_F(QueryPlanCRUDTest, SetRemove) { // MATCH (n) SET n :label1 :label2 REMOVE n :label1 :label2 auto scan_all = MakeScanAll(storage, symbol_table, "n"); auto set = - std::make_shared(scan_all.op_, scan_all.sym_, std::vector{label1, label2}); - auto rem = std::make_shared(set, scan_all.sym_, std::vector{label1, label2}); + std::make_shared(scan_all.op_, scan_all.sym_, std::vector{label1, label2}); + auto rem = + std::make_shared(set, scan_all.sym_, std::vector{label1, label2}); auto context = MakeContext(storage, symbol_table, &dba); EXPECT_EQ(1, PullAll(*rem, &context)); dba.AdvanceCommand(); - EXPECT_FALSE(*v.HasLabel(storage::View::OLD, label1)); - EXPECT_FALSE(*v.HasLabel(storage::View::OLD, label2)); + EXPECT_FALSE(*v.HasLabel(storage::v3::View::OLD, label1)); + EXPECT_FALSE(*v.HasLabel(storage::v3::View::OLD, label2)); } TEST_F(QueryPlanCRUDTest, Merge) { @@ -840,11 +842,11 @@ TEST_F(QueryPlanCRUDTest, Merge) { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); - auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); ASSERT_TRUE(dba.InsertEdge(&v1, &v2, dba.NameToEdgeType("Type")).HasValue()); - auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}); dba.AdvanceCommand(); AstStorage storage; @@ -855,7 +857,7 @@ TEST_F(QueryPlanCRUDTest, Merge) { // merge_match branch auto r_m = MakeExpand(storage, symbol_table, std::make_shared(), n.sym_, "r", EdgeAtom::Direction::BOTH, {}, - "m", false, storage::View::OLD); + "m", false, storage::v3::View::OLD); auto m_p = PROPERTY_LOOKUP(IDENT("m")->MapTo(r_m.node_sym_), prop); auto m_set = std::make_shared(r_m.op_, prop.second, m_p, LITERAL(1)); @@ -868,12 +870,12 @@ TEST_F(QueryPlanCRUDTest, Merge) { ASSERT_EQ(3, PullAll(*merge, &context)); dba.AdvanceCommand(); - ASSERT_EQ(v1.GetProperty(storage::View::OLD, prop.second)->type(), storage::PropertyValue::Type::Int); - ASSERT_EQ(v1.GetProperty(storage::View::OLD, prop.second)->ValueInt(), 1); - ASSERT_EQ(v2.GetProperty(storage::View::OLD, prop.second)->type(), storage::PropertyValue::Type::Int); - ASSERT_EQ(v2.GetProperty(storage::View::OLD, prop.second)->ValueInt(), 1); - ASSERT_EQ(v3.GetProperty(storage::View::OLD, prop.second)->type(), storage::PropertyValue::Type::Int); - ASSERT_EQ(v3.GetProperty(storage::View::OLD, prop.second)->ValueInt(), 2); + ASSERT_EQ(v1.GetProperty(storage::v3::View::OLD, prop.second)->type(), storage::v3::PropertyValue::Type::Int); + ASSERT_EQ(v1.GetProperty(storage::v3::View::OLD, prop.second)->ValueInt(), 1); + ASSERT_EQ(v2.GetProperty(storage::v3::View::OLD, prop.second)->type(), storage::v3::PropertyValue::Type::Int); + ASSERT_EQ(v2.GetProperty(storage::v3::View::OLD, prop.second)->ValueInt(), 1); + ASSERT_EQ(v3.GetProperty(storage::v3::View::OLD, prop.second)->type(), storage::v3::PropertyValue::Type::Int); + ASSERT_EQ(v3.GetProperty(storage::v3::View::OLD, prop.second)->ValueInt(), 2); } TEST_F(QueryPlanCRUDTest, MergeNoInput) { @@ -887,21 +889,21 @@ TEST_F(QueryPlanCRUDTest, MergeNoInput) { NodeCreationInfo node; node.symbol = symbol_table.CreateSymbol("n", true); node.labels = {label}; - std::get>>(node.properties) + std::get>>(node.properties) .emplace_back(property, LITERAL(1)); auto create = std::make_shared(nullptr, node); auto merge = std::make_shared(nullptr, create, create); - EXPECT_EQ(0, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(0, CountIterable(dba.Vertices(storage::v3::View::OLD))); auto context = MakeContext(storage, symbol_table, &dba); EXPECT_EQ(1, PullAll(*merge, &context)); dba.AdvanceCommand(); - EXPECT_EQ(1, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::v3::View::OLD))); } TEST(QueryPlan, SetPropertyOnNull) { // SET (Null).prop = 42 - storage::Storage db; + storage::v3::Storage db; auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); AstStorage storage; @@ -918,7 +920,7 @@ TEST(QueryPlan, SetPropertyOnNull) { TEST(QueryPlan, SetPropertiesOnNull) { // OPTIONAL MATCH (n) SET n = n - storage::Storage db; + storage::v3::Storage db; auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); AstStorage storage; @@ -927,14 +929,14 @@ TEST(QueryPlan, SetPropertiesOnNull) { auto n_ident = IDENT("n")->MapTo(n.sym_); auto optional = std::make_shared(nullptr, n.op_, std::vector{n.sym_}); auto set_op = std::make_shared(optional, n.sym_, n_ident, plan::SetProperties::Op::REPLACE); - EXPECT_EQ(0, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(0, CountIterable(dba.Vertices(storage::v3::View::OLD))); auto context = MakeContext(storage, symbol_table, &dba); EXPECT_EQ(1, PullAll(*set_op, &context)); } TEST(QueryPlan, SetLabelsOnNull) { // OPTIONAL MATCH (n) SET n :label - storage::Storage db; + storage::v3::Storage db; auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); auto label = dba.NameToLabel("label"); @@ -942,15 +944,15 @@ TEST(QueryPlan, SetLabelsOnNull) { SymbolTable symbol_table; auto n = MakeScanAll(storage, symbol_table, "n"); auto optional = std::make_shared(nullptr, n.op_, std::vector{n.sym_}); - auto set_op = std::make_shared(optional, n.sym_, std::vector{label}); - EXPECT_EQ(0, CountIterable(dba.Vertices(storage::View::OLD))); + auto set_op = std::make_shared(optional, n.sym_, std::vector{label}); + EXPECT_EQ(0, CountIterable(dba.Vertices(storage::v3::View::OLD))); auto context = MakeContext(storage, symbol_table, &dba); EXPECT_EQ(1, PullAll(*set_op, &context)); } TEST(QueryPlan, RemovePropertyOnNull) { // REMOVE (Null).prop - storage::Storage db; + storage::v3::Storage db; auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); AstStorage storage; @@ -966,7 +968,7 @@ TEST(QueryPlan, RemovePropertyOnNull) { TEST(QueryPlan, RemoveLabelsOnNull) { // OPTIONAL MATCH (n) REMOVE n :label - storage::Storage db; + storage::v3::Storage db; auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); auto label = dba.NameToLabel("label"); @@ -974,8 +976,8 @@ TEST(QueryPlan, RemoveLabelsOnNull) { SymbolTable symbol_table; auto n = MakeScanAll(storage, symbol_table, "n"); auto optional = std::make_shared(nullptr, n.op_, std::vector{n.sym_}); - auto remove_op = std::make_shared(optional, n.sym_, std::vector{label}); - EXPECT_EQ(0, CountIterable(dba.Vertices(storage::View::OLD))); + auto remove_op = std::make_shared(optional, n.sym_, std::vector{label}); + EXPECT_EQ(0, CountIterable(dba.Vertices(storage::v3::View::OLD))); auto context = MakeContext(storage, symbol_table, &dba); EXPECT_EQ(1, PullAll(*remove_op, &context)); } @@ -985,9 +987,9 @@ TEST_F(QueryPlanCRUDTest, DeleteSetProperty) { DbAccessor dba(&storage_dba); // Add a single vertex. - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); dba.AdvanceCommand(); - EXPECT_EQ(1, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::v3::View::OLD))); AstStorage storage; SymbolTable symbol_table; // MATCH (n) DELETE n SET n.prop = 42 @@ -1007,9 +1009,9 @@ TEST_F(QueryPlanCRUDTest, DeleteSetPropertiesFromMap) { DbAccessor dba(&storage_dba); // Add a single vertex. - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); dba.AdvanceCommand(); - EXPECT_EQ(1, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::v3::View::OLD))); AstStorage storage; SymbolTable symbol_table; // MATCH (n) DELETE n SET n = {prop: 42} @@ -1033,11 +1035,11 @@ TEST_F(QueryPlanCRUDTest, DeleteSetPropertiesFrom) { // Add a single vertex. { - auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); - ASSERT_TRUE(v.SetProperty(dba.NameToProperty("prop"), storage::PropertyValue(1)).HasValue()); + auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(v.SetProperty(dba.NameToProperty("prop"), storage::v3::PropertyValue(1)).HasValue()); } dba.AdvanceCommand(); - EXPECT_EQ(1, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::v3::View::OLD))); AstStorage storage; SymbolTable symbol_table; // MATCH (n) DELETE n SET n = n @@ -1057,16 +1059,16 @@ TEST_F(QueryPlanCRUDTest, DeleteRemoveLabels) { DbAccessor dba(&storage_dba); // Add a single vertex. - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); dba.AdvanceCommand(); - EXPECT_EQ(1, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::v3::View::OLD))); AstStorage storage; SymbolTable symbol_table; // MATCH (n) DELETE n REMOVE n :label auto n = MakeScanAll(storage, symbol_table, "n"); auto n_get = storage.Create("n")->MapTo(n.sym_); auto delete_op = std::make_shared(n.op_, std::vector{n_get}, false); - std::vector labels{dba.NameToLabel("label1")}; + std::vector labels{dba.NameToLabel("label1")}; auto rem_op = std::make_shared(delete_op, n.sym_, labels); auto context = MakeContext(storage, symbol_table, &dba); EXPECT_THROW(PullAll(*rem_op, &context), QueryRuntimeException); @@ -1077,9 +1079,9 @@ TEST_F(QueryPlanCRUDTest, DeleteRemoveProperty) { DbAccessor dba(&storage_dba); // Add a single vertex. - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); dba.AdvanceCommand(); - EXPECT_EQ(1, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::v3::View::OLD))); AstStorage storage; SymbolTable symbol_table; // MATCH (n) DELETE n REMOVE n.prop diff --git a/tests/unit/query_v2_query_plan_edge_cases.cpp b/tests/unit/query_v2_query_plan_edge_cases.cpp index 2abeab496..cf268bc27 100644 --- a/tests/unit/query_v2_query_plan_edge_cases.cpp +++ b/tests/unit/query_v2_query_plan_edge_cases.cpp @@ -19,22 +19,23 @@ #include #include -#include "communication/result_stream_faker.hpp" -#include "query/interpreter.hpp" -#include "storage/v2/storage.hpp" +#include "query/v2/interpreter.hpp" +#include "result_stream_faker.hpp" +#include "storage/v3/storage.hpp" DECLARE_bool(query_cost_planner); -namespace memgraph::query::tests { +namespace memgraph::query::v2::tests { class QueryExecution : public testing::Test { protected: - storage::Storage db; - std::optional db_; + storage::v3::Storage db; + std::optional db_; std::optional interpreter_context_; std::optional interpreter_; - std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_query_plan_edge_cases"}; + std::filesystem::path data_directory{std::filesystem::temp_directory_path() / + "MG_tests_unit_query_v2_query_plan_edge_cases"}; void SetUp() { db_.emplace(); @@ -112,4 +113,4 @@ TEST_F(QueryExecution, EdgeUniquenessInOptional) { .size(), 3); } -} // namespace memgraph::query::tests +} // namespace memgraph::query::v2::tests diff --git a/tests/unit/query_v2_query_plan_match_filter_return.cpp b/tests/unit/query_v2_query_plan_match_filter_return.cpp index 079b3f5f4..8be276772 100644 --- a/tests/unit/query_v2_query_plan_match_filter_return.cpp +++ b/tests/unit/query_v2_query_plan_match_filter_return.cpp @@ -24,14 +24,16 @@ #include #include -#include "query/context.hpp" -#include "query/exceptions.hpp" -#include "query/plan/operator.hpp" -#include "query_plan_common.hpp" -#include "storage/v2/property_value.hpp" +#include "query/v2/context.hpp" +#include "query/v2/exceptions.hpp" +#include "query/v2/plan/operator.hpp" +#include "query_v2_query_common.hpp" +#include "storage/v3/property_value.hpp" -using namespace memgraph::query; -using namespace memgraph::query::plan; +#include "query_v2_query_plan_common.hpp" + +using namespace memgraph::query::v2; +using namespace memgraph::query::v2::plan; namespace std { template <> @@ -45,20 +47,20 @@ namespace memgraph::query::tests { class MatchReturnFixture : public testing::Test { protected: void SetUp() override { - ASSERT_TRUE(db.CreateSchema(label, {storage::SchemaProperty{property, common::SchemaType::INT}})); + ASSERT_TRUE(db.CreateSchema(label, {storage::v3::SchemaProperty{property, common::SchemaType::INT}})); } - storage::Storage db; - storage::Storage::Accessor storage_dba{db.Access()}; + storage::v3::Storage db; + storage::v3::Storage::Accessor storage_dba{db.Access()}; DbAccessor dba{&storage_dba}; - const storage::LabelId label{db.NameToLabel("label")}; - const storage::PropertyId property{db.NameToProperty("property")}; + const storage::v3::LabelId label{db.NameToLabel("label")}; + const storage::v3::PropertyId property{db.NameToProperty("property")}; AstStorage storage; SymbolTable symbol_table; void AddVertices(int count) { for (int i = 0; i < count; i++) { - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(i)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i)}}).HasValue()); } } @@ -74,7 +76,7 @@ TEST_F(MatchReturnFixture, MatchReturn) { AddVertices(2); dba.AdvanceCommand(); - auto test_pull_count = [&](storage::View view) { + auto test_pull_count = [&](storage::v3::View view) { auto scan_all = MakeScanAll(storage, symbol_table, "n", nullptr, view); auto output = NEXPR("n", IDENT("n")->MapTo(scan_all.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); @@ -83,13 +85,13 @@ TEST_F(MatchReturnFixture, MatchReturn) { return PullAll(*produce, &context); }; - EXPECT_EQ(2, test_pull_count(storage::View::NEW)); - EXPECT_EQ(2, test_pull_count(storage::View::OLD)); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); - EXPECT_EQ(3, test_pull_count(storage::View::NEW)); - EXPECT_EQ(2, test_pull_count(storage::View::OLD)); + EXPECT_EQ(2, test_pull_count(storage::v3::View::NEW)); + EXPECT_EQ(2, test_pull_count(storage::v3::View::OLD)); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); + EXPECT_EQ(3, test_pull_count(storage::v3::View::NEW)); + EXPECT_EQ(2, test_pull_count(storage::v3::View::OLD)); dba.AdvanceCommand(); - EXPECT_EQ(3, test_pull_count(storage::View::OLD)); + EXPECT_EQ(3, test_pull_count(storage::v3::View::OLD)); } TEST_F(MatchReturnFixture, MatchReturnPath) { @@ -105,7 +107,7 @@ TEST_F(MatchReturnFixture, MatchReturnPath) { auto results = PathResults(produce); ASSERT_EQ(results.size(), 2); std::vector expected_paths; - for (const auto &v : dba.Vertices(storage::View::OLD)) expected_paths.emplace_back(v); + for (const auto &v : dba.Vertices(storage::v3::View::OLD)) expected_paths.emplace_back(v); ASSERT_EQ(expected_paths.size(), 2); EXPECT_TRUE(std::is_permutation(expected_paths.begin(), expected_paths.end(), results.begin())); } @@ -113,22 +115,22 @@ TEST_F(MatchReturnFixture, MatchReturnPath) { class QueryPlanMatchFilterTest : public testing::Test { protected: QueryPlanMatchFilterTest() { - EXPECT_TRUE(db.CreateSchema(label, {storage::SchemaProperty{property, common::SchemaType::INT}})); + EXPECT_TRUE(db.CreateSchema(label, {storage::v3::SchemaProperty{property, common::SchemaType::INT}})); } - storage::Storage db; - storage::LabelId label = db.NameToLabel("label"); - storage::PropertyId property = db.NameToProperty("property"); + storage::v3::Storage db; + storage::v3::LabelId label = db.NameToLabel("label"); + storage::v3::PropertyId property = db.NameToProperty("property"); }; TEST_F(QueryPlanMatchFilterTest, MatchReturnCartesian) { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}) + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}) ->AddLabel(dba.NameToLabel("l1")) .HasValue()); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}) + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}) ->AddLabel(dba.NameToLabel("l2")) .HasValue()); @@ -156,8 +158,8 @@ TEST_F(QueryPlanMatchFilterTest, StandaloneReturn) { DbAccessor dba(&storage_dba); // add a few nodes to the database - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}).HasValue()); dba.AdvanceCommand(); AstStorage storage; @@ -179,14 +181,14 @@ TEST_F(QueryPlanMatchFilterTest, NodeFilterLabelsAndProperties) { DbAccessor dba(&storage_dba); // add a few nodes to the database - storage::LabelId label1 = dba.NameToLabel("Label1"); + storage::v3::LabelId label1 = dba.NameToLabel("Label1"); auto property1 = PROPERTY_PAIR("Property1"); - auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); - auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); - auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}); - auto v4 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(4)}}); - auto v5 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(5)}}); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(6)}}).HasValue()); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}); + auto v4 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(4)}}); + auto v5 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(5)}}); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(6)}}).HasValue()); // test all combination of (label | no_label) * (no_prop | wrong_prop | // right_prop) @@ -195,10 +197,10 @@ TEST_F(QueryPlanMatchFilterTest, NodeFilterLabelsAndProperties) { ASSERT_TRUE(v2.AddLabel(label1).HasValue()); ASSERT_TRUE(v3.AddLabel(label1).HasValue()); // v1 and v4 will have the right properties - ASSERT_TRUE(v1.SetProperty(property1.second, storage::PropertyValue(42)).HasValue()); - ASSERT_TRUE(v2.SetProperty(property1.second, storage::PropertyValue(1)).HasValue()); - ASSERT_TRUE(v4.SetProperty(property1.second, storage::PropertyValue(42)).HasValue()); - ASSERT_TRUE(v5.SetProperty(property1.second, storage::PropertyValue(1)).HasValue()); + ASSERT_TRUE(v1.SetProperty(property1.second, storage::v3::PropertyValue(42)).HasValue()); + ASSERT_TRUE(v2.SetProperty(property1.second, storage::v3::PropertyValue(1)).HasValue()); + ASSERT_TRUE(v4.SetProperty(property1.second, storage::v3::PropertyValue(42)).HasValue()); + ASSERT_TRUE(v5.SetProperty(property1.second, storage::v3::PropertyValue(1)).HasValue()); dba.AdvanceCommand(); AstStorage storage; @@ -233,28 +235,28 @@ TEST_F(QueryPlanMatchFilterTest, NodeFilterMultipleLabels) { DbAccessor dba(&storage_dba); // add a few nodes to the database - storage::LabelId label1 = dba.NameToLabel("label1"); - storage::LabelId label2 = dba.NameToLabel("label2"); - storage::LabelId label3 = dba.NameToLabel("label3"); + storage::v3::LabelId label1 = dba.NameToLabel("label1"); + storage::v3::LabelId label2 = dba.NameToLabel("label2"); + storage::v3::LabelId label3 = dba.NameToLabel("label3"); // the test will look for nodes that have label1 and label2 ASSERT_TRUE( - dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); // NOT accepted - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}) + dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); // NOT accepted + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}) ->AddLabel(label1) .HasValue()); // NOT accepted - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}) + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}) ->AddLabel(label2) .HasValue()); // NOT accepted - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(4)}}) + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(4)}}) ->AddLabel(label3) - .HasValue()); // NOT accepted - auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(5)}}); // YES accepted + .HasValue()); // NOT accepted + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(5)}}); // YES accepted ASSERT_TRUE(v1.AddLabel(label1).HasValue()); ASSERT_TRUE(v1.AddLabel(label2).HasValue()); - auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(6)}}); // NOT accepted + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(6)}}); // NOT accepted ASSERT_TRUE(v2.AddLabel(label1).HasValue()); ASSERT_TRUE(v2.AddLabel(label3).HasValue()); - auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(7)}}); // YES accepted + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(7)}}); // YES accepted ASSERT_TRUE(v3.AddLabel(label1).HasValue()); ASSERT_TRUE(v3.AddLabel(label2).HasValue()); ASSERT_TRUE(v3.AddLabel(label3).HasValue()); @@ -286,7 +288,7 @@ TEST_F(QueryPlanMatchFilterTest, Cartesian) { DbAccessor dba(&storage_dba); auto add_vertex = [&dba, this](std::string label1) { - auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); MG_ASSERT(vertex.AddLabel(dba.NameToLabel(label1)).HasValue()); return vertex; }; @@ -345,7 +347,7 @@ TEST_F(QueryPlanMatchFilterTest, CartesianThreeWay) { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); auto add_vertex = [&dba, this](std::string label1) { - auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); MG_ASSERT(vertex.AddLabel(dba.NameToLabel(label1)).HasValue()); return vertex; }; @@ -390,16 +392,16 @@ TEST_F(QueryPlanMatchFilterTest, CartesianThreeWay) { class ExpandFixture : public QueryPlanMatchFilterTest { protected: - storage::Storage::Accessor storage_dba{db.Access()}; + storage::v3::Storage::Accessor storage_dba{db.Access()}; DbAccessor dba{&storage_dba}; AstStorage storage; SymbolTable symbol_table; // make a V-graph (v3)<-[r2]-(v1)-[r1]->(v2) - VertexAccessor v1{*dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}})}; - VertexAccessor v2{*dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}})}; - VertexAccessor v3{*dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}})}; - storage::EdgeTypeId edge_type{db.NameToEdgeType("Edge")}; + VertexAccessor v1{*dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}})}; + VertexAccessor v2{*dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}})}; + VertexAccessor v3{*dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}})}; + storage::v3::EdgeTypeId edge_type{db.NameToEdgeType("Edge")}; EdgeAccessor r1{*dba.InsertEdge(&v1, &v2, edge_type)}; EdgeAccessor r2{*dba.InsertEdge(&v1, &v3, edge_type)}; @@ -412,7 +414,7 @@ class ExpandFixture : public QueryPlanMatchFilterTest { }; TEST_F(ExpandFixture, Expand) { - auto test_expand = [&](EdgeAtom::Direction direction, storage::View view) { + auto test_expand = [&](EdgeAtom::Direction direction, storage::v3::View view) { auto n = MakeScanAll(storage, symbol_table, "n"); auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", direction, {}, "m", false, view); @@ -426,22 +428,22 @@ TEST_F(ExpandFixture, Expand) { // test that expand works well for both old and new graph state ASSERT_TRUE(dba.InsertEdge(&v1, &v2, edge_type).HasValue()); ASSERT_TRUE(dba.InsertEdge(&v1, &v3, edge_type).HasValue()); - EXPECT_EQ(2, test_expand(EdgeAtom::Direction::OUT, storage::View::OLD)); - EXPECT_EQ(2, test_expand(EdgeAtom::Direction::IN, storage::View::OLD)); - EXPECT_EQ(4, test_expand(EdgeAtom::Direction::BOTH, storage::View::OLD)); - EXPECT_EQ(4, test_expand(EdgeAtom::Direction::OUT, storage::View::NEW)); - EXPECT_EQ(4, test_expand(EdgeAtom::Direction::IN, storage::View::NEW)); - EXPECT_EQ(8, test_expand(EdgeAtom::Direction::BOTH, storage::View::NEW)); + EXPECT_EQ(2, test_expand(EdgeAtom::Direction::OUT, storage::v3::View::OLD)); + EXPECT_EQ(2, test_expand(EdgeAtom::Direction::IN, storage::v3::View::OLD)); + EXPECT_EQ(4, test_expand(EdgeAtom::Direction::BOTH, storage::v3::View::OLD)); + EXPECT_EQ(4, test_expand(EdgeAtom::Direction::OUT, storage::v3::View::NEW)); + EXPECT_EQ(4, test_expand(EdgeAtom::Direction::IN, storage::v3::View::NEW)); + EXPECT_EQ(8, test_expand(EdgeAtom::Direction::BOTH, storage::v3::View::NEW)); dba.AdvanceCommand(); - EXPECT_EQ(4, test_expand(EdgeAtom::Direction::OUT, storage::View::OLD)); - EXPECT_EQ(4, test_expand(EdgeAtom::Direction::IN, storage::View::OLD)); - EXPECT_EQ(8, test_expand(EdgeAtom::Direction::BOTH, storage::View::OLD)); + EXPECT_EQ(4, test_expand(EdgeAtom::Direction::OUT, storage::v3::View::OLD)); + EXPECT_EQ(4, test_expand(EdgeAtom::Direction::IN, storage::v3::View::OLD)); + EXPECT_EQ(8, test_expand(EdgeAtom::Direction::BOTH, storage::v3::View::OLD)); } TEST_F(ExpandFixture, ExpandPath) { auto n = MakeScanAll(storage, symbol_table, "n"); auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "m", false, - storage::View::OLD); + storage::v3::View::OLD); Symbol path_sym = symbol_table.CreateSymbol("path", true); auto path = std::make_shared(r_m.op_, path_sym, std::vector{n.sym_, r_m.edge_sym_, r_m.node_sym_}); @@ -476,12 +478,12 @@ class QueryPlanExpandVariable : public QueryPlanMatchFilterTest { // a lot below in test declaration using map_int = std::unordered_map; - storage::Storage::Accessor storage_dba{db.Access()}; + storage::v3::Storage::Accessor storage_dba{db.Access()}; DbAccessor dba{&storage_dba}; // labels for layers in the double chain - std::vector labels; + std::vector labels; // for all the edges - storage::EdgeTypeId edge_type = dba.NameToEdgeType("edge_type"); + storage::v3::EdgeTypeId edge_type = dba.NameToEdgeType("edge_type"); AstStorage storage; SymbolTable symbol_table; @@ -495,8 +497,8 @@ class QueryPlanExpandVariable : public QueryPlanMatchFilterTest { std::vector layer; for (int from_layer_ind = -1; from_layer_ind < chain_length - 1; from_layer_ind++) { std::vector new_layer{ - *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}), - *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}})}; + *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}), + *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}})}; auto label = dba.NameToLabel(std::to_string(from_layer_ind + 1)); labels.push_back(label); for (size_t v_to_ind = 0; v_to_ind < new_layer.size(); v_to_ind++) { @@ -506,16 +508,16 @@ class QueryPlanExpandVariable : public QueryPlanMatchFilterTest { auto &v_from = layer[v_from_ind]; auto edge = dba.InsertEdge(&v_from, &v_to, edge_type); ASSERT_TRUE(edge->SetProperty(dba.NameToProperty("p"), - storage::PropertyValue(fmt::format("V{}{}->V{}{}", from_layer_ind, v_from_ind, - from_layer_ind + 1, v_to_ind))) + storage::v3::PropertyValue(fmt::format( + "V{}{}->V{}{}", from_layer_ind, v_from_ind, from_layer_ind + 1, v_to_ind))) .HasValue()); } } layer = new_layer; } dba.AdvanceCommand(); - ASSERT_EQ(CountIterable(dba.Vertices(storage::View::OLD)), 2 * chain_length); - ASSERT_EQ(CountEdges(&dba, storage::View::OLD), 4 * (chain_length - 1)); + ASSERT_EQ(CountIterable(dba.Vertices(storage::v3::View::OLD)), 2 * chain_length); + ASSERT_EQ(CountEdges(&dba, storage::v3::View::OLD), 4 * (chain_length - 1)); } /** @@ -532,9 +534,10 @@ class QueryPlanExpandVariable : public QueryPlanMatchFilterTest { template std::shared_ptr AddMatch(std::shared_ptr input_op, const std::string &node_from, int layer, EdgeAtom::Direction direction, - const std::vector &edge_types, + const std::vector &edge_types, std::optional lower, std::optional upper, Symbol edge_sym, - const std::string &node_to, storage::View view, bool is_reverse = false) { + const std::string &node_to, storage::v3::View view, + bool is_reverse = false) { auto n_from = MakeScanAll(storage, symbol_table, node_from, input_op); auto filter_op = std::make_shared( n_from.op_, @@ -550,7 +553,7 @@ class QueryPlanExpandVariable : public QueryPlanMatchFilterTest { auto convert = [this](std::optional bound) { return bound ? LITERAL(static_cast(bound.value())) : nullptr; }; - MG_ASSERT(view == storage::View::OLD, "ExpandVariable should only be planned with storage::View::OLD"); + MG_ASSERT(view == storage::v3::View::OLD, "ExpandVariable should only be planned with storage::v3::View::OLD"); return std::make_shared(filter_op, n_from.sym_, n_to_sym, edge_sym, EdgeAtom::Type::DEPTH_FIRST, direction, edge_types, is_reverse, convert(lower), convert(upper), false, @@ -617,9 +620,9 @@ TEST_F(QueryPlanExpandVariable, OneVariableExpansion) { auto test_expand = [&](int layer, EdgeAtom::Direction direction, std::optional lower, std::optional upper, bool reverse) { auto e = Edge("r", direction); - return GetEdgeListSizes( - AddMatch(nullptr, "n", layer, direction, {}, lower, upper, e, "m", storage::View::OLD, reverse), - e); + return GetEdgeListSizes(AddMatch(nullptr, "n", layer, direction, {}, lower, upper, e, "m", + storage::v3::View::OLD, reverse), + e); }; for (int reverse = 0; reverse < 2; ++reverse) { @@ -657,19 +660,19 @@ TEST_F(QueryPlanExpandVariable, EdgeUniquenessSingleAndVariableExpansion) { if (single_expansion_before) { symbols.push_back(Edge("r0", direction)); - last_op = - AddMatch(last_op, "n0", layer, direction, {}, lower, upper, symbols.back(), "m0", storage::View::OLD); + last_op = AddMatch(last_op, "n0", layer, direction, {}, lower, upper, symbols.back(), "m0", + storage::v3::View::OLD); } auto var_length_sym = Edge("r1", direction); symbols.push_back(var_length_sym); last_op = AddMatch(last_op, "n1", layer, direction, {}, lower, upper, var_length_sym, "m1", - storage::View::OLD); + storage::v3::View::OLD); if (!single_expansion_before) { symbols.push_back(Edge("r2", direction)); - last_op = - AddMatch(last_op, "n2", layer, direction, {}, lower, upper, symbols.back(), "m2", storage::View::OLD); + last_op = AddMatch(last_op, "n2", layer, direction, {}, lower, upper, symbols.back(), "m2", + storage::v3::View::OLD); } if (add_uniqueness_check) { @@ -693,10 +696,10 @@ TEST_F(QueryPlanExpandVariable, EdgeUniquenessTwoVariableExpansions) { std::optional upper, bool add_uniqueness_check) { auto e1 = Edge("r1", direction); auto first = - AddMatch(nullptr, "n1", layer, direction, {}, lower, upper, e1, "m1", storage::View::OLD); + AddMatch(nullptr, "n1", layer, direction, {}, lower, upper, e1, "m1", storage::v3::View::OLD); auto e2 = Edge("r2", direction); auto last_op = - AddMatch(first, "n2", layer, direction, {}, lower, upper, e2, "m2", storage::View::OLD); + AddMatch(first, "n2", layer, direction, {}, lower, upper, e2, "m2", storage::v3::View::OLD); if (add_uniqueness_check) { last_op = std::make_shared(last_op, e2, std::vector{e1}); } @@ -711,7 +714,7 @@ TEST_F(QueryPlanExpandVariable, EdgeUniquenessTwoVariableExpansions) { TEST_F(QueryPlanExpandVariable, NamedPath) { auto e = Edge("r", EdgeAtom::Direction::OUT); auto expand = - AddMatch(nullptr, "n", 0, EdgeAtom::Direction::OUT, {}, 2, 2, e, "m", storage::View::OLD); + AddMatch(nullptr, "n", 0, EdgeAtom::Direction::OUT, {}, 2, 2, e, "m", storage::v3::View::OLD); auto find_symbol = [this](const std::string &name) { for (const auto &sym : symbol_table.table()) if (sym.second.name() == name) return sym.second; @@ -723,11 +726,11 @@ TEST_F(QueryPlanExpandVariable, NamedPath) { std::vector{find_symbol("n"), e, find_symbol("m")}); std::vector expected_paths; - for (const auto &v : dba.Vertices(storage::View::OLD)) { - if (!*v.HasLabel(storage::View::OLD, labels[0])) continue; - auto maybe_edges1 = v.OutEdges(storage::View::OLD); + for (const auto &v : dba.Vertices(storage::v3::View::OLD)) { + if (!*v.HasLabel(storage::v3::View::OLD, labels[0])) continue; + auto maybe_edges1 = v.OutEdges(storage::v3::View::OLD); for (const auto &e1 : *maybe_edges1) { - auto maybe_edges2 = e1.To().OutEdges(storage::View::OLD); + auto maybe_edges2 = e1.To().OutEdges(storage::v3::View::OLD); for (const auto &e2 : *maybe_edges2) { expected_paths.emplace_back(v, e1, e1.To(), e2, e2.To()); } @@ -748,7 +751,7 @@ TEST_F(QueryPlanExpandVariable, ExpandToSameSymbol) { auto node = NODE("n"); auto symbol = symbol_table.CreateSymbol("n", true); node->identifier_->MapTo(symbol); - auto logical_op = std::make_shared(nullptr, symbol, storage::View::OLD); + auto logical_op = std::make_shared(nullptr, symbol, storage::v3::View::OLD); auto n_from = ScanAllTuple{node, logical_op, symbol}; auto filter_op = std::make_shared( @@ -761,14 +764,14 @@ TEST_F(QueryPlanExpandVariable, ExpandToSameSymbol) { return bound ? LITERAL(static_cast(bound.value())) : nullptr; }; - return GetEdgeListSizes( - std::make_shared(filter_op, symbol, symbol, e, EdgeAtom::Type::DEPTH_FIRST, direction, - std::vector{}, reverse, convert(lower), convert(upper), - /* existing = */ true, - ExpansionLambda{symbol_table.CreateSymbol("inner_edge", false), - symbol_table.CreateSymbol("inner_node", false), nullptr}, - std::nullopt, std::nullopt), - e); + return GetEdgeListSizes(std::make_shared( + filter_op, symbol, symbol, e, EdgeAtom::Type::DEPTH_FIRST, direction, + std::vector{}, reverse, convert(lower), convert(upper), + /* existing = */ true, + ExpansionLambda{symbol_table.CreateSymbol("inner_edge", false), + symbol_table.CreateSymbol("inner_node", false), nullptr}, + std::nullopt, std::nullopt), + e); }; // The graph is a double chain: @@ -940,10 +943,10 @@ class QueryPlanExpandWeightedShortestPath : public QueryPlanMatchFilterTest { }; protected: - storage::Storage::Accessor storage_dba{db.Access()}; + storage::v3::Storage::Accessor storage_dba{db.Access()}; DbAccessor dba{&storage_dba}; - std::pair prop = PROPERTY_PAIR("property1"); - storage::EdgeTypeId edge_type = dba.NameToEdgeType("edge_type"); + std::pair prop = PROPERTY_PAIR("property1"); + storage::v3::EdgeTypeId edge_type = dba.NameToEdgeType("edge_type"); // make 5 vertices because we'll need to compare against them exactly // v[0] has `prop` with the value 0 @@ -966,13 +969,13 @@ class QueryPlanExpandWeightedShortestPath : public QueryPlanMatchFilterTest { void SetUp() { for (int i = 0; i < 5; i++) { - v.push_back(*dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(i)}})); - ASSERT_TRUE(v.back().SetProperty(prop.second, storage::PropertyValue(i)).HasValue()); + v.push_back(*dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i)}})); + ASSERT_TRUE(v.back().SetProperty(prop.second, storage::v3::PropertyValue(i)).HasValue()); } auto add_edge = [&](int from, int to, double weight) { auto edge = dba.InsertEdge(&v[from], &v[to], edge_type); - ASSERT_TRUE(edge->SetProperty(prop.second, storage::PropertyValue(weight)).HasValue()); + ASSERT_TRUE(edge->SetProperty(prop.second, storage::v3::PropertyValue(weight)).HasValue()); e.emplace(std::make_pair(from, to), *edge); }; @@ -1006,7 +1009,7 @@ class QueryPlanExpandWeightedShortestPath : public QueryPlanMatchFilterTest { auto edge_list_sym = symbol_table.CreateSymbol("edgelist_", true); auto filter_lambda = last_op = std::make_shared( last_op, n.sym_, node_sym, edge_list_sym, EdgeAtom::Type::WEIGHTED_SHORTEST_PATH, direction, - std::vector{}, false, nullptr, max_depth ? LITERAL(max_depth.value()) : nullptr, + std::vector{}, false, nullptr, max_depth ? LITERAL(max_depth.value()) : nullptr, existing_node_input != nullptr, ExpansionLambda{filter_edge, filter_node, where}, ExpansionLambda{weight_edge, weight_node, PROPERTY_LOOKUP(ident_e, prop)}, total_weight); @@ -1026,12 +1029,12 @@ class QueryPlanExpandWeightedShortestPath : public QueryPlanMatchFilterTest { template auto GetProp(const TAccessor &accessor) { - return accessor.GetProperty(storage::View::OLD, prop.second)->ValueInt(); + return accessor.GetProperty(storage::v3::View::OLD, prop.second)->ValueInt(); } template auto GetDoubleProp(const TAccessor &accessor) { - return accessor.GetProperty(storage::View::OLD, prop.second)->ValueDouble(); + return accessor.GetProperty(storage::v3::View::OLD, prop.second)->ValueDouble(); } Expression *PropNe(Symbol symbol, int value) { @@ -1195,11 +1198,11 @@ TEST_F(QueryPlanExpandWeightedShortestPath, UpperBound) { EXPECT_EQ(results[2].total_weight, 12); } { - auto new_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); - ASSERT_TRUE(new_vertex.SetProperty(prop.second, storage::PropertyValue(5)).HasValue()); + auto new_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(new_vertex.SetProperty(prop.second, storage::v3::PropertyValue(5)).HasValue()); auto edge = dba.InsertEdge(&v[4], &new_vertex, edge_type); ASSERT_TRUE(edge.HasValue()); - ASSERT_TRUE(edge->SetProperty(prop.second, storage::PropertyValue(2)).HasValue()); + ASSERT_TRUE(edge->SetProperty(prop.second, storage::v3::PropertyValue(2)).HasValue()); dba.AdvanceCommand(); auto results = ExpandWShortest(EdgeAtom::Direction::BOTH, 3, LITERAL(true)); @@ -1219,21 +1222,21 @@ TEST_F(QueryPlanExpandWeightedShortestPath, UpperBound) { } TEST_F(QueryPlanExpandWeightedShortestPath, NonNumericWeight) { - auto new_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); - ASSERT_TRUE(new_vertex.SetProperty(prop.second, storage::PropertyValue(5)).HasValue()); + auto new_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(new_vertex.SetProperty(prop.second, storage::v3::PropertyValue(5)).HasValue()); auto edge = dba.InsertEdge(&v[4], &new_vertex, edge_type); ASSERT_TRUE(edge.HasValue()); - ASSERT_TRUE(edge->SetProperty(prop.second, storage::PropertyValue("not a number")).HasValue()); + ASSERT_TRUE(edge->SetProperty(prop.second, storage::v3::PropertyValue("not a number")).HasValue()); dba.AdvanceCommand(); EXPECT_THROW(ExpandWShortest(EdgeAtom::Direction::BOTH, 1000, LITERAL(true)), QueryRuntimeException); } TEST_F(QueryPlanExpandWeightedShortestPath, NegativeWeight) { - auto new_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); - ASSERT_TRUE(new_vertex.SetProperty(prop.second, storage::PropertyValue(5)).HasValue()); + auto new_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(new_vertex.SetProperty(prop.second, storage::v3::PropertyValue(5)).HasValue()); auto edge = dba.InsertEdge(&v[4], &new_vertex, edge_type); ASSERT_TRUE(edge.HasValue()); - ASSERT_TRUE(edge->SetProperty(prop.second, storage::PropertyValue(-10)).HasValue()); // negative weight + ASSERT_TRUE(edge->SetProperty(prop.second, storage::v3::PropertyValue(-10)).HasValue()); // negative weight dba.AdvanceCommand(); EXPECT_THROW(ExpandWShortest(EdgeAtom::Direction::BOTH, 1000, LITERAL(true)), QueryRuntimeException); } @@ -1252,20 +1255,20 @@ TEST_F(QueryPlanMatchFilterTest, ExpandOptional) { // graph (v2 {p: 2})<-[:T]-(v1 {p: 1})-[:T]->(v3 {p: 2}) auto prop = dba.NameToProperty("p"); auto edge_type = dba.NameToEdgeType("T"); - auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); - ASSERT_TRUE(v1.SetProperty(prop, storage::PropertyValue(1)).HasValue()); - auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); - ASSERT_TRUE(v2.SetProperty(prop, storage::PropertyValue(2)).HasValue()); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + ASSERT_TRUE(v1.SetProperty(prop, storage::v3::PropertyValue(1)).HasValue()); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); + ASSERT_TRUE(v2.SetProperty(prop, storage::v3::PropertyValue(2)).HasValue()); ASSERT_TRUE(dba.InsertEdge(&v1, &v2, edge_type).HasValue()); - auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(3)}}); - ASSERT_TRUE(v3.SetProperty(prop, storage::PropertyValue(2)).HasValue()); + auto v3 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(3)}}); + ASSERT_TRUE(v3.SetProperty(prop, storage::v3::PropertyValue(2)).HasValue()); ASSERT_TRUE(dba.InsertEdge(&v1, &v3, edge_type).HasValue()); dba.AdvanceCommand(); // MATCH (n) OPTIONAL MATCH (n)-[r]->(m) auto n = MakeScanAll(storage, symbol_table, "n"); auto r_m = MakeExpand(storage, symbol_table, nullptr, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "m", false, - storage::View::OLD); + storage::v3::View::OLD); auto optional = std::make_shared(n.op_, r_m.op_, std::vector{r_m.edge_sym_, r_m.node_sym_}); // RETURN n, r, m @@ -1280,8 +1283,8 @@ TEST_F(QueryPlanMatchFilterTest, ExpandOptional) { for (auto &row : results) { ASSERT_EQ(row[0].type(), TypedValue::Type::Vertex); auto &va = row[0].ValueVertex(); - auto va_p = *va.GetProperty(storage::View::OLD, prop); - ASSERT_EQ(va_p.type(), storage::PropertyValue::Type::Int); + auto va_p = *va.GetProperty(storage::v3::View::OLD, prop); + ASSERT_EQ(va_p.type(), storage::v3::PropertyValue::Type::Int); if (va_p.ValueInt() == 1) { v1_is_n_count++; EXPECT_EQ(row[1].type(), TypedValue::Type::Edge); @@ -1295,7 +1298,7 @@ TEST_F(QueryPlanMatchFilterTest, ExpandOptional) { } TEST(QueryPlan, OptionalMatchEmptyDB) { - storage::Storage db; + storage::v3::Storage db; auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); @@ -1315,7 +1318,7 @@ TEST(QueryPlan, OptionalMatchEmptyDB) { } TEST(QueryPlan, OptionalMatchEmptyDBExpandFromNode) { - storage::Storage db; + storage::v3::Storage db; auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); AstStorage storage; @@ -1330,7 +1333,7 @@ TEST(QueryPlan, OptionalMatchEmptyDBExpandFromNode) { auto with = MakeProduce(optional, n_ne); // MATCH (n) -[r]-> (m) auto r_m = MakeExpand(storage, symbol_table, with, with_n_sym, "r", EdgeAtom::Direction::OUT, {}, "m", false, - storage::View::OLD); + storage::v3::View::OLD); // RETURN m auto m_ne = NEXPR("m", IDENT("m")->MapTo(r_m.node_sym_))->MapTo(symbol_table.CreateSymbol("m", true)); auto produce = MakeProduce(r_m.op_, m_ne); @@ -1343,13 +1346,13 @@ TEST_F(QueryPlanMatchFilterTest, OptionalMatchThenExpandToMissingNode) { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); // Make a graph with 2 connected, unlabeled nodes. - auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); - auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); auto edge_type = dba.NameToEdgeType("edge_type"); ASSERT_TRUE(dba.InsertEdge(&v1, &v2, edge_type).HasValue()); dba.AdvanceCommand(); - EXPECT_EQ(2, CountIterable(dba.Vertices(storage::View::OLD))); - EXPECT_EQ(1, CountEdges(&dba, storage::View::OLD)); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::v3::View::OLD))); + EXPECT_EQ(1, CountEdges(&dba, storage::v3::View::OLD)); AstStorage storage; SymbolTable symbol_table; // OPTIONAL MATCH (n :missing) @@ -1374,7 +1377,7 @@ TEST_F(QueryPlanMatchFilterTest, OptionalMatchThenExpandToMissingNode) { auto node = NODE("n"); node->identifier_->MapTo(with_n_sym); auto expand = std::make_shared(m.op_, m.sym_, with_n_sym, edge_sym, edge_direction, - std::vector{}, true, storage::View::OLD); + std::vector{}, true, storage::v3::View::OLD); // RETURN m auto m_ne = NEXPR("m", IDENT("m")->MapTo(m.sym_))->MapTo(symbol_table.CreateSymbol("m", true)); auto produce = MakeProduce(expand, m_ne); @@ -1389,8 +1392,8 @@ TEST_F(QueryPlanMatchFilterTest, ExpandExistingNode) { // make a graph (v1)->(v2) that // has a recursive edge (v1)->(v1) - auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); - auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); auto edge_type = dba.NameToEdgeType("Edge"); ASSERT_TRUE(dba.InsertEdge(&v1, &v1, edge_type).HasValue()); ASSERT_TRUE(dba.InsertEdge(&v1, &v2, edge_type).HasValue()); @@ -1402,10 +1405,10 @@ TEST_F(QueryPlanMatchFilterTest, ExpandExistingNode) { auto test_existing = [&](bool with_existing, int expected_result_count) { auto n = MakeScanAll(storage, symbol_table, "n"); auto r_n = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "n", with_existing, - storage::View::OLD); + storage::v3::View::OLD); if (with_existing) r_n.op_ = std::make_shared(n.op_, n.sym_, n.sym_, r_n.edge_sym_, r_n.edge_->direction_, - std::vector{}, with_existing, storage::View::OLD); + std::vector{}, with_existing, storage::v3::View::OLD); // make a named expression and a produce auto output = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true)); @@ -1425,7 +1428,7 @@ TEST_F(QueryPlanMatchFilterTest, ExpandBothCycleEdgeCase) { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); ASSERT_TRUE(dba.InsertEdge(&v, &v, dba.NameToEdgeType("et")).HasValue()); dba.AdvanceCommand(); @@ -1434,7 +1437,7 @@ TEST_F(QueryPlanMatchFilterTest, ExpandBothCycleEdgeCase) { auto n = MakeScanAll(storage, symbol_table, "n"); auto r_ = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::BOTH, {}, "_", false, - storage::View::OLD); + storage::v3::View::OLD); auto context = MakeContext(storage, symbol_table, &dba); EXPECT_EQ(1, PullAll(*r_.op_, &context)); } @@ -1447,11 +1450,11 @@ TEST_F(QueryPlanMatchFilterTest, EdgeFilter) { // where only one edge will qualify // and there are all combinations of // (edge_type yes|no) * (property yes|absent|no) - std::vector edge_types; + std::vector edge_types; for (int j = 0; j < 2; ++j) edge_types.push_back(dba.NameToEdgeType("et" + std::to_string(j))); std::vector vertices; for (int i = 0; i < 7; ++i) { - vertices.push_back(*dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(i)}})); + vertices.push_back(*dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i)}})); } auto prop = PROPERTY_PAIR("property1"); std::vector edges; @@ -1459,10 +1462,10 @@ TEST_F(QueryPlanMatchFilterTest, EdgeFilter) { edges.push_back(*dba.InsertEdge(&vertices[0], &vertices[i + 1], edge_types[i % 2])); switch (i % 3) { case 0: - ASSERT_TRUE(edges.back().SetProperty(prop.second, storage::PropertyValue(42)).HasValue()); + ASSERT_TRUE(edges.back().SetProperty(prop.second, storage::v3::PropertyValue(42)).HasValue()); break; case 1: - ASSERT_TRUE(edges.back().SetProperty(prop.second, storage::PropertyValue(100)).HasValue()); + ASSERT_TRUE(edges.back().SetProperty(prop.second, storage::v3::PropertyValue(100)).HasValue()); break; default: break; @@ -1480,7 +1483,7 @@ TEST_F(QueryPlanMatchFilterTest, EdgeFilter) { auto n = MakeScanAll(storage, symbol_table, "n"); const auto &edge_type = edge_types[0]; auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {edge_type}, "m", false, - storage::View::OLD); + storage::v3::View::OLD); r_m.edge_->edge_types_.push_back(storage.GetEdgeTypeIx(dba.EdgeTypeToName(edge_type))); std::get<0>(r_m.edge_->properties_)[storage.GetPropertyIx(prop.first)] = LITERAL(42); auto *filter_expr = EQ(PROPERTY_LOOKUP(r_m.edge_->identifier_, prop), LITERAL(42)); @@ -1496,7 +1499,7 @@ TEST_F(QueryPlanMatchFilterTest, EdgeFilter) { EXPECT_EQ(1, test_filter()); // test that edge filtering always filters on old state - for (auto &edge : edges) ASSERT_TRUE(edge.SetProperty(prop.second, storage::PropertyValue(42)).HasValue()); + for (auto &edge : edges) ASSERT_TRUE(edge.SetProperty(prop.second, storage::v3::PropertyValue(42)).HasValue()); EXPECT_EQ(1, test_filter()); dba.AdvanceCommand(); EXPECT_EQ(3, test_filter()); @@ -1506,8 +1509,8 @@ TEST_F(QueryPlanMatchFilterTest, EdgeFilterMultipleTypes) { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); - auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); auto type_1 = dba.NameToEdgeType("type_1"); auto type_2 = dba.NameToEdgeType("type_2"); auto type_3 = dba.NameToEdgeType("type_3"); @@ -1522,7 +1525,7 @@ TEST_F(QueryPlanMatchFilterTest, EdgeFilterMultipleTypes) { // make a scan all auto n = MakeScanAll(storage, symbol_table, "n"); auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::OUT, {type_1, type_2}, "m", - false, storage::View::OLD); + false, storage::v3::View::OLD); // make a named expression and a produce auto output = @@ -1540,11 +1543,11 @@ TEST_F(QueryPlanMatchFilterTest, Filter) { // add a 6 nodes with property 'prop', 2 have true as value auto property1 = PROPERTY_PAIR("property1"); for (int i = 0; i < 6; ++i) { - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(i)}}) - ->SetProperty(property1.second, storage::PropertyValue(i % 3 == 0)) + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i)}}) + ->SetProperty(property1.second, storage::v3::PropertyValue(i % 3 == 0)) .HasValue()); } - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}) + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}) .HasValue()); // prop not set, gives NULL dba.AdvanceCommand(); @@ -1566,8 +1569,8 @@ TEST_F(QueryPlanMatchFilterTest, EdgeUniquenessFilter) { DbAccessor dba(&storage_dba); // make a graph that has (v1)->(v2) and a recursive edge (v1)->(v1) - auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); - auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + auto v1 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); + auto v2 = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); auto edge_type = dba.NameToEdgeType("edge_type"); ASSERT_TRUE(dba.InsertEdge(&v1, &v2, edge_type).HasValue()); ASSERT_TRUE(dba.InsertEdge(&v1, &v1, edge_type).HasValue()); @@ -1579,10 +1582,10 @@ TEST_F(QueryPlanMatchFilterTest, EdgeUniquenessFilter) { auto n1 = MakeScanAll(storage, symbol_table, "n1"); auto r1_n2 = MakeExpand(storage, symbol_table, n1.op_, n1.sym_, "r1", EdgeAtom::Direction::OUT, {}, "n2", false, - storage::View::OLD); + storage::v3::View::OLD); std::shared_ptr last_op = r1_n2.op_; auto r2_n3 = MakeExpand(storage, symbol_table, last_op, r1_n2.node_sym_, "r2", EdgeAtom::Direction::OUT, {}, "n3", - false, storage::View::OLD); + false, storage::v3::View::OLD); last_op = r2_n3.op_; if (edge_uniqueness) last_op = std::make_shared(last_op, r2_n3.edge_sym_, std::vector{r1_n2.edge_sym_}); @@ -1598,7 +1601,7 @@ TEST(QueryPlan, Distinct) { // test queries like // UNWIND [1, 2, 3, 3] AS x RETURN DISTINCT x - storage::Storage db; + storage::v3::Storage db; auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); AstStorage storage; @@ -1647,12 +1650,12 @@ TEST_F(QueryPlanMatchFilterTest, ScanAllByLabel) { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); // Add a vertex with a label and one without. - auto labeled_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto labeled_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); ASSERT_TRUE(labeled_vertex.AddLabel(label1).HasValue()); - ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}).HasValue()); + ASSERT_TRUE(dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}).HasValue()); dba.AdvanceCommand(); - EXPECT_EQ(2, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::v3::View::OLD))); // MATCH (n :label) AstStorage storage; SymbolTable symbol_table; @@ -1673,26 +1676,26 @@ TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelProperty) { auto label1 = db.NameToLabel("label1"); auto prop = db.NameToProperty("prop"); // vertex property values that will be stored into the DB - std::vector values{ - storage::PropertyValue(true), - storage::PropertyValue(false), - storage::PropertyValue("a"), - storage::PropertyValue("b"), - storage::PropertyValue("c"), - storage::PropertyValue(0), - storage::PropertyValue(1), - storage::PropertyValue(2), - storage::PropertyValue(0.5), - storage::PropertyValue(1.5), - storage::PropertyValue(2.5), - storage::PropertyValue(std::vector{storage::PropertyValue(0)}), - storage::PropertyValue(std::vector{storage::PropertyValue(1)}), - storage::PropertyValue(std::vector{storage::PropertyValue(2)})}; + std::vector values{ + storage::v3::PropertyValue(true), + storage::v3::PropertyValue(false), + storage::v3::PropertyValue("a"), + storage::v3::PropertyValue("b"), + storage::v3::PropertyValue("c"), + storage::v3::PropertyValue(0), + storage::v3::PropertyValue(1), + storage::v3::PropertyValue(2), + storage::v3::PropertyValue(0.5), + storage::v3::PropertyValue(1.5), + storage::v3::PropertyValue(2.5), + storage::v3::PropertyValue(std::vector{storage::v3::PropertyValue(0)}), + storage::v3::PropertyValue(std::vector{storage::v3::PropertyValue(1)}), + storage::v3::PropertyValue(std::vector{storage::v3::PropertyValue(2)})}; { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); for (const auto &value : values) { - auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); ASSERT_TRUE(vertex.AddLabel(label1).HasValue()); ASSERT_TRUE(vertex.SetProperty(prop, value).HasValue()); } @@ -1702,7 +1705,7 @@ TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelProperty) { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - ASSERT_EQ(14, CountIterable(dba.Vertices(storage::View::OLD))); + ASSERT_EQ(14, CountIterable(dba.Vertices(storage::v3::View::OLD))); auto run_scan_all = [&](const TypedValue &lower, Bound::Type lower_type, const TypedValue &upper, Bound::Type upper_type) { @@ -1723,7 +1726,8 @@ TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelProperty) { auto results = run_scan_all(lower, lower_type, upper, upper_type); ASSERT_EQ(results.size(), expected.size()); for (size_t i = 0; i < expected.size(); i++) { - TypedValue equal = TypedValue(*results[i][0].ValueVertex().GetProperty(storage::View::OLD, prop)) == expected[i]; + TypedValue equal = + TypedValue(*results[i][0].ValueVertex().GetProperty(storage::v3::View::OLD, prop)) == expected[i]; ASSERT_EQ(equal.type(), TypedValue::Type::Bool); EXPECT_TRUE(equal.ValueBool()); } @@ -1736,23 +1740,23 @@ TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelProperty) { check(TypedValue(1.5), Bound::Type::EXCLUSIVE, TypedValue(2.5), Bound::Type::INCLUSIVE, {TypedValue(2), TypedValue(2.5)}); - auto are_comparable = [](storage::PropertyValue::Type a, storage::PropertyValue::Type b) { - auto is_numeric = [](const storage::PropertyValue::Type t) { - return t == storage::PropertyValue::Type::Int || t == storage::PropertyValue::Type::Double; + auto are_comparable = [](storage::v3::PropertyValue::Type a, storage::v3::PropertyValue::Type b) { + auto is_numeric = [](const storage::v3::PropertyValue::Type t) { + return t == storage::v3::PropertyValue::Type::Int || t == storage::v3::PropertyValue::Type::Double; }; return a == b || (is_numeric(a) && is_numeric(b)); }; - auto is_orderable = [](const storage::PropertyValue &t) { + auto is_orderable = [](const storage::v3::PropertyValue &t) { return t.IsNull() || t.IsInt() || t.IsDouble() || t.IsString(); }; // when a range contains different types, nothing should get returned for (const auto &value_a : values) { for (const auto &value_b : values) { - if (are_comparable(static_cast(value_a).type(), - static_cast(value_b).type())) + if (are_comparable(static_cast(value_a).type(), + static_cast(value_b).type())) continue; if (is_orderable(value_a) && is_orderable(value_b)) { check(TypedValue(value_a), Bound::Type::INCLUSIVE, TypedValue(value_b), Bound::Type::INCLUSIVE, {}); @@ -1782,19 +1786,19 @@ TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelPropertyEqualityNoError) { { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - auto number_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto number_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); ASSERT_TRUE(number_vertex.AddLabel(label1).HasValue()); - ASSERT_TRUE(number_vertex.SetProperty(prop, storage::PropertyValue(42)).HasValue()); - auto string_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + ASSERT_TRUE(number_vertex.SetProperty(prop, storage::v3::PropertyValue(42)).HasValue()); + auto string_vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); ASSERT_TRUE(string_vertex.AddLabel(label1).HasValue()); - ASSERT_TRUE(string_vertex.SetProperty(prop, storage::PropertyValue("string")).HasValue()); + ASSERT_TRUE(string_vertex.SetProperty(prop, storage::v3::PropertyValue("string")).HasValue()); ASSERT_FALSE(dba.Commit().HasError()); } db.CreateIndex(label1, prop); auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - EXPECT_EQ(2, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::v3::View::OLD))); // MATCH (n :label {prop: 42}) AstStorage storage; SymbolTable symbol_table; @@ -1808,7 +1812,7 @@ TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelPropertyEqualityNoError) { const auto &row = results[0]; ASSERT_EQ(row.size(), 1); auto vertex = row[0].ValueVertex(); - TypedValue value(*vertex.GetProperty(storage::View::OLD, prop)); + TypedValue value(*vertex.GetProperty(storage::v3::View::OLD, prop)); TypedValue::BoolEqual eq; EXPECT_TRUE(eq(value, TypedValue(42))); } @@ -1820,9 +1824,9 @@ TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelPropertyValueError) { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); for (int i = 0; i < 2; ++i) { - auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); ASSERT_TRUE(vertex.AddLabel(label1).HasValue()); - ASSERT_TRUE(vertex.SetProperty(prop, storage::PropertyValue(i)).HasValue()); + ASSERT_TRUE(vertex.SetProperty(prop, storage::v3::PropertyValue(i)).HasValue()); } ASSERT_FALSE(dba.Commit().HasError()); } @@ -1830,7 +1834,7 @@ TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelPropertyValueError) { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - EXPECT_EQ(2, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::v3::View::OLD))); // MATCH (m), (n :label1 {prop: m}) AstStorage storage; SymbolTable symbol_table; @@ -1850,9 +1854,9 @@ TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelPropertyRangeError) { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); for (int i = 0; i < 2; ++i) { - auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); ASSERT_TRUE(vertex.AddLabel(label1).HasValue()); - ASSERT_TRUE(vertex.SetProperty(prop, storage::PropertyValue(i)).HasValue()); + ASSERT_TRUE(vertex.SetProperty(prop, storage::v3::PropertyValue(i)).HasValue()); } ASSERT_FALSE(dba.Commit().HasError()); } @@ -1860,7 +1864,7 @@ TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelPropertyRangeError) { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - EXPECT_EQ(2, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::v3::View::OLD))); // MATCH (m), (n :label1 {prop: m}) AstStorage storage; SymbolTable symbol_table; @@ -1901,18 +1905,18 @@ TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelPropertyEqualNull) { { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); ASSERT_TRUE(vertex.AddLabel(label1).HasValue()); - auto vertex_with_prop = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto vertex_with_prop = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); ASSERT_TRUE(vertex_with_prop.AddLabel(label1).HasValue()); - ASSERT_TRUE(vertex_with_prop.SetProperty(prop, storage::PropertyValue(42)).HasValue()); + ASSERT_TRUE(vertex_with_prop.SetProperty(prop, storage::v3::PropertyValue(42)).HasValue()); ASSERT_FALSE(dba.Commit().HasError()); } db.CreateIndex(label1, prop); auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - EXPECT_EQ(2, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::v3::View::OLD))); // MATCH (n :label1 {prop: 42}) AstStorage storage; SymbolTable symbol_table; @@ -1935,18 +1939,18 @@ TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelPropertyRangeNull) { { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto vertex = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); ASSERT_TRUE(vertex.AddLabel(label).HasValue()); - auto vertex_with_prop = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(2)}}); + auto vertex_with_prop = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(2)}}); ASSERT_TRUE(vertex_with_prop.AddLabel(label).HasValue()); - ASSERT_TRUE(vertex_with_prop.SetProperty(prop, storage::PropertyValue(42)).HasValue()); + ASSERT_TRUE(vertex_with_prop.SetProperty(prop, storage::v3::PropertyValue(42)).HasValue()); ASSERT_FALSE(dba.Commit().HasError()); } db.CreateIndex(label1, prop); auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - EXPECT_EQ(2, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(2, CountIterable(dba.Vertices(storage::v3::View::OLD))); // MATCH (n :label1) WHERE null <= n.prop < null AstStorage storage; SymbolTable symbol_table; @@ -1967,16 +1971,16 @@ TEST_F(QueryPlanMatchFilterTest, ScanAllByLabelPropertyNoValueInIndexContinuatio { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); ASSERT_TRUE(v.AddLabel(label1).HasValue()); - ASSERT_TRUE(v.SetProperty(prop, storage::PropertyValue(2)).HasValue()); + ASSERT_TRUE(v.SetProperty(prop, storage::v3::PropertyValue(2)).HasValue()); ASSERT_FALSE(dba.Commit().HasError()); } db.CreateIndex(label1, prop); auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - EXPECT_EQ(1, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(1, CountIterable(dba.Vertices(storage::v3::View::OLD))); AstStorage storage; SymbolTable symbol_table; @@ -2006,10 +2010,10 @@ TEST_F(QueryPlanMatchFilterTest, ScanAllEqualsScanAllByLabelProperty) { for (int i = 0; i < vertex_count; ++i) { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::PropertyValue(1)}}); + auto v = *dba.InsertVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(1)}}); ASSERT_TRUE(v.AddLabel(label1).HasValue()); ASSERT_TRUE( - v.SetProperty(prop, storage::PropertyValue(i < vertex_prop_count ? prop_value1 : prop_value2)).HasValue()); + v.SetProperty(prop, storage::v3::PropertyValue(i < vertex_prop_count ? prop_value1 : prop_value2)).HasValue()); ASSERT_FALSE(dba.Commit().HasError()); } @@ -2019,7 +2023,7 @@ TEST_F(QueryPlanMatchFilterTest, ScanAllEqualsScanAllByLabelProperty) { { auto storage_dba = db.Access(); DbAccessor dba(&storage_dba); - EXPECT_EQ(vertex_count, CountIterable(dba.Vertices(storage::View::OLD))); + EXPECT_EQ(vertex_count, CountIterable(dba.Vertices(storage::v3::View::OLD))); } // Make sure there are `vertex_prop_count` results when using index diff --git a/tests/unit/query_v2_query_plan_v2_create_set_remove_delete.cpp b/tests/unit/query_v2_query_plan_v2_create_set_remove_delete.cpp index 6f298e2ad..ed119722f 100644 --- a/tests/unit/query_v2_query_plan_v2_create_set_remove_delete.cpp +++ b/tests/unit/query_v2_query_plan_v2_create_set_remove_delete.cpp @@ -11,23 +11,23 @@ #include -#include "query/frontend/semantic/symbol_table.hpp" -#include "query/plan/operator.hpp" -#include "query_plan_common.hpp" -#include "storage/v2/property_value.hpp" -#include "storage/v2/storage.hpp" +#include "query/v2/frontend/semantic/symbol_table.hpp" +#include "query/v2/plan/operator.hpp" +#include "query_v2_query_plan_common.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/storage.hpp" -namespace memgraph::query::tests { +namespace memgraph::query::v2::tests { class QueryPlanCRUDTest : public testing::Test { protected: void SetUp() override { - ASSERT_TRUE(db.CreateSchema(label, {storage::SchemaProperty{property, common::SchemaType::INT}})); + ASSERT_TRUE(db.CreateSchema(label, {storage::v3::SchemaProperty{property, common::SchemaType::INT}})); } - storage::Storage db; - const storage::LabelId label{db.NameToLabel("label")}; - const storage::PropertyId property{db.NameToProperty("property")}; + storage::v3::Storage db; + const storage::v3::LabelId label{db.NameToLabel("label")}; + const storage::v3::PropertyId property{db.NameToProperty("property")}; }; TEST_F(QueryPlanCRUDTest, CreateNodeWithAttributes) { @@ -39,7 +39,7 @@ TEST_F(QueryPlanCRUDTest, CreateNodeWithAttributes) { plan::NodeCreationInfo node; node.symbol = symbol_table.CreateSymbol("n", true); node.labels.emplace_back(label); - std::get>>(node.properties) + std::get>>(node.properties) .emplace_back(property, ast.Create(42)); plan::CreateNode create_node(nullptr, node); @@ -53,12 +53,12 @@ TEST_F(QueryPlanCRUDTest, CreateNodeWithAttributes) { const auto &node_value = frame[node.symbol]; EXPECT_EQ(node_value.type(), TypedValue::Type::Vertex); const auto &v = node_value.ValueVertex(); - EXPECT_TRUE(*v.HasLabel(storage::View::NEW, label)); - EXPECT_EQ(v.GetProperty(storage::View::NEW, property)->ValueInt(), 42); - EXPECT_EQ(CountIterable(*v.InEdges(storage::View::NEW)), 0); - EXPECT_EQ(CountIterable(*v.OutEdges(storage::View::NEW)), 0); + EXPECT_TRUE(*v.HasLabel(storage::v3::View::NEW, label)); + EXPECT_EQ(v.GetProperty(storage::v3::View::NEW, property)->ValueInt(), 42); + EXPECT_EQ(CountIterable(*v.InEdges(storage::v3::View::NEW)), 0); + EXPECT_EQ(CountIterable(*v.OutEdges(storage::v3::View::NEW)), 0); // Invokes LOG(FATAL) instead of erroring out. - // EXPECT_TRUE(v.HasLabel(label, storage::View::OLD).IsError()); + // EXPECT_TRUE(v.HasLabel(label, storage::v3::View::OLD).IsError()); } EXPECT_EQ(count, 1); } @@ -70,7 +70,7 @@ TEST_F(QueryPlanCRUDTest, ScanAllEmpty) { DbAccessor execution_dba(&dba); auto node_symbol = symbol_table.CreateSymbol("n", true); { - plan::ScanAll scan_all(nullptr, node_symbol, storage::View::OLD); + plan::ScanAll scan_all(nullptr, node_symbol, storage::v3::View::OLD); auto context = MakeContext(ast, symbol_table, &execution_dba); Frame frame(context.symbol_table.max_position()); auto cursor = scan_all.MakeCursor(utils::NewDeleteResource()); @@ -79,7 +79,7 @@ TEST_F(QueryPlanCRUDTest, ScanAllEmpty) { EXPECT_EQ(count, 0); } { - plan::ScanAll scan_all(nullptr, node_symbol, storage::View::NEW); + plan::ScanAll scan_all(nullptr, node_symbol, storage::v3::View::NEW); auto context = MakeContext(ast, symbol_table, &execution_dba); Frame frame(context.symbol_table.max_position()); auto cursor = scan_all.MakeCursor(utils::NewDeleteResource()); @@ -93,8 +93,8 @@ TEST_F(QueryPlanCRUDTest, ScanAll) { { auto dba = db.Access(); for (int i = 0; i < 42; ++i) { - auto v = *dba.CreateVertexAndValidate(label, {}, {{property, storage::PropertyValue(i)}}); - ASSERT_TRUE(v.SetProperty(property, storage::PropertyValue(i)).HasValue()); + auto v = *dba.CreateVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i)}}); + ASSERT_TRUE(v.SetProperty(property, storage::v3::PropertyValue(i)).HasValue()); } EXPECT_FALSE(dba.Commit().HasError()); } @@ -119,13 +119,13 @@ TEST_F(QueryPlanCRUDTest, ScanAllByLabel) { auto dba = db.Access(); // Add some unlabeled vertices for (int i = 0; i < 12; ++i) { - auto v = *dba.CreateVertexAndValidate(label, {}, {{property, storage::PropertyValue(i)}}); - ASSERT_TRUE(v.SetProperty(property, storage::PropertyValue(i)).HasValue()); + auto v = *dba.CreateVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i)}}); + ASSERT_TRUE(v.SetProperty(property, storage::v3::PropertyValue(i)).HasValue()); } // Add labeled vertices for (int i = 0; i < 42; ++i) { - auto v = *dba.CreateVertexAndValidate(label, {}, {{property, storage::PropertyValue(i)}}); - ASSERT_TRUE(v.SetProperty(property, storage::PropertyValue(i)).HasValue()); + auto v = *dba.CreateVertexAndValidate(label, {}, {{property, storage::v3::PropertyValue(i)}}); + ASSERT_TRUE(v.SetProperty(property, storage::v3::PropertyValue(i)).HasValue()); ASSERT_TRUE(v.AddLabel(label2).HasValue()); } EXPECT_FALSE(dba.Commit().HasError()); @@ -143,4 +143,4 @@ TEST_F(QueryPlanCRUDTest, ScanAllByLabel) { while (cursor->Pull(frame, context)) ++count; EXPECT_EQ(count, 42); } -} // namespace memgraph::query::tests +} // namespace memgraph::query::v2::tests diff --git a/tests/unit/query_v2_query_required_privileges.cpp b/tests/unit/query_v2_query_required_privileges.cpp new file mode 100644 index 000000000..c59fc604e --- /dev/null +++ b/tests/unit/query_v2_query_required_privileges.cpp @@ -0,0 +1,222 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include +#include + +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/frontend/ast/ast_visitor.hpp" +#include "query/v2/frontend/semantic/required_privileges.hpp" +#include "storage/v3/id_types.hpp" + +#include "query_v2_query_common.hpp" + +using namespace memgraph::query::v2; + +class FakeDbAccessor {}; + +const std::string EDGE_TYPE = "0"; +const std::string LABEL_0 = "label0"; +const std::string LABEL_1 = "label1"; +const std::string PROP_0 = "prop0"; + +using ::testing::UnorderedElementsAre; + +class TestPrivilegeExtractor : public ::testing::Test { + protected: + AstStorage storage; + FakeDbAccessor dba; +}; + +TEST_F(TestPrivilegeExtractor, CreateNode) { + auto *query = QUERY(SINGLE_QUERY(CREATE(PATTERN(NODE("n"))))); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::CREATE)); +} + +TEST_F(TestPrivilegeExtractor, MatchNodeDelete) { + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), DELETE(IDENT("n")))); + EXPECT_THAT(GetRequiredPrivileges(query), + UnorderedElementsAre(AuthQuery::Privilege::MATCH, AuthQuery::Privilege::DELETE)); +} + +TEST_F(TestPrivilegeExtractor, MatchNodeReturn) { + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN("n"))); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::MATCH)); +} + +TEST_F(TestPrivilegeExtractor, MatchCreateExpand) { + auto *query = + QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), + CREATE(PATTERN(NODE("n"), EDGE("r", EdgeAtom::Direction::OUT, {EDGE_TYPE}), NODE("m"))))); + EXPECT_THAT(GetRequiredPrivileges(query), + UnorderedElementsAre(AuthQuery::Privilege::MATCH, AuthQuery::Privilege::CREATE)); +} + +TEST_F(TestPrivilegeExtractor, MatchNodeSetLabels) { + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), SET("n", {LABEL_0, LABEL_1}))); + EXPECT_THAT(GetRequiredPrivileges(query), + UnorderedElementsAre(AuthQuery::Privilege::MATCH, AuthQuery::Privilege::SET)); +} + +TEST_F(TestPrivilegeExtractor, MatchNodeSetProperty) { + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), + SET(PROPERTY_LOOKUP(storage.Create("n"), PROP_0), LITERAL(42)))); + EXPECT_THAT(GetRequiredPrivileges(query), + UnorderedElementsAre(AuthQuery::Privilege::MATCH, AuthQuery::Privilege::SET)); +} + +TEST_F(TestPrivilegeExtractor, MatchNodeSetProperties) { + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), SET("n", LIST()))); + EXPECT_THAT(GetRequiredPrivileges(query), + UnorderedElementsAre(AuthQuery::Privilege::MATCH, AuthQuery::Privilege::SET)); +} + +TEST_F(TestPrivilegeExtractor, MatchNodeRemoveLabels) { + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), REMOVE("n", {LABEL_0, LABEL_1}))); + EXPECT_THAT(GetRequiredPrivileges(query), + UnorderedElementsAre(AuthQuery::Privilege::MATCH, AuthQuery::Privilege::REMOVE)); +} + +TEST_F(TestPrivilegeExtractor, MatchNodeRemoveProperty) { + auto *query = + QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), REMOVE(PROPERTY_LOOKUP(storage.Create("n"), PROP_0)))); + EXPECT_THAT(GetRequiredPrivileges(query), + UnorderedElementsAre(AuthQuery::Privilege::MATCH, AuthQuery::Privilege::REMOVE)); +} + +TEST_F(TestPrivilegeExtractor, CreateIndex) { + auto *query = CREATE_INDEX_ON(storage.GetLabelIx(LABEL_0), storage.GetPropertyIx(PROP_0)); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::INDEX)); +} + +TEST_F(TestPrivilegeExtractor, AuthQuery) { + auto *query = + AUTH_QUERY(AuthQuery::Action::CREATE_ROLE, "", "role", "", nullptr, std::vector{}); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::AUTH)); +} + +TEST_F(TestPrivilegeExtractor, ShowIndexInfo) { + auto *query = storage.Create(); + query->info_type_ = InfoQuery::InfoType::INDEX; + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::INDEX)); +} + +TEST_F(TestPrivilegeExtractor, ShowStatsInfo) { + auto *query = storage.Create(); + query->info_type_ = InfoQuery::InfoType::STORAGE; + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::STATS)); +} + +TEST_F(TestPrivilegeExtractor, ShowConstraintInfo) { + auto *query = storage.Create(); + query->info_type_ = InfoQuery::InfoType::CONSTRAINT; + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::CONSTRAINT)); +} + +TEST_F(TestPrivilegeExtractor, CreateConstraint) { + auto *query = storage.Create(); + query->action_type_ = ConstraintQuery::ActionType::CREATE; + query->constraint_.label = storage.GetLabelIx("label"); + query->constraint_.properties.push_back(storage.GetPropertyIx("prop0")); + query->constraint_.properties.push_back(storage.GetPropertyIx("prop1")); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::CONSTRAINT)); +} + +TEST_F(TestPrivilegeExtractor, DropConstraint) { + auto *query = storage.Create(); + query->action_type_ = ConstraintQuery::ActionType::DROP; + query->constraint_.label = storage.GetLabelIx("label"); + query->constraint_.properties.push_back(storage.GetPropertyIx("prop0")); + query->constraint_.properties.push_back(storage.GetPropertyIx("prop1")); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::CONSTRAINT)); +} + +// NOLINTNEXTLINE(hicpp-special-member-functions) +TEST_F(TestPrivilegeExtractor, DumpDatabase) { + auto *query = storage.Create(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::DUMP)); +} + +TEST_F(TestPrivilegeExtractor, ReadFile) { + auto load_csv = storage.Create(); + load_csv->row_var_ = IDENT("row"); + auto *query = QUERY(SINGLE_QUERY(load_csv)); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::READ_FILE)); +} + +TEST_F(TestPrivilegeExtractor, LockPathQuery) { + auto *query = storage.Create(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::DURABILITY)); +} + +TEST_F(TestPrivilegeExtractor, FreeMemoryQuery) { + auto *query = storage.Create(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::FREE_MEMORY)); +} + +TEST_F(TestPrivilegeExtractor, TriggerQuery) { + auto *query = storage.Create(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::TRIGGER)); +} + +TEST_F(TestPrivilegeExtractor, SetIsolationLevelQuery) { + auto *query = storage.Create(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::CONFIG)); +} + +TEST_F(TestPrivilegeExtractor, CreateSnapshotQuery) { + auto *query = storage.Create(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::DURABILITY)); +} + +TEST_F(TestPrivilegeExtractor, StreamQuery) { + auto *query = storage.Create(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::STREAM)); +} + +TEST_F(TestPrivilegeExtractor, SettingQuery) { + auto *query = storage.Create(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::CONFIG)); +} + +TEST_F(TestPrivilegeExtractor, ShowVersion) { + auto *query = storage.Create(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::STATS)); +} + +TEST_F(TestPrivilegeExtractor, SchemaQuery) { + auto *query = storage.Create(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::SCHEMA)); +} + +TEST_F(TestPrivilegeExtractor, CallProcedureQuery) { + { + auto *query = QUERY(SINGLE_QUERY(CALL_PROCEDURE("mg.get_module_files"))); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::MODULE_READ)); + } + { + auto *query = QUERY(SINGLE_QUERY(CALL_PROCEDURE("mg.create_module_file", {LITERAL("some_name.py")}))); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::MODULE_WRITE)); + } + { + auto *query = QUERY( + SINGLE_QUERY(CALL_PROCEDURE("mg.update_module_file", {LITERAL("some_name.py"), LITERAL("some content")}))); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::MODULE_WRITE)); + } + { + auto *query = QUERY(SINGLE_QUERY(CALL_PROCEDURE("mg.get_module_file", {LITERAL("some_name.py")}))); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::MODULE_READ)); + } + { + auto *query = QUERY(SINGLE_QUERY(CALL_PROCEDURE("mg.delete_module_file", {LITERAL("some_name.py")}))); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::MODULE_WRITE)); + } +} diff --git a/tests/unit/result_stream_faker.hpp b/tests/unit/result_stream_faker.hpp new file mode 100644 index 000000000..60c5884e7 --- /dev/null +++ b/tests/unit/result_stream_faker.hpp @@ -0,0 +1,132 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include + +#include "glue/v2/communication.hpp" +#include "query/v2/typed_value.hpp" +#include "storage/v3/storage.hpp" +#include "utils/algorithm.hpp" + +/** + * A mocker for the data output record stream. + * This implementation checks that messages are + * sent to it in an acceptable order, and tracks + * the content of those messages. + */ +class ResultStreamFaker { + public: + explicit ResultStreamFaker(memgraph::storage::v3::Storage *store) : store_(store) {} + + ResultStreamFaker(const ResultStreamFaker &) = delete; + ResultStreamFaker &operator=(const ResultStreamFaker &) = delete; + ResultStreamFaker(ResultStreamFaker &&) = default; + ResultStreamFaker &operator=(ResultStreamFaker &&) = default; + + void Header(const std::vector &fields) { header_ = fields; } + + void Result(const std::vector &values) { results_.push_back(values); } + + void Result(const std::vector &values) { + std::vector bvalues; + bvalues.reserve(values.size()); + for (const auto &value : values) { + auto maybe_value = memgraph::glue::v2::ToBoltValue(value, *store_, memgraph::storage::v3::View::NEW); + MG_ASSERT(maybe_value.HasValue()); + bvalues.push_back(std::move(*maybe_value)); + } + results_.push_back(std::move(bvalues)); + } + + void Summary(const std::map &summary) { summary_ = summary; } + + void Summary(const std::map &summary) { + std::map bsummary; + for (const auto &item : summary) { + auto maybe_value = memgraph::glue::v2::ToBoltValue(item.second, *store_, memgraph::storage::v3::View::NEW); + MG_ASSERT(maybe_value.HasValue()); + bsummary.insert({item.first, std::move(*maybe_value)}); + } + summary_ = std::move(bsummary); + } + + const auto &GetHeader() const { return header_; } + + const auto &GetResults() const { return results_; } + + const auto &GetSummary() const { return summary_; } + + friend std::ostream &operator<<(std::ostream &os, const ResultStreamFaker &results) { + auto decoded_value_to_string = [](const auto &value) { + std::stringstream ss; + ss << value; + return ss.str(); + }; + const std::vector &header = results.GetHeader(); + std::vector column_widths(header.size()); + std::transform(header.begin(), header.end(), column_widths.begin(), [](const auto &s) { return s.size(); }); + + // convert all the results into strings, and track max column width + auto &results_data = results.GetResults(); + std::vector> result_strings(results_data.size(), + std::vector(column_widths.size())); + for (int row_ind = 0; row_ind < static_cast(results_data.size()); ++row_ind) { + for (int col_ind = 0; col_ind < static_cast(column_widths.size()); ++col_ind) { + std::string string_val = decoded_value_to_string(results_data[row_ind][col_ind]); + column_widths[col_ind] = std::max(column_widths[col_ind], (int)string_val.size()); + result_strings[row_ind][col_ind] = string_val; + } + } + + // output a results table + // first define some helper functions + auto emit_horizontal_line = [&]() { + os << "+"; + for (auto col_width : column_widths) os << std::string((unsigned long)col_width + 2, '-') << "+"; + os << std::endl; + }; + + auto emit_result_vec = [&](const std::vector result_vec) { + os << "| "; + for (int col_ind = 0; col_ind < static_cast(column_widths.size()); ++col_ind) { + const std::string &res = result_vec[col_ind]; + os << res << std::string(column_widths[col_ind] - res.size(), ' '); + os << " | "; + } + os << std::endl; + }; + + // final output of results + emit_horizontal_line(); + emit_result_vec(results.GetHeader()); + emit_horizontal_line(); + for (const auto &result_vec : result_strings) emit_result_vec(result_vec); + emit_horizontal_line(); + os << "Found " << results_data.size() << " matching results" << std::endl; + + // output the summary + os << "Query summary: {"; + memgraph::utils::PrintIterable(os, results.GetSummary(), ", ", + [&](auto &stream, const auto &kv) { stream << kv.first << ": " << kv.second; }); + os << "}" << std::endl; + + return os; + } + + private: + memgraph::storage::v3::Storage *store_; + // the data that the record stream can accept + std::vector header_; + std::vector> results_; + std::map summary_; +}; diff --git a/tests/unit/storage_v3_schema.cpp b/tests/unit/storage_v3_schema.cpp index bd9a82209..3d090fc2f 100644 --- a/tests/unit/storage_v3_schema.cpp +++ b/tests/unit/storage_v3_schema.cpp @@ -19,23 +19,23 @@ #include #include "common/types.hpp" -#include "storage/v2/id_types.hpp" -#include "storage/v2/property_value.hpp" -#include "storage/v2/schema_validator.hpp" -#include "storage/v2/schemas.hpp" -#include "storage/v2/storage.hpp" -#include "storage/v2/temporal.hpp" +#include "storage/v3/id_types.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/schema_validator.hpp" +#include "storage/v3/schemas.hpp" +#include "storage/v3/storage.hpp" +#include "storage/v3/temporal.hpp" using testing::Pair; using testing::UnorderedElementsAre; using SchemaType = memgraph::common::SchemaType; -namespace memgraph::storage::tests { +namespace memgraph::storage::v3::tests { class SchemaTest : public testing::Test { private: - memgraph::storage::NameIdMapper label_mapper_; - memgraph::storage::NameIdMapper property_mapper_; + NameIdMapper label_mapper_; + NameIdMapper property_mapper_; protected: LabelId NameToLabel(const std::string &name) { return LabelId::FromUint(label_mapper_.NameToId(name)); } @@ -161,8 +161,8 @@ class SchemaValidatorTest : public testing::Test { PropertyId NameToProperty(const std::string &name) { return PropertyId::FromUint(property_mapper_.NameToId(name)); } private: - memgraph::storage::NameIdMapper label_mapper_; - memgraph::storage::NameIdMapper property_mapper_; + NameIdMapper label_mapper_; + NameIdMapper property_mapper_; protected: Schemas schemas; @@ -291,4 +291,4 @@ TEST_F(SchemaValidatorTest, TestSchemaValidatePropertyUpdateLabel) { } EXPECT_EQ(schema_validator.ValidateLabelUpdate(NameToLabel("test")), std::nullopt); } -} // namespace memgraph::storage::tests +} // namespace memgraph::storage::v3::tests From 5012824e053a882ab5e020ba4ef7ff0abc029a7c Mon Sep 17 00:00:00 2001 From: jbajic Date: Thu, 4 Aug 2022 11:45:16 +0200 Subject: [PATCH 5/5] Address review comments --- src/query/v2/metadata.cpp | 4 ---- src/query/v2/metadata.hpp | 2 -- src/query/v2/plan/operator.cpp | 6 ++++-- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/query/v2/metadata.cpp b/src/query/v2/metadata.cpp index 7735edacf..f8a14d4a0 100644 --- a/src/query/v2/metadata.cpp +++ b/src/query/v2/metadata.cpp @@ -72,10 +72,6 @@ constexpr std::string_view GetCodeString(const NotificationCode code) { return "ReplicaPortWarning"sv; case NotificationCode::SET_REPLICA: return "SetReplica"sv; - case NotificationCode::SHOW_SCHEMA: - return "ShowSchema"sv; - case NotificationCode::SHOW_SCHEMAS: - return "ShowSchemas"sv; case NotificationCode::START_STREAM: return "StartStream"sv; case NotificationCode::START_ALL_STREAMS: diff --git a/src/query/v2/metadata.hpp b/src/query/v2/metadata.hpp index 4a5e7b9c9..c5211b1c1 100644 --- a/src/query/v2/metadata.hpp +++ b/src/query/v2/metadata.hpp @@ -44,8 +44,6 @@ enum class NotificationCode : uint8_t { REPLICA_PORT_WARNING, REGISTER_REPLICA, SET_REPLICA, - SHOW_SCHEMA, - SHOW_SCHEMAS, START_STREAM, START_ALL_STREAMS, STOP_STREAM, diff --git a/src/query/v2/plan/operator.cpp b/src/query/v2/plan/operator.cpp index 432a8c7e8..4e3005d75 100644 --- a/src/query/v2/plan/operator.cpp +++ b/src/query/v2/plan/operator.cpp @@ -198,8 +198,10 @@ VertexAccessor &CreateLocalVertexAtomically(const NodeCreationInfo &node_info, F properties.emplace_back(property_id, value); } } - // TODO Remove later on since that will be enforced from grammar side - MG_ASSERT(!node_info.labels.empty(), "There must be at least one label!"); + + if (node_info.labels.empty()) { + throw QueryRuntimeException("Primary label must be defined!"); + } const auto primary_label = node_info.labels[0]; std::vector secondary_labels(node_info.labels.begin() + 1, node_info.labels.end()); auto maybe_new_node = dba.InsertVertexAndValidate(primary_label, secondary_labels, properties);