Integrate auth checks into query execution

Reviewers: mtomic, teon.banek

Reviewed By: mtomic

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1544
This commit is contained in:
Matej Ferencevic 2018-08-22 10:59:46 +02:00
parent 0249a280f8
commit 1b643958b6
26 changed files with 689 additions and 72 deletions

View File

@ -8,7 +8,8 @@
### Major Features and Improvements
* Kafka integration
* [Enterprise Ed.] Authentication and authorization support.
* [Enterprise Ed.] Kafka integration.
## v0.12.0

View File

@ -43,7 +43,8 @@ set(memgraph_src_files
durability/recovery.cpp
durability/snapshooter.cpp
durability/wal.cpp
glue/conversion.cpp
glue/auth.cpp
glue/communication.cpp
query/common.cpp
query/frontend/ast/ast.cpp
query/frontend/ast/cypher_main_visitor.cpp

View File

@ -40,6 +40,17 @@ std::string PermissionToString(Permission permission) {
}
}
std::string PermissionLevelToString(PermissionLevel level) {
switch (level) {
case PermissionLevel::GRANT:
return "GRANT";
case PermissionLevel::NEUTRAL:
return "NEUTRAL";
case PermissionLevel::DENY:
return "DENY";
}
}
Permissions::Permissions(uint64_t grants, uint64_t denies) {
// The deny bitmask has higher priority than the grant bitmask.
denies_ = denies;
@ -205,6 +216,7 @@ const Permissions User::GetPermissions() const {
const std::string &User::username() const { return username_; }
const Permissions &User::permissions() const { return permissions_; }
Permissions &User::permissions() { return permissions_; }
std::experimental::optional<Role> User::role() const { return role_; }

View File

@ -37,6 +37,9 @@ enum class PermissionLevel {
DENY,
};
// Function that converts a permission level to its string representation.
std::string PermissionLevelToString(PermissionLevel level);
class Permissions final {
public:
Permissions(uint64_t grants = 0, uint64_t denies = 0);
@ -113,6 +116,7 @@ class User final {
const std::string &username() const;
const Permissions &permissions() const;
Permissions &permissions();
std::experimental::optional<Role> role() const;

View File

@ -4,7 +4,7 @@
#include "communication/bolt/v1/value.hpp"
#include "database/graph_db_accessor.hpp"
#include "glue/conversion.hpp"
#include "glue/communication.hpp"
namespace database {

View File

@ -12,7 +12,7 @@
#include "durability/snapshot_value.hpp"
#include "durability/version.hpp"
#include "durability/wal.hpp"
#include "glue/conversion.hpp"
#include "glue/communication.hpp"
#include "query/typed_value.hpp"
#include "storage/address_types.hpp"
#include "transactions/type.hpp"

View File

@ -2,7 +2,7 @@
#include "communication/bolt/v1/encoder/base_encoder.hpp"
#include "database/graph_db_accessor.hpp"
#include "glue/conversion.hpp"
#include "glue/communication.hpp"
#include "utils/cast.hpp"
namespace durability {

28
src/glue/auth.cpp Normal file
View File

@ -0,0 +1,28 @@
#include "glue/auth.hpp"
namespace glue {
auth::Permission PrivilegeToPermission(query::AuthQuery::Privilege privilege) {
switch (privilege) {
case query::AuthQuery::Privilege::MATCH:
return auth::Permission::MATCH;
case query::AuthQuery::Privilege::CREATE:
return auth::Permission::CREATE;
case query::AuthQuery::Privilege::MERGE:
return auth::Permission::MERGE;
case query::AuthQuery::Privilege::DELETE:
return auth::Permission::DELETE;
case query::AuthQuery::Privilege::SET:
return auth::Permission::SET;
case query::AuthQuery::Privilege::REMOVE:
return auth::Permission::REMOVE;
case query::AuthQuery::Privilege::INDEX:
return auth::Permission::INDEX;
case query::AuthQuery::Privilege::AUTH:
return auth::Permission::AUTH;
case query::AuthQuery::Privilege::STREAM:
return auth::Permission::STREAM;
}
}
}

12
src/glue/auth.hpp Normal file
View File

@ -0,0 +1,12 @@
#include "auth/models.hpp"
#include "query/frontend/ast/ast.hpp"
namespace glue {
/**
* This function converts query::AuthQuery::Privilege to its corresponding
* auth::Permission.
*/
auth::Permission PrivilegeToPermission(query::AuthQuery::Privilege privilege);
} // namespace glue

View File

@ -1,4 +1,4 @@
#include "glue/conversion.hpp"
#include "glue/communication.hpp"
#include <map>
#include <string>

View File

@ -17,7 +17,8 @@
#include "database/distributed_graph_db.hpp"
#include "database/graph_db.hpp"
#include "distributed/pull_rpc_clients.hpp"
#include "glue/conversion.hpp"
#include "glue/auth.hpp"
#include "glue/communication.hpp"
#include "integrations/kafka/exceptions.hpp"
#include "integrations/kafka/streams.hpp"
#include "query/exceptions.hpp"
@ -95,7 +96,20 @@ class BoltSession final
for (const auto &kv : params)
params_tv.emplace(kv.first, glue::ToTypedValue(kv.second));
try {
return transaction_engine_.Interpret(query, params_tv);
auto result = transaction_engine_.Interpret(query, params_tv);
if (user_) {
const auto &permissions = user_->GetPermissions();
for (const auto &privilege : result.second) {
if (permissions.Has(glue::PrivilegeToPermission(privilege)) !=
auth::PermissionLevel::GRANT) {
transaction_engine_.Abort();
throw communication::bolt::ClientError(
"You are not authorized to execute this query! Please contact "
"your database administrator.");
}
}
}
return result.first;
} catch (const query::QueryException &e) {
// Wrap QueryException into ClientError, because we want to allow the
// client to fix their query.

View File

@ -21,7 +21,8 @@
#include "distributed/pull_rpc_clients.hpp"
#include "distributed/updates_rpc_clients.hpp"
#include "distributed/updates_rpc_server.hpp"
#include "glue/conversion.hpp"
#include "glue/auth.hpp"
#include "glue/communication.hpp"
#include "integrations/kafka/exceptions.hpp"
#include "integrations/kafka/streams.hpp"
#include "query/context.hpp"
@ -3895,7 +3896,8 @@ AuthHandler::AuthHandler(AuthQuery::Action action, std::string user,
Expression *password,
std::vector<AuthQuery::Privilege> privileges,
Symbol user_symbol, Symbol role_symbol,
Symbol grants_symbol)
Symbol privilege_symbol, Symbol effective_symbol,
Symbol details_symbol)
: action_(action),
user_(user),
role_(role),
@ -3904,7 +3906,9 @@ AuthHandler::AuthHandler(AuthQuery::Action action, std::string user,
privileges_(privileges),
user_symbol_(user_symbol),
role_symbol_(role_symbol),
grants_symbol_(grants_symbol) {}
privilege_symbol_(privilege_symbol),
effective_symbol_(effective_symbol),
details_symbol_(details_symbol) {}
bool AuthHandler::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
return visitor.Visit(*this);
@ -3921,7 +3925,7 @@ std::vector<Symbol> AuthHandler::OutputSymbols(const SymbolTable &) const {
return {role_symbol_};
case AuthQuery::Action::SHOW_GRANTS:
return {grants_symbol_};
return {privilege_symbol_, effective_symbol_, details_symbol_};
case AuthQuery::Action::CREATE_USER:
case AuthQuery::Action::DROP_USER:
@ -3944,53 +3948,59 @@ class AuthHandlerCursor : public Cursor {
std::vector<auth::Permission> GetAuthPermissions() {
std::vector<auth::Permission> ret;
for (const auto &privilege : self_.privileges()) {
switch (privilege) {
case AuthQuery::Privilege::MATCH:
ret.push_back(auth::Permission::MATCH);
break;
case AuthQuery::Privilege::CREATE:
ret.push_back(auth::Permission::CREATE);
break;
case AuthQuery::Privilege::MERGE:
ret.push_back(auth::Permission::MERGE);
break;
case AuthQuery::Privilege::DELETE:
ret.push_back(auth::Permission::DELETE);
break;
case AuthQuery::Privilege::SET:
ret.push_back(auth::Permission::SET);
break;
case AuthQuery::Privilege::REMOVE:
ret.push_back(auth::Permission::REMOVE);
break;
case AuthQuery::Privilege::INDEX:
ret.push_back(auth::Permission::INDEX);
break;
case AuthQuery::Privilege::AUTH:
ret.push_back(auth::Permission::AUTH);
break;
case AuthQuery::Privilege::STREAM:
ret.push_back(auth::Permission::STREAM);
break;
ret.push_back(glue::PrivilegeToPermission(privilege));
}
return ret;
}
std::vector<std::tuple<std::string, std::string, std::string>>
GetGrantsForAuthUser(const auth::User &user) {
std::vector<std::tuple<std::string, std::string, std::string>> ret;
const auto &permissions = user.GetPermissions();
for (const auto &privilege : kPrivilegesAll) {
auto permission = glue::PrivilegeToPermission(privilege);
auto effective = permissions.Has(permission);
if (permissions.Has(permission) != auth::PermissionLevel::NEUTRAL) {
std::vector<std::string> description;
auto user_level = user.permissions().Has(permission);
if (user_level == auth::PermissionLevel::GRANT) {
description.push_back("GRANTED TO USER");
} else if (user_level == auth::PermissionLevel::DENY) {
description.push_back("DENIED TO USER");
}
if (user.role()) {
auto role_level = user.role()->permissions().Has(permission);
if (role_level == auth::PermissionLevel::GRANT) {
description.push_back("GRANTED TO ROLE");
} else if (role_level == auth::PermissionLevel::DENY) {
description.push_back("DENIED TO ROLE");
}
}
ret.push_back({auth::PermissionToString(permission),
auth::PermissionLevelToString(effective),
utils::Join(description, ", ")});
}
}
return ret;
}
std::vector<std::string> GetGrantsFromAuthPermissions(
auth::Permissions &permissions) {
std::vector<std::string> grants, denies, ret;
for (const auto &permission : permissions.GetGrants()) {
grants.push_back(auth::PermissionToString(permission));
}
for (const auto &permission : permissions.GetDenies()) {
denies.push_back(auth::PermissionToString(permission));
}
if (grants.size() > 0) {
ret.push_back(fmt::format("GRANT {}", utils::Join(grants, ", ")));
}
if (denies.size() > 0) {
ret.push_back(fmt::format("DENY {}", utils::Join(denies, ", ")));
std::vector<std::tuple<std::string, std::string, std::string>>
GetGrantsForAuthRole(const auth::Role &role) {
std::vector<std::tuple<std::string, std::string, std::string>> ret;
const auto &permissions = role.permissions();
for (const auto &privilege : kPrivilegesAll) {
auto permission = glue::PrivilegeToPermission(privilege);
auto effective = permissions.Has(permission);
if (effective != auth::PermissionLevel::NEUTRAL) {
std::string description;
if (effective == auth::PermissionLevel::GRANT) {
description = "GRANTED TO ROLE";
} else if (effective == auth::PermissionLevel::DENY) {
description = "DENIED TO ROLE";
}
ret.push_back({auth::PermissionToString(permission),
auth::PermissionLevelToString(effective), description});
}
}
return ret;
}
@ -4195,16 +4205,18 @@ class AuthHandlerCursor : public Cursor {
self_.user_or_role());
}
if (user) {
grants_.emplace(GetGrantsFromAuthPermissions(user->permissions()));
grants_.emplace(GetGrantsForAuthUser(*user));
} else {
grants_.emplace(GetGrantsFromAuthPermissions(role->permissions()));
grants_.emplace(GetGrantsForAuthRole(*role));
}
grants_it_ = grants_->begin();
}
if (grants_it_ == grants_->end()) return false;
frame[self_.grants_symbol()] = *grants_it_;
frame[self_.privilege_symbol()] = std::get<0>(*grants_it_);
frame[self_.effective_symbol()] = std::get<1>(*grants_it_);
frame[self_.details_symbol()] = std::get<2>(*grants_it_);
grants_it_++;
return true;
@ -4258,8 +4270,11 @@ class AuthHandlerCursor : public Cursor {
std::vector<auth::User>::iterator users_it_;
std::experimental::optional<std::vector<auth::Role>> roles_;
std::vector<auth::Role>::iterator roles_it_;
std::experimental::optional<std::vector<std::string>> grants_;
std::vector<std::string>::iterator grants_it_;
std::experimental::optional<
std::vector<std::tuple<std::string, std::string, std::string>>>
grants_;
std::vector<std::tuple<std::string, std::string, std::string>>::iterator
grants_it_;
bool returned_role_for_user_{false};
};

View File

@ -2071,13 +2071,17 @@ and returns true, once.")
cpp<#))
(user-symbol "Symbol" :reader t)
(role-symbol "Symbol" :reader t)
(grants-symbol "Symbol" :reader t))
(privilege-symbol "Symbol" :reader t)
(effective-symbol "Symbol" :reader t)
(details-symbol "Symbol" :reader t))
(:public
#>cpp
AuthHandler(AuthQuery::Action action, std::string user, std::string role,
std::string user_or_role, Expression * password,
std::vector<AuthQuery::Privilege> privileges,
Symbol user_symbol, Symbol role_symbol, Symbol grants_symbol);
Symbol user_symbol, Symbol role_symbol,
Symbol privilege_symbol, Symbol effective_symbol,
Symbol details_symbol);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(database::GraphDbAccessor & db)

View File

@ -192,7 +192,9 @@ class RuleBasedPlanner {
auth_query->privileges_,
symbol_table.CreateSymbol("user", false),
symbol_table.CreateSymbol("role", false),
symbol_table.CreateSymbol("grants", false));
symbol_table.CreateSymbol("privilege", false),
symbol_table.CreateSymbol("effective", false),
symbol_table.CreateSymbol("details", false));
} else if (auto *create_stream =
dynamic_cast<query::CreateStream *>(clause)) {
DCHECK(!input_op) << "Unexpected operator before CreateStream";

View File

@ -17,9 +17,9 @@ class TransactionEngine final {
~TransactionEngine() { Abort(); }
std::vector<std::string> Interpret(
const std::string &query,
const std::map<std::string, TypedValue> &params) {
std::pair<std::vector<std::string>, std::vector<query::AuthQuery::Privilege>>
Interpret(const std::string &query,
const std::map<std::string, TypedValue> &params) {
// Clear pending results.
results_ = std::experimental::nullopt;
@ -59,14 +59,13 @@ class TransactionEngine final {
if (in_explicit_transaction_ && db_accessor_) AdvanceCommand();
// Create a DB accessor if we don't yet have one.
if (!db_accessor_)
db_accessor_ = db_.Access();
if (!db_accessor_) db_accessor_ = db_.Access();
// Interpret the query and return the headers.
try {
results_.emplace(
interpreter_(query, *db_accessor_, params, in_explicit_transaction_));
return results_->header();
return {results_->header(), results_->privileges()};
} catch (const utils::BasicException &) {
AbortCommand();
throw;

View File

@ -5,7 +5,7 @@
#include "communication/bolt/v1/decoder/decoder.hpp"
#include "communication/bolt/v1/encoder/base_encoder.hpp"
#include "glue/conversion.hpp"
#include "glue/communication.hpp"
#include "storage/pod_buffer.hpp"
#include "storage/property_value_store.hpp"

View File

@ -9,3 +9,6 @@ add_subdirectory(transactions)
# kafka test binaries
add_subdirectory(kafka)
# auth test binaries
add_subdirectory(auth)

View File

@ -33,3 +33,12 @@
- ../../../build_debug/kafka.py # kafka script
- ../../../build_debug/tests/integration/kafka/tester # tester binary
enable_network: true
- name: integration__auth
cd: auth
commands: TIMEOUT=600 ./runner.py
infiles:
- runner.py # runner script
- ../../../build_debug/memgraph # memgraph binary
- ../../../build_debug/tests/integration/auth/checker # checker binary
- ../../../build_debug/tests/integration/auth/tester # tester binary

View File

@ -0,0 +1,11 @@
set(target_name memgraph__integration__auth)
set(checker_target_name ${target_name}__checker)
set(tester_target_name ${target_name}__tester)
add_executable(${checker_target_name} checker.cpp)
set_target_properties(${checker_target_name} PROPERTIES OUTPUT_NAME checker)
target_link_libraries(${checker_target_name} mg-communication)
add_executable(${tester_target_name} tester.cpp)
set_target_properties(${tester_target_name} PROPERTIES OUTPUT_NAME tester)
target_link_libraries(${tester_target_name} mg-communication)

View File

@ -0,0 +1,63 @@
#include <gflags/gflags.h>
#include <glog/logging.h>
#include "communication/bolt/client.hpp"
#include "io/network/endpoint.hpp"
#include "io/network/utils.hpp"
DEFINE_string(address, "127.0.0.1", "Server address");
DEFINE_int32(port, 7687, "Server port");
DEFINE_string(username, "admin", "Username for the database");
DEFINE_string(password, "admin", "Password for the database");
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
/**
* Verifies that user 'user' has privileges that are given as positional
* arguments.
*/
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
io::network::Endpoint endpoint(io::network::ResolveHostname(FLAGS_address),
FLAGS_port);
communication::ClientContext context(FLAGS_use_ssl);
communication::bolt::Client client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
LOG(FATAL) << "Couldn't connect to server " << FLAGS_address << ":"
<< FLAGS_port;
}
try {
auto ret = client.Execute("SHOW GRANTS FOR user", {});
const auto &records = ret.records;
uint64_t count_got = 0;
for (const auto &record : records) {
count_got += record.size();
}
if (count_got != argc - 1) {
LOG(FATAL) << "Expected the grants to have " << argc - 1
<< " entries but they had " << count_got << " entries!";
}
uint64_t pos = 1;
for (const auto &record : records) {
for (const auto &value : record) {
std::string expected(argv[pos++]);
if (value.ValueString() != expected) {
LOG(FATAL) << "Expected to get the value '" << expected
<< " but got the value '" << value.ValueString() << "'";
}
}
}
} catch (const communication::bolt::ClientQueryException &e) {
LOG(FATAL) << "The query shoudn't have failed but it failed with an "
"error message '"
<< e.what() << "'";
}
return 0;
}

360
tests/integration/auth/runner.py Executable file
View File

@ -0,0 +1,360 @@
#!/usr/bin/python3 -u
import argparse
import atexit
import os
import subprocess
import sys
import tempfile
import time
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
# When you create a new permission just add a testcase to this list (a tuple
# of query, touple of required permissions) and the test will automatically
# detect the new permission (from the query required permissions) and test all
# queries against all combinations of permissions.
QUERIES = [
# CREATE
(
"CREATE (n)",
("CREATE",)
),
(
"MATCH (n), (m) CREATE (n)-[:e]->(m)",
("CREATE", "MATCH")
),
# DELETE
(
"MATCH (n) DELETE n",
("DELETE", "MATCH"),
),
(
"MATCH (n) DETACH DELETE n",
("DELETE", "MATCH"),
),
# MATCH
(
"MATCH (n) RETURN n",
("MATCH",)
),
(
"MATCH (n), (m) RETURN count(n), count(m)",
("MATCH",)
),
# MERGE
(
"MERGE (n) ON CREATE SET n.created = timestamp() "
"ON MATCH SET n.lastSeen = timestamp() "
"RETURN n.name, n.created, n.lastSeen",
("MERGE",)
),
# SET
(
"MATCH (n) SET n.value = 0 RETURN n",
("SET", "MATCH")
),
(
"MATCH (n), (m) SET n.value = m.value",
("SET", "MATCH")
),
# REMOVE
(
"MATCH (n) REMOVE n.value",
("REMOVE", "MATCH")
),
(
"MATCH (n), (m) REMOVE n.value, m.value",
("REMOVE", "MATCH")
),
# INDEX
(
"CREATE INDEX ON :User (id)",
("INDEX",)
),
# AUTH
(
"CREATE ROLE test_role",
("AUTH",)
),
(
"DROP ROLE test_role",
("AUTH",)
),
(
"SHOW ROLES",
("AUTH",)
),
(
"CREATE USER test_user",
("AUTH",)
),
(
"SET PASSWORD FOR test_user TO '1234'",
("AUTH",)
),
(
"DROP USER test_user",
("AUTH",)
),
(
"SHOW USERS",
("AUTH",)
),
(
"GRANT ROLE test_role TO test_user",
("AUTH",)
),
(
"REVOKE ROLE test_role FROM test_user",
("AUTH",)
),
(
"GRANT ALL PRIVILEGES TO test_user",
("AUTH",)
),
(
"DENY ALL PRIVILEGES TO test_user",
("AUTH",)
),
(
"REVOKE ALL PRIVILEGES FROM test_user",
("AUTH",)
),
(
"SHOW GRANTS FOR test_user",
("AUTH",)
),
(
"SHOW ROLE FOR USER test_user",
("AUTH",)
),
(
"SHOW USERS FOR ROLE test_role",
("AUTH",)
),
# STREAM
(
"CREATE STREAM strim AS LOAD DATA KAFKA '127.0.0.1:9092' WITH TOPIC "
"'test' WITH TRANSFORM 'http://127.0.0.1/transform.py'",
("STREAM",)
),
(
"DROP STREAM strim",
("STREAM",)
),
(
"SHOW STREAMS",
("STREAM",)
),
(
"START STREAM strim",
("STREAM",)
),
(
"STOP STREAM strim",
("STREAM",)
),
(
"START ALL STREAMS",
("STREAM",)
),
(
"STOP ALL STREAMS",
("STREAM",)
),
(
"TEST STREAM strim",
("STREAM",)
),
]
UNAUTHORIZED_ERROR = "You are not authorized to execute this query! Please " \
"contact your database administrator."
def wait_for_server(port, delay=0.1):
cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
while subprocess.call(cmd) != 0:
time.sleep(0.01)
time.sleep(delay)
def execute_tester(binary, queries, should_fail=False, failure_message="",
username="", password="", check_failure=True):
args = [binary, "--username", username, "--password", password]
if should_fail:
args.append("--should-fail")
if failure_message:
args.extend(["--failure-message", failure_message])
if check_failure:
args.append("--check-failure")
args.extend(queries)
subprocess.run(args).check_returncode()
def execute_checker(binary, grants):
args = [binary] + grants
subprocess.run(args).check_returncode()
def get_permissions(permissions, mask):
ret, pos = [], 0
while mask > 0:
if mask & 1:
ret.append(permissions[pos])
mask >>= 1
pos += 1
return ret
def check_permissions(query_perms, user_perms):
return set(query_perms).issubset(user_perms)
def execute_test(memgraph_binary, tester_binary, checker_binary):
storage_directory = tempfile.TemporaryDirectory()
memgraph_args = [memgraph_binary,
"--durability-directory", storage_directory.name]
def execute_admin_queries(queries):
return execute_tester(tester_binary, queries, should_fail=False,
check_failure=True, username="admin",
password="admin")
def execute_user_queries(queries, should_fail=False, failure_message="",
check_failure=True):
return execute_tester(tester_binary, queries, should_fail,
failure_message, "user", "user", check_failure)
# Start the memgraph binary
memgraph = subprocess.Popen(list(map(str, memgraph_args)))
time.sleep(0.1)
assert memgraph.poll() is None, "Memgraph process died prematurely!"
wait_for_server(7687)
# Register cleanup function
@atexit.register
def cleanup():
if memgraph.poll() is None:
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
# Prepare all users
execute_admin_queries([
"CREATE USER admin IDENTIFIED BY 'admin'",
"GRANT ALL PRIVILEGES TO admin",
"CREATE USER user IDENTIFIED BY 'user'"
])
# Find all existing permissions
permissions = set()
for query, perms in QUERIES:
permissions.update(perms)
permissions = list(sorted(permissions))
# Run the test with all combinations of permissions
print("\033[1;36m~~ Starting query test ~~\033[0m")
for mask in range(0, 2 ** len(permissions)):
user_perms = get_permissions(permissions, mask)
print("\033[1;34m~~ Checking queries with privileges: ",
", ".join(user_perms), " ~~\033[0m")
admin_queries = ["REVOKE ALL PRIVILEGES FROM user"]
if len(user_perms) > 0:
admin_queries.append(
"GRANT {} TO user".format(", ".join(user_perms)))
execute_admin_queries(admin_queries)
authorized, unauthorized = [], []
for query, query_perms in QUERIES:
if check_permissions(query_perms, user_perms):
authorized.append(query)
else:
unauthorized.append(query)
execute_user_queries(authorized, check_failure=False,
failure_message=UNAUTHORIZED_ERROR)
execute_user_queries(unauthorized, should_fail=True,
failure_message=UNAUTHORIZED_ERROR)
print("\033[1;36m~~ Finished query test ~~\033[0m\n")
# Run the user/role permissions test
print("\033[1;36m~~ Starting permissions test ~~\033[0m")
execute_admin_queries([
"CREATE ROLE role",
"REVOKE ALL PRIVILEGES FROM user",
])
execute_checker(checker_binary, [])
for user_perm in ["GRANT", "DENY", "REVOKE"]:
for role_perm in ["GRANT", "DENY", "REVOKE"]:
for mapped in [True, False]:
print("\033[1;34m~~ Checking permissions with user ",
user_perm, ", role ", role_perm,
"user mapped to role:", mapped, " ~~\033[0m")
if mapped:
execute_admin_queries(["GRANT ROLE role TO user"])
else:
execute_admin_queries(["REVOKE ROLE role FROM user"])
user_prep = "FROM" if user_perm == "REVOKE" else "TO"
role_prep = "FROM" if role_perm == "REVOKE" else "TO"
execute_admin_queries([
"{} MATCH {} user".format(user_perm, user_prep),
"{} MATCH {} role".format(role_perm, role_prep)
])
expected = []
perms = [user_perm, role_perm] if mapped else [user_perm]
if "DENY" in perms:
expected = ["MATCH", "DENY"]
elif "GRANT" in perms:
expected = ["MATCH", "GRANT"]
if len(expected) > 0:
details = []
if user_perm == "GRANT":
details.append("GRANTED TO USER")
elif user_perm == "DENY":
details.append("DENIED TO USER")
if mapped:
if role_perm == "GRANT":
details.append("GRANTED TO ROLE")
elif role_perm == "DENY":
details.append("DENIED TO ROLE")
expected.append(", ".join(details))
execute_checker(checker_binary, expected)
print("\033[1;36m~~ Finished permissions test ~~\033[0m\n")
# Shutdown the memgraph binary
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
if __name__ == "__main__":
memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph")
if not os.path.exists(memgraph_binary):
memgraph_binary = os.path.join(PROJECT_DIR, "build_debug", "memgraph")
tester_binary = os.path.join(PROJECT_DIR, "build", "tests",
"integration", "auth", "tester")
if not os.path.exists(tester_binary):
tester_binary = os.path.join(PROJECT_DIR, "build_debug", "tests",
"integration", "auth", "tester")
checker_binary = os.path.join(PROJECT_DIR, "build", "tests",
"integration", "auth", "checker")
if not os.path.exists(checker_binary):
checker_binary = os.path.join(PROJECT_DIR, "build_debug", "tests",
"integration", "auth", "checker")
parser = argparse.ArgumentParser()
parser.add_argument("--memgraph", default=memgraph_binary)
parser.add_argument("--tester", default=tester_binary)
parser.add_argument("--checker", default=checker_binary)
args = parser.parse_args()
execute_test(args.memgraph, args.tester, args.checker)
sys.exit(0)

View File

@ -0,0 +1,79 @@
#include <gflags/gflags.h>
#include <glog/logging.h>
#include "communication/bolt/client.hpp"
#include "io/network/endpoint.hpp"
#include "io/network/utils.hpp"
DEFINE_string(address, "127.0.0.1", "Server address");
DEFINE_int32(port, 7687, "Server port");
DEFINE_string(username, "", "Username for the database");
DEFINE_string(password, "", "Password for the database");
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
DEFINE_bool(check_failure, false, "Set to true to enable failure checking.");
DEFINE_bool(should_fail, false, "Set to true to expect a failure.");
DEFINE_string(failure_message, "", "Set to the expected failure message.");
/**
* Executes queries passed as positional arguments and verifies whether they
* succeeded, failed, failed with a specific error message or executed without a
* specific error occurring.
*/
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
io::network::Endpoint endpoint(io::network::ResolveHostname(FLAGS_address),
FLAGS_port);
communication::ClientContext context(FLAGS_use_ssl);
communication::bolt::Client client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
LOG(FATAL) << "Couldn't connect to server " << FLAGS_address << ":"
<< FLAGS_port;
}
for (int i = 1; i < argc; ++i) {
std::string query(argv[i]);
try {
client.Execute(query, {});
} catch (const communication::bolt::ClientQueryException &e) {
if (!FLAGS_check_failure) {
if (!FLAGS_failure_message.empty() &&
e.what() == FLAGS_failure_message) {
LOG(FATAL)
<< "The query should have succeeded or failed with an error "
"message that isn't equal to '"
<< FLAGS_failure_message
<< "' but it failed with that error message";
}
continue;
}
if (FLAGS_should_fail) {
if (!FLAGS_failure_message.empty() &&
e.what() != FLAGS_failure_message) {
LOG(FATAL)
<< "The query should have failed with an error message of '"
<< FLAGS_failure_message << "' but instead it failed with '"
<< e.what() << "'";
}
return 0;
} else {
LOG(FATAL) << "The query shoudn't have failed but it failed with an "
"error message '"
<< e.what() << "'";
}
}
if (!FLAGS_check_failure) continue;
if (FLAGS_should_fail) {
LOG(FATAL) << "The query should have failed but instead it executed "
"successfully!";
}
}
return 0;
}

View File

@ -6,7 +6,7 @@
#include "communication/result_stream_faker.hpp"
#include "database/distributed_graph_db.hpp"
#include "database/graph_db_accessor.hpp"
#include "glue/conversion.hpp"
#include "glue/communication.hpp"
#include "query/interpreter.hpp"
#include "query/typed_value.hpp"

View File

@ -7,7 +7,7 @@
#include "durability/hashed_file_writer.hpp"
#include "durability/paths.hpp"
#include "durability/version.hpp"
#include "glue/conversion.hpp"
#include "glue/communication.hpp"
#include "query/typed_value.hpp"
#include "utils/file.hpp"

View File

@ -4,7 +4,7 @@
#include "communication/bolt/v1/encoder/encoder.hpp"
#include "database/graph_db.hpp"
#include "database/graph_db_accessor.hpp"
#include "glue/conversion.hpp"
#include "glue/communication.hpp"
using communication::bolt::Value;