memgraph/src/glue/SessionHL.cpp

390 lines
16 KiB
C++
Raw Normal View History

// Copyright 2024 Memgraph Ltd.
2023-08-19 00:23:15 +08:00
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <optional>
#include <utility>
#include "auth/auth.hpp"
#include "gflags/gflags.h"
2023-08-19 00:23:15 +08:00
#include "audit/log.hpp"
#include "dbms/constants.hpp"
#include "flags/run_time_configurable.hpp"
#include "glue/SessionHL.hpp"
2023-08-19 00:23:15 +08:00
#include "glue/auth_checker.hpp"
#include "glue/communication.hpp"
#include "glue/query_user.hpp"
#include "glue/run_id.hpp"
2023-08-19 00:23:15 +08:00
#include "license/license.hpp"
#include "query/auth_checker.hpp"
2023-08-19 00:23:15 +08:00
#include "query/discard_value_stream.hpp"
#include "query/interpreter_context.hpp"
#include "query/query_user.hpp"
2023-10-16 20:16:00 +08:00
#include "utils/event_map.hpp"
#include "utils/spin_lock.hpp"
#include "utils/variant_helpers.hpp"
2023-08-19 00:23:15 +08:00
namespace memgraph::metrics {
extern const Event ActiveBoltSessions;
} // namespace memgraph::metrics
namespace {
2023-08-19 00:23:15 +08:00
auto ToQueryExtras(const memgraph::communication::bolt::Value &extra) -> memgraph::query::QueryExtras {
auto const &as_map = extra.ValueMap();
auto metadata_pv = std::map<std::string, memgraph::storage::PropertyValue>{};
if (auto const it = as_map.find("tx_metadata"); it != as_map.cend() && it->second.IsMap()) {
for (const auto &[key, bolt_md] : it->second.ValueMap()) {
metadata_pv.emplace(key, memgraph::glue::ToPropertyValue(bolt_md));
}
}
auto tx_timeout = std::optional<int64_t>{};
if (auto const it = as_map.find("tx_timeout"); it != as_map.cend() && it->second.IsInt()) {
tx_timeout = it->second.ValueInt();
}
return memgraph::query::QueryExtras{std::move(metadata_pv), tx_timeout};
}
class TypedValueResultStreamBase {
public:
explicit TypedValueResultStreamBase(memgraph::storage::Storage *storage);
2023-08-19 00:23:15 +08:00
std::vector<memgraph::communication::bolt::Value> DecodeValues(
const std::vector<memgraph::query::TypedValue> &values) const;
protected:
2023-08-19 00:23:15 +08:00
// NOTE: Needed only for ToBoltValue conversions
memgraph::storage::Storage *storage_;
2023-08-19 00:23:15 +08:00
};
/// Wrapper around TEncoder which converts TypedValue to Value
/// before forwarding the calls to original TEncoder.
template <typename TEncoder>
class TypedValueResultStream : public TypedValueResultStreamBase {
public:
TypedValueResultStream(TEncoder *encoder, memgraph::storage::Storage *storage)
: TypedValueResultStreamBase{storage}, encoder_(encoder) {}
2023-08-19 00:23:15 +08:00
void Result(const std::vector<memgraph::query::TypedValue> &values) { encoder_->MessageRecord(DecodeValues(values)); }
private:
TEncoder *encoder_;
};
std::vector<memgraph::communication::bolt::Value> TypedValueResultStreamBase::DecodeValues(
const std::vector<memgraph::query::TypedValue> &values) const {
std::vector<memgraph::communication::bolt::Value> decoded_values;
decoded_values.reserve(values.size());
for (const auto &v : values) {
auto maybe_value = memgraph::glue::ToBoltValue(v, storage_, memgraph::storage::View::NEW);
2023-08-19 00:23:15 +08:00
if (maybe_value.HasError()) {
switch (maybe_value.GetError()) {
case memgraph::storage::Error::DELETED_OBJECT:
throw memgraph::communication::bolt::ClientError("Returning a deleted object as a result.");
case memgraph::storage::Error::NONEXISTENT_OBJECT:
throw memgraph::communication::bolt::ClientError("Returning a nonexistent object as a result.");
case memgraph::storage::Error::VERTEX_HAS_EDGES:
case memgraph::storage::Error::SERIALIZATION_ERROR:
case memgraph::storage::Error::PROPERTIES_DISABLED:
throw memgraph::communication::bolt::ClientError("Unexpected storage error when streaming results.");
}
}
decoded_values.emplace_back(std::move(*maybe_value));
}
return decoded_values;
}
TypedValueResultStreamBase::TypedValueResultStreamBase(memgraph::storage::Storage *storage) : storage_(storage) {}
2023-08-19 00:23:15 +08:00
#ifdef MG_ENTERPRISE
void MultiDatabaseAuth(memgraph::query::QueryUserOrRole *user, std::string_view db) {
if (user && !user->IsAuthorized({}, std::string(db), &memgraph::query::session_long_policy)) {
2023-08-19 00:23:15 +08:00
throw memgraph::communication::bolt::ClientError(
"You are not authorized on the database \"{}\"! Please contact your database administrator.", db);
}
}
#endif
} // namespace
namespace memgraph::glue {
#ifdef MG_ENTERPRISE
2023-08-19 00:23:15 +08:00
std::string SessionHL::GetDefaultDB() {
if (user_or_role_) {
return user_or_role_->GetDefaultDB();
2023-08-19 00:23:15 +08:00
}
return std::string{memgraph::dbms::kDefaultDB};
2023-08-19 00:23:15 +08:00
}
#endif
2023-08-19 00:23:15 +08:00
std::string SessionHL::GetCurrentDB() const {
if (!interpreter_.current_db_.db_acc_) return "";
const auto *db = interpreter_.current_db_.db_acc_->get();
return db->name();
2023-08-19 00:23:15 +08:00
}
std::optional<std::string> SessionHL::GetServerNameForInit() {
const auto &name = flags::run_time::GetServerName();
return name.empty() ? std::nullopt : std::make_optional(name);
2023-08-19 00:23:15 +08:00
}
2023-08-19 00:23:15 +08:00
bool SessionHL::Authenticate(const std::string &username, const std::string &password) {
bool res = true;
interpreter_.ResetUser();
{
auto locked_auth = auth_->Lock();
if (locked_auth->AccessControlled()) {
const auto user_or_role = locked_auth->Authenticate(username, password);
if (user_or_role.has_value()) {
user_or_role_ = AuthChecker::GenQueryUser(auth_, *user_or_role);
interpreter_.SetUser(AuthChecker::GenQueryUser(auth_, *user_or_role));
} else {
res = false;
}
} else {
// No access control -> give empty user
user_or_role_ = AuthChecker::GenQueryUser(auth_, std::nullopt);
interpreter_.SetUser(AuthChecker::GenQueryUser(auth_, std::nullopt));
2023-08-19 00:23:15 +08:00
}
}
#ifdef MG_ENTERPRISE
// Start off with the default database
interpreter_.SetCurrentDB(GetDefaultDB(), false);
2023-08-19 00:23:15 +08:00
#endif
implicit_db_.emplace(GetCurrentDB());
return res;
2023-08-19 00:23:15 +08:00
}
void SessionHL::Abort() { interpreter_.Abort(); }
2023-08-19 00:23:15 +08:00
std::map<std::string, memgraph::communication::bolt::Value> SessionHL::Discard(std::optional<int> n,
std::optional<int> qid) {
try {
memgraph::query::DiscardValueResultStream stream;
return DecodeSummary(interpreter_.Pull(&stream, n, qid));
2023-08-19 00:23:15 +08:00
} catch (const memgraph::query::QueryException &e) {
2023-10-16 20:16:00 +08:00
// Count the number of specific exceptions thrown
metrics::IncrementCounter(GetExceptionName(e));
2023-08-19 00:23:15 +08:00
// Wrap QueryException into ClientError, because we want to allow the
// client to fix their query.
throw memgraph::communication::bolt::ClientError(e.what());
}
}
std::map<std::string, memgraph::communication::bolt::Value> SessionHL::Pull(SessionHL::TEncoder *encoder,
std::optional<int> n,
std::optional<int> qid) {
try {
auto &db = interpreter_.current_db_.db_acc_;
auto *storage = db ? db->get()->storage() : nullptr;
TypedValueResultStream<TEncoder> stream(encoder, storage);
return DecodeSummary(interpreter_.Pull(&stream, n, qid));
2023-08-19 00:23:15 +08:00
} catch (const memgraph::query::QueryException &e) {
2023-10-16 20:16:00 +08:00
// Count the number of specific exceptions thrown
metrics::IncrementCounter(GetExceptionName(e));
2023-08-19 00:23:15 +08:00
// Wrap QueryException into ClientError, because we want to allow the
// client to fix their query.
throw memgraph::communication::bolt::ClientError(e.what());
} catch (const utils::BasicException &) {
// Exceptions inheriting from BasicException will result in a TransientError
// i. e. client will be encouraged to retry execution because it
// could succeed if executed again.
throw;
2023-08-19 00:23:15 +08:00
}
}
2023-08-19 00:23:15 +08:00
std::pair<std::vector<std::string>, std::optional<int>> SessionHL::Interpret(
const std::string &query, const std::map<std::string, memgraph::communication::bolt::Value> &params,
const std::map<std::string, memgraph::communication::bolt::Value> &extra) {
std::map<std::string, memgraph::storage::PropertyValue> params_pv;
for (const auto &[key, bolt_param] : params) {
params_pv.emplace(key, ToPropertyValue(bolt_param));
}
#ifdef MG_ENTERPRISE
2023-08-19 00:23:15 +08:00
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
auto &db = interpreter_.current_db_.db_acc_;
const auto username = user_or_role_ ? (user_or_role_->username() ? *user_or_role_->username() : "") : "";
audit_log_->Record(endpoint_.address().to_string(), username, query, memgraph::storage::PropertyValue(params_pv),
db ? db->get()->name() : "no known database");
2023-08-19 00:23:15 +08:00
}
#endif
try {
auto result = interpreter_.Prepare(query, params_pv, ToQueryExtras(extra));
2023-08-19 00:23:15 +08:00
const std::string db_name = result.db ? *result.db : "";
if (user_or_role_ && !user_or_role_->IsAuthorized(result.privileges, db_name, &query::session_long_policy)) {
interpreter_.Abort();
2023-08-19 00:23:15 +08:00
if (db_name.empty()) {
throw memgraph::communication::bolt::ClientError(
"You are not authorized to execute this query! Please contact your database administrator.");
}
throw memgraph::communication::bolt::ClientError(
"You are not authorized to execute this query on database \"{}\"! Please contact your database "
"administrator.",
db_name);
}
return {std::move(result.headers), result.qid};
} catch (const memgraph::query::QueryException &e) {
2023-10-16 20:16:00 +08:00
// Count the number of specific exceptions thrown
metrics::IncrementCounter(GetExceptionName(e));
2023-08-19 00:23:15 +08:00
// Wrap QueryException into ClientError, because we want to allow the
// client to fix their query.
throw memgraph::communication::bolt::ClientError(e.what());
} catch (const memgraph::query::ReplicationException &e) {
2023-10-16 20:16:00 +08:00
// Count the number of specific exceptions thrown
metrics::IncrementCounter(GetExceptionName(e));
2023-08-19 00:23:15 +08:00
throw memgraph::communication::bolt::ClientError(e.what());
}
}
void SessionHL::RollbackTransaction() {
try {
interpreter_.RollbackTransaction();
} catch (const memgraph::query::QueryException &e) {
// Count the number of specific exceptions thrown
metrics::IncrementCounter(GetExceptionName(e));
// Wrap QueryException into ClientError, because we want to allow the
// client to fix their query.
throw memgraph::communication::bolt::ClientError(e.what());
} catch (const memgraph::query::ReplicationException &e) {
// Count the number of specific exceptions thrown
metrics::IncrementCounter(GetExceptionName(e));
throw memgraph::communication::bolt::ClientError(e.what());
}
}
void SessionHL::CommitTransaction() {
try {
interpreter_.CommitTransaction();
} catch (const memgraph::query::QueryException &e) {
// Count the number of specific exceptions thrown
metrics::IncrementCounter(GetExceptionName(e));
// Wrap QueryException into ClientError, because we want to allow the
// client to fix their query.
throw memgraph::communication::bolt::ClientError(e.what());
} catch (const memgraph::query::ReplicationException &e) {
// Count the number of specific exceptions thrown
metrics::IncrementCounter(GetExceptionName(e));
throw memgraph::communication::bolt::ClientError(e.what());
}
}
2023-08-19 00:23:15 +08:00
void SessionHL::BeginTransaction(const std::map<std::string, memgraph::communication::bolt::Value> &extra) {
try {
interpreter_.BeginTransaction(ToQueryExtras(extra));
} catch (const memgraph::query::QueryException &e) {
// Count the number of specific exceptions thrown
metrics::IncrementCounter(GetExceptionName(e));
// Wrap QueryException into ClientError, because we want to allow the
// client to fix their query.
throw memgraph::communication::bolt::ClientError(e.what());
} catch (const memgraph::query::ReplicationException &e) {
// Count the number of specific exceptions thrown
metrics::IncrementCounter(GetExceptionName(e));
throw memgraph::communication::bolt::ClientError(e.what());
}
2023-08-19 00:23:15 +08:00
}
2023-08-19 00:23:15 +08:00
void SessionHL::Configure(const std::map<std::string, memgraph::communication::bolt::Value> &run_time_info) {
#ifdef MG_ENTERPRISE
std::string db;
bool update = false;
// Check if user explicitly defined the database to use
if (run_time_info.contains("db")) {
const auto &db_info = run_time_info.at("db");
if (!db_info.IsString()) {
throw memgraph::communication::bolt::ClientError("Malformed database name.");
}
db = db_info.ValueString();
const auto &current = GetCurrentDB();
update = db != current;
if (!in_explicit_db_) implicit_db_.emplace(current); // Still not in an explicit database, save for recovery
2023-08-19 00:23:15 +08:00
in_explicit_db_ = true;
// NOTE: Once in a transaction, the drivers stop explicitly sending the db and count on using it until commit
} else if (in_explicit_db_ && !interpreter_.in_explicit_transaction_) { // Just on a switch
if (implicit_db_) {
db = *implicit_db_;
} else {
db = GetDefaultDB();
}
update = db != GetCurrentDB();
2023-08-19 00:23:15 +08:00
in_explicit_db_ = false;
}
// Check if the underlying database needs to be updated
if (update) {
MultiDatabaseAuth(user_or_role_.get(), db);
interpreter_.SetCurrentDB(db, in_explicit_db_);
2023-08-19 00:23:15 +08:00
}
#endif
}
SessionHL::SessionHL(memgraph::query::InterpreterContext *interpreter_context,
memgraph::communication::v2::ServerEndpoint endpoint,
memgraph::communication::v2::InputStream *input_stream,
memgraph::communication::v2::OutputStream *output_stream, memgraph::auth::SynchedAuth *auth
2023-08-19 00:23:15 +08:00
#ifdef MG_ENTERPRISE
,
memgraph::audit::Log *audit_log
2023-08-19 00:23:15 +08:00
#endif
)
2023-08-19 00:23:15 +08:00
: Session<memgraph::communication::v2::InputStream, memgraph::communication::v2::OutputStream>(input_stream,
output_stream),
interpreter_context_(interpreter_context),
interpreter_(interpreter_context_),
2023-08-19 00:23:15 +08:00
#ifdef MG_ENTERPRISE
audit_log_(audit_log),
2023-08-19 00:23:15 +08:00
#endif
auth_(auth),
endpoint_(std::move(endpoint)),
implicit_db_(dbms::kDefaultDB) {
// Metrics update
2023-08-19 00:23:15 +08:00
memgraph::metrics::IncrementCounter(memgraph::metrics::ActiveBoltSessions);
#ifdef MG_ENTERPRISE
interpreter_.OnChangeCB([&](std::string_view db_name) { MultiDatabaseAuth(user_or_role_.get(), db_name); });
#endif
interpreter_context_->interpreters.WithLock([this](auto &interpreters) { interpreters.insert(&interpreter_); });
2023-08-19 00:23:15 +08:00
}
SessionHL::~SessionHL() {
memgraph::metrics::DecrementCounter(memgraph::metrics::ActiveBoltSessions);
interpreter_context_->interpreters.WithLock([this](auto &interpreters) { interpreters.erase(&interpreter_); });
2023-08-19 00:23:15 +08:00
}
std::map<std::string, memgraph::communication::bolt::Value> SessionHL::DecodeSummary(
const std::map<std::string, memgraph::query::TypedValue> &summary) {
auto &db_acc = interpreter_.current_db_.db_acc_;
auto *storage = db_acc ? db_acc->get()->storage() : nullptr;
2023-08-19 00:23:15 +08:00
std::map<std::string, memgraph::communication::bolt::Value> decoded_summary;
for (const auto &kv : summary) {
auto maybe_value = ToBoltValue(kv.second, storage, memgraph::storage::View::NEW);
2023-08-19 00:23:15 +08:00
if (maybe_value.HasError()) {
switch (maybe_value.GetError()) {
case memgraph::storage::Error::DELETED_OBJECT:
case memgraph::storage::Error::SERIALIZATION_ERROR:
case memgraph::storage::Error::VERTEX_HAS_EDGES:
case memgraph::storage::Error::PROPERTIES_DISABLED:
case memgraph::storage::Error::NONEXISTENT_OBJECT:
throw memgraph::communication::bolt::ClientError("Unexpected storage error when streaming summary.");
}
}
decoded_summary.emplace(kv.first, std::move(*maybe_value));
}
// Add this memgraph instance run_id, received from telemetry
// This is sent with every query, instead of only on bolt init inside
// communication/bolt/v1/states/init.hpp because neo4jdriver does not
// read the init message.
decoded_summary.emplace("run_id", memgraph::glue::run_id_);
2023-08-19 00:23:15 +08:00
return decoded_summary;
}
} // namespace memgraph::glue