diff --git a/release/arch-pkg/PKGBUILD.proto b/release/arch-pkg/PKGBUILD.proto index d5c10eefd..6ea748740 100644 --- a/release/arch-pkg/PKGBUILD.proto +++ b/release/arch-pkg/PKGBUILD.proto @@ -16,7 +16,7 @@ optdepends=() provides=() conflicts=() replaces=() -backup=("etc/memgraph/memgraph.conf" "etc/logrotate.d/memgraph") +backup=("etc/memgraph/memgraph.conf" "etc/logrotate.d/memgraph" "etc/logrotate.d/memgraph_audit") options=() install=memgraph.install changelog= diff --git a/release/debian/conffiles b/release/debian/conffiles index e20b1a1c4..27ad25f2c 100644 --- a/release/debian/conffiles +++ b/release/debian/conffiles @@ -1,2 +1,3 @@ /etc/memgraph/memgraph.conf /etc/logrotate.d/memgraph +/etc/logrotate.d/memgraph_audit diff --git a/release/logrotate_audit.conf b/release/logrotate_audit.conf new file mode 100644 index 000000000..aa0126c83 --- /dev/null +++ b/release/logrotate_audit.conf @@ -0,0 +1,13 @@ +# logrotate configuration for Memgraph Audit logs +# see "man logrotate" for details + +/var/lib/memgraph/durability/audit/audit.log { + # rotate log files daily + daily + # keep one year worth of audit logs + rotate 365 + # send SIGUSR2 to notify memgraph to recreate logfile + postrotate + /usr/bin/killall -s SIGUSR2 memgraph + endscript +} diff --git a/release/rpm/memgraph.spec.in b/release/rpm/memgraph.spec.in index 50b9ee153..7cdbf9636 100644 --- a/release/rpm/memgraph.spec.in +++ b/release/rpm/memgraph.spec.in @@ -120,6 +120,7 @@ chattr -i -R /usr/share/memgraph/examples || true # uses plain %config. %config(noreplace) "/etc/memgraph/memgraph.conf" %config(noreplace) "/etc/logrotate.d/memgraph" +%config(noreplace) "/etc/logrotate.d/memgraph_audit" @CPACK_RPM_USER_INSTALL_FILES@ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b8a32b69c..cef7d41b2 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -16,6 +16,7 @@ add_subdirectory(rpc) # Memgraph Single Node # ---------------------------------------------------------------------------- set(mg_single_node_sources + audit/log.cpp data_structures/concurrent/skiplist_gc.cpp database/single_node/config.cpp database/single_node/graph_db.cpp @@ -95,6 +96,7 @@ target_compile_definitions(mg-single-node PUBLIC MG_SINGLE_NODE) # ---------------------------------------------------------------------------- set(mg_distributed_sources + audit/log.cpp database/distributed/distributed_counters.cpp database/distributed/distributed_graph_db.cpp distributed/bfs_rpc_clients.cpp @@ -256,6 +258,7 @@ target_compile_definitions(mg-distributed PUBLIC MG_DISTRIBUTED) # Memgraph Single Node High Availability # ---------------------------------------------------------------------------- set(mg_single_node_ha_sources + audit/log.cpp data_structures/concurrent/skiplist_gc.cpp database/single_node_ha/config.cpp database/single_node_ha/graph_db.cpp @@ -404,6 +407,8 @@ install(FILES ${CMAKE_SOURCE_DIR}/config/community.conf # Install logrotate configuration (must use absolute path). install(FILES ${CMAKE_SOURCE_DIR}/release/logrotate.conf DESTINATION /etc/logrotate.d RENAME memgraph) +install(FILES ${CMAKE_SOURCE_DIR}/release/logrotate_audit.conf + DESTINATION /etc/logrotate.d RENAME memgraph_audit) # Create empty directories for default location of lib and log. install(CODE "file(MAKE_DIRECTORY \$ENV{DESTDIR}/var/log/memgraph \$ENV{DESTDIR}/var/lib/memgraph)") diff --git a/src/audit/log.cpp b/src/audit/log.cpp new file mode 100644 index 000000000..2c9eb475c --- /dev/null +++ b/src/audit/log.cpp @@ -0,0 +1,112 @@ +#include "audit/log.hpp" + +#include + +#include +#include +#include + +#include "utils/string.hpp" + +namespace audit { + +// Helper function that converts a `PropertyValue` to `nlohmann::json`. +inline nlohmann::json PropertyValueToJson(const PropertyValue &pv) { + nlohmann::json ret; + switch (pv.type()) { + case PropertyValue::Type::Null: + break; + case PropertyValue::Type::Bool: + ret = pv.Value(); + break; + case PropertyValue::Type::Int: + ret = pv.Value(); + break; + case PropertyValue::Type::Double: + ret = pv.Value(); + break; + case PropertyValue::Type::String: + ret = pv.Value(); + break; + case PropertyValue::Type::List: { + ret = nlohmann::json::array(); + for (const auto &item : pv.Value>()) { + ret.push_back(PropertyValueToJson(item)); + } + break; + } + case PropertyValue::Type::Map: { + ret = nlohmann::json::object(); + for (const auto &item : + pv.Value>()) { + ret.push_back(nlohmann::json::object_t::value_type( + item.first, PropertyValueToJson(item.second))); + } + break; + } + } + return ret; +} + +Log::Log(const std::experimental::filesystem::path &storage_directory, + int32_t buffer_size, int32_t buffer_flush_interval_millis) + : storage_directory_(storage_directory), + buffer_size_(buffer_size), + buffer_flush_interval_millis_(buffer_flush_interval_millis), + started_(false) {} + +void Log::Start() { + CHECK(!started_) << "Trying to start an already started audit log!"; + + utils::EnsureDirOrDie(storage_directory_); + + buffer_.emplace(buffer_size_); + started_ = true; + + ReopenLog(); + scheduler_.Run("Audit", + std::chrono::milliseconds(buffer_flush_interval_millis_), + [&] { Flush(); }); +} + +Log::~Log() { + if (!started_) return; + + started_ = false; + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + + scheduler_.Stop(); + Flush(); +} + +void Log::Record(const std::string &address, const std::string &username, + const std::string &query, const PropertyValue ¶ms) { + if (!started_.load(std::memory_order_relaxed)) return; + auto timestamp = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + buffer_->emplace(Item{timestamp, address, username, query, params}); +} + +void Log::ReopenLog() { + if (!started_.load(std::memory_order_relaxed)) return; + std::lock_guard guard(lock_); + if (log_.IsOpen()) log_.Close(); + log_.Open(storage_directory_ / "audit.log"); +} + +void Log::Flush() { + std::lock_guard guard(lock_); + for (uint64_t i = 0; i < buffer_size_; ++i) { + auto item = buffer_->pop(); + if (!item) break; + log_.Write( + fmt::format("{}.{:06d},{},{},{},{}\n", item->timestamp / 1000000, + item->timestamp % 1000000, item->address, item->username, + utils::Escape(item->query), + utils::Escape(PropertyValueToJson(item->params).dump()))); + } + log_.Sync(); +} + +} // namespace audit diff --git a/src/audit/log.hpp b/src/audit/log.hpp new file mode 100644 index 000000000..21d7044e0 --- /dev/null +++ b/src/audit/log.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include + +#include "data_structures/ring_buffer.hpp" +#include "storage/common/types/property_value.hpp" +#include "utils/file.hpp" +#include "utils/scheduler.hpp" + +namespace audit { + +const uint64_t kBufferSizeDefault = 100000; +const uint64_t kBufferFlushIntervalMillisDefault = 200; + +/// This class implements an audit log. Functions used for logging are +/// thread-safe, functions used for setup aren't thread-safe. +class Log { + private: + struct Item { + int64_t timestamp; + std::string address; + std::string username; + std::string query; + PropertyValue params; + }; + + public: + Log(const std::experimental::filesystem::path &storage_directory, + int32_t buffer_size, int32_t buffer_flush_interval_millis); + + ~Log(); + + Log(const Log &) = delete; + Log(Log &&) = delete; + Log &operator=(const Log &) = delete; + Log &operator=(Log &&) = delete; + + /// Starts the audit log. If you don't want to use the audit log just don't + /// start it. All functions can still be used when the log isn't started and + /// they won't do anything. Isn't thread-safe. + void Start(); + + /// Adds an entry to the audit log. Thread-safe. + void Record(const std::string &address, const std::string &username, + const std::string &query, const PropertyValue ¶ms); + + /// Reopens the log file. Used for log file rotation. Thread-safe. + void ReopenLog(); + + private: + void Flush(); + + std::experimental::filesystem::path storage_directory_; + int32_t buffer_size_; + int32_t buffer_flush_interval_millis_; + std::atomic started_; + + std::experimental::optional> buffer_; + utils::Scheduler scheduler_; + + utils::LogFile log_; + std::mutex lock_; +}; + +} // namespace audit diff --git a/src/memgraph.cpp b/src/memgraph.cpp index e8c4c2d39..7ec5a5bd4 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -39,18 +39,62 @@ DEFINE_bool(telemetry_enabled, false, "the database runtime (vertex and edge counts and resource usage) " "to allow for easier improvement of the product."); +// Audit logging flags. +DEFINE_bool(audit_enabled, false, "Set to true to enable audit logging."); +DEFINE_VALIDATED_int32(audit_buffer_size, audit::kBufferSizeDefault, + "Maximum number of items in the audit log buffer.", + FLAG_IN_RANGE(1, INT32_MAX)); +DEFINE_VALIDATED_int32( + audit_buffer_flush_interval_ms, audit::kBufferFlushIntervalMillisDefault, + "Interval (in milliseconds) used for flushing the audit log buffer.", + FLAG_IN_RANGE(10, INT32_MAX)); + using ServerT = communication::Server; using communication::ServerContext; void SingleNodeMain() { google::SetUsageMessage("Memgraph single-node database server"); + + // All enterprise features should be constructed before the main database + // storage. This will cause them to be destructed *after* the main database + // storage. That way any errors that happen during enterprise features + // destruction won't have an impact on the storage engine. + // Example: When the main storage is destructed it makes a snapshot. When + // audit logging is destructed it syncs all pending data to disk and that can + // fail. That is why it must be destructed *after* the main database storage + // to minimise the impact of their failure on the main storage. + + // Begin enterprise features initialization + + auto durability_directory = + std::experimental::filesystem::path(FLAGS_durability_directory); + + // Auth + auth::Auth auth{durability_directory / "auth"}; + + // Audit log + audit::Log audit_log{durability_directory / "audit", FLAGS_audit_buffer_size, + FLAGS_audit_buffer_flush_interval_ms}; + // Start the log if enabled. + if (FLAGS_audit_enabled) { + audit_log.Start(); + } + // Setup SIGUSR2 to be used for reopening audit log files, when e.g. logrotate + // rotates our audit logs. + CHECK(utils::SignalHandler::RegisterHandler( + utils::Signal::User2, [&audit_log]() { audit_log.ReopenLog(); })) + << "Unable to register SIGUSR2 handler!"; + + // End enterprise features initialization + + // Main storage and execution engines initialization + database::GraphDb db; query::Interpreter interpreter; - SessionData session_data{&db, &interpreter}; + SessionData session_data{&db, &interpreter, &auth, &audit_log}; integrations::kafka::Streams kafka_streams{ - std::experimental::filesystem::path(FLAGS_durability_directory) / - "streams", + durability_directory / "streams", [&session_data]( const std::string &query, const std::map ¶ms) { @@ -64,7 +108,7 @@ void SingleNodeMain() { LOG(ERROR) << e.what(); } - session_data.interpreter->auth_ = &session_data.auth; + session_data.interpreter->auth_ = &auth; session_data.interpreter->kafka_streams_ = &kafka_streams; ServerContext context; @@ -83,9 +127,7 @@ void SingleNodeMain() { if (FLAGS_telemetry_enabled) { telemetry.emplace( "https://telemetry.memgraph.com/88b5e7e8-746a-11e8-9f85-538a9e9690cc/", - std::experimental::filesystem::path(FLAGS_durability_directory) / - "telemetry", - std::chrono::minutes(10)); + durability_directory / "telemetry", std::chrono::minutes(10)); telemetry->AddCollector("db", [&db]() -> nlohmann::json { auto dba = db.Access(); return {{"vertices", dba->VerticesCount()}, {"edges", dba->EdgesCount()}}; @@ -104,6 +146,4 @@ void SingleNodeMain() { server.AwaitShutdown(); } -int main(int argc, char **argv) { - return WithInit(argc, argv, SingleNodeMain); -} +int main(int argc, char **argv) { return WithInit(argc, argv, SingleNodeMain); } diff --git a/src/memgraph_distributed.cpp b/src/memgraph_distributed.cpp index c182c9d98..eaecec94c 100644 --- a/src/memgraph_distributed.cpp +++ b/src/memgraph_distributed.cpp @@ -40,6 +40,16 @@ DEFINE_bool(telemetry_enabled, false, "the database runtime (vertex and edge counts and resource usage) " "to allow for easier improvement of the product."); +// Audit logging flags. +DEFINE_bool(audit_enabled, false, "Set to true to enable audit logging."); +DEFINE_VALIDATED_int32(audit_buffer_size, audit::kBufferSizeDefault, + "Maximum number of items in the audit log buffer.", + FLAG_IN_RANGE(1, INT32_MAX)); +DEFINE_VALIDATED_int32( + audit_buffer_flush_interval_ms, audit::kBufferFlushIntervalMillisDefault, + "Interval (in milliseconds) used for flushing the audit log buffer.", + FLAG_IN_RANGE(10, INT32_MAX)); + using ServerT = communication::Server; using communication::ServerContext; @@ -55,13 +65,26 @@ DECLARE_int32(worker_id); void MasterMain() { google::SetUsageMessage("Memgraph distributed master"); + auto durability_directory = + std::experimental::filesystem::path(FLAGS_durability_directory); + + auth::Auth auth{durability_directory / "auth"}; + + audit::Log audit_log{durability_directory / "audit", FLAGS_audit_buffer_size, + FLAGS_audit_buffer_flush_interval_ms}; + if (FLAGS_audit_enabled) { + audit_log.Start(); + } + CHECK(utils::SignalHandler::RegisterHandler( + utils::Signal::User2, [&audit_log]() { audit_log.ReopenLog(); })) + << "Unable to register SIGUSR2 handler!"; + database::Master db; query::DistributedInterpreter interpreter(&db); - SessionData session_data{&db, &interpreter}; + SessionData session_data{&db, &interpreter, &auth, &audit_log}; integrations::kafka::Streams kafka_streams{ - std::experimental::filesystem::path(FLAGS_durability_directory) / - "streams", + durability_directory / "streams", [&session_data]( const std::string &query, const std::map ¶ms) { @@ -75,7 +98,7 @@ void MasterMain() { LOG(ERROR) << e.what(); } - session_data.interpreter->auth_ = &session_data.auth; + session_data.interpreter->auth_ = &auth; session_data.interpreter->kafka_streams_ = &kafka_streams; ServerContext context; diff --git a/src/memgraph_ha.cpp b/src/memgraph_ha.cpp index 14064ab53..8ee4d109e 100644 --- a/src/memgraph_ha.cpp +++ b/src/memgraph_ha.cpp @@ -32,18 +32,42 @@ DEFINE_VALIDATED_int32(session_inactivity_timeout, 1800, DEFINE_string(cert_file, "", "Certificate file to use."); DEFINE_string(key_file, "", "Key file to use."); +// Audit logging flags. +DEFINE_bool(audit_enabled, false, "Set to true to enable audit logging."); +DEFINE_VALIDATED_int32(audit_buffer_size, audit::kBufferSizeDefault, + "Maximum number of items in the audit log buffer.", + FLAG_IN_RANGE(1, INT32_MAX)); +DEFINE_VALIDATED_int32( + audit_buffer_flush_interval_ms, audit::kBufferFlushIntervalMillisDefault, + "Interval (in milliseconds) used for flushing the audit log buffer.", + FLAG_IN_RANGE(10, INT32_MAX)); + using ServerT = communication::Server; using communication::ServerContext; void SingleNodeHAMain() { google::SetUsageMessage("Memgraph high availability single-node database server"); + + auto durability_directory = + std::experimental::filesystem::path(FLAGS_durability_directory); + + auth::Auth auth{durability_directory / "auth"}; + + audit::Log audit_log{durability_directory / "audit", FLAGS_audit_buffer_size, + FLAGS_audit_buffer_flush_interval_ms}; + if (FLAGS_audit_enabled) { + audit_log.Start(); + } + CHECK(utils::SignalHandler::RegisterHandler( + utils::Signal::User2, [&audit_log]() { audit_log.ReopenLog(); })) + << "Unable to register SIGUSR2 handler!"; + database::GraphDb db; query::Interpreter interpreter; - SessionData session_data{&db, &interpreter}; + SessionData session_data{&db, &interpreter, &auth, &audit_log}; integrations::kafka::Streams kafka_streams{ - std::experimental::filesystem::path(FLAGS_durability_directory) / - "streams", + durability_directory / "streams", [&session_data]( const std::string &query, const std::map ¶ms) { @@ -57,7 +81,7 @@ void SingleNodeHAMain() { LOG(ERROR) << e.what(); } - session_data.interpreter->auth_ = &session_data.auth; + session_data.interpreter->auth_ = &auth; session_data.interpreter->kafka_streams_ = &kafka_streams; ServerContext context; diff --git a/src/memgraph_init.cpp b/src/memgraph_init.cpp index e75aecd62..9a7b807af 100644 --- a/src/memgraph_init.cpp +++ b/src/memgraph_init.cpp @@ -21,14 +21,17 @@ DEFINE_uint64(memory_warning_threshold, 1024, "less available RAM it will log a warning. Set to 0 to " "disable."); -BoltSession::BoltSession(SessionData *data, const io::network::Endpoint &, +BoltSession::BoltSession(SessionData *data, + const io::network::Endpoint &endpoint, communication::InputStream *input_stream, communication::OutputStream *output_stream) : communication::bolt::Session(input_stream, output_stream), transaction_engine_(data->db, data->interpreter), - auth_(&data->auth) {} + auth_(data->auth), + audit_log_(data->audit_log), + endpoint_(endpoint) {} using TEncoder = communication::bolt::Session BoltSession::Interpret( std::map params_pv; for (const auto &kv : params) params_pv.emplace(kv.first, glue::ToPropertyValue(kv.second)); + audit_log_->Record(endpoint_.address(), user_ ? user_->username() : "", query, + params_pv); try { auto result = transaction_engine_.Interpret(query, params_pv); if (user_) { diff --git a/src/memgraph_init.hpp b/src/memgraph_init.hpp index 36fed15ec..373092726 100644 --- a/src/memgraph_init.hpp +++ b/src/memgraph_init.hpp @@ -9,6 +9,7 @@ #include +#include "audit/log.hpp" #include "auth/auth.hpp" #include "communication/bolt/v1/session.hpp" #include "communication/init.hpp" @@ -21,10 +22,18 @@ DECLARE_string(durability_directory); /// Encapsulates Dbms and Interpreter that are passed through the network server /// and worker to the session. struct SessionData { - database::GraphDb *db{nullptr}; - query::Interpreter *interpreter{nullptr}; - auth::Auth auth{ - std::experimental::filesystem::path(FLAGS_durability_directory) / "auth"}; + // Explicit constructor here to ensure that pointers to all objects are + // supplied. + SessionData(database::GraphDb *_db, query::Interpreter *_interpreter, + auth::Auth *_auth, audit::Log *_audit_log) + : db(_db), + interpreter(_interpreter), + auth(_auth), + audit_log(_audit_log) {} + database::GraphDb *db; + query::Interpreter *interpreter; + auth::Auth *auth; + audit::Log *audit_log; }; class BoltSession final @@ -66,6 +75,8 @@ class BoltSession final query::TransactionEngine transaction_engine_; auth::Auth *auth_; std::experimental::optional user_; + audit::Log *audit_log_; + io::network::Endpoint endpoint_; }; /// Class that implements ResultStream API for Kafka. diff --git a/src/utils/signals.hpp b/src/utils/signals.hpp index 38f9a0ca9..2e9ac65d3 100644 --- a/src/utils/signals.hpp +++ b/src/utils/signals.hpp @@ -21,6 +21,7 @@ enum class Signal : int { Pipe = SIGPIPE, BusError = SIGBUS, User1 = SIGUSR1, + User2 = SIGUSR2, }; /** diff --git a/tests/feature_benchmark/kafka/benchmark.cpp b/tests/feature_benchmark/kafka/benchmark.cpp index 8e206369f..a4c5e3fa3 100644 --- a/tests/feature_benchmark/kafka/benchmark.cpp +++ b/tests/feature_benchmark/kafka/benchmark.cpp @@ -29,9 +29,19 @@ DEFINE_string(output_file, "", "Output file where shold the results be."); void KafkaBenchmarkMain() { google::SetUsageMessage("Memgraph kafka benchmark database server"); + + auto durability_directory = + std::experimental::filesystem::path(FLAGS_durability_directory); + + auth::Auth auth{durability_directory / "auth"}; + + audit::Log audit_log{durability_directory / "audit", + audit::kBufferSizeDefault, + audit::kBufferFlushIntervalMillisDefault}; + query::Interpreter interpreter; database::GraphDb db; - SessionData session_data{&db, &interpreter}; + SessionData session_data{&db, &interpreter, &auth, &audit_log}; std::atomic query_counter{0}; std::atomic timeout_reached{false}; @@ -47,7 +57,7 @@ void KafkaBenchmarkMain() { query_counter++; }}; - session_data.interpreter->auth_ = &session_data.auth; + session_data.interpreter->auth_ = &auth; session_data.interpreter->kafka_streams_ = &kafka_streams; std::string stream_name = "benchmark"; diff --git a/tests/integration/CMakeLists.txt b/tests/integration/CMakeLists.txt index 8cd148f4e..0bb188d74 100644 --- a/tests/integration/CMakeLists.txt +++ b/tests/integration/CMakeLists.txt @@ -19,3 +19,5 @@ add_subdirectory(distributed) # distributed ha/basic binaries add_subdirectory(ha/basic) +# audit test binaries +add_subdirectory(audit) diff --git a/tests/integration/apollo_runs.yaml b/tests/integration/apollo_runs.yaml index e36d69664..eb28191b3 100644 --- a/tests/integration/apollo_runs.yaml +++ b/tests/integration/apollo_runs.yaml @@ -43,6 +43,14 @@ - ../../../build_debug/tests/integration/auth/checker # checker binary - ../../../build_debug/tests/integration/auth/tester # tester binary +- name: integration__audit + cd: audit + commands: ./runner.py + infiles: + - runner.py # runner script + - ../../../build_debug/memgraph # memgraph binary + - ../../../build_debug/tests/integration/audit/tester # tester binary + - name: integration__distributed cd: distributed commands: TIMEOUT=480 ./runner.py diff --git a/tests/integration/audit/CMakeLists.txt b/tests/integration/audit/CMakeLists.txt new file mode 100644 index 000000000..fb1b76c15 --- /dev/null +++ b/tests/integration/audit/CMakeLists.txt @@ -0,0 +1,6 @@ +set(target_name memgraph__integration__audit) +set(tester_target_name ${target_name}__tester) + +add_executable(${tester_target_name} tester.cpp) +set_target_properties(${tester_target_name} PROPERTIES OUTPUT_NAME tester) +target_link_libraries(${tester_target_name} mg-communication json) diff --git a/tests/integration/audit/runner.py b/tests/integration/audit/runner.py new file mode 100755 index 000000000..fcf7ab891 --- /dev/null +++ b/tests/integration/audit/runner.py @@ -0,0 +1,123 @@ +#!/usr/bin/python3 -u +import argparse +import atexit +import csv +import json +import os +import subprocess +import sys +import tempfile +import time + +SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) +PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) + +QUERIES = [ + ("MATCH (n) DELETE n", {}), + ("MATCH (n) DETACH DELETE n", {}), + ("CREATE (n)", {}), + ("CREATE (n {name: $name})", {"name": True}), + ("CREATE (n {name: $name})", {"name": 5}), + ("CREATE (n {name: $name})", {"name": 3.14}), + ("CREATE (n {name: $name})", {"name": "nandare"}), + ("CREATE (n {name: $name})", {"name": ["nandare", "hai hai hai"]}), + ("CREATE (n {name: $name})", {"name": {"test": "ho ho ho"}}), + ("CREATE (n {name: $name})", {"name": 5, "leftover": 42}), + ("MATCH (n), (m) CREATE (n)-[:e {when: $when}]->(m)", {"when": 42}), + ("MATCH (n) RETURN n", {}), + ( + "MATCH (n), (m {type: $type}) RETURN count(n), count(m)", + {"type": "dadada"} + ), + ( + "MERGE (n) ON CREATE SET n.created = timestamp() " + "ON MATCH SET n.lastSeen = timestamp() " + "RETURN n.name, n.created, n.lastSeen", + {} + ), + ( + "MATCH (n {value: $value}) SET n.value = 0 RETURN n", + {"value": "nandare!"} + ), + ("MATCH (n), (m) SET n.value = m.value", {}), + ("MATCH (n {test: $test}) REMOVE n.value", {"test": 48}), + ("MATCH (n), (m) REMOVE n.value, m.value", {}), + ("CREATE INDEX ON :User (id)", {}), +] + + +def wait_for_server(port, delay=0.1): + cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)] + while subprocess.call(cmd) != 0: + time.sleep(0.01) + time.sleep(delay) + + +def execute_test(memgraph_binary, tester_binary): + storage_directory = tempfile.TemporaryDirectory() + memgraph_args = [memgraph_binary, + "--durability-directory", storage_directory.name, + "--audit-enabled"] + + # Start the memgraph binary + memgraph = subprocess.Popen(list(map(str, memgraph_args))) + time.sleep(0.1) + assert memgraph.poll() is None, "Memgraph process died prematurely!" + wait_for_server(7687) + + # Register cleanup function + @atexit.register + def cleanup(): + if memgraph.poll() is None: + memgraph.terminate() + assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + + # Execute all queries + print("\033[1;36m~~ Starting query execution ~~\033[0m") + for query, params in QUERIES: + print(query, params) + args = [tester_binary, "--query", query, + "--params-json", json.dumps(params)] + subprocess.run(args).check_returncode() + print("\033[1;36m~~ Finished query execution ~~\033[0m\n") + + # Shutdown the memgraph binary + memgraph.terminate() + assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + + # Verify the written log + print("\033[1;36m~~ Starting log verification ~~\033[0m") + with open(os.path.join(storage_directory.name, "audit", "audit.log")) as f: + reader = csv.reader(f, delimiter=',', doublequote=False, + escapechar='\\', lineterminator='\n', + quotechar='"', quoting=csv.QUOTE_MINIMAL, + skipinitialspace=False, strict=True) + queries = [] + for line in reader: + timestamp, address, username, query, params = line + params = json.loads(params) + queries.append((query, params)) + print(query, params) + assert queries == QUERIES, "Logged queries don't match " \ + "executed queries!" + print("\033[1;36m~~ Finished log verification ~~\033[0m\n") + + +if __name__ == "__main__": + memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph") + if not os.path.exists(memgraph_binary): + memgraph_binary = os.path.join(PROJECT_DIR, "build_debug", "memgraph") + tester_binary = os.path.join(PROJECT_DIR, "build", "tests", + "integration", "audit", "tester") + if not os.path.exists(tester_binary): + tester_binary = os.path.join(PROJECT_DIR, "build_debug", "tests", + "integration", "audit", "tester") + + parser = argparse.ArgumentParser() + parser.add_argument("--memgraph", default=memgraph_binary) + parser.add_argument("--tester", default=tester_binary) + args = parser.parse_args() + + execute_test(args.memgraph, args.tester) + + sys.exit(0) diff --git a/tests/integration/audit/tester.cpp b/tests/integration/audit/tester.cpp new file mode 100644 index 000000000..37b3b007f --- /dev/null +++ b/tests/integration/audit/tester.cpp @@ -0,0 +1,86 @@ +#include +#include +#include + +#include "communication/bolt/client.hpp" +#include "io/network/endpoint.hpp" +#include "io/network/utils.hpp" + +DEFINE_string(address, "127.0.0.1", "Server address"); +DEFINE_int32(port, 7687, "Server port"); +DEFINE_string(username, "", "Username for the database"); +DEFINE_string(password, "", "Password for the database"); +DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server."); + +DEFINE_string(query, "", "Query to execute"); +DEFINE_string(params_json, "{}", "Params for the query"); + +communication::bolt::Value JsonToValue(const nlohmann::json &jv) { + communication::bolt::Value ret; + switch (jv.type()) { + case nlohmann::json::value_t::null: + break; + case nlohmann::json::value_t::boolean: + ret = jv.get(); + break; + case nlohmann::json::value_t::number_integer: + ret = jv.get(); + break; + case nlohmann::json::value_t::number_unsigned: + ret = jv.get(); + break; + case nlohmann::json::value_t::number_float: + ret = jv.get(); + break; + case nlohmann::json::value_t::string: + ret = jv.get(); + break; + case nlohmann::json::value_t::array: { + std::vector vec; + for (const auto &item : jv) { + vec.push_back(JsonToValue(item)); + } + ret = vec; + break; + } + case nlohmann::json::value_t::object: { + std::map map; + for (auto it = jv.begin(); it != jv.end(); ++it) { + auto tmp = JsonToValue(it.key()); + CHECK(tmp.type() == communication::bolt::Value::Type::String) + << "Expected a string as the map key!"; + map.insert({tmp.ValueString(), JsonToValue(it.value())}); + } + ret = map; + break; + } + case nlohmann::json::value_t::discarded: + LOG(FATAL) << "Unexpected 'discarded' type in json value!"; + break; + } + return ret; +} + +/** + * Executes the specified query using the specified parameters. On any errors it + * exits with a non-zero exit code. + */ +int main(int argc, char **argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + google::InitGoogleLogging(argv[0]); + + communication::Init(); + + io::network::Endpoint endpoint(io::network::ResolveHostname(FLAGS_address), + FLAGS_port); + + communication::ClientContext context(FLAGS_use_ssl); + communication::bolt::Client client(&context); + + client.Connect(endpoint, FLAGS_username, FLAGS_password); + client.Execute( + FLAGS_query, + JsonToValue(nlohmann::json::parse(FLAGS_params_json)).ValueMap()); + + return 0; +}