diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 78ab2eba2..b64d287a6 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -47,6 +47,7 @@ set(memgraph_src_files query/common.cpp query/frontend/ast/ast.cpp query/frontend/ast/cypher_main_visitor.cpp + query/frontend/semantic/required_privileges.cpp query/frontend/semantic/symbol_generator.cpp query/frontend/stripped.cpp query/interpret/awesome_memgraph_functions.cpp diff --git a/src/query/frontend/semantic/required_privileges.cpp b/src/query/frontend/semantic/required_privileges.cpp new file mode 100644 index 000000000..d8a375b3f --- /dev/null +++ b/src/query/frontend/semantic/required_privileges.cpp @@ -0,0 +1,103 @@ +#include "query/frontend/ast/ast.hpp" + +namespace query { + +class PrivilegeExtractor : public HierarchicalTreeVisitor { + public: + using HierarchicalTreeVisitor::PostVisit; + using HierarchicalTreeVisitor::PreVisit; + using HierarchicalTreeVisitor::Visit; + + std::vector<AuthQuery::Privilege> privileges() { return privileges_; } + + bool PreVisit(Create &) override { + AddPrivilege(AuthQuery::Privilege::CREATE); + return false; + } + bool PreVisit(Delete &) override { + AddPrivilege(AuthQuery::Privilege::DELETE); + return false; + } + bool PreVisit(Match &) override { + AddPrivilege(AuthQuery::Privilege::MATCH); + return false; + } + bool PreVisit(Merge &) override { + AddPrivilege(AuthQuery::Privilege::MERGE); + return false; + } + bool PreVisit(SetProperty &) override { + AddPrivilege(AuthQuery::Privilege::SET); + return false; + } + bool PreVisit(SetProperties &) override { + AddPrivilege(AuthQuery::Privilege::SET); + return false; + } + bool PreVisit(SetLabels &) override { + AddPrivilege(AuthQuery::Privilege::SET); + return false; + } + bool PreVisit(RemoveProperty &) override { + AddPrivilege(AuthQuery::Privilege::REMOVE); + return false; + } + bool PreVisit(RemoveLabels &) override { + AddPrivilege(AuthQuery::Privilege::REMOVE); + return false; + } + bool Visit(Identifier &) override { return true; } + bool Visit(PrimitiveLiteral &) override { return true; } + bool Visit(ParameterLookup &) override { return true; } + + bool Visit(CreateIndex &) override { + AddPrivilege(AuthQuery::Privilege::INDEX); + return true; + } + bool Visit(AuthQuery &) override { + AddPrivilege(AuthQuery::Privilege::AUTH); + return true; + } + bool Visit(CreateStream &) override { + AddPrivilege(AuthQuery::Privilege::STREAM); + return true; + } + bool Visit(DropStream &) override { + AddPrivilege(AuthQuery::Privilege::STREAM); + return true; + } + bool Visit(ShowStreams &) override { + AddPrivilege(AuthQuery::Privilege::STREAM); + return true; + } + bool Visit(StartStopStream &) override { + AddPrivilege(AuthQuery::Privilege::STREAM); + return true; + } + bool Visit(StartStopAllStreams &) override { + AddPrivilege(AuthQuery::Privilege::STREAM); + return true; + } + bool Visit(TestStream &) override { + AddPrivilege(AuthQuery::Privilege::STREAM); + return true; + } + + private: + void AddPrivilege(AuthQuery::Privilege privilege) { + if (!utils::Contains(privileges_, privilege)) { + privileges_.push_back(privilege); + } + } + + std::vector<AuthQuery::Privilege> privileges_; +}; + +std::vector<AuthQuery::Privilege> GetRequiredPrivileges( + const AstStorage &ast_storage) { + PrivilegeExtractor extractor; + ast_storage.query()->Accept(extractor); + return extractor.privileges(); +} + +} // namespace query diff --git a/src/query/frontend/semantic/required_privileges.hpp b/src/query/frontend/semantic/required_privileges.hpp new file mode 100644 index 000000000..9b74d55c6 --- /dev/null +++ b/src/query/frontend/semantic/required_privileges.hpp @@ -0,0 +1,8 @@ +#pragma once + +#include "query/frontend/ast/ast.hpp" + +namespace query { +std::vector<AuthQuery::Privilege> GetRequiredPrivileges( + const AstStorage &ast_storage); +} diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 586a4c97e..41c414535 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -8,6 +8,7 @@ #include "query/exceptions.hpp" #include "query/frontend/ast/cypher_main_visitor.hpp" #include "query/frontend/opencypher/parser.hpp" +#include "query/frontend/semantic/required_privileges.hpp" #include "query/frontend/semantic/symbol_generator.hpp" #include "query/plan/planner.hpp" #include "query/plan/vertex_count_cache.hpp" @@ -80,6 +81,10 @@ Interpreter::Results Interpreter::operator()( } ctx.parameters_.Add(param_pair.first, param_it->second); } + AstStorage ast_storage = QueryToAst(stripped, ctx); + // TODO: Maybe cache required privileges to improve performance on very simple + // queries. + auto required_privileges = query::GetRequiredPrivileges(ast_storage); auto frontend_time = frontend_timer.Elapsed(); // Try to get a cached plan. Note that this local shared_ptr might be the only @@ -96,8 +101,9 @@ Interpreter::Results Interpreter::operator()( } utils::Timer planning_timer; if (!plan) { - plan = plan_cache_access.insert(stripped.hash(), QueryToPlan(stripped, ctx)) - .first->second; + plan = + plan_cache_access.insert(stripped.hash(), AstToPlan(ast_storage, ctx)) + .first->second; } auto planning_time = planning_timer.Elapsed(); @@ -128,12 +134,11 @@ Interpreter::Results Interpreter::operator()( } return Results(std::move(ctx), plan, std::move(cursor), output_symbols, - header, summary, plan_cache_); + header, summary, plan_cache_, required_privileges); } -std::shared_ptr<Interpreter::CachedPlan> Interpreter::QueryToPlan( - const StrippedQuery &stripped, Context &ctx) { - AstStorage ast_storage = QueryToAst(stripped, ctx); +std::shared_ptr<Interpreter::CachedPlan> Interpreter::AstToPlan( + AstStorage &ast_storage, Context &ctx) { SymbolGenerator symbol_generator(ctx.symbol_table_); ast_storage.query()->Accept(symbol_generator); diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index 474664676..771f44ec6 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -74,7 +74,8 @@ class Interpreter { Results(Context ctx, std::shared_ptr<CachedPlan> plan, std::unique_ptr<query::plan::Cursor> cursor, std::vector<Symbol> output_symbols, std::vector<std::string> header, - std::map<std::string, TypedValue> summary, PlanCacheT &plan_cache) + std::map<std::string, TypedValue> summary, PlanCacheT &plan_cache, + std::vector<AuthQuery::Privilege> privileges) : ctx_(std::move(ctx)), plan_(plan), cursor_(std::move(cursor)), @@ -82,7 +83,8 @@ class Interpreter { output_symbols_(output_symbols), header_(header), summary_(summary), - plan_cache_(plan_cache) {} + plan_cache_(plan_cache), + privileges_(std::move(privileges)) {} public: Results(const Results &) = delete; @@ -137,6 +139,10 @@ class Interpreter { const std::vector<std::string> &header() { return header_; } const std::map<std::string, TypedValue> &summary() { return summary_; } + const std::vector<AuthQuery::Privilege> &privileges() { + return privileges_; + } + private: Context ctx_; std::shared_ptr<CachedPlan> plan_; @@ -150,6 +156,8 @@ class Interpreter { double execution_time_{0}; // Gets invalidated after if an index has been built. PlanCacheT &plan_cache_; + + std::vector<AuthQuery::Privilege> privileges_; }; explicit Interpreter(database::GraphDb &db); @@ -185,9 +193,8 @@ class Interpreter { // Optional, not null only in a distributed master. distributed::PlanDispatcher *plan_dispatcher_{nullptr}; - // stripped query -> CachedPlan - std::shared_ptr<CachedPlan> QueryToPlan(const StrippedQuery &stripped, - Context &ctx); + // high level tree -> CachedPlan + std::shared_ptr<CachedPlan> AstToPlan(AstStorage &ast_storage, Context &ctx); // stripped query -> high level tree AstStorage QueryToAst(const StrippedQuery &stripped, Context &ctx); diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index b84b98de7..4bc7f57c5 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -169,6 +169,9 @@ target_link_libraries(${test_prefix}query_plan_match_filter_return memgraph_lib add_unit_test(query_planner.cpp) target_link_libraries(${test_prefix}query_planner memgraph_lib kvstore_dummy_lib) +add_unit_test(query_required_privileges.cpp) +target_link_libraries(${test_prefix}query_required_privileges memgraph_lib kvstore_dummy_lib) + add_unit_test(query_semantic.cpp) target_link_libraries(${test_prefix}query_semantic memgraph_lib kvstore_dummy_lib) diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index 7de256d5c..7e36033b2 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -576,7 +576,7 @@ auto GetMerge(AstStorage &storage, Pattern *pattern, OnMatch on_match, list, expr) #define AUTH_QUERY(action, user, role, user_or_role, password, privileges) \ storage.Create<query::AuthQuery>((action), (user), (role), (user_or_role), \ - LITERAL(password), (privileges)) + password, (privileges)) #define DROP_USER(usernames) storage.Create<query::DropUser>((usernames)) #define CREATE_STREAM(stream_name, stream_uri, stream_topic, transform_uri, \ batch_interval, batch_size) \ diff --git a/tests/unit/query_required_privileges.cpp b/tests/unit/query_required_privileges.cpp new file mode 100644 index 000000000..52cc0ec27 --- /dev/null +++ b/tests/unit/query_required_privileges.cpp @@ -0,0 +1,128 @@ +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "query/frontend/semantic/required_privileges.hpp" +#include "storage/types.hpp" + +#include "query_common.hpp" + +using namespace query; + +class FakeDbAccessor {}; + +storage::EdgeType EDGE_TYPE(0); +storage::Label LABEL_0(0); +storage::Label LABEL_1(1); +storage::Property PROP_0(0); + +using ::testing::UnorderedElementsAre; + +class TestPrivilegeExtractor : public ::testing::Test { + protected: + AstStorage storage; + FakeDbAccessor dba; +}; + +TEST_F(TestPrivilegeExtractor, CreateNode) { + QUERY(SINGLE_QUERY(CREATE(PATTERN(NODE("n"))))); + EXPECT_THAT(GetRequiredPrivileges(storage), + UnorderedElementsAre(AuthQuery::Privilege::CREATE)); +} + +TEST_F(TestPrivilegeExtractor, MatchNodeDelete) { + QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), DELETE(IDENT("n")))); + EXPECT_THAT(GetRequiredPrivileges(storage), + UnorderedElementsAre(AuthQuery::Privilege::MATCH, + AuthQuery::Privilege::DELETE)); +} + +TEST_F(TestPrivilegeExtractor, MatchNodeReturn) { + QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN("n"))); + EXPECT_THAT(GetRequiredPrivileges(storage), + UnorderedElementsAre(AuthQuery::Privilege::MATCH)); +} + +TEST_F(TestPrivilegeExtractor, MatchCreateExpand) { + QUERY(SINGLE_QUERY( + MATCH(PATTERN(NODE("n"))), + CREATE(PATTERN(NODE("n"), + EDGE("r", EdgeAtom::Direction::OUT, {EDGE_TYPE}), + NODE("m"))))); + EXPECT_THAT(GetRequiredPrivileges(storage), + UnorderedElementsAre(AuthQuery::Privilege::MATCH, + AuthQuery::Privilege::CREATE)); +} + +TEST_F(TestPrivilegeExtractor, MatchNodeSetLabels) { + QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), SET("n", {LABEL_0, LABEL_1}))); + EXPECT_THAT(GetRequiredPrivileges(storage), + UnorderedElementsAre(AuthQuery::Privilege::MATCH, + AuthQuery::Privilege::SET)); +} + +TEST_F(TestPrivilegeExtractor, MatchNodeSetProperty) { + QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), + SET(PROPERTY_LOOKUP("n", {"prop", PROP_0}), LITERAL(42)))); + EXPECT_THAT(GetRequiredPrivileges(storage), + UnorderedElementsAre(AuthQuery::Privilege::MATCH, + AuthQuery::Privilege::SET)); +} + +TEST_F(TestPrivilegeExtractor, MatchNodeSetProperties) { + QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), SET("n", LIST()))); + EXPECT_THAT(GetRequiredPrivileges(storage), + UnorderedElementsAre(AuthQuery::Privilege::MATCH, + AuthQuery::Privilege::SET)); +} + +TEST_F(TestPrivilegeExtractor, MatchNodeRemoveLabels) { + QUERY( + SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), REMOVE("n", {LABEL_0, LABEL_1}))); + EXPECT_THAT(GetRequiredPrivileges(storage), + UnorderedElementsAre(AuthQuery::Privilege::MATCH, + AuthQuery::Privilege::REMOVE)); +} + +TEST_F(TestPrivilegeExtractor, MatchNodeRemoveProperty) { + QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), + REMOVE(PROPERTY_LOOKUP("n", {"prop", PROP_0})))); + EXPECT_THAT(GetRequiredPrivileges(storage), + UnorderedElementsAre(AuthQuery::Privilege::MATCH, + AuthQuery::Privilege::REMOVE)); +} + +TEST_F(TestPrivilegeExtractor, CreateIndex) { + QUERY(SINGLE_QUERY(CREATE_INDEX_ON(LABEL_0, PROP_0))); + EXPECT_THAT(GetRequiredPrivileges(storage), + UnorderedElementsAre(AuthQuery::Privilege::INDEX)); +} + +TEST_F(TestPrivilegeExtractor, AuthQuery) { + QUERY(SINGLE_QUERY(AUTH_QUERY(AuthQuery::Action::CREATE_ROLE, "", "role", "", + nullptr, std::vector<AuthQuery::Privilege>{}))); + EXPECT_THAT(GetRequiredPrivileges(storage), + UnorderedElementsAre(AuthQuery::Privilege::AUTH)); +} + +TEST_F(TestPrivilegeExtractor, StreamQuery) { + std::string stream_name("kafka"); + std::string stream_uri("localhost:1234"); + std::string stream_topic("tropik"); + std::string transform_uri("localhost:1234/file.py"); + + std::vector<Clause *> stream_clauses = { + CREATE_STREAM(stream_name, stream_uri, stream_topic, transform_uri, + nullptr, nullptr), + DROP_STREAM(stream_name), + SHOW_STREAMS, + START_STREAM(stream_name, nullptr), + STOP_STREAM(stream_name), + START_ALL_STREAMS, + STOP_ALL_STREAMS}; + + for (auto *stream_clause : stream_clauses) { + QUERY(SINGLE_QUERY(stream_clause)); + EXPECT_THAT(GetRequiredPrivileges(storage), + UnorderedElementsAre(AuthQuery::Privilege::STREAM)); + } +}