Format all the memgraph and test source files (#97)

This commit is contained in:
antonio2368 2021-02-18 15:32:43 +01:00 committed by GitHub
parent 435af8b833
commit 3f3c55a4aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
311 changed files with 12810 additions and 22030 deletions

View File

@ -38,8 +38,7 @@ inline nlohmann::json PropertyValueToJson(const storage::PropertyValue &pv) {
case storage::PropertyValue::Type::Map: {
ret = nlohmann::json::object();
for (const auto &item : pv.ValueMap()) {
ret.push_back(nlohmann::json::object_t::value_type(
item.first, PropertyValueToJson(item.second)));
ret.push_back(nlohmann::json::object_t::value_type(item.first, PropertyValueToJson(item.second)));
}
break;
}
@ -47,8 +46,7 @@ inline nlohmann::json PropertyValueToJson(const storage::PropertyValue &pv) {
return ret;
}
Log::Log(const std::filesystem::path &storage_directory, int32_t buffer_size,
int32_t buffer_flush_interval_millis)
Log::Log(const std::filesystem::path &storage_directory, int32_t buffer_size, int32_t buffer_flush_interval_millis)
: storage_directory_(storage_directory),
buffer_size_(buffer_size),
buffer_flush_interval_millis_(buffer_flush_interval_millis),
@ -63,9 +61,7 @@ void Log::Start() {
started_ = true;
ReopenLog();
scheduler_.Run("Audit",
std::chrono::milliseconds(buffer_flush_interval_millis_),
[&] { Flush(); });
scheduler_.Run("Audit", std::chrono::milliseconds(buffer_flush_interval_millis_), [&] { Flush(); });
}
Log::~Log() {
@ -78,13 +74,12 @@ Log::~Log() {
Flush();
}
void Log::Record(const std::string &address, const std::string &username,
const std::string &query,
void Log::Record(const std::string &address, const std::string &username, const std::string &query,
const storage::PropertyValue &params) {
if (!started_.load(std::memory_order_relaxed)) return;
auto timestamp = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
auto timestamp =
std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch())
.count();
buffer_->emplace(Item{timestamp, address, username, query, params});
}
@ -92,8 +87,7 @@ void Log::ReopenLog() {
if (!started_.load(std::memory_order_relaxed)) return;
std::lock_guard<std::mutex> guard(lock_);
if (log_.IsOpen()) log_.Close();
log_.Open(storage_directory_ / "audit.log",
utils::OutputFile::Mode::APPEND_TO_EXISTING);
log_.Open(storage_directory_ / "audit.log", utils::OutputFile::Mode::APPEND_TO_EXISTING);
}
void Log::Flush() {
@ -101,11 +95,9 @@ void Log::Flush() {
for (uint64_t i = 0; i < buffer_size_; ++i) {
auto item = buffer_->pop();
if (!item) break;
log_.Write(
fmt::format("{}.{:06d},{},{},{},{}\n", item->timestamp / 1000000,
item->timestamp % 1000000, item->address, item->username,
utils::Escape(item->query),
utils::Escape(PropertyValueToJson(item->params).dump())));
log_.Write(fmt::format("{}.{:06d},{},{},{},{}\n", item->timestamp / 1000000, item->timestamp % 1000000,
item->address, item->username, utils::Escape(item->query),
utils::Escape(PropertyValueToJson(item->params).dump())));
}
log_.Sync();
}

View File

@ -27,8 +27,7 @@ class Log {
};
public:
Log(const std::filesystem::path &storage_directory, int32_t buffer_size,
int32_t buffer_flush_interval_millis);
Log(const std::filesystem::path &storage_directory, int32_t buffer_size, int32_t buffer_flush_interval_millis);
~Log();
@ -43,8 +42,8 @@ class Log {
void Start();
/// Adds an entry to the audit log. Thread-safe.
void Record(const std::string &address, const std::string &username,
const std::string &query, const storage::PropertyValue &params);
void Record(const std::string &address, const std::string &username, const std::string &query,
const storage::PropertyValue &params);
/// Reopens the log file. Used for log file rotation. Thread-safe.
void ReopenLog();

View File

@ -12,26 +12,20 @@
#include "utils/logging.hpp"
#include "utils/string.hpp"
DEFINE_VALIDATED_string(
auth_module_executable, "",
"Absolute path to the auth module executable that should be used.", {
if (value.empty()) return true;
// Check the file status, following symlinks.
auto status = std::filesystem::status(value);
if (!std::filesystem::is_regular_file(status)) {
std::cerr << "The auth module path doesn't exist or isn't a file!"
<< std::endl;
return false;
}
return true;
});
DEFINE_bool(auth_module_create_missing_user, true,
"Set to false to disable creation of missing users.");
DEFINE_bool(auth_module_create_missing_role, true,
"Set to false to disable creation of missing roles.");
DEFINE_bool(
auth_module_manage_roles, true,
"Set to false to disable management of roles through the auth module.");
DEFINE_VALIDATED_string(auth_module_executable, "", "Absolute path to the auth module executable that should be used.",
{
if (value.empty()) return true;
// Check the file status, following symlinks.
auto status = std::filesystem::status(value);
if (!std::filesystem::is_regular_file(status)) {
std::cerr << "The auth module path doesn't exist or isn't a file!" << std::endl;
return false;
}
return true;
});
DEFINE_bool(auth_module_create_missing_user, true, "Set to false to disable creation of missing users.");
DEFINE_bool(auth_module_create_missing_role, true, "Set to false to disable creation of missing roles.");
DEFINE_bool(auth_module_manage_roles, true, "Set to false to disable management of roles through the auth module.");
DEFINE_VALIDATED_int32(auth_module_timeout_ms, 10000,
"Timeout (in milliseconds) used when waiting for a "
"response from the auth module.",
@ -60,11 +54,9 @@ const std::string kLinkPrefix = "link:";
* key="link:<username>", value="<rolename>"
*/
Auth::Auth(const std::string &storage_directory)
: storage_(storage_directory), module_(FLAGS_auth_module_executable) {}
Auth::Auth(const std::string &storage_directory) : storage_(storage_directory), module_(FLAGS_auth_module_executable) {}
std::optional<User> Auth::Authenticate(const std::string &username,
const std::string &password) {
std::optional<User> Auth::Authenticate(const std::string &username, const std::string &password) {
if (module_.IsUsed()) {
nlohmann::json params = nlohmann::json::object();
params["username"] = username;
@ -73,8 +65,7 @@ std::optional<User> Auth::Authenticate(const std::string &username,
auto ret = module_.Call(params, FLAGS_auth_module_timeout_ms);
// Verify response integrity.
if (!ret.is_object() || ret.find("authenticated") == ret.end() ||
ret.find("role") == ret.end()) {
if (!ret.is_object() || ret.find("authenticated") == ret.end() || ret.find("role") == ret.end()) {
return std::nullopt;
}
const auto &ret_authenticated = ret.at("authenticated");
@ -142,9 +133,7 @@ std::optional<User> Auth::Authenticate(const std::string &username,
} else {
auto user = GetUser(username);
if (!user) {
spdlog::warn(
"Couldn't authenticate user '{}' because the user doesn't exist",
username);
spdlog::warn("Couldn't authenticate user '{}' because the user doesn't exist", username);
return std::nullopt;
}
if (!user->CheckPassword(password)) {
@ -182,21 +171,18 @@ std::optional<User> Auth::GetUser(const std::string &username_orig) {
void Auth::SaveUser(const User &user) {
bool success = false;
if (user.role()) {
success = storage_.PutMultiple(
{{kUserPrefix + user.username(), user.Serialize().dump()},
{kLinkPrefix + user.username(), user.role()->rolename()}});
success = storage_.PutMultiple({{kUserPrefix + user.username(), user.Serialize().dump()},
{kLinkPrefix + user.username(), user.role()->rolename()}});
} else {
success = storage_.PutAndDeleteMultiple(
{{kUserPrefix + user.username(), user.Serialize().dump()}},
{kLinkPrefix + user.username()});
success = storage_.PutAndDeleteMultiple({{kUserPrefix + user.username(), user.Serialize().dump()}},
{kLinkPrefix + user.username()});
}
if (!success) {
throw AuthException("Couldn't save user '{}'!", user.username());
}
}
std::optional<User> Auth::AddUser(const std::string &username,
const std::optional<std::string> &password) {
std::optional<User> Auth::AddUser(const std::string &username, const std::optional<std::string> &password) {
auto existing_user = GetUser(username);
if (existing_user) return std::nullopt;
auto existing_role = GetRole(username);
@ -210,8 +196,7 @@ std::optional<User> Auth::AddUser(const std::string &username,
bool Auth::RemoveUser(const std::string &username_orig) {
auto username = utils::ToLowerCase(username_orig);
if (!storage_.Get(kUserPrefix + username)) return false;
std::vector<std::string> keys(
{kLinkPrefix + username, kUserPrefix + username});
std::vector<std::string> keys({kLinkPrefix + username, kUserPrefix + username});
if (!storage_.DeleteMultiple(keys)) {
throw AuthException("Couldn't remove user '{}'!", username);
}
@ -220,8 +205,7 @@ bool Auth::RemoveUser(const std::string &username_orig) {
std::vector<auth::User> Auth::AllUsers() {
std::vector<auth::User> ret;
for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix);
++it) {
for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) {
auto username = it->first.substr(kUserPrefix.size());
if (username != utils::ToLowerCase(username)) continue;
auto user = GetUser(username);
@ -232,9 +216,7 @@ std::vector<auth::User> Auth::AllUsers() {
return ret;
}
bool Auth::HasUsers() {
return storage_.begin(kUserPrefix) != storage_.end(kUserPrefix);
}
bool Auth::HasUsers() { return storage_.begin(kUserPrefix) != storage_.end(kUserPrefix); }
std::optional<Role> Auth::GetRole(const std::string &rolename_orig) {
auto rolename = utils::ToLowerCase(rolename_orig);
@ -271,8 +253,7 @@ bool Auth::RemoveRole(const std::string &rolename_orig) {
auto rolename = utils::ToLowerCase(rolename_orig);
if (!storage_.Get(kRolePrefix + rolename)) return false;
std::vector<std::string> keys;
for (auto it = storage_.begin(kLinkPrefix); it != storage_.end(kLinkPrefix);
++it) {
for (auto it = storage_.begin(kLinkPrefix); it != storage_.end(kLinkPrefix); ++it) {
if (utils::ToLowerCase(it->second) == rolename) {
keys.push_back(it->first);
}
@ -286,8 +267,7 @@ bool Auth::RemoveRole(const std::string &rolename_orig) {
std::vector<auth::Role> Auth::AllRoles() {
std::vector<auth::Role> ret;
for (auto it = storage_.begin(kRolePrefix); it != storage_.end(kRolePrefix);
++it) {
for (auto it = storage_.begin(kRolePrefix); it != storage_.end(kRolePrefix); ++it) {
auto rolename = it->first.substr(kRolePrefix.size());
if (rolename != utils::ToLowerCase(rolename)) continue;
auto role = GetRole(rolename);
@ -300,12 +280,10 @@ std::vector<auth::Role> Auth::AllRoles() {
return ret;
}
std::vector<auth::User> Auth::AllUsersForRole(
const std::string &rolename_orig) {
std::vector<auth::User> Auth::AllUsersForRole(const std::string &rolename_orig) {
auto rolename = utils::ToLowerCase(rolename_orig);
std::vector<auth::User> ret;
for (auto it = storage_.begin(kLinkPrefix); it != storage_.end(kLinkPrefix);
++it) {
for (auto it = storage_.begin(kLinkPrefix); it != storage_.end(kLinkPrefix); ++it) {
auto username = it->first.substr(kLinkPrefix.size());
if (username != utils::ToLowerCase(username)) continue;
if (it->second != utils::ToLowerCase(it->second)) continue;

View File

@ -32,8 +32,7 @@ class Auth final {
* @return a user when the username and password match, nullopt otherwise
* @throw AuthException if unable to authenticate for whatever reason.
*/
std::optional<User> Authenticate(const std::string &username,
const std::string &password);
std::optional<User> Authenticate(const std::string &username, const std::string &password);
/**
* Gets a user from the storage.
@ -63,9 +62,7 @@ class Auth final {
* @return a user when the user is created, nullopt if the user exists
* @throw AuthException if unable to save the user.
*/
std::optional<User> AddUser(
const std::string &username,
const std::optional<std::string> &password = std::nullopt);
std::optional<User> AddUser(const std::string &username, const std::optional<std::string> &password = std::nullopt);
/**
* Removes a user from the storage.

View File

@ -9,8 +9,7 @@
#include "utils/cast.hpp"
#include "utils/string.hpp"
DEFINE_bool(auth_password_permit_null, true,
"Set to false to disable null passwords.");
DEFINE_bool(auth_password_permit_null, true, "Set to false to disable null passwords.");
DEFINE_string(auth_password_strength_regex, ".+",
"The regular expression that should be used to match the entire "
@ -129,8 +128,7 @@ Permissions Permissions::Deserialize(const nlohmann::json &data) {
if (!data.is_object()) {
throw AuthException("Couldn't load permissions data!");
}
if (!data["grants"].is_number_unsigned() ||
!data["denies"].is_number_unsigned()) {
if (!data["grants"].is_number_unsigned() || !data["denies"].is_number_unsigned()) {
throw AuthException("Couldn't load permissions data!");
}
return {data["grants"], data["denies"]};
@ -143,12 +141,9 @@ bool operator==(const Permissions &first, const Permissions &second) {
return first.grants() == second.grants() && first.denies() == second.denies();
}
bool operator!=(const Permissions &first, const Permissions &second) {
return !(first == second);
}
bool operator!=(const Permissions &first, const Permissions &second) { return !(first == second); }
Role::Role(const std::string &rolename)
: rolename_(utils::ToLowerCase(rolename)) {}
Role::Role(const std::string &rolename) : rolename_(utils::ToLowerCase(rolename)) {}
Role::Role(const std::string &rolename, const Permissions &permissions)
: rolename_(utils::ToLowerCase(rolename)), permissions_(permissions) {}
@ -176,18 +171,13 @@ Role Role::Deserialize(const nlohmann::json &data) {
}
bool operator==(const Role &first, const Role &second) {
return first.rolename_ == second.rolename_ &&
first.permissions_ == second.permissions_;
return first.rolename_ == second.rolename_ && first.permissions_ == second.permissions_;
}
User::User(const std::string &username)
: username_(utils::ToLowerCase(username)) {}
User::User(const std::string &username) : username_(utils::ToLowerCase(username)) {}
User::User(const std::string &username, const std::string &password_hash,
const Permissions &permissions)
: username_(utils::ToLowerCase(username)),
password_hash_(password_hash),
permissions_(permissions) {}
User::User(const std::string &username, const std::string &password_hash, const Permissions &permissions)
: username_(utils::ToLowerCase(username)), password_hash_(password_hash), permissions_(permissions) {}
bool User::CheckPassword(const std::string &password) {
if (password_hash_ == "") return true;
@ -244,8 +234,7 @@ User User::Deserialize(const nlohmann::json &data) {
if (!data.is_object()) {
throw AuthException("Couldn't load user data!");
}
if (!data["username"].is_string() || !data["password_hash"].is_string() ||
!data["permissions"].is_object()) {
if (!data["username"].is_string() || !data["password_hash"].is_string() || !data["permissions"].is_object()) {
throw AuthException("Couldn't load user data!");
}
auto permissions = Permissions::Deserialize(data["permissions"]);
@ -253,9 +242,7 @@ User User::Deserialize(const nlohmann::json &data) {
}
bool operator==(const User &first, const User &second) {
return first.username_ == second.username_ &&
first.password_hash_ == second.password_hash_ &&
first.permissions_ == second.permissions_ &&
first.role_ == second.role_;
return first.username_ == second.username_ && first.password_hash_ == second.password_hash_ &&
first.permissions_ == second.permissions_ && first.role_ == second.role_;
}
} // namespace auth

View File

@ -29,11 +29,9 @@ enum class Permission : uint64_t {
// Constant list of all available permissions.
const std::vector<Permission> kPermissionsAll = {
Permission::MATCH, Permission::CREATE, Permission::MERGE,
Permission::DELETE, Permission::SET, Permission::REMOVE,
Permission::INDEX, Permission::STATS, Permission::CONSTRAINT,
Permission::DUMP, Permission::AUTH, Permission::REPLICATION,
Permission::LOCK_PATH};
Permission::MATCH, Permission::CREATE, Permission::MERGE, Permission::DELETE, Permission::SET,
Permission::REMOVE, Permission::INDEX, Permission::STATS, Permission::CONSTRAINT, Permission::DUMP,
Permission::AUTH, Permission::REPLICATION, Permission::LOCK_PATH};
// Function that converts a permission to its string representation.
std::string PermissionToString(Permission permission);
@ -110,15 +108,13 @@ class User final {
public:
User(const std::string &username);
User(const std::string &username, const std::string &password_hash,
const Permissions &permissions);
User(const std::string &username, const std::string &password_hash, const Permissions &permissions);
/// @throw AuthException if unable to verify the password.
bool CheckPassword(const std::string &password);
/// @throw AuthException if unable to set the password.
void UpdatePassword(
const std::optional<std::string> &password = std::nullopt);
void UpdatePassword(const std::optional<std::string> &password = std::nullopt);
void SetRole(const Role &role);

View File

@ -86,54 +86,22 @@ class CharPP final {
////////////////////////////////////
const std::vector<int> kSeccompSyscallsBlacklist = {
SCMP_SYS(mknod),
SCMP_SYS(mount),
SCMP_SYS(setuid),
SCMP_SYS(stime),
SCMP_SYS(ptrace),
SCMP_SYS(setgid),
SCMP_SYS(acct),
SCMP_SYS(umount),
SCMP_SYS(setpgid),
SCMP_SYS(chroot),
SCMP_SYS(setreuid),
SCMP_SYS(setregid),
SCMP_SYS(sethostname),
SCMP_SYS(settimeofday),
SCMP_SYS(setgroups),
SCMP_SYS(swapon),
SCMP_SYS(reboot),
SCMP_SYS(setpriority),
SCMP_SYS(ioperm),
SCMP_SYS(syslog),
SCMP_SYS(iopl),
SCMP_SYS(vhangup),
SCMP_SYS(vm86old),
SCMP_SYS(swapoff),
SCMP_SYS(setdomainname),
SCMP_SYS(adjtimex),
SCMP_SYS(init_module),
SCMP_SYS(delete_module),
SCMP_SYS(setfsuid),
SCMP_SYS(setfsgid),
SCMP_SYS(setresuid),
SCMP_SYS(vm86),
SCMP_SYS(setresgid),
SCMP_SYS(capset),
SCMP_SYS(setreuid),
SCMP_SYS(setregid),
SCMP_SYS(setgroups),
SCMP_SYS(setresuid),
SCMP_SYS(setresgid),
SCMP_SYS(setuid),
SCMP_SYS(setgid),
SCMP_SYS(setfsuid),
SCMP_SYS(setfsgid),
SCMP_SYS(pivot_root),
SCMP_SYS(sched_setaffinity),
SCMP_SYS(clock_settime),
SCMP_SYS(kexec_load),
SCMP_SYS(mknodat),
SCMP_SYS(mknod), SCMP_SYS(mount), SCMP_SYS(setuid),
SCMP_SYS(stime), SCMP_SYS(ptrace), SCMP_SYS(setgid),
SCMP_SYS(acct), SCMP_SYS(umount), SCMP_SYS(setpgid),
SCMP_SYS(chroot), SCMP_SYS(setreuid), SCMP_SYS(setregid),
SCMP_SYS(sethostname), SCMP_SYS(settimeofday), SCMP_SYS(setgroups),
SCMP_SYS(swapon), SCMP_SYS(reboot), SCMP_SYS(setpriority),
SCMP_SYS(ioperm), SCMP_SYS(syslog), SCMP_SYS(iopl),
SCMP_SYS(vhangup), SCMP_SYS(vm86old), SCMP_SYS(swapoff),
SCMP_SYS(setdomainname), SCMP_SYS(adjtimex), SCMP_SYS(init_module),
SCMP_SYS(delete_module), SCMP_SYS(setfsuid), SCMP_SYS(setfsgid),
SCMP_SYS(setresuid), SCMP_SYS(vm86), SCMP_SYS(setresgid),
SCMP_SYS(capset), SCMP_SYS(setreuid), SCMP_SYS(setregid),
SCMP_SYS(setgroups), SCMP_SYS(setresuid), SCMP_SYS(setresgid),
SCMP_SYS(setuid), SCMP_SYS(setgid), SCMP_SYS(setfsuid),
SCMP_SYS(setfsgid), SCMP_SYS(pivot_root), SCMP_SYS(sched_setaffinity),
SCMP_SYS(clock_settime), SCMP_SYS(kexec_load), SCMP_SYS(mknodat),
SCMP_SYS(unshare),
#ifdef SYS_seccomp
SCMP_SYS(seccomp),
@ -182,24 +150,20 @@ int Target(void *arg) {
// Redirect `stdin` to `/dev/null`.
int fd = open("/dev/null", O_RDONLY | O_CLOEXEC);
if (fd == -1) {
std::cerr
<< "Couldn't open \"/dev/null\" for auth module stdin because of: "
<< strerror(errno) << " (" << errno << ")!" << std::endl;
std::cerr << "Couldn't open \"/dev/null\" for auth module stdin because of: " << strerror(errno) << " (" << errno
<< ")!" << std::endl;
return EXIT_FAILURE;
}
if (dup2(fd, STDIN_FILENO) != STDIN_FILENO) {
std::cerr
<< "Couldn't attach \"/dev/null\" to auth module stdin because of: "
<< strerror(errno) << " (" << errno << ")!" << std::endl;
std::cerr << "Couldn't attach \"/dev/null\" to auth module stdin because of: " << strerror(errno) << " (" << errno
<< ")!" << std::endl;
return EXIT_FAILURE;
}
// Change the current directory to the module directory.
if (chdir(ta->module_executable_path.parent_path().c_str()) != 0) {
std::cerr << "Couldn't change directory to "
<< ta->module_executable_path.parent_path()
<< " for auth module stdin because of: " << strerror(errno)
<< " (" << errno << ")!" << std::endl;
std::cerr << "Couldn't change directory to " << ta->module_executable_path.parent_path()
<< " for auth module stdin because of: " << strerror(errno) << " (" << errno << ")!" << std::endl;
return EXIT_FAILURE;
}
@ -214,8 +178,7 @@ int Target(void *arg) {
}
// Connect the communication input pipe.
if (dup2(ta->pipe_to_module, kCommunicationToModuleFd) !=
kCommunicationToModuleFd) {
if (dup2(ta->pipe_to_module, kCommunicationToModuleFd) != kCommunicationToModuleFd) {
std::cerr << "Couldn't attach communication to module pipe to auth module "
"because of: "
<< strerror(errno) << " (" << errno << ")!" << std::endl;
@ -223,8 +186,7 @@ int Target(void *arg) {
}
// Connect the communication output pipe.
if (dup2(ta->pipe_from_module, kCommunicationFromModuleFd) !=
kCommunicationFromModuleFd) {
if (dup2(ta->pipe_from_module, kCommunicationFromModuleFd) != kCommunicationFromModuleFd) {
std::cerr << "Couldn't attach communication from module pipe to auth "
"module because of: "
<< strerror(errno) << " (" << errno << ")!" << std::endl;
@ -246,8 +208,8 @@ int Target(void *arg) {
sigemptyset(&action.sa_mask);
action.sa_flags = 0;
if (sigaction(SIGINT, &action, nullptr) != 0) {
std::cerr << "Couldn't ignore SIGINT for auth module because of: "
<< strerror(errno) << " (" << errno << ")!" << std::endl;
std::cerr << "Couldn't ignore SIGINT for auth module because of: " << strerror(errno) << " (" << errno << ")!"
<< std::endl;
return EXIT_FAILURE;
}
@ -261,8 +223,7 @@ int Target(void *arg) {
// If the `execve` call succeeded then the process will exit from that call
// and won't reach this piece of code ever.
std::cerr << "Couldn't start auth module because of: " << strerror(errno)
<< " (" << errno << ")!" << std::endl;
std::cerr << "Couldn't start auth module because of: " << strerror(errno) << " (" << errno << ")!" << std::endl;
return EXIT_FAILURE;
}
@ -408,8 +369,7 @@ bool Module::Startup() {
return true;
}
nlohmann::json Module::Call(const nlohmann::json &params,
int timeout_millisec) {
nlohmann::json Module::Call(const nlohmann::json &params, int timeout_millisec) {
std::lock_guard<std::mutex> guard(lock_);
if (!params.is_object()) return {};

View File

@ -43,16 +43,14 @@ class ClientFatalException : public utils::BasicException {
// only handle the `ClientFatalException`.
class ServerCommunicationException : public ClientFatalException {
public:
ServerCommunicationException()
: ClientFatalException("Couldn't communicate with the server!") {}
ServerCommunicationException() : ClientFatalException("Couldn't communicate with the server!") {}
};
// Internal exception used whenever a malformed data error occurs. You should
// only handle the `ClientFatalException`.
class ServerMalformedDataException : public ClientFatalException {
public:
ServerMalformedDataException()
: ClientFatalException("The server sent malformed data!") {}
ServerMalformedDataException() : ClientFatalException("The server sent malformed data!") {}
};
/// Structure that is used to return results from an executed query.
@ -79,8 +77,7 @@ class Client final {
/// connection is set-up, multiple queries may be executed through a single
/// established connection.
/// @throws ClientFatalException when we couldn't connect to the server
void Connect(const io::network::Endpoint &endpoint,
const std::string &username, const std::string &password,
void Connect(const io::network::Endpoint &endpoint, const std::string &username, const std::string &password,
const std::string &client_name = "memgraph-bolt") {
if (!client_.Connect(endpoint)) {
throw ClientFatalException("Couldn't connect to {}!", endpoint);
@ -103,14 +100,11 @@ class Client final {
}
if (memcmp(kProtocol, client_.GetData(), sizeof(kProtocol)) != 0) {
SPDLOG_ERROR("Server negotiated unsupported protocol version!");
throw ClientFatalException(
"The server negotiated an usupported protocol version!");
throw ClientFatalException("The server negotiated an usupported protocol version!");
}
client_.ShiftData(sizeof(kProtocol));
if (!encoder_.MessageInit(client_name, {{"scheme", "basic"},
{"principal", username},
{"credentials", password}})) {
if (!encoder_.MessageInit(client_name, {{"scheme", "basic"}, {"principal", username}, {"credentials", password}})) {
SPDLOG_ERROR("Couldn't send init message!");
throw ServerCommunicationException();
}
@ -135,15 +129,12 @@ class Client final {
/// executing the query (eg. mistyped query,
/// etc.)
/// @throws ClientFatalException when we couldn't communicate with the server
QueryData Execute(const std::string &query,
const std::map<std::string, Value> &parameters) {
QueryData Execute(const std::string &query, const std::map<std::string, Value> &parameters) {
if (!client_.IsConnected()) {
throw ClientFatalException(
"You must first connect to the server before using the client!");
throw ClientFatalException("You must first connect to the server before using the client!");
}
SPDLOG_INFO("Sending run message with statement: '{}'; parameters: {}",
query, parameters);
SPDLOG_INFO("Sending run message with statement: '{}'; parameters: {}", query, parameters);
encoder_.MessageRun(query, parameters);
encoder_.MessagePullAll();
@ -165,8 +156,7 @@ class Client final {
if (it != tmp.end()) {
auto it_code = tmp.find("code");
if (it_code != tmp.end()) {
throw ClientQueryException(it_code->second.ValueString(),
it->second.ValueString());
throw ClientQueryException(it_code->second.ValueString(), it->second.ValueString());
} else {
throw ClientQueryException("", it->second.ValueString());
}
@ -209,8 +199,7 @@ class Client final {
if (it != tmp.end()) {
auto it_code = tmp.find("code");
if (it_code != tmp.end()) {
throw ClientQueryException(it_code->second.ValueString(),
it->second.ValueString());
throw ClientQueryException(it_code->second.ValueString(), it->second.ValueString());
} else {
throw ClientQueryException("", it->second.ValueString());
}
@ -308,15 +297,11 @@ class Client final {
communication::ClientOutputStream output_stream_{client_};
// decoder objects
ChunkedDecoderBuffer<communication::ClientInputStream> decoder_buffer_{
input_stream_};
Decoder<ChunkedDecoderBuffer<communication::ClientInputStream>> decoder_{
decoder_buffer_};
ChunkedDecoderBuffer<communication::ClientInputStream> decoder_buffer_{input_stream_};
Decoder<ChunkedDecoderBuffer<communication::ClientInputStream>> decoder_{decoder_buffer_};
// encoder objects
ChunkedEncoderBuffer<communication::ClientOutputStream> encoder_buffer_{
output_stream_};
ClientEncoder<ChunkedEncoderBuffer<communication::ClientOutputStream>>
encoder_{encoder_buffer_};
ChunkedEncoderBuffer<communication::ClientOutputStream> encoder_buffer_{output_stream_};
ClientEncoder<ChunkedEncoderBuffer<communication::ClientOutputStream>> encoder_{encoder_buffer_};
};
} // namespace communication::bolt

View File

@ -78,12 +78,8 @@ enum class Marker : uint8_t {
};
static constexpr uint8_t MarkerString = 0, MarkerList = 1, MarkerMap = 2;
static constexpr Marker MarkerTiny[3] = {Marker::TinyString, Marker::TinyList,
Marker::TinyMap};
static constexpr Marker Marker8[3] = {Marker::String8, Marker::List8,
Marker::Map8};
static constexpr Marker Marker16[3] = {Marker::String16, Marker::List16,
Marker::Map16};
static constexpr Marker Marker32[3] = {Marker::String32, Marker::List32,
Marker::Map32};
static constexpr Marker MarkerTiny[3] = {Marker::TinyString, Marker::TinyList, Marker::TinyMap};
static constexpr Marker Marker8[3] = {Marker::String8, Marker::List8, Marker::Map8};
static constexpr Marker Marker16[3] = {Marker::String16, Marker::List16, Marker::Map16};
static constexpr Marker Marker32[3] = {Marker::String32, Marker::List32, Marker::Map32};
} // namespace communication::bolt

View File

@ -40,9 +40,7 @@ enum class ChunkState : uint8_t {
template <typename TBuffer>
class ChunkedDecoderBuffer {
public:
ChunkedDecoderBuffer(TBuffer &buffer) : buffer_(buffer) {
data_.reserve(kChunkMaxDataSize);
}
ChunkedDecoderBuffer(TBuffer &buffer) : buffer_(buffer) { data_.reserve(kChunkMaxDataSize); }
/**
* Reads data from the internal buffer.

View File

@ -158,8 +158,7 @@ class Decoder {
}
bool ReadBool(const Marker &marker, Value *data) {
DMG_ASSERT(marker == Marker::False || marker == Marker::True,
"Received invalid marker!");
DMG_ASSERT(marker == Marker::False || marker == Marker::True, "Received invalid marker!");
if (marker == Marker::False) {
*data = Value(false);
} else {

View File

@ -7,8 +7,7 @@
#include "utils/cast.hpp"
#include "utils/endian.hpp"
static_assert(std::is_same_v<std::uint8_t, char> ||
std::is_same_v<std::uint8_t, unsigned char>,
static_assert(std::is_same_v<std::uint8_t, char> || std::is_same_v<std::uint8_t, unsigned char>,
"communication::bolt::Encoder requires uint8_t to be "
"implemented as char or unsigned char.");
@ -29,9 +28,7 @@ class BaseEncoder {
void WriteRAW(const uint8_t *data, uint64_t len) { buffer_.Write(data, len); }
void WriteRAW(const char *data, uint64_t len) {
WriteRAW((const uint8_t *)data, len);
}
void WriteRAW(const char *data, uint64_t len) { WriteRAW((const uint8_t *)data, len); }
void WriteRAW(const uint8_t data) { WriteRAW(&data, 1); }
@ -126,8 +123,7 @@ class BaseEncoder {
void WriteEdge(const Edge &edge, bool unbound = false) {
WriteRAW(utils::UnderlyingCast(Marker::TinyStruct) + (unbound ? 3 : 5));
WriteRAW(utils::UnderlyingCast(unbound ? Signature::UnboundRelationship
: Signature::Relationship));
WriteRAW(utils::UnderlyingCast(unbound ? Signature::UnboundRelationship : Signature::Relationship));
WriteInt(edge.id.AsInt());
if (!unbound) {

View File

@ -37,8 +37,7 @@ namespace communication::bolt {
template <class TOutputStream>
class ChunkedEncoderBuffer {
public:
ChunkedEncoderBuffer(TOutputStream &output_stream)
: output_stream_(output_stream) {}
ChunkedEncoderBuffer(TOutputStream &output_stream) : output_stream_(output_stream) {}
/**
* Writes n values into the buffer. If n is bigger than whole chunk size
@ -53,12 +52,10 @@ class ChunkedEncoderBuffer {
while (n > 0) {
// Define the number of bytes which will be copied into the chunk because
// the internal storage is a fixed length array.
size_t size =
n < kChunkMaxDataSize - have_ ? n : kChunkMaxDataSize - have_;
size_t size = n < kChunkMaxDataSize - have_ ? n : kChunkMaxDataSize - have_;
// Copy `size` values to the chunk array.
std::memcpy(chunk_.data() + kChunkHeaderSize + have_, values + written,
size);
std::memcpy(chunk_.data() + kChunkHeaderSize + have_, values + written, size);
// Update positions. The position pointer and incoming size have to be
// updated because all incoming values have to be processed.
@ -87,8 +84,7 @@ class ChunkedEncoderBuffer {
chunk_[1] = have_ & 0xFF;
// Write the data to the stream.
auto ret = output_stream_.Write(chunk_.data(), kChunkHeaderSize + have_,
have_more);
auto ret = output_stream_.Write(chunk_.data(), kChunkHeaderSize + have_, have_more);
// Cleanup.
Clear();

View File

@ -38,8 +38,7 @@ class ClientEncoder : private BaseEncoder<Buffer> {
* @returns true if the data was successfully sent to the client
* when flushing, false otherwise
*/
bool MessageInit(const std::string client_name,
const std::map<std::string, Value> &auth_token) {
bool MessageInit(const std::string client_name, const std::map<std::string, Value> &auth_token) {
WriteRAW(utils::UnderlyingCast(Marker::TinyStruct2));
WriteRAW(utils::UnderlyingCast(Signature::Init));
WriteString(client_name);
@ -65,9 +64,7 @@ class ClientEncoder : private BaseEncoder<Buffer> {
* @returns true if the data was successfully sent to the client
* when flushing, false otherwise
*/
bool MessageRun(const std::string &statement,
const std::map<std::string, Value> &parameters,
bool have_more = true) {
bool MessageRun(const std::string &statement, const std::map<std::string, Value> &parameters, bool have_more = true) {
WriteRAW(utils::UnderlyingCast(Marker::TinyStruct2));
WriteRAW(utils::UnderlyingCast(Signature::Run));
WriteString(statement);

View File

@ -50,13 +50,10 @@ class VerboseError : public utils::BasicException {
};
template <class... Args>
VerboseError(Classification classification, const std::string &category,
const std::string &title, const std::string &format,
Args &&... args)
VerboseError(Classification classification, const std::string &category, const std::string &title,
const std::string &format, Args &&...args)
: BasicException(format, std::forward<Args>(args)...),
code_(fmt::format("Memgraph.{}.{}.{}",
ClassificationToString(classification), category,
title)) {}
code_(fmt::format("Memgraph.{}.{}.{}", ClassificationToString(classification), category, title)) {}
const std::string &code() const noexcept { return code_; }

View File

@ -62,9 +62,7 @@ class Session {
* @param q If set, defines from which query to pull the results,
* otherwise the last query is used.
*/
virtual std::map<std::string, Value> Pull(TEncoder *encoder,
std::optional<int> n,
std::optional<int> qid) = 0;
virtual std::map<std::string, Value> Pull(TEncoder *encoder, std::optional<int> n, std::optional<int> qid) = 0;
/**
* Discard results of the processed query.
@ -74,8 +72,7 @@ class Session {
* @param q If set, defines from which query to discard the results,
* otherwise the last query is used.
*/
virtual std::map<std::string, Value> Discard(std::optional<int> n,
std::optional<int> qid) = 0;
virtual std::map<std::string, Value> Discard(std::optional<int> n, std::optional<int> qid) = 0;
virtual void BeginTransaction() = 0;
virtual void CommitTransaction() = 0;
@ -85,8 +82,7 @@ class Session {
virtual void Abort() = 0;
/** Return `true` if the user was successfully authenticated. */
virtual bool Authenticate(const std::string &username,
const std::string &password) = 0;
virtual bool Authenticate(const std::string &username, const std::string &password) = 0;
/** Return the name of the server that should be used for the Bolt INIT
* message. */
@ -104,8 +100,7 @@ class Session {
// Receive the handshake.
if (input_stream_.size() < kHandshakeSize) {
spdlog::trace("Received partial handshake of size {}",
input_stream_.size());
spdlog::trace("Received partial handshake of size {}", input_stream_.size());
return;
}
state_ = StateHandshakeRun(*this);

View File

@ -44,4 +44,4 @@ enum class State : uint8_t {
*/
Close
};
}
} // namespace communication::bolt

View File

@ -26,8 +26,7 @@ State StateErrorRun(TSession &session, State state) {
return State::Close;
}
if (UNLIKELY(signature == Signature::Noop && session.version_.major == 4 &&
session.version_.minor == 1)) {
if (UNLIKELY(signature == Signature::Noop && session.version_.major == 4 && session.version_.minor == 1)) {
spdlog::trace("Received NOOP message");
return state;
}
@ -35,8 +34,7 @@ State StateErrorRun(TSession &session, State state) {
// Clear the data buffer if it has any leftover data.
session.encoder_buffer_.Clear();
if ((session.version_.major == 1 && signature == Signature::AckFailure) ||
signature == Signature::Reset) {
if ((session.version_.major == 1 && signature == Signature::AckFailure) || signature == Signature::Reset) {
if (signature == Signature::AckFailure) {
spdlog::trace("AckFailure received");
} else {
@ -62,8 +60,7 @@ State StateErrorRun(TSession &session, State state) {
// All bolt client messages have less than 15 parameters so if we receive
// anything than a TinyStruct it's an error.
if ((value & 0xF0) != utils::UnderlyingCast(Marker::TinyStruct)) {
spdlog::trace("Expected TinyStruct marker, but received 0x{:02X}!",
value);
spdlog::trace("Expected TinyStruct marker, but received 0x{:02X}!", value);
return State::Close;
}

View File

@ -16,8 +16,7 @@
namespace communication::bolt {
// TODO (mferencevic): revise these error messages
inline std::pair<std::string, std::string> ExceptionToErrorMessage(
const std::exception &e) {
inline std::pair<std::string, std::string> ExceptionToErrorMessage(const std::exception &e) {
if (auto *verbose = dynamic_cast<const VerboseError *>(&e)) {
return {verbose->code(), verbose->what()};
}
@ -54,8 +53,7 @@ inline std::pair<std::string, std::string> ExceptionToErrorMessage(
// All exceptions used in memgraph are derived from BasicException. Since
// we caught some other exception we don't know what is going on. Return
// DatabaseError, log real message and return generic string.
spdlog::error("Unknown exception occurred during query execution {}",
e.what());
spdlog::error("Unknown exception occurred during query execution {}", e.what());
return {"Memgraph.DatabaseError.MemgraphError.MemgraphError",
"An unknown exception occurred, this is unexpected. Real message "
"should be in database logs."};
@ -69,8 +67,7 @@ inline State HandleFailure(TSession &session, const std::exception &e) {
}
session.encoder_buffer_.Clear();
auto code_message = ExceptionToErrorMessage(e);
bool fail_sent = session.encoder_.MessageFailure(
{{"code", code_message.first}, {"message", code_message.second}});
bool fail_sent = session.encoder_.MessageFailure({{"code", code_message.first}, {"message", code_message.second}});
if (!fail_sent) {
spdlog::trace("Couldn't send failure message!");
return State::Close;
@ -80,15 +77,12 @@ inline State HandleFailure(TSession &session, const std::exception &e) {
template <typename TSession>
State HandleRun(TSession &session, State state, Marker marker) {
const std::map<std::string, Value> kEmptyFields = {
{"fields", std::vector<Value>{}}};
const std::map<std::string, Value> kEmptyFields = {{"fields", std::vector<Value>{}}};
const auto expected_marker =
session.version_.major == 1 ? Marker::TinyStruct2 : Marker::TinyStruct3;
const auto expected_marker = session.version_.major == 1 ? Marker::TinyStruct2 : Marker::TinyStruct3;
if (marker != expected_marker) {
spdlog::trace("Expected {} marker, but received 0x{:02X}!",
session.version_.major == 1 ? "TinyStruct2" : "TinyStruct3",
utils::UnderlyingCast(marker));
session.version_.major == 1 ? "TinyStruct2" : "TinyStruct3", utils::UnderlyingCast(marker));
return State::Close;
}
@ -117,15 +111,13 @@ State HandleRun(TSession &session, State state, Marker marker) {
return State::Close;
}
DMG_ASSERT(!session.encoder_buffer_.HasData(),
"There should be no data to write in this state");
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
spdlog::debug("[Run] '{}'", query.ValueString());
try {
// Interpret can throw.
auto [header, qid] =
session.Interpret(query.ValueString(), params.ValueMap());
auto [header, qid] = session.Interpret(query.ValueString(), params.ValueMap());
// Convert std::string to Value
std::vector<Value> vec;
std::map<std::string, Value> data;
@ -146,12 +138,10 @@ State HandleRun(TSession &session, State state, Marker marker) {
namespace detail {
template <bool is_pull, typename TSession>
State HandlePullDiscard(TSession &session, State state, Marker marker) {
const auto expected_marker =
session.version_.major == 1 ? Marker::TinyStruct : Marker::TinyStruct1;
const auto expected_marker = session.version_.major == 1 ? Marker::TinyStruct : Marker::TinyStruct1;
if (marker != expected_marker) {
spdlog::trace("Expected {} marker, but received 0x{:02X}!",
session.version_.major == 1 ? "TinyStruct" : "TinyStruct1",
utils::UnderlyingCast(marker));
session.version_.major == 1 ? "TinyStruct" : "TinyStruct1", utils::UnderlyingCast(marker));
return State::Close;
}
@ -176,15 +166,13 @@ State HandlePullDiscard(TSession &session, State state, Marker marker) {
}
const auto &extra_map = extra.ValueMap();
if (extra_map.count("n")) {
if (const auto n_value = extra_map.at("n").ValueInt();
n_value != kPullAll) {
if (const auto n_value = extra_map.at("n").ValueInt(); n_value != kPullAll) {
n = n_value;
}
}
if (extra_map.count("qid")) {
if (const auto qid_value = extra_map.at("qid").ValueInt();
qid_value != kPullLast) {
if (const auto qid_value = extra_map.at("qid").ValueInt(); qid_value != kPullLast) {
qid = qid_value;
}
}
@ -236,8 +224,7 @@ State HandleReset(Session &session, State, Marker marker) {
// now this command only resets the session to a clean state. It
// does not IGNORE running and pending commands as it should.
if (marker != Marker::TinyStruct) {
spdlog::trace("Expected TinyStruct marker, but received 0x{:02X}!",
utils::UnderlyingCast(marker));
spdlog::trace("Expected TinyStruct marker, but received 0x{:02X}!", utils::UnderlyingCast(marker));
return State::Close;
}
@ -262,8 +249,7 @@ State HandleBegin(Session &session, State state, Marker marker) {
}
if (marker != Marker::TinyStruct1) {
spdlog::trace("Expected TinyStruct1 marker, but received 0x{:02x}!",
utils::UnderlyingCast(marker));
spdlog::trace("Expected TinyStruct1 marker, but received 0x{:02x}!", utils::UnderlyingCast(marker));
return State::Close;
}
@ -278,8 +264,7 @@ State HandleBegin(Session &session, State state, Marker marker) {
return State::Close;
}
DMG_ASSERT(!session.encoder_buffer_.HasData(),
"There should be no data to write in this state");
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
if (!session.encoder_.MessageSuccess({})) {
spdlog::trace("Couldn't send success message!");
@ -303,8 +288,7 @@ State HandleCommit(Session &session, State state, Marker marker) {
}
if (marker != Marker::TinyStruct) {
spdlog::trace("Expected TinyStruct marker, but received 0x{:02x}!",
utils::UnderlyingCast(marker));
spdlog::trace("Expected TinyStruct marker, but received 0x{:02x}!", utils::UnderlyingCast(marker));
return State::Close;
}
@ -313,8 +297,7 @@ State HandleCommit(Session &session, State state, Marker marker) {
return State::Close;
}
DMG_ASSERT(!session.encoder_buffer_.HasData(),
"There should be no data to write in this state");
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
try {
if (!session.encoder_.MessageSuccess({})) {
@ -336,8 +319,7 @@ State HandleRollback(Session &session, State state, Marker marker) {
}
if (marker != Marker::TinyStruct) {
spdlog::trace("Expected TinyStruct marker, but received 0x{:02x}!",
utils::UnderlyingCast(marker));
spdlog::trace("Expected TinyStruct marker, but received 0x{:02x}!", utils::UnderlyingCast(marker));
return State::Close;
}
@ -346,8 +328,7 @@ State HandleRollback(Session &session, State state, Marker marker) {
return State::Close;
}
DMG_ASSERT(!session.encoder_buffer_.HasData(),
"There should be no data to write in this state");
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
try {
if (!session.encoder_.MessageSuccess({})) {
@ -376,8 +357,7 @@ State StateExecutingRun(Session &session, State state) {
return State::Close;
}
if (UNLIKELY(signature == Signature::Noop && session.version_.major == 4 &&
session.version_.minor == 1)) {
if (UNLIKELY(signature == Signature::Noop && session.version_.major == 4 && session.version_.minor == 1)) {
spdlog::trace("Received NOOP message");
return state;
}
@ -399,8 +379,7 @@ State StateExecutingRun(Session &session, State state) {
} else if (signature == Signature::Goodbye && session.version_.major != 1) {
throw SessionClosedException("Closing connection.");
} else {
spdlog::trace("Unrecognized signature received (0x{:02X})!",
utils::UnderlyingCast(signature));
spdlog::trace("Unrecognized signature received (0x{:02X})!", utils::UnderlyingCast(signature));
return State::Close;
}
}

View File

@ -17,15 +17,13 @@ namespace communication::bolt {
*/
template <typename TSession>
State StateHandshakeRun(TSession &session) {
auto precmp =
std::memcmp(session.input_stream_.data(), kPreamble, sizeof(kPreamble));
auto precmp = std::memcmp(session.input_stream_.data(), kPreamble, sizeof(kPreamble));
if (UNLIKELY(precmp != 0)) {
spdlog::trace("Received a wrong preamble!");
return State::Close;
}
DMG_ASSERT(session.input_stream_.size() >= kHandshakeSize,
"Wrong size of the handshake data!");
DMG_ASSERT(session.input_stream_.size() >= kHandshakeSize, "Wrong size of the handshake data!");
auto dataPosition = session.input_stream_.data() + sizeof(kPreamble);
@ -61,8 +59,7 @@ State StateHandshakeRun(TSession &session) {
return State::Close;
}
spdlog::info("Using version {}.{} of protocol", session.version_.major,
session.version_.minor);
spdlog::info("Using version {}.{} of protocol", session.version_.major, session.version_.minor);
// Delete data from the input stream. It is guaranteed that there will more
// than, or equal to 20 bytes (kHandshakeSize) in the buffer.

View File

@ -15,8 +15,7 @@ namespace detail {
template <typename TSession>
std::optional<Value> StateInitRunV1(TSession &session, const Marker marker) {
if (UNLIKELY(marker != Marker::TinyStruct2)) {
spdlog::trace("Expected TinyStruct2 marker, but received 0x{:02X}!",
utils::UnderlyingCast(marker));
spdlog::trace("Expected TinyStruct2 marker, but received 0x{:02X}!", utils::UnderlyingCast(marker));
spdlog::trace(
"The client sent malformed data, but we are continuing "
"because the official Neo4j Java driver sends malformed "
@ -45,8 +44,7 @@ std::optional<Value> StateInitRunV1(TSession &session, const Marker marker) {
template <typename TSession>
std::optional<Value> StateInitRunV4(TSession &session, const Marker marker) {
if (UNLIKELY(marker != Marker::TinyStruct1)) {
spdlog::trace("Expected TinyStruct1 marker, but received 0x{:02X}!",
utils::UnderlyingCast(marker));
spdlog::trace("Expected TinyStruct1 marker, but received 0x{:02X}!", utils::UnderlyingCast(marker));
spdlog::trace(
"The client sent malformed data, but we are continuing "
"because the official Neo4j Java driver sends malformed "
@ -80,8 +78,7 @@ std::optional<Value> StateInitRunV4(TSession &session, const Marker marker) {
*/
template <typename Session>
State StateInitRun(Session &session) {
DMG_ASSERT(!session.encoder_buffer_.HasData(),
"There should be no data to write in this state");
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
Marker marker;
Signature signature;
@ -90,21 +87,18 @@ State StateInitRun(Session &session) {
return State::Close;
}
if (UNLIKELY(signature == Signature::Noop && session.version_.major == 4 &&
session.version_.minor == 1)) {
if (UNLIKELY(signature == Signature::Noop && session.version_.major == 4 && session.version_.minor == 1)) {
SPDLOG_DEBUG("Received NOOP message");
return State::Init;
}
if (UNLIKELY(signature != Signature::Init)) {
spdlog::trace("Expected Init signature, but received 0x{:02X}!",
utils::UnderlyingCast(signature));
spdlog::trace("Expected Init signature, but received 0x{:02X}!", utils::UnderlyingCast(signature));
return State::Close;
}
auto maybeMetadata = session.version_.major == 1
? detail::StateInitRunV1(session, marker)
: detail::StateInitRunV4(session, marker);
auto maybeMetadata =
session.version_.major == 1 ? detail::StateInitRunV1(session, marker) : detail::StateInitRunV4(session, marker);
if (!maybeMetadata) {
return State::Close;
@ -126,16 +120,14 @@ State StateInitRun(Session &session) {
username = data["principal"].ValueString();
password = data["credentials"].ValueString();
} else if (data["scheme"].ValueString() != "none") {
spdlog::warn("Unsupported authentication scheme: {}",
data["scheme"].ValueString());
spdlog::warn("Unsupported authentication scheme: {}", data["scheme"].ValueString());
return State::Close;
}
// Authenticate the user.
if (!session.Authenticate(username, password)) {
if (!session.encoder_.MessageFailure(
{{"code", "Memgraph.ClientError.Security.Unauthenticated"},
{"message", "Authentication failure"}})) {
{{"code", "Memgraph.ClientError.Security.Unauthenticated"}, {"message", "Authentication failure"}})) {
spdlog::trace("Couldn't send failure message to the client!");
}
// Throw an exception to indicate to the network stack that the session

View File

@ -198,8 +198,7 @@ Value &Value::operator=(Value &&other) noexcept {
new (&edge_v) Edge(std::move(other.edge_v));
break;
case Type::UnboundedEdge:
new (&unbounded_edge_v)
UnboundedEdge(std::move(other.unbounded_edge_v));
new (&unbounded_edge_v) UnboundedEdge(std::move(other.unbounded_edge_v));
break;
case Type::Path:
new (&path_v) Path(std::move(other.path_v));
@ -258,17 +257,14 @@ std::ostream &operator<<(std::ostream &os, const Vertex &vertex) {
if (vertex.labels.size() > 0) {
os << ":";
}
utils::PrintIterable(os, vertex.labels, ":",
[&](auto &stream, auto label) { stream << label; });
utils::PrintIterable(os, vertex.labels, ":", [&](auto &stream, auto label) { stream << label; });
if (vertex.labels.size() > 0 && vertex.properties.size() > 0) {
os << " ";
}
if (vertex.properties.size() > 0) {
os << "{";
utils::PrintIterable(os, vertex.properties, ", ",
[&](auto &stream, const auto &pair) {
stream << pair.first << ": " << pair.second;
});
[&](auto &stream, const auto &pair) { stream << pair.first << ": " << pair.second; });
os << "}";
}
return os << ")";
@ -279,9 +275,7 @@ std::ostream &operator<<(std::ostream &os, const Edge &edge) {
if (edge.properties.size() > 0) {
os << " {";
utils::PrintIterable(os, edge.properties, ", ",
[&](auto &stream, const auto &pair) {
stream << pair.first << ": " << pair.second;
});
[&](auto &stream, const auto &pair) { stream << pair.first << ": " << pair.second; });
os << "}";
}
return os << "]";
@ -292,9 +286,7 @@ std::ostream &operator<<(std::ostream &os, const UnboundedEdge &edge) {
if (edge.properties.size() > 0) {
os << " {";
utils::PrintIterable(os, edge.properties, ", ",
[&](auto &stream, const auto &pair) {
stream << pair.first << ": " << pair.second;
});
[&](auto &stream, const auto &pair) { stream << pair.first << ": " << pair.second; });
os << "}";
}
return os << "]";
@ -339,9 +331,7 @@ std::ostream &operator<<(std::ostream &os, const Value &value) {
case Value::Type::Map:
os << "{";
utils::PrintIterable(os, value.ValueMap(), ", ",
[](auto &stream, const auto &pair) {
stream << pair.first << ": " << pair.second;
});
[](auto &stream, const auto &pair) { stream << pair.first << ": " << pair.second; });
return os << "}";
case Value::Type::Vertex:
return os << value.ValueVertex();

View File

@ -33,9 +33,7 @@ class Id {
int64_t id_;
};
inline bool operator==(const Id &id1, const Id &id2) {
return id1.AsInt() == id2.AsInt();
}
inline bool operator==(const Id &id1, const Id &id2) { return id1.AsInt() == id2.AsInt(); }
inline bool operator!=(const Id &id1, const Id &id2) { return !(id1 == id2); }
@ -84,13 +82,10 @@ struct Path {
// into the collection and puts that index into `indices`. A multiplier is
// added to switch between positive and negative indices (that define edge
// direction).
auto add_element = [this](auto &collection, const auto &element,
int multiplier, int offset) {
auto add_element = [this](auto &collection, const auto &element, int multiplier, int offset) {
auto found =
std::find_if(collection.begin(), collection.end(),
[&](const auto &e) { return e.id == element.id; });
indices.emplace_back(multiplier *
(std::distance(collection.begin(), found) + offset));
std::find_if(collection.begin(), collection.end(), [&](const auto &e) { return e.id == element.id; });
indices.emplace_back(multiplier * (std::distance(collection.begin(), found) + offset));
if (found == collection.end()) collection.push_back(element);
};
@ -125,19 +120,7 @@ class Value {
Value() : type_(Type::Null) {}
/** Types that can be stored in a Value. */
enum class Type : unsigned {
Null,
Bool,
Int,
Double,
String,
List,
Map,
Vertex,
Edge,
UnboundedEdge,
Path
};
enum class Type : unsigned { Null, Bool, Int, Double, String, List, Map, Vertex, Edge, UnboundedEdge, Path };
// constructors for primitive types
Value(bool value) : type_(Type::Bool) { bool_v = value; }
@ -146,47 +129,29 @@ class Value {
Value(double value) : type_(Type::Double) { double_v = value; }
// constructors for non-primitive types
Value(const std::string &value) : type_(Type::String) {
new (&string_v) std::string(value);
}
Value(const std::string &value) : type_(Type::String) { new (&string_v) std::string(value); }
Value(const char *value) : Value(std::string(value)) {}
Value(const std::vector<Value> &value) : type_(Type::List) {
new (&list_v) std::vector<Value>(value);
}
Value(const std::vector<Value> &value) : type_(Type::List) { new (&list_v) std::vector<Value>(value); }
Value(const std::map<std::string, Value> &value) : type_(Type::Map) {
new (&map_v) std::map<std::string, Value>(value);
}
Value(const Vertex &value) : type_(Type::Vertex) {
new (&vertex_v) Vertex(value);
}
Value(const Vertex &value) : type_(Type::Vertex) { new (&vertex_v) Vertex(value); }
Value(const Edge &value) : type_(Type::Edge) { new (&edge_v) Edge(value); }
Value(const UnboundedEdge &value) : type_(Type::UnboundedEdge) {
new (&unbounded_edge_v) UnboundedEdge(value);
}
Value(const UnboundedEdge &value) : type_(Type::UnboundedEdge) { new (&unbounded_edge_v) UnboundedEdge(value); }
Value(const Path &value) : type_(Type::Path) { new (&path_v) Path(value); }
// move constructors for non-primitive values
Value(std::string &&value) noexcept : type_(Type::String) {
new (&string_v) std::string(std::move(value));
}
Value(std::vector<Value> &&value) noexcept : type_(Type::List) {
new (&list_v) std::vector<Value>(std::move(value));
}
Value(std::string &&value) noexcept : type_(Type::String) { new (&string_v) std::string(std::move(value)); }
Value(std::vector<Value> &&value) noexcept : type_(Type::List) { new (&list_v) std::vector<Value>(std::move(value)); }
Value(std::map<std::string, Value> &&value) noexcept : type_(Type::Map) {
new (&map_v) std::map<std::string, Value>(std::move(value));
}
Value(Vertex &&value) noexcept : type_(Type::Vertex) {
new (&vertex_v) Vertex(std::move(value));
}
Value(Edge &&value) noexcept : type_(Type::Edge) {
new (&edge_v) Edge(std::move(value));
}
Value(Vertex &&value) noexcept : type_(Type::Vertex) { new (&vertex_v) Vertex(std::move(value)); }
Value(Edge &&value) noexcept : type_(Type::Edge) { new (&edge_v) Edge(std::move(value)); }
Value(UnboundedEdge &&value) noexcept : type_(Type::UnboundedEdge) {
new (&unbounded_edge_v) UnboundedEdge(std::move(value));
}
Value(Path &&value) noexcept : type_(Type::Path) {
new (&path_v) Path(std::move(value));
}
Value(Path &&value) noexcept : type_(Type::Path) { new (&path_v) Path(std::move(value)); }
Value &operator=(const Value &other);
Value &operator=(Value &&other) noexcept;

View File

@ -4,8 +4,7 @@
namespace communication {
Buffer::Buffer()
: data_(kBufferInitialSize, 0), read_end_(this), write_end_(this) {}
Buffer::Buffer() : data_(kBufferInitialSize, 0), read_end_(this), write_end_(this) {}
Buffer::ReadEnd::ReadEnd(Buffer *buffer) : buffer_(buffer) {}
@ -21,9 +20,7 @@ void Buffer::ReadEnd::Clear() { buffer_->Clear(); }
Buffer::WriteEnd::WriteEnd(Buffer *buffer) : buffer_(buffer) {}
io::network::StreamBuffer Buffer::WriteEnd::Allocate() {
return buffer_->Allocate();
}
io::network::StreamBuffer Buffer::WriteEnd::Allocate() { return buffer_->Allocate(); }
void Buffer::WriteEnd::Written(size_t len) { buffer_->Written(len); }

View File

@ -195,8 +195,7 @@ bool Client::Write(const uint8_t *data, size_t len, bool have_more) {
}
bool Client::Write(const std::string &str, bool have_more) {
return Write(reinterpret_cast<const uint8_t *>(str.data()), str.size(),
have_more);
return Write(reinterpret_cast<const uint8_t *>(str.data()), str.size(), have_more);
}
const io::network::Endpoint &Client::endpoint() { return socket_.endpoint(); }
@ -224,12 +223,9 @@ void ClientInputStream::Clear() { client_.ClearData(); }
ClientOutputStream::ClientOutputStream(Client &client) : client_(client) {}
bool ClientOutputStream::Write(const uint8_t *data, size_t len,
bool have_more) {
bool ClientOutputStream::Write(const uint8_t *data, size_t len, bool have_more) {
return client_.Write(data, len, have_more);
}
bool ClientOutputStream::Write(const std::string &str, bool have_more) {
return client_.Write(str, have_more);
}
bool ClientOutputStream::Write(const std::string &str, bool have_more) { return client_.Write(str, have_more); }
} // namespace communication

View File

@ -19,21 +19,16 @@ ClientContext::ClientContext(bool use_ssl) : use_ssl_(use_ssl), ctx_(nullptr) {
}
}
ClientContext::ClientContext(const std::string &key_file,
const std::string &cert_file)
: ClientContext(true) {
ClientContext::ClientContext(const std::string &key_file, const std::string &cert_file) : ClientContext(true) {
if (key_file != "" && cert_file != "") {
MG_ASSERT(SSL_CTX_use_certificate_file(ctx_, cert_file.c_str(),
SSL_FILETYPE_PEM) == 1,
MG_ASSERT(SSL_CTX_use_certificate_file(ctx_, cert_file.c_str(), SSL_FILETYPE_PEM) == 1,
"Couldn't load client certificate from file: {}", cert_file);
MG_ASSERT(SSL_CTX_use_PrivateKey_file(ctx_, key_file.c_str(),
SSL_FILETYPE_PEM) == 1,
MG_ASSERT(SSL_CTX_use_PrivateKey_file(ctx_, key_file.c_str(), SSL_FILETYPE_PEM) == 1,
"Couldn't load client private key from file: ", key_file);
}
}
ClientContext::ClientContext(ClientContext &&other) noexcept
: use_ssl_(other.use_ssl_), ctx_(other.ctx_) {
ClientContext::ClientContext(ClientContext &&other) noexcept : use_ssl_(other.use_ssl_), ctx_(other.ctx_) {
other.use_ssl_ = false;
other.ctx_ = nullptr;
}
@ -69,9 +64,8 @@ bool ClientContext::use_ssl() { return use_ssl_; }
ServerContext::ServerContext() : use_ssl_(false), ctx_(nullptr) {}
ServerContext::ServerContext(const std::string &key_file,
const std::string &cert_file,
const std::string &ca_file, bool verify_peer)
ServerContext::ServerContext(const std::string &key_file, const std::string &cert_file, const std::string &ca_file,
bool verify_peer)
: use_ssl_(true),
#if OPENSSL_VERSION_NUMBER < 0x10100000L
ctx_(SSL_CTX_new(SSLv23_server_method()))
@ -81,11 +75,9 @@ ServerContext::ServerContext(const std::string &key_file,
{
// TODO (mferencevic): add support for encrypted private keys
// TODO (mferencevic): add certificate revocation list (CRL)
MG_ASSERT(SSL_CTX_use_certificate_file(ctx_, cert_file.c_str(),
SSL_FILETYPE_PEM) == 1,
MG_ASSERT(SSL_CTX_use_certificate_file(ctx_, cert_file.c_str(), SSL_FILETYPE_PEM) == 1,
"Couldn't load server certificate from file: {}", cert_file);
MG_ASSERT(SSL_CTX_use_PrivateKey_file(ctx_, key_file.c_str(),
SSL_FILETYPE_PEM) == 1,
MG_ASSERT(SSL_CTX_use_PrivateKey_file(ctx_, key_file.c_str(), SSL_FILETYPE_PEM) == 1,
"Couldn't load server private key from file: {}", key_file);
// Disable legacy SSL support. Other options can be seen here:
@ -94,29 +86,25 @@ ServerContext::ServerContext(const std::string &key_file,
if (ca_file != "") {
// Load the certificate authority file.
MG_ASSERT(
SSL_CTX_load_verify_locations(ctx_, ca_file.c_str(), nullptr) == 1,
"Couldn't load certificate authority from file: {}", ca_file);
MG_ASSERT(SSL_CTX_load_verify_locations(ctx_, ca_file.c_str(), nullptr) == 1,
"Couldn't load certificate authority from file: {}", ca_file);
if (verify_peer) {
// Add the CA to list of accepted CAs that is sent to the client.
STACK_OF(X509_NAME) *ca_names = SSL_load_client_CA_file(ca_file.c_str());
MG_ASSERT(ca_names != nullptr,
"Couldn't load certificate authority from file: {}", ca_file);
MG_ASSERT(ca_names != nullptr, "Couldn't load certificate authority from file: {}", ca_file);
// `ca_names` doesn' need to be free'd because we pass it to
// `SSL_CTX_set_client_CA_list`:
// https://mta.openssl.org/pipermail/openssl-users/2015-May/001363.html
SSL_CTX_set_client_CA_list(ctx_, ca_names);
// Enable verification of the client certificate.
SSL_CTX_set_verify(
ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
}
}
}
ServerContext::ServerContext(ServerContext &&other) noexcept
: use_ssl_(other.use_ssl_), ctx_(other.ctx_) {
ServerContext::ServerContext(ServerContext &&other) noexcept : use_ssl_(other.use_ssl_), ctx_(other.ctx_) {
other.use_ssl_ = false;
other.ctx_ = nullptr;
}

View File

@ -72,8 +72,8 @@ class ServerContext final {
* to check that the client certificate is valid, then you need to supply a
* valid `ca_file` as well.
*/
ServerContext(const std::string &key_file, const std::string &cert_file,
const std::string &ca_file = "", bool verify_peer = false);
ServerContext(const std::string &key_file, const std::string &cert_file, const std::string &ca_file = "",
bool verify_peer = false);
// This object can't be copied because the underlying SSL implementation is
// messy and ownership can't be handled correctly.

View File

@ -28,10 +28,7 @@ void LockingFunction(int mode, int n, const char *file, int line) {
}
}
unsigned long IdFunction() {
return (unsigned long)std::hash<std::thread::id>()(
std::this_thread::get_id());
}
unsigned long IdFunction() { return (unsigned long)std::hash<std::thread::id>()(std::this_thread::get_id()); }
void SetupThreading() {
crypto_locks.resize(CRYPTO_num_locks());
@ -58,8 +55,7 @@ SSLInit::SSLInit() {
ERR_load_crypto_strings();
// Ignore SIGPIPE.
MG_ASSERT(utils::SignalIgnore(utils::Signal::Pipe),
"Couldn't ignore SIGPIPE!");
MG_ASSERT(utils::SignalIgnore(utils::Signal::Pipe), "Couldn't ignore SIGPIPE!");
SetupThreading();
}

View File

@ -40,8 +40,7 @@ class Listener final {
using SessionHandler = Session<TSession, TSessionData>;
public:
Listener(TSessionData *data, ServerContext *context,
int inactivity_timeout_sec, const std::string &service_name,
Listener(TSessionData *data, ServerContext *context, int inactivity_timeout_sec, const std::string &service_name,
size_t workers_count)
: data_(data),
alive_(false),
@ -77,8 +76,8 @@ class Listener final {
int fd = connection.fd();
// Create a new Session for the connection.
sessions_.push_back(std::make_unique<SessionHandler>(
std::move(connection), data_, context_, inactivity_timeout_sec_));
sessions_.push_back(
std::make_unique<SessionHandler>(std::move(connection), data_, context_, inactivity_timeout_sec_));
// Register the connection in Epoll.
// We want to listen to an incoming event which is edge triggered and
@ -86,8 +85,7 @@ class Listener final {
// concurrently and that is why we use `EPOLLONESHOT`, for a detailed
// description what are the problems and why this is correct see:
// https://idea.popcount.org/2017-02-20-epoll-is-fundamentally-broken-12/
epoll_.Add(fd, EPOLLIN | EPOLLET | EPOLLRDHUP | EPOLLONESHOT,
sessions_.back().get());
epoll_.Add(fd, EPOLLIN | EPOLLET | EPOLLRDHUP | EPOLLONESHOT, sessions_.back().get());
}
/**
@ -117,8 +115,7 @@ class Listener final {
std::lock_guard<utils::SpinLock> guard(lock_);
for (auto &session : sessions_) {
if (session->TimedOut()) {
spdlog::warn("{} session associated with {} timed out",
service_name, session->socket().endpoint());
spdlog::warn("{} session associated with {} timed out", service_name, session->socket().endpoint());
// Here we shutdown the socket to terminate any leftover
// blocking `Write` calls and to signal an event that the
// session is closed. Session cleanup will be done in the event
@ -178,8 +175,7 @@ class Listener final {
// dereference it here. It is safe to dereference the pointer because
// this design guarantees that there will never be an event that has
// a stale Session pointer.
SessionHandler &session =
*reinterpret_cast<SessionHandler *>(event.data.ptr);
SessionHandler &session = *reinterpret_cast<SessionHandler *>(event.data.ptr);
// Process epoll events. We use epoll in edge-triggered mode so we process
// all events here. Only one of the `if` statements must be executed
@ -192,20 +188,16 @@ class Listener final {
;
} else if (event.events & EPOLLRDHUP) {
// The client closed the connection.
spdlog::info("{} client {} closed the connection.", service_name_,
session.socket().endpoint());
spdlog::info("{} client {} closed the connection.", service_name_, session.socket().endpoint());
CloseSession(session);
} else if (!(event.events & EPOLLIN) ||
event.events & (EPOLLHUP | EPOLLERR)) {
} else if (!(event.events & EPOLLIN) || event.events & (EPOLLHUP | EPOLLERR)) {
// There was an error on the server side.
spdlog::error("Error occured in {} session associated with {}",
service_name_, session.socket().endpoint());
spdlog::error("Error occured in {} session associated with {}", service_name_, session.socket().endpoint());
CloseSession(session);
} else {
// Unhandled epoll event.
spdlog::error(
"Unhandled event occured in {} session associated with {} events: {}",
service_name_, session.socket().endpoint(), event.events);
spdlog::error("Unhandled event occured in {} session associated with {} events: {}", service_name_,
session.socket().endpoint(), event.events);
CloseSession(session);
}
}
@ -215,13 +207,11 @@ class Listener final {
if (session.Execute()) {
// Session execution done, rearm epoll to send events for this
// socket.
epoll_.Modify(session.socket().fd(),
EPOLLIN | EPOLLET | EPOLLRDHUP | EPOLLONESHOT, &session);
epoll_.Modify(session.socket().fd(), EPOLLIN | EPOLLET | EPOLLRDHUP | EPOLLONESHOT, &session);
return false;
}
} catch (const SessionClosedException &e) {
spdlog::info("{} client {} closed the connection.", service_name_,
session.socket().endpoint());
spdlog::info("{} client {} closed the connection.", service_name_, session.socket().endpoint());
CloseSession(session);
return false;
} catch (const std::exception &e) {
@ -245,11 +235,9 @@ class Listener final {
epoll_.Delete(session.socket().fd());
std::lock_guard<utils::SpinLock> guard(lock_);
auto it = std::find_if(sessions_.begin(), sessions_.end(),
[&](const auto &l) { return l.get() == &session; });
auto it = std::find_if(sessions_.begin(), sessions_.end(), [&](const auto &l) { return l.get() == &session; });
MG_ASSERT(it != sessions_.end(),
"Trying to remove session that is not found in sessions!");
MG_ASSERT(it != sessions_.end(), "Trying to remove session that is not found in sessions!");
int i = it - sessions_.begin();
swap(sessions_[i], sessions_.back());

View File

@ -27,9 +27,7 @@ class ResultStreamFaker {
void Header(const std::vector<std::string> &fields) { header_ = fields; }
void Result(const std::vector<communication::bolt::Value> &values) {
results_.push_back(values);
}
void Result(const std::vector<communication::bolt::Value> &values) { results_.push_back(values); }
void Result(const std::vector<query::TypedValue> &values) {
std::vector<communication::bolt::Value> bvalues;
@ -42,16 +40,12 @@ class ResultStreamFaker {
results_.push_back(std::move(bvalues));
}
void Summary(
const std::map<std::string, communication::bolt::Value> &summary) {
summary_ = summary;
}
void Summary(const std::map<std::string, communication::bolt::Value> &summary) { summary_ = summary; }
void Summary(const std::map<std::string, query::TypedValue> &summary) {
std::map<std::string, communication::bolt::Value> bsummary;
for (const auto &item : summary) {
auto maybe_value =
glue::ToBoltValue(item.second, *store_, storage::View::NEW);
auto maybe_value = glue::ToBoltValue(item.second, *store_, storage::View::NEW);
MG_ASSERT(maybe_value.HasValue());
bsummary.insert({item.first, std::move(*maybe_value)});
}
@ -64,8 +58,7 @@ class ResultStreamFaker {
const auto &GetSummary() const { return summary_; }
friend std::ostream &operator<<(std::ostream &os,
const ResultStreamFaker &results) {
friend std::ostream &operator<<(std::ostream &os, const ResultStreamFaker &results) {
auto decoded_value_to_string = [](const auto &value) {
std::stringstream ss;
ss << value;
@ -73,21 +66,16 @@ class ResultStreamFaker {
};
const std::vector<std::string> &header = results.GetHeader();
std::vector<int> column_widths(header.size());
std::transform(header.begin(), header.end(), column_widths.begin(),
[](const auto &s) { return s.size(); });
std::transform(header.begin(), header.end(), column_widths.begin(), [](const auto &s) { return s.size(); });
// convert all the results into strings, and track max column width
auto &results_data = results.GetResults();
std::vector<std::vector<std::string>> result_strings(
results_data.size(), std::vector<std::string>(column_widths.size()));
for (int row_ind = 0; row_ind < static_cast<int>(results_data.size());
++row_ind) {
for (int col_ind = 0; col_ind < static_cast<int>(column_widths.size());
++col_ind) {
std::string string_val =
decoded_value_to_string(results_data[row_ind][col_ind]);
column_widths[col_ind] =
std::max(column_widths[col_ind], (int)string_val.size());
std::vector<std::vector<std::string>> result_strings(results_data.size(),
std::vector<std::string>(column_widths.size()));
for (int row_ind = 0; row_ind < static_cast<int>(results_data.size()); ++row_ind) {
for (int col_ind = 0; col_ind < static_cast<int>(column_widths.size()); ++col_ind) {
std::string string_val = decoded_value_to_string(results_data[row_ind][col_ind]);
column_widths[col_ind] = std::max(column_widths[col_ind], (int)string_val.size());
result_strings[row_ind][col_ind] = string_val;
}
}
@ -96,15 +84,13 @@ class ResultStreamFaker {
// first define some helper functions
auto emit_horizontal_line = [&]() {
os << "+";
for (auto col_width : column_widths)
os << std::string((unsigned long)col_width + 2, '-') << "+";
for (auto col_width : column_widths) os << std::string((unsigned long)col_width + 2, '-') << "+";
os << std::endl;
};
auto emit_result_vec = [&](const std::vector<std::string> result_vec) {
os << "| ";
for (int col_ind = 0; col_ind < static_cast<int>(column_widths.size());
++col_ind) {
for (int col_ind = 0; col_ind < static_cast<int>(column_widths.size()); ++col_ind) {
const std::string &res = result_vec[col_ind];
os << res << std::string(column_widths[col_ind] - res.size(), ' ');
os << " | ";
@ -123,9 +109,7 @@ class ResultStreamFaker {
// output the summary
os << "Query summary: {";
utils::PrintIterable(os, results.GetSummary(), ", ",
[&](auto &stream, const auto &kv) {
stream << kv.first << ": " << kv.second;
});
[&](auto &stream, const auto &kv) { stream << kv.first << ": " << kv.second; });
os << "}" << std::endl;
return os;

View File

@ -46,14 +46,12 @@ class Server final {
* Constructs and binds server to endpoint, operates on session data and
* invokes workers_count workers
*/
Server(const io::network::Endpoint &endpoint, TSessionData *session_data,
ServerContext *context, int inactivity_timeout_sec,
const std::string &service_name,
Server(const io::network::Endpoint &endpoint, TSessionData *session_data, ServerContext *context,
int inactivity_timeout_sec, const std::string &service_name,
size_t workers_count = std::thread::hardware_concurrency())
: alive_(false),
endpoint_(endpoint),
listener_(session_data, context, inactivity_timeout_sec, service_name,
workers_count),
listener_(session_data, context, inactivity_timeout_sec, service_name, workers_count),
service_name_(service_name) {}
~Server() {
@ -69,8 +67,7 @@ class Server final {
Server &operator=(Server &&) = delete;
const auto &endpoint() const {
MG_ASSERT(alive_,
"You can't get the server endpoint when it's not running!");
MG_ASSERT(alive_, "You can't get the server endpoint when it's not running!");
return socket_.endpoint();
}
@ -138,8 +135,7 @@ class Server final {
// Connection is not available anymore or configuration failed.
return;
}
spdlog::info("Accepted a {} connection from {}", service_name_,
s->endpoint());
spdlog::info("Accepted a {} connection from {}", service_name_, s->endpoint());
listener_.AddConnection(std::move(*s));
}

View File

@ -35,22 +35,17 @@ using InputStream = Buffer::ReadEnd;
*/
class OutputStream final {
public:
OutputStream(
std::function<bool(const uint8_t *, size_t, bool)> write_function)
: write_function_(write_function) {}
OutputStream(std::function<bool(const uint8_t *, size_t, bool)> write_function) : write_function_(write_function) {}
OutputStream(const OutputStream &) = delete;
OutputStream(OutputStream &&) = delete;
OutputStream &operator=(const OutputStream &) = delete;
OutputStream &operator=(OutputStream &&) = delete;
bool Write(const uint8_t *data, size_t len, bool have_more = false) {
return write_function_(data, len, have_more);
}
bool Write(const uint8_t *data, size_t len, bool have_more = false) { return write_function_(data, len, have_more); }
bool Write(const std::string &str, bool have_more = false) {
return Write(reinterpret_cast<const uint8_t *>(str.data()), str.size(),
have_more);
return Write(reinterpret_cast<const uint8_t *>(str.data()), str.size(), have_more);
}
private:
@ -65,14 +60,10 @@ class OutputStream final {
template <class TSession, class TSessionData>
class Session final {
public:
Session(io::network::Socket &&socket, TSessionData *data,
ServerContext *context, int inactivity_timeout_sec)
Session(io::network::Socket &&socket, TSessionData *data, ServerContext *context, int inactivity_timeout_sec)
: socket_(std::move(socket)),
output_stream_([this](const uint8_t *data, size_t len, bool have_more) {
return Write(data, len, have_more);
}),
session_(data, socket_.endpoint(), input_buffer_.read_end(),
&output_stream_),
output_stream_([this](const uint8_t *data, size_t len, bool have_more) { return Write(data, len, have_more); }),
session_(data, socket_.endpoint(), input_buffer_.read_end(), &output_stream_),
inactivity_timeout_sec_(inactivity_timeout_sec) {
// Set socket options.
// The socket is set to be a non-blocking socket. We use the socket in a
@ -243,8 +234,7 @@ class Session final {
bool TimedOut() {
std::unique_lock<utils::SpinLock> guard(lock_);
if (execution_active_) return false;
return last_event_time_ + std::chrono::seconds(inactivity_timeout_sec_) <
std::chrono::steady_clock::now();
return last_event_time_ + std::chrono::seconds(inactivity_timeout_sec_) < std::chrono::steady_clock::now();
}
/**
@ -316,8 +306,7 @@ class Session final {
TSession session_;
// Time of the last event and associated lock.
std::chrono::time_point<std::chrono::steady_clock> last_event_time_{
std::chrono::steady_clock::now()};
std::chrono::time_point<std::chrono::steady_clock> last_event_time_{std::chrono::steady_clock::now()};
bool execution_active_{false};
utils::SpinLock lock_;
const int inactivity_timeout_sec_;

View File

@ -20,9 +20,7 @@
template <typename TElement>
class RingBuffer {
public:
explicit RingBuffer(int capacity) : capacity_(capacity) {
buffer_ = std::make_unique<TElement[]>(capacity_);
}
explicit RingBuffer(int capacity) : capacity_(capacity) { buffer_ = std::make_unique<TElement[]>(capacity_); }
RingBuffer(const RingBuffer &) = delete;
RingBuffer(RingBuffer &&) = delete;

View File

@ -32,34 +32,28 @@ query::TypedValue ToTypedValue(const Value &value) {
}
case Value::Type::Map: {
std::map<std::string, query::TypedValue> map;
for (const auto &kv : value.ValueMap())
map.emplace(kv.first, ToTypedValue(kv.second));
for (const auto &kv : value.ValueMap()) map.emplace(kv.first, ToTypedValue(kv.second));
return query::TypedValue(std::move(map));
}
case Value::Type::Vertex:
case Value::Type::Edge:
case Value::Type::UnboundedEdge:
case Value::Type::Path:
throw communication::bolt::ValueException(
"Unsupported conversion from Value to TypedValue");
throw communication::bolt::ValueException("Unsupported conversion from Value to TypedValue");
}
}
storage::Result<communication::bolt::Vertex> ToBoltVertex(
const query::VertexAccessor &vertex, const storage::Storage &db,
storage::View view) {
storage::Result<communication::bolt::Vertex> ToBoltVertex(const query::VertexAccessor &vertex,
const storage::Storage &db, storage::View view) {
return ToBoltVertex(vertex.impl_, db, view);
}
storage::Result<communication::bolt::Edge> ToBoltEdge(
const query::EdgeAccessor &edge, const storage::Storage &db,
storage::View view) {
storage::Result<communication::bolt::Edge> ToBoltEdge(const query::EdgeAccessor &edge, const storage::Storage &db,
storage::View view) {
return ToBoltEdge(edge.impl_, db, view);
}
storage::Result<Value> ToBoltValue(const query::TypedValue &value,
const storage::Storage &db,
storage::View view) {
storage::Result<Value> ToBoltValue(const query::TypedValue &value, const storage::Storage &db, storage::View view) {
switch (value.type()) {
case query::TypedValue::Type::Null:
return Value();
@ -90,20 +84,17 @@ storage::Result<Value> ToBoltValue(const query::TypedValue &value,
}
return Value(std::move(map));
}
case query::TypedValue::Type::Vertex:
{
case query::TypedValue::Type::Vertex: {
auto maybe_vertex = ToBoltVertex(value.ValueVertex(), db, view);
if (maybe_vertex.HasError()) return maybe_vertex.GetError();
return Value(std::move(*maybe_vertex));
}
case query::TypedValue::Type::Edge:
{
case query::TypedValue::Type::Edge: {
auto maybe_edge = ToBoltEdge(value.ValueEdge(), db, view);
if (maybe_edge.HasError()) return maybe_edge.GetError();
return Value(std::move(*maybe_edge));
}
case query::TypedValue::Type::Path:
{
case query::TypedValue::Type::Path: {
auto maybe_path = ToBoltPath(value.ValuePath(), db, view);
if (maybe_path.HasError()) return maybe_path.GetError();
return Value(std::move(*maybe_path));
@ -111,9 +102,8 @@ storage::Result<Value> ToBoltValue(const query::TypedValue &value,
}
}
storage::Result<communication::bolt::Vertex> ToBoltVertex(
const storage::VertexAccessor &vertex, const storage::Storage &db,
storage::View view) {
storage::Result<communication::bolt::Vertex> ToBoltVertex(const storage::VertexAccessor &vertex,
const storage::Storage &db, storage::View view) {
auto id = communication::bolt::Id::FromUint(vertex.Gid().AsUint());
auto maybe_labels = vertex.Labels(view);
if (maybe_labels.HasError()) return maybe_labels.GetError();
@ -131,12 +121,10 @@ storage::Result<communication::bolt::Vertex> ToBoltVertex(
return communication::bolt::Vertex{id, labels, properties};
}
storage::Result<communication::bolt::Edge> ToBoltEdge(
const storage::EdgeAccessor &edge, const storage::Storage &db,
storage::View view) {
storage::Result<communication::bolt::Edge> ToBoltEdge(const storage::EdgeAccessor &edge, const storage::Storage &db,
storage::View view) {
auto id = communication::bolt::Id::FromUint(edge.Gid().AsUint());
auto from =
communication::bolt::Id::FromUint(edge.FromVertex().Gid().AsUint());
auto from = communication::bolt::Id::FromUint(edge.FromVertex().Gid().AsUint());
auto to = communication::bolt::Id::FromUint(edge.ToVertex().Gid().AsUint());
auto type = db.EdgeTypeToName(edge.EdgeType());
auto maybe_properties = edge.Properties(view);
@ -148,8 +136,8 @@ storage::Result<communication::bolt::Edge> ToBoltEdge(
return communication::bolt::Edge{id, from, to, type, properties};
}
storage::Result<communication::bolt::Path> ToBoltPath(
const query::Path &path, const storage::Storage &db, storage::View view) {
storage::Result<communication::bolt::Path> ToBoltPath(const query::Path &path, const storage::Storage &db,
storage::View view) {
std::vector<communication::bolt::Vertex> vertices;
vertices.reserve(path.vertices().size());
for (const auto &v : path.vertices()) {
@ -182,22 +170,19 @@ storage::PropertyValue ToPropertyValue(const Value &value) {
case Value::Type::List: {
std::vector<storage::PropertyValue> vec;
vec.reserve(value.ValueList().size());
for (const auto &value : value.ValueList())
vec.emplace_back(ToPropertyValue(value));
for (const auto &value : value.ValueList()) vec.emplace_back(ToPropertyValue(value));
return storage::PropertyValue(std::move(vec));
}
case Value::Type::Map: {
std::map<std::string, storage::PropertyValue> map;
for (const auto &kv : value.ValueMap())
map.emplace(kv.first, ToPropertyValue(kv.second));
for (const auto &kv : value.ValueMap()) map.emplace(kv.first, ToPropertyValue(kv.second));
return storage::PropertyValue(std::move(map));
}
case Value::Type::Vertex:
case Value::Type::Edge:
case Value::Type::UnboundedEdge:
case Value::Type::Path:
throw communication::bolt::ValueException(
"Unsupported conversion from Value to PropertyValue");
throw communication::bolt::ValueException("Unsupported conversion from Value to PropertyValue");
}
}

View File

@ -21,35 +21,32 @@ namespace glue {
/// @param storage::View for deciding which vertex attributes are visible.
///
/// @throw std::bad_alloc
storage::Result<communication::bolt::Vertex> ToBoltVertex(
const storage::VertexAccessor &vertex, const storage::Storage &db,
storage::View view);
storage::Result<communication::bolt::Vertex> ToBoltVertex(const storage::VertexAccessor &vertex,
const storage::Storage &db, storage::View view);
/// @param storage::EdgeAccessor for converting to communication::bolt::Edge.
/// @param storage::Storage for getting edge type and property names.
/// @param storage::View for deciding which edge attributes are visible.
///
/// @throw std::bad_alloc
storage::Result<communication::bolt::Edge> ToBoltEdge(
const storage::EdgeAccessor &edge, const storage::Storage &db,
storage::View view);
storage::Result<communication::bolt::Edge> ToBoltEdge(const storage::EdgeAccessor &edge, const storage::Storage &db,
storage::View view);
/// @param query::Path for converting to communication::bolt::Path.
/// @param storage::Storage for ToBoltVertex and ToBoltEdge.
/// @param storage::View for ToBoltVertex and ToBoltEdge.
///
/// @throw std::bad_alloc
storage::Result<communication::bolt::Path> ToBoltPath(
const query::Path &path, const storage::Storage &db, storage::View view);
storage::Result<communication::bolt::Path> ToBoltPath(const query::Path &path, const storage::Storage &db,
storage::View view);
/// @param query::TypedValue for converting to communication::bolt::Value.
/// @param storage::Storage for ToBoltVertex and ToBoltEdge.
/// @param storage::View for ToBoltVertex and ToBoltEdge.
///
/// @throw std::bad_alloc
storage::Result<communication::bolt::Value> ToBoltValue(
const query::TypedValue &value, const storage::Storage &db,
storage::View view);
storage::Result<communication::bolt::Value> ToBoltValue(const query::TypedValue &value, const storage::Storage &db,
storage::View view);
query::TypedValue ToTypedValue(const communication::bolt::Value &value);

View File

@ -17,17 +17,13 @@
inline void LoadConfig(const std::string &product_name) {
namespace fs = std::filesystem;
std::vector<fs::path> configs = {fs::path("/etc/memgraph/memgraph.conf")};
if (getenv("HOME") != nullptr)
configs.emplace_back(fs::path(getenv("HOME")) /
fs::path(".memgraph/config"));
if (getenv("HOME") != nullptr) configs.emplace_back(fs::path(getenv("HOME")) / fs::path(".memgraph/config"));
{
auto memgraph_config = getenv("MEMGRAPH_CONFIG");
if (memgraph_config != nullptr) {
auto path = fs::path(memgraph_config);
MG_ASSERT(
fs::exists(path),
"MEMGRAPH_CONFIG environment variable set to nonexisting path: {}",
path.generic_string());
MG_ASSERT(fs::exists(path), "MEMGRAPH_CONFIG environment variable set to nonexisting path: {}",
path.generic_string());
configs.emplace_back(path);
}
}
@ -35,8 +31,7 @@ inline void LoadConfig(const std::string &product_name) {
std::vector<std::string> flagfile_arguments;
for (const auto &config : configs)
if (fs::exists(config)) {
flagfile_arguments.emplace_back(
std::string("--flag-file=" + config.generic_string()));
flagfile_arguments.emplace_back(std::string("--flag-file=" + config.generic_string()));
}
int custom_argc = static_cast<int>(flagfile_arguments.size()) + 1;

View File

@ -25,10 +25,8 @@ Endpoint::IpFamily Endpoint::GetIpFamily(const std::string &ip_address) {
}
}
std::optional<std::pair<std::string, uint16_t>>
Endpoint::ParseSocketOrIpAddress(
const std::string &address,
const std::optional<uint16_t> default_port = {}) {
std::optional<std::pair<std::string, uint16_t>> Endpoint::ParseSocketOrIpAddress(
const std::string &address, const std::optional<uint16_t> default_port = {}) {
/// expected address format:
/// - "ip_address:port_number"
/// - "ip_address"
@ -80,8 +78,7 @@ std::string Endpoint::SocketAddress() const {
}
Endpoint::Endpoint() {}
Endpoint::Endpoint(std::string ip_address, uint16_t port)
: address(std::move(ip_address)), port(port) {
Endpoint::Endpoint(std::string ip_address, uint16_t port) : address(std::move(ip_address)), port(port) {
IpFamily ip_family = GetIpFamily(address);
if (ip_family == IpFamily::NONE) {
throw NetworkError("Not a valid IPv4 or IPv6 address: {}", ip_address);

View File

@ -21,13 +21,11 @@ class Epoll {
public:
using Event = struct epoll_event;
Epoll(bool set_cloexec = false)
: epoll_fd_(epoll_create1(set_cloexec ? EPOLL_CLOEXEC : 0)) {
Epoll(bool set_cloexec = false) : epoll_fd_(epoll_create1(set_cloexec ? EPOLL_CLOEXEC : 0)) {
// epoll_create1 returns an error if there is a logical error in our code
// (for example invalid flags) or if there is irrecoverable error. In both
// cases it is best to terminate.
MG_ASSERT(epoll_fd_ != -1, "Error on epoll create: ({}) {}", errno,
strerror(errno));
MG_ASSERT(epoll_fd_ != -1, "Error on epoll create: ({}) {}", errno, strerror(errno));
}
/**
@ -42,15 +40,13 @@ class Epoll {
Event event;
event.events = events;
event.data.ptr = ptr;
int status = epoll_ctl(epoll_fd_, (modify ? EPOLL_CTL_MOD : EPOLL_CTL_ADD),
fd, &event);
int status = epoll_ctl(epoll_fd_, (modify ? EPOLL_CTL_MOD : EPOLL_CTL_ADD), fd, &event);
// epoll_ctl can return an error on our logical error or on irrecoverable
// error. There is a third possibility that some system limit is reached. In
// that case we could return an erorr and close connection. Chances of
// reaching system limit in normally working memgraph is extremely unlikely,
// so it is correct to terminate even in that case.
MG_ASSERT(!status, "Error on epoll {}: ({}) {}",
(modify ? "modify" : "add"), errno, strerror(errno));
MG_ASSERT(!status, "Error on epoll {}: ({}) {}", (modify ? "modify" : "add"), errno, strerror(errno));
}
/**
@ -60,9 +56,7 @@ class Epoll {
* @param events epoll events mask
* @param ptr pointer to the associated event handler
*/
void Modify(int fd, uint32_t events, void *ptr) {
Add(fd, events, ptr, true);
}
void Modify(int fd, uint32_t events, void *ptr) { Add(fd, events, ptr, true); }
/**
* This function deletes a file descriptor that is listened for events.
@ -76,8 +70,7 @@ class Epoll {
// that case we could return an erorr and close connection. Chances of
// reaching system limit in normally working memgraph is extremely unlikely,
// so it is correct to terminate even in that case.
MG_ASSERT(!status, "Error on epoll delete: ({}) {}", errno,
strerror(errno));
MG_ASSERT(!status, "Error on epoll delete: ({}) {}", errno, strerror(errno));
}
/**
@ -91,8 +84,7 @@ class Epoll {
int Wait(Event *events, int max_events, int timeout) {
auto num_events = epoll_wait(epoll_fd_, events, max_events, timeout);
// If this check fails there was logical error in our code.
MG_ASSERT(num_events != -1 || errno == EINTR,
"Error on epoll wait: ({}) {}", errno, strerror(errno));
MG_ASSERT(num_events != -1 || errno == EINTR, "Error on epoll wait: ({}) {}", errno, strerror(errno));
// num_events can be -1 if errno was EINTR (epoll_wait interrupted by signal
// handler). We treat that as no events, so we return 0.
return num_events == -1 ? 0 : num_events;

View File

@ -8,4 +8,4 @@ class NetworkError : public utils::StacktraceException {
public:
using utils::StacktraceException::StacktraceException;
};
}
} // namespace io::network

View File

@ -59,8 +59,7 @@ bool Socket::IsOpen() const { return socket_ != -1; }
bool Socket::Connect(const Endpoint &endpoint) {
if (socket_ != -1) return false;
auto info = AddrInfo::Get(endpoint.address.c_str(),
std::to_string(endpoint.port).c_str());
auto info = AddrInfo::Get(endpoint.address.c_str(), std::to_string(endpoint.port).c_str());
for (struct addrinfo *it = info; it != nullptr; it = it->ai_next) {
int sfd = socket(it->ai_family, it->ai_socktype, it->ai_protocol);
@ -83,8 +82,7 @@ bool Socket::Connect(const Endpoint &endpoint) {
bool Socket::Bind(const Endpoint &endpoint) {
if (socket_ != -1) return false;
auto info = AddrInfo::Get(endpoint.address.c_str(),
std::to_string(endpoint.port).c_str());
auto info = AddrInfo::Get(endpoint.address.c_str(), std::to_string(endpoint.port).c_str());
for (struct addrinfo *it = info; it != nullptr; it = it->ai_next) {
int sfd = socket(it->ai_family, it->ai_socktype, it->ai_protocol);
@ -130,38 +128,30 @@ void Socket::SetNonBlocking() {
int flags = fcntl(socket_, F_GETFL, 0);
MG_ASSERT(flags != -1, "Can't get socket mode");
flags |= O_NONBLOCK;
MG_ASSERT(fcntl(socket_, F_SETFL, flags) != -1,
"Can't set socket nonblocking");
MG_ASSERT(fcntl(socket_, F_SETFL, flags) != -1, "Can't set socket nonblocking");
}
void Socket::SetKeepAlive() {
int optval = 1;
socklen_t optlen = sizeof(optval);
MG_ASSERT(!setsockopt(socket_, SOL_SOCKET, SO_KEEPALIVE, &optval, optlen),
"Can't set socket keep alive");
MG_ASSERT(!setsockopt(socket_, SOL_SOCKET, SO_KEEPALIVE, &optval, optlen), "Can't set socket keep alive");
optval = 20; // wait 20s before sending keep-alive packets
MG_ASSERT(
!setsockopt(socket_, SOL_TCP, TCP_KEEPIDLE, (void *)&optval, optlen),
"Can't set socket keep alive");
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPIDLE, (void *)&optval, optlen), "Can't set socket keep alive");
optval = 4; // 4 keep-alive packets must fail to close
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPCNT, (void *)&optval, optlen),
"Can't set socket keep alive");
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPCNT, (void *)&optval, optlen), "Can't set socket keep alive");
optval = 15; // send keep-alive packets every 15s
MG_ASSERT(
!setsockopt(socket_, SOL_TCP, TCP_KEEPINTVL, (void *)&optval, optlen),
"Can't set socket keep alive");
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPINTVL, (void *)&optval, optlen), "Can't set socket keep alive");
}
void Socket::SetNoDelay() {
int optval = 1;
socklen_t optlen = sizeof(optval);
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_NODELAY, (void *)&optval, optlen),
"Can't set socket no delay");
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_NODELAY, (void *)&optval, optlen), "Can't set socket no delay");
}
void Socket::SetTimeout(long sec, long usec) {
@ -169,11 +159,9 @@ void Socket::SetTimeout(long sec, long usec) {
tv.tv_sec = sec;
tv.tv_usec = usec;
MG_ASSERT(!setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
"Can't set socket timeout");
MG_ASSERT(!setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), "Can't set socket timeout");
MG_ASSERT(!setsockopt(socket_, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
"Can't set socket timeout");
MG_ASSERT(!setsockopt(socket_, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), "Can't set socket timeout");
}
int Socket::ErrorStatus() const {
@ -238,8 +226,7 @@ bool Socket::Write(const uint8_t *data, size_t len, bool have_more) {
}
bool Socket::Write(const std::string &s, bool have_more) {
return Write(reinterpret_cast<const uint8_t *>(s.data()), s.size(),
have_more);
return Write(reinterpret_cast<const uint8_t *>(s.data()), s.size(), have_more);
}
ssize_t Socket::Read(void *buffer, size_t len, bool nonblock) {

View File

@ -14,4 +14,4 @@ struct StreamBuffer {
uint8_t *data;
size_t len;
};
}
} // namespace io::network

View File

@ -23,9 +23,8 @@ std::string ResolveHostname(std::string hostname) {
int addr_result;
addrinfo *servinfo;
MG_ASSERT((addr_result =
getaddrinfo(hostname.c_str(), NULL, &hints, &servinfo)) == 0,
"Error with getaddrinfo: {}", gai_strerror(addr_result));
MG_ASSERT((addr_result = getaddrinfo(hostname.c_str(), NULL, &hints, &servinfo)) == 0, "Error with getaddrinfo: {}",
gai_strerror(addr_result));
MG_ASSERT(servinfo, "Could not resolve address: {}", hostname);
std::string address;

View File

@ -12,18 +12,16 @@ struct KVStore::impl {
rocksdb::Options options;
};
KVStore::KVStore(std::filesystem::path storage)
: pimpl_(std::make_unique<impl>()) {
KVStore::KVStore(std::filesystem::path storage) : pimpl_(std::make_unique<impl>()) {
pimpl_->storage = storage;
if (!utils::EnsureDir(pimpl_->storage))
throw KVStoreError("Folder for the key-value store " +
pimpl_->storage.string() + " couldn't be initialized!");
throw KVStoreError("Folder for the key-value store " + pimpl_->storage.string() + " couldn't be initialized!");
pimpl_->options.create_if_missing = true;
rocksdb::DB *db = nullptr;
auto s = rocksdb::DB::Open(pimpl_->options, storage.c_str(), &db);
if (!s.ok())
throw KVStoreError("RocksDB couldn't be initialized inside " +
storage.string() + " -- " + std::string(s.ToString()));
throw KVStoreError("RocksDB couldn't be initialized inside " + storage.string() + " -- " +
std::string(s.ToString()));
pimpl_->db.reset(db);
}
@ -72,19 +70,16 @@ bool KVStore::DeleteMultiple(const std::vector<std::string> &keys) {
}
bool KVStore::DeletePrefix(const std::string &prefix) {
std::unique_ptr<rocksdb::Iterator> iter = std::unique_ptr<rocksdb::Iterator>(
pimpl_->db->NewIterator(rocksdb::ReadOptions()));
for (iter->Seek(prefix); iter->Valid() && iter->key().starts_with(prefix);
iter->Next()) {
if (!pimpl_->db->Delete(rocksdb::WriteOptions(), iter->key()).ok())
return false;
std::unique_ptr<rocksdb::Iterator> iter =
std::unique_ptr<rocksdb::Iterator>(pimpl_->db->NewIterator(rocksdb::ReadOptions()));
for (iter->Seek(prefix); iter->Valid() && iter->key().starts_with(prefix); iter->Next()) {
if (!pimpl_->db->Delete(rocksdb::WriteOptions(), iter->key()).ok()) return false;
}
return true;
}
bool KVStore::PutAndDeleteMultiple(
const std::map<std::string, std::string> &items,
const std::vector<std::string> &keys) {
bool KVStore::PutAndDeleteMultiple(const std::map<std::string, std::string> &items,
const std::vector<std::string> &keys) {
rocksdb::WriteBatch batch;
for (const auto &item : items) {
batch.Put(item.first, item.second);
@ -105,22 +100,16 @@ struct KVStore::iterator::impl {
std::pair<std::string, std::string> disk_prop;
};
KVStore::iterator::iterator(const KVStore *kvstore, const std::string &prefix,
bool at_end)
KVStore::iterator::iterator(const KVStore *kvstore, const std::string &prefix, bool at_end)
: pimpl_(std::make_unique<impl>()) {
pimpl_->kvstore = kvstore;
pimpl_->prefix = prefix;
pimpl_->it = std::unique_ptr<rocksdb::Iterator>(
pimpl_->kvstore->pimpl_->db->NewIterator(rocksdb::ReadOptions()));
pimpl_->it = std::unique_ptr<rocksdb::Iterator>(pimpl_->kvstore->pimpl_->db->NewIterator(rocksdb::ReadOptions()));
pimpl_->it->Seek(pimpl_->prefix);
if (!pimpl_->it->Valid() || !pimpl_->it->key().starts_with(pimpl_->prefix) ||
at_end)
pimpl_->it = nullptr;
if (!pimpl_->it->Valid() || !pimpl_->it->key().starts_with(pimpl_->prefix) || at_end) pimpl_->it = nullptr;
}
KVStore::iterator::iterator(KVStore::iterator &&other) {
pimpl_ = std::move(other.pimpl_);
}
KVStore::iterator::iterator(KVStore::iterator &&other) { pimpl_ = std::move(other.pimpl_); }
KVStore::iterator::~iterator() {}
@ -131,24 +120,19 @@ KVStore::iterator &KVStore::iterator::operator=(KVStore::iterator &&other) {
KVStore::iterator &KVStore::iterator::operator++() {
pimpl_->it->Next();
if (!pimpl_->it->Valid() || !pimpl_->it->key().starts_with(pimpl_->prefix))
pimpl_->it = nullptr;
if (!pimpl_->it->Valid() || !pimpl_->it->key().starts_with(pimpl_->prefix)) pimpl_->it = nullptr;
return *this;
}
bool KVStore::iterator::operator==(const iterator &other) const {
return pimpl_->kvstore == other.pimpl_->kvstore &&
pimpl_->prefix == other.pimpl_->prefix &&
return pimpl_->kvstore == other.pimpl_->kvstore && pimpl_->prefix == other.pimpl_->prefix &&
pimpl_->it == other.pimpl_->it;
}
bool KVStore::iterator::operator!=(const iterator &other) const {
return !(*this == other);
}
bool KVStore::iterator::operator!=(const iterator &other) const { return !(*this == other); }
KVStore::iterator::reference KVStore::iterator::operator*() {
pimpl_->disk_prop = {pimpl_->it->key().ToString(),
pimpl_->it->value().ToString()};
pimpl_->disk_prop = {pimpl_->it->key().ToString(), pimpl_->it->value().ToString()};
return pimpl_->disk_prop;
}
@ -166,8 +150,7 @@ size_t KVStore::Size(const std::string &prefix) {
return size;
}
bool KVStore::CompactRange(const std::string &begin_prefix,
const std::string &end_prefix) {
bool KVStore::CompactRange(const std::string &begin_prefix, const std::string &end_prefix) {
rocksdb::CompactRangeOptions options;
rocksdb::Slice begin(begin_prefix);
rocksdb::Slice end(end_prefix);

View File

@ -114,8 +114,7 @@ class KVStore final {
* @return true if the items have been successfully stored and deleted.
* In case of any error false is going to be returned.
*/
bool PutAndDeleteMultiple(const std::map<std::string, std::string> &items,
const std::vector<std::string> &keys);
bool PutAndDeleteMultiple(const std::map<std::string, std::string> &items, const std::vector<std::string> &keys);
/**
* Returns total number of stored (key, value) pairs. The function takes an
@ -140,8 +139,7 @@ class KVStore final {
*
* @return - true if the compaction finished successfully, false otherwise.
*/
bool CompactRange(const std::string &begin_prefix,
const std::string &end_prefix);
bool CompactRange(const std::string &begin_prefix, const std::string &end_prefix);
/**
* Custom prefix-based iterator over kvstore.
@ -150,17 +148,14 @@ class KVStore final {
* and behaves as if all of those pairs are stored in a single iterable
* collection of std::pair<std::string, std::string>.
*/
class iterator final
: public std::iterator<
std::input_iterator_tag, // iterator_category
std::pair<std::string, std::string>, // value_type
long, // difference_type
const std::pair<std::string, std::string> *, // pointer
const std::pair<std::string, std::string> & // reference
> {
class iterator final : public std::iterator<std::input_iterator_tag, // iterator_category
std::pair<std::string, std::string>, // value_type
long, // difference_type
const std::pair<std::string, std::string> *, // pointer
const std::pair<std::string, std::string> & // reference
> {
public:
explicit iterator(const KVStore *kvstore, const std::string &prefix = "",
bool at_end = false);
explicit iterator(const KVStore *kvstore, const std::string &prefix = "", bool at_end = false);
iterator(const iterator &other) = delete;
@ -191,13 +186,9 @@ class KVStore final {
std::unique_ptr<impl> pimpl_;
};
iterator begin(const std::string &prefix = "") {
return iterator(this, prefix);
}
iterator begin(const std::string &prefix = "") { return iterator(this, prefix); }
iterator end(const std::string &prefix = "") {
return iterator(this, prefix, true);
}
iterator end(const std::string &prefix = "") { return iterator(this, prefix, true); }
private:
struct impl;

View File

@ -26,8 +26,7 @@ std::optional<std::string> KVStore::Get(const std::string &key) const noexcept {
}
bool KVStore::Delete(const std::string &key) {
LOG_FATAL(
"Unsupported operation (KVStore::Delete) -- this is a dummy kvstore");
LOG_FATAL("Unsupported operation (KVStore::Delete) -- this is a dummy kvstore");
}
bool KVStore::DeleteMultiple(const std::vector<std::string> &keys) {
@ -42,9 +41,8 @@ bool KVStore::DeletePrefix(const std::string &prefix) {
"dummy kvstore");
}
bool KVStore::PutAndDeleteMultiple(
const std::map<std::string, std::string> &items,
const std::vector<std::string> &keys) {
bool KVStore::PutAndDeleteMultiple(const std::map<std::string, std::string> &items,
const std::vector<std::string> &keys) {
LOG_FATAL(
"Unsupported operation (KVStore::PutAndDeleteMultiple) -- this is a "
"dummy kvstore");
@ -54,13 +52,9 @@ bool KVStore::PutAndDeleteMultiple(
struct KVStore::iterator::impl {};
KVStore::iterator::iterator(const KVStore *kvstore, const std::string &prefix,
bool at_end)
: pimpl_(new impl()) {}
KVStore::iterator::iterator(const KVStore *kvstore, const std::string &prefix, bool at_end) : pimpl_(new impl()) {}
KVStore::iterator::iterator(KVStore::iterator &&other) {
pimpl_ = std::move(other.pimpl_);
}
KVStore::iterator::iterator(KVStore::iterator &&other) { pimpl_ = std::move(other.pimpl_); }
KVStore::iterator::~iterator() {}
@ -77,9 +71,7 @@ KVStore::iterator &KVStore::iterator::operator++() {
bool KVStore::iterator::operator==(const iterator &other) const { return true; }
bool KVStore::iterator::operator!=(const iterator &other) const {
return false;
}
bool KVStore::iterator::operator!=(const iterator &other) const { return false; }
KVStore::iterator::reference KVStore::iterator::operator*() {
LOG_FATAL(
@ -99,8 +91,7 @@ bool KVStore::iterator::IsValid() { return false; }
size_t KVStore::Size(const std::string &prefix) { return 0; }
bool KVStore::CompactRange(const std::string &begin_prefix,
const std::string &end_prefix) {
bool KVStore::CompactRange(const std::string &begin_prefix, const std::string &end_prefix) {
LOG_FATAL(
"Unsupported operation (KVStore::Compact) -- this is a "
"dummy kvstore");

View File

@ -63,25 +63,19 @@
#endif
// Bolt server flags.
DEFINE_string(bolt_address, "0.0.0.0",
"IP address on which the Bolt server should listen.");
DEFINE_VALIDATED_int32(bolt_port, 7687,
"Port on which the Bolt server should listen.",
DEFINE_string(bolt_address, "0.0.0.0", "IP address on which the Bolt server should listen.");
DEFINE_VALIDATED_int32(bolt_port, 7687, "Port on which the Bolt server should listen.",
FLAG_IN_RANGE(0, std::numeric_limits<uint16_t>::max()));
DEFINE_VALIDATED_int32(
bolt_num_workers, std::max(std::thread::hardware_concurrency(), 1U),
"Number of workers used by the Bolt server. By default, this will be the "
"number of processing units available on the machine.",
FLAG_IN_RANGE(1, INT32_MAX));
DEFINE_VALIDATED_int32(
bolt_session_inactivity_timeout, 1800,
"Time in seconds after which inactive Bolt sessions will be "
"closed.",
FLAG_IN_RANGE(1, INT32_MAX));
DEFINE_string(bolt_cert_file, "",
"Certificate file which should be used for the Bolt server.");
DEFINE_string(bolt_key_file, "",
"Key file which should be used for the Bolt server.");
DEFINE_VALIDATED_int32(bolt_num_workers, std::max(std::thread::hardware_concurrency(), 1U),
"Number of workers used by the Bolt server. By default, this will be the "
"number of processing units available on the machine.",
FLAG_IN_RANGE(1, INT32_MAX));
DEFINE_VALIDATED_int32(bolt_session_inactivity_timeout, 1800,
"Time in seconds after which inactive Bolt sessions will be "
"closed.",
FLAG_IN_RANGE(1, INT32_MAX));
DEFINE_string(bolt_cert_file, "", "Certificate file which should be used for the Bolt server.");
DEFINE_string(bolt_key_file, "", "Key file which should be used for the Bolt server.");
DEFINE_string(bolt_server_name_for_init, "",
"Server name which the database should send to the client in the "
"Bolt INIT message.");
@ -89,26 +83,20 @@ DEFINE_string(bolt_server_name_for_init, "",
// General purpose flags.
// NOTE: The `data_directory` flag must be the same here and in
// `mg_import_csv`. If you change it, make sure to change it there as well.
DEFINE_string(data_directory, "mg_data",
"Path to directory in which to save all permanent data.");
DEFINE_HIDDEN_string(
log_link_basename, "",
"Basename used for symlink creation to the last log file.");
DEFINE_string(data_directory, "mg_data", "Path to directory in which to save all permanent data.");
DEFINE_HIDDEN_string(log_link_basename, "", "Basename used for symlink creation to the last log file.");
DEFINE_uint64(memory_warning_threshold, 1024,
"Memory warning threshold, in MB. If Memgraph detects there is "
"less available RAM it will log a warning. Set to 0 to "
"disable.");
// Storage flags.
DEFINE_VALIDATED_uint64(storage_gc_cycle_sec, 30,
"Storage garbage collector interval (in seconds).",
DEFINE_VALIDATED_uint64(storage_gc_cycle_sec, 30, "Storage garbage collector interval (in seconds).",
FLAG_IN_RANGE(1, 24 * 3600));
// NOTE: The `storage_properties_on_edges` flag must be the same here and in
// `mg_import_csv`. If you change it, make sure to change it there as well.
DEFINE_bool(storage_properties_on_edges, false,
"Controls whether edges have properties.");
DEFINE_bool(storage_recover_on_startup, false,
"Controls whether the storage recovers persisted data on startup.");
DEFINE_bool(storage_properties_on_edges, false, "Controls whether edges have properties.");
DEFINE_bool(storage_recover_on_startup, false, "Controls whether the storage recovers persisted data on startup.");
DEFINE_VALIDATED_uint64(storage_snapshot_interval_sec, 0,
"Storage snapshot creation interval (in seconds). Set "
"to 0 to disable periodic snapshot creation.",
@ -116,21 +104,15 @@ DEFINE_VALIDATED_uint64(storage_snapshot_interval_sec, 0,
DEFINE_bool(storage_wal_enabled, false,
"Controls whether the storage uses write-ahead-logging. To enable "
"WAL periodic snapshots must be enabled.");
DEFINE_VALIDATED_uint64(storage_snapshot_retention_count, 3,
"The number of snapshots that should always be kept.",
DEFINE_VALIDATED_uint64(storage_snapshot_retention_count, 3, "The number of snapshots that should always be kept.",
FLAG_IN_RANGE(1, 1000000));
DEFINE_VALIDATED_uint64(storage_wal_file_size_kib,
storage::Config::Durability().wal_file_size_kibibytes,
"Minimum file size of each WAL file.",
FLAG_IN_RANGE(1, 1000 * 1024));
DEFINE_VALIDATED_uint64(
storage_wal_file_flush_every_n_tx,
storage::Config::Durability().wal_file_flush_every_n_tx,
"Issue a 'fsync' call after this amount of transactions are written to the "
"WAL file. Set to 1 for fully synchronous operation.",
FLAG_IN_RANGE(1, 1000000));
DEFINE_bool(storage_snapshot_on_exit, false,
"Controls whether the storage creates another snapshot on exit.");
DEFINE_VALIDATED_uint64(storage_wal_file_size_kib, storage::Config::Durability().wal_file_size_kibibytes,
"Minimum file size of each WAL file.", FLAG_IN_RANGE(1, 1000 * 1024));
DEFINE_VALIDATED_uint64(storage_wal_file_flush_every_n_tx, storage::Config::Durability().wal_file_flush_every_n_tx,
"Issue a 'fsync' call after this amount of transactions are written to the "
"WAL file. Set to 1 for fully synchronous operation.",
FLAG_IN_RANGE(1, 1000000));
DEFINE_bool(storage_snapshot_on_exit, false, "Controls whether the storage creates another snapshot on exit.");
DEFINE_bool(telemetry_enabled, false,
"Set to true to enable telemetry. We collect information about the "
@ -141,13 +123,11 @@ DEFINE_bool(telemetry_enabled, false,
// Audit logging flags.
#ifdef MG_ENTERPRISE
DEFINE_bool(audit_enabled, false, "Set to true to enable audit logging.");
DEFINE_VALIDATED_int32(audit_buffer_size, audit::kBufferSizeDefault,
"Maximum number of items in the audit log buffer.",
DEFINE_VALIDATED_int32(audit_buffer_size, audit::kBufferSizeDefault, "Maximum number of items in the audit log buffer.",
FLAG_IN_RANGE(1, INT32_MAX));
DEFINE_VALIDATED_int32(
audit_buffer_flush_interval_ms, audit::kBufferFlushIntervalMillisDefault,
"Interval (in milliseconds) used for flushing the audit log buffer.",
FLAG_IN_RANGE(10, INT32_MAX));
DEFINE_VALIDATED_int32(audit_buffer_flush_interval_ms, audit::kBufferFlushIntervalMillisDefault,
"Interval (in milliseconds) used for flushing the audit log buffer.",
FLAG_IN_RANGE(10, INT32_MAX));
#endif
// Query flags.
@ -155,41 +135,34 @@ DEFINE_uint64(query_execution_timeout_sec, 180,
"Maximum allowed query execution time. Queries exceeding this "
"limit will be aborted. Value of 0 means no limit.");
DEFINE_VALIDATED_string(
query_modules_directory, "",
"Directory where modules with custom query procedures are stored.", {
if (value.empty()) return true;
if (utils::DirExists(value)) return true;
std::cout << "Expected --" << flagname << " to point to a directory."
<< std::endl;
return false;
});
DEFINE_VALIDATED_string(query_modules_directory, "", "Directory where modules with custom query procedures are stored.",
{
if (value.empty()) return true;
if (utils::DirExists(value)) return true;
std::cout << "Expected --" << flagname << " to point to a directory." << std::endl;
return false;
});
// Logging flags
DEFINE_bool(also_log_to_stderr, false,
"Log messages go to stderr in addition to logfiles");
DEFINE_bool(also_log_to_stderr, false, "Log messages go to stderr in addition to logfiles");
DEFINE_string(log_file, "", "Path to where the log should be stored.");
namespace {
constexpr std::array log_level_mappings{
std::pair{"TRACE", spdlog::level::trace},
std::pair{"DEBUG", spdlog::level::debug},
std::pair{"INFO", spdlog::level::info},
std::pair{"WARNING", spdlog::level::warn},
std::pair{"ERROR", spdlog::level::err},
std::pair{"CRITICAL", spdlog::level::critical}};
std::pair{"TRACE", spdlog::level::trace}, std::pair{"DEBUG", spdlog::level::debug},
std::pair{"INFO", spdlog::level::info}, std::pair{"WARNING", spdlog::level::warn},
std::pair{"ERROR", spdlog::level::err}, std::pair{"CRITICAL", spdlog::level::critical}};
std::string GetAllowedLogLevelsString() {
std::vector<std::string> allowed_log_levels;
allowed_log_levels.reserve(log_level_mappings.size());
std::transform(log_level_mappings.cbegin(), log_level_mappings.cend(),
std::back_inserter(allowed_log_levels),
std::transform(log_level_mappings.cbegin(), log_level_mappings.cend(), std::back_inserter(allowed_log_levels),
[](const auto &mapping) { return mapping.first; });
return utils::Join(allowed_log_levels, ", ");
}
const std::string log_level_help_string = fmt::format(
"Minimum log level. Allowed values: {}", GetAllowedLogLevelsString());
const std::string log_level_help_string =
fmt::format("Minimum log level. Allowed values: {}", GetAllowedLogLevelsString());
} // namespace
DEFINE_VALIDATED_string(log_level, "WARNING", log_level_help_string.c_str(), {
@ -199,11 +172,8 @@ DEFINE_VALIDATED_string(log_level, "WARNING", log_level_help_string.c_str(), {
}
if (std::find_if(log_level_mappings.cbegin(), log_level_mappings.cend(),
[&](const auto &mapping) {
return mapping.first == value;
}) == log_level_mappings.cend()) {
std::cout << "Invalid value for log level. Allowed values: "
<< GetAllowedLogLevelsString() << std::endl;
[&](const auto &mapping) { return mapping.first == value; }) == log_level_mappings.cend()) {
std::cout << "Invalid value for log level. Allowed values: " << GetAllowedLogLevelsString() << std::endl;
return false;
}
@ -212,9 +182,8 @@ DEFINE_VALIDATED_string(log_level, "WARNING", log_level_help_string.c_str(), {
namespace {
void ParseLogLevel() {
const auto mapping_iter = std::find_if(
log_level_mappings.cbegin(), log_level_mappings.cend(),
[](const auto &mapping) { return mapping.first == FLAGS_log_level; });
const auto mapping_iter = std::find_if(log_level_mappings.cbegin(), log_level_mappings.cend(),
[](const auto &mapping) { return mapping.first == FLAGS_log_level; });
MG_ASSERT(mapping_iter != log_level_mappings.cend(), "Invalid log level");
spdlog::set_level(mapping_iter->second);
@ -227,8 +196,7 @@ void ConfigureLogging() {
std::vector<spdlog::sink_ptr> loggers;
if (FLAGS_also_log_to_stderr) {
loggers.emplace_back(
std::make_shared<spdlog::sinks::stderr_color_sink_mt>());
loggers.emplace_back(std::make_shared<spdlog::sinks::stderr_color_sink_mt>());
}
if (!FLAGS_log_file.empty()) {
@ -240,12 +208,10 @@ void ConfigureLogging() {
local_time = localtime(&current_time);
loggers.emplace_back(std::make_shared<spdlog::sinks::daily_file_sink_mt>(
FLAGS_log_file, local_time->tm_hour, local_time->tm_min, false,
log_retention_count));
FLAGS_log_file, local_time->tm_hour, local_time->tm_min, false, log_retention_count));
}
spdlog::set_default_logger(std::make_shared<spdlog::logger>(
"memgraph_log", loggers.begin(), loggers.end()));
spdlog::set_default_logger(std::make_shared<spdlog::logger>("memgraph_log", loggers.begin(), loggers.end()));
spdlog::flush_on(spdlog::level::trace);
ParseLogLevel();
@ -258,13 +224,9 @@ void ConfigureLogging() {
struct SessionData {
// Explicit constructor here to ensure that pointers to all objects are
// supplied.
SessionData(storage::Storage *db,
query::InterpreterContext *interpreter_context, auth::Auth *auth,
SessionData(storage::Storage *db, query::InterpreterContext *interpreter_context, auth::Auth *auth,
audit::Log *audit_log)
: db(db),
interpreter_context(interpreter_context),
auth(auth),
audit_log(audit_log) {}
: db(db), interpreter_context(interpreter_context), auth(auth), audit_log(audit_log) {}
storage::Storage *db;
query::InterpreterContext *interpreter_context;
auth::Auth *auth;
@ -274,24 +236,19 @@ struct SessionData {
struct SessionData {
// Explicit constructor here to ensure that pointers to all objects are
// supplied.
SessionData(storage::Storage *db,
query::InterpreterContext *interpreter_context)
SessionData(storage::Storage *db, query::InterpreterContext *interpreter_context)
: db(db), interpreter_context(interpreter_context) {}
storage::Storage *db;
query::InterpreterContext *interpreter_context;
};
#endif
class BoltSession final
: public communication::bolt::Session<communication::InputStream,
communication::OutputStream> {
class BoltSession final : public communication::bolt::Session<communication::InputStream, communication::OutputStream> {
public:
BoltSession(SessionData *data, const io::network::Endpoint &endpoint,
communication::InputStream *input_stream,
BoltSession(SessionData *data, const io::network::Endpoint &endpoint, communication::InputStream *input_stream,
communication::OutputStream *output_stream)
: communication::bolt::Session<communication::InputStream,
communication::OutputStream>(
input_stream, output_stream),
: communication::bolt::Session<communication::InputStream, communication::OutputStream>(input_stream,
output_stream),
db_(data->db),
interpreter_(data->interpreter_context),
#ifdef MG_ENTERPRISE
@ -301,8 +258,7 @@ class BoltSession final
endpoint_(endpoint) {
}
using communication::bolt::Session<communication::InputStream,
communication::OutputStream>::TEncoder;
using communication::bolt::Session<communication::InputStream, communication::OutputStream>::TEncoder;
void BeginTransaction() override { interpreter_.BeginTransaction(); }
@ -311,15 +267,11 @@ class BoltSession final
void RollbackTransaction() override { interpreter_.RollbackTransaction(); }
std::pair<std::vector<std::string>, std::optional<int>> Interpret(
const std::string &query,
const std::map<std::string, communication::bolt::Value> &params)
override {
const std::string &query, const std::map<std::string, communication::bolt::Value> &params) override {
std::map<std::string, storage::PropertyValue> params_pv;
for (const auto &kv : params)
params_pv.emplace(kv.first, glue::ToPropertyValue(kv.second));
for (const auto &kv : params) params_pv.emplace(kv.first, glue::ToPropertyValue(kv.second));
#ifdef MG_ENTERPRISE
audit_log_->Record(endpoint_.address, user_ ? user_->username() : "", query,
storage::PropertyValue(params_pv));
audit_log_->Record(endpoint_.address, user_ ? user_->username() : "", query, storage::PropertyValue(params_pv));
#endif
try {
auto result = interpreter_.Prepare(query, params_pv);
@ -327,8 +279,7 @@ class BoltSession final
if (user_) {
const auto &permissions = user_->GetPermissions();
for (const auto &privilege : result.privileges) {
if (permissions.Has(glue::PrivilegeToPermission(privilege)) !=
auth::PermissionLevel::GRANT) {
if (permissions.Has(glue::PrivilegeToPermission(privilege)) != auth::PermissionLevel::GRANT) {
interpreter_.Abort();
throw communication::bolt::ClientError(
"You are not authorized to execute this query! Please contact "
@ -346,23 +297,20 @@ class BoltSession final
}
}
std::map<std::string, communication::bolt::Value> Pull(
TEncoder *encoder, std::optional<int> n,
std::optional<int> qid) override {
std::map<std::string, communication::bolt::Value> Pull(TEncoder *encoder, std::optional<int> n,
std::optional<int> qid) override {
TypedValueResultStream stream(encoder, db_);
return PullResults(stream, n, qid);
}
std::map<std::string, communication::bolt::Value> Discard(
std::optional<int> n, std::optional<int> qid) override {
std::map<std::string, communication::bolt::Value> Discard(std::optional<int> n, std::optional<int> qid) override {
DiscardValueResultStream stream;
return PullResults(stream, n, qid);
}
void Abort() override { interpreter_.Abort(); }
bool Authenticate(const std::string &username,
const std::string &password) override {
bool Authenticate(const std::string &username, const std::string &password) override {
#ifdef MG_ENTERPRISE
if (!auth_->HasUsers()) return true;
user_ = auth_->Authenticate(username, password);
@ -379,14 +327,13 @@ class BoltSession final
private:
template <typename TStream>
std::map<std::string, communication::bolt::Value> PullResults(
TStream &stream, std::optional<int> n, std::optional<int> qid) {
std::map<std::string, communication::bolt::Value> PullResults(TStream &stream, std::optional<int> n,
std::optional<int> qid) {
try {
const auto &summary = interpreter_.Pull(&stream, n, qid);
std::map<std::string, communication::bolt::Value> decoded_summary;
for (const auto &kv : summary) {
auto maybe_value =
glue::ToBoltValue(kv.second, *db_, storage::View::NEW);
auto maybe_value = glue::ToBoltValue(kv.second, *db_, storage::View::NEW);
if (maybe_value.HasError()) {
switch (maybe_value.GetError()) {
case storage::Error::DELETED_OBJECT:
@ -394,8 +341,7 @@ class BoltSession final
case storage::Error::VERTEX_HAS_EDGES:
case storage::Error::PROPERTIES_DISABLED:
case storage::Error::NONEXISTENT_OBJECT:
throw communication::bolt::ClientError(
"Unexpected storage error when streaming summary.");
throw communication::bolt::ClientError("Unexpected storage error when streaming summary.");
}
}
decoded_summary.emplace(kv.first, std::move(*maybe_value));
@ -412,8 +358,7 @@ class BoltSession final
/// before forwarding the calls to original TEncoder.
class TypedValueResultStream {
public:
TypedValueResultStream(TEncoder *encoder, const storage::Storage *db)
: encoder_(encoder), db_(db) {}
TypedValueResultStream(TEncoder *encoder, const storage::Storage *db) : encoder_(encoder), db_(db) {}
void Result(const std::vector<query::TypedValue> &values) {
std::vector<communication::bolt::Value> decoded_values;
@ -423,16 +368,13 @@ class BoltSession final
if (maybe_value.HasError()) {
switch (maybe_value.GetError()) {
case storage::Error::DELETED_OBJECT:
throw communication::bolt::ClientError(
"Returning a deleted object as a result.");
throw communication::bolt::ClientError("Returning a deleted object as a result.");
case storage::Error::NONEXISTENT_OBJECT:
throw communication::bolt::ClientError(
"Returning a nonexistent object as a result.");
throw communication::bolt::ClientError("Returning a nonexistent object as a result.");
case storage::Error::VERTEX_HAS_EDGES:
case storage::Error::SERIALIZATION_ERROR:
case storage::Error::PROPERTIES_DISABLED:
throw communication::bolt::ClientError(
"Unexpected storage error when streaming results.");
throw communication::bolt::ClientError("Unexpected storage error when streaming results.");
}
}
decoded_values.emplace_back(std::move(*maybe_value));
@ -467,20 +409,17 @@ using ServerT = communication::Server<BoltSession, SessionData>;
using communication::ServerContext;
#ifdef MG_ENTERPRISE
DEFINE_string(
auth_user_or_role_name_regex, "[a-zA-Z0-9_.+-@]+",
"Set to the regular expression that each user or role name must fulfill.");
DEFINE_string(auth_user_or_role_name_regex, "[a-zA-Z0-9_.+-@]+",
"Set to the regular expression that each user or role name must fulfill.");
class AuthQueryHandler final : public query::AuthQueryHandler {
auth::Auth *auth_;
std::regex name_regex_;
public:
AuthQueryHandler(auth::Auth *auth, const std::regex &name_regex)
: auth_(auth), name_regex_(name_regex) {}
AuthQueryHandler(auth::Auth *auth, const std::regex &name_regex) : auth_(auth), name_regex_(name_regex) {}
bool CreateUser(const std::string &username,
const std::optional<std::string> &password) override {
bool CreateUser(const std::string &username, const std::optional<std::string> &password) override {
if (!std::regex_match(username, name_regex_)) {
throw query::QueryRuntimeException("Invalid user name.");
}
@ -506,8 +445,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
}
}
void SetPassword(const std::string &username,
const std::optional<std::string> &password) override {
void SetPassword(const std::string &username, const std::optional<std::string> &password) override {
if (!std::regex_match(username, name_regex_)) {
throw query::QueryRuntimeException("Invalid user name.");
}
@ -515,8 +453,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
std::lock_guard<std::mutex> lock(auth_->WithLock());
auto user = auth_->GetUser(username);
if (!user) {
throw query::QueryRuntimeException("User '{}' doesn't exist.",
username);
throw query::QueryRuntimeException("User '{}' doesn't exist.", username);
}
user->UpdatePassword(password);
auth_->SaveUser(*user);
@ -581,8 +518,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
}
}
std::optional<std::string> GetRolenameForUser(
const std::string &username) override {
std::optional<std::string> GetRolenameForUser(const std::string &username) override {
if (!std::regex_match(username, name_regex_)) {
throw query::QueryRuntimeException("Invalid user name.");
}
@ -590,8 +526,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
std::lock_guard<std::mutex> lock(auth_->WithLock());
auto user = auth_->GetUser(username);
if (!user) {
throw query::QueryRuntimeException("User '{}' doesn't exist .",
username);
throw query::QueryRuntimeException("User '{}' doesn't exist .", username);
}
if (user->role()) return user->role()->rolename();
return std::nullopt;
@ -600,8 +535,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
}
}
std::vector<query::TypedValue> GetUsernamesForRole(
const std::string &rolename) override {
std::vector<query::TypedValue> GetUsernamesForRole(const std::string &rolename) override {
if (!std::regex_match(rolename, name_regex_)) {
throw query::QueryRuntimeException("Invalid role name.");
}
@ -609,8 +543,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
std::lock_guard<std::mutex> lock(auth_->WithLock());
auto role = auth_->GetRole(rolename);
if (!role) {
throw query::QueryRuntimeException("Role '{}' doesn't exist.",
rolename);
throw query::QueryRuntimeException("Role '{}' doesn't exist.", rolename);
}
std::vector<query::TypedValue> usernames;
const auto &users = auth_->AllUsersForRole(rolename);
@ -624,8 +557,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
}
}
void SetRole(const std::string &username,
const std::string &rolename) override {
void SetRole(const std::string &username, const std::string &rolename) override {
if (!std::regex_match(username, name_regex_)) {
throw query::QueryRuntimeException("Invalid user name.");
}
@ -636,18 +568,15 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
std::lock_guard<std::mutex> lock(auth_->WithLock());
auto user = auth_->GetUser(username);
if (!user) {
throw query::QueryRuntimeException("User '{}' doesn't exist .",
username);
throw query::QueryRuntimeException("User '{}' doesn't exist .", username);
}
auto role = auth_->GetRole(rolename);
if (!role) {
throw query::QueryRuntimeException("Role '{}' doesn't exist .",
rolename);
throw query::QueryRuntimeException("Role '{}' doesn't exist .", rolename);
}
if (user->role()) {
throw query::QueryRuntimeException(
"User '{}' is already a member of role '{}'.", username,
user->role()->rolename());
throw query::QueryRuntimeException("User '{}' is already a member of role '{}'.", username,
user->role()->rolename());
}
user->SetRole(*role);
auth_->SaveUser(*user);
@ -664,8 +593,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
std::lock_guard<std::mutex> lock(auth_->WithLock());
auto user = auth_->GetUser(username);
if (!user) {
throw query::QueryRuntimeException("User '{}' doesn't exist .",
username);
throw query::QueryRuntimeException("User '{}' doesn't exist .", username);
}
user->ClearRole();
auth_->SaveUser(*user);
@ -674,8 +602,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
}
}
std::vector<std::vector<query::TypedValue>> GetPrivileges(
const std::string &user_or_role) override {
std::vector<std::vector<query::TypedValue>> GetPrivileges(const std::string &user_or_role) override {
if (!std::regex_match(user_or_role, name_regex_)) {
throw query::QueryRuntimeException("Invalid user or role name.");
}
@ -685,8 +612,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
auto user = auth_->GetUser(user_or_role);
auto role = auth_->GetRole(user_or_role);
if (!user && !role) {
throw query::QueryRuntimeException("User or role '{}' doesn't exist.",
user_or_role);
throw query::QueryRuntimeException("User or role '{}' doesn't exist.", user_or_role);
}
if (user) {
const auto &permissions = user->GetPermissions();
@ -709,10 +635,9 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
description.emplace_back("DENIED TO ROLE");
}
}
grants.push_back(
{query::TypedValue(auth::PermissionToString(permission)),
query::TypedValue(auth::PermissionLevelToString(effective)),
query::TypedValue(utils::Join(description, ", "))});
grants.push_back({query::TypedValue(auth::PermissionToString(permission)),
query::TypedValue(auth::PermissionLevelToString(effective)),
query::TypedValue(utils::Join(description, ", "))});
}
}
} else {
@ -727,10 +652,9 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
} else if (effective == auth::PermissionLevel::DENY) {
description = "DENIED TO ROLE";
}
grants.push_back(
{query::TypedValue(auth::PermissionToString(permission)),
query::TypedValue(auth::PermissionLevelToString(effective)),
query::TypedValue(description)});
grants.push_back({query::TypedValue(auth::PermissionToString(permission)),
query::TypedValue(auth::PermissionLevelToString(effective)),
query::TypedValue(description)});
}
}
}
@ -740,48 +664,40 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
}
}
void GrantPrivilege(
const std::string &user_or_role,
const std::vector<query::AuthQuery::Privilege> &privileges) override {
EditPermissions(user_or_role, privileges,
[](auto *permissions, const auto &permission) {
// TODO (mferencevic): should we first check that the
// privilege is granted/denied/revoked before
// unconditionally granting/denying/revoking it?
permissions->Grant(permission);
});
void GrantPrivilege(const std::string &user_or_role,
const std::vector<query::AuthQuery::Privilege> &privileges) override {
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
// TODO (mferencevic): should we first check that the
// privilege is granted/denied/revoked before
// unconditionally granting/denying/revoking it?
permissions->Grant(permission);
});
}
void DenyPrivilege(
const std::string &user_or_role,
const std::vector<query::AuthQuery::Privilege> &privileges) override {
EditPermissions(user_or_role, privileges,
[](auto *permissions, const auto &permission) {
// TODO (mferencevic): should we first check that the
// privilege is granted/denied/revoked before
// unconditionally granting/denying/revoking it?
permissions->Deny(permission);
});
void DenyPrivilege(const std::string &user_or_role,
const std::vector<query::AuthQuery::Privilege> &privileges) override {
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
// TODO (mferencevic): should we first check that the
// privilege is granted/denied/revoked before
// unconditionally granting/denying/revoking it?
permissions->Deny(permission);
});
}
void RevokePrivilege(
const std::string &user_or_role,
const std::vector<query::AuthQuery::Privilege> &privileges) override {
EditPermissions(user_or_role, privileges,
[](auto *permissions, const auto &permission) {
// TODO (mferencevic): should we first check that the
// privilege is granted/denied/revoked before
// unconditionally granting/denying/revoking it?
permissions->Revoke(permission);
});
void RevokePrivilege(const std::string &user_or_role,
const std::vector<query::AuthQuery::Privilege> &privileges) override {
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
// TODO (mferencevic): should we first check that the
// privilege is granted/denied/revoked before
// unconditionally granting/denying/revoking it?
permissions->Revoke(permission);
});
}
private:
template <class TEditFun>
void EditPermissions(
const std::string &user_or_role,
const std::vector<query::AuthQuery::Privilege> &privileges,
const TEditFun &edit_fun) {
void EditPermissions(const std::string &user_or_role, const std::vector<query::AuthQuery::Privilege> &privileges,
const TEditFun &edit_fun) {
if (!std::regex_match(user_or_role, name_regex_)) {
throw query::QueryRuntimeException("Invalid user or role name.");
}
@ -795,8 +711,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
auto user = auth_->GetUser(user_or_role);
auto role = auth_->GetRole(user_or_role);
if (!user && !role) {
throw query::QueryRuntimeException("User or role '{}' doesn't exist.",
user_or_role);
throw query::QueryRuntimeException("User or role '{}' doesn't exist.", user_or_role);
}
if (user) {
for (const auto &permission : permissions) {
@ -818,71 +733,44 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
class NoAuthInCommunity : public query::QueryRuntimeException {
public:
NoAuthInCommunity()
: query::QueryRuntimeException::QueryRuntimeException(
"Auth is not supported in Memgraph Community!") {}
: query::QueryRuntimeException::QueryRuntimeException("Auth is not supported in Memgraph Community!") {}
};
class AuthQueryHandler final : public query::AuthQueryHandler {
public:
bool CreateUser(const std::string &,
const std::optional<std::string> &) override {
throw NoAuthInCommunity();
}
bool CreateUser(const std::string &, const std::optional<std::string> &) override { throw NoAuthInCommunity(); }
bool DropUser(const std::string &) override { throw NoAuthInCommunity(); }
void SetPassword(const std::string &,
const std::optional<std::string> &) override {
throw NoAuthInCommunity();
}
void SetPassword(const std::string &, const std::optional<std::string> &) override { throw NoAuthInCommunity(); }
bool CreateRole(const std::string &) override { throw NoAuthInCommunity(); }
bool DropRole(const std::string &) override { throw NoAuthInCommunity(); }
std::vector<query::TypedValue> GetUsernames() override {
throw NoAuthInCommunity();
}
std::vector<query::TypedValue> GetUsernames() override { throw NoAuthInCommunity(); }
std::vector<query::TypedValue> GetRolenames() override {
throw NoAuthInCommunity();
}
std::vector<query::TypedValue> GetRolenames() override { throw NoAuthInCommunity(); }
std::optional<std::string> GetRolenameForUser(const std::string &) override {
throw NoAuthInCommunity();
}
std::optional<std::string> GetRolenameForUser(const std::string &) override { throw NoAuthInCommunity(); }
std::vector<query::TypedValue> GetUsernamesForRole(
const std::string &) override {
throw NoAuthInCommunity();
}
std::vector<query::TypedValue> GetUsernamesForRole(const std::string &) override { throw NoAuthInCommunity(); }
void SetRole(const std::string &, const std::string &) override {
throw NoAuthInCommunity();
}
void SetRole(const std::string &, const std::string &) override { throw NoAuthInCommunity(); }
void ClearRole(const std::string &) override { throw NoAuthInCommunity(); }
std::vector<std::vector<query::TypedValue>> GetPrivileges(
const std::string &) override {
std::vector<std::vector<query::TypedValue>> GetPrivileges(const std::string &) override { throw NoAuthInCommunity(); }
void GrantPrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override {
throw NoAuthInCommunity();
}
void GrantPrivilege(
const std::string &,
const std::vector<query::AuthQuery::Privilege> &) override {
void DenyPrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override {
throw NoAuthInCommunity();
}
void DenyPrivilege(
const std::string &,
const std::vector<query::AuthQuery::Privilege> &) override {
throw NoAuthInCommunity();
}
void RevokePrivilege(
const std::string &,
const std::vector<query::AuthQuery::Privilege> &) override {
void RevokePrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override {
throw NoAuthInCommunity();
}
};
@ -911,11 +799,9 @@ void InitSignalHandlers(const std::function<void()> &shutdown_fun) {
shutdown_fun();
};
MG_ASSERT(utils::SignalHandler::RegisterHandler(
utils::Signal::Terminate, shutdown, block_shutdown_signals),
MG_ASSERT(utils::SignalHandler::RegisterHandler(utils::Signal::Terminate, shutdown, block_shutdown_signals),
"Unable to register SIGTERM handler!");
MG_ASSERT(utils::SignalHandler::RegisterHandler(
utils::Signal::Interupt, shutdown, block_shutdown_signals),
MG_ASSERT(utils::SignalHandler::RegisterHandler(utils::Signal::Interupt, shutdown, block_shutdown_signals),
"Unable to register SIGINT handler!");
}
@ -952,13 +838,10 @@ int main(int argc, char **argv) {
auto gil = py::EnsureGIL();
auto maybe_exc = py::AppendToSysPath(py_support_dir.c_str());
if (maybe_exc) {
spdlog::error("Unable to load support for embedded Python: {}",
*maybe_exc);
spdlog::error("Unable to load support for embedded Python: {}", *maybe_exc);
}
} else {
spdlog::error(
"Unable to load support for embedded Python: missing directory {}",
py_support_dir);
spdlog::error("Unable to load support for embedded Python: missing directory {}", py_support_dir);
}
} catch (const std::filesystem::filesystem_error &e) {
spdlog::error("Unable to load support for embedded Python: {}", e.what());
@ -978,8 +861,7 @@ int main(int argc, char **argv) {
mem_log_scheduler.Run("Memory warning", std::chrono::seconds(3), [] {
auto free_ram = utils::sysinfo::AvailableMemoryKilobytes();
if (free_ram && *free_ram / 1024 < FLAGS_memory_warning_threshold)
spdlog::warn("Running out of available RAM, only {} MB left",
*free_ram / 1024);
spdlog::warn("Running out of available RAM, only {} MB left", *free_ram / 1024);
});
} else {
// Kernel version for the `MemAvailable` value is from: man procfs
@ -990,8 +872,7 @@ int main(int argc, char **argv) {
}
}
std::cout << "You are running Memgraph v" << gflags::VersionString()
<< std::endl;
std::cout << "You are running Memgraph v" << gflags::VersionString() << std::endl;
auto data_directory = std::filesystem::path(FLAGS_data_directory);
@ -1011,18 +892,15 @@ int main(int argc, char **argv) {
auth::Auth auth{data_directory / "auth"};
// Audit log
audit::Log audit_log{data_directory / "audit", FLAGS_audit_buffer_size,
FLAGS_audit_buffer_flush_interval_ms};
audit::Log audit_log{data_directory / "audit", FLAGS_audit_buffer_size, FLAGS_audit_buffer_flush_interval_ms};
// Start the log if enabled.
if (FLAGS_audit_enabled) {
audit_log.Start();
}
// Setup SIGUSR2 to be used for reopening audit log files, when e.g. logrotate
// rotates our audit logs.
MG_ASSERT(
utils::SignalHandler::RegisterHandler(
utils::Signal::User2, [&audit_log]() { audit_log.ReopenLog(); }),
"Unable to register SIGUSR2 handler!");
MG_ASSERT(utils::SignalHandler::RegisterHandler(utils::Signal::User2, [&audit_log]() { audit_log.ReopenLog(); }),
"Unable to register SIGUSR2 handler!");
// End enterprise features initialization
#endif
@ -1030,54 +908,45 @@ int main(int argc, char **argv) {
// Main storage and execution engines initialization
storage::Config db_config{
.gc = {.type = storage::Config::Gc::Type::PERIODIC,
.interval = std::chrono::seconds(FLAGS_storage_gc_cycle_sec)},
.gc = {.type = storage::Config::Gc::Type::PERIODIC, .interval = std::chrono::seconds(FLAGS_storage_gc_cycle_sec)},
.items = {.properties_on_edges = FLAGS_storage_properties_on_edges},
.durability = {
.storage_directory = FLAGS_data_directory,
.recover_on_startup = FLAGS_storage_recover_on_startup,
.snapshot_retention_count = FLAGS_storage_snapshot_retention_count,
.wal_file_size_kibibytes = FLAGS_storage_wal_file_size_kib,
.wal_file_flush_every_n_tx = FLAGS_storage_wal_file_flush_every_n_tx,
.snapshot_on_exit = FLAGS_storage_snapshot_on_exit}};
.durability = {.storage_directory = FLAGS_data_directory,
.recover_on_startup = FLAGS_storage_recover_on_startup,
.snapshot_retention_count = FLAGS_storage_snapshot_retention_count,
.wal_file_size_kibibytes = FLAGS_storage_wal_file_size_kib,
.wal_file_flush_every_n_tx = FLAGS_storage_wal_file_flush_every_n_tx,
.snapshot_on_exit = FLAGS_storage_snapshot_on_exit}};
if (FLAGS_storage_snapshot_interval_sec == 0) {
if (FLAGS_storage_wal_enabled) {
LOG_FATAL(
"In order to use write-ahead-logging you must enable "
"periodic snapshots by setting the snapshot interval to a "
"value larger than 0!");
db_config.durability.snapshot_wal_mode =
storage::Config::Durability::SnapshotWalMode::DISABLED;
db_config.durability.snapshot_wal_mode = storage::Config::Durability::SnapshotWalMode::DISABLED;
}
} else {
if (FLAGS_storage_wal_enabled) {
db_config.durability.snapshot_wal_mode = storage::Config::Durability::
SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL;
db_config.durability.snapshot_wal_mode = storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL;
} else {
db_config.durability.snapshot_wal_mode =
storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT;
db_config.durability.snapshot_wal_mode = storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT;
}
db_config.durability.snapshot_interval =
std::chrono::seconds(FLAGS_storage_snapshot_interval_sec);
db_config.durability.snapshot_interval = std::chrono::seconds(FLAGS_storage_snapshot_interval_sec);
}
storage::Storage db(db_config);
query::InterpreterContext interpreter_context{&db};
query::SetExecutionTimeout(&interpreter_context,
FLAGS_query_execution_timeout_sec);
query::SetExecutionTimeout(&interpreter_context, FLAGS_query_execution_timeout_sec);
#ifdef MG_ENTERPRISE
SessionData session_data{&db, &interpreter_context, &auth, &audit_log};
#else
SessionData session_data{&db, &interpreter_context};
#endif
query::procedure::gModuleRegistry.SetModulesDirectory(
FLAGS_query_modules_directory);
query::procedure::gModuleRegistry.SetModulesDirectory(FLAGS_query_modules_directory);
query::procedure::gModuleRegistry.UnloadAndLoadModulesFromDirectory();
#ifdef MG_ENTERPRISE
AuthQueryHandler auth_handler(&auth,
std::regex(FLAGS_auth_user_or_role_name_regex));
AuthQueryHandler auth_handler(&auth, std::regex(FLAGS_auth_user_or_role_name_regex));
#else
AuthQueryHandler auth_handler;
#endif
@ -1093,16 +962,14 @@ int main(int argc, char **argv) {
spdlog::warn("Using non-secure Bolt connection (without SSL)");
}
ServerT server({FLAGS_bolt_address, static_cast<uint16_t>(FLAGS_bolt_port)},
&session_data, &context, FLAGS_bolt_session_inactivity_timeout,
service_name, FLAGS_bolt_num_workers);
ServerT server({FLAGS_bolt_address, static_cast<uint16_t>(FLAGS_bolt_port)}, &session_data, &context,
FLAGS_bolt_session_inactivity_timeout, service_name, FLAGS_bolt_num_workers);
// Setup telemetry
std::optional<telemetry::Telemetry> telemetry;
if (FLAGS_telemetry_enabled) {
telemetry.emplace(
"https://telemetry.memgraph.com/88b5e7e8-746a-11e8-9f85-538a9e9690cc/",
data_directory / "telemetry", std::chrono::minutes(10));
telemetry.emplace("https://telemetry.memgraph.com/88b5e7e8-746a-11e8-9f85-538a9e9690cc/",
data_directory / "telemetry", std::chrono::minutes(10));
telemetry->AddCollector("db", [&db]() -> nlohmann::json {
auto info = db.GetInfo();
return {{"vertices", info.vertex_count}, {"edges", info.edge_count}};

View File

@ -42,30 +42,22 @@ bool ValidateIdTypeOptions(const char *flagname, const std::string &value) {
// They are used to automatically load the same configuration as the main
// Memgraph binary so that the flags don't need to be specified when importing a
// CSV file on a correctly set-up Memgraph installation.
DEFINE_string(data_directory, "mg_data",
"Path to directory in which to save all permanent data.");
DEFINE_bool(storage_properties_on_edges, false,
"Controls whether relationships have properties.");
DEFINE_string(data_directory, "mg_data", "Path to directory in which to save all permanent data.");
DEFINE_bool(storage_properties_on_edges, false, "Controls whether relationships have properties.");
// CSV import flags.
DEFINE_string(array_delimiter, ";",
"Delimiter between elements of array values.");
DEFINE_string(array_delimiter, ";", "Delimiter between elements of array values.");
DEFINE_validator(array_delimiter, &ValidateControlCharacter);
DEFINE_string(delimiter, ",", "Delimiter between each field in the CSV.");
DEFINE_validator(delimiter, &ValidateControlCharacter);
DEFINE_string(quote, "\"",
"Quotation character for data in the CSV. Cannot contain '\n'");
DEFINE_string(quote, "\"", "Quotation character for data in the CSV. Cannot contain '\n'");
DEFINE_validator(quote, &ValidateControlCharacter);
DEFINE_bool(skip_duplicate_nodes, false,
"Set to true to skip duplicate nodes instead of raising an error.");
DEFINE_bool(skip_duplicate_nodes, false, "Set to true to skip duplicate nodes instead of raising an error.");
DEFINE_bool(skip_bad_relationships, false,
"Set to true to skip relationships that connect nodes that don't "
"exist instead of raising an error.");
DEFINE_bool(ignore_empty_strings, false,
"Set to true to treat empty strings as null values.");
DEFINE_bool(
ignore_extra_columns, false,
"Set to true to ignore columns that aren't specified in the header.");
DEFINE_bool(ignore_empty_strings, false, "Set to true to treat empty strings as null values.");
DEFINE_bool(ignore_extra_columns, false, "Set to true to ignore columns that aren't specified in the header.");
DEFINE_bool(trim_strings, false,
"Set to true to trim leading/trailing whitespace from all fields "
"that are loaded from the CSV file.");
@ -75,25 +67,22 @@ DEFINE_string(id_type, "STRING",
DEFINE_validator(id_type, &ValidateIdTypeOptions);
// Arguments `--nodes` and `--relationships` can be input multiple times and are
// handled with custom parsing.
DEFINE_string(
nodes, "",
"Files that should be parsed for nodes. The CSV header will be loaded from "
"the first supplied file, all other files supplied in a single flag will "
"be treated as data files. Additional labels can be specified for the node "
"files. The flag can be specified multiple times (useful for differently "
"formatted node files). The format of this argument is: "
"[<label>[:<label>]...=]<file>[,<file>][,<file>]...");
DEFINE_string(
relationships, "",
"Files that should be parsed for relationships. The CSV header will be "
"loaded from the first supplied file, all other files supplied in a single "
"flag will be treated as data files. The relationship type can be "
"specified for the relationship files. The flag can be specified multiple "
"times (useful for differently formatted relationship files). The format "
"of this argument is: [<type>=]<file>[,<file>][,<file>]...");
DEFINE_string(nodes, "",
"Files that should be parsed for nodes. The CSV header will be loaded from "
"the first supplied file, all other files supplied in a single flag will "
"be treated as data files. Additional labels can be specified for the node "
"files. The flag can be specified multiple times (useful for differently "
"formatted node files). The format of this argument is: "
"[<label>[:<label>]...=]<file>[,<file>][,<file>]...");
DEFINE_string(relationships, "",
"Files that should be parsed for relationships. The CSV header will be "
"loaded from the first supplied file, all other files supplied in a single "
"flag will be treated as data files. The relationship type can be "
"specified for the relationship files. The flag can be specified multiple "
"times (useful for differently formatted relationship files). The format "
"of this argument is: [<type>=]<file>[,<file>][,<file>]...");
std::vector<std::string> ParseRepeatedFlag(const std::string &flagname,
int argc, char *argv[]) {
std::vector<std::string> ParseRepeatedFlag(const std::string &flagname, int argc, char *argv[]) {
std::vector<std::string> values;
for (int i = 1; i < argc; ++i) {
std::string flag(argv[i]);
@ -132,9 +121,7 @@ struct NodeId {
std::string id_space;
};
bool operator==(const NodeId &a, const NodeId &b) {
return a.id == b.id && a.id_space == b.id_space;
}
bool operator==(const NodeId &a, const NodeId &b) { return a.id == b.id && a.id_space == b.id_space; }
std::ostream &operator<<(std::ostream &stream, const NodeId &node_id) {
if (!node_id.id_space.empty()) {
@ -171,8 +158,7 @@ enum class CsvParserState {
EXPECT_DELIMITER,
};
bool SubstringStartsWith(const std::string_view &str, size_t pos,
const std::string_view &what) {
bool SubstringStartsWith(const std::string_view &str, size_t pos, const std::string_view &what) {
return utils::StartsWith(utils::Substr(str, pos), what);
}
@ -262,8 +248,7 @@ std::pair<std::vector<std::string>, uint64_t> ReadRow(std::istream &stream) {
}
case CsvParserState::QUOTING: {
auto quote_now = SubstringStartsWith(line, i, FLAGS_quote);
auto quote_next =
SubstringStartsWith(line, i + FLAGS_quote.size(), FLAGS_quote);
auto quote_next = SubstringStartsWith(line, i + FLAGS_quote.size(), FLAGS_quote);
if (quote_now && quote_next) {
// This is an escaped quote character.
column += FLAGS_quote;
@ -293,8 +278,7 @@ std::pair<std::vector<std::string>, uint64_t> ReadRow(std::istream &stream) {
state = CsvParserState::NEXT_FIELD;
i += FLAGS_delimiter.size() - 1;
} else {
throw LoadException("Expected '{}' after '{}', but got '{}'",
FLAGS_delimiter, FLAGS_quote, c);
throw LoadException("Expected '{}' after '{}', but got '{}'", FLAGS_delimiter, FLAGS_quote, c);
}
break;
}
@ -326,8 +310,7 @@ std::pair<std::vector<std::string>, uint64_t> ReadRow(std::istream &stream) {
}
if (FLAGS_trim_strings) {
std::transform(std::begin(row), std::end(row), std::begin(row),
[](const auto &item) { return utils::Trim(item); });
std::transform(std::begin(row), std::end(row), std::begin(row), [](const auto &item) { return utils::Trim(item); });
}
return {std::move(row), lines_count};
@ -373,13 +356,10 @@ double StringToDouble(const std::string &value) {
}
/// @throw LoadException
storage::PropertyValue StringToValue(const std::string &str,
const std::string &type) {
if (FLAGS_ignore_empty_strings && str.empty())
return storage::PropertyValue();
storage::PropertyValue StringToValue(const std::string &str, const std::string &type) {
if (FLAGS_ignore_empty_strings && str.empty()) return storage::PropertyValue();
auto convert = [](const auto &str, const auto &type) {
if (type == "integer" || type == "int" || type == "long" ||
type == "byte" || type == "short") {
if (type == "integer" || type == "int" || type == "long" || type == "byte" || type == "short") {
return storage::PropertyValue(StringToInt(str));
} else if (type == "float" || type == "double") {
return storage::PropertyValue(StringToDouble(str));
@ -411,8 +391,7 @@ storage::PropertyValue StringToValue(const std::string &str,
std::string GetIdSpace(const std::string &type) {
// The format of this field is as follows:
// [START_|END_]ID[(<id_space>)]
std::regex format(R"(^(START_|END_)?ID(\(([^\(\)]+)\))?$)",
std::regex::extended);
std::regex format(R"(^(START_|END_)?ID(\(([^\(\)]+)\))?$)", std::regex::extended);
std::smatch res;
if (!std::regex_match(type, res, format))
throw LoadException(
@ -424,8 +403,7 @@ std::string GetIdSpace(const std::string &type) {
}
/// @throw LoadException
void ProcessNodeRow(storage::Storage *store, const std::vector<Field> &fields,
const std::vector<std::string> &row,
void ProcessNodeRow(storage::Storage *store, const std::vector<Field> &fields, const std::vector<std::string> &row,
const std::vector<std::string> &additional_labels,
std::unordered_map<NodeId, storage::Gid> *node_id_map) {
std::optional<NodeId> id;
@ -458,45 +436,32 @@ void ProcessNodeRow(storage::Storage *store, const std::vector<Field> &fields,
} else {
pv_id = storage::PropertyValue(node_id.id);
}
auto node_property =
node.SetProperty(acc.NameToProperty(field.name), pv_id);
if (!node_property.HasValue())
throw LoadException("Couldn't add property '{}' to the node",
field.name);
if (!*node_property)
throw LoadException("The property '{}' already exists", field.name);
auto node_property = node.SetProperty(acc.NameToProperty(field.name), pv_id);
if (!node_property.HasValue()) throw LoadException("Couldn't add property '{}' to the node", field.name);
if (!*node_property) throw LoadException("The property '{}' already exists", field.name);
}
id = node_id;
} else if (field.type == "LABEL") {
for (const auto &label : utils::Split(value, FLAGS_array_delimiter)) {
auto node_label = node.AddLabel(acc.NameToLabel(label));
if (!node_label.HasValue())
throw LoadException("Couldn't add label '{}' to the node", label);
if (!*node_label)
throw LoadException("The label '{}' already exists", label);
if (!node_label.HasValue()) throw LoadException("Couldn't add label '{}' to the node", label);
if (!*node_label) throw LoadException("The label '{}' already exists", label);
}
} else if (field.type != "IGNORE") {
auto node_property = node.SetProperty(acc.NameToProperty(field.name),
StringToValue(value, field.type));
if (!node_property.HasValue())
throw LoadException("Couldn't add property '{}' to the node",
field.name);
if (!*node_property)
throw LoadException("The property '{}' already exists", field.name);
auto node_property = node.SetProperty(acc.NameToProperty(field.name), StringToValue(value, field.type));
if (!node_property.HasValue()) throw LoadException("Couldn't add property '{}' to the node", field.name);
if (!*node_property) throw LoadException("The property '{}' already exists", field.name);
}
}
for (const auto &label : additional_labels) {
auto node_label = node.AddLabel(acc.NameToLabel(label));
if (!node_label.HasValue())
throw LoadException("Couldn't add label '{}' to the node", label);
if (!*node_label)
throw LoadException("The label '{}' already exists", label);
if (!node_label.HasValue()) throw LoadException("Couldn't add label '{}' to the node", label);
if (!*node_label) throw LoadException("The label '{}' already exists", label);
}
if (acc.Commit().HasError()) throw LoadException("Couldn't store the node");
}
void ProcessNodes(storage::Storage *store, const std::string &nodes_path,
std::optional<std::vector<Field>> *header,
void ProcessNodes(storage::Storage *store, const std::string &nodes_path, std::optional<std::vector<Field>> *header,
std::unordered_map<NodeId, storage::Gid> *node_id_map,
const std::vector<std::string> &additional_labels) {
std::ifstream nodes_file(nodes_path);
@ -524,17 +489,14 @@ void ProcessNodes(storage::Storage *store, const std::string &nodes_path,
row_number += lines_count;
}
} catch (const LoadException &e) {
LOG_FATAL("Couldn't process row {} of '{}' because of: {}", row_number,
nodes_path, e.what());
LOG_FATAL("Couldn't process row {} of '{}' because of: {}", row_number, nodes_path, e.what());
}
}
/// @throw LoadException
void ProcessRelationshipsRow(
storage::Storage *store, const std::vector<Field> &fields,
const std::vector<std::string> &row,
std::optional<std::string> relationship_type,
const std::unordered_map<NodeId, storage::Gid> &node_id_map) {
void ProcessRelationshipsRow(storage::Storage *store, const std::vector<Field> &fields,
const std::vector<std::string> &row, std::optional<std::string> relationship_type,
const std::unordered_map<NodeId, storage::Gid> &node_id_map) {
std::optional<storage::Gid> start_id;
std::optional<storage::Gid> end_id;
std::map<std::string, storage::PropertyValue> properties;
@ -576,14 +538,11 @@ void ProcessRelationshipsRow(
}
end_id = it->second;
} else if (field.type == "TYPE") {
if (relationship_type)
throw LoadException("Only one relationship TYPE must be specified");
if (relationship_type) throw LoadException("Only one relationship TYPE must be specified");
relationship_type = value;
} else if (field.type != "IGNORE") {
auto [it, inserted] =
properties.emplace(field.name, StringToValue(value, field.type));
if (!inserted)
throw LoadException("The property '{}' already exists", field.name);
auto [it, inserted] = properties.emplace(field.name, StringToValue(value, field.type));
if (!inserted) throw LoadException("The property '{}' already exists", field.name);
}
}
if (!start_id) throw LoadException("START_ID must be set");
@ -596,18 +555,14 @@ void ProcessRelationshipsRow(
auto to_node = acc.FindVertex(*end_id, storage::View::NEW);
if (!to_node) throw LoadException("To node must be in the storage");
auto relationship = acc.CreateEdge(&*from_node, &*to_node,
acc.NameToEdgeType(*relationship_type));
if (!relationship.HasValue())
throw LoadException("Couldn't create the relationship");
auto relationship = acc.CreateEdge(&*from_node, &*to_node, acc.NameToEdgeType(*relationship_type));
if (!relationship.HasValue()) throw LoadException("Couldn't create the relationship");
for (const auto &property : properties) {
auto ret = relationship->SetProperty(acc.NameToProperty(property.first),
property.second);
auto ret = relationship->SetProperty(acc.NameToProperty(property.first), property.second);
if (!ret.HasValue()) {
if (ret.GetError() != storage::Error::PROPERTIES_DISABLED) {
throw LoadException("Couldn't add property '{}' to the relationship",
property.first);
throw LoadException("Couldn't add property '{}' to the relationship", property.first);
} else {
throw LoadException(
"Couldn't add property '{}' to the relationship because properties "
@ -617,15 +572,13 @@ void ProcessRelationshipsRow(
}
}
if (acc.Commit().HasError())
throw LoadException("Couldn't store the relationship");
if (acc.Commit().HasError()) throw LoadException("Couldn't store the relationship");
}
void ProcessRelationships(
storage::Storage *store, const std::string &relationships_path,
const std::optional<std::string> &relationship_type,
std::optional<std::vector<Field>> *header,
const std::unordered_map<NodeId, storage::Gid> &node_id_map) {
void ProcessRelationships(storage::Storage *store, const std::string &relationships_path,
const std::optional<std::string> &relationship_type,
std::optional<std::vector<Field>> *header,
const std::unordered_map<NodeId, storage::Gid> &node_id_map) {
std::ifstream relationships_file(relationships_path);
MG_ASSERT(relationships_file, "Unable to open '{}'", relationships_path);
uint64_t row_number = 1;
@ -647,13 +600,11 @@ void ProcessRelationships(
if (row.size() > (*header)->size()) {
row.resize((*header)->size());
}
ProcessRelationshipsRow(store, **header, row, relationship_type,
node_id_map);
ProcessRelationshipsRow(store, **header, row, relationship_type, node_id_map);
row_number += lines_count;
}
} catch (const LoadException &e) {
LOG_FATAL("Couldn't process row {} of '{}' because of: {}", row_number,
relationships_path, e.what());
LOG_FATAL("Couldn't process row {} of '{}' because of: {}", row_number, relationships_path, e.what());
}
}
@ -735,8 +686,7 @@ int main(int argc, char *argv[]) {
.items = {.properties_on_edges = FLAGS_storage_properties_on_edges},
.durability = {.storage_directory = FLAGS_data_directory,
.recover_on_startup = false,
.snapshot_wal_mode =
storage::Config::Durability::SnapshotWalMode::DISABLED,
.snapshot_wal_mode = storage::Config::Durability::SnapshotWalMode::DISABLED,
.snapshot_on_exit = true},
}};
@ -748,8 +698,7 @@ int main(int argc, char *argv[]) {
std::optional<std::vector<Field>> header;
for (const auto &nodes_file : files) {
spdlog::info("Loading {}", nodes_file);
ProcessNodes(&store, nodes_file, &header, &node_id_map,
additional_labels);
ProcessNodes(&store, nodes_file, &header, &node_id_map, additional_labels);
}
}
@ -759,8 +708,7 @@ int main(int argc, char *argv[]) {
std::optional<std::vector<Field>> header;
for (const auto &relationships_file : files) {
spdlog::info("Loading {}", relationships_file);
ProcessRelationships(&store, relationships_file, type, &header,
node_id_map);
ProcessRelationships(&store, relationships_file, type, &header, node_id_map);
}
}

View File

@ -104,34 +104,26 @@ class [[nodiscard]] Object final {
/// This function always succeeds, meaning that exceptions that occur while
/// calling __getattr__ and __getattribute__ will get suppressed. To get error
/// reporting, use GetAttr instead.
bool HasAttr(const char *attr_name) const {
return PyObject_HasAttrString(ptr_, attr_name);
}
bool HasAttr(const char *attr_name) const { return PyObject_HasAttrString(ptr_, attr_name); }
/// Equivalent to `hasattr(this, attr_name)` in Python.
///
/// This function always succeeds, meaning that exceptions that occur while
/// calling __getattr__ and __getattribute__ will get suppressed. To get error
/// reporting, use GetAttr instead.
bool HasAttr(PyObject *attr_name) const {
return PyObject_HasAttr(ptr_, attr_name);
}
bool HasAttr(PyObject *attr_name) const { return PyObject_HasAttr(ptr_, attr_name); }
/// Equivalent to `this.attr_name` in Python.
///
/// Returned Object is nullptr if an error occurred.
/// @sa FetchError
Object GetAttr(const char *attr_name) const {
return Object(PyObject_GetAttrString(ptr_, attr_name));
}
Object GetAttr(const char *attr_name) const { return Object(PyObject_GetAttrString(ptr_, attr_name)); }
/// Equivalent to `this.attr_name` in Python.
///
/// Returned Object is nullptr if an error occurred.
/// @sa FetchError
Object GetAttr(PyObject *attr_name) const {
return Object(PyObject_GetAttr(ptr_, attr_name));
}
Object GetAttr(PyObject *attr_name) const { return Object(PyObject_GetAttr(ptr_, attr_name)); }
/// Equivalent to `this.attr_name = v` in Python.
///
@ -145,9 +137,7 @@ class [[nodiscard]] Object final {
///
/// False is returned if an error occurred.
/// @sa FetchError
[[nodiscard]] bool SetAttr(PyObject *attr_name, PyObject *v) {
return PyObject_SetAttr(ptr_, attr_name, v) == 0;
}
[[nodiscard]] bool SetAttr(PyObject *attr_name, PyObject *v) { return PyObject_SetAttr(ptr_, attr_name, v) == 0; }
/// Equivalent to `callable()` in Python.
///
@ -161,8 +151,7 @@ class [[nodiscard]] Object final {
/// @sa FetchError
template <class... TArgs>
Object Call(const TArgs &...args) const {
return Object(PyObject_CallFunctionObjArgs(
ptr_, static_cast<PyObject *>(args)..., nullptr));
return Object(PyObject_CallFunctionObjArgs(ptr_, static_cast<PyObject *>(args)..., nullptr));
}
/// Equivalent to `obj.meth_name()` in Python.
@ -170,8 +159,7 @@ class [[nodiscard]] Object final {
/// Returned Object is nullptr if an error occurred.
/// @sa FetchError
Object CallMethod(std::string_view meth_name) const {
Object name(
PyUnicode_FromStringAndSize(meth_name.data(), meth_name.size()));
Object name(PyUnicode_FromStringAndSize(meth_name.data(), meth_name.size()));
return Object(PyObject_CallMethodObjArgs(ptr_, name.Ptr(), nullptr));
}
@ -181,10 +169,8 @@ class [[nodiscard]] Object final {
/// @sa FetchError
template <class... TArgs>
Object CallMethod(std::string_view meth_name, const TArgs &...args) const {
Object name(
PyUnicode_FromStringAndSize(meth_name.data(), meth_name.size()));
return Object(PyObject_CallMethodObjArgs(
ptr_, name.Ptr(), static_cast<PyObject *>(args)..., nullptr));
Object name(PyUnicode_FromStringAndSize(meth_name.data(), meth_name.size()));
return Object(PyObject_CallMethodObjArgs(ptr_, name.Ptr(), static_cast<PyObject *>(args)..., nullptr));
}
};
@ -210,8 +196,7 @@ struct [[nodiscard]] ExceptionInfo final {
/// argument `skip_first_line` allows the user to skip the first line of the
/// traceback. It is useful if the first line in the traceback always prints
/// some internal wrapper function.
[[nodiscard]] inline std::string FormatException(const ExceptionInfo &exc_info,
bool skip_first_line = false) {
[[nodiscard]] inline std::string FormatException(const ExceptionInfo &exc_info, bool skip_first_line = false) {
if (!exc_info.type) return "";
Object traceback_mod(PyImport_ImportModule("traceback"));
MG_ASSERT(traceback_mod);
@ -221,9 +206,8 @@ struct [[nodiscard]] ExceptionInfo final {
if (skip_first_line && traceback_root) {
traceback_root = traceback_root.GetAttr("tb_next");
}
auto list = format_exception_fn.Call(
exc_info.type, exc_info.value ? exc_info.value.Ptr() : Py_None,
traceback_root ? traceback_root.Ptr() : Py_None);
auto list = format_exception_fn.Call(exc_info.type, exc_info.value ? exc_info.value.Ptr() : Py_None,
traceback_root ? traceback_root.Ptr() : Py_None);
MG_ASSERT(list);
std::stringstream ss;
auto len = PyList_GET_SIZE(list.Ptr());
@ -235,8 +219,7 @@ struct [[nodiscard]] ExceptionInfo final {
}
/// Write ExceptionInfo to stream just like the Python interpreter would.
inline std::ostream &operator<<(std::ostream &os,
const ExceptionInfo &exc_info) {
inline std::ostream &operator<<(std::ostream &os, const ExceptionInfo &exc_info) {
os << FormatException(exc_info);
return os;
}
@ -259,16 +242,14 @@ inline std::ostream &operator<<(std::ostream &os,
}
inline void RestoreError(ExceptionInfo exc_info) {
PyErr_Restore(exc_info.type.Steal(), exc_info.value.Steal(),
exc_info.traceback.Steal());
PyErr_Restore(exc_info.type.Steal(), exc_info.value.Steal(), exc_info.traceback.Steal());
}
/// Append `dir` to Python's `sys.path`.
///
/// The function does not check whether the directory exists, or is readable.
/// ExceptionInfo is returned if an error occurred.
[[nodiscard]] inline std::optional<ExceptionInfo> AppendToSysPath(
const char *dir) {
[[nodiscard]] inline std::optional<ExceptionInfo> AppendToSysPath(const char *dir) {
MG_ASSERT(dir);
auto *py_path = PySys_GetObject("path");
MG_ASSERT(py_path);

View File

@ -15,9 +15,7 @@ bool TypedValueCompare(const TypedValue &a, const TypedValue &b) {
// comparisons are from this point legal only between values of
// the same type, or int+float combinations
if ((a.type() != b.type() && !(a.IsNumeric() && b.IsNumeric())))
throw QueryRuntimeException(
"Can't compare value of type {} to value of type {}.", a.type(),
b.type());
throw QueryRuntimeException("Can't compare value of type {} to value of type {}.", a.type(), b.type());
switch (a.type()) {
case TypedValue::Type::Bool:
@ -39,8 +37,7 @@ bool TypedValueCompare(const TypedValue &a, const TypedValue &b) {
case TypedValue::Type::Vertex:
case TypedValue::Type::Edge:
case TypedValue::Type::Path:
throw QueryRuntimeException(
"Comparison is not defined for values of type {}.", a.type());
throw QueryRuntimeException("Comparison is not defined for values of type {}.", a.type());
default:
LOG_FATAL("Unhandled comparison for types");
}

View File

@ -27,12 +27,10 @@ bool TypedValueCompare(const TypedValue &a, const TypedValue &b);
class TypedValueVectorCompare final {
public:
TypedValueVectorCompare() {}
explicit TypedValueVectorCompare(const std::vector<Ordering> &ordering)
: ordering_(ordering) {}
explicit TypedValueVectorCompare(const std::vector<Ordering> &ordering) : ordering_(ordering) {}
template <class TAllocator>
bool operator()(const std::vector<TypedValue, TAllocator> &c1,
const std::vector<TypedValue, TAllocator> &c2) const {
bool operator()(const std::vector<TypedValue, TAllocator> &c1, const std::vector<TypedValue, TAllocator> &c2) const {
// ordering is invalid if there are more elements in the collections
// then there are in the ordering_ vector
MG_ASSERT(c1.size() <= ordering_.size() && c2.size() <= ordering_.size(),
@ -41,12 +39,9 @@ class TypedValueVectorCompare final {
auto c1_it = c1.begin();
auto c2_it = c2.begin();
auto ordering_it = ordering_.begin();
for (; c1_it != c1.end() && c2_it != c2.end();
c1_it++, c2_it++, ordering_it++) {
if (impl::TypedValueCompare(*c1_it, *c2_it))
return *ordering_it == Ordering::ASC;
if (impl::TypedValueCompare(*c2_it, *c1_it))
return *ordering_it == Ordering::DESC;
for (; c1_it != c1.end() && c2_it != c2.end(); c1_it++, c2_it++, ordering_it++) {
if (impl::TypedValueCompare(*c1_it, *c2_it)) return *ordering_it == Ordering::ASC;
if (impl::TypedValueCompare(*c2_it, *c1_it)) return *ordering_it == Ordering::DESC;
}
// at least one collection is exhausted
@ -61,41 +56,33 @@ class TypedValueVectorCompare final {
};
/// Raise QueryRuntimeException if the value for symbol isn't of expected type.
inline void ExpectType(const Symbol &symbol, const TypedValue &value,
TypedValue::Type expected) {
inline void ExpectType(const Symbol &symbol, const TypedValue &value, TypedValue::Type expected) {
if (value.type() != expected)
throw QueryRuntimeException("Expected a {} for '{}', but got {}.", expected,
symbol.name(), value.type());
throw QueryRuntimeException("Expected a {} for '{}', but got {}.", expected, symbol.name(), value.type());
}
/// Set a property `value` mapped with given `key` on a `record`.
///
/// @throw QueryRuntimeException if value cannot be set as a property value
template <class TRecordAccessor>
void PropsSetChecked(TRecordAccessor *record, const storage::PropertyId &key,
const TypedValue &value) {
void PropsSetChecked(TRecordAccessor *record, const storage::PropertyId &key, const TypedValue &value) {
try {
auto maybe_error = record->SetProperty(key, storage::PropertyValue(value));
if (maybe_error.HasError()) {
switch (maybe_error.GetError()) {
case storage::Error::SERIALIZATION_ERROR:
throw QueryRuntimeException(
"Can't serialize due to concurrent operations.");
throw QueryRuntimeException("Can't serialize due to concurrent operations.");
case storage::Error::DELETED_OBJECT:
throw QueryRuntimeException(
"Trying to set properties on a deleted object.");
throw QueryRuntimeException("Trying to set properties on a deleted object.");
case storage::Error::PROPERTIES_DISABLED:
throw QueryRuntimeException(
"Can't set property because properties on edges are disabled.");
throw QueryRuntimeException("Can't set property because properties on edges are disabled.");
case storage::Error::VERTEX_HAS_EDGES:
case storage::Error::NONEXISTENT_OBJECT:
throw QueryRuntimeException(
"Unexpected error when setting a property.");
throw QueryRuntimeException("Unexpected error when setting a property.");
}
}
} catch (const TypedValueException &) {
throw QueryRuntimeException("'{}' cannot be used as a property value.",
value.type());
throw QueryRuntimeException("'{}' cannot be used as a property value.", value.type());
}
}

View File

@ -26,8 +26,8 @@ struct EvaluationContext {
mutable std::unordered_map<std::string, int64_t> counters;
};
inline std::vector<storage::PropertyId> NamesToProperties(
const std::vector<std::string> &property_names, DbAccessor *dba) {
inline std::vector<storage::PropertyId> NamesToProperties(const std::vector<std::string> &property_names,
DbAccessor *dba) {
std::vector<storage::PropertyId> properties;
properties.reserve(property_names.size());
for (const auto &name : property_names) {
@ -36,8 +36,7 @@ inline std::vector<storage::PropertyId> NamesToProperties(
return properties;
}
inline std::vector<storage::LabelId> NamesToLabels(
const std::vector<std::string> &label_names, DbAccessor *dba) {
inline std::vector<storage::LabelId> NamesToLabels(const std::vector<std::string> &label_names, DbAccessor *dba) {
std::vector<storage::LabelId> labels;
labels.reserve(label_names.size());
for (const auto &name : label_names) {
@ -60,11 +59,9 @@ struct ExecutionContext {
};
inline bool MustAbort(const ExecutionContext &context) {
return (context.is_shutting_down &&
context.is_shutting_down->load(std::memory_order_acquire)) ||
return (context.is_shutting_down && context.is_shutting_down->load(std::memory_order_acquire)) ||
(context.max_execution_time_sec > 0 &&
context.execution_tsc_timer.Elapsed() >=
context.max_execution_time_sec);
context.execution_tsc_timer.Elapsed() >= context.max_execution_time_sec);
}
} // namespace query

View File

@ -47,19 +47,15 @@ class EdgeAccessor final {
auto Properties(storage::View view) const { return impl_.Properties(view); }
storage::Result<storage::PropertyValue> GetProperty(storage::View view,
storage::PropertyId key) const {
storage::Result<storage::PropertyValue> GetProperty(storage::View view, storage::PropertyId key) const {
return impl_.GetProperty(key, view);
}
storage::Result<bool> SetProperty(storage::PropertyId key,
const storage::PropertyValue &value) {
storage::Result<bool> SetProperty(storage::PropertyId key, const storage::PropertyValue &value) {
return impl_.SetProperty(key, value);
}
storage::Result<bool> RemoveProperty(storage::PropertyId key) {
return SetProperty(key, storage::PropertyValue());
}
storage::Result<bool> RemoveProperty(storage::PropertyId key) { return SetProperty(key, storage::PropertyValue()); }
utils::BasicResult<storage::Error, void> ClearProperties() {
auto ret = impl_.ClearProperties();
@ -86,44 +82,32 @@ class VertexAccessor final {
public:
storage::VertexAccessor impl_;
static EdgeAccessor MakeEdgeAccessor(const storage::EdgeAccessor impl) {
return EdgeAccessor(impl);
}
static EdgeAccessor MakeEdgeAccessor(const storage::EdgeAccessor impl) { return EdgeAccessor(impl); }
public:
explicit VertexAccessor(storage::VertexAccessor impl)
: impl_(std::move(impl)) {}
explicit VertexAccessor(storage::VertexAccessor impl) : impl_(std::move(impl)) {}
auto Labels(storage::View view) const { return impl_.Labels(view); }
storage::Result<bool> AddLabel(storage::LabelId label) {
return impl_.AddLabel(label);
}
storage::Result<bool> AddLabel(storage::LabelId label) { return impl_.AddLabel(label); }
storage::Result<bool> RemoveLabel(storage::LabelId label) {
return impl_.RemoveLabel(label);
}
storage::Result<bool> RemoveLabel(storage::LabelId label) { return impl_.RemoveLabel(label); }
storage::Result<bool> HasLabel(storage::View view,
storage::LabelId label) const {
storage::Result<bool> HasLabel(storage::View view, storage::LabelId label) const {
return impl_.HasLabel(label, view);
}
auto Properties(storage::View view) const { return impl_.Properties(view); }
storage::Result<storage::PropertyValue> GetProperty(storage::View view,
storage::PropertyId key) const {
storage::Result<storage::PropertyValue> GetProperty(storage::View view, storage::PropertyId key) const {
return impl_.GetProperty(key, view);
}
storage::Result<bool> SetProperty(storage::PropertyId key,
const storage::PropertyValue &value) {
storage::Result<bool> SetProperty(storage::PropertyId key, const storage::PropertyValue &value) {
return impl_.SetProperty(key, value);
}
storage::Result<bool> RemoveProperty(storage::PropertyId key) {
return SetProperty(key, storage::PropertyValue());
}
storage::Result<bool> RemoveProperty(storage::PropertyId key) { return SetProperty(key, storage::PropertyValue()); }
utils::BasicResult<storage::Error, void> ClearProperties() {
auto ret = impl_.ClearProperties();
@ -131,10 +115,8 @@ class VertexAccessor final {
return {};
}
auto InEdges(storage::View view,
const std::vector<storage::EdgeTypeId> &edge_types) const
-> storage::Result<decltype(iter::imap(MakeEdgeAccessor,
*impl_.InEdges(view)))> {
auto InEdges(storage::View view, const std::vector<storage::EdgeTypeId> &edge_types) const
-> storage::Result<decltype(iter::imap(MakeEdgeAccessor, *impl_.InEdges(view)))> {
auto maybe_edges = impl_.InEdges(view, edge_types);
if (maybe_edges.HasError()) return maybe_edges.GetError();
return iter::imap(MakeEdgeAccessor, std::move(*maybe_edges));
@ -142,20 +124,15 @@ class VertexAccessor final {
auto InEdges(storage::View view) const { return InEdges(view, {}); }
auto InEdges(storage::View view,
const std::vector<storage::EdgeTypeId> &edge_types,
const VertexAccessor &dest) const
-> storage::Result<decltype(iter::imap(MakeEdgeAccessor,
*impl_.InEdges(view)))> {
auto InEdges(storage::View view, const std::vector<storage::EdgeTypeId> &edge_types, const VertexAccessor &dest) const
-> storage::Result<decltype(iter::imap(MakeEdgeAccessor, *impl_.InEdges(view)))> {
auto maybe_edges = impl_.InEdges(view, edge_types, &dest.impl_);
if (maybe_edges.HasError()) return maybe_edges.GetError();
return iter::imap(MakeEdgeAccessor, std::move(*maybe_edges));
}
auto OutEdges(storage::View view,
const std::vector<storage::EdgeTypeId> &edge_types) const
-> storage::Result<decltype(iter::imap(MakeEdgeAccessor,
*impl_.OutEdges(view)))> {
auto OutEdges(storage::View view, const std::vector<storage::EdgeTypeId> &edge_types) const
-> storage::Result<decltype(iter::imap(MakeEdgeAccessor, *impl_.OutEdges(view)))> {
auto maybe_edges = impl_.OutEdges(view, edge_types);
if (maybe_edges.HasError()) return maybe_edges.GetError();
return iter::imap(MakeEdgeAccessor, std::move(*maybe_edges));
@ -163,23 +140,17 @@ class VertexAccessor final {
auto OutEdges(storage::View view) const { return OutEdges(view, {}); }
auto OutEdges(storage::View view,
const std::vector<storage::EdgeTypeId> &edge_types,
auto OutEdges(storage::View view, const std::vector<storage::EdgeTypeId> &edge_types,
const VertexAccessor &dest) const
-> storage::Result<decltype(iter::imap(MakeEdgeAccessor,
*impl_.OutEdges(view)))> {
-> storage::Result<decltype(iter::imap(MakeEdgeAccessor, *impl_.OutEdges(view)))> {
auto maybe_edges = impl_.OutEdges(view, edge_types, &dest.impl_);
if (maybe_edges.HasError()) return maybe_edges.GetError();
return iter::imap(MakeEdgeAccessor, std::move(*maybe_edges));
}
storage::Result<size_t> InDegree(storage::View view) const {
return impl_.InDegree(view);
}
storage::Result<size_t> InDegree(storage::View view) const { return impl_.InDegree(view); }
storage::Result<size_t> OutDegree(storage::View view) const {
return impl_.OutDegree(view);
}
storage::Result<size_t> OutDegree(storage::View view) const { return impl_.OutDegree(view); }
int64_t CypherId() const { return impl_.Gid().AsInt(); }
@ -190,13 +161,9 @@ class VertexAccessor final {
bool operator!=(const VertexAccessor &v) const { return !(*this == v); }
};
inline VertexAccessor EdgeAccessor::To() const {
return VertexAccessor(impl_.ToVertex());
}
inline VertexAccessor EdgeAccessor::To() const { return VertexAccessor(impl_.ToVertex()); }
inline VertexAccessor EdgeAccessor::From() const {
return VertexAccessor(impl_.FromVertex());
}
inline VertexAccessor EdgeAccessor::From() const { return VertexAccessor(impl_.FromVertex()); }
inline bool EdgeAccessor::IsCycle() const { return To() == From(); }
@ -225,8 +192,7 @@ class DbAccessor final {
bool operator!=(const Iterator &other) const { return !(other == *this); }
};
explicit VerticesIterable(storage::VerticesIterable iterable)
: iterable_(std::move(iterable)) {}
explicit VerticesIterable(storage::VerticesIterable iterable) : iterable_(std::move(iterable)) {}
Iterator begin() { return Iterator(iterable_.begin()); }
@ -234,60 +200,45 @@ class DbAccessor final {
};
public:
explicit DbAccessor(storage::Storage::Accessor *accessor)
: accessor_(accessor) {}
explicit DbAccessor(storage::Storage::Accessor *accessor) : accessor_(accessor) {}
std::optional<VertexAccessor> FindVertex(storage::Gid gid,
storage::View view) {
std::optional<VertexAccessor> FindVertex(storage::Gid gid, storage::View view) {
auto maybe_vertex = accessor_->FindVertex(gid, view);
if (maybe_vertex) return VertexAccessor(*maybe_vertex);
return std::nullopt;
}
VerticesIterable Vertices(storage::View view) {
return VerticesIterable(accessor_->Vertices(view));
}
VerticesIterable Vertices(storage::View view) { return VerticesIterable(accessor_->Vertices(view)); }
VerticesIterable Vertices(storage::View view, storage::LabelId label) {
return VerticesIterable(accessor_->Vertices(label, view));
}
VerticesIterable Vertices(storage::View view, storage::LabelId label,
storage::PropertyId property) {
VerticesIterable Vertices(storage::View view, storage::LabelId label, storage::PropertyId property) {
return VerticesIterable(accessor_->Vertices(label, property, view));
}
VerticesIterable Vertices(storage::View view, storage::LabelId label,
storage::PropertyId property,
VerticesIterable Vertices(storage::View view, storage::LabelId label, storage::PropertyId property,
const storage::PropertyValue &value) {
return VerticesIterable(accessor_->Vertices(label, property, value, view));
}
VerticesIterable Vertices(
storage::View view, storage::LabelId label, storage::PropertyId property,
const std::optional<utils::Bound<storage::PropertyValue>> &lower,
const std::optional<utils::Bound<storage::PropertyValue>> &upper) {
return VerticesIterable(
accessor_->Vertices(label, property, lower, upper, view));
VerticesIterable Vertices(storage::View view, storage::LabelId label, storage::PropertyId property,
const std::optional<utils::Bound<storage::PropertyValue>> &lower,
const std::optional<utils::Bound<storage::PropertyValue>> &upper) {
return VerticesIterable(accessor_->Vertices(label, property, lower, upper, view));
}
VertexAccessor InsertVertex() {
return VertexAccessor(accessor_->CreateVertex());
}
VertexAccessor InsertVertex() { return VertexAccessor(accessor_->CreateVertex()); }
storage::Result<EdgeAccessor> InsertEdge(VertexAccessor *from,
VertexAccessor *to,
storage::Result<EdgeAccessor> InsertEdge(VertexAccessor *from, VertexAccessor *to,
const storage::EdgeTypeId &edge_type) {
auto maybe_edge =
accessor_->CreateEdge(&from->impl_, &to->impl_, edge_type);
if (maybe_edge.HasError())
return storage::Result<EdgeAccessor>(maybe_edge.GetError());
auto maybe_edge = accessor_->CreateEdge(&from->impl_, &to->impl_, edge_type);
if (maybe_edge.HasError()) return storage::Result<EdgeAccessor>(maybe_edge.GetError());
return EdgeAccessor(std::move(*maybe_edge));
}
storage::Result<bool> RemoveEdge(EdgeAccessor *edge) {
return accessor_->DeleteEdge(&edge->impl_);
}
storage::Result<bool> RemoveEdge(EdgeAccessor *edge) { return accessor_->DeleteEdge(&edge->impl_); }
storage::Result<bool> DetachRemoveVertex(VertexAccessor *vertex_accessor) {
return accessor_->DetachDeleteVertex(&vertex_accessor->impl_);
@ -297,55 +248,35 @@ class DbAccessor final {
return accessor_->DeleteVertex(&vertex_accessor->impl_);
}
storage::PropertyId NameToProperty(const std::string_view &name) {
return accessor_->NameToProperty(name);
}
storage::PropertyId NameToProperty(const std::string_view &name) { return accessor_->NameToProperty(name); }
storage::LabelId NameToLabel(const std::string_view &name) {
return accessor_->NameToLabel(name);
}
storage::LabelId NameToLabel(const std::string_view &name) { return accessor_->NameToLabel(name); }
storage::EdgeTypeId NameToEdgeType(const std::string_view &name) {
return accessor_->NameToEdgeType(name);
}
storage::EdgeTypeId NameToEdgeType(const std::string_view &name) { return accessor_->NameToEdgeType(name); }
const std::string &PropertyToName(storage::PropertyId prop) const {
return accessor_->PropertyToName(prop);
}
const std::string &PropertyToName(storage::PropertyId prop) const { return accessor_->PropertyToName(prop); }
const std::string &LabelToName(storage::LabelId label) const {
return accessor_->LabelToName(label);
}
const std::string &LabelToName(storage::LabelId label) const { return accessor_->LabelToName(label); }
const std::string &EdgeTypeToName(storage::EdgeTypeId type) const {
return accessor_->EdgeTypeToName(type);
}
const std::string &EdgeTypeToName(storage::EdgeTypeId type) const { return accessor_->EdgeTypeToName(type); }
void AdvanceCommand() { accessor_->AdvanceCommand(); }
utils::BasicResult<storage::ConstraintViolation, void> Commit() {
return accessor_->Commit();
}
utils::BasicResult<storage::ConstraintViolation, void> Commit() { return accessor_->Commit(); }
void Abort() { accessor_->Abort(); }
bool LabelIndexExists(storage::LabelId label) const {
return accessor_->LabelIndexExists(label);
}
bool LabelIndexExists(storage::LabelId label) const { return accessor_->LabelIndexExists(label); }
bool LabelPropertyIndexExists(storage::LabelId label,
storage::PropertyId prop) const {
bool LabelPropertyIndexExists(storage::LabelId label, storage::PropertyId prop) const {
return accessor_->LabelPropertyIndexExists(label, prop);
}
int64_t VerticesCount() const { return accessor_->ApproximateVertexCount(); }
int64_t VerticesCount(storage::LabelId label) const {
return accessor_->ApproximateVertexCount(label);
}
int64_t VerticesCount(storage::LabelId label) const { return accessor_->ApproximateVertexCount(label); }
int64_t VerticesCount(storage::LabelId label,
storage::PropertyId property) const {
int64_t VerticesCount(storage::LabelId label, storage::PropertyId property) const {
return accessor_->ApproximateVertexCount(label, property);
}
@ -354,20 +285,15 @@ class DbAccessor final {
return accessor_->ApproximateVertexCount(label, property, value);
}
int64_t VerticesCount(
storage::LabelId label, storage::PropertyId property,
const std::optional<utils::Bound<storage::PropertyValue>> &lower,
const std::optional<utils::Bound<storage::PropertyValue>> &upper) const {
int64_t VerticesCount(storage::LabelId label, storage::PropertyId property,
const std::optional<utils::Bound<storage::PropertyValue>> &lower,
const std::optional<utils::Bound<storage::PropertyValue>> &upper) const {
return accessor_->ApproximateVertexCount(label, property, lower, upper);
}
storage::IndicesInfo ListAllIndices() const {
return accessor_->ListAllIndices();
}
storage::IndicesInfo ListAllIndices() const { return accessor_->ListAllIndices(); }
storage::ConstraintsInfo ListAllConstraints() const {
return accessor_->ListAllConstraints();
}
storage::ConstraintsInfo ListAllConstraints() const { return accessor_->ListAllConstraints(); }
};
} // namespace query
@ -376,16 +302,12 @@ namespace std {
template <>
struct hash<query::VertexAccessor> {
size_t operator()(const query::VertexAccessor &v) const {
return std::hash<decltype(v.impl_)>{}(v.impl_);
}
size_t operator()(const query::VertexAccessor &v) const { return std::hash<decltype(v.impl_)>{}(v.impl_); }
};
template <>
struct hash<query::EdgeAccessor> {
size_t operator()(const query::EdgeAccessor &e) const {
return std::hash<decltype(e.impl_)>{}(e.impl_);
}
size_t operator()(const query::EdgeAccessor &e) const { return std::hash<decltype(e.impl_)>{}(e.impl_); }
};
} // namespace std

View File

@ -50,8 +50,7 @@ void DumpPreciseDouble(std::ostream *os, double value) {
// A temporary stream is used to keep precision of the original output
// stream unchanged.
std::ostringstream temp_oss;
temp_oss << std::setprecision(std::numeric_limits<double>::max_digits10)
<< value;
temp_oss << std::setprecision(std::numeric_limits<double>::max_digits10) << value;
*os << temp_oss.str();
}
@ -75,9 +74,7 @@ void DumpPropertyValue(std::ostream *os, const storage::PropertyValue &value) {
case storage::PropertyValue::Type::List: {
*os << "[";
const auto &list = value.ValueList();
utils::PrintIterable(*os, list, ", ", [](auto &os, const auto &item) {
DumpPropertyValue(&os, item);
});
utils::PrintIterable(*os, list, ", ", [](auto &os, const auto &item) { DumpPropertyValue(&os, item); });
*os << "]";
return;
}
@ -94,10 +91,9 @@ void DumpPropertyValue(std::ostream *os, const storage::PropertyValue &value) {
}
}
void DumpProperties(
std::ostream *os, query::DbAccessor *dba,
const std::map<storage::PropertyId, storage::PropertyValue> &store,
std::optional<int64_t> property_id = std::nullopt) {
void DumpProperties(std::ostream *os, query::DbAccessor *dba,
const std::map<storage::PropertyId, storage::PropertyValue> &store,
std::optional<int64_t> property_id = std::nullopt) {
*os << "{";
if (property_id) {
*os << kInternalPropertyId << ": " << *property_id;
@ -110,24 +106,20 @@ void DumpProperties(
*os << "}";
}
void DumpVertex(std::ostream *os, query::DbAccessor *dba,
const query::VertexAccessor &vertex) {
void DumpVertex(std::ostream *os, query::DbAccessor *dba, const query::VertexAccessor &vertex) {
*os << "CREATE (";
*os << ":" << kInternalVertexLabel;
auto maybe_labels = vertex.Labels(storage::View::OLD);
if (maybe_labels.HasError()) {
switch (maybe_labels.GetError()) {
case storage::Error::DELETED_OBJECT:
throw query::QueryRuntimeException(
"Trying to get labels from a deleted node.");
throw query::QueryRuntimeException("Trying to get labels from a deleted node.");
case storage::Error::NONEXISTENT_OBJECT:
throw query::QueryRuntimeException(
"Trying to get labels from a node that doesn't exist.");
throw query::QueryRuntimeException("Trying to get labels from a node that doesn't exist.");
case storage::Error::SERIALIZATION_ERROR:
case storage::Error::VERTEX_HAS_EDGES:
case storage::Error::PROPERTIES_DISABLED:
throw query::QueryRuntimeException(
"Unexpected error when getting labels.");
throw query::QueryRuntimeException("Unexpected error when getting labels.");
}
}
for (const auto &label : *maybe_labels) {
@ -138,24 +130,20 @@ void DumpVertex(std::ostream *os, query::DbAccessor *dba,
if (maybe_props.HasError()) {
switch (maybe_props.GetError()) {
case storage::Error::DELETED_OBJECT:
throw query::QueryRuntimeException(
"Trying to get properties from a deleted object.");
throw query::QueryRuntimeException("Trying to get properties from a deleted object.");
case storage::Error::NONEXISTENT_OBJECT:
throw query::QueryRuntimeException(
"Trying to get properties from a node that doesn't exist.");
throw query::QueryRuntimeException("Trying to get properties from a node that doesn't exist.");
case storage::Error::SERIALIZATION_ERROR:
case storage::Error::VERTEX_HAS_EDGES:
case storage::Error::PROPERTIES_DISABLED:
throw query::QueryRuntimeException(
"Unexpected error when getting properties.");
throw query::QueryRuntimeException("Unexpected error when getting properties.");
}
}
DumpProperties(os, dba, *maybe_props, vertex.CypherId());
*os << ");";
}
void DumpEdge(std::ostream *os, query::DbAccessor *dba,
const query::EdgeAccessor &edge) {
void DumpEdge(std::ostream *os, query::DbAccessor *dba, const query::EdgeAccessor &edge) {
*os << "MATCH ";
*os << "(u:" << kInternalVertexLabel << "), ";
*os << "(v:" << kInternalVertexLabel << ")";
@ -169,16 +157,13 @@ void DumpEdge(std::ostream *os, query::DbAccessor *dba,
if (maybe_props.HasError()) {
switch (maybe_props.GetError()) {
case storage::Error::DELETED_OBJECT:
throw query::QueryRuntimeException(
"Trying to get properties from a deleted object.");
throw query::QueryRuntimeException("Trying to get properties from a deleted object.");
case storage::Error::NONEXISTENT_OBJECT:
throw query::QueryRuntimeException(
"Trying to get properties from an edge that doesn't exist.");
throw query::QueryRuntimeException("Trying to get properties from an edge that doesn't exist.");
case storage::Error::SERIALIZATION_ERROR:
case storage::Error::VERTEX_HAS_EDGES:
case storage::Error::PROPERTIES_DISABLED:
throw query::QueryRuntimeException(
"Unexpected error when getting properties.");
throw query::QueryRuntimeException("Unexpected error when getting properties.");
}
}
if (maybe_props->size() > 0) {
@ -188,35 +173,28 @@ void DumpEdge(std::ostream *os, query::DbAccessor *dba,
*os << "]->(v);";
}
void DumpLabelIndex(std::ostream *os, query::DbAccessor *dba,
const storage::LabelId label) {
void DumpLabelIndex(std::ostream *os, query::DbAccessor *dba, const storage::LabelId label) {
*os << "CREATE INDEX ON :" << EscapeName(dba->LabelToName(label)) << ";";
}
void DumpLabelPropertyIndex(std::ostream *os, query::DbAccessor *dba,
storage::LabelId label,
void DumpLabelPropertyIndex(std::ostream *os, query::DbAccessor *dba, storage::LabelId label,
storage::PropertyId property) {
*os << "CREATE INDEX ON :" << EscapeName(dba->LabelToName(label)) << "("
<< EscapeName(dba->PropertyToName(property)) << ");";
}
void DumpExistenceConstraint(std::ostream *os, query::DbAccessor *dba,
storage::LabelId label,
storage::PropertyId property) {
*os << "CREATE CONSTRAINT ON (u:" << EscapeName(dba->LabelToName(label))
<< ") ASSERT EXISTS (u." << EscapeName(dba->PropertyToName(property))
*os << "CREATE INDEX ON :" << EscapeName(dba->LabelToName(label)) << "(" << EscapeName(dba->PropertyToName(property))
<< ");";
}
void DumpUniqueConstraint(std::ostream *os, query::DbAccessor *dba,
storage::LabelId label,
void DumpExistenceConstraint(std::ostream *os, query::DbAccessor *dba, storage::LabelId label,
storage::PropertyId property) {
*os << "CREATE CONSTRAINT ON (u:" << EscapeName(dba->LabelToName(label)) << ") ASSERT EXISTS (u."
<< EscapeName(dba->PropertyToName(property)) << ");";
}
void DumpUniqueConstraint(std::ostream *os, query::DbAccessor *dba, storage::LabelId label,
const std::set<storage::PropertyId> &properties) {
*os << "CREATE CONSTRAINT ON (u:" << EscapeName(dba->LabelToName(label))
<< ") ASSERT ";
utils::PrintIterable(
*os, properties, ", ", [&dba](auto &stream, const auto &property) {
stream << "u." << EscapeName(dba->PropertyToName(property));
});
*os << "CREATE CONSTRAINT ON (u:" << EscapeName(dba->LabelToName(label)) << ") ASSERT ";
utils::PrintIterable(*os, properties, ", ", [&dba](auto &stream, const auto &property) {
stream << "u." << EscapeName(dba->PropertyToName(property));
});
*os << " IS UNIQUE;";
}
@ -250,8 +228,7 @@ bool PullPlanDump::Pull(AnyStream *stream, std::optional<int> n) {
// finishes. If the function did not finish streaming all the results,
// std::nullopt should be returned because n results have already been sent.
while (current_chunk_index_ < pull_chunks_.size() && (!n || *n > 0)) {
const auto maybe_streamed_count =
pull_chunks_[current_chunk_index_](stream, n);
const auto maybe_streamed_count = pull_chunks_[current_chunk_index_](stream, n);
if (!maybe_streamed_count) {
// n wasn't large enough to stream all the results from the current chunk
@ -273,9 +250,7 @@ bool PullPlanDump::Pull(AnyStream *stream, std::optional<int> n) {
PullPlanDump::PullChunk PullPlanDump::CreateLabelIndicesPullChunk() {
// Dump all label indices
return [this, global_index = 0U](
AnyStream *stream,
std::optional<int> n) mutable -> std::optional<size_t> {
return [this, global_index = 0U](AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the construction of indices vectors
if (!indices_info_) {
indices_info_.emplace(dba_->ListAllIndices());
@ -301,9 +276,7 @@ PullPlanDump::PullChunk PullPlanDump::CreateLabelIndicesPullChunk() {
}
PullPlanDump::PullChunk PullPlanDump::CreateLabelPropertyIndicesPullChunk() {
return [this, global_index = 0U](
AnyStream *stream,
std::optional<int> n) mutable -> std::optional<size_t> {
return [this, global_index = 0U](AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the construction of indices vectors
if (!indices_info_) {
indices_info_.emplace(dba_->ListAllIndices());
@ -314,8 +287,7 @@ PullPlanDump::PullChunk PullPlanDump::CreateLabelPropertyIndicesPullChunk() {
while (global_index < label_property.size() && (!n || local_counter < *n)) {
std::ostringstream os;
const auto &label_property_index = label_property[global_index];
DumpLabelPropertyIndex(&os, dba_, label_property_index.first,
label_property_index.second);
DumpLabelPropertyIndex(&os, dba_, label_property_index.first, label_property_index.second);
stream->Result({TypedValue(os.str())});
++global_index;
@ -331,9 +303,7 @@ PullPlanDump::PullChunk PullPlanDump::CreateLabelPropertyIndicesPullChunk() {
}
PullPlanDump::PullChunk PullPlanDump::CreateExistenceConstraintsPullChunk() {
return [this, global_index = 0U](
AnyStream *stream,
std::optional<int> n) mutable -> std::optional<size_t> {
return [this, global_index = 0U](AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the construction of constraint vectors
if (!constraints_info_) {
constraints_info_.emplace(dba_->ListAllConstraints());
@ -360,9 +330,7 @@ PullPlanDump::PullChunk PullPlanDump::CreateExistenceConstraintsPullChunk() {
}
PullPlanDump::PullChunk PullPlanDump::CreateUniqueConstraintsPullChunk() {
return [this, global_index = 0U](
AnyStream *stream,
std::optional<int> n) mutable -> std::optional<size_t> {
return [this, global_index = 0U](AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the construction of constraint vectors
if (!constraints_info_) {
constraints_info_.emplace(dba_->ListAllConstraints());
@ -389,12 +357,10 @@ PullPlanDump::PullChunk PullPlanDump::CreateUniqueConstraintsPullChunk() {
}
PullPlanDump::PullChunk PullPlanDump::CreateInternalIndexPullChunk() {
return [this](AnyStream *stream,
std::optional<int>) mutable -> std::optional<size_t> {
return [this](AnyStream *stream, std::optional<int>) mutable -> std::optional<size_t> {
if (vertices_iterable_.begin() != vertices_iterable_.end()) {
std::ostringstream os;
os << "CREATE INDEX ON :" << kInternalVertexLabel << "("
<< kInternalPropertyId << ");";
os << "CREATE INDEX ON :" << kInternalVertexLabel << "(" << kInternalPropertyId << ");";
stream->Result({TypedValue(os.str())});
internal_index_created_ = true;
return 1;
@ -404,10 +370,8 @@ PullPlanDump::PullChunk PullPlanDump::CreateInternalIndexPullChunk() {
}
PullPlanDump::PullChunk PullPlanDump::CreateVertexPullChunk() {
return [this,
maybe_current_iter = std::optional<VertexAccessorIterableIterator>{}](
AnyStream *stream,
std::optional<int> n) mutable -> std::optional<size_t> {
return [this, maybe_current_iter = std::optional<VertexAccessorIterableIterator>{}](
AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the call of begin() function
// If multiple begins are called before an iteration,
// one iteration will make the rest of iterators be in undefined
@ -419,8 +383,7 @@ PullPlanDump::PullChunk PullPlanDump::CreateVertexPullChunk() {
auto &current_iter{*maybe_current_iter};
size_t local_counter = 0;
while (current_iter != vertices_iterable_.end() &&
(!n || local_counter < *n)) {
while (current_iter != vertices_iterable_.end() && (!n || local_counter < *n)) {
std::ostringstream os;
DumpVertex(&os, dba_, *current_iter);
stream->Result({TypedValue(os.str())});
@ -436,74 +399,62 @@ PullPlanDump::PullChunk PullPlanDump::CreateVertexPullChunk() {
}
PullPlanDump::PullChunk PullPlanDump::CreateEdgePullChunk() {
return
[this,
maybe_current_vertex_iter =
std::optional<VertexAccessorIterableIterator>{},
// we need to save the iterable which contains list of accessor so
// our saved iterator is valid in the next run
maybe_edge_iterable = std::shared_ptr<EdgeAccessorIterable>{nullptr},
maybe_current_edge_iter = std::optional<EdgeAccessorIterableIterator>{}](
AnyStream *stream,
std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the call of begin() function
// If multiple begins are called before an iteration,
// one iteration will make the rest of iterators be in undefined
// states.
if (!maybe_current_vertex_iter) {
maybe_current_vertex_iter.emplace(vertices_iterable_.begin());
}
return [this, maybe_current_vertex_iter = std::optional<VertexAccessorIterableIterator>{},
// we need to save the iterable which contains list of accessor so
// our saved iterator is valid in the next run
maybe_edge_iterable = std::shared_ptr<EdgeAccessorIterable>{nullptr},
maybe_current_edge_iter = std::optional<EdgeAccessorIterableIterator>{}](
AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the call of begin() function
// If multiple begins are called before an iteration,
// one iteration will make the rest of iterators be in undefined
// states.
if (!maybe_current_vertex_iter) {
maybe_current_vertex_iter.emplace(vertices_iterable_.begin());
}
auto &current_vertex_iter{*maybe_current_vertex_iter};
size_t local_counter = 0U;
for (; current_vertex_iter != vertices_iterable_.end() &&
(!n || local_counter < *n);
++current_vertex_iter) {
const auto &vertex = *current_vertex_iter;
// If we have a saved iterable from a previous pull
// we need to use the same iterable
if (!maybe_edge_iterable) {
maybe_edge_iterable = std::make_shared<EdgeAccessorIterable>(
vertex.OutEdges(storage::View::OLD));
}
auto &maybe_edges = *maybe_edge_iterable;
MG_ASSERT(maybe_edges.HasValue(), "Invalid database state!");
auto current_edge_iter = maybe_current_edge_iter
? *maybe_current_edge_iter
: maybe_edges->begin();
for (; current_edge_iter != maybe_edges->end() &&
(!n || local_counter < *n);
++current_edge_iter) {
std::ostringstream os;
DumpEdge(&os, dba_, *current_edge_iter);
stream->Result({TypedValue(os.str())});
auto &current_vertex_iter{*maybe_current_vertex_iter};
size_t local_counter = 0U;
for (; current_vertex_iter != vertices_iterable_.end() && (!n || local_counter < *n); ++current_vertex_iter) {
const auto &vertex = *current_vertex_iter;
// If we have a saved iterable from a previous pull
// we need to use the same iterable
if (!maybe_edge_iterable) {
maybe_edge_iterable = std::make_shared<EdgeAccessorIterable>(vertex.OutEdges(storage::View::OLD));
}
auto &maybe_edges = *maybe_edge_iterable;
MG_ASSERT(maybe_edges.HasValue(), "Invalid database state!");
auto current_edge_iter = maybe_current_edge_iter ? *maybe_current_edge_iter : maybe_edges->begin();
for (; current_edge_iter != maybe_edges->end() && (!n || local_counter < *n); ++current_edge_iter) {
std::ostringstream os;
DumpEdge(&os, dba_, *current_edge_iter);
stream->Result({TypedValue(os.str())});
++local_counter;
}
if (current_edge_iter != maybe_edges->end()) {
maybe_current_edge_iter.emplace(current_edge_iter);
return std::nullopt;
}
maybe_current_edge_iter = std::nullopt;
maybe_edge_iterable = nullptr;
}
if (current_vertex_iter == vertices_iterable_.end()) {
return local_counter;
}
++local_counter;
}
if (current_edge_iter != maybe_edges->end()) {
maybe_current_edge_iter.emplace(current_edge_iter);
return std::nullopt;
};
}
maybe_current_edge_iter = std::nullopt;
maybe_edge_iterable = nullptr;
}
if (current_vertex_iter == vertices_iterable_.end()) {
return local_counter;
}
return std::nullopt;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateDropInternalIndexPullChunk() {
return [this](AnyStream *stream, std::optional<int>) {
if (internal_index_created_) {
std::ostringstream os;
os << "DROP INDEX ON :" << kInternalVertexLabel << "("
<< kInternalPropertyId << ");";
os << "DROP INDEX ON :" << kInternalVertexLabel << "(" << kInternalPropertyId << ");";
stream->Result({TypedValue(os.str())});
return 1;
}
@ -515,8 +466,7 @@ PullPlanDump::PullChunk PullPlanDump::CreateInternalIndexCleanupPullChunk() {
return [this](AnyStream *stream, std::optional<int>) {
if (internal_index_created_) {
std::ostringstream os;
os << "MATCH (u) REMOVE u:" << kInternalVertexLabel << ", u."
<< kInternalPropertyId << ";";
os << "MATCH (u) REMOVE u:" << kInternalVertexLabel << ", u." << kInternalPropertyId << ";";
stream->Result({TypedValue(os.str())});
return 1;
}
@ -524,8 +474,6 @@ PullPlanDump::PullChunk PullPlanDump::CreateInternalIndexCleanupPullChunk() {
};
}
void DumpDatabaseToCypherQueries(query::DbAccessor *dba, AnyStream *stream) {
PullPlanDump(dba).Pull(stream, {});
}
void DumpDatabaseToCypherQueries(query::DbAccessor *dba, AnyStream *stream) { PullPlanDump(dba).Pull(stream, {}); }
} // namespace query

View File

@ -23,23 +23,18 @@ struct PullPlanDump {
std::optional<storage::IndicesInfo> indices_info_ = std::nullopt;
std::optional<storage::ConstraintsInfo> constraints_info_ = std::nullopt;
using VertexAccessorIterable =
decltype(std::declval<query::DbAccessor>().Vertices(storage::View::OLD));
using VertexAccessorIterableIterator =
decltype(std::declval<VertexAccessorIterable>().begin());
using VertexAccessorIterable = decltype(std::declval<query::DbAccessor>().Vertices(storage::View::OLD));
using VertexAccessorIterableIterator = decltype(std::declval<VertexAccessorIterable>().begin());
using EdgeAccessorIterable =
decltype(std::declval<VertexAccessor>().OutEdges(storage::View::OLD));
using EdgeAccessorIterableIterator =
decltype(std::declval<EdgeAccessorIterable>().GetValue().begin());
using EdgeAccessorIterable = decltype(std::declval<VertexAccessor>().OutEdges(storage::View::OLD));
using EdgeAccessorIterableIterator = decltype(std::declval<EdgeAccessorIterable>().GetValue().begin());
VertexAccessorIterable vertices_iterable_;
bool internal_index_created_ = false;
size_t current_chunk_index_ = 0;
using PullChunk = std::function<std::optional<size_t>(AnyStream *stream,
std::optional<int> n)>;
using PullChunk = std::function<std::optional<size_t>(AnyStream *stream, std::optional<int> n)>;
// We define every part of the dump query in a self contained function.
// Each functions is responsible of keeping track of its execution status.
// If a function did finish its execution, it should return number of results

View File

@ -45,23 +45,19 @@ class SemanticException : public QueryException {
class UnboundVariableError : public SemanticException {
public:
explicit UnboundVariableError(const std::string &name)
: SemanticException("Unbound variable: " + name + ".") {}
explicit UnboundVariableError(const std::string &name) : SemanticException("Unbound variable: " + name + ".") {}
};
class RedeclareVariableError : public SemanticException {
public:
explicit RedeclareVariableError(const std::string &name)
: SemanticException("Redeclaring variable: " + name + ".") {}
explicit RedeclareVariableError(const std::string &name) : SemanticException("Redeclaring variable: " + name + ".") {}
};
class TypeMismatchError : public SemanticException {
public:
TypeMismatchError(const std::string &name, const std::string &datum,
const std::string &expected)
: SemanticException(
fmt::format("Type mismatch: {} already defined as {}, expected {}.",
name, datum, expected)) {}
TypeMismatchError(const std::string &name, const std::string &datum, const std::string &expected)
: SemanticException(fmt::format("Type mismatch: {} already defined as {}, expected {}.", name, datum, expected)) {
}
};
class UnprovidedParameterError : public QueryException {
@ -72,16 +68,13 @@ class UnprovidedParameterError : public QueryException {
class ProfileInMulticommandTxException : public QueryException {
public:
using QueryException::QueryException;
ProfileInMulticommandTxException()
: QueryException("PROFILE not allowed in multicommand transactions.") {}
ProfileInMulticommandTxException() : QueryException("PROFILE not allowed in multicommand transactions.") {}
};
class IndexInMulticommandTxException : public QueryException {
public:
using QueryException::QueryException;
IndexInMulticommandTxException()
: QueryException(
"Index manipulation not allowed in multicommand transactions.") {}
IndexInMulticommandTxException() : QueryException("Index manipulation not allowed in multicommand transactions.") {}
};
class ConstraintInMulticommandTxException : public QueryException {
@ -96,9 +89,7 @@ class ConstraintInMulticommandTxException : public QueryException {
class InfoInMulticommandTxException : public QueryException {
public:
using QueryException::QueryException;
InfoInMulticommandTxException()
: QueryException(
"Info reporting not allowed in multicommand transactions.") {}
InfoInMulticommandTxException() : QueryException("Info reporting not allowed in multicommand transactions.") {}
};
/**
@ -147,38 +138,30 @@ class RemoveAttachedVertexException : public QueryRuntimeException {
class UserModificationInMulticommandTxException : public QueryException {
public:
UserModificationInMulticommandTxException()
: QueryException(
"Authentication clause not allowed in multicommand transactions.") {
}
: QueryException("Authentication clause not allowed in multicommand transactions.") {}
};
class StreamClauseInMulticommandTxException : public QueryException {
public:
StreamClauseInMulticommandTxException()
: QueryException(
"Stream clause not allowed in multicommand transactions.") {}
StreamClauseInMulticommandTxException() : QueryException("Stream clause not allowed in multicommand transactions.") {}
};
class InvalidArgumentsException : public QueryException {
public:
InvalidArgumentsException(const std::string &argument_name,
const std::string &message)
: QueryException(fmt::format("Invalid arguments sent: {} - {}",
argument_name, message)) {}
InvalidArgumentsException(const std::string &argument_name, const std::string &message)
: QueryException(fmt::format("Invalid arguments sent: {} - {}", argument_name, message)) {}
};
class ReplicationModificationInMulticommandTxException : public QueryException {
public:
ReplicationModificationInMulticommandTxException()
: QueryException(
"Replication clause not allowed in multicommand transactions.") {}
: QueryException("Replication clause not allowed in multicommand transactions.") {}
};
class LockPathModificationInMulticommandTxException : public QueryException {
public:
LockPathModificationInMulticommandTxException()
: QueryException(
"Lock path clause not allowed in multicommand transactions.") {}
: QueryException("Lock path clause not allowed in multicommand transactions.") {}
};
} // namespace query

View File

@ -76,23 +76,17 @@ class ReplicationQuery;
class LockPathQuery;
using TreeCompositeVisitor = ::utils::CompositeVisitor<
SingleQuery, CypherUnion, NamedExpression, OrOperator, XorOperator,
AndOperator, NotOperator, AdditionOperator, SubtractionOperator,
MultiplicationOperator, DivisionOperator, ModOperator, NotEqualOperator,
EqualOperator, LessOperator, GreaterOperator, LessEqualOperator,
GreaterEqualOperator, InListOperator, SubscriptOperator,
ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator,
IsNullOperator, ListLiteral, MapLiteral, PropertyLookup, LabelsTest,
Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Any, None,
CallProcedure, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom,
Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty,
RemoveLabels, Merge, Unwind, RegexMatch>;
SingleQuery, CypherUnion, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, AdditionOperator,
SubtractionOperator, MultiplicationOperator, DivisionOperator, ModOperator, NotEqualOperator, EqualOperator,
LessOperator, GreaterOperator, LessEqualOperator, GreaterEqualOperator, InListOperator, SubscriptOperator,
ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral,
PropertyLookup, LabelsTest, Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Any, None, CallProcedure,
Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels,
RemoveProperty, RemoveLabels, Merge, Unwind, RegexMatch>;
using TreeLeafVisitor =
::utils::LeafVisitor<Identifier, PrimitiveLiteral, ParameterLookup>;
using TreeLeafVisitor = ::utils::LeafVisitor<Identifier, PrimitiveLiteral, ParameterLookup>;
class HierarchicalTreeVisitor : public TreeCompositeVisitor,
public TreeLeafVisitor {
class HierarchicalTreeVisitor : public TreeCompositeVisitor, public TreeLeafVisitor {
public:
using TreeCompositeVisitor::PostVisit;
using TreeCompositeVisitor::PreVisit;
@ -103,21 +97,15 @@ class HierarchicalTreeVisitor : public TreeCompositeVisitor,
template <class TResult>
class ExpressionVisitor
: public ::utils::Visitor<
TResult, NamedExpression, OrOperator, XorOperator, AndOperator,
NotOperator, AdditionOperator, SubtractionOperator,
MultiplicationOperator, DivisionOperator, ModOperator,
NotEqualOperator, EqualOperator, LessOperator, GreaterOperator,
LessEqualOperator, GreaterEqualOperator, InListOperator,
SubscriptOperator, ListSlicingOperator, IfOperator, UnaryPlusOperator,
UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral,
PropertyLookup, LabelsTest, Aggregation, Function, Reduce, Coalesce,
Extract, All, Single, Any, None, ParameterLookup, Identifier,
PrimitiveLiteral, RegexMatch> {};
TResult, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, AdditionOperator,
SubtractionOperator, MultiplicationOperator, DivisionOperator, ModOperator, NotEqualOperator, EqualOperator,
LessOperator, GreaterOperator, LessEqualOperator, GreaterEqualOperator, InListOperator, SubscriptOperator,
ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral,
MapLiteral, PropertyLookup, LabelsTest, Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Any,
None, ParameterLookup, Identifier, PrimitiveLiteral, RegexMatch> {};
template <class TResult>
class QueryVisitor
: public ::utils::Visitor<TResult, CypherQuery, ExplainQuery, ProfileQuery,
IndexQuery, AuthQuery, InfoQuery, ConstraintQuery,
DumpQuery, ReplicationQuery, LockPathQuery> {};
class QueryVisitor : public ::utils::Visitor<TResult, CypherQuery, ExplainQuery, ProfileQuery, IndexQuery, AuthQuery,
InfoQuery, ConstraintQuery, DumpQuery, ReplicationQuery, LockPathQuery> {};
} // namespace query

File diff suppressed because it is too large Load Diff

View File

@ -22,12 +22,10 @@ struct ParsingContext {
class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
public:
explicit CypherMainVisitor(ParsingContext context, AstStorage *storage)
: context_(context), storage_(storage) {}
explicit CypherMainVisitor(ParsingContext context, AstStorage *storage) : context_(context), storage_(storage) {}
private:
Expression *CreateBinaryOperatorByToken(size_t token, Expression *e1,
Expression *e2) {
Expression *CreateBinaryOperatorByToken(size_t token, Expression *e1, Expression *e2) {
switch (token) {
case MemgraphCypher::OR:
return storage_->Create<OrOperator>(e1, e2);
@ -82,8 +80,7 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
for (auto *child : all_children) {
antlr4::tree::TerminalNode *operator_node = nullptr;
if ((operator_node = dynamic_cast<antlr4::tree::TerminalNode *>(child))) {
if (std::find(allowed_operators.begin(), allowed_operators.end(),
operator_node->getSymbol()->getType()) !=
if (std::find(allowed_operators.begin(), allowed_operators.end(), operator_node->getSymbol()->getType()) !=
allowed_operators.end()) {
operators.push_back(operator_node->getSymbol()->getType());
}
@ -100,10 +97,9 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
* expression7.
*/
template <typename TExpression>
Expression *LeftAssociativeOperatorExpression(
std::vector<TExpression *> _expressions,
std::vector<antlr4::tree::ParseTree *> all_children,
const std::vector<size_t> &allowed_operators) {
Expression *LeftAssociativeOperatorExpression(std::vector<TExpression *> _expressions,
std::vector<antlr4::tree::ParseTree *> all_children,
const std::vector<size_t> &allowed_operators) {
DMG_ASSERT(_expressions.size(), "can't happen");
std::vector<Expression *> expressions;
auto operators = ExtractOperators(all_children, allowed_operators);
@ -114,17 +110,14 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
Expression *first_operand = expressions[0];
for (int i = 1; i < (int)expressions.size(); ++i) {
first_operand = CreateBinaryOperatorByToken(
operators[i - 1], first_operand, expressions[i]);
first_operand = CreateBinaryOperatorByToken(operators[i - 1], first_operand, expressions[i]);
}
return first_operand;
}
template <typename TExpression>
Expression *PrefixUnaryOperator(
TExpression *_expression,
std::vector<antlr4::tree::ParseTree *> all_children,
const std::vector<size_t> &allowed_operators) {
Expression *PrefixUnaryOperator(TExpression *_expression, std::vector<antlr4::tree::ParseTree *> all_children,
const std::vector<size_t> &allowed_operators) {
DMG_ASSERT(_expression, "can't happen");
auto operators = ExtractOperators(all_children, allowed_operators);
@ -138,26 +131,22 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
/**
* @return CypherQuery*
*/
antlrcpp::Any visitCypherQuery(
MemgraphCypher::CypherQueryContext *ctx) override;
antlrcpp::Any visitCypherQuery(MemgraphCypher::CypherQueryContext *ctx) override;
/**
* @return IndexQuery*
*/
antlrcpp::Any visitIndexQuery(
MemgraphCypher::IndexQueryContext *ctx) override;
antlrcpp::Any visitIndexQuery(MemgraphCypher::IndexQueryContext *ctx) override;
/**
* @return ExplainQuery*
*/
antlrcpp::Any visitExplainQuery(
MemgraphCypher::ExplainQueryContext *ctx) override;
antlrcpp::Any visitExplainQuery(MemgraphCypher::ExplainQueryContext *ctx) override;
/**
* @return ProfileQuery*
*/
antlrcpp::Any visitProfileQuery(
MemgraphCypher::ProfileQueryContext *ctx) override;
antlrcpp::Any visitProfileQuery(MemgraphCypher::ProfileQueryContext *ctx) override;
/**
* @return InfoQuery*
@ -167,14 +156,12 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
/**
* @return Constraint
*/
antlrcpp::Any visitConstraint(
MemgraphCypher::ConstraintContext *ctx) override;
antlrcpp::Any visitConstraint(MemgraphCypher::ConstraintContext *ctx) override;
/**
* @return ConstraintQuery*
*/
antlrcpp::Any visitConstraintQuery(
MemgraphCypher::ConstraintQueryContext *ctx) override;
antlrcpp::Any visitConstraintQuery(MemgraphCypher::ConstraintQueryContext *ctx) override;
/**
* @return AuthQuery*
@ -189,56 +176,47 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
/**
* @return ReplicationQuery*
*/
antlrcpp::Any visitReplicationQuery(
MemgraphCypher::ReplicationQueryContext *ctx) override;
antlrcpp::Any visitReplicationQuery(MemgraphCypher::ReplicationQueryContext *ctx) override;
/**
* @return ReplicationQuery*
*/
antlrcpp::Any visitSetReplicationRole(
MemgraphCypher::SetReplicationRoleContext *ctx) override;
antlrcpp::Any visitSetReplicationRole(MemgraphCypher::SetReplicationRoleContext *ctx) override;
/**
* @return ReplicationQuery*
*/
antlrcpp::Any visitShowReplicationRole(
MemgraphCypher::ShowReplicationRoleContext *ctx) override;
antlrcpp::Any visitShowReplicationRole(MemgraphCypher::ShowReplicationRoleContext *ctx) override;
/**
* @return ReplicationQuery*
*/
antlrcpp::Any visitRegisterReplica(
MemgraphCypher::RegisterReplicaContext *ctx) override;
antlrcpp::Any visitRegisterReplica(MemgraphCypher::RegisterReplicaContext *ctx) override;
/**
* @return ReplicationQuery*
*/
antlrcpp::Any visitDropReplica(
MemgraphCypher::DropReplicaContext *ctx) override;
antlrcpp::Any visitDropReplica(MemgraphCypher::DropReplicaContext *ctx) override;
/**
* @return ReplicationQuery*
*/
antlrcpp::Any visitShowReplicas(
MemgraphCypher::ShowReplicasContext *ctx) override;
antlrcpp::Any visitShowReplicas(MemgraphCypher::ShowReplicasContext *ctx) override;
/**
* @return LockPathQuery*
*/
antlrcpp::Any visitLockPathQuery(
MemgraphCypher::LockPathQueryContext *ctx) override;
antlrcpp::Any visitLockPathQuery(MemgraphCypher::LockPathQueryContext *ctx) override;
/**
* @return CypherUnion*
*/
antlrcpp::Any visitCypherUnion(
MemgraphCypher::CypherUnionContext *ctx) override;
antlrcpp::Any visitCypherUnion(MemgraphCypher::CypherUnionContext *ctx) override;
/**
* @return SingleQuery*
*/
antlrcpp::Any visitSingleQuery(
MemgraphCypher::SingleQueryContext *ctx) override;
antlrcpp::Any visitSingleQuery(MemgraphCypher::SingleQueryContext *ctx) override;
/**
* @return Clause* or vector<Clause*>!!!
@ -248,8 +226,7 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
/**
* @return Match*
*/
antlrcpp::Any visitCypherMatch(
MemgraphCypher::CypherMatchContext *ctx) override;
antlrcpp::Any visitCypherMatch(MemgraphCypher::CypherMatchContext *ctx) override;
/**
* @return Create*
@ -259,20 +236,17 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
/**
* @return CallProcedure*
*/
antlrcpp::Any visitCallProcedure(
MemgraphCypher::CallProcedureContext *ctx) override;
antlrcpp::Any visitCallProcedure(MemgraphCypher::CallProcedureContext *ctx) override;
/**
* @return std::string
*/
antlrcpp::Any visitUserOrRoleName(
MemgraphCypher::UserOrRoleNameContext *ctx) override;
antlrcpp::Any visitUserOrRoleName(MemgraphCypher::UserOrRoleNameContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitCreateRole(
MemgraphCypher::CreateRoleContext *ctx) override;
antlrcpp::Any visitCreateRole(MemgraphCypher::CreateRoleContext *ctx) override;
/**
* @return AuthQuery*
@ -287,8 +261,7 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
/**
* @return IndexQuery*
*/
antlrcpp::Any visitCreateIndex(
MemgraphCypher::CreateIndexContext *ctx) override;
antlrcpp::Any visitCreateIndex(MemgraphCypher::CreateIndexContext *ctx) override;
/**
* @return DropIndex*
@ -298,14 +271,12 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
/**
* @return AuthQuery*
*/
antlrcpp::Any visitCreateUser(
MemgraphCypher::CreateUserContext *ctx) override;
antlrcpp::Any visitCreateUser(MemgraphCypher::CreateUserContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitSetPassword(
MemgraphCypher::SetPasswordContext *ctx) override;
antlrcpp::Any visitSetPassword(MemgraphCypher::SetPasswordContext *ctx) override;
/**
* @return AuthQuery*
@ -330,20 +301,17 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
/**
* @return AuthQuery*
*/
antlrcpp::Any visitGrantPrivilege(
MemgraphCypher::GrantPrivilegeContext *ctx) override;
antlrcpp::Any visitGrantPrivilege(MemgraphCypher::GrantPrivilegeContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitDenyPrivilege(
MemgraphCypher::DenyPrivilegeContext *ctx) override;
antlrcpp::Any visitDenyPrivilege(MemgraphCypher::DenyPrivilegeContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitRevokePrivilege(
MemgraphCypher::RevokePrivilegeContext *ctx) override;
antlrcpp::Any visitRevokePrivilege(MemgraphCypher::RevokePrivilegeContext *ctx) override;
/**
* @return AuthQuery::Privilege
@ -353,46 +321,39 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
/**
* @return AuthQuery*
*/
antlrcpp::Any visitShowPrivileges(
MemgraphCypher::ShowPrivilegesContext *ctx) override;
antlrcpp::Any visitShowPrivileges(MemgraphCypher::ShowPrivilegesContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitShowRoleForUser(
MemgraphCypher::ShowRoleForUserContext *ctx) override;
antlrcpp::Any visitShowRoleForUser(MemgraphCypher::ShowRoleForUserContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitShowUsersForRole(
MemgraphCypher::ShowUsersForRoleContext *ctx) override;
antlrcpp::Any visitShowUsersForRole(MemgraphCypher::ShowUsersForRoleContext *ctx) override;
/**
* @return Return*
*/
antlrcpp::Any visitCypherReturn(
MemgraphCypher::CypherReturnContext *ctx) override;
antlrcpp::Any visitCypherReturn(MemgraphCypher::CypherReturnContext *ctx) override;
/**
* @return Return*
*/
antlrcpp::Any visitReturnBody(
MemgraphCypher::ReturnBodyContext *ctx) override;
antlrcpp::Any visitReturnBody(MemgraphCypher::ReturnBodyContext *ctx) override;
/**
* @return pair<bool, vector<NamedExpression*>> first member is true if
* asterisk was found in return
* expressions.
*/
antlrcpp::Any visitReturnItems(
MemgraphCypher::ReturnItemsContext *ctx) override;
antlrcpp::Any visitReturnItems(MemgraphCypher::ReturnItemsContext *ctx) override;
/**
* @return vector<NamedExpression*>
*/
antlrcpp::Any visitReturnItem(
MemgraphCypher::ReturnItemContext *ctx) override;
antlrcpp::Any visitReturnItem(MemgraphCypher::ReturnItemContext *ctx) override;
/**
* @return vector<SortItem>
@ -407,44 +368,37 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
/**
* @return NodeAtom*
*/
antlrcpp::Any visitNodePattern(
MemgraphCypher::NodePatternContext *ctx) override;
antlrcpp::Any visitNodePattern(MemgraphCypher::NodePatternContext *ctx) override;
/**
* @return vector<LabelIx>
*/
antlrcpp::Any visitNodeLabels(
MemgraphCypher::NodeLabelsContext *ctx) override;
antlrcpp::Any visitNodeLabels(MemgraphCypher::NodeLabelsContext *ctx) override;
/**
* @return unordered_map<PropertyIx, Expression*>
*/
antlrcpp::Any visitProperties(
MemgraphCypher::PropertiesContext *ctx) override;
antlrcpp::Any visitProperties(MemgraphCypher::PropertiesContext *ctx) override;
/**
* @return map<std::string, Expression*>
*/
antlrcpp::Any visitMapLiteral(
MemgraphCypher::MapLiteralContext *ctx) override;
antlrcpp::Any visitMapLiteral(MemgraphCypher::MapLiteralContext *ctx) override;
/**
* @return vector<Expression*>
*/
antlrcpp::Any visitListLiteral(
MemgraphCypher::ListLiteralContext *ctx) override;
antlrcpp::Any visitListLiteral(MemgraphCypher::ListLiteralContext *ctx) override;
/**
* @return PropertyIx
*/
antlrcpp::Any visitPropertyKeyName(
MemgraphCypher::PropertyKeyNameContext *ctx) override;
antlrcpp::Any visitPropertyKeyName(MemgraphCypher::PropertyKeyNameContext *ctx) override;
/**
* @return string
*/
antlrcpp::Any visitSymbolicName(
MemgraphCypher::SymbolicNameContext *ctx) override;
antlrcpp::Any visitSymbolicName(MemgraphCypher::SymbolicNameContext *ctx) override;
/**
* @return vector<Pattern*>
@ -454,185 +408,160 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
/**
* @return Pattern*
*/
antlrcpp::Any visitPatternPart(
MemgraphCypher::PatternPartContext *ctx) override;
antlrcpp::Any visitPatternPart(MemgraphCypher::PatternPartContext *ctx) override;
/**
* @return Pattern*
*/
antlrcpp::Any visitPatternElement(
MemgraphCypher::PatternElementContext *ctx) override;
antlrcpp::Any visitPatternElement(MemgraphCypher::PatternElementContext *ctx) override;
/**
* @return vector<pair<EdgeAtom*, NodeAtom*>>
*/
antlrcpp::Any visitPatternElementChain(
MemgraphCypher::PatternElementChainContext *ctx) override;
antlrcpp::Any visitPatternElementChain(MemgraphCypher::PatternElementChainContext *ctx) override;
/**
*@return EdgeAtom*
*/
antlrcpp::Any visitRelationshipPattern(
MemgraphCypher::RelationshipPatternContext *ctx) override;
antlrcpp::Any visitRelationshipPattern(MemgraphCypher::RelationshipPatternContext *ctx) override;
/**
* This should never be called. Everything is done directly in
* visitRelationshipPattern.
*/
antlrcpp::Any visitRelationshipDetail(
MemgraphCypher::RelationshipDetailContext *ctx) override;
antlrcpp::Any visitRelationshipDetail(MemgraphCypher::RelationshipDetailContext *ctx) override;
/**
* This should never be called. Everything is done directly in
* visitRelationshipPattern.
*/
antlrcpp::Any visitRelationshipLambda(
MemgraphCypher::RelationshipLambdaContext *ctx) override;
antlrcpp::Any visitRelationshipLambda(MemgraphCypher::RelationshipLambdaContext *ctx) override;
/**
* @return vector<EdgeTypeIx>
*/
antlrcpp::Any visitRelationshipTypes(
MemgraphCypher::RelationshipTypesContext *ctx) override;
antlrcpp::Any visitRelationshipTypes(MemgraphCypher::RelationshipTypesContext *ctx) override;
/**
* @return std::tuple<EdgeAtom::Type, int64_t, int64_t>.
*/
antlrcpp::Any visitVariableExpansion(
MemgraphCypher::VariableExpansionContext *ctx) override;
antlrcpp::Any visitVariableExpansion(MemgraphCypher::VariableExpansionContext *ctx) override;
/**
* Top level expression, does nothing.
*
* @return Expression*
*/
antlrcpp::Any visitExpression(
MemgraphCypher::ExpressionContext *ctx) override;
antlrcpp::Any visitExpression(MemgraphCypher::ExpressionContext *ctx) override;
/**
* OR.
*
* @return Expression*
*/
antlrcpp::Any visitExpression12(
MemgraphCypher::Expression12Context *ctx) override;
antlrcpp::Any visitExpression12(MemgraphCypher::Expression12Context *ctx) override;
/**
* XOR.
*
* @return Expression*
*/
antlrcpp::Any visitExpression11(
MemgraphCypher::Expression11Context *ctx) override;
antlrcpp::Any visitExpression11(MemgraphCypher::Expression11Context *ctx) override;
/**
* AND.
*
* @return Expression*
*/
antlrcpp::Any visitExpression10(
MemgraphCypher::Expression10Context *ctx) override;
antlrcpp::Any visitExpression10(MemgraphCypher::Expression10Context *ctx) override;
/**
* NOT.
*
* @return Expression*
*/
antlrcpp::Any visitExpression9(
MemgraphCypher::Expression9Context *ctx) override;
antlrcpp::Any visitExpression9(MemgraphCypher::Expression9Context *ctx) override;
/**
* Comparisons.
*
* @return Expression*
*/
antlrcpp::Any visitExpression8(
MemgraphCypher::Expression8Context *ctx) override;
antlrcpp::Any visitExpression8(MemgraphCypher::Expression8Context *ctx) override;
/**
* Never call this. Everything related to generating code for comparison
* operators should be done in visitExpression8.
*/
antlrcpp::Any visitPartialComparisonExpression(
MemgraphCypher::PartialComparisonExpressionContext *ctx) override;
antlrcpp::Any visitPartialComparisonExpression(MemgraphCypher::PartialComparisonExpressionContext *ctx) override;
/**
* Addition and subtraction.
*
* @return Expression*
*/
antlrcpp::Any visitExpression7(
MemgraphCypher::Expression7Context *ctx) override;
antlrcpp::Any visitExpression7(MemgraphCypher::Expression7Context *ctx) override;
/**
* Multiplication, division, modding.
*
* @return Expression*
*/
antlrcpp::Any visitExpression6(
MemgraphCypher::Expression6Context *ctx) override;
antlrcpp::Any visitExpression6(MemgraphCypher::Expression6Context *ctx) override;
/**
* Power.
*
* @return Expression*
*/
antlrcpp::Any visitExpression5(
MemgraphCypher::Expression5Context *ctx) override;
antlrcpp::Any visitExpression5(MemgraphCypher::Expression5Context *ctx) override;
/**
* Unary minus and plus.
*
* @return Expression*
*/
antlrcpp::Any visitExpression4(
MemgraphCypher::Expression4Context *ctx) override;
antlrcpp::Any visitExpression4(MemgraphCypher::Expression4Context *ctx) override;
/**
* IS NULL, IS NOT NULL, STARTS WITH, END WITH, =~, ...
*
* @return Expression*
*/
antlrcpp::Any visitExpression3a(
MemgraphCypher::Expression3aContext *ctx) override;
antlrcpp::Any visitExpression3a(MemgraphCypher::Expression3aContext *ctx) override;
/**
* Does nothing, everything is done in visitExpression3a.
*
* @return Expression*
*/
antlrcpp::Any visitStringAndNullOperators(
MemgraphCypher::StringAndNullOperatorsContext *ctx) override;
antlrcpp::Any visitStringAndNullOperators(MemgraphCypher::StringAndNullOperatorsContext *ctx) override;
/**
* List indexing and slicing.
*
* @return Expression*
*/
antlrcpp::Any visitExpression3b(
MemgraphCypher::Expression3bContext *ctx) override;
antlrcpp::Any visitExpression3b(MemgraphCypher::Expression3bContext *ctx) override;
/**
* Does nothing, everything is done in visitExpression3b.
*/
antlrcpp::Any visitListIndexingOrSlicing(
MemgraphCypher::ListIndexingOrSlicingContext *ctx) override;
antlrcpp::Any visitListIndexingOrSlicing(MemgraphCypher::ListIndexingOrSlicingContext *ctx) override;
/**
* Node labels test.
*
* @return Expression*
*/
antlrcpp::Any visitExpression2a(
MemgraphCypher::Expression2aContext *ctx) override;
antlrcpp::Any visitExpression2a(MemgraphCypher::Expression2aContext *ctx) override;
/**
* Property lookup.
*
* @return Expression*
*/
antlrcpp::Any visitExpression2b(
MemgraphCypher::Expression2bContext *ctx) override;
antlrcpp::Any visitExpression2b(MemgraphCypher::Expression2bContext *ctx) override;
/**
* Literals, params, list comprehension...
@ -649,20 +578,17 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
/**
* @return Expression*
*/
antlrcpp::Any visitParenthesizedExpression(
MemgraphCypher::ParenthesizedExpressionContext *ctx) override;
antlrcpp::Any visitParenthesizedExpression(MemgraphCypher::ParenthesizedExpressionContext *ctx) override;
/**
* @return Expression*
*/
antlrcpp::Any visitFunctionInvocation(
MemgraphCypher::FunctionInvocationContext *ctx) override;
antlrcpp::Any visitFunctionInvocation(MemgraphCypher::FunctionInvocationContext *ctx) override;
/**
* @return string - uppercased
*/
antlrcpp::Any visitFunctionName(
MemgraphCypher::FunctionNameContext *ctx) override;
antlrcpp::Any visitFunctionName(MemgraphCypher::FunctionNameContext *ctx) override;
/**
* @return Expression*
@ -679,32 +605,27 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
/**
* @return bool
*/
antlrcpp::Any visitBooleanLiteral(
MemgraphCypher::BooleanLiteralContext *ctx) override;
antlrcpp::Any visitBooleanLiteral(MemgraphCypher::BooleanLiteralContext *ctx) override;
/**
* @return TypedValue with either double or int
*/
antlrcpp::Any visitNumberLiteral(
MemgraphCypher::NumberLiteralContext *ctx) override;
antlrcpp::Any visitNumberLiteral(MemgraphCypher::NumberLiteralContext *ctx) override;
/**
* @return int64_t
*/
antlrcpp::Any visitIntegerLiteral(
MemgraphCypher::IntegerLiteralContext *ctx) override;
antlrcpp::Any visitIntegerLiteral(MemgraphCypher::IntegerLiteralContext *ctx) override;
/**
* @return double
*/
antlrcpp::Any visitDoubleLiteral(
MemgraphCypher::DoubleLiteralContext *ctx) override;
antlrcpp::Any visitDoubleLiteral(MemgraphCypher::DoubleLiteralContext *ctx) override;
/**
* @return Delete*
*/
antlrcpp::Any visitCypherDelete(
MemgraphCypher::CypherDeleteContext *ctx) override;
antlrcpp::Any visitCypherDelete(MemgraphCypher::CypherDeleteContext *ctx) override;
/**
* @return Where*
@ -729,27 +650,23 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
/**
* @return Clause*
*/
antlrcpp::Any visitRemoveItem(
MemgraphCypher::RemoveItemContext *ctx) override;
antlrcpp::Any visitRemoveItem(MemgraphCypher::RemoveItemContext *ctx) override;
/**
* @return PropertyLookup*
*/
antlrcpp::Any visitPropertyExpression(
MemgraphCypher::PropertyExpressionContext *ctx) override;
antlrcpp::Any visitPropertyExpression(MemgraphCypher::PropertyExpressionContext *ctx) override;
/**
* @return IfOperator*
*/
antlrcpp::Any visitCaseExpression(
MemgraphCypher::CaseExpressionContext *ctx) override;
antlrcpp::Any visitCaseExpression(MemgraphCypher::CaseExpressionContext *ctx) override;
/**
* Never call this. Ast generation for this production is done in
* @c visitCaseExpression.
*/
antlrcpp::Any visitCaseAlternatives(
MemgraphCypher::CaseAlternativesContext *ctx) override;
antlrcpp::Any visitCaseAlternatives(MemgraphCypher::CaseAlternativesContext *ctx) override;
/**
* @return With*
@ -770,8 +687,7 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
* Never call this. Ast generation for these expressions should be done by
* explicitly visiting the members of @c FilterExpressionContext.
*/
antlrcpp::Any visitFilterExpression(
MemgraphCypher::FilterExpressionContext *) override;
antlrcpp::Any visitFilterExpression(MemgraphCypher::FilterExpressionContext *) override;
public:
Query *query() { return query_; }

View File

@ -89,22 +89,17 @@ void PrintObject(std::ostream *out, const std::map<K, V> &map);
template <typename T>
void PrintObject(std::ostream *out, const T &arg) {
static_assert(
!std::is_convertible<T, Expression *>::value,
"This overload shouldn't be called with pointers convertible "
"to Expression *. This means your other PrintObject overloads aren't "
"being called for certain AST nodes when they should (or perhaps such "
"overloads don't exist yet).");
static_assert(!std::is_convertible<T, Expression *>::value,
"This overload shouldn't be called with pointers convertible "
"to Expression *. This means your other PrintObject overloads aren't "
"being called for certain AST nodes when they should (or perhaps such "
"overloads don't exist yet).");
*out << arg;
}
void PrintObject(std::ostream *out, const std::string &str) {
*out << utils::Escape(str);
}
void PrintObject(std::ostream *out, const std::string &str) { *out << utils::Escape(str); }
void PrintObject(std::ostream *out, Aggregation::Op op) {
*out << Aggregation::OpToString(op);
}
void PrintObject(std::ostream *out, Aggregation::Op op) { *out << Aggregation::OpToString(op); }
void PrintObject(std::ostream *out, Expression *expr) {
if (expr) {
@ -115,9 +110,7 @@ void PrintObject(std::ostream *out, Expression *expr) {
}
}
void PrintObject(std::ostream *out, Identifier *expr) {
PrintObject(out, static_cast<Expression *>(expr));
}
void PrintObject(std::ostream *out, Identifier *expr) { PrintObject(out, static_cast<Expression *>(expr)); }
void PrintObject(std::ostream *out, const storage::PropertyValue &value) {
switch (value.type()) {
@ -154,9 +147,7 @@ void PrintObject(std::ostream *out, const storage::PropertyValue &value) {
template <typename T>
void PrintObject(std::ostream *out, const std::vector<T> &vec) {
*out << "[";
utils::PrintIterable(*out, vec, ", ", [](auto &stream, const auto &item) {
PrintObject(&stream, item);
});
utils::PrintIterable(*out, vec, ", ", [](auto &stream, const auto &item) { PrintObject(&stream, item); });
*out << "]";
}
@ -179,26 +170,22 @@ void PrintOperatorArgs(std::ostream *out, const T &arg) {
}
template <typename T, typename... Ts>
void PrintOperatorArgs(std::ostream *out, const T &arg, const Ts &... args) {
void PrintOperatorArgs(std::ostream *out, const T &arg, const Ts &...args) {
*out << " ";
PrintObject(out, arg);
PrintOperatorArgs(out, args...);
}
template <typename... Ts>
void PrintOperator(std::ostream *out, const std::string &name,
const Ts &... args) {
void PrintOperator(std::ostream *out, const std::string &name, const Ts &...args) {
*out << "(" << name;
PrintOperatorArgs(out, args...);
}
ExpressionPrettyPrinter::ExpressionPrettyPrinter(std::ostream *out)
: out_(out) {}
ExpressionPrettyPrinter::ExpressionPrettyPrinter(std::ostream *out) : out_(out) {}
#define UNARY_OPERATOR_VISIT(OP_NODE, OP_STR) \
void ExpressionPrettyPrinter::Visit(OP_NODE &op) { \
PrintOperator(out_, OP_STR, op.expression_); \
}
#define UNARY_OPERATOR_VISIT(OP_NODE, OP_STR) \
void ExpressionPrettyPrinter::Visit(OP_NODE &op) { PrintOperator(out_, OP_STR, op.expression_); }
UNARY_OPERATOR_VISIT(NotOperator, "Not");
UNARY_OPERATOR_VISIT(UnaryPlusOperator, "+");
@ -207,10 +194,8 @@ UNARY_OPERATOR_VISIT(IsNullOperator, "IsNull");
#undef UNARY_OPERATOR_VISIT
#define BINARY_OPERATOR_VISIT(OP_NODE, OP_STR) \
void ExpressionPrettyPrinter::Visit(OP_NODE &op) { \
PrintOperator(out_, OP_STR, op.expression1_, op.expression2_); \
}
#define BINARY_OPERATOR_VISIT(OP_NODE, OP_STR) \
void ExpressionPrettyPrinter::Visit(OP_NODE &op) { PrintOperator(out_, OP_STR, op.expression1_, op.expression2_); }
BINARY_OPERATOR_VISIT(OrOperator, "Or");
BINARY_OPERATOR_VISIT(XorOperator, "Xor");
@ -232,18 +217,14 @@ BINARY_OPERATOR_VISIT(SubscriptOperator, "Subscript");
#undef BINARY_OPERATOR_VISIT
void ExpressionPrettyPrinter::Visit(ListSlicingOperator &op) {
PrintOperator(out_, "ListSlicing", op.list_, op.lower_bound_,
op.upper_bound_);
PrintOperator(out_, "ListSlicing", op.list_, op.lower_bound_, op.upper_bound_);
}
void ExpressionPrettyPrinter::Visit(IfOperator &op) {
PrintOperator(out_, "If", op.condition_, op.then_expression_,
op.else_expression_);
PrintOperator(out_, "If", op.condition_, op.then_expression_, op.else_expression_);
}
void ExpressionPrettyPrinter::Visit(ListLiteral &op) {
PrintOperator(out_, "ListLiteral", op.elements_);
}
void ExpressionPrettyPrinter::Visit(ListLiteral &op) { PrintOperator(out_, "ListLiteral", op.elements_); }
void ExpressionPrettyPrinter::Visit(MapLiteral &op) {
std::map<std::string, Expression *> map;
@ -253,74 +234,53 @@ void ExpressionPrettyPrinter::Visit(MapLiteral &op) {
PrintObject(out_, map);
}
void ExpressionPrettyPrinter::Visit(LabelsTest &op) {
PrintOperator(out_, "LabelsTest", op.expression_);
}
void ExpressionPrettyPrinter::Visit(LabelsTest &op) { PrintOperator(out_, "LabelsTest", op.expression_); }
void ExpressionPrettyPrinter::Visit(Aggregation &op) {
PrintOperator(out_, "Aggregation", op.op_);
}
void ExpressionPrettyPrinter::Visit(Aggregation &op) { PrintOperator(out_, "Aggregation", op.op_); }
void ExpressionPrettyPrinter::Visit(Function &op) {
PrintOperator(out_, "Function", op.function_name_, op.arguments_);
}
void ExpressionPrettyPrinter::Visit(Function &op) { PrintOperator(out_, "Function", op.function_name_, op.arguments_); }
void ExpressionPrettyPrinter::Visit(Reduce &op) {
PrintOperator(out_, "Reduce", op.accumulator_, op.initializer_,
op.identifier_, op.list_, op.expression_);
PrintOperator(out_, "Reduce", op.accumulator_, op.initializer_, op.identifier_, op.list_, op.expression_);
}
void ExpressionPrettyPrinter::Visit(Coalesce &op) {
PrintOperator(out_, "Coalesce", op.expressions_);
}
void ExpressionPrettyPrinter::Visit(Coalesce &op) { PrintOperator(out_, "Coalesce", op.expressions_); }
void ExpressionPrettyPrinter::Visit(Extract &op) {
PrintOperator(out_, "Extract", op.identifier_, op.list_, op.expression_);
}
void ExpressionPrettyPrinter::Visit(All &op) {
PrintOperator(out_, "All", op.identifier_, op.list_expression_,
op.where_->expression_);
PrintOperator(out_, "All", op.identifier_, op.list_expression_, op.where_->expression_);
}
void ExpressionPrettyPrinter::Visit(Single &op) {
PrintOperator(out_, "Single", op.identifier_, op.list_expression_,
op.where_->expression_);
PrintOperator(out_, "Single", op.identifier_, op.list_expression_, op.where_->expression_);
}
void ExpressionPrettyPrinter::Visit(Any &op) {
PrintOperator(out_, "Any", op.identifier_, op.list_expression_,
op.where_->expression_);
PrintOperator(out_, "Any", op.identifier_, op.list_expression_, op.where_->expression_);
}
void ExpressionPrettyPrinter::Visit(None &op) {
PrintOperator(out_, "None", op.identifier_, op.list_expression_,
op.where_->expression_);
PrintOperator(out_, "None", op.identifier_, op.list_expression_, op.where_->expression_);
}
void ExpressionPrettyPrinter::Visit(Identifier &op) {
PrintOperator(out_, "Identifier", op.name_);
}
void ExpressionPrettyPrinter::Visit(Identifier &op) { PrintOperator(out_, "Identifier", op.name_); }
void ExpressionPrettyPrinter::Visit(PrimitiveLiteral &op) {
PrintObject(out_, op.value_);
}
void ExpressionPrettyPrinter::Visit(PrimitiveLiteral &op) { PrintObject(out_, op.value_); }
void ExpressionPrettyPrinter::Visit(PropertyLookup &op) {
PrintOperator(out_, "PropertyLookup", op.expression_, op.property_.name);
}
void ExpressionPrettyPrinter::Visit(ParameterLookup &op) {
PrintOperator(out_, "ParameterLookup", op.token_position_);
}
void ExpressionPrettyPrinter::Visit(ParameterLookup &op) { PrintOperator(out_, "ParameterLookup", op.token_position_); }
void ExpressionPrettyPrinter::Visit(NamedExpression &op) {
PrintOperator(out_, "NamedExpression", op.name_, op.expression_);
}
void ExpressionPrettyPrinter::Visit(RegexMatch &op) {
PrintOperator(out_, "=~", op.string_expr_, op.regex_);
}
void ExpressionPrettyPrinter::Visit(RegexMatch &op) { PrintOperator(out_, "=~", op.string_expr_, op.regex_); }
} // namespace

View File

@ -35,12 +35,10 @@ class Parser {
private:
class FirstMessageErrorListener : public antlr4::BaseErrorListener {
void syntaxError(antlr4::IRecognizer *, antlr4::Token *, size_t line,
size_t position, const std::string &message,
void syntaxError(antlr4::IRecognizer *, antlr4::Token *, size_t line, size_t position, const std::string &message,
std::exception_ptr) override {
if (error_.empty()) {
error_ = "line " + std::to_string(line) + ":" +
std::to_string(position + 1) + " " + message;
error_ = "line " + std::to_string(line) + ":" + std::to_string(position + 1) + " " + message;
}
}

View File

@ -26,8 +26,7 @@ std::string ParseStringLiteral(const std::string &s) {
auto EncodeEscapedUnicodeCodepointUtf32 = [](const std::string &s, int &i) {
const int kLongUnicodeLength = 8;
int j = i + 1;
while (j < static_cast<int>(s.size()) - 1 &&
j < i + kLongUnicodeLength + 1 && isxdigit(s[j])) {
while (j < static_cast<int>(s.size()) - 1 && j < i + kLongUnicodeLength + 1 && isxdigit(s[j])) {
++j;
}
if (j - i == kLongUnicodeLength + 1) {
@ -43,8 +42,7 @@ std::string ParseStringLiteral(const std::string &s) {
auto EncodeEscapedUnicodeCodepointUtf16 = [](const std::string &s, int &i) {
const int kShortUnicodeLength = 4;
int j = i + 1;
while (j < static_cast<int>(s.size()) - 1 &&
j < i + kShortUnicodeLength + 1 && isxdigit(s[j])) {
while (j < static_cast<int>(s.size()) - 1 && j < i + kShortUnicodeLength + 1 && isxdigit(s[j])) {
++j;
}
if (j - i >= kShortUnicodeLength + 1) {
@ -56,31 +54,24 @@ std::string ParseStringLiteral(const std::string &s) {
throw SemanticException("Invalid UTF codepoint.");
}
++j;
if (j >= static_cast<int>(s.size()) - 1 ||
(s[j] != 'u' && s[j] != 'U')) {
if (j >= static_cast<int>(s.size()) - 1 || (s[j] != 'u' && s[j] != 'U')) {
throw SemanticException("Invalid UTF codepoint.");
}
++j;
int k = j;
while (k < static_cast<int>(s.size()) - 1 &&
k < j + kShortUnicodeLength && isxdigit(s[k])) {
while (k < static_cast<int>(s.size()) - 1 && k < j + kShortUnicodeLength && isxdigit(s[k])) {
++k;
}
if (k != j + kShortUnicodeLength) {
throw SemanticException("Invalid UTF codepoint.");
}
char16_t surrogates[3] = {t,
static_cast<char16_t>(stoi(
s.substr(j, kShortUnicodeLength), 0, 16)),
0};
char16_t surrogates[3] = {t, static_cast<char16_t>(stoi(s.substr(j, kShortUnicodeLength), 0, 16)), 0};
i += kShortUnicodeLength + 2 + kShortUnicodeLength;
std::wstring_convert<std::codecvt_utf8_utf16<char16_t>, char16_t>
converter;
std::wstring_convert<std::codecvt_utf8_utf16<char16_t>, char16_t> converter;
return converter.to_bytes(surrogates);
} else {
i += kShortUnicodeLength;
std::wstring_convert<std::codecvt_utf8_utf16<char16_t>, char16_t>
converter;
std::wstring_convert<std::codecvt_utf8_utf16<char16_t>, char16_t> converter;
return converter.to_bytes(t);
}
}
@ -168,8 +159,7 @@ std::string ParseParameter(const std::string &s) {
if (s[1] != '`') return s.substr(1);
// If parameter name is escaped symbolic name then symbolic name should be
// unescaped and leading and trailing backquote should be removed.
DMG_ASSERT(s.size() > 3U && s.back() == '`',
"Invalid string passed as parameter name");
DMG_ASSERT(s.size() > 3U && s.back() == '`', "Invalid string passed as parameter name");
std::string out;
for (int i = 2; i < static_cast<int>(s.size()) - 1; ++i) {
if (s[i] == '`') {

View File

@ -2,8 +2,7 @@
namespace query {
class PrivilegeExtractor : public QueryVisitor<void>,
public HierarchicalTreeVisitor {
class PrivilegeExtractor : public QueryVisitor<void>, public HierarchicalTreeVisitor {
public:
using HierarchicalTreeVisitor::PostVisit;
using HierarchicalTreeVisitor::PreVisit;
@ -12,19 +11,13 @@ class PrivilegeExtractor : public QueryVisitor<void>,
std::vector<AuthQuery::Privilege> privileges() { return privileges_; }
void Visit(IndexQuery &) override {
AddPrivilege(AuthQuery::Privilege::INDEX);
}
void Visit(IndexQuery &) override { AddPrivilege(AuthQuery::Privilege::INDEX); }
void Visit(AuthQuery &) override { AddPrivilege(AuthQuery::Privilege::AUTH); }
void Visit(ExplainQuery &query) override {
query.cypher_query_->Accept(*this);
}
void Visit(ExplainQuery &query) override { query.cypher_query_->Accept(*this); }
void Visit(ProfileQuery &query) override {
query.cypher_query_->Accept(*this);
}
void Visit(ProfileQuery &query) override { query.cypher_query_->Accept(*this); }
void Visit(InfoQuery &info_query) override {
switch (info_query.info_type_) {
@ -44,9 +37,7 @@ class PrivilegeExtractor : public QueryVisitor<void>,
}
}
void Visit(ConstraintQuery &constraint_query) override {
AddPrivilege(AuthQuery::Privilege::CONSTRAINT);
}
void Visit(ConstraintQuery &constraint_query) override { AddPrivilege(AuthQuery::Privilege::CONSTRAINT); }
void Visit(CypherQuery &query) override {
query.single_query_->Accept(*this);
@ -55,13 +46,9 @@ class PrivilegeExtractor : public QueryVisitor<void>,
}
}
void Visit(DumpQuery &dump_query) override {
AddPrivilege(AuthQuery::Privilege::DUMP);
}
void Visit(DumpQuery &dump_query) override { AddPrivilege(AuthQuery::Privilege::DUMP); }
void Visit(LockPathQuery &lock_path_query) override {
AddPrivilege(AuthQuery::Privilege::LOCK_PATH);
}
void Visit(LockPathQuery &lock_path_query) override { AddPrivilege(AuthQuery::Privilege::LOCK_PATH); }
void Visit(ReplicationQuery &replication_query) override {
switch (replication_query.action_) {

View File

@ -12,24 +12,19 @@
namespace query {
auto SymbolGenerator::CreateSymbol(const std::string &name, bool user_declared,
Symbol::Type type, int token_position) {
auto symbol =
symbol_table_.CreateSymbol(name, user_declared, type, token_position);
auto SymbolGenerator::CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type, int token_position) {
auto symbol = symbol_table_.CreateSymbol(name, user_declared, type, token_position);
scope_.symbols[name] = symbol;
return symbol;
}
auto SymbolGenerator::GetOrCreateSymbol(const std::string &name,
bool user_declared, Symbol::Type type) {
auto SymbolGenerator::GetOrCreateSymbol(const std::string &name, bool user_declared, Symbol::Type type) {
auto search = scope_.symbols.find(name);
if (search != scope_.symbols.end()) {
auto symbol = search->second;
// Unless we have `ANY` type, check that types match.
if (type != Symbol::Type::ANY && symbol.type() != Symbol::Type::ANY &&
type != symbol.type()) {
throw TypeMismatchError(name, Symbol::TypeToString(symbol.type()),
Symbol::TypeToString(type));
if (type != Symbol::Type::ANY && symbol.type() != Symbol::Type::ANY && type != symbol.type()) {
throw TypeMismatchError(name, Symbol::TypeToString(symbol.type()), Symbol::TypeToString(type));
}
return search->second;
}
@ -50,8 +45,7 @@ void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) {
user_symbols.emplace_back(sym_pair.second);
}
if (user_symbols.empty()) {
throw SemanticException(
"There are no variables in scope to use for '*'.");
throw SemanticException("There are no variables in scope to use for '*'.");
}
}
// WITH/RETURN clause removes declarations of all the previous variables and
@ -74,13 +68,11 @@ void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) {
for (auto &named_expr : body.named_expressions) {
const auto &name = named_expr->name_;
if (!new_names.insert(name).second) {
throw SemanticException(
"Multiple results with the same name '{}' are not allowed.", name);
throw SemanticException("Multiple results with the same name '{}' are not allowed.", name);
}
// An improvement would be to infer the type of the expression, so that the
// new symbol would have a more specific type.
named_expr->MapTo(CreateSymbol(name, true, Symbol::Type::ANY,
named_expr->token_position_));
named_expr->MapTo(CreateSymbol(name, true, Symbol::Type::ANY, named_expr->token_position_));
}
scope_.in_order_by = true;
for (const auto &order_pair : body.order_by) {
@ -102,8 +94,7 @@ void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) {
// We have an ORDER BY or WHERE, but no aggregation, which means we didn't
// clear the old symbols, so do it now. We cannot just call clear, because
// we've added new symbols.
for (auto sym_it = scope_.symbols.begin();
sym_it != scope_.symbols.end();) {
for (auto sym_it = scope_.symbols.begin(); sym_it != scope_.symbols.end();) {
if (new_names.find(sym_it->first) == new_names.end()) {
sym_it = scope_.symbols.erase(sym_it);
} else {
@ -131,8 +122,7 @@ bool SymbolGenerator::PreVisit(CypherUnion &) {
bool SymbolGenerator::PostVisit(CypherUnion &cypher_union) {
if (prev_return_names_ != curr_return_names_) {
throw SemanticException(
"All subqueries in an UNION must have the same column names.");
throw SemanticException("All subqueries in an UNION must have the same column names.");
}
// create new symbols for the result of the union
@ -180,8 +170,7 @@ bool SymbolGenerator::PreVisit(Return &ret) {
}
bool SymbolGenerator::PostVisit(Return &) {
for (const auto &name_symbol : scope_.symbols)
curr_return_names_.insert(name_symbol.first);
for (const auto &name_symbol : scope_.symbols) curr_return_names_.insert(name_symbol.first);
return true;
}
@ -239,15 +228,13 @@ bool SymbolGenerator::PostVisit(Match &) {
SymbolGenerator::ReturnType SymbolGenerator::Visit(Identifier &ident) {
if (scope_.in_skip || scope_.in_limit) {
throw SemanticException("Variables are not allowed in {}.",
scope_.in_skip ? "SKIP" : "LIMIT");
throw SemanticException("Variables are not allowed in {}.", scope_.in_skip ? "SKIP" : "LIMIT");
}
Symbol symbol;
if (scope_.in_pattern && !(scope_.in_node_atom || scope_.visiting_edge)) {
// If we are in the pattern, and outside of a node or an edge, the
// identifier is the pattern name.
symbol = GetOrCreateSymbol(ident.name_, ident.user_declared_,
Symbol::Type::PATH);
symbol = GetOrCreateSymbol(ident.name_, ident.user_declared_, Symbol::Type::PATH);
} else if (scope_.in_pattern && scope_.in_pattern_atom_identifier) {
// Patterns used to create nodes and edges cannot redeclare already
// established bindings. Declaration only happens in single node
@ -255,8 +242,7 @@ SymbolGenerator::ReturnType SymbolGenerator::Visit(Identifier &ident) {
// `MATCH (n) CREATE (n)` should throw an error that `n` is already
// declared. While `MATCH (n) CREATE (n) -[:R]-> (n)` is allowed,
// since `n` now references the bound node instead of declaring it.
if ((scope_.in_create_node || scope_.in_create_edge) &&
HasSymbol(ident.name_)) {
if ((scope_.in_create_node || scope_.in_create_edge) && HasSymbol(ident.name_)) {
throw RedeclareVariableError(ident.name_);
}
auto type = Symbol::Type::VERTEX;
@ -266,14 +252,11 @@ SymbolGenerator::ReturnType SymbolGenerator::Visit(Identifier &ident) {
if (HasSymbol(ident.name_)) {
throw RedeclareVariableError(ident.name_);
}
type = scope_.visiting_edge->IsVariable() ? Symbol::Type::EDGE_LIST
: Symbol::Type::EDGE;
type = scope_.visiting_edge->IsVariable() ? Symbol::Type::EDGE_LIST : Symbol::Type::EDGE;
}
symbol = GetOrCreateSymbol(ident.name_, ident.user_declared_, type);
} else if (scope_.in_pattern && !scope_.in_pattern_atom_identifier &&
scope_.in_match) {
if (scope_.in_edge_range &&
scope_.visiting_edge->identifier_->name_ == ident.name_) {
} else if (scope_.in_pattern && !scope_.in_pattern_atom_identifier && scope_.in_match) {
if (scope_.in_edge_range && scope_.visiting_edge->identifier_->name_ == ident.name_) {
// Prevent variable path bounds to reference the identifier which is bound
// by the variable path itself.
throw UnboundVariableError(ident.name_);
@ -295,10 +278,9 @@ bool SymbolGenerator::PreVisit(Aggregation &aggr) {
// Check if the aggregation can be used in this context. This check should
// probably move to a separate phase, which checks if the query is well
// formed.
if ((!scope_.in_return && !scope_.in_with) || scope_.in_order_by ||
scope_.in_skip || scope_.in_limit || scope_.in_where) {
throw SemanticException(
"Aggregation functions are only allowed in WITH and RETURN.");
if ((!scope_.in_return && !scope_.in_with) || scope_.in_order_by || scope_.in_skip || scope_.in_limit ||
scope_.in_where) {
throw SemanticException("Aggregation functions are only allowed in WITH and RETURN.");
}
if (scope_.in_aggregation) {
throw SemanticException(
@ -313,13 +295,11 @@ bool SymbolGenerator::PreVisit(Aggregation &aggr) {
// CASE count(n) WHEN 10 THEN "YES" ELSE "NO" END.
// TODO: Rethink of allowing aggregations in some parts of the CASE
// construct.
throw SemanticException(
"Using aggregation functions inside of CASE is not allowed.");
throw SemanticException("Using aggregation functions inside of CASE is not allowed.");
}
// Create a virtual symbol for aggregation result.
// Currently, we only have aggregation operators which return numbers.
auto aggr_name =
Aggregation::OpToString(aggr.op_) + std::to_string(aggr.symbol_pos_);
auto aggr_name = Aggregation::OpToString(aggr.op_) + std::to_string(aggr.symbol_pos_);
aggr.MapTo(CreateSymbol(aggr_name, false, Symbol::Type::NUMBER));
scope_.in_aggregation = true;
scope_.has_aggregation = true;
@ -368,8 +348,7 @@ bool SymbolGenerator::PreVisit(None &none) {
bool SymbolGenerator::PreVisit(Reduce &reduce) {
reduce.initializer_->Accept(*this);
reduce.list_->Accept(*this);
VisitWithIdentifiers(reduce.expression_,
{reduce.accumulator_, reduce.identifier_});
VisitWithIdentifiers(reduce.expression_, {reduce.accumulator_, reduce.identifier_});
return false;
}
@ -384,8 +363,7 @@ bool SymbolGenerator::PreVisit(Extract &extract) {
bool SymbolGenerator::PreVisit(Pattern &pattern) {
scope_.in_pattern = true;
if ((scope_.in_create || scope_.in_merge) && pattern.atoms_.size() == 1U) {
MG_ASSERT(utils::IsSubtype(*pattern.atoms_[0], NodeAtom::kType),
"Expected a single NodeAtom in Pattern");
MG_ASSERT(utils::IsSubtype(*pattern.atoms_[0], NodeAtom::kType), "Expected a single NodeAtom in Pattern");
scope_.in_create_node = true;
}
return true;
@ -399,14 +377,11 @@ bool SymbolGenerator::PostVisit(Pattern &) {
bool SymbolGenerator::PreVisit(NodeAtom &node_atom) {
scope_.in_node_atom = true;
bool props_or_labels =
!node_atom.properties_.empty() || !node_atom.labels_.empty();
bool props_or_labels = !node_atom.properties_.empty() || !node_atom.labels_.empty();
const auto &node_name = node_atom.identifier_->name_;
if ((scope_.in_create || scope_.in_merge) && props_or_labels &&
HasSymbol(node_name)) {
throw SemanticException(
"Cannot create node '" + node_name +
"' with labels or properties, because it is already declared.");
if ((scope_.in_create || scope_.in_merge) && props_or_labels && HasSymbol(node_name)) {
throw SemanticException("Cannot create node '" + node_name +
"' with labels or properties, because it is already declared.");
}
for (auto kv : node_atom.properties_) {
kv.second->Accept(*this);
@ -458,22 +433,19 @@ bool SymbolGenerator::PreVisit(EdgeAtom &edge_atom) {
scope_.in_pattern = false;
if (edge_atom.filter_lambda_.expression) {
VisitWithIdentifiers(edge_atom.filter_lambda_.expression,
{edge_atom.filter_lambda_.inner_edge,
edge_atom.filter_lambda_.inner_node});
{edge_atom.filter_lambda_.inner_edge, edge_atom.filter_lambda_.inner_node});
} else {
// Create inner symbols, but don't bind them in scope, since they are to
// be used in the missing filter expression.
auto *inner_edge = edge_atom.filter_lambda_.inner_edge;
inner_edge->MapTo(symbol_table_.CreateSymbol(
inner_edge->name_, inner_edge->user_declared_, Symbol::Type::EDGE));
inner_edge->MapTo(symbol_table_.CreateSymbol(inner_edge->name_, inner_edge->user_declared_, Symbol::Type::EDGE));
auto *inner_node = edge_atom.filter_lambda_.inner_node;
inner_node->MapTo(symbol_table_.CreateSymbol(
inner_node->name_, inner_node->user_declared_, Symbol::Type::VERTEX));
inner_node->MapTo(
symbol_table_.CreateSymbol(inner_node->name_, inner_node->user_declared_, Symbol::Type::VERTEX));
}
if (edge_atom.weight_lambda_.expression) {
VisitWithIdentifiers(edge_atom.weight_lambda_.expression,
{edge_atom.weight_lambda_.inner_edge,
edge_atom.weight_lambda_.inner_node});
{edge_atom.weight_lambda_.inner_edge, edge_atom.weight_lambda_.inner_node});
}
scope_.in_pattern = true;
}
@ -484,9 +456,8 @@ bool SymbolGenerator::PreVisit(EdgeAtom &edge_atom) {
if (HasSymbol(edge_atom.total_weight_->name_)) {
throw RedeclareVariableError(edge_atom.total_weight_->name_);
}
edge_atom.total_weight_->MapTo(GetOrCreateSymbol(
edge_atom.total_weight_->name_, edge_atom.total_weight_->user_declared_,
Symbol::Type::NUMBER));
edge_atom.total_weight_->MapTo(GetOrCreateSymbol(edge_atom.total_weight_->name_,
edge_atom.total_weight_->user_declared_, Symbol::Type::NUMBER));
}
return false;
}
@ -497,8 +468,7 @@ bool SymbolGenerator::PostVisit(EdgeAtom &) {
return true;
}
void SymbolGenerator::VisitWithIdentifiers(
Expression *expr, const std::vector<Identifier *> &identifiers) {
void SymbolGenerator::VisitWithIdentifiers(Expression *expr, const std::vector<Identifier *> &identifiers) {
std::vector<std::pair<std::optional<Symbol>, Identifier *>> prev_symbols;
// Collect previous symbols if they exist.
for (const auto &identifier : identifiers) {
@ -507,8 +477,7 @@ void SymbolGenerator::VisitWithIdentifiers(
if (prev_symbol_it != scope_.symbols.end()) {
prev_symbol = prev_symbol_it->second;
}
identifier->MapTo(
CreateSymbol(identifier->name_, identifier->user_declared_));
identifier->MapTo(CreateSymbol(identifier->name_, identifier->user_declared_));
prev_symbols.emplace_back(prev_symbol, identifier);
}
// Visit the expression with the new symbols bound.
@ -525,8 +494,6 @@ void SymbolGenerator::VisitWithIdentifiers(
}
}
bool SymbolGenerator::HasSymbol(const std::string &name) {
return scope_.symbols.find(name) != scope_.symbols.end();
}
bool SymbolGenerator::HasSymbol(const std::string &name) { return scope_.symbols.find(name) != scope_.symbols.end(); }
} // namespace query

View File

@ -17,8 +17,7 @@ namespace query {
/// variable types.
class SymbolGenerator : public HierarchicalTreeVisitor {
public:
explicit SymbolGenerator(SymbolTable &symbol_table)
: symbol_table_(symbol_table) {}
explicit SymbolGenerator(SymbolTable &symbol_table) : symbol_table_(symbol_table) {}
using HierarchicalTreeVisitor::PostVisit;
using HierarchicalTreeVisitor::PreVisit;
@ -117,14 +116,12 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
// Returns a freshly generated symbol. Previous mapping of the same name to a
// different symbol is replaced with the new one.
auto CreateSymbol(const std::string &name, bool user_declared,
Symbol::Type type = Symbol::Type::ANY,
auto CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY,
int token_position = -1);
// Returns the symbol by name. If the mapping already exists, checks if the
// types match. Otherwise, returns a new symbol.
auto GetOrCreateSymbol(const std::string &name, bool user_declared,
Symbol::Type type = Symbol::Type::ANY);
auto GetOrCreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY);
void VisitReturnBody(ReturnBody &body, Where *where = nullptr);

View File

@ -12,13 +12,11 @@ namespace query {
class SymbolTable final {
public:
SymbolTable() {}
const Symbol &CreateSymbol(const std::string &name, bool user_declared,
Symbol::Type type = Symbol::Type::ANY,
const Symbol &CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY,
int32_t token_position = -1) {
MG_ASSERT(table_.size() <= std::numeric_limits<int32_t>::max(),
"SymbolTable size doesn't fit into 32-bit integer!");
auto got = table_.emplace(position_, Symbol(name, position_, user_declared,
type, token_position));
auto got = table_.emplace(position_, Symbol(name, position_, user_declared, type, token_position));
MG_ASSERT(got.second, "Duplicate symbol ID!");
position_++;
return got.first->second;
@ -31,24 +29,17 @@ class SymbolTable final {
while (true) {
static const std::string &kAnonPrefix = "anon";
std::string name_candidate = kAnonPrefix + std::to_string(id++);
if (std::find_if(std::begin(table_), std::end(table_),
[&name_candidate](const auto &item) -> bool {
return item.second.name_ == name_candidate;
}) == std::end(table_)) {
if (std::find_if(std::begin(table_), std::end(table_), [&name_candidate](const auto &item) -> bool {
return item.second.name_ == name_candidate;
}) == std::end(table_)) {
return CreateSymbol(name_candidate, false, type);
}
}
}
const Symbol &at(const Identifier &ident) const {
return table_.at(ident.symbol_pos_);
}
const Symbol &at(const NamedExpression &nexpr) const {
return table_.at(nexpr.symbol_pos_);
}
const Symbol &at(const Aggregation &aggr) const {
return table_.at(aggr.symbol_pos_);
}
const Symbol &at(const Identifier &ident) const { return table_.at(ident.symbol_pos_); }
const Symbol &at(const NamedExpression &nexpr) const { return table_.at(nexpr.symbol_pos_); }
const Symbol &at(const Aggregation &aggr) const { return table_.at(aggr.symbol_pos_); }
// TODO: Remove these since members are public
int32_t max_position() const { return static_cast<int32_t>(table_.size()); }

View File

@ -64,9 +64,7 @@ StrippedQuery::StrippedQuery(const std::string &query) : original_(query) {
// A helper function that stores literal and its token position in a
// literals_. In stripped query text literal is replaced with a new_value.
// new_value can be any value that is lexed as a literal.
auto replace_stripped = [this, &token_strings](int position,
const auto &value,
const std::string &new_value) {
auto replace_stripped = [this, &token_strings](int position, const auto &value, const std::string &new_value) {
literals_.Add(position, storage::PropertyValue(value));
token_strings.push_back(new_value);
};
@ -101,16 +99,13 @@ StrippedQuery::StrippedQuery(const std::string &query) : original_(query) {
case Token::SPACE:
break;
case Token::STRING:
replace_stripped(token_index, ParseStringLiteral(token.second),
kStrippedStringToken);
replace_stripped(token_index, ParseStringLiteral(token.second), kStrippedStringToken);
break;
case Token::INT:
replace_stripped(token_index, ParseIntegerLiteral(token.second),
kStrippedIntToken);
replace_stripped(token_index, ParseIntegerLiteral(token.second), kStrippedIntToken);
break;
case Token::REAL:
replace_stripped(token_index, ParseDoubleLiteral(token.second),
kStrippedDoubleToken);
replace_stripped(token_index, ParseDoubleLiteral(token.second), kStrippedDoubleToken);
break;
case Token::SPECIAL:
case Token::ESCAPED_NAME:
@ -135,9 +130,7 @@ StrippedQuery::StrippedQuery(const std::string &query) : original_(query) {
while (it != tokens.end()) {
// Store nonaliased named expressions in returns in named_exprs_.
it = std::find_if(it, tokens.end(),
[](const std::pair<Token, std::string> &a) {
return utils::IEquals(a.second, "return");
});
[](const std::pair<Token, std::string> &a) { return utils::IEquals(a.second, "return"); });
// There is no RETURN so there is nothing to do here.
if (it == tokens.end()) return;
// Skip RETURN;
@ -172,13 +165,10 @@ StrippedQuery::StrippedQuery(const std::string &query) : original_(query) {
int num_open_braces = 0;
int num_open_parantheses = 0;
int num_open_brackets = 0;
for (; jt != tokens.end() &&
(jt->second != "," || num_open_braces || num_open_parantheses ||
num_open_brackets) &&
!utils::IEquals(jt->second, "order") &&
!utils::IEquals(jt->second, "skip") &&
!utils::IEquals(jt->second, "limit") &&
!utils::IEquals(jt->second, "union") && jt->second != ";";
for (;
jt != tokens.end() && (jt->second != "," || num_open_braces || num_open_parantheses || num_open_brackets) &&
!utils::IEquals(jt->second, "order") && !utils::IEquals(jt->second, "skip") &&
!utils::IEquals(jt->second, "limit") && !utils::IEquals(jt->second, "union") && jt->second != ";";
++jt) {
if (jt->second == "(") {
++num_open_parantheses;
@ -203,8 +193,7 @@ StrippedQuery::StrippedQuery(const std::string &query) : original_(query) {
// trailing whitespaces.
std::string s;
auto begin_token = it - tokens.begin() + original_tokens.begin();
auto end_token =
last_non_space - tokens.begin() + original_tokens.begin() + 1;
auto end_token = last_non_space - tokens.begin() + original_tokens.begin() + 1;
for (auto kt = begin_token; kt != end_token; ++kt) {
s += kt->second;
}
@ -280,9 +269,7 @@ std::pair<int, int> GetFirstUtf8SymbolCodepoint(const char *_s) {
if ((*s2 >> 6) != 0x02) throw LexingException("Invalid character.");
auto *s3 = s + 3;
if ((*s3 >> 6) != 0x02) throw LexingException("Invalid character.");
return {((*s & 0x07) << 18) | ((*s1 & 0x3f) << 12) | ((*s2 & 0x3f) << 6) |
(*s3 & 0x3f),
4};
return {((*s & 0x07) << 18) | ((*s1 & 0x3f) << 12) | ((*s2 & 0x3f) << 6) | (*s3 & 0x3f), 4};
}
throw LexingException("Invalid character.");
}
@ -301,13 +288,9 @@ std::pair<int, int> GetFirstUtf8SymbolCodepoint(const char *_s) {
// //\ ) , / '%%%%(%%'
// , _.'/ `\<-- \<
// `^^^` ^^ ^^
int StrippedQuery::MatchKeyword(int start) const {
return kKeywords.Match<tolower>(original_.c_str() + start);
}
int StrippedQuery::MatchKeyword(int start) const { return kKeywords.Match<tolower>(original_.c_str() + start); }
int StrippedQuery::MatchSpecial(int start) const {
return kSpecialTokens.Match(original_.c_str() + start);
}
int StrippedQuery::MatchSpecial(int start) const { return kSpecialTokens.Match(original_.c_str() + start); }
int StrippedQuery::MatchString(int start) const {
if (original_[start] != '"' && original_[start] != '\'') return 0;
@ -316,9 +299,8 @@ int StrippedQuery::MatchString(int start) const {
if (*p == start_char) return p - (original_.data() + start) + 1;
if (*p == '\\') {
++p;
if (*p == '\\' || *p == '\'' || *p == '"' || *p == 'B' || *p == 'b' ||
*p == 'F' || *p == 'f' || *p == 'N' || *p == 'n' || *p == 'R' ||
*p == 'r' || *p == 'T' || *p == 't') {
if (*p == '\\' || *p == '\'' || *p == '"' || *p == 'B' || *p == 'b' || *p == 'F' || *p == 'f' || *p == 'N' ||
*p == 'n' || *p == 'R' || *p == 'r' || *p == 'T' || *p == 't') {
// Allowed escaped characters.
continue;
} else if (*p == 'U' || *p == 'u') {
@ -356,8 +338,7 @@ int StrippedQuery::MatchDecimalInt(int start) const {
int StrippedQuery::MatchOctalInt(int start) const {
if (original_[start] != '0') return 0;
int i = start + 1;
while (i < static_cast<int>(original_.size()) && '0' <= original_[i] &&
original_[i] <= '7') {
while (i < static_cast<int>(original_.size()) && '0' <= original_[i] && original_[i] <= '7') {
++i;
}
if (i == start + 1) return 0;
@ -440,15 +421,13 @@ int StrippedQuery::MatchEscapedName(int start) const {
int StrippedQuery::MatchUnescapedName(int start) const {
auto i = start;
auto got = GetFirstUtf8SymbolCodepoint(original_.data() + i);
if (got.first >= lexer_constants::kBitsetSize ||
!kUnescapedNameAllowedStarts[got.first]) {
if (got.first >= lexer_constants::kBitsetSize || !kUnescapedNameAllowedStarts[got.first]) {
return 0;
}
i += got.second;
while (i < static_cast<int>(original_.size())) {
got = GetFirstUtf8SymbolCodepoint(original_.data() + i);
if (got.first >= lexer_constants::kBitsetSize ||
!kUnescapedNameAllowedParts[got.first]) {
if (got.first >= lexer_constants::kBitsetSize || !kUnescapedNameAllowedParts[got.first]) {
break;
}
i += got.second;
@ -469,13 +448,11 @@ int StrippedQuery::MatchWhitespaceAndComments(int start) const {
auto got = GetFirstUtf8SymbolCodepoint(original_.data() + i);
if (got.first < lexer_constants::kBitsetSize && kSpaceParts[got.first]) {
i += got.second;
} else if (i + 1 < len && original_[i] == '/' &&
original_[i + 1] == '*') {
} else if (i + 1 < len && original_[i] == '/' && original_[i + 1] == '*') {
comment_position = i;
state = State::IN_BLOCK_COMMENT;
i += 2;
} else if (i + 1 < len && original_[i] == '/' &&
original_[i + 1] == '/') {
} else if (i + 1 < len && original_[i] == '/' && original_[i + 1] == '/') {
comment_position = i;
if (i + 2 < len) {
// Special case for an empty line comment starting right at the end of
@ -490,8 +467,7 @@ int StrippedQuery::MatchWhitespaceAndComments(int start) const {
if (original_[i] == '\n') {
state = State::OUT;
++i;
} else if (i + 1 < len && original_[i] == '\r' &&
original_[i + 1] == '\n') {
} else if (i + 1 < len && original_[i] == '\r' && original_[i + 1] == '\n') {
state = State::OUT;
i += 2;
} else if (original_[i] == '\r') {

File diff suppressed because it is too large Load Diff

View File

@ -180,11 +180,9 @@ struct Or<ArgType, ArgTypes...> {
static std::string TypeNames() {
if constexpr (sizeof...(ArgTypes) > 1) {
return fmt::format("'{}', {}", ArgTypeName<ArgType>(),
Or<ArgTypes...>::TypeNames());
return fmt::format("'{}', {}", ArgTypeName<ArgType>(), Or<ArgTypes...>::TypeNames());
} else {
return fmt::format("'{}' or '{}'", ArgTypeName<ArgType>(),
Or<ArgTypes...>::TypeNames());
return fmt::format("'{}' or '{}'", ArgTypeName<ArgType>(), Or<ArgTypes...>::TypeNames());
}
}
};
@ -206,21 +204,18 @@ template <class ArgType>
struct Optional<ArgType> {
static constexpr size_t size = 1;
static void Check(const char *name, const TypedValue *args, int64_t nargs,
int64_t pos) {
static void Check(const char *name, const TypedValue *args, int64_t nargs, int64_t pos) {
if (nargs == 0) return;
const TypedValue &arg = args[0];
if constexpr (IsOrType<ArgType>::value) {
if (!ArgType::Check(arg)) {
throw QueryRuntimeException(
"Optional '{}' argument at position {} must be either {}.", name,
pos, ArgType::TypeNames());
throw QueryRuntimeException("Optional '{}' argument at position {} must be either {}.", name, pos,
ArgType::TypeNames());
}
} else {
if (!ArgIsType<ArgType>(arg))
throw QueryRuntimeException(
"Optional '{}' argument at position {} must be '{}'.", name, pos,
ArgTypeName<ArgType>());
throw QueryRuntimeException("Optional '{}' argument at position {} must be '{}'.", name, pos,
ArgTypeName<ArgType>());
}
}
};
@ -229,8 +224,7 @@ template <class ArgType, class... ArgTypes>
struct Optional<ArgType, ArgTypes...> {
static constexpr size_t size = 1 + sizeof...(ArgTypes);
static void Check(const char *name, const TypedValue *args, int64_t nargs,
int64_t pos) {
static void Check(const char *name, const TypedValue *args, int64_t nargs, int64_t pos) {
if (nargs == 0) return;
Optional<ArgType>::Check(name, args, nargs, pos);
Optional<ArgTypes...>::Check(name, args + 1, nargs - 1, pos + 1);
@ -272,8 +266,7 @@ constexpr size_t FTypeOptionalArgs() {
}
template <class ArgType, class... ArgTypes>
void FType(const char *name, const TypedValue *args, int64_t nargs,
int64_t pos = 1) {
void FType(const char *name, const TypedValue *args, int64_t nargs, int64_t pos = 1) {
if constexpr (std::is_same_v<ArgType, void>) {
if (nargs != 0) {
throw QueryRuntimeException("'{}' requires no arguments.", name);
@ -285,30 +278,25 @@ void FType(const char *name, const TypedValue *args, int64_t nargs,
constexpr int64_t total_args = required_args + optional_args;
if constexpr (optional_args > 0) {
if (nargs < required_args || nargs > total_args) {
throw QueryRuntimeException("'{}' requires between {} and {} arguments.",
name, required_args, total_args);
throw QueryRuntimeException("'{}' requires between {} and {} arguments.", name, required_args, total_args);
}
} else {
if (nargs != required_args) {
throw QueryRuntimeException(
"'{}' requires exactly {} {}.", name, required_args,
required_args == 1 ? "argument" : "arguments");
throw QueryRuntimeException("'{}' requires exactly {} {}.", name, required_args,
required_args == 1 ? "argument" : "arguments");
}
}
const TypedValue &arg = args[0];
if constexpr (IsOrType<ArgType>::value) {
if (!ArgType::Check(arg)) {
throw QueryRuntimeException(
"'{}' argument at position {} must be either {}.", name, pos,
ArgType::TypeNames());
throw QueryRuntimeException("'{}' argument at position {} must be either {}.", name, pos, ArgType::TypeNames());
}
} else if constexpr (IsOptional<ArgType>::value) {
static_assert(sizeof...(ArgTypes) == 0, "Optional arguments must be last!");
ArgType::Check(name, args, nargs, pos);
} else {
if (!ArgIsType<ArgType>(arg)) {
throw QueryRuntimeException("'{}' argument at position {} must be '{}'",
name, pos, ArgTypeName<ArgType>());
throw QueryRuntimeException("'{}' argument at position {} must be '{}'", name, pos, ArgTypeName<ArgType>());
}
}
if constexpr (sizeof...(ArgTypes) > 0) {
@ -339,15 +327,13 @@ void FType(const char *name, const TypedValue *args, int64_t nargs,
// TODO: Implement degrees, haversin, radians
// TODO: Implement spatial functions
TypedValue EndNode(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue EndNode(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Edge>>("endNode", args, nargs);
if (args[0].IsNull()) return TypedValue(ctx.memory);
return TypedValue(args[0].ValueEdge().To(), ctx.memory);
}
TypedValue Head(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Head(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, List>>("head", args, nargs);
if (args[0].IsNull()) return TypedValue(ctx.memory);
const auto &list = args[0].ValueList();
@ -355,8 +341,7 @@ TypedValue Head(const TypedValue *args, int64_t nargs,
return TypedValue(list[0], ctx.memory);
}
TypedValue Last(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Last(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, List>>("last", args, nargs);
if (args[0].IsNull()) return TypedValue(ctx.memory);
const auto &list = args[0].ValueList();
@ -364,8 +349,7 @@ TypedValue Last(const TypedValue *args, int64_t nargs,
return TypedValue(list.back(), ctx.memory);
}
TypedValue Properties(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Properties(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Vertex, Edge>>("properties", args, nargs);
auto *dba = ctx.db_accessor;
auto get_properties = [&](const auto &record_accessor) {
@ -374,16 +358,13 @@ TypedValue Properties(const TypedValue *args, int64_t nargs,
if (maybe_props.HasError()) {
switch (maybe_props.GetError()) {
case storage::Error::DELETED_OBJECT:
throw QueryRuntimeException(
"Trying to get properties from a deleted object.");
throw QueryRuntimeException("Trying to get properties from a deleted object.");
case storage::Error::NONEXISTENT_OBJECT:
throw query::QueryRuntimeException(
"Trying to get properties from an object that doesn't exist.");
throw query::QueryRuntimeException("Trying to get properties from an object that doesn't exist.");
case storage::Error::SERIALIZATION_ERROR:
case storage::Error::VERTEX_HAS_EDGES:
case storage::Error::PROPERTIES_DISABLED:
throw QueryRuntimeException(
"Unexpected error when getting properties.");
throw QueryRuntimeException("Unexpected error when getting properties.");
}
}
for (const auto &property : *maybe_props) {
@ -401,31 +382,25 @@ TypedValue Properties(const TypedValue *args, int64_t nargs,
}
}
TypedValue Size(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Size(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, List, String, Map, Path>>("size", args, nargs);
const auto &value = args[0];
if (value.IsNull()) {
return TypedValue(ctx.memory);
} else if (value.IsList()) {
return TypedValue(static_cast<int64_t>(value.ValueList().size()),
ctx.memory);
return TypedValue(static_cast<int64_t>(value.ValueList().size()), ctx.memory);
} else if (value.IsString()) {
return TypedValue(static_cast<int64_t>(value.ValueString().size()),
ctx.memory);
return TypedValue(static_cast<int64_t>(value.ValueString().size()), ctx.memory);
} else if (value.IsMap()) {
// neo4j doesn't implement size for map, but I don't see a good reason not
// to do it.
return TypedValue(static_cast<int64_t>(value.ValueMap().size()),
ctx.memory);
return TypedValue(static_cast<int64_t>(value.ValueMap().size()), ctx.memory);
} else {
return TypedValue(static_cast<int64_t>(value.ValuePath().edges().size()),
ctx.memory);
return TypedValue(static_cast<int64_t>(value.ValuePath().edges().size()), ctx.memory);
}
}
TypedValue StartNode(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue StartNode(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Edge>>("startNode", args, nargs);
if (args[0].IsNull()) return TypedValue(ctx.memory);
return TypedValue(args[0].ValueEdge().From(), ctx.memory);
@ -439,13 +414,11 @@ size_t UnwrapDegreeResult(storage::Result<size_t> maybe_degree) {
case storage::Error::DELETED_OBJECT:
throw QueryRuntimeException("Trying to get degree of a deleted node.");
case storage::Error::NONEXISTENT_OBJECT:
throw query::QueryRuntimeException(
"Trying to get degree of a node that doesn't exist.");
throw query::QueryRuntimeException("Trying to get degree of a node that doesn't exist.");
case storage::Error::SERIALIZATION_ERROR:
case storage::Error::VERTEX_HAS_EDGES:
case storage::Error::PROPERTIES_DISABLED:
throw QueryRuntimeException(
"Unexpected error when getting node degree.");
throw QueryRuntimeException("Unexpected error when getting node degree.");
}
}
return *maybe_degree;
@ -453,8 +426,7 @@ size_t UnwrapDegreeResult(storage::Result<size_t> maybe_degree) {
} // namespace
TypedValue Degree(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Degree(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Vertex>>("degree", args, nargs);
if (args[0].IsNull()) return TypedValue(ctx.memory);
const auto &vertex = args[0].ValueVertex();
@ -463,8 +435,7 @@ TypedValue Degree(const TypedValue *args, int64_t nargs,
return TypedValue(static_cast<int64_t>(out_degree + in_degree), ctx.memory);
}
TypedValue InDegree(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue InDegree(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Vertex>>("inDegree", args, nargs);
if (args[0].IsNull()) return TypedValue(ctx.memory);
const auto &vertex = args[0].ValueVertex();
@ -472,8 +443,7 @@ TypedValue InDegree(const TypedValue *args, int64_t nargs,
return TypedValue(static_cast<int64_t>(in_degree), ctx.memory);
}
TypedValue OutDegree(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue OutDegree(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Vertex>>("outDegree", args, nargs);
if (args[0].IsNull()) return TypedValue(ctx.memory);
const auto &vertex = args[0].ValueVertex();
@ -481,8 +451,7 @@ TypedValue OutDegree(const TypedValue *args, int64_t nargs,
return TypedValue(static_cast<int64_t>(out_degree), ctx.memory);
}
TypedValue ToBoolean(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue ToBoolean(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Bool, Integer, String>>("toBoolean", args, nargs);
const auto &value = args[0];
if (value.IsNull()) {
@ -501,8 +470,7 @@ TypedValue ToBoolean(const TypedValue *args, int64_t nargs,
}
}
TypedValue ToFloat(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue ToFloat(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Number, String>>("toFloat", args, nargs);
const auto &value = args[0];
if (value.IsNull()) {
@ -513,16 +481,14 @@ TypedValue ToFloat(const TypedValue *args, int64_t nargs,
return TypedValue(value, ctx.memory);
} else {
try {
return TypedValue(utils::ParseDouble(utils::Trim(value.ValueString())),
ctx.memory);
return TypedValue(utils::ParseDouble(utils::Trim(value.ValueString())), ctx.memory);
} catch (const utils::BasicException &) {
return TypedValue(ctx.memory);
}
}
}
TypedValue ToInteger(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue ToInteger(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Bool, Number, String>>("toInteger", args, nargs);
const auto &value = args[0];
if (value.IsNull()) {
@ -537,28 +503,22 @@ TypedValue ToInteger(const TypedValue *args, int64_t nargs,
try {
// Yup, this is correct. String is valid if it has floating point
// number, then it is parsed and converted to int.
return TypedValue(static_cast<int64_t>(utils::ParseDouble(
utils::Trim(value.ValueString()))),
ctx.memory);
return TypedValue(static_cast<int64_t>(utils::ParseDouble(utils::Trim(value.ValueString()))), ctx.memory);
} catch (const utils::BasicException &) {
return TypedValue(ctx.memory);
}
}
}
TypedValue Type(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Type(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Edge>>("type", args, nargs);
auto *dba = ctx.db_accessor;
if (args[0].IsNull()) return TypedValue(ctx.memory);
return TypedValue(dba->EdgeTypeToName(args[0].ValueEdge().EdgeType()),
ctx.memory);
return TypedValue(dba->EdgeTypeToName(args[0].ValueEdge().EdgeType()), ctx.memory);
}
TypedValue ValueType(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
FType<Or<Null, Bool, Integer, Double, String, List, Map, Vertex, Edge, Path>>(
"type", args, nargs);
TypedValue ValueType(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Bool, Integer, Double, String, List, Map, Vertex, Edge, Path>>("type", args, nargs);
// The type names returned should be standardized openCypher type names.
// https://github.com/opencypher/openCypher/blob/master/docs/openCypher9.pdf
switch (args[0].type()) {
@ -586,8 +546,7 @@ TypedValue ValueType(const TypedValue *args, int64_t nargs,
}
// TODO: How is Keys different from Properties function?
TypedValue Keys(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Keys(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Vertex, Edge>>("keys", args, nargs);
auto *dba = ctx.db_accessor;
auto get_keys = [&](const auto &record_accessor) {
@ -596,11 +555,9 @@ TypedValue Keys(const TypedValue *args, int64_t nargs,
if (maybe_props.HasError()) {
switch (maybe_props.GetError()) {
case storage::Error::DELETED_OBJECT:
throw QueryRuntimeException(
"Trying to get keys from a deleted object.");
throw QueryRuntimeException("Trying to get keys from a deleted object.");
case storage::Error::NONEXISTENT_OBJECT:
throw query::QueryRuntimeException(
"Trying to get keys from an object that doesn't exist.");
throw query::QueryRuntimeException("Trying to get keys from an object that doesn't exist.");
case storage::Error::SERIALIZATION_ERROR:
case storage::Error::VERTEX_HAS_EDGES:
case storage::Error::PROPERTIES_DISABLED:
@ -622,8 +579,7 @@ TypedValue Keys(const TypedValue *args, int64_t nargs,
}
}
TypedValue Labels(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Labels(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Vertex>>("labels", args, nargs);
auto *dba = ctx.db_accessor;
if (args[0].IsNull()) return TypedValue(ctx.memory);
@ -632,11 +588,9 @@ TypedValue Labels(const TypedValue *args, int64_t nargs,
if (maybe_labels.HasError()) {
switch (maybe_labels.GetError()) {
case storage::Error::DELETED_OBJECT:
throw QueryRuntimeException(
"Trying to get labels from a deleted node.");
throw QueryRuntimeException("Trying to get labels from a deleted node.");
case storage::Error::NONEXISTENT_OBJECT:
throw query::QueryRuntimeException(
"Trying to get labels from a node that doesn't exist.");
throw query::QueryRuntimeException("Trying to get labels from a node that doesn't exist.");
case storage::Error::SERIALIZATION_ERROR:
case storage::Error::VERTEX_HAS_EDGES:
case storage::Error::PROPERTIES_DISABLED:
@ -649,8 +603,7 @@ TypedValue Labels(const TypedValue *args, int64_t nargs,
return TypedValue(std::move(labels));
}
TypedValue Nodes(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Nodes(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Path>>("nodes", args, nargs);
if (args[0].IsNull()) return TypedValue(ctx.memory);
const auto &vertices = args[0].ValuePath().vertices();
@ -660,8 +613,7 @@ TypedValue Nodes(const TypedValue *args, int64_t nargs,
return TypedValue(std::move(values));
}
TypedValue Relationships(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Relationships(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Path>>("relationships", args, nargs);
if (args[0].IsNull()) return TypedValue(ctx.memory);
const auto &edges = args[0].ValuePath().edges();
@ -671,10 +623,8 @@ TypedValue Relationships(const TypedValue *args, int64_t nargs,
return TypedValue(std::move(values));
}
TypedValue Range(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
FType<Or<Null, Integer>, Or<Null, Integer>,
Optional<Or<Null, NonZeroInteger>>>("range", args, nargs);
TypedValue Range(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Integer>, Or<Null, Integer>, Optional<Or<Null, NonZeroInteger>>>("range", args, nargs);
for (int64_t i = 0; i < nargs; ++i)
if (args[i].IsNull()) return TypedValue(ctx.memory);
auto lbound = args[0].ValueInt();
@ -693,8 +643,7 @@ TypedValue Range(const TypedValue *args, int64_t nargs,
return TypedValue(std::move(list));
}
TypedValue Tail(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Tail(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, List>>("tail", args, nargs);
if (args[0].IsNull()) return TypedValue(ctx.memory);
TypedValue::TVector list(args[0].ValueList(), ctx.memory);
@ -703,10 +652,8 @@ TypedValue Tail(const TypedValue *args, int64_t nargs,
return TypedValue(std::move(list));
}
TypedValue UniformSample(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
FType<Or<Null, List>, Or<Null, NonNegativeInteger>>("uniformSample", args,
nargs);
TypedValue UniformSample(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, List>, Or<Null, NonNegativeInteger>>("uniformSample", args, nargs);
static thread_local std::mt19937 pseudo_rand_gen_{std::random_device{}()};
if (args[0].IsNull() || args[1].IsNull()) return TypedValue(ctx.memory);
const auto &population = args[0].ValueList();
@ -722,8 +669,7 @@ TypedValue UniformSample(const TypedValue *args, int64_t nargs,
return TypedValue(std::move(sampled));
}
TypedValue Abs(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Abs(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Number>>("abs", args, nargs);
const auto &value = args[0];
if (value.IsNull()) {
@ -735,18 +681,17 @@ TypedValue Abs(const TypedValue *args, int64_t nargs,
}
}
#define WRAP_CMATH_FLOAT_FUNCTION(name, lowercased_name) \
TypedValue name(const TypedValue *args, int64_t nargs, \
const FunctionContext &ctx) { \
FType<Or<Null, Number>>(#lowercased_name, args, nargs); \
const auto &value = args[0]; \
if (value.IsNull()) { \
return TypedValue(ctx.memory); \
} else if (value.IsInt()) { \
return TypedValue(lowercased_name(value.ValueInt()), ctx.memory); \
} else { \
return TypedValue(lowercased_name(value.ValueDouble()), ctx.memory); \
} \
#define WRAP_CMATH_FLOAT_FUNCTION(name, lowercased_name) \
TypedValue name(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { \
FType<Or<Null, Number>>(#lowercased_name, args, nargs); \
const auto &value = args[0]; \
if (value.IsNull()) { \
return TypedValue(ctx.memory); \
} else if (value.IsInt()) { \
return TypedValue(lowercased_name(value.ValueInt()), ctx.memory); \
} else { \
return TypedValue(lowercased_name(value.ValueDouble()), ctx.memory); \
} \
}
WRAP_CMATH_FLOAT_FUNCTION(Ceil, ceil)
@ -767,8 +712,7 @@ WRAP_CMATH_FLOAT_FUNCTION(Tan, tan)
#undef WRAP_CMATH_FLOAT_FUNCTION
TypedValue Atan2(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Atan2(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Number>, Or<Null, Number>>("atan2", args, nargs);
if (args[0].IsNull() || args[1].IsNull()) return TypedValue(ctx.memory);
auto to_double = [](const TypedValue &t) -> double {
@ -783,8 +727,7 @@ TypedValue Atan2(const TypedValue *args, int64_t nargs,
return TypedValue(atan2(y, x), ctx.memory);
}
TypedValue Sign(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Sign(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Number>>("sign", args, nargs);
auto sign = [&](auto x) { return TypedValue((0 < x) - (x < 0), ctx.memory); };
const auto &value = args[0];
@ -797,20 +740,17 @@ TypedValue Sign(const TypedValue *args, int64_t nargs,
}
}
TypedValue E(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue E(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<void>("e", args, nargs);
return TypedValue(M_E, ctx.memory);
}
TypedValue Pi(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Pi(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<void>("pi", args, nargs);
return TypedValue(M_PI, ctx.memory);
}
TypedValue Rand(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Rand(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<void>("rand", args, nargs);
static thread_local std::mt19937 pseudo_rand_gen_{std::random_device{}()};
static thread_local std::uniform_real_distribution<> rand_dist_{0, 1};
@ -818,8 +758,7 @@ TypedValue Rand(const TypedValue *args, int64_t nargs,
}
template <class TPredicate>
TypedValue StringMatchOperator(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue StringMatchOperator(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, String>, Or<Null, String>>(TPredicate::name, args, nargs);
if (args[0].IsNull() || args[1].IsNull()) return TypedValue(ctx.memory);
const auto &s1 = args[0].ValueString();
@ -830,8 +769,7 @@ TypedValue StringMatchOperator(const TypedValue *args, int64_t nargs,
// Check if s1 starts with s2.
struct StartsWithPredicate {
constexpr static const char *name = "startsWith";
bool operator()(const TypedValue::TString &s1,
const TypedValue::TString &s2) const {
bool operator()(const TypedValue::TString &s1, const TypedValue::TString &s2) const {
if (s1.size() < s2.size()) return false;
return std::equal(s2.begin(), s2.end(), s1.begin());
}
@ -841,8 +779,7 @@ auto StartsWith = StringMatchOperator<StartsWithPredicate>;
// Check if s1 ends with s2.
struct EndsWithPredicate {
constexpr static const char *name = "endsWith";
bool operator()(const TypedValue::TString &s1,
const TypedValue::TString &s2) const {
bool operator()(const TypedValue::TString &s1, const TypedValue::TString &s2) const {
if (s1.size() < s2.size()) return false;
return std::equal(s2.rbegin(), s2.rend(), s1.rbegin());
}
@ -852,16 +789,14 @@ auto EndsWith = StringMatchOperator<EndsWithPredicate>;
// Check if s1 contains s2.
struct ContainsPredicate {
constexpr static const char *name = "contains";
bool operator()(const TypedValue::TString &s1,
const TypedValue::TString &s2) const {
bool operator()(const TypedValue::TString &s1, const TypedValue::TString &s2) const {
if (s1.size() < s2.size()) return false;
return s1.find(s2) != std::string::npos;
}
};
auto Contains = StringMatchOperator<ContainsPredicate>;
TypedValue Assert(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Assert(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Bool, Optional<String>>("assert", args, nargs);
if (!args[0].ValueBool()) {
std::string message("Assertion failed");
@ -875,24 +810,21 @@ TypedValue Assert(const TypedValue *args, int64_t nargs,
return TypedValue(args[0], ctx.memory);
}
TypedValue Counter(const TypedValue *args, int64_t nargs,
const FunctionContext &context) {
TypedValue Counter(const TypedValue *args, int64_t nargs, const FunctionContext &context) {
FType<String, Integer, Optional<NonZeroInteger>>("counter", args, nargs);
int64_t step = 1;
if (nargs == 3) {
step = args[2].ValueInt();
}
auto [it, inserted] =
context.counters->emplace(args[0].ValueString(), args[1].ValueInt());
auto [it, inserted] = context.counters->emplace(args[0].ValueString(), args[1].ValueInt());
auto value = it->second;
it->second += step;
return TypedValue(value, context.memory);
}
TypedValue Id(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Id(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Vertex, Edge>>("id", args, nargs);
const auto &arg = args[0];
if (arg.IsNull()) {
@ -904,8 +836,7 @@ TypedValue Id(const TypedValue *args, int64_t nargs,
}
}
TypedValue ToString(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue ToString(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, String, Number, Bool>>("toString", args, nargs);
const auto &arg = args[0];
if (arg.IsNull()) {
@ -923,106 +854,80 @@ TypedValue ToString(const TypedValue *args, int64_t nargs,
}
}
TypedValue Timestamp(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Timestamp(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<void>("timestamp", args, nargs);
return TypedValue(ctx.timestamp, ctx.memory);
}
TypedValue Left(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Left(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, String>, Or<Null, NonNegativeInteger>>("left", args, nargs);
if (args[0].IsNull() || args[1].IsNull()) return TypedValue(ctx.memory);
return TypedValue(utils::Substr(args[0].ValueString(), 0, args[1].ValueInt()),
ctx.memory);
return TypedValue(utils::Substr(args[0].ValueString(), 0, args[1].ValueInt()), ctx.memory);
}
TypedValue Right(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Right(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, String>, Or<Null, NonNegativeInteger>>("right", args, nargs);
if (args[0].IsNull() || args[1].IsNull()) return TypedValue(ctx.memory);
const auto &str = args[0].ValueString();
auto len = args[1].ValueInt();
return len <= str.size()
? TypedValue(utils::Substr(str, str.size() - len, len), ctx.memory)
: TypedValue(str, ctx.memory);
return len <= str.size() ? TypedValue(utils::Substr(str, str.size() - len, len), ctx.memory)
: TypedValue(str, ctx.memory);
}
TypedValue CallStringFunction(
const TypedValue *args, int64_t nargs, utils::MemoryResource *memory,
const char *name,
std::function<TypedValue::TString(const TypedValue::TString &)> fun) {
TypedValue CallStringFunction(const TypedValue *args, int64_t nargs, utils::MemoryResource *memory, const char *name,
std::function<TypedValue::TString(const TypedValue::TString &)> fun) {
FType<Or<Null, String>>(name, args, nargs);
if (args[0].IsNull()) return TypedValue(memory);
return TypedValue(fun(args[0].ValueString()), memory);
}
TypedValue LTrim(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
return CallStringFunction(
args, nargs, ctx.memory, "lTrim", [&](const auto &str) {
return TypedValue::TString(utils::LTrim(str), ctx.memory);
});
TypedValue LTrim(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
return CallStringFunction(args, nargs, ctx.memory, "lTrim",
[&](const auto &str) { return TypedValue::TString(utils::LTrim(str), ctx.memory); });
}
TypedValue RTrim(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
return CallStringFunction(
args, nargs, ctx.memory, "rTrim", [&](const auto &str) {
return TypedValue::TString(utils::RTrim(str), ctx.memory);
});
TypedValue RTrim(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
return CallStringFunction(args, nargs, ctx.memory, "rTrim",
[&](const auto &str) { return TypedValue::TString(utils::RTrim(str), ctx.memory); });
}
TypedValue Trim(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
return CallStringFunction(
args, nargs, ctx.memory, "trim", [&](const auto &str) {
return TypedValue::TString(utils::Trim(str), ctx.memory);
});
TypedValue Trim(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
return CallStringFunction(args, nargs, ctx.memory, "trim",
[&](const auto &str) { return TypedValue::TString(utils::Trim(str), ctx.memory); });
}
TypedValue Reverse(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
return CallStringFunction(
args, nargs, ctx.memory, "reverse",
[&](const auto &str) { return utils::Reversed(str, ctx.memory); });
TypedValue Reverse(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
return CallStringFunction(args, nargs, ctx.memory, "reverse",
[&](const auto &str) { return utils::Reversed(str, ctx.memory); });
}
TypedValue ToLower(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
return CallStringFunction(args, nargs, ctx.memory, "toLower",
[&](const auto &str) {
TypedValue::TString res(ctx.memory);
utils::ToLowerCase(&res, str);
return res;
});
TypedValue ToLower(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
return CallStringFunction(args, nargs, ctx.memory, "toLower", [&](const auto &str) {
TypedValue::TString res(ctx.memory);
utils::ToLowerCase(&res, str);
return res;
});
}
TypedValue ToUpper(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
return CallStringFunction(args, nargs, ctx.memory, "toUpper",
[&](const auto &str) {
TypedValue::TString res(ctx.memory);
utils::ToUpperCase(&res, str);
return res;
});
TypedValue ToUpper(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
return CallStringFunction(args, nargs, ctx.memory, "toUpper", [&](const auto &str) {
TypedValue::TString res(ctx.memory);
utils::ToUpperCase(&res, str);
return res;
});
}
TypedValue Replace(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
FType<Or<Null, String>, Or<Null, String>, Or<Null, String>>("replace", args,
nargs);
TypedValue Replace(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, String>, Or<Null, String>, Or<Null, String>>("replace", args, nargs);
if (args[0].IsNull() || args[1].IsNull() || args[2].IsNull()) {
return TypedValue(ctx.memory);
}
TypedValue::TString replaced(ctx.memory);
utils::Replace(&replaced, args[0].ValueString(), args[1].ValueString(),
args[2].ValueString());
utils::Replace(&replaced, args[0].ValueString(), args[1].ValueString(), args[2].ValueString());
return TypedValue(std::move(replaced));
}
TypedValue Split(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue Split(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, String>, Or<Null, String>>("split", args, nargs);
if (args[0].IsNull() || args[1].IsNull()) {
return TypedValue(ctx.memory);
@ -1032,10 +937,8 @@ TypedValue Split(const TypedValue *args, int64_t nargs,
return TypedValue(std::move(result));
}
TypedValue Substring(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
FType<Or<Null, String>, NonNegativeInteger, Optional<NonNegativeInteger>>(
"substring", args, nargs);
TypedValue Substring(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, String>, NonNegativeInteger, Optional<NonNegativeInteger>>("substring", args, nargs);
if (args[0].IsNull()) return TypedValue(ctx.memory);
const auto &str = args[0].ValueString();
auto start = args[1].ValueInt();
@ -1044,8 +947,7 @@ TypedValue Substring(const TypedValue *args, int64_t nargs,
return TypedValue(utils::Substr(str, start, len), ctx.memory);
}
TypedValue ToByteString(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue ToByteString(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<String>("toByteString", args, nargs);
const auto &str = args[0].ValueString();
if (str.empty()) return TypedValue("", ctx.memory);
@ -1057,8 +959,7 @@ TypedValue ToByteString(const TypedValue *args, int64_t nargs,
if (ch >= '0' && ch <= '9') return ch - '0';
if (ch >= 'a' && ch <= 'f') return ch - 'a' + 10;
if (ch >= 'A' && ch <= 'F') return ch - 'A' + 10;
throw QueryRuntimeException(
"'toByteString' argument has an invalid character '{}'", ch);
throw QueryRuntimeException("'toByteString' argument has an invalid character '{}'", ch);
};
utils::pmr::string bytes(ctx.memory);
bytes.reserve((1 + hex_str.size()) / 2);
@ -1074,27 +975,23 @@ TypedValue ToByteString(const TypedValue *args, int64_t nargs,
return TypedValue(std::move(bytes));
}
TypedValue FromByteString(const TypedValue *args, int64_t nargs,
const FunctionContext &ctx) {
TypedValue FromByteString(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<String, Optional<PositiveInteger>>("fromByteString", args, nargs);
const auto &bytes = args[0].ValueString();
if (bytes.empty()) return TypedValue("", ctx.memory);
size_t min_length = bytes.size();
if (nargs == 2)
min_length = std::max(min_length, static_cast<size_t>(args[1].ValueInt()));
if (nargs == 2) min_length = std::max(min_length, static_cast<size_t>(args[1].ValueInt()));
utils::pmr::string str(ctx.memory);
str.reserve(min_length * 2 + 2);
str.append("0x");
for (size_t pad = 0; pad < min_length - bytes.size(); ++pad)
str.append(2, '0');
for (size_t pad = 0; pad < min_length - bytes.size(); ++pad) str.append(2, '0');
// Convert the bytes to a character string in hex representation.
// Unfortunately, we don't know whether the default `char` is signed or
// unsigned, so we have to work around any potential undefined behaviour when
// conversions between the 2 occur. That's why this function is more
// complicated than it should be.
auto to_hex = [](const unsigned char val) -> char {
unsigned char ch = val < 10U ? static_cast<unsigned char>('0') + val
: static_cast<unsigned char>('a') + val - 10U;
unsigned char ch = val < 10U ? static_cast<unsigned char>('0') + val : static_cast<unsigned char>('a') + val - 10U;
return utils::MemcpyCast<char>(ch);
};
for (unsigned char byte : bytes) {
@ -1106,9 +1003,8 @@ TypedValue FromByteString(const TypedValue *args, int64_t nargs,
} // namespace
std::function<TypedValue(const TypedValue *, int64_t,
const FunctionContext &ctx)>
NameToFunction(const std::string &function_name) {
std::function<TypedValue(const TypedValue *, int64_t, const FunctionContext &ctx)> NameToFunction(
const std::string &function_name) {
// Scalar functions
if (function_name == "DEGREE") return Degree;
if (function_name == "INDEGREE") return InDegree;

View File

@ -34,8 +34,7 @@ struct FunctionContext {
/// having an array stored anywhere the caller likes, as long as it is
/// contiguous in memory. Since most functions don't take many arguments, it's
/// convenient to have them stored in the calling stack frame.
std::function<TypedValue(const TypedValue *arguments, int64_t num_arguments,
const FunctionContext &context)>
std::function<TypedValue(const TypedValue *arguments, int64_t num_arguments, const FunctionContext &context)>
NameToFunction(const std::string &function_name);
} // namespace query

View File

@ -21,14 +21,9 @@ namespace query {
class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
public:
ExpressionEvaluator(Frame *frame, const SymbolTable &symbol_table,
const EvaluationContext &ctx, DbAccessor *dba,
ExpressionEvaluator(Frame *frame, const SymbolTable &symbol_table, const EvaluationContext &ctx, DbAccessor *dba,
storage::View view)
: frame_(frame),
symbol_table_(&symbol_table),
ctx_(&ctx),
dba_(dba),
view_(view) {}
: frame_(frame), symbol_table_(&symbol_table), ctx_(&ctx), dba_(dba), view_(view) {}
using ExpressionVisitor<TypedValue>::Visit;
@ -45,27 +40,25 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
return TypedValue(frame_->at(symbol_table_->at(ident)), ctx_->memory);
}
#define BINARY_OPERATOR_VISITOR(OP_NODE, CPP_OP, CYPHER_OP) \
TypedValue Visit(OP_NODE &op) override { \
auto val1 = op.expression1_->Accept(*this); \
auto val2 = op.expression2_->Accept(*this); \
try { \
return val1 CPP_OP val2; \
} catch (const TypedValueException &) { \
throw QueryRuntimeException("Invalid types: {} and {} for '{}'.", \
val1.type(), val2.type(), #CYPHER_OP); \
} \
#define BINARY_OPERATOR_VISITOR(OP_NODE, CPP_OP, CYPHER_OP) \
TypedValue Visit(OP_NODE &op) override { \
auto val1 = op.expression1_->Accept(*this); \
auto val2 = op.expression2_->Accept(*this); \
try { \
return val1 CPP_OP val2; \
} catch (const TypedValueException &) { \
throw QueryRuntimeException("Invalid types: {} and {} for '{}'.", val1.type(), val2.type(), #CYPHER_OP); \
} \
}
#define UNARY_OPERATOR_VISITOR(OP_NODE, CPP_OP, CYPHER_OP) \
TypedValue Visit(OP_NODE &op) override { \
auto val = op.expression_->Accept(*this); \
try { \
return CPP_OP val; \
} catch (const TypedValueException &) { \
throw QueryRuntimeException("Invalid type {} for '{}'.", val.type(), \
#CYPHER_OP); \
} \
#define UNARY_OPERATOR_VISITOR(OP_NODE, CPP_OP, CYPHER_OP) \
TypedValue Visit(OP_NODE &op) override { \
auto val = op.expression_->Accept(*this); \
try { \
return CPP_OP val; \
} catch (const TypedValueException &) { \
throw QueryRuntimeException("Invalid type {} for '{}'.", val.type(), #CYPHER_OP); \
} \
}
BINARY_OPERATOR_VISITOR(OrOperator, ||, OR);
@ -99,8 +92,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
try {
return value1 && value2;
} catch (const TypedValueException &) {
throw QueryRuntimeException("Invalid types: {} and {} for AND.",
value1.type(), value2.type());
throw QueryRuntimeException("Invalid types: {} and {} for AND.", value1.type(), value2.type());
}
}
@ -111,8 +103,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
}
if (condition.type() != TypedValue::Type::Bool) {
// At the moment IfOperator is used only in CASE construct.
throw QueryRuntimeException("CASE expected boolean expression, got {}.",
condition.type());
throw QueryRuntimeException("CASE expected boolean expression, got {}.", condition.type());
}
if (condition.ValueBool()) {
return if_operator.then_expression_->Accept(*this);
@ -158,17 +149,14 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
TypedValue Visit(SubscriptOperator &list_indexing) override {
auto lhs = list_indexing.expression1_->Accept(*this);
auto index = list_indexing.expression2_->Accept(*this);
if (!lhs.IsList() && !lhs.IsMap() && !lhs.IsVertex() && !lhs.IsEdge() &&
!lhs.IsNull())
if (!lhs.IsList() && !lhs.IsMap() && !lhs.IsVertex() && !lhs.IsEdge() && !lhs.IsNull())
throw QueryRuntimeException(
"Expected a list, a map, a node or an edge to index with '[]', got "
"{}.",
lhs.type());
if (lhs.IsNull() || index.IsNull()) return TypedValue(ctx_->memory);
if (lhs.IsList()) {
if (!index.IsInt())
throw QueryRuntimeException(
"Expected an integer as a list index, got {}.", index.type());
if (!index.IsInt()) throw QueryRuntimeException("Expected an integer as a list index, got {}.", index.type());
auto index_int = index.ValueInt();
// NOTE: Take non-const reference to list, so that we can move out the
// indexed element as the result.
@ -176,17 +164,14 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
if (index_int < 0) {
index_int += static_cast<int64_t>(list.size());
}
if (index_int >= static_cast<int64_t>(list.size()) || index_int < 0)
return TypedValue(ctx_->memory);
if (index_int >= static_cast<int64_t>(list.size()) || index_int < 0) return TypedValue(ctx_->memory);
// NOTE: Explicit move is needed, so that we return the move constructed
// value and preserve the correct MemoryResource.
return std::move(list[index_int]);
}
if (lhs.IsMap()) {
if (!index.IsString())
throw QueryRuntimeException("Expected a string as a map index, got {}.",
index.type());
if (!index.IsString()) throw QueryRuntimeException("Expected a string as a map index, got {}.", index.type());
// NOTE: Take non-const reference to map, so that we can move out the
// looked-up element as the result.
auto &map = lhs.ValueMap();
@ -198,19 +183,13 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
}
if (lhs.IsVertex()) {
if (!index.IsString())
throw QueryRuntimeException(
"Expected a string as a property name, got {}.", index.type());
return TypedValue(GetProperty(lhs.ValueVertex(), index.ValueString()),
ctx_->memory);
if (!index.IsString()) throw QueryRuntimeException("Expected a string as a property name, got {}.", index.type());
return TypedValue(GetProperty(lhs.ValueVertex(), index.ValueString()), ctx_->memory);
}
if (lhs.IsEdge()) {
if (!index.IsString())
throw QueryRuntimeException(
"Expected a string as a property name, got {}.", index.type());
return TypedValue(GetProperty(lhs.ValueEdge(), index.ValueString()),
ctx_->memory);
if (!index.IsString()) throw QueryRuntimeException("Expected a string as a property name, got {}.", index.type());
return TypedValue(GetProperty(lhs.ValueEdge(), index.ValueString()), ctx_->memory);
}
// lhs is Null
@ -227,24 +206,20 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
if (bound.type() == TypedValue::Type::Null) {
is_null = true;
} else if (bound.type() != TypedValue::Type::Int) {
throw QueryRuntimeException(
"Expected an integer for a bound in list slicing, got {}.",
bound.type());
throw QueryRuntimeException("Expected an integer for a bound in list slicing, got {}.", bound.type());
}
return bound;
}
return TypedValue(default_value, ctx_->memory);
};
auto _upper_bound =
get_bound(op.upper_bound_, std::numeric_limits<int64_t>::max());
auto _upper_bound = get_bound(op.upper_bound_, std::numeric_limits<int64_t>::max());
auto _lower_bound = get_bound(op.lower_bound_, 0);
auto _list = op.list_->Accept(*this);
if (_list.type() == TypedValue::Type::Null) {
is_null = true;
} else if (_list.type() != TypedValue::Type::List) {
throw QueryRuntimeException("Expected a list to slice, got {}.",
_list.type());
throw QueryRuntimeException("Expected a list to slice, got {}.", _list.type());
}
if (is_null) {
@ -255,16 +230,14 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
if (bound < 0) {
bound = static_cast<int64_t>(list.size()) + bound;
}
return std::max(static_cast<int64_t>(0),
std::min(bound, static_cast<int64_t>(list.size())));
return std::max(static_cast<int64_t>(0), std::min(bound, static_cast<int64_t>(list.size())));
};
auto lower_bound = normalise_bound(_lower_bound.ValueInt());
auto upper_bound = normalise_bound(_upper_bound.ValueInt());
if (upper_bound <= lower_bound) {
return TypedValue(TypedValue::TVector(ctx_->memory), ctx_->memory);
}
return TypedValue(TypedValue::TVector(
list.begin() + lower_bound, list.begin() + upper_bound, ctx_->memory));
return TypedValue(TypedValue::TVector(list.begin() + lower_bound, list.begin() + upper_bound, ctx_->memory));
}
TypedValue Visit(IsNullOperator &is_null) override {
@ -278,13 +251,9 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
case TypedValue::Type::Null:
return TypedValue(ctx_->memory);
case TypedValue::Type::Vertex:
return TypedValue(GetProperty(expression_result.ValueVertex(),
property_lookup.property_),
ctx_->memory);
return TypedValue(GetProperty(expression_result.ValueVertex(), property_lookup.property_), ctx_->memory);
case TypedValue::Type::Edge:
return TypedValue(GetProperty(expression_result.ValueEdge(),
property_lookup.property_),
ctx_->memory);
return TypedValue(GetProperty(expression_result.ValueEdge(), property_lookup.property_), ctx_->memory);
case TypedValue::Type::Map: {
// NOTE: Take non-const reference to map, so that we can move out the
// looked-up element as the result.
@ -296,8 +265,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
return std::move(found->second);
}
default:
throw QueryRuntimeException(
"Only nodes, edges and maps have properties to be looked-up.");
throw QueryRuntimeException("Only nodes, edges and maps have properties to be looked-up.");
}
}
@ -310,8 +278,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
const auto &vertex = expression_result.ValueVertex();
for (const auto &label : labels_test.labels_) {
auto has_label = vertex.HasLabel(view_, GetLabel(label));
if (has_label.HasError() &&
has_label.GetError() == storage::Error::NONEXISTENT_OBJECT) {
if (has_label.HasError() && has_label.GetError() == storage::Error::NONEXISTENT_OBJECT) {
// This is a very nasty and temporary hack in order to make MERGE
// work. The old storage had the following logic when returning an
// `OLD` view: `return old ? old : new`. That means that if the
@ -324,16 +291,13 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
if (has_label.HasError()) {
switch (has_label.GetError()) {
case storage::Error::DELETED_OBJECT:
throw QueryRuntimeException(
"Trying to access labels on a deleted node.");
throw QueryRuntimeException("Trying to access labels on a deleted node.");
case storage::Error::NONEXISTENT_OBJECT:
throw query::QueryRuntimeException(
"Trying to access labels from a node that doesn't exist.");
throw query::QueryRuntimeException("Trying to access labels from a node that doesn't exist.");
case storage::Error::SERIALIZATION_ERROR:
case storage::Error::VERTEX_HAS_EDGES:
case storage::Error::PROPERTIES_DISABLED:
throw QueryRuntimeException(
"Unexpected error when accessing labels.");
throw QueryRuntimeException("Unexpected error when accessing labels.");
}
}
if (!*has_label) {
@ -356,15 +320,13 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
TypedValue Visit(ListLiteral &literal) override {
TypedValue::TVector result(ctx_->memory);
result.reserve(literal.elements_.size());
for (const auto &expression : literal.elements_)
result.emplace_back(expression->Accept(*this));
for (const auto &expression : literal.elements_) result.emplace_back(expression->Accept(*this));
return TypedValue(result, ctx_->memory);
}
TypedValue Visit(MapLiteral &literal) override {
TypedValue::TMap result(ctx_->memory);
for (const auto &pair : literal.elements_)
result.emplace(pair.first.name, pair.second->Accept(*this));
for (const auto &pair : literal.elements_) result.emplace(pair.first.name, pair.second->Accept(*this));
return TypedValue(result, ctx_->memory);
}
@ -390,20 +352,16 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
}
TypedValue Visit(Function &function) override {
FunctionContext function_ctx{dba_, ctx_->memory, ctx_->timestamp,
&ctx_->counters, view_};
FunctionContext function_ctx{dba_, ctx_->memory, ctx_->timestamp, &ctx_->counters, view_};
// Stack allocate evaluated arguments when there's a small number of them.
if (function.arguments_.size() <= 8) {
TypedValue arguments[8] = {
TypedValue(ctx_->memory), TypedValue(ctx_->memory),
TypedValue(ctx_->memory), TypedValue(ctx_->memory),
TypedValue(ctx_->memory), TypedValue(ctx_->memory),
TypedValue(ctx_->memory), TypedValue(ctx_->memory)};
TypedValue arguments[8] = {TypedValue(ctx_->memory), TypedValue(ctx_->memory), TypedValue(ctx_->memory),
TypedValue(ctx_->memory), TypedValue(ctx_->memory), TypedValue(ctx_->memory),
TypedValue(ctx_->memory), TypedValue(ctx_->memory)};
for (size_t i = 0; i < function.arguments_.size(); ++i) {
arguments[i] = function.arguments_[i]->Accept(*this);
}
auto res = function.function_(arguments, function.arguments_.size(),
function_ctx);
auto res = function.function_(arguments, function.arguments_.size(), function_ctx);
MG_ASSERT(res.GetMemoryResource() == ctx_->memory);
return res;
} else {
@ -412,8 +370,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
for (const auto &argument : function.arguments_) {
arguments.emplace_back(argument->Accept(*this));
}
auto res =
function.function_(arguments.data(), arguments.size(), function_ctx);
auto res = function.function_(arguments.data(), arguments.size(), function_ctx);
MG_ASSERT(res.GetMemoryResource() == ctx_->memory);
return res;
}
@ -425,8 +382,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
return TypedValue(ctx_->memory);
}
if (list_value.type() != TypedValue::Type::List) {
throw QueryRuntimeException("REDUCE expected a list, got {}.",
list_value.type());
throw QueryRuntimeException("REDUCE expected a list, got {}.", list_value.type());
}
const auto &list = list_value.ValueList();
const auto &element_symbol = symbol_table_->at(*reduce.identifier_);
@ -446,8 +402,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
return TypedValue(ctx_->memory);
}
if (list_value.type() != TypedValue::Type::List) {
throw QueryRuntimeException("EXTRACT expected a list, got {}.",
list_value.type());
throw QueryRuntimeException("EXTRACT expected a list, got {}.", list_value.type());
}
const auto &list = list_value.ValueList();
const auto &element_symbol = symbol_table_->at(*extract.identifier_);
@ -470,8 +425,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
return TypedValue(ctx_->memory);
}
if (list_value.type() != TypedValue::Type::List) {
throw QueryRuntimeException("ALL expected a list, got {}.",
list_value.type());
throw QueryRuntimeException("ALL expected a list, got {}.", list_value.type());
}
const auto &list = list_value.ValueList();
const auto &symbol = symbol_table_->at(*all.identifier_);
@ -481,9 +435,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
frame_->at(symbol) = element;
auto result = all.where_->expression_->Accept(*this);
if (!result.IsNull() && result.type() != TypedValue::Type::Bool) {
throw QueryRuntimeException(
"Predicate of ALL must evaluate to boolean, got {}.",
result.type());
throw QueryRuntimeException("Predicate of ALL must evaluate to boolean, got {}.", result.type());
}
if (!result.IsNull()) {
has_value = true;
@ -510,8 +462,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
return TypedValue(ctx_->memory);
}
if (list_value.type() != TypedValue::Type::List) {
throw QueryRuntimeException("SINGLE expected a list, got {}.",
list_value.type());
throw QueryRuntimeException("SINGLE expected a list, got {}.", list_value.type());
}
const auto &list = list_value.ValueList();
const auto &symbol = symbol_table_->at(*single.identifier_);
@ -521,9 +472,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
frame_->at(symbol) = element;
auto result = single.where_->expression_->Accept(*this);
if (!result.IsNull() && result.type() != TypedValue::Type::Bool) {
throw QueryRuntimeException(
"Predicate of SINGLE must evaluate to boolean, got {}.",
result.type());
throw QueryRuntimeException("Predicate of SINGLE must evaluate to boolean, got {}.", result.type());
}
if (result.type() == TypedValue::Type::Bool) {
has_value = true;
@ -551,8 +500,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
return TypedValue(ctx_->memory);
}
if (list_value.type() != TypedValue::Type::List) {
throw QueryRuntimeException("ANY expected a list, got {}.",
list_value.type());
throw QueryRuntimeException("ANY expected a list, got {}.", list_value.type());
}
const auto &list = list_value.ValueList();
const auto &symbol = symbol_table_->at(*any.identifier_);
@ -561,9 +509,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
frame_->at(symbol) = element;
auto result = any.where_->expression_->Accept(*this);
if (!result.IsNull() && result.type() != TypedValue::Type::Bool) {
throw QueryRuntimeException(
"Predicate of ANY must evaluate to boolean, got {}.",
result.type());
throw QueryRuntimeException("Predicate of ANY must evaluate to boolean, got {}.", result.type());
}
if (!result.IsNull()) {
has_value = true;
@ -586,8 +532,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
return TypedValue(ctx_->memory);
}
if (list_value.type() != TypedValue::Type::List) {
throw QueryRuntimeException("NONE expected a list, got {}.",
list_value.type());
throw QueryRuntimeException("NONE expected a list, got {}.", list_value.type());
}
const auto &list = list_value.ValueList();
const auto &symbol = symbol_table_->at(*none.identifier_);
@ -596,9 +541,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
frame_->at(symbol) = element;
auto result = none.where_->expression_->Accept(*this);
if (!result.IsNull() && result.type() != TypedValue::Type::Bool) {
throw QueryRuntimeException(
"Predicate of NONE must evaluate to boolean, got {}.",
result.type());
throw QueryRuntimeException("Predicate of NONE must evaluate to boolean, got {}.", result.type());
}
if (!result.IsNull()) {
has_value = true;
@ -616,9 +559,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
}
TypedValue Visit(ParameterLookup &param_lookup) override {
return TypedValue(
ctx_->parameters.AtTokenPosition(param_lookup.token_position_),
ctx_->memory);
return TypedValue(ctx_->parameters.AtTokenPosition(param_lookup.token_position_), ctx_->memory);
}
TypedValue Visit(RegexMatch &regex_match) override {
@ -628,9 +569,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
return TypedValue(ctx_->memory);
}
if (regex_value.type() != TypedValue::Type::String) {
throw QueryRuntimeException(
"Regular expression must evaluate to a string, got {}.",
regex_value.type());
throw QueryRuntimeException("Regular expression must evaluate to a string, got {}.", regex_value.type());
}
if (target_string_value.type() != TypedValue::Type::String) {
// Instead of error, we return Null which makes it compatible in case we
@ -643,75 +582,60 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
std::regex regex(regex_value.ValueString());
return TypedValue(std::regex_match(target_string, regex), ctx_->memory);
} catch (const std::regex_error &e) {
throw QueryRuntimeException("Regex error in '{}': {}",
regex_value.ValueString(), e.what());
throw QueryRuntimeException("Regex error in '{}': {}", regex_value.ValueString(), e.what());
}
}
private:
template <class TRecordAccessor>
storage::PropertyValue GetProperty(const TRecordAccessor &record_accessor,
PropertyIx prop) {
auto maybe_prop =
record_accessor.GetProperty(view_, ctx_->properties[prop.ix]);
if (maybe_prop.HasError() &&
maybe_prop.GetError() == storage::Error::NONEXISTENT_OBJECT) {
storage::PropertyValue GetProperty(const TRecordAccessor &record_accessor, PropertyIx prop) {
auto maybe_prop = record_accessor.GetProperty(view_, ctx_->properties[prop.ix]);
if (maybe_prop.HasError() && maybe_prop.GetError() == storage::Error::NONEXISTENT_OBJECT) {
// This is a very nasty and temporary hack in order to make MERGE work.
// The old storage had the following logic when returning an `OLD` view:
// `return old ? old : new`. That means that if the `OLD` view didn't
// exist, it returned the NEW view. With this hack we simulate that
// behavior.
// TODO (mferencevic, teon.banek): Remove once MERGE is reimplemented.
maybe_prop = record_accessor.GetProperty(storage::View::NEW,
ctx_->properties[prop.ix]);
maybe_prop = record_accessor.GetProperty(storage::View::NEW, ctx_->properties[prop.ix]);
}
if (maybe_prop.HasError()) {
switch (maybe_prop.GetError()) {
case storage::Error::DELETED_OBJECT:
throw QueryRuntimeException(
"Trying to get a property from a deleted object.");
throw QueryRuntimeException("Trying to get a property from a deleted object.");
case storage::Error::NONEXISTENT_OBJECT:
throw query::QueryRuntimeException(
"Trying to get a property from an object that doesn't exist.");
throw query::QueryRuntimeException("Trying to get a property from an object that doesn't exist.");
case storage::Error::SERIALIZATION_ERROR:
case storage::Error::VERTEX_HAS_EDGES:
case storage::Error::PROPERTIES_DISABLED:
throw QueryRuntimeException(
"Unexpected error when getting a property.");
throw QueryRuntimeException("Unexpected error when getting a property.");
}
}
return *maybe_prop;
}
template <class TRecordAccessor>
storage::PropertyValue GetProperty(const TRecordAccessor &record_accessor,
const std::string_view &name) {
auto maybe_prop =
record_accessor.GetProperty(view_, dba_->NameToProperty(name));
if (maybe_prop.HasError() &&
maybe_prop.GetError() == storage::Error::NONEXISTENT_OBJECT) {
storage::PropertyValue GetProperty(const TRecordAccessor &record_accessor, const std::string_view &name) {
auto maybe_prop = record_accessor.GetProperty(view_, dba_->NameToProperty(name));
if (maybe_prop.HasError() && maybe_prop.GetError() == storage::Error::NONEXISTENT_OBJECT) {
// This is a very nasty and temporary hack in order to make MERGE work.
// The old storage had the following logic when returning an `OLD` view:
// `return old ? old : new`. That means that if the `OLD` view didn't
// exist, it returned the NEW view. With this hack we simulate that
// behavior.
// TODO (mferencevic, teon.banek): Remove once MERGE is reimplemented.
maybe_prop =
record_accessor.GetProperty(view_, dba_->NameToProperty(name));
maybe_prop = record_accessor.GetProperty(view_, dba_->NameToProperty(name));
}
if (maybe_prop.HasError()) {
switch (maybe_prop.GetError()) {
case storage::Error::DELETED_OBJECT:
throw QueryRuntimeException(
"Trying to get a property from a deleted object.");
throw QueryRuntimeException("Trying to get a property from a deleted object.");
case storage::Error::NONEXISTENT_OBJECT:
throw query::QueryRuntimeException(
"Trying to get a property from an object that doesn't exist.");
throw query::QueryRuntimeException("Trying to get a property from an object that doesn't exist.");
case storage::Error::SERIALIZATION_ERROR:
case storage::Error::VERTEX_HAS_EDGES:
case storage::Error::PROPERTIES_DISABLED:
throw QueryRuntimeException(
"Unexpected error when getting a property.");
throw QueryRuntimeException("Unexpected error when getting a property.");
}
}
return *maybe_prop;
@ -732,8 +656,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
/// @param what - Name of what's getting evaluated. Used for user feedback (via
/// exception) when the evaluated value is not an int.
/// @throw QueryRuntimeException if expression doesn't evaluate to an int.
inline int64_t EvaluateInt(ExpressionEvaluator *evaluator, Expression *expr,
const std::string &what) {
inline int64_t EvaluateInt(ExpressionEvaluator *evaluator, Expression *expr, const std::string &what) {
TypedValue value = expr->Accept(*evaluator);
try {
return value.ValueInt();

View File

@ -13,31 +13,19 @@ namespace query {
class Frame {
public:
/// Create a Frame of given size backed by a utils::NewDeleteResource()
explicit Frame(int64_t size) : elems_(size, utils::NewDeleteResource()) {
MG_ASSERT(size >= 0);
}
explicit Frame(int64_t size) : elems_(size, utils::NewDeleteResource()) { MG_ASSERT(size >= 0); }
Frame(int64_t size, utils::MemoryResource *memory) : elems_(size, memory) {
MG_ASSERT(size >= 0);
}
Frame(int64_t size, utils::MemoryResource *memory) : elems_(size, memory) { MG_ASSERT(size >= 0); }
TypedValue &operator[](const Symbol &symbol) {
return elems_[symbol.position()];
}
const TypedValue &operator[](const Symbol &symbol) const {
return elems_[symbol.position()];
}
TypedValue &operator[](const Symbol &symbol) { return elems_[symbol.position()]; }
const TypedValue &operator[](const Symbol &symbol) const { return elems_[symbol.position()]; }
TypedValue &at(const Symbol &symbol) { return elems_.at(symbol.position()); }
const TypedValue &at(const Symbol &symbol) const {
return elems_.at(symbol.position());
}
const TypedValue &at(const Symbol &symbol) const { return elems_.at(symbol.position()); }
auto &elems() { return elems_; }
utils::MemoryResource *GetMemoryResource() const {
return elems_.get_allocator().GetMemoryResource();
}
utils::MemoryResource *GetMemoryResource() const { return elems_.get_allocator().GetMemoryResource(); }
private:
utils::pmr::vector<TypedValue> elems_;

File diff suppressed because it is too large Load Diff

View File

@ -39,16 +39,14 @@ class AuthQueryHandler {
/// Return false if the user already exists.
/// @throw QueryRuntimeException if an error ocurred.
virtual bool CreateUser(const std::string &username,
const std::optional<std::string> &password) = 0;
virtual bool CreateUser(const std::string &username, const std::optional<std::string> &password) = 0;
/// Return false if the user does not exist.
/// @throw QueryRuntimeException if an error ocurred.
virtual bool DropUser(const std::string &username) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void SetPassword(const std::string &username,
const std::optional<std::string> &password) = 0;
virtual void SetPassword(const std::string &username, const std::optional<std::string> &password) = 0;
/// Return false if the role already exists.
/// @throw QueryRuntimeException if an error ocurred.
@ -65,37 +63,28 @@ class AuthQueryHandler {
virtual std::vector<TypedValue> GetRolenames() = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual std::optional<std::string> GetRolenameForUser(
const std::string &username) = 0;
virtual std::optional<std::string> GetRolenameForUser(const std::string &username) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual std::vector<TypedValue> GetUsernamesForRole(
const std::string &rolename) = 0;
virtual std::vector<TypedValue> GetUsernamesForRole(const std::string &rolename) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void SetRole(const std::string &username,
const std::string &rolename) = 0;
virtual void SetRole(const std::string &username, const std::string &rolename) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void ClearRole(const std::string &username) = 0;
virtual std::vector<std::vector<TypedValue>> GetPrivileges(
const std::string &user_or_role) = 0;
virtual std::vector<std::vector<TypedValue>> GetPrivileges(const std::string &user_or_role) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void GrantPrivilege(
const std::string &user_or_role,
const std::vector<AuthQuery::Privilege> &privileges) = 0;
virtual void GrantPrivilege(const std::string &user_or_role, const std::vector<AuthQuery::Privilege> &privileges) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void DenyPrivilege(
const std::string &user_or_role,
const std::vector<AuthQuery::Privilege> &privileges) = 0;
virtual void DenyPrivilege(const std::string &user_or_role, const std::vector<AuthQuery::Privilege> &privileges) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void RevokePrivilege(
const std::string &user_or_role,
const std::vector<AuthQuery::Privilege> &privileges) = 0;
virtual void RevokePrivilege(const std::string &user_or_role,
const std::vector<AuthQuery::Privilege> &privileges) = 0;
};
enum class QueryHandlerResult { COMMIT, ABORT, NOTHING };
@ -119,18 +108,14 @@ class ReplicationQueryHandler {
};
/// @throw QueryRuntimeException if an error ocurred.
virtual void SetReplicationRole(
ReplicationQuery::ReplicationRole replication_role,
std::optional<int64_t> port) = 0;
virtual void SetReplicationRole(ReplicationQuery::ReplicationRole replication_role, std::optional<int64_t> port) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual ReplicationQuery::ReplicationRole ShowReplicationRole() const = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void RegisterReplica(const std::string &name,
const std::string &socket_address,
const ReplicationQuery::SyncMode sync_mode,
const std::optional<double> timeout) = 0;
virtual void RegisterReplica(const std::string &name, const std::string &socket_address,
const ReplicationQuery::SyncMode sync_mode, const std::optional<double> timeout) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void DropReplica(const std::string &replica_name) = 0;
@ -145,9 +130,7 @@ class ReplicationQueryHandler {
struct PreparedQuery {
std::vector<std::string> header;
std::vector<AuthQuery::Privilege> privileges;
std::function<std::optional<QueryHandlerResult>(AnyStream *stream,
std::optional<int> n)>
query_handler;
std::function<std::optional<QueryHandlerResult>(AnyStream *stream, std::optional<int> n)> query_handler;
plan::ReadWriteTypeChecker::RWType rw_type;
};
@ -172,10 +155,7 @@ class CachedPlan {
const auto &symbol_table() const { return plan_->GetSymbolTable(); }
const auto &ast_storage() const { return plan_->GetAstStorage(); }
bool IsExpired() const {
return cache_timer_.Elapsed() >
std::chrono::seconds(FLAGS_query_plan_cache_ttl);
};
bool IsExpired() const { return cache_timer_.Elapsed() > std::chrono::seconds(FLAGS_query_plan_cache_ttl); };
private:
std::unique_ptr<LogicalPlan> plan_;
@ -189,12 +169,8 @@ struct CachedQuery {
};
struct QueryCacheEntry {
bool operator==(const QueryCacheEntry &other) const {
return first == other.first;
}
bool operator<(const QueryCacheEntry &other) const {
return first < other.first;
}
bool operator==(const QueryCacheEntry &other) const { return first == other.first; }
bool operator<(const QueryCacheEntry &other) const { return first < other.first; }
bool operator==(const uint64_t &other) const { return first == other; }
bool operator<(const uint64_t &other) const { return first < other; }
@ -205,12 +181,8 @@ struct QueryCacheEntry {
};
struct PlanCacheEntry {
bool operator==(const PlanCacheEntry &other) const {
return first == other.first;
}
bool operator<(const PlanCacheEntry &other) const {
return first < other.first;
}
bool operator==(const PlanCacheEntry &other) const { return first == other.first; }
bool operator<(const PlanCacheEntry &other) const { return first < other.first; }
bool operator==(const uint64_t &other) const { return first == other; }
bool operator<(const uint64_t &other) const { return first < other; }
@ -228,9 +200,7 @@ struct PlanCacheEntry {
* been passed to an `Interpreter` instance.
*/
struct InterpreterContext {
explicit InterpreterContext(storage::Storage *db) : db(db) {
MG_ASSERT(db, "Storage must not be NULL");
}
explicit InterpreterContext(storage::Storage *db) : db(db) { MG_ASSERT(db, "Storage must not be NULL"); }
storage::Storage *db;
@ -254,9 +224,7 @@ struct InterpreterContext {
/// Function that is used to tell all active interpreters that they should stop
/// their ongoing execution.
inline void Shutdown(InterpreterContext *context) {
context->is_shutting_down.store(true, std::memory_order_release);
}
inline void Shutdown(InterpreterContext *context) { context->is_shutting_down.store(true, std::memory_order_release); }
/// Function used to set the maximum execution timeout in seconds.
inline void SetExecutionTimeout(InterpreterContext *context, double timeout) {
@ -286,9 +254,7 @@ class Interpreter final {
*
* @throw query::QueryException
*/
PrepareResult Prepare(
const std::string &query,
const std::map<std::string, storage::PropertyValue> &params);
PrepareResult Prepare(const std::string &query, const std::map<std::string, storage::PropertyValue> &params);
/**
* Execute the last prepared query and stream *all* of the results into the
@ -329,8 +295,7 @@ class Interpreter final {
* @throw query::QueryException
*/
template <typename TStream>
std::map<std::string, TypedValue> Pull(TStream *result_stream,
std::optional<int> n = {},
std::map<std::string, TypedValue> Pull(TStream *result_stream, std::optional<int> n = {},
std::optional<int> qid = {});
void BeginTransaction();
@ -392,35 +357,27 @@ class Interpreter final {
size_t ActiveQueryExecutions() {
return std::count_if(query_executions_.begin(), query_executions_.end(),
[](const auto &execution) {
return execution && execution->prepared_query;
});
[](const auto &execution) { return execution && execution->prepared_query; });
}
};
template <typename TStream>
std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream,
std::optional<int> n,
std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream, std::optional<int> n,
std::optional<int> qid) {
MG_ASSERT(in_explicit_transaction_ || !qid,
"qid can be only used in explicit transaction!");
const int qid_value =
qid ? *qid : static_cast<int>(query_executions_.size() - 1);
MG_ASSERT(in_explicit_transaction_ || !qid, "qid can be only used in explicit transaction!");
const int qid_value = qid ? *qid : static_cast<int>(query_executions_.size() - 1);
if (qid_value < 0 || qid_value >= query_executions_.size()) {
throw InvalidArgumentsException("qid",
"Query with specified ID does not exist!");
throw InvalidArgumentsException("qid", "Query with specified ID does not exist!");
}
if (n && n < 0) {
throw InvalidArgumentsException("n",
"Cannot fetch negative number of results!");
throw InvalidArgumentsException("n", "Cannot fetch negative number of results!");
}
auto &query_execution = query_executions_[qid_value];
MG_ASSERT(query_execution && query_execution->prepared_query,
"Query already finished executing!");
MG_ASSERT(query_execution && query_execution->prepared_query, "Query already finished executing!");
// Each prepared query has its own summary so we need to somehow preserve
// it after it finishes executing because it gets destroyed alongside
@ -430,8 +387,7 @@ std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream,
// Wrap the (statically polymorphic) stream type into a common type which
// the handler knows.
AnyStream stream{result_stream, &query_execution->execution_memory};
const auto maybe_res =
query_execution->prepared_query->query_handler(&stream, n);
const auto maybe_res = query_execution->prepared_query->query_handler(&stream, n);
// Stream is using execution memory of the query_execution which
// can be deleted after its execution so the stream should be cleared
// first.

View File

@ -21,9 +21,7 @@ struct Parameters {
* @param position Token position in query of value.
* @param value
*/
void Add(int position, const storage::PropertyValue &value) {
storage_.emplace_back(position, value);
}
void Add(int position, const storage::PropertyValue &value) { storage_.emplace_back(position, value); }
/**
* Returns the value found for the given token position.
@ -32,11 +30,8 @@ struct Parameters {
* @return Value for the given token position.
*/
const storage::PropertyValue &AtTokenPosition(int position) const {
auto found =
std::find_if(storage_.begin(), storage_.end(),
[&](const auto &a) { return a.first == position; });
MG_ASSERT(found != storage_.end(),
"Token position must be present in container");
auto found = std::find_if(storage_.begin(), storage_.end(), [&](const auto &a) { return a.first == position; });
MG_ASSERT(found != storage_.end(), "Token position must be present in container");
return found->second;
}

View File

@ -24,8 +24,7 @@ class Path {
* Create the path starting with the given vertex.
* Allocations are done using the given MemoryResource.
*/
explicit Path(const VertexAccessor &vertex,
utils::MemoryResource *memory = utils::NewDeleteResource())
explicit Path(const VertexAccessor &vertex, utils::MemoryResource *memory = utils::NewDeleteResource())
: vertices_(memory), edges_(memory) {
Expand(vertex);
}
@ -37,8 +36,7 @@ class Path {
*/
template <typename... TOthers>
explicit Path(const VertexAccessor &vertex, const TOthers &...others)
: vertices_(utils::NewDeleteResource()),
edges_(utils::NewDeleteResource()) {
: vertices_(utils::NewDeleteResource()), edges_(utils::NewDeleteResource()) {
Expand(vertex);
Expand(others...);
}
@ -49,8 +47,7 @@ class Path {
* Allocations are done using the given MemoryResource.
*/
template <typename... TOthers>
Path(std::allocator_arg_t, utils::MemoryResource *memory,
const VertexAccessor &vertex, const TOthers &...others)
Path(std::allocator_arg_t, utils::MemoryResource *memory, const VertexAccessor &vertex, const TOthers &...others)
: vertices_(memory), edges_(memory) {
Expand(vertex);
Expand(others...);
@ -65,10 +62,9 @@ class Path {
* will default to utils::NewDeleteResource().
*/
Path(const Path &other)
: Path(other, std::allocator_traits<allocator_type>::
select_on_container_copy_construction(
other.GetMemoryResource())
.GetMemoryResource()) {}
: Path(other,
std::allocator_traits<allocator_type>::select_on_container_copy_construction(other.GetMemoryResource())
.GetMemoryResource()) {}
/** Construct a copy using the given utils::MemoryResource */
Path(const Path &other, utils::MemoryResource *memory)
@ -79,8 +75,7 @@ class Path {
* utils::MemoryResource is obtained from other. After the move, other will be
* empty.
*/
Path(Path &&other) noexcept
: Path(std::move(other), other.GetMemoryResource()) {}
Path(Path &&other) noexcept : Path(std::move(other), other.GetMemoryResource()) {}
/**
* Construct with the value of other, but use the given utils::MemoryResource.
@ -89,8 +84,7 @@ class Path {
* performed.
*/
Path(Path &&other, utils::MemoryResource *memory)
: vertices_(std::move(other.vertices_), memory),
edges_(std::move(other.edges_), memory) {}
: vertices_(std::move(other.vertices_), memory), edges_(std::move(other.edges_), memory) {}
/** Copy assign other, utils::MemoryResource of `this` is used */
Path &operator=(const Path &) = default;
@ -102,15 +96,13 @@ class Path {
/** Expands the path with the given vertex. */
void Expand(const VertexAccessor &vertex) {
DMG_ASSERT(vertices_.size() == edges_.size(),
"Illegal path construction order");
DMG_ASSERT(vertices_.size() == edges_.size(), "Illegal path construction order");
vertices_.emplace_back(vertex);
}
/** Expands the path with the given edge. */
void Expand(const EdgeAccessor &edge) {
DMG_ASSERT(vertices_.size() - 1 == edges_.size(),
"Illegal path construction order");
DMG_ASSERT(vertices_.size() - 1 == edges_.size(), "Illegal path construction order");
edges_.emplace_back(edge);
}
@ -129,13 +121,9 @@ class Path {
const auto &vertices() const { return vertices_; }
const auto &edges() const { return edges_; }
utils::MemoryResource *GetMemoryResource() const {
return vertices_.get_allocator().GetMemoryResource();
}
utils::MemoryResource *GetMemoryResource() const { return vertices_.get_allocator().GetMemoryResource(); }
bool operator==(const Path &other) const {
return vertices_ == other.vertices_ && edges_ == other.edges_;
}
bool operator==(const Path &other) const { return vertices_ == other.vertices_ && edges_ == other.edges_; }
private:
// Contains all the vertices in the path.

View File

@ -91,13 +91,10 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor {
double factor = 1.0;
if (property_value)
// get the exact influence based on ScanAll(label, property, value)
factor = db_accessor_->VerticesCount(
logical_op.label_, logical_op.property_, property_value.value());
factor = db_accessor_->VerticesCount(logical_op.label_, logical_op.property_, property_value.value());
else
// estimate the influence as ScanAll(label, property) * filtering
factor =
db_accessor_->VerticesCount(logical_op.label_, logical_op.property_) *
CardParam::kFilter;
factor = db_accessor_->VerticesCount(logical_op.label_, logical_op.property_) * CardParam::kFilter;
cardinality_ *= factor;
@ -115,18 +112,14 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor {
int64_t factor = 1;
if (upper || lower)
// if we have either Bound<PropertyValue>, use the value index
factor = db_accessor_->VerticesCount(logical_op.label_,
logical_op.property_, lower, upper);
factor = db_accessor_->VerticesCount(logical_op.label_, logical_op.property_, lower, upper);
else
// no values, but we still have the label
factor =
db_accessor_->VerticesCount(logical_op.label_, logical_op.property_);
factor = db_accessor_->VerticesCount(logical_op.label_, logical_op.property_);
// if we failed to take either bound from the op into account, then apply
// the filtering constant to the factor
if ((logical_op.upper_bound_ && !upper) ||
(logical_op.lower_bound_ && !lower))
factor *= CardParam::kFilter;
if ((logical_op.upper_bound_ && !upper) || (logical_op.lower_bound_ && !lower)) factor *= CardParam::kFilter;
cardinality_ *= factor;
@ -136,8 +129,7 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor {
}
bool PostVisit(ScanAllByLabelProperty &logical_op) override {
const auto factor =
db_accessor_->VerticesCount(logical_op.label_, logical_op.property_);
const auto factor = db_accessor_->VerticesCount(logical_op.label_, logical_op.property_);
cardinality_ *= factor;
IncrementCost(CostParam::MakeScanAllByLabelProperty);
return true;
@ -181,8 +173,7 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor {
// if the Unwind expression is a list literal, we can deduce cardinality
// exactly, otherwise we approximate
double unwind_value;
if (auto *literal =
utils::Downcast<query::ListLiteral>(unwind.input_expression_))
if (auto *literal = utils::Downcast<query::ListLiteral>(unwind.input_expression_))
unwind_value = literal->elements_.size();
else
unwind_value = MiscParam::kUnwindNoLiteral;
@ -218,21 +209,17 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor {
std::optional<ScanAllByLabelPropertyRange::Bound> bound) {
if (bound) {
auto property_value = ConstPropertyValue(bound->value());
if (property_value)
return utils::Bound<storage::PropertyValue>(*property_value,
bound->type());
if (property_value) return utils::Bound<storage::PropertyValue>(*property_value, bound->type());
}
return std::nullopt;
}
// If the expression is a constant property value, it is returned. Otherwise,
// return nullopt.
std::optional<storage::PropertyValue> ConstPropertyValue(
const Expression *expression) {
std::optional<storage::PropertyValue> ConstPropertyValue(const Expression *expression) {
if (auto *literal = utils::Downcast<const PrimitiveLiteral>(expression)) {
return literal->value_;
} else if (auto *param_lookup =
utils::Downcast<const ParameterLookup>(expression)) {
} else if (auto *param_lookup = utils::Downcast<const ParameterLookup>(expression)) {
return parameters.AtTokenPosition(param_lookup->token_position_);
}
return std::nullopt;
@ -241,8 +228,7 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor {
/** Returns the estimated cost of the given plan. */
template <class TDbAccessor>
double EstimatePlanCost(TDbAccessor *db, const Parameters &parameters,
LogicalOperator &plan) {
double EstimatePlanCost(TDbAccessor *db, const Parameters &parameters, LogicalOperator &plan) {
CostEstimator<TDbAccessor> estimator(db, parameters);
plan.Accept(estimator);
return estimator.cost();

File diff suppressed because it is too large Load Diff

View File

@ -28,38 +28,31 @@ class PostProcessor final {
public:
using ProcessedPlan = std::unique_ptr<LogicalOperator>;
explicit PostProcessor(const Parameters &parameters)
: parameters_(parameters) {}
explicit PostProcessor(const Parameters &parameters) : parameters_(parameters) {}
template <class TPlanningContext>
std::unique_ptr<LogicalOperator> Rewrite(
std::unique_ptr<LogicalOperator> plan, TPlanningContext *context) {
return RewriteWithIndexLookup(std::move(plan), context->symbol_table,
context->ast_storage, context->db);
std::unique_ptr<LogicalOperator> Rewrite(std::unique_ptr<LogicalOperator> plan, TPlanningContext *context) {
return RewriteWithIndexLookup(std::move(plan), context->symbol_table, context->ast_storage, context->db);
}
template <class TVertexCounts>
double EstimatePlanCost(const std::unique_ptr<LogicalOperator> &plan,
TVertexCounts *vertex_counts) {
double EstimatePlanCost(const std::unique_ptr<LogicalOperator> &plan, TVertexCounts *vertex_counts) {
return ::query::plan::EstimatePlanCost(vertex_counts, parameters_, *plan);
}
template <class TPlanningContext>
std::unique_ptr<LogicalOperator> MergeWithCombinator(
std::unique_ptr<LogicalOperator> curr_op,
std::unique_ptr<LogicalOperator> last_op, const Tree &combinator,
TPlanningContext *context) {
std::unique_ptr<LogicalOperator> MergeWithCombinator(std::unique_ptr<LogicalOperator> curr_op,
std::unique_ptr<LogicalOperator> last_op, const Tree &combinator,
TPlanningContext *context) {
if (const auto *union_ = utils::Downcast<const CypherUnion>(&combinator)) {
return std::unique_ptr<LogicalOperator>(
impl::GenUnion(*union_, std::move(last_op), std::move(curr_op),
*context->symbol_table));
impl::GenUnion(*union_, std::move(last_op), std::move(curr_op), *context->symbol_table));
}
throw utils::NotYetImplemented("query combinator");
}
template <class TPlanningContext>
std::unique_ptr<LogicalOperator> MakeDistinct(
std::unique_ptr<LogicalOperator> last_op, TPlanningContext *context) {
std::unique_ptr<LogicalOperator> MakeDistinct(std::unique_ptr<LogicalOperator> last_op, TPlanningContext *context) {
auto output_symbols = last_op->OutputSymbols(*context->symbol_table);
return std::make_unique<Distinct>(std::move(last_op), output_symbols);
}
@ -78,12 +71,10 @@ class PostProcessor final {
/// @sa RuleBasedPlanner
/// @sa VariableStartPlanner
template <template <class> class TPlanner, class TDbAccessor>
auto MakeLogicalPlanForSingleQuery(
std::vector<SingleQueryPart> single_query_parts,
PlanningContext<TDbAccessor> *context) {
auto MakeLogicalPlanForSingleQuery(std::vector<SingleQueryPart> single_query_parts,
PlanningContext<TDbAccessor> *context) {
context->bound_symbols.clear();
return TPlanner<PlanningContext<TDbAccessor>>(context).Plan(
single_query_parts);
return TPlanner<PlanningContext<TDbAccessor>>(context).Plan(single_query_parts);
}
/// Generates the LogicalOperator tree and returns the resulting plan.
@ -98,10 +89,8 @@ auto MakeLogicalPlanForSingleQuery(
/// @return pair consisting of the final `TPlanPostProcess::ProcessedPlan` and
/// the estimated cost of that plan as a `double`.
template <class TPlanningContext, class TPlanPostProcess>
auto MakeLogicalPlan(TPlanningContext *context, TPlanPostProcess *post_process,
bool use_variable_planner) {
auto query_parts = CollectQueryParts(*context->symbol_table,
*context->ast_storage, context->query);
auto MakeLogicalPlan(TPlanningContext *context, TPlanPostProcess *post_process, bool use_variable_planner) {
auto query_parts = CollectQueryParts(*context->symbol_table, *context->ast_storage, context->query);
auto &vertex_counts = *context->db;
double total_cost = 0;
@ -113,22 +102,19 @@ auto MakeLogicalPlan(TPlanningContext *context, TPlanPostProcess *post_process,
double min_cost = std::numeric_limits<double>::max();
if (use_variable_planner) {
auto plans = MakeLogicalPlanForSingleQuery<VariableStartPlanner>(
query_part.single_query_parts, context);
auto plans = MakeLogicalPlanForSingleQuery<VariableStartPlanner>(query_part.single_query_parts, context);
for (auto plan : plans) {
// Plans are generated lazily and the current plan will disappear, so
// it's ok to move it.
auto rewritten_plan = post_process->Rewrite(std::move(plan), context);
double cost =
post_process->EstimatePlanCost(rewritten_plan, &vertex_counts);
double cost = post_process->EstimatePlanCost(rewritten_plan, &vertex_counts);
if (!curr_plan || cost < min_cost) {
curr_plan.emplace(std::move(rewritten_plan));
min_cost = cost;
}
}
} else {
auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>(
query_part.single_query_parts, context);
auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>(query_part.single_query_parts, context);
auto rewritten_plan = post_process->Rewrite(std::move(plan), context);
min_cost = post_process->EstimatePlanCost(rewritten_plan, &vertex_counts);
curr_plan.emplace(std::move(rewritten_plan));
@ -136,9 +122,8 @@ auto MakeLogicalPlan(TPlanningContext *context, TPlanPostProcess *post_process,
total_cost += min_cost;
if (query_part.query_combinator) {
last_plan = post_process->MergeWithCombinator(
std::move(*curr_plan), std::move(last_plan),
*query_part.query_combinator, context);
last_plan = post_process->MergeWithCombinator(std::move(*curr_plan), std::move(last_plan),
*query_part.query_combinator, context);
} else {
last_plan = std::move(*curr_plan);
}
@ -152,8 +137,7 @@ auto MakeLogicalPlan(TPlanningContext *context, TPlanPostProcess *post_process,
}
template <class TPlanningContext>
auto MakeLogicalPlan(TPlanningContext *context, const Parameters &parameters,
bool use_variable_planner) {
auto MakeLogicalPlan(TPlanningContext *context, const Parameters &parameters, bool use_variable_planner) {
PostProcessor post_processor(parameters);
return MakeLogicalPlan(context, &post_processor, use_variable_planner);
}

View File

@ -8,9 +8,8 @@ namespace query::plan {
namespace {
void ForEachPattern(
Pattern &pattern, std::function<void(NodeAtom *)> base,
std::function<void(NodeAtom *, EdgeAtom *, NodeAtom *)> collect) {
void ForEachPattern(Pattern &pattern, std::function<void(NodeAtom *)> base,
std::function<void(NodeAtom *, EdgeAtom *, NodeAtom *)> collect) {
DMG_ASSERT(!pattern.atoms_.empty(), "Missing atoms in pattern");
auto atoms_it = pattern.atoms_.begin();
auto current_node = utils::Downcast<NodeAtom>(*atoms_it++);
@ -20,8 +19,7 @@ void ForEachPattern(
while (atoms_it != pattern.atoms_.end()) {
auto edge = utils::Downcast<EdgeAtom>(*atoms_it++);
DMG_ASSERT(edge, "Expected an edge atom in pattern.");
DMG_ASSERT(atoms_it != pattern.atoms_.end(),
"Edge atom should not end the pattern.");
DMG_ASSERT(atoms_it != pattern.atoms_.end(), "Edge atom should not end the pattern.");
auto prev_node = current_node;
current_node = utils::Downcast<NodeAtom>(*atoms_it++);
DMG_ASSERT(current_node, "Expected a node atom in pattern.");
@ -37,32 +35,24 @@ void ForEachPattern(
// (m) -[e]- (n), (n) -[f]- (o).
// This representation makes it easier to permute from which node or edge we
// want to start expanding.
std::vector<Expansion> NormalizePatterns(
const SymbolTable &symbol_table, const std::vector<Pattern *> &patterns) {
std::vector<Expansion> NormalizePatterns(const SymbolTable &symbol_table, const std::vector<Pattern *> &patterns) {
std::vector<Expansion> expansions;
auto ignore_node = [&](auto *) {};
auto collect_expansion = [&](auto *prev_node, auto *edge,
auto *current_node) {
auto collect_expansion = [&](auto *prev_node, auto *edge, auto *current_node) {
UsedSymbolsCollector collector(symbol_table);
if (edge->IsVariable()) {
if (edge->lower_bound_) edge->lower_bound_->Accept(collector);
if (edge->upper_bound_) edge->upper_bound_->Accept(collector);
if (edge->filter_lambda_.expression)
edge->filter_lambda_.expression->Accept(collector);
if (edge->filter_lambda_.expression) edge->filter_lambda_.expression->Accept(collector);
// Remove symbols which are bound by lambda arguments.
collector.symbols_.erase(
symbol_table.at(*edge->filter_lambda_.inner_edge));
collector.symbols_.erase(
symbol_table.at(*edge->filter_lambda_.inner_node));
collector.symbols_.erase(symbol_table.at(*edge->filter_lambda_.inner_edge));
collector.symbols_.erase(symbol_table.at(*edge->filter_lambda_.inner_node));
if (edge->type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) {
collector.symbols_.erase(
symbol_table.at(*edge->weight_lambda_.inner_edge));
collector.symbols_.erase(
symbol_table.at(*edge->weight_lambda_.inner_node));
collector.symbols_.erase(symbol_table.at(*edge->weight_lambda_.inner_edge));
collector.symbols_.erase(symbol_table.at(*edge->weight_lambda_.inner_node));
}
}
expansions.emplace_back(Expansion{prev_node, edge, edge->direction_, false,
collector.symbols_, current_node});
expansions.emplace_back(Expansion{prev_node, edge, edge->direction_, false, collector.symbols_, current_node});
};
for (const auto &pattern : patterns) {
if (pattern->atoms_.size() == 1U) {
@ -81,8 +71,7 @@ std::vector<Expansion> NormalizePatterns(
// as well as edge symbols which determine Cyphermorphism. Collecting filters
// will lift them out of a pattern and generate new expressions (just like they
// were in a Where clause).
void AddMatching(const std::vector<Pattern *> &patterns, Where *where,
SymbolTable &symbol_table, AstStorage &storage,
void AddMatching(const std::vector<Pattern *> &patterns, Where *where, SymbolTable &symbol_table, AstStorage &storage,
Matching &matching) {
auto expansions = NormalizePatterns(symbol_table, patterns);
std::unordered_set<Symbol> edge_symbols;
@ -116,18 +105,15 @@ void AddMatching(const std::vector<Pattern *> &patterns, Where *where,
std::vector<Symbol> path_elements;
for (auto *pattern_atom : pattern->atoms_)
path_elements.emplace_back(symbol_table.at(*pattern_atom->identifier_));
matching.named_paths.emplace(symbol_table.at(*pattern->identifier_),
std::move(path_elements));
matching.named_paths.emplace(symbol_table.at(*pattern->identifier_), std::move(path_elements));
}
}
if (where) {
matching.filters.CollectWhereFilter(*where, symbol_table);
}
}
void AddMatching(const Match &match, SymbolTable &symbol_table,
AstStorage &storage, Matching &matching) {
return AddMatching(match.patterns_, match.where_, symbol_table, storage,
matching);
void AddMatching(const Match &match, SymbolTable &symbol_table, AstStorage &storage, Matching &matching) {
return AddMatching(match.patterns_, match.where_, symbol_table, storage, matching);
}
auto SplitExpressionOnAnd(Expression *expression) {
@ -151,8 +137,7 @@ auto SplitExpressionOnAnd(Expression *expression) {
} // namespace
PropertyFilter::PropertyFilter(const SymbolTable &symbol_table,
const Symbol &symbol, PropertyIx property,
PropertyFilter::PropertyFilter(const SymbolTable &symbol_table, const Symbol &symbol, PropertyIx property,
Expression *value, Type type)
: symbol_(symbol), property_(property), type_(type), value_(value) {
MG_ASSERT(type != Type::RANGE);
@ -161,15 +146,10 @@ PropertyFilter::PropertyFilter(const SymbolTable &symbol_table,
is_symbol_in_value_ = utils::Contains(collector.symbols_, symbol);
}
PropertyFilter::PropertyFilter(
const SymbolTable &symbol_table, const Symbol &symbol, PropertyIx property,
const std::optional<PropertyFilter::Bound> &lower_bound,
const std::optional<PropertyFilter::Bound> &upper_bound)
: symbol_(symbol),
property_(property),
type_(Type::RANGE),
lower_bound_(lower_bound),
upper_bound_(upper_bound) {
PropertyFilter::PropertyFilter(const SymbolTable &symbol_table, const Symbol &symbol, PropertyIx property,
const std::optional<PropertyFilter::Bound> &lower_bound,
const std::optional<PropertyFilter::Bound> &upper_bound)
: symbol_(symbol), property_(property), type_(Type::RANGE), lower_bound_(lower_bound), upper_bound_(upper_bound) {
UsedSymbolsCollector collector(symbol_table);
if (lower_bound) {
lower_bound->value()->Accept(collector);
@ -180,8 +160,7 @@ PropertyFilter::PropertyFilter(
is_symbol_in_value_ = utils::Contains(collector.symbols_, symbol);
}
PropertyFilter::PropertyFilter(const Symbol &symbol, PropertyIx property,
Type type)
PropertyFilter::PropertyFilter(const Symbol &symbol, PropertyIx property, Type type)
: symbol_(symbol), property_(property), type_(type) {
// As this constructor is used for property filters where
// we don't have to evaluate the filter expression, we set
@ -190,8 +169,7 @@ PropertyFilter::PropertyFilter(const Symbol &symbol, PropertyIx property,
// we may be looking up.
}
IdFilter::IdFilter(const SymbolTable &symbol_table, const Symbol &symbol,
Expression *value)
IdFilter::IdFilter(const SymbolTable &symbol_table, const Symbol &symbol, Expression *value)
: symbol_(symbol), value_(value) {
MG_ASSERT(value);
UsedSymbolsCollector collector(symbol_table);
@ -203,16 +181,12 @@ void Filters::EraseFilter(const FilterInfo &filter) {
// TODO: Ideally, we want to determine the equality of both expression trees,
// instead of a simple pointer compare.
all_filters_.erase(std::remove_if(all_filters_.begin(), all_filters_.end(),
[&filter](const auto &f) {
return f.expression == filter.expression;
}),
[&filter](const auto &f) { return f.expression == filter.expression; }),
all_filters_.end());
}
void Filters::EraseLabelFilter(const Symbol &symbol, LabelIx label,
std::vector<Expression *> *removed_filters) {
for (auto filter_it = all_filters_.begin();
filter_it != all_filters_.end();) {
void Filters::EraseLabelFilter(const Symbol &symbol, LabelIx label, std::vector<Expression *> *removed_filters) {
for (auto filter_it = all_filters_.begin(); filter_it != all_filters_.end();) {
if (filter_it->type != FilterInfo::Type::Label) {
++filter_it;
continue;
@ -221,15 +195,13 @@ void Filters::EraseLabelFilter(const Symbol &symbol, LabelIx label,
++filter_it;
continue;
}
auto label_it =
std::find(filter_it->labels.begin(), filter_it->labels.end(), label);
auto label_it = std::find(filter_it->labels.begin(), filter_it->labels.end(), label);
if (label_it == filter_it->labels.end()) {
++filter_it;
continue;
}
filter_it->labels.erase(label_it);
DMG_ASSERT(!utils::Contains(filter_it->labels, label),
"Didn't expect duplicated labels");
DMG_ASSERT(!utils::Contains(filter_it->labels, label), "Didn't expect duplicated labels");
if (filter_it->labels.empty()) {
// If there are no labels to filter, then erase the whole FilterInfo.
if (removed_filters) {
@ -242,8 +214,7 @@ void Filters::EraseLabelFilter(const Symbol &symbol, LabelIx label,
}
}
void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table,
AstStorage &storage) {
void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table, AstStorage &storage) {
UsedSymbolsCollector collector(symbol_table);
auto add_properties_variable = [&](EdgeAtom *atom) {
const auto &symbol = symbol_table.at(*atom->identifier_);
@ -256,18 +227,13 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table,
{
collector.symbols_.clear();
prop_pair.second->Accept(collector);
collector.symbols_.emplace(
symbol_table.at(*atom->filter_lambda_.inner_node));
collector.symbols_.emplace(
symbol_table.at(*atom->filter_lambda_.inner_edge));
collector.symbols_.emplace(symbol_table.at(*atom->filter_lambda_.inner_node));
collector.symbols_.emplace(symbol_table.at(*atom->filter_lambda_.inner_edge));
// First handle the inline property filter.
auto *property_lookup = storage.Create<PropertyLookup>(
atom->filter_lambda_.inner_edge, prop_pair.first);
auto *prop_equal =
storage.Create<EqualOperator>(property_lookup, prop_pair.second);
auto *property_lookup = storage.Create<PropertyLookup>(atom->filter_lambda_.inner_edge, prop_pair.first);
auto *prop_equal = storage.Create<EqualOperator>(property_lookup, prop_pair.second);
// Currently, variable expand has no gains if we set PropertyFilter.
all_filters_.emplace_back(FilterInfo{FilterInfo::Type::Generic,
prop_equal, collector.symbols_});
all_filters_.emplace_back(FilterInfo{FilterInfo::Type::Generic, prop_equal, collector.symbols_});
}
{
collector.symbols_.clear();
@ -275,23 +241,15 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table,
collector.symbols_.insert(symbol); // PropertyLookup uses the symbol.
// Now handle the post-expansion filter.
// Create a new identifier and a symbol which will be filled in All.
auto *identifier =
storage
.Create<Identifier>(atom->identifier_->name_,
atom->identifier_->user_declared_)
->MapTo(
symbol_table.CreateSymbol(atom->identifier_->name_, false));
auto *identifier = storage.Create<Identifier>(atom->identifier_->name_, atom->identifier_->user_declared_)
->MapTo(symbol_table.CreateSymbol(atom->identifier_->name_, false));
// Create an equality expression and store it in all_filters_.
auto *property_lookup =
storage.Create<PropertyLookup>(identifier, prop_pair.first);
auto *prop_equal =
storage.Create<EqualOperator>(property_lookup, prop_pair.second);
auto *property_lookup = storage.Create<PropertyLookup>(identifier, prop_pair.first);
auto *prop_equal = storage.Create<EqualOperator>(property_lookup, prop_pair.second);
// Currently, variable expand has no gains if we set PropertyFilter.
all_filters_.emplace_back(
FilterInfo{FilterInfo::Type::Generic,
storage.Create<All>(identifier, atom->identifier_,
storage.Create<Where>(prop_equal)),
collector.symbols_});
all_filters_.emplace_back(FilterInfo{
FilterInfo::Type::Generic,
storage.Create<All>(identifier, atom->identifier_, storage.Create<Where>(prop_equal)), collector.symbols_});
}
}
};
@ -299,17 +257,13 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table,
const auto &symbol = symbol_table.at(*atom->identifier_);
for (auto &prop_pair : atom->properties_) {
// Create an equality expression and store it in all_filters_.
auto *property_lookup =
storage.Create<PropertyLookup>(atom->identifier_, prop_pair.first);
auto *prop_equal =
storage.Create<EqualOperator>(property_lookup, prop_pair.second);
auto *property_lookup = storage.Create<PropertyLookup>(atom->identifier_, prop_pair.first);
auto *prop_equal = storage.Create<EqualOperator>(property_lookup, prop_pair.second);
collector.symbols_.clear();
prop_equal->Accept(collector);
FilterInfo filter_info{FilterInfo::Type::Property, prop_equal,
collector.symbols_};
FilterInfo filter_info{FilterInfo::Type::Property, prop_equal, collector.symbols_};
// Store a PropertyFilter on the value of the property.
filter_info.property_filter.emplace(symbol_table, symbol, prop_pair.first,
prop_pair.second,
filter_info.property_filter.emplace(symbol_table, symbol, prop_pair.first, prop_pair.second,
PropertyFilter::Type::EQUAL);
all_filters_.emplace_back(filter_info);
}
@ -318,10 +272,8 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table,
const auto &node_symbol = symbol_table.at(*node->identifier_);
if (!node->labels_.empty()) {
// Create a LabelsTest and store it.
auto *labels_test =
storage.Create<LabelsTest>(node->identifier_, node->labels_);
auto label_filter = FilterInfo{FilterInfo::Type::Label, labels_test,
std::unordered_set<Symbol>{node_symbol}};
auto *labels_test = storage.Create<LabelsTest>(node->identifier_, node->labels_);
auto label_filter = FilterInfo{FilterInfo::Type::Label, labels_test, std::unordered_set<Symbol>{node_symbol}};
label_filter.labels = node->labels_;
all_filters_.emplace_back(label_filter);
}
@ -339,15 +291,13 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table,
// Adds the where filter expression to `all_filters_` and collects additional
// information for potential property and label indexing.
void Filters::CollectWhereFilter(Where &where,
const SymbolTable &symbol_table) {
void Filters::CollectWhereFilter(Where &where, const SymbolTable &symbol_table) {
CollectFilterExpression(where.expression_, symbol_table);
}
// Adds the expression to `all_filters_` and collects additional
// information for potential property and label indexing.
void Filters::CollectFilterExpression(Expression *expr,
const SymbolTable &symbol_table) {
void Filters::CollectFilterExpression(Expression *expr, const SymbolTable &symbol_table) {
auto filters = SplitExpressionOnAnd(expr);
for (const auto &filter : filters) {
AnalyzeAndStoreFilter(filter, symbol_table);
@ -356,16 +306,12 @@ void Filters::CollectFilterExpression(Expression *expr,
// Analyzes the filter expression by collecting information on filtering labels
// and properties to be used with indexing.
void Filters::AnalyzeAndStoreFilter(Expression *expr,
const SymbolTable &symbol_table) {
void Filters::AnalyzeAndStoreFilter(Expression *expr, const SymbolTable &symbol_table) {
using Bound = PropertyFilter::Bound;
UsedSymbolsCollector collector(symbol_table);
expr->Accept(collector);
auto make_filter = [&collector, &expr](FilterInfo::Type type) {
return FilterInfo{type, expr, collector.symbols_};
};
auto get_property_lookup = [](auto *maybe_lookup, auto *&prop_lookup,
auto *&ident) -> bool {
auto make_filter = [&collector, &expr](FilterInfo::Type type) { return FilterInfo{type, expr, collector.symbols_}; };
auto get_property_lookup = [](auto *maybe_lookup, auto *&prop_lookup, auto *&ident) -> bool {
return (prop_lookup = utils::Downcast<PropertyLookup>(maybe_lookup)) &&
(ident = utils::Downcast<Identifier>(prop_lookup->expression_));
};
@ -376,9 +322,8 @@ void Filters::AnalyzeAndStoreFilter(Expression *expr,
Identifier *ident = nullptr;
if (get_property_lookup(maybe_lookup, prop_lookup, ident)) {
auto filter = make_filter(FilterInfo::Type::Property);
filter.property_filter = PropertyFilter(
symbol_table, symbol_table.at(*ident), prop_lookup->property_,
val_expr, PropertyFilter::Type::EQUAL);
filter.property_filter = PropertyFilter(symbol_table, symbol_table.at(*ident), prop_lookup->property_, val_expr,
PropertyFilter::Type::EQUAL);
all_filters_.emplace_back(filter);
return true;
}
@ -390,9 +335,8 @@ void Filters::AnalyzeAndStoreFilter(Expression *expr,
Identifier *ident = nullptr;
if (get_property_lookup(maybe_lookup, prop_lookup, ident)) {
auto filter = make_filter(FilterInfo::Type::Property);
filter.property_filter = PropertyFilter(
symbol_table, symbol_table.at(*ident), prop_lookup->property_,
val_expr, PropertyFilter::Type::REGEX_MATCH);
filter.property_filter = PropertyFilter(symbol_table, symbol_table.at(*ident), prop_lookup->property_, val_expr,
PropertyFilter::Type::REGEX_MATCH);
all_filters_.emplace_back(filter);
return true;
}
@ -400,16 +344,14 @@ void Filters::AnalyzeAndStoreFilter(Expression *expr,
};
// Checks if either the expr1 and expr2 are property lookups, adds them as
// PropertyFilter and returns true. Otherwise, returns false.
auto add_prop_greater = [&](auto *expr1, auto *expr2,
auto bound_type) -> bool {
auto add_prop_greater = [&](auto *expr1, auto *expr2, auto bound_type) -> bool {
PropertyLookup *prop_lookup = nullptr;
Identifier *ident = nullptr;
bool is_prop_filter = false;
if (get_property_lookup(expr1, prop_lookup, ident)) {
// n.prop > value
auto filter = make_filter(FilterInfo::Type::Property);
filter.property_filter.emplace(symbol_table, symbol_table.at(*ident),
prop_lookup->property_,
filter.property_filter.emplace(symbol_table, symbol_table.at(*ident), prop_lookup->property_,
Bound(expr2, bound_type), std::nullopt);
all_filters_.emplace_back(filter);
is_prop_filter = true;
@ -417,8 +359,7 @@ void Filters::AnalyzeAndStoreFilter(Expression *expr,
if (get_property_lookup(expr2, prop_lookup, ident)) {
// value > n.prop
auto filter = make_filter(FilterInfo::Type::Property);
filter.property_filter.emplace(symbol_table, symbol_table.at(*ident),
prop_lookup->property_, std::nullopt,
filter.property_filter.emplace(symbol_table, symbol_table.at(*ident), prop_lookup->property_, std::nullopt,
Bound(expr1, bound_type));
all_filters_.emplace_back(filter);
is_prop_filter = true;
@ -447,9 +388,8 @@ void Filters::AnalyzeAndStoreFilter(Expression *expr,
Identifier *ident = nullptr;
if (get_property_lookup(maybe_lookup, prop_lookup, ident)) {
auto filter = make_filter(FilterInfo::Type::Property);
filter.property_filter = PropertyFilter(
symbol_table, symbol_table.at(*ident), prop_lookup->property_,
val_expr, PropertyFilter::Type::IN);
filter.property_filter = PropertyFilter(symbol_table, symbol_table.at(*ident), prop_lookup->property_, val_expr,
PropertyFilter::Type::IN);
all_filters_.emplace_back(filter);
return true;
}
@ -466,30 +406,26 @@ void Filters::AnalyzeAndStoreFilter(Expression *expr,
return false;
}
auto *maybe_is_null_check =
utils::Downcast<IsNullOperator>(maybe_is_not_null_check->expression_);
auto *maybe_is_null_check = utils::Downcast<IsNullOperator>(maybe_is_not_null_check->expression_);
if (!maybe_is_null_check) {
return false;
}
PropertyLookup *prop_lookup = nullptr;
Identifier *ident = nullptr;
if (!get_property_lookup(maybe_is_null_check->expression_, prop_lookup,
ident)) {
if (!get_property_lookup(maybe_is_null_check->expression_, prop_lookup, ident)) {
return false;
}
auto filter = make_filter(FilterInfo::Type::Property);
filter.property_filter =
PropertyFilter(symbol_table.at(*ident), prop_lookup->property_,
PropertyFilter::Type::IS_NOT_NULL);
PropertyFilter(symbol_table.at(*ident), prop_lookup->property_, PropertyFilter::Type::IS_NOT_NULL);
all_filters_.emplace_back(filter);
return true;
};
// We are only interested to see the insides of And, because Or prevents
// indexing since any labels and properties found there may be optional.
DMG_ASSERT(!utils::IsSubtype(*expr, AndOperator::kType),
"Expected AndOperators have been split.");
DMG_ASSERT(!utils::IsSubtype(*expr, AndOperator::kType), "Expected AndOperators have been split.");
if (auto *labels_test = utils::Downcast<LabelsTest>(expr)) {
// Since LabelsTest may contain any expression, we can only use the
// simplest test on an identifier.
@ -527,25 +463,21 @@ void Filters::AnalyzeAndStoreFilter(Expression *expr,
all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic));
}
} else if (auto *gt = utils::Downcast<GreaterOperator>(expr)) {
if (!add_prop_greater(gt->expression1_, gt->expression2_,
Bound::Type::EXCLUSIVE)) {
if (!add_prop_greater(gt->expression1_, gt->expression2_, Bound::Type::EXCLUSIVE)) {
all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic));
}
} else if (auto *ge = utils::Downcast<GreaterEqualOperator>(expr)) {
if (!add_prop_greater(ge->expression1_, ge->expression2_,
Bound::Type::INCLUSIVE)) {
if (!add_prop_greater(ge->expression1_, ge->expression2_, Bound::Type::INCLUSIVE)) {
all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic));
}
} else if (auto *lt = utils::Downcast<LessOperator>(expr)) {
// Like greater, but in reverse.
if (!add_prop_greater(lt->expression2_, lt->expression1_,
Bound::Type::EXCLUSIVE)) {
if (!add_prop_greater(lt->expression2_, lt->expression1_, Bound::Type::EXCLUSIVE)) {
all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic));
}
} else if (auto *le = utils::Downcast<LessEqualOperator>(expr)) {
// Like greater equal, but in reverse.
if (!add_prop_greater(le->expression2_, le->expression1_,
Bound::Type::INCLUSIVE)) {
if (!add_prop_greater(le->expression2_, le->expression1_, Bound::Type::INCLUSIVE)) {
all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic));
}
} else if (auto *in = utils::Downcast<InListOperator>(expr)) {
@ -571,29 +503,25 @@ void Filters::AnalyzeAndStoreFilter(Expression *expr,
// Converts a Query to multiple QueryParts. In the process new Ast nodes may be
// created, e.g. filter expressions.
std::vector<SingleQueryPart> CollectSingleQueryParts(
SymbolTable &symbol_table, AstStorage &storage, SingleQuery *single_query) {
std::vector<SingleQueryPart> CollectSingleQueryParts(SymbolTable &symbol_table, AstStorage &storage,
SingleQuery *single_query) {
std::vector<SingleQueryPart> query_parts(1);
auto *query_part = &query_parts.back();
for (auto &clause : single_query->clauses_) {
if (auto *match = utils::Downcast<Match>(clause)) {
if (match->optional_) {
query_part->optional_matching.emplace_back(Matching{});
AddMatching(*match, symbol_table, storage,
query_part->optional_matching.back());
AddMatching(*match, symbol_table, storage, query_part->optional_matching.back());
} else {
DMG_ASSERT(query_part->optional_matching.empty(),
"Match clause cannot follow optional match.");
DMG_ASSERT(query_part->optional_matching.empty(), "Match clause cannot follow optional match.");
AddMatching(*match, symbol_table, storage, query_part->matching);
}
} else {
query_part->remaining_clauses.push_back(clause);
if (auto *merge = utils::Downcast<query::Merge>(clause)) {
query_part->merge_matching.emplace_back(Matching{});
AddMatching({merge->pattern_}, nullptr, symbol_table, storage,
query_part->merge_matching.back());
} else if (utils::IsSubtype(*clause, With::kType) ||
utils::IsSubtype(*clause, query::Unwind::kType) ||
AddMatching({merge->pattern_}, nullptr, symbol_table, storage, query_part->merge_matching.back());
} else if (utils::IsSubtype(*clause, With::kType) || utils::IsSubtype(*clause, query::Unwind::kType) ||
utils::IsSubtype(*clause, query::CallProcedure::kType)) {
// This query part is done, continue with a new one.
query_parts.emplace_back(SingleQueryPart{});
@ -606,14 +534,12 @@ std::vector<SingleQueryPart> CollectSingleQueryParts(
return query_parts;
}
QueryParts CollectQueryParts(SymbolTable &symbol_table, AstStorage &storage,
CypherQuery *query) {
QueryParts CollectQueryParts(SymbolTable &symbol_table, AstStorage &storage, CypherQuery *query) {
std::vector<QueryPart> query_parts;
auto *single_query = query->single_query_;
MG_ASSERT(single_query, "Expected at least a single query");
query_parts.push_back(
QueryPart{CollectSingleQueryParts(symbol_table, storage, single_query)});
query_parts.push_back(QueryPart{CollectSingleQueryParts(symbol_table, storage, single_query)});
bool distinct = false;
for (auto *cypher_union : query->cypher_unions_) {
@ -623,9 +549,7 @@ QueryParts CollectQueryParts(SymbolTable &symbol_table, AstStorage &storage,
auto *single_query = cypher_union->single_query_;
MG_ASSERT(single_query, "Expected UNION to have a query");
query_parts.push_back(
QueryPart{CollectSingleQueryParts(symbol_table, storage, single_query),
cypher_union});
query_parts.push_back(QueryPart{CollectSingleQueryParts(symbol_table, storage, single_query), cypher_union});
}
return QueryParts{query_parts, distinct};
}

View File

@ -16,8 +16,7 @@ namespace query::plan {
/// Collects symbols from identifiers found in visited AST nodes.
class UsedSymbolsCollector : public HierarchicalTreeVisitor {
public:
explicit UsedSymbolsCollector(const SymbolTable &symbol_table)
: symbol_table_(symbol_table) {}
explicit UsedSymbolsCollector(const SymbolTable &symbol_table) : symbol_table_(symbol_table) {}
using HierarchicalTreeVisitor::PostVisit;
using HierarchicalTreeVisitor::PreVisit;
@ -99,11 +98,10 @@ class PropertyFilter {
enum class Type { EQUAL, REGEX_MATCH, RANGE, IN, IS_NOT_NULL };
/// Construct with Expression being the equality or regex match check.
PropertyFilter(const SymbolTable &, const Symbol &, PropertyIx, Expression *,
Type);
PropertyFilter(const SymbolTable &, const Symbol &, PropertyIx, Expression *, Type);
/// Construct the range based filter.
PropertyFilter(const SymbolTable &, const Symbol &, PropertyIx,
const std::optional<Bound> &, const std::optional<Bound> &);
PropertyFilter(const SymbolTable &, const Symbol &, PropertyIx, const std::optional<Bound> &,
const std::optional<Bound> &);
/// Construct a filter without an expression that produces a value.
/// Used for the "PROP IS NOT NULL" filter, and can be used for any
/// property filter that doesn't need to use an expression to produce
@ -177,20 +175,14 @@ class Filters final {
auto erase(iterator pos) { return all_filters_.erase(pos); }
auto erase(const_iterator pos) { return all_filters_.erase(pos); }
auto erase(iterator first, iterator last) {
return all_filters_.erase(first, last);
}
auto erase(const_iterator first, const_iterator last) {
return all_filters_.erase(first, last);
}
auto erase(iterator first, iterator last) { return all_filters_.erase(first, last); }
auto erase(const_iterator first, const_iterator last) { return all_filters_.erase(first, last); }
auto FilteredLabels(const Symbol &symbol) const {
std::unordered_set<LabelIx> labels;
for (const auto &filter : all_filters_) {
if (filter.type == FilterInfo::Type::Label &&
utils::Contains(filter.used_symbols, symbol)) {
MG_ASSERT(filter.used_symbols.size() == 1U,
"Expected a single used symbol for label filter");
if (filter.type == FilterInfo::Type::Label && utils::Contains(filter.used_symbols, symbol)) {
MG_ASSERT(filter.used_symbols.size() == 1U, "Expected a single used symbol for label filter");
labels.insert(filter.labels.begin(), filter.labels.end());
}
}
@ -205,15 +197,13 @@ class Filters final {
/// Remove a label filter for symbol; may invalidate iterators.
/// If removed_filters is not nullptr, fills the vector with original
/// `Expression *` which are now completely removed.
void EraseLabelFilter(const Symbol &, LabelIx,
std::vector<Expression *> *removed_filters = nullptr);
void EraseLabelFilter(const Symbol &, LabelIx, std::vector<Expression *> *removed_filters = nullptr);
/// Returns a vector of FilterInfo for properties.
auto PropertyFilters(const Symbol &symbol) const {
std::vector<FilterInfo> filters;
for (const auto &filter : all_filters_) {
if (filter.type == FilterInfo::Type::Property &&
filter.property_filter->symbol_ == symbol) {
if (filter.type == FilterInfo::Type::Property && filter.property_filter->symbol_ == symbol) {
filters.push_back(filter);
}
}
@ -224,8 +214,7 @@ class Filters final {
auto IdFilters(const Symbol &symbol) const {
std::vector<FilterInfo> filters;
for (const auto &filter : all_filters_) {
if (filter.type == FilterInfo::Type::Id &&
filter.id_filter->symbol_ == symbol) {
if (filter.type == FilterInfo::Type::Id && filter.id_filter->symbol_ == symbol) {
filters.push_back(filter);
}
}

View File

@ -6,8 +6,7 @@
namespace query::plan {
PlanPrinter::PlanPrinter(const DbAccessor *dba, std::ostream *out)
: dba_(dba), out_(out) {}
PlanPrinter::PlanPrinter(const DbAccessor *dba, std::ostream *out) : dba_(dba), out_(out) {}
#define PRE_VISIT(TOp) \
bool PlanPrinter::PreVisit(TOp &) { \
@ -20,13 +19,10 @@ PRE_VISIT(CreateNode);
bool PlanPrinter::PreVisit(CreateExpand &op) {
WithPrintLn([&](auto &out) {
out << "* CreateExpand (" << op.input_symbol_.name() << ")"
<< (op.edge_info_.direction == query::EdgeAtom::Direction::IN ? "<-"
: "-")
<< "[" << op.edge_info_.symbol.name() << ":"
<< dba_->EdgeTypeToName(op.edge_info_.edge_type) << "]"
<< (op.edge_info_.direction == query::EdgeAtom::Direction::OUT ? "->"
: "-")
<< "(" << op.node_info_.symbol.name() << ")";
<< (op.edge_info_.direction == query::EdgeAtom::Direction::IN ? "<-" : "-") << "["
<< op.edge_info_.symbol.name() << ":" << dba_->EdgeTypeToName(op.edge_info_.edge_type) << "]"
<< (op.edge_info_.direction == query::EdgeAtom::Direction::OUT ? "->" : "-") << "("
<< op.node_info_.symbol.name() << ")";
});
return true;
}
@ -44,8 +40,7 @@ bool PlanPrinter::PreVisit(query::plan::ScanAll &op) {
bool PlanPrinter::PreVisit(query::plan::ScanAllByLabel &op) {
WithPrintLn([&](auto &out) {
out << "* ScanAllByLabel"
<< " (" << op.output_symbol_.name() << " :"
<< dba_->LabelToName(op.label_) << ")";
<< " (" << op.output_symbol_.name() << " :" << dba_->LabelToName(op.label_) << ")";
});
return true;
}
@ -53,8 +48,7 @@ bool PlanPrinter::PreVisit(query::plan::ScanAllByLabel &op) {
bool PlanPrinter::PreVisit(query::plan::ScanAllByLabelPropertyValue &op) {
WithPrintLn([&](auto &out) {
out << "* ScanAllByLabelPropertyValue"
<< " (" << op.output_symbol_.name() << " :"
<< dba_->LabelToName(op.label_) << " {"
<< " (" << op.output_symbol_.name() << " :" << dba_->LabelToName(op.label_) << " {"
<< dba_->PropertyToName(op.property_) << "})";
});
return true;
@ -63,8 +57,7 @@ bool PlanPrinter::PreVisit(query::plan::ScanAllByLabelPropertyValue &op) {
bool PlanPrinter::PreVisit(query::plan::ScanAllByLabelPropertyRange &op) {
WithPrintLn([&](auto &out) {
out << "* ScanAllByLabelPropertyRange"
<< " (" << op.output_symbol_.name() << " :"
<< dba_->LabelToName(op.label_) << " {"
<< " (" << op.output_symbol_.name() << " :" << dba_->LabelToName(op.label_) << " {"
<< dba_->PropertyToName(op.property_) << "})";
});
return true;
@ -73,8 +66,7 @@ bool PlanPrinter::PreVisit(query::plan::ScanAllByLabelPropertyRange &op) {
bool PlanPrinter::PreVisit(query::plan::ScanAllByLabelProperty &op) {
WithPrintLn([&](auto &out) {
out << "* ScanAllByLabelProperty"
<< " (" << op.output_symbol_.name() << " :"
<< dba_->LabelToName(op.label_) << " {"
<< " (" << op.output_symbol_.name() << " :" << dba_->LabelToName(op.label_) << " {"
<< dba_->PropertyToName(op.property_) << "})";
});
return true;
@ -91,17 +83,13 @@ bool PlanPrinter::PreVisit(ScanAllById &op) {
bool PlanPrinter::PreVisit(query::plan::Expand &op) {
WithPrintLn([&](auto &out) {
*out_ << "* Expand (" << op.input_symbol_.name() << ")"
<< (op.common_.direction == query::EdgeAtom::Direction::IN ? "<-"
: "-")
<< "[" << op.common_.edge_symbol.name();
utils::PrintIterable(*out_, op.common_.edge_types, "|",
[this](auto &stream, const auto &edge_type) {
stream << ":" << dba_->EdgeTypeToName(edge_type);
});
*out_ << "]"
<< (op.common_.direction == query::EdgeAtom::Direction::OUT ? "->"
: "-")
<< "(" << op.common_.node_symbol.name() << ")";
<< (op.common_.direction == query::EdgeAtom::Direction::IN ? "<-" : "-") << "["
<< op.common_.edge_symbol.name();
utils::PrintIterable(*out_, op.common_.edge_types, "|", [this](auto &stream, const auto &edge_type) {
stream << ":" << dba_->EdgeTypeToName(edge_type);
});
*out_ << "]" << (op.common_.direction == query::EdgeAtom::Direction::OUT ? "->" : "-") << "("
<< op.common_.node_symbol.name() << ")";
});
return true;
}
@ -124,17 +112,13 @@ bool PlanPrinter::PreVisit(query::plan::ExpandVariable &op) {
LOG_FATAL("Unexpected ExpandVariable::type_");
}
*out_ << " (" << op.input_symbol_.name() << ")"
<< (op.common_.direction == query::EdgeAtom::Direction::IN ? "<-"
: "-")
<< "[" << op.common_.edge_symbol.name();
utils::PrintIterable(*out_, op.common_.edge_types, "|",
[this](auto &stream, const auto &edge_type) {
stream << ":" << dba_->EdgeTypeToName(edge_type);
});
*out_ << "]"
<< (op.common_.direction == query::EdgeAtom::Direction::OUT ? "->"
: "-")
<< "(" << op.common_.node_symbol.name() << ")";
<< (op.common_.direction == query::EdgeAtom::Direction::IN ? "<-" : "-") << "["
<< op.common_.edge_symbol.name();
utils::PrintIterable(*out_, op.common_.edge_types, "|", [this](auto &stream, const auto &edge_type) {
stream << ":" << dba_->EdgeTypeToName(edge_type);
});
*out_ << "]" << (op.common_.direction == query::EdgeAtom::Direction::OUT ? "->" : "-") << "("
<< op.common_.node_symbol.name() << ")";
});
return true;
}
@ -142,9 +126,7 @@ bool PlanPrinter::PreVisit(query::plan::ExpandVariable &op) {
bool PlanPrinter::PreVisit(query::plan::Produce &op) {
WithPrintLn([&](auto &out) {
out << "* Produce {";
utils::PrintIterable(
out, op.named_expressions_, ", ",
[](auto &out, const auto &nexpr) { out << nexpr->name_; });
utils::PrintIterable(out, op.named_expressions_, ", ", [](auto &out, const auto &nexpr) { out << nexpr->name_; });
out << "}";
});
return true;
@ -163,12 +145,10 @@ PRE_VISIT(Accumulate);
bool PlanPrinter::PreVisit(query::plan::Aggregate &op) {
WithPrintLn([&](auto &out) {
out << "* Aggregate {";
utils::PrintIterable(
out, op.aggregations_, ", ",
[](auto &out, const auto &aggr) { out << aggr.output_sym.name(); });
utils::PrintIterable(out, op.aggregations_, ", ",
[](auto &out, const auto &aggr) { out << aggr.output_sym.name(); });
out << "} {";
utils::PrintIterable(out, op.remember_, ", ",
[](auto &out, const auto &sym) { out << sym.name(); });
utils::PrintIterable(out, op.remember_, ", ", [](auto &out, const auto &sym) { out << sym.name(); });
out << "}";
});
return true;
@ -180,8 +160,7 @@ PRE_VISIT(Limit);
bool PlanPrinter::PreVisit(query::plan::OrderBy &op) {
WithPrintLn([&op](auto &out) {
out << "* OrderBy {";
utils::PrintIterable(out, op.output_symbols_, ", ",
[](auto &out, const auto &sym) { out << sym.name(); });
utils::PrintIterable(out, op.output_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); });
out << "}";
});
return true;
@ -208,11 +187,9 @@ PRE_VISIT(Distinct);
bool PlanPrinter::PreVisit(query::plan::Union &op) {
WithPrintLn([&op](auto &out) {
out << "* Union {";
utils::PrintIterable(out, op.left_symbols_, ", ",
[](auto &out, const auto &sym) { out << sym.name(); });
utils::PrintIterable(out, op.left_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); });
out << " : ";
utils::PrintIterable(out, op.right_symbols_, ", ",
[](auto &out, const auto &sym) { out << sym.name(); });
utils::PrintIterable(out, op.right_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); });
out << "}";
});
Branch(*op.right_op_);
@ -223,8 +200,7 @@ bool PlanPrinter::PreVisit(query::plan::Union &op) {
bool PlanPrinter::PreVisit(query::plan::CallProcedure &op) {
WithPrintLn([&op](auto &out) {
out << "* CallProcedure<" << op.procedure_name_ << "> {";
utils::PrintIterable(out, op.result_symbols_, ", ",
[](auto &out, const auto &sym) { out << sym.name(); });
utils::PrintIterable(out, op.result_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); });
out << "}";
});
return true;
@ -238,11 +214,9 @@ bool PlanPrinter::Visit(query::plan::Once &op) {
bool PlanPrinter::PreVisit(query::plan::Cartesian &op) {
WithPrintLn([&op](auto &out) {
out << "* Cartesian {";
utils::PrintIterable(out, op.left_symbols_, ", ",
[](auto &out, const auto &sym) { out << sym.name(); });
utils::PrintIterable(out, op.left_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); });
out << " : ";
utils::PrintIterable(out, op.right_symbols_, ", ",
[](auto &out, const auto &sym) { out << sym.name(); });
utils::PrintIterable(out, op.right_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); });
out << "}";
});
Branch(*op.right_op_);
@ -257,23 +231,20 @@ bool PlanPrinter::DefaultPreVisit() {
return true;
}
void PlanPrinter::Branch(query::plan::LogicalOperator &op,
const std::string &branch_name) {
void PlanPrinter::Branch(query::plan::LogicalOperator &op, const std::string &branch_name) {
WithPrintLn([&](auto &out) { out << "|\\ " << branch_name; });
++depth_;
op.Accept(*this);
--depth_;
}
void PrettyPrint(const DbAccessor &dba, const LogicalOperator *plan_root,
std::ostream *out) {
void PrettyPrint(const DbAccessor &dba, const LogicalOperator *plan_root, std::ostream *out) {
PlanPrinter printer(&dba, out);
// FIXME(mtomic): We should make visitors that take const arguments.
const_cast<LogicalOperator *>(plan_root)->Accept(printer);
}
nlohmann::json PlanToJson(const DbAccessor &dba,
const LogicalOperator *plan_root) {
nlohmann::json PlanToJson(const DbAccessor &dba, const LogicalOperator *plan_root) {
impl::PlanToJsonVisitor visitor(&dba);
// FIXME(mtomic): We should make visitors that take const arguments.
const_cast<LogicalOperator *>(plan_root)->Accept(visitor);
@ -352,17 +323,11 @@ json ToJson(const utils::Bound<Expression *> &bound) {
json ToJson(const Symbol &symbol) { return symbol.name(); }
json ToJson(storage::EdgeTypeId edge_type, const DbAccessor &dba) {
return dba.EdgeTypeToName(edge_type);
}
json ToJson(storage::EdgeTypeId edge_type, const DbAccessor &dba) { return dba.EdgeTypeToName(edge_type); }
json ToJson(storage::LabelId label, const DbAccessor &dba) {
return dba.LabelToName(label);
}
json ToJson(storage::LabelId label, const DbAccessor &dba) { return dba.LabelToName(label); }
json ToJson(storage::PropertyId property, const DbAccessor &dba) {
return dba.PropertyToName(property);
}
json ToJson(storage::PropertyId property, const DbAccessor &dba) { return dba.PropertyToName(property); }
json ToJson(NamedExpression *nexpr) {
json json;
@ -371,9 +336,7 @@ json ToJson(NamedExpression *nexpr) {
return json;
}
json ToJson(
const std::vector<std::pair<storage::PropertyId, Expression *>> &properties,
const DbAccessor &dba) {
json ToJson(const std::vector<std::pair<storage::PropertyId, Expression *>> &properties, const DbAccessor &dba) {
json json;
for (const auto &prop_pair : properties) {
json.emplace(ToJson(prop_pair.first, dba), ToJson(prop_pair.second));
@ -558,9 +521,7 @@ bool PlanToJsonVisitor::PreVisit(ExpandVariable &op) {
self["upper_bound"] = op.upper_bound_ ? ToJson(op.upper_bound_) : json();
self["existing_node"] = op.common_.existing_node;
self["filter_lambda"] = op.filter_lambda_.expression
? ToJson(op.filter_lambda_.expression)
: json();
self["filter_lambda"] = op.filter_lambda_.expression ? ToJson(op.filter_lambda_.expression) : json();
if (op.type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) {
self["weight_lambda"] = ToJson(op.weight_lambda_->expression);

View File

@ -19,19 +19,16 @@ class LogicalOperator;
/// DbAccessor is needed for resolving label and property names.
/// Note that `plan_root` isn't modified, but we can't take it as a const
/// because we don't have support for visiting a const LogicalOperator.
void PrettyPrint(const DbAccessor &dba, const LogicalOperator *plan_root,
std::ostream *out);
void PrettyPrint(const DbAccessor &dba, const LogicalOperator *plan_root, std::ostream *out);
/// Overload of `PrettyPrint` which defaults the `std::ostream` to `std::cout`.
inline void PrettyPrint(const DbAccessor &dba,
const LogicalOperator *plan_root) {
inline void PrettyPrint(const DbAccessor &dba, const LogicalOperator *plan_root) {
PrettyPrint(dba, plan_root, &std::cout);
}
/// Convert a `LogicalOperator` plan to a JSON representation.
/// DbAccessor is needed for resolving label and property names.
nlohmann::json PlanToJson(const DbAccessor &dba,
const LogicalOperator *plan_root);
nlohmann::json PlanToJson(const DbAccessor &dba, const LogicalOperator *plan_root);
class PlanPrinter : public virtual HierarchicalLogicalOperatorVisitor {
public:
@ -130,9 +127,8 @@ nlohmann::json ToJson(storage::PropertyId property, const DbAccessor &dba);
nlohmann::json ToJson(NamedExpression *nexpr);
nlohmann::json ToJson(
const std::vector<std::pair<storage::PropertyId, Expression *>> &properties,
const DbAccessor &dba);
nlohmann::json ToJson(const std::vector<std::pair<storage::PropertyId, Expression *>> &properties,
const DbAccessor &dba);
nlohmann::json ToJson(const NodeCreationInfo &node_info, const DbAccessor &dba);
@ -141,7 +137,7 @@ nlohmann::json ToJson(const EdgeCreationInfo &edge_info, const DbAccessor &dba);
nlohmann::json ToJson(const Aggregate::Element &elem);
template <class T, class... Args>
nlohmann::json ToJson(const std::vector<T> &items, Args &&... args) {
nlohmann::json ToJson(const std::vector<T> &items, Args &&...args) {
nlohmann::json json;
for (const auto &item : items) {
json.emplace_back(ToJson(item, std::forward<Args>(args)...));

View File

@ -14,23 +14,18 @@ namespace query::plan {
namespace {
unsigned long long IndividualCycles(const ProfilingStats &cumulative_stats) {
return cumulative_stats.num_cycles -
std::accumulate(
cumulative_stats.children.begin(), cumulative_stats.children.end(),
0ULL,
[](auto acc, auto &stats) { return acc + stats.num_cycles; });
return cumulative_stats.num_cycles - std::accumulate(cumulative_stats.children.begin(),
cumulative_stats.children.end(), 0ULL,
[](auto acc, auto &stats) { return acc + stats.num_cycles; });
}
double RelativeTime(unsigned long long num_cycles,
unsigned long long total_cycles) {
double RelativeTime(unsigned long long num_cycles, unsigned long long total_cycles) {
return static_cast<double>(num_cycles) / total_cycles;
}
double AbsoluteTime(unsigned long long num_cycles,
unsigned long long total_cycles,
double AbsoluteTime(unsigned long long num_cycles, unsigned long long total_cycles,
std::chrono::duration<double> total_time) {
return (RelativeTime(num_cycles, total_cycles) *
static_cast<std::chrono::duration<double, std::milli>>(total_time))
return (RelativeTime(num_cycles, total_cycles) * static_cast<std::chrono::duration<double, std::milli>>(total_time))
.count();
}
@ -44,18 +39,15 @@ namespace {
class ProfilingStatsToTableHelper {
public:
ProfilingStatsToTableHelper(unsigned long long total_cycles,
std::chrono::duration<double> total_time)
ProfilingStatsToTableHelper(unsigned long long total_cycles, std::chrono::duration<double> total_time)
: total_cycles_(total_cycles), total_time_(total_time) {}
void Output(const ProfilingStats &cumulative_stats) {
auto cycles = IndividualCycles(cumulative_stats);
rows_.emplace_back(std::vector<TypedValue>{
TypedValue(FormatOperator(cumulative_stats.name)),
TypedValue(cumulative_stats.actual_hits),
TypedValue(FormatRelativeTime(cycles)),
TypedValue(FormatAbsoluteTime(cycles))});
TypedValue(FormatOperator(cumulative_stats.name)), TypedValue(cumulative_stats.actual_hits),
TypedValue(FormatRelativeTime(cycles)), TypedValue(FormatAbsoluteTime(cycles))});
for (size_t i = 1; i < cumulative_stats.children.size(); ++i) {
Branch(cumulative_stats.children[i]);
@ -70,8 +62,7 @@ class ProfilingStatsToTableHelper {
private:
void Branch(const ProfilingStats &cumulative_stats) {
rows_.emplace_back(std::vector<TypedValue>{
TypedValue("|\\"), TypedValue(""), TypedValue(""), TypedValue("")});
rows_.emplace_back(std::vector<TypedValue>{TypedValue("|\\"), TypedValue(""), TypedValue(""), TypedValue("")});
++depth_;
Output(cumulative_stats);
@ -89,18 +80,14 @@ class ProfilingStatsToTableHelper {
std::string Format(const std::string &str) { return Format(str.c_str()); }
std::string FormatOperator(const char *str) {
return Format(std::string("* ") + str);
}
std::string FormatOperator(const char *str) { return Format(std::string("* ") + str); }
std::string FormatRelativeTime(unsigned long long num_cycles) {
return fmt::format("{: 10.6f} %",
RelativeTime(num_cycles, total_cycles_) * 100);
return fmt::format("{: 10.6f} %", RelativeTime(num_cycles, total_cycles_) * 100);
}
std::string FormatAbsoluteTime(unsigned long long num_cycles) {
return fmt::format("{: 10.6f} ms",
AbsoluteTime(num_cycles, total_cycles_, total_time_));
return fmt::format("{: 10.6f} ms", AbsoluteTime(num_cycles, total_cycles_, total_time_));
}
int64_t depth_{0};
@ -111,9 +98,8 @@ class ProfilingStatsToTableHelper {
} // namespace
std::vector<std::vector<TypedValue>> ProfilingStatsToTable(
const ProfilingStats &cumulative_stats,
std::chrono::duration<double> total_time) {
std::vector<std::vector<TypedValue>> ProfilingStatsToTable(const ProfilingStats &cumulative_stats,
std::chrono::duration<double> total_time) {
ProfilingStatsToTableHelper helper{cumulative_stats.num_cycles, total_time};
helper.Output(cumulative_stats);
return helper.rows();
@ -130,13 +116,10 @@ class ProfilingStatsToJsonHelper {
using json = nlohmann::json;
public:
ProfilingStatsToJsonHelper(unsigned long long total_cycles,
std::chrono::duration<double> total_time)
ProfilingStatsToJsonHelper(unsigned long long total_cycles, std::chrono::duration<double> total_time)
: total_cycles_(total_cycles), total_time_(total_time) {}
void Output(const ProfilingStats &cumulative_stats) {
return Output(cumulative_stats, &json_);
}
void Output(const ProfilingStats &cumulative_stats) { return Output(cumulative_stats, &json_); }
json ToJson() { return json_; }
@ -147,8 +130,7 @@ class ProfilingStatsToJsonHelper {
obj->emplace("name", cumulative_stats.name);
obj->emplace("actual_hits", cumulative_stats.actual_hits);
obj->emplace("relative_time", RelativeTime(cycles, total_cycles_));
obj->emplace("absolute_time",
AbsoluteTime(cycles, total_cycles_, total_time_));
obj->emplace("absolute_time", AbsoluteTime(cycles, total_cycles_, total_time_));
obj->emplace("children", json::array());
for (size_t i = 0; i < cumulative_stats.children.size(); ++i) {
@ -165,8 +147,7 @@ class ProfilingStatsToJsonHelper {
} // namespace
nlohmann::json ProfilingStatsToJson(const ProfilingStats &cumulative_stats,
std::chrono::duration<double> total_time) {
nlohmann::json ProfilingStatsToJson(const ProfilingStats &cumulative_stats, std::chrono::duration<double> total_time) {
ProfilingStatsToJsonHelper helper{cumulative_stats.num_cycles, total_time};
helper.Output(cumulative_stats);
return helper.ToJson();

View File

@ -23,11 +23,10 @@ struct ProfilingStats {
std::vector<ProfilingStats> children;
};
std::vector<std::vector<TypedValue>> ProfilingStatsToTable(
const ProfilingStats &cumulative_stats, std::chrono::duration<double>);
std::vector<std::vector<TypedValue>> ProfilingStatsToTable(const ProfilingStats &cumulative_stats,
std::chrono::duration<double>);
nlohmann::json ProfilingStatsToJson(const ProfilingStats &cumulative_stats,
std::chrono::duration<double>);
nlohmann::json ProfilingStatsToJson(const ProfilingStats &cumulative_stats, std::chrono::duration<double>);
} // namespace plan
} // namespace query

View File

@ -81,9 +81,7 @@ void ReadWriteTypeChecker::UpdateType(RWType op_type) {
}
}
void ReadWriteTypeChecker::InferRWType(LogicalOperator &root) {
root.Accept(*this);
}
void ReadWriteTypeChecker::InferRWType(LogicalOperator &root) { root.Accept(*this); }
std::string ReadWriteTypeChecker::TypeToString(const RWType type) {
switch (type) {

View File

@ -2,17 +2,15 @@
#include "utils/flag_validation.hpp"
DEFINE_VALIDATED_HIDDEN_int64(
query_vertex_count_to_expand_existing, 10,
"Maximum count of indexed vertices which provoke "
"indexed lookup and then expand to existing, instead of "
"a regular expand. Default is 10, to turn off use -1.",
FLAG_IN_RANGE(-1, std::numeric_limits<std::int64_t>::max()));
DEFINE_VALIDATED_HIDDEN_int64(query_vertex_count_to_expand_existing, 10,
"Maximum count of indexed vertices which provoke "
"indexed lookup and then expand to existing, instead of "
"a regular expand. Default is 10, to turn off use -1.",
FLAG_IN_RANGE(-1, std::numeric_limits<std::int64_t>::max()));
namespace query::plan::impl {
Expression *RemoveAndExpressions(
Expression *expr, const std::unordered_set<Expression *> &exprs_to_remove) {
Expression *RemoveAndExpressions(Expression *expr, const std::unordered_set<Expression *> &exprs_to_remove) {
auto *and_op = utils::Downcast<AndOperator>(expr);
if (!and_op) return expr;
if (utils::Contains(exprs_to_remove, and_op)) {
@ -24,10 +22,8 @@ Expression *RemoveAndExpressions(
if (utils::Contains(exprs_to_remove, and_op->expression2_)) {
and_op->expression2_ = nullptr;
}
and_op->expression1_ =
RemoveAndExpressions(and_op->expression1_, exprs_to_remove);
and_op->expression2_ =
RemoveAndExpressions(and_op->expression2_, exprs_to_remove);
and_op->expression1_ = RemoveAndExpressions(and_op->expression1_, exprs_to_remove);
and_op->expression2_ = RemoveAndExpressions(and_op->expression2_, exprs_to_remove);
if (!and_op->expression1_ && !and_op->expression2_) {
return nullptr;
}

View File

@ -25,14 +25,12 @@ namespace impl {
// Return the new root expression after removing the given expressions from the
// given expression tree.
Expression *RemoveAndExpressions(
Expression *expr, const std::unordered_set<Expression *> &exprs_to_remove);
Expression *RemoveAndExpressions(Expression *expr, const std::unordered_set<Expression *> &exprs_to_remove);
template <class TDbAccessor>
class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
public:
IndexLookupRewriter(SymbolTable *symbol_table, AstStorage *ast_storage,
TDbAccessor *db)
IndexLookupRewriter(SymbolTable *symbol_table, AstStorage *ast_storage, TDbAccessor *db)
: symbol_table_(symbol_table), ast_storage_(ast_storage), db_(db) {}
using HierarchicalLogicalOperatorVisitor::PostVisit;
@ -52,10 +50,8 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
// free the memory.
bool PostVisit(Filter &op) override {
prev_ops_.pop_back();
op.expression_ =
RemoveAndExpressions(op.expression_, filter_exprs_for_removal_);
if (!op.expression_ ||
utils::Contains(filter_exprs_for_removal_, op.expression_)) {
op.expression_ = RemoveAndExpressions(op.expression_, filter_exprs_for_removal_);
if (!op.expression_ || utils::Contains(filter_exprs_for_removal_, op.expression_)) {
SetOnParent(op.input());
}
return true;
@ -91,8 +87,7 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
return true;
}
ScanAll dst_scan(expand.input(), expand.common_.node_symbol, expand.view_);
auto indexed_scan =
GenScanByIndex(dst_scan, FLAGS_query_vertex_count_to_expand_existing);
auto indexed_scan = GenScanByIndex(dst_scan, FLAGS_query_vertex_count_to_expand_existing);
if (indexed_scan) {
expand.set_input(std::move(indexed_scan));
expand.common_.existing_node = true;
@ -113,8 +108,7 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
return true;
}
std::unique_ptr<ScanAll> indexed_scan;
ScanAll dst_scan(expand.input(), expand.common_.node_symbol,
storage::View::OLD);
ScanAll dst_scan(expand.input(), expand.common_.node_symbol, storage::View::OLD);
// With expand to existing we only get real gains with BFS, because we use a
// different algorithm then, so prefer expand to existing.
if (expand.type_ == EdgeAtom::Type::BREADTH_FIRST) {
@ -122,8 +116,7 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
// unconditionally creating an indexed scan.
indexed_scan = GenScanByIndex(dst_scan);
} else {
indexed_scan =
GenScanByIndex(dst_scan, FLAGS_query_vertex_count_to_expand_existing);
indexed_scan = GenScanByIndex(dst_scan, FLAGS_query_vertex_count_to_expand_existing);
}
if (indexed_scan) {
expand.set_input(std::move(indexed_scan));
@ -449,9 +442,7 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
int64_t vertex_count;
};
bool DefaultPreVisit() override {
throw utils::NotYetImplemented("optimizing index lookup");
}
bool DefaultPreVisit() override { throw utils::NotYetImplemented("optimizing index lookup"); }
void SetOnParent(const std::shared_ptr<LogicalOperator> &input) {
MG_ASSERT(input);
@ -471,18 +462,12 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
}
}
storage::LabelId GetLabel(LabelIx label) {
return db_->NameToLabel(label.name);
}
storage::LabelId GetLabel(LabelIx label) { return db_->NameToLabel(label.name); }
storage::PropertyId GetProperty(PropertyIx prop) {
return db_->NameToProperty(prop.name);
}
storage::PropertyId GetProperty(PropertyIx prop) { return db_->NameToProperty(prop.name); }
std::optional<LabelIx> FindBestLabelIndex(
const std::unordered_set<LabelIx> &labels) {
MG_ASSERT(!labels.empty(),
"Trying to find the best label without any labels.");
std::optional<LabelIx> FindBestLabelIndex(const std::unordered_set<LabelIx> &labels) {
MG_ASSERT(!labels.empty(), "Trying to find the best label without any labels.");
std::optional<LabelIx> best_label;
for (const auto &label : labels) {
if (!db_->LabelIndexExists(GetLabel(label))) continue;
@ -490,17 +475,15 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
best_label = label;
continue;
}
if (db_->VerticesCount(GetLabel(label)) <
db_->VerticesCount(GetLabel(*best_label)))
best_label = label;
if (db_->VerticesCount(GetLabel(label)) < db_->VerticesCount(GetLabel(*best_label))) best_label = label;
}
return best_label;
}
// Finds the label-property combination which has indexed the lowest amount of
// vertices. If the index cannot be found, nullopt is returned.
std::optional<LabelPropertyIndex> FindBestLabelPropertyIndex(
const Symbol &symbol, const std::unordered_set<Symbol> &bound_symbols) {
std::optional<LabelPropertyIndex> FindBestLabelPropertyIndex(const Symbol &symbol,
const std::unordered_set<Symbol> &bound_symbols) {
auto are_bound = [&bound_symbols](const auto &used_symbols) {
for (const auto &used_symbol : used_symbols) {
if (!utils::Contains(bound_symbols, used_symbol)) {
@ -512,8 +495,7 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
std::optional<LabelPropertyIndex> found;
for (const auto &label : filters_.FilteredLabels(symbol)) {
for (const auto &filter : filters_.PropertyFilters(symbol)) {
if (filter.property_filter->is_symbol_in_value_ ||
!are_bound(filter.used_symbols)) {
if (filter.property_filter->is_symbol_in_value_ || !are_bound(filter.used_symbols)) {
// Skip filter expressions which use the symbol whose property we are
// looking up or aren't bound. We cannot scan by such expressions. For
// example, in `n.a = 2 + n.b` both sides of `=` refer to `n`, so we
@ -521,27 +503,20 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
continue;
}
const auto &property = filter.property_filter->property_;
if (!db_->LabelPropertyIndexExists(GetLabel(label),
GetProperty(property))) {
if (!db_->LabelPropertyIndexExists(GetLabel(label), GetProperty(property))) {
continue;
}
int64_t vertex_count =
db_->VerticesCount(GetLabel(label), GetProperty(property));
int64_t vertex_count = db_->VerticesCount(GetLabel(label), GetProperty(property));
auto is_better_type = [&found](PropertyFilter::Type type) {
// Order the types by the most preferred index lookup type.
static const PropertyFilter::Type kFilterTypeOrder[] = {
PropertyFilter::Type::EQUAL, PropertyFilter::Type::RANGE,
PropertyFilter::Type::REGEX_MATCH};
auto *found_sort_ix =
std::find(kFilterTypeOrder, kFilterTypeOrder + 3,
found->filter.property_filter->type_);
auto *type_sort_ix =
std::find(kFilterTypeOrder, kFilterTypeOrder + 3, type);
PropertyFilter::Type::EQUAL, PropertyFilter::Type::RANGE, PropertyFilter::Type::REGEX_MATCH};
auto *found_sort_ix = std::find(kFilterTypeOrder, kFilterTypeOrder + 3, found->filter.property_filter->type_);
auto *type_sort_ix = std::find(kFilterTypeOrder, kFilterTypeOrder + 3, type);
return type_sort_ix < found_sort_ix;
};
if (!found || vertex_count < found->vertex_count ||
(vertex_count == found->vertex_count &&
is_better_type(filter.property_filter->type_))) {
(vertex_count == found->vertex_count && is_better_type(filter.property_filter->type_))) {
found = LabelPropertyIndex{label, filter, vertex_count};
}
}
@ -556,15 +531,13 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
// `max_vertex_count` controls, whether no operator should be created if the
// vertex count in the best index exceeds this number. In such a case,
// `nullptr` is returned and `input` is not chained.
std::unique_ptr<ScanAll> GenScanByIndex(
const ScanAll &scan,
const std::optional<int64_t> &max_vertex_count = std::nullopt) {
std::unique_ptr<ScanAll> GenScanByIndex(const ScanAll &scan,
const std::optional<int64_t> &max_vertex_count = std::nullopt) {
const auto &input = scan.input();
const auto &node_symbol = scan.output_symbol_;
const auto &view = scan.view_;
const auto &modified_symbols = scan.ModifiedSymbols(*symbol_table_);
std::unordered_set<Symbol> bound_symbols(modified_symbols.begin(),
modified_symbols.end());
std::unordered_set<Symbol> bound_symbols(modified_symbols.begin(), modified_symbols.end());
auto are_bound = [&bound_symbols](const auto &used_symbols) {
for (const auto &used_symbol : used_symbols) {
if (!utils::Contains(bound_symbols, used_symbol)) {
@ -576,9 +549,7 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
// First, try to see if we can find a vertex by ID.
if (!max_vertex_count || *max_vertex_count >= 1) {
for (const auto &filter : filters_.IdFilters(node_symbol)) {
if (filter.id_filter->is_symbol_in_value_ ||
!are_bound(filter.used_symbols))
continue;
if (filter.id_filter->is_symbol_in_value_ || !are_bound(filter.used_symbols)) continue;
auto *value = filter.id_filter->value_;
filter_exprs_for_removal_.insert(filter.expression);
filters_.EraseFilter(filter);
@ -606,76 +577,62 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
}
filters_.EraseFilter(found_index->filter);
std::vector<Expression *> removed_expressions;
filters_.EraseLabelFilter(node_symbol, found_index->label,
&removed_expressions);
filter_exprs_for_removal_.insert(removed_expressions.begin(),
removed_expressions.end());
filters_.EraseLabelFilter(node_symbol, found_index->label, &removed_expressions);
filter_exprs_for_removal_.insert(removed_expressions.begin(), removed_expressions.end());
if (prop_filter.lower_bound_ || prop_filter.upper_bound_) {
return std::make_unique<ScanAllByLabelPropertyRange>(
input, node_symbol, GetLabel(found_index->label),
GetProperty(prop_filter.property_), prop_filter.property_.name,
prop_filter.lower_bound_, prop_filter.upper_bound_, view);
input, node_symbol, GetLabel(found_index->label), GetProperty(prop_filter.property_),
prop_filter.property_.name, prop_filter.lower_bound_, prop_filter.upper_bound_, view);
} else if (prop_filter.type_ == PropertyFilter::Type::REGEX_MATCH) {
// Generate index scan using the empty string as a lower bound.
Expression *empty_string = ast_storage_->Create<PrimitiveLiteral>("");
auto lower_bound = utils::MakeBoundInclusive(empty_string);
return std::make_unique<ScanAllByLabelPropertyRange>(
input, node_symbol, GetLabel(found_index->label),
GetProperty(prop_filter.property_), prop_filter.property_.name,
std::make_optional(lower_bound), std::nullopt, view);
input, node_symbol, GetLabel(found_index->label), GetProperty(prop_filter.property_),
prop_filter.property_.name, std::make_optional(lower_bound), std::nullopt, view);
} else if (prop_filter.type_ == PropertyFilter::Type::IN) {
// TODO(buda): ScanAllByLabelProperty + Filter should be considered
// here once the operator and the right cardinality estimation exist.
auto const &symbol = symbol_table_->CreateAnonymousSymbol();
auto *expression = ast_storage_->Create<Identifier>(symbol.name_);
expression->MapTo(symbol);
auto unwind_operator =
std::make_unique<Unwind>(input, prop_filter.value_, symbol);
auto unwind_operator = std::make_unique<Unwind>(input, prop_filter.value_, symbol);
return std::make_unique<ScanAllByLabelPropertyValue>(
std::move(unwind_operator), node_symbol,
GetLabel(found_index->label), GetProperty(prop_filter.property_),
std::move(unwind_operator), node_symbol, GetLabel(found_index->label), GetProperty(prop_filter.property_),
prop_filter.property_.name, expression, view);
} else if (prop_filter.type_ == PropertyFilter::Type::IS_NOT_NULL) {
return std::make_unique<ScanAllByLabelProperty>(
input, node_symbol, GetLabel(found_index->label),
GetProperty(prop_filter.property_), prop_filter.property_.name,
view);
return std::make_unique<ScanAllByLabelProperty>(input, node_symbol, GetLabel(found_index->label),
GetProperty(prop_filter.property_), prop_filter.property_.name,
view);
} else {
MG_ASSERT(
prop_filter.value_,
"Property filter should either have bounds or a value expression.");
return std::make_unique<ScanAllByLabelPropertyValue>(
input, node_symbol, GetLabel(found_index->label),
GetProperty(prop_filter.property_), prop_filter.property_.name,
prop_filter.value_, view);
MG_ASSERT(prop_filter.value_, "Property filter should either have bounds or a value expression.");
return std::make_unique<ScanAllByLabelPropertyValue>(input, node_symbol, GetLabel(found_index->label),
GetProperty(prop_filter.property_),
prop_filter.property_.name, prop_filter.value_, view);
}
}
auto maybe_label = FindBestLabelIndex(labels);
if (!maybe_label) return nullptr;
const auto &label = *maybe_label;
if (max_vertex_count &&
db_->VerticesCount(GetLabel(label)) > *max_vertex_count) {
if (max_vertex_count && db_->VerticesCount(GetLabel(label)) > *max_vertex_count) {
// Don't create an indexed lookup, since we have more labeled vertices
// than the allowed count.
return nullptr;
}
std::vector<Expression *> removed_expressions;
filters_.EraseLabelFilter(node_symbol, label, &removed_expressions);
filter_exprs_for_removal_.insert(removed_expressions.begin(),
removed_expressions.end());
return std::make_unique<ScanAllByLabel>(input, node_symbol, GetLabel(label),
view);
filter_exprs_for_removal_.insert(removed_expressions.begin(), removed_expressions.end());
return std::make_unique<ScanAllByLabel>(input, node_symbol, GetLabel(label), view);
}
};
} // namespace impl
template <class TDbAccessor>
std::unique_ptr<LogicalOperator> RewriteWithIndexLookup(
std::unique_ptr<LogicalOperator> root_op, SymbolTable *symbol_table,
AstStorage *ast_storage, TDbAccessor *db) {
impl::IndexLookupRewriter<TDbAccessor> rewriter(symbol_table, ast_storage,
db);
std::unique_ptr<LogicalOperator> RewriteWithIndexLookup(std::unique_ptr<LogicalOperator> root_op,
SymbolTable *symbol_table, AstStorage *ast_storage,
TDbAccessor *db) {
impl::IndexLookupRewriter<TDbAccessor> rewriter(symbol_table, ast_storage, db);
root_op->Accept(rewriter);
if (rewriter.new_root_) {
// This shouldn't happen in real use case, because IndexLookupRewriter

View File

@ -14,8 +14,7 @@ namespace query::plan {
namespace {
bool HasBoundFilterSymbols(const std::unordered_set<Symbol> &bound_symbols,
const FilterInfo &filter) {
bool HasBoundFilterSymbols(const std::unordered_set<Symbol> &bound_symbols, const FilterInfo &filter) {
for (const auto &symbol : filter.used_symbols) {
if (bound_symbols.find(symbol) == bound_symbols.end()) {
return false;
@ -37,14 +36,9 @@ bool HasBoundFilterSymbols(const std::unordered_set<Symbol> &bound_symbols,
// aggregations and expressions used for group by.
class ReturnBodyContext : public HierarchicalTreeVisitor {
public:
ReturnBodyContext(const ReturnBody &body, SymbolTable &symbol_table,
const std::unordered_set<Symbol> &bound_symbols,
ReturnBodyContext(const ReturnBody &body, SymbolTable &symbol_table, const std::unordered_set<Symbol> &bound_symbols,
AstStorage &storage, Where *where = nullptr)
: body_(body),
symbol_table_(symbol_table),
bound_symbols_(bound_symbols),
storage_(storage),
where_(where) {
: body_(body), symbol_table_(symbol_table), bound_symbols_(bound_symbols), storage_(storage), where_(where) {
// Collect symbols from named expressions.
output_symbols_.reserve(body_.named_expressions.size());
if (body.all_identifiers) {
@ -78,8 +72,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
if (where) {
where->Accept(*this);
}
MG_ASSERT(aggregations_.empty(),
"Unexpected aggregations in ORDER BY or WHERE");
MG_ASSERT(aggregations_.empty(), "Unexpected aggregations in ORDER BY or WHERE");
}
}
@ -94,8 +87,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
private:
template <typename TLiteral, typename TIteratorToExpression>
void PostVisitCollectionLiteral(
TLiteral &literal, TIteratorToExpression iterator_to_expression) {
void PostVisitCollectionLiteral(TLiteral &literal, TIteratorToExpression iterator_to_expression) {
// If there is an aggregation in the list, and there are group-bys, then we
// need to add the group-bys manually. If there are no aggregations, the
// whole list will be added as a group-by.
@ -115,8 +107,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
}
has_aggregation_.emplace_back(has_aggr);
if (has_aggr) {
for (auto expression_ptr : literal_group_by)
group_by_.emplace_back(expression_ptr);
for (auto expression_ptr : literal_group_by) group_by_.emplace_back(expression_ptr);
}
}
@ -130,9 +121,8 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
}
bool PostVisit(MapLiteral &map_literal) override {
MG_ASSERT(
map_literal.elements_.size() <= has_aggregation_.size(),
"Expected has_aggregation_ flags as much as there are map elements.");
MG_ASSERT(map_literal.elements_.size() <= has_aggregation_.size(),
"Expected has_aggregation_ flags as much as there are map elements.");
PostVisitCollectionLiteral(map_literal, [](auto it) { return it->second; });
return true;
}
@ -141,8 +131,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
// Remove the symbol which is bound by all, because we are only interested
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*all.identifier_));
MG_ASSERT(has_aggregation_.size() >= 3U,
"Expected 3 has_aggregation_ flags for ALL arguments");
MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for ALL arguments");
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
@ -156,8 +145,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
// Remove the symbol which is bound by single, because we are only
// interested in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*single.identifier_));
MG_ASSERT(has_aggregation_.size() >= 3U,
"Expected 3 has_aggregation_ flags for SINGLE arguments");
MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for SINGLE arguments");
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
@ -171,8 +159,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
// Remove the symbol which is bound by any, because we are only interested
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*any.identifier_));
MG_ASSERT(has_aggregation_.size() >= 3U,
"Expected 3 has_aggregation_ flags for ANY arguments");
MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for ANY arguments");
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
@ -186,8 +173,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
// Remove the symbol which is bound by none, because we are only interested
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*none.identifier_));
MG_ASSERT(has_aggregation_.size() >= 3U,
"Expected 3 has_aggregation_ flags for NONE arguments");
MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for NONE arguments");
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
@ -202,8 +188,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*reduce.accumulator_));
used_symbols_.erase(symbol_table_.at(*reduce.identifier_));
MG_ASSERT(has_aggregation_.size() >= 5U,
"Expected 5 has_aggregation_ flags for REDUCE arguments");
MG_ASSERT(has_aggregation_.size() >= 5U, "Expected 5 has_aggregation_ flags for REDUCE arguments");
bool has_aggr = false;
for (int i = 0; i < 5; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
@ -215,8 +200,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
bool PostVisit(Coalesce &coalesce) override {
MG_ASSERT(has_aggregation_.size() >= coalesce.expressions_.size(),
"Expected >= {} has_aggregation_ flags for COALESCE arguments",
has_aggregation_.size());
"Expected >= {} has_aggregation_ flags for COALESCE arguments", has_aggregation_.size());
bool has_aggr = false;
for (size_t i = 0; i < coalesce.expressions_.size(); ++i) {
has_aggr = has_aggr || has_aggregation_.back();
@ -230,8 +214,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
// Remove the symbol bound by extract, because we are only interested
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*extract.identifier_));
MG_ASSERT(has_aggregation_.size() >= 3U,
"Expected 3 has_aggregation_ flags for EXTRACT arguments");
MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for EXTRACT arguments");
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
@ -308,25 +291,24 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
return true;
}
#define VISIT_BINARY_OPERATOR(BinaryOperator) \
bool PostVisit(BinaryOperator &op) override { \
MG_ASSERT(has_aggregation_.size() >= 2U, \
"Expected at least 2 has_aggregation_ flags."); \
/* has_aggregation_ stack is reversed, last result is from the 2nd */ \
/* expression. */ \
bool aggr2 = has_aggregation_.back(); \
has_aggregation_.pop_back(); \
bool aggr1 = has_aggregation_.back(); \
has_aggregation_.pop_back(); \
bool has_aggr = aggr1 || aggr2; \
if (has_aggr && !(aggr1 && aggr2)) { \
/* Group by the expression which does not contain aggregation. */ \
/* Possible optimization is to ignore constant value expressions */ \
group_by_.emplace_back(aggr1 ? op.expression2_ : op.expression1_); \
} \
/* Propagate that this whole expression may contain an aggregation. */ \
has_aggregation_.emplace_back(has_aggr); \
return true; \
#define VISIT_BINARY_OPERATOR(BinaryOperator) \
bool PostVisit(BinaryOperator &op) override { \
MG_ASSERT(has_aggregation_.size() >= 2U, "Expected at least 2 has_aggregation_ flags."); \
/* has_aggregation_ stack is reversed, last result is from the 2nd */ \
/* expression. */ \
bool aggr2 = has_aggregation_.back(); \
has_aggregation_.pop_back(); \
bool aggr1 = has_aggregation_.back(); \
has_aggregation_.pop_back(); \
bool has_aggr = aggr1 || aggr2; \
if (has_aggr && !(aggr1 && aggr2)) { \
/* Group by the expression which does not contain aggregation. */ \
/* Possible optimization is to ignore constant value expressions */ \
group_by_.emplace_back(aggr1 ? op.expression2_ : op.expression1_); \
} \
/* Propagate that this whole expression may contain an aggregation. */ \
has_aggregation_.emplace_back(has_aggr); \
return true; \
}
VISIT_BINARY_OPERATOR(OrOperator)
@ -351,8 +333,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
bool PostVisit(Aggregation &aggr) override {
// Aggregation contains a virtual symbol, where the result will be stored.
const auto &symbol = symbol_table_.at(aggr);
aggregations_.emplace_back(Aggregate::Element{
aggr.expression1_, aggr.expression2_, aggr.op_, symbol});
aggregations_.emplace_back(Aggregate::Element{aggr.expression1_, aggr.expression2_, aggr.op_, symbol});
// Aggregation expression1_ is optional in COUNT(*), and COLLECT_MAP uses
// two expressions, so we can have 0, 1 or 2 elements on the
// has_aggregation_stack for this Aggregation expression.
@ -368,8 +349,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
}
bool PostVisit(NamedExpression &named_expr) override {
MG_ASSERT(has_aggregation_.size() == 1U,
"Expected to reduce has_aggregation_ to single boolean.");
MG_ASSERT(has_aggregation_.size() == 1U, "Expected to reduce has_aggregation_ to single boolean.");
if (!has_aggregation_.back()) {
group_by_.emplace_back(named_expr.expression_);
}
@ -383,8 +363,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
}
bool PostVisit(RegexMatch &regex_match) override {
MG_ASSERT(has_aggregation_.size() >= 2U,
"Expected 2 has_aggregation_ flags for RegexMatch arguments");
MG_ASSERT(has_aggregation_.size() >= 2U, "Expected 2 has_aggregation_ flags for RegexMatch arguments");
bool has_aggr = has_aggregation_.back();
has_aggregation_.pop_back();
has_aggregation_.back() |= has_aggr;
@ -395,17 +374,14 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
// This should be used when body.all_identifiers is true, to generate
// expressions for Produce operator.
void ExpandUserSymbols() {
MG_ASSERT(named_expressions_.empty(),
"ExpandUserSymbols should be first to fill named_expressions_");
MG_ASSERT(output_symbols_.empty(),
"ExpandUserSymbols should be first to fill output_symbols_");
MG_ASSERT(named_expressions_.empty(), "ExpandUserSymbols should be first to fill named_expressions_");
MG_ASSERT(output_symbols_.empty(), "ExpandUserSymbols should be first to fill output_symbols_");
for (const auto &symbol : bound_symbols_) {
if (!symbol.user_declared()) {
continue;
}
auto *ident = storage_.Create<Identifier>(symbol.name())->MapTo(symbol);
auto *named_expr =
storage_.Create<NamedExpression>(symbol.name(), ident)->MapTo(symbol);
auto *named_expr = storage_.Create<NamedExpression>(symbol.name(), ident)->MapTo(symbol);
// Fill output expressions and symbols with expanded identifiers.
named_expressions_.emplace_back(named_expr);
output_symbols_.emplace_back(symbol);
@ -471,38 +447,30 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
std::vector<NamedExpression *> named_expressions_;
};
std::unique_ptr<LogicalOperator> GenReturnBody(
std::unique_ptr<LogicalOperator> input_op, bool advance_command,
const ReturnBodyContext &body, bool accumulate = false) {
std::vector<Symbol> used_symbols(body.used_symbols().begin(),
body.used_symbols().end());
std::unique_ptr<LogicalOperator> GenReturnBody(std::unique_ptr<LogicalOperator> input_op, bool advance_command,
const ReturnBodyContext &body, bool accumulate = false) {
std::vector<Symbol> used_symbols(body.used_symbols().begin(), body.used_symbols().end());
auto last_op = std::move(input_op);
if (accumulate) {
// We only advance the command in Accumulate. This is done for WITH clause,
// when the first part updated the database. RETURN clause may only need an
// accumulation after updates, without advancing the command.
last_op = std::make_unique<Accumulate>(std::move(last_op), used_symbols,
advance_command);
last_op = std::make_unique<Accumulate>(std::move(last_op), used_symbols, advance_command);
}
if (!body.aggregations().empty()) {
// When we have aggregation, SKIP/LIMIT should always come after it.
std::vector<Symbol> remember(body.group_by_used_symbols().begin(),
body.group_by_used_symbols().end());
last_op = std::make_unique<Aggregate>(
std::move(last_op), body.aggregations(), body.group_by(), remember);
std::vector<Symbol> remember(body.group_by_used_symbols().begin(), body.group_by_used_symbols().end());
last_op = std::make_unique<Aggregate>(std::move(last_op), body.aggregations(), body.group_by(), remember);
}
last_op =
std::make_unique<Produce>(std::move(last_op), body.named_expressions());
last_op = std::make_unique<Produce>(std::move(last_op), body.named_expressions());
// Distinct in ReturnBody only makes Produce values unique, so plan after it.
if (body.distinct()) {
last_op =
std::make_unique<Distinct>(std::move(last_op), body.output_symbols());
last_op = std::make_unique<Distinct>(std::move(last_op), body.output_symbols());
}
// Like Where, OrderBy can read from symbols established by named expressions
// in Produce, so it must come after it.
if (!body.order_by().empty()) {
last_op = std::make_unique<OrderBy>(std::move(last_op), body.order_by(),
body.output_symbols());
last_op = std::make_unique<OrderBy>(std::move(last_op), body.order_by(), body.output_symbols());
}
// Finally, Skip and Limit must come after OrderBy.
if (body.skip()) {
@ -515,8 +483,7 @@ std::unique_ptr<LogicalOperator> GenReturnBody(
// Where may see new symbols so it comes after we generate Produce and in
// general, comes after any OrderBy, Skip or Limit.
if (body.where()) {
last_op =
std::make_unique<Filter>(std::move(last_op), body.where()->expression_);
last_op = std::make_unique<Filter>(std::move(last_op), body.where()->expression_);
}
return last_op;
}
@ -525,13 +492,11 @@ std::unique_ptr<LogicalOperator> GenReturnBody(
namespace impl {
Expression *ExtractFilters(const std::unordered_set<Symbol> &bound_symbols,
Filters &filters, AstStorage &storage) {
Expression *ExtractFilters(const std::unordered_set<Symbol> &bound_symbols, Filters &filters, AstStorage &storage) {
Expression *filter_expr = nullptr;
for (auto filters_it = filters.begin(); filters_it != filters.end();) {
if (HasBoundFilterSymbols(bound_symbols, *filters_it)) {
filter_expr = impl::BoolJoin<AndOperator>(storage, filter_expr,
filters_it->expression);
filter_expr = impl::BoolJoin<AndOperator>(storage, filter_expr, filters_it->expression);
filters_it = filters.erase(filters_it);
} else {
filters_it++;
@ -540,10 +505,9 @@ Expression *ExtractFilters(const std::unordered_set<Symbol> &bound_symbols,
return filter_expr;
}
std::unique_ptr<LogicalOperator> GenFilters(
std::unique_ptr<LogicalOperator> last_op,
const std::unordered_set<Symbol> &bound_symbols, Filters &filters,
AstStorage &storage) {
std::unique_ptr<LogicalOperator> GenFilters(std::unique_ptr<LogicalOperator> last_op,
const std::unordered_set<Symbol> &bound_symbols, Filters &filters,
AstStorage &storage) {
auto *filter_expr = ExtractFilters(bound_symbols, filters, storage);
if (filter_expr) {
last_op = std::make_unique<Filter>(std::move(last_op), filter_expr);
@ -551,21 +515,18 @@ std::unique_ptr<LogicalOperator> GenFilters(
return last_op;
}
std::unique_ptr<LogicalOperator> GenNamedPaths(
std::unique_ptr<LogicalOperator> last_op,
std::unordered_set<Symbol> &bound_symbols,
std::unordered_map<Symbol, std::vector<Symbol>> &named_paths) {
std::unique_ptr<LogicalOperator> GenNamedPaths(std::unique_ptr<LogicalOperator> last_op,
std::unordered_set<Symbol> &bound_symbols,
std::unordered_map<Symbol, std::vector<Symbol>> &named_paths) {
auto all_are_bound = [&bound_symbols](const std::vector<Symbol> &syms) {
for (const auto &sym : syms)
if (bound_symbols.find(sym) == bound_symbols.end()) return false;
return true;
};
for (auto named_path_it = named_paths.begin();
named_path_it != named_paths.end();) {
for (auto named_path_it = named_paths.begin(); named_path_it != named_paths.end();) {
if (all_are_bound(named_path_it->second)) {
last_op = std::make_unique<ConstructNamedPath>(
std::move(last_op), named_path_it->first,
std::move(named_path_it->second));
last_op = std::make_unique<ConstructNamedPath>(std::move(last_op), named_path_it->first,
std::move(named_path_it->second));
bound_symbols.insert(named_path_it->first);
named_path_it = named_paths.erase(named_path_it);
} else {
@ -576,10 +537,9 @@ std::unique_ptr<LogicalOperator> GenNamedPaths(
return last_op;
}
std::unique_ptr<LogicalOperator> GenReturn(
Return &ret, std::unique_ptr<LogicalOperator> input_op,
SymbolTable &symbol_table, bool is_write,
const std::unordered_set<Symbol> &bound_symbols, AstStorage &storage) {
std::unique_ptr<LogicalOperator> GenReturn(Return &ret, std::unique_ptr<LogicalOperator> input_op,
SymbolTable &symbol_table, bool is_write,
const std::unordered_set<Symbol> &bound_symbols, AstStorage &storage) {
// Similar to WITH clause, but we want to accumulate when the query writes to
// the database. This way we handle the case when we want to return
// expressions with the latest updated results. For example, `MATCH (n) -- ()
@ -592,10 +552,9 @@ std::unique_ptr<LogicalOperator> GenReturn(
return GenReturnBody(std::move(input_op), advance_command, body, accumulate);
}
std::unique_ptr<LogicalOperator> GenWith(
With &with, std::unique_ptr<LogicalOperator> input_op,
SymbolTable &symbol_table, bool is_write,
std::unordered_set<Symbol> &bound_symbols, AstStorage &storage) {
std::unique_ptr<LogicalOperator> GenWith(With &with, std::unique_ptr<LogicalOperator> input_op,
SymbolTable &symbol_table, bool is_write,
std::unordered_set<Symbol> &bound_symbols, AstStorage &storage) {
// WITH clause is Accumulate/Aggregate (advance_command) + Produce and
// optional Filter. In case of update and aggregation, we want to accumulate
// first, so that when aggregating, we get the latest results. Similar to
@ -603,10 +562,8 @@ std::unique_ptr<LogicalOperator> GenWith(
bool accumulate = is_write;
// No need to advance the command if we only performed reads.
bool advance_command = is_write;
ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage,
with.where_);
auto last_op =
GenReturnBody(std::move(input_op), advance_command, body, accumulate);
ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage, with.where_);
auto last_op = GenReturnBody(std::move(input_op), advance_command, body, accumulate);
// Reset bound symbols, so that only those in WITH are exposed.
bound_symbols.clear();
for (const auto &symbol : body.output_symbols()) {
@ -615,11 +572,9 @@ std::unique_ptr<LogicalOperator> GenWith(
return last_op;
}
std::unique_ptr<LogicalOperator> GenUnion(
const CypherUnion &cypher_union, std::shared_ptr<LogicalOperator> left_op,
std::shared_ptr<LogicalOperator> right_op, SymbolTable &symbol_table) {
return std::make_unique<Union>(left_op, right_op, cypher_union.union_symbols_,
left_op->OutputSymbols(symbol_table),
std::unique_ptr<LogicalOperator> GenUnion(const CypherUnion &cypher_union, std::shared_ptr<LogicalOperator> left_op,
std::shared_ptr<LogicalOperator> right_op, SymbolTable &symbol_table) {
return std::make_unique<Union>(left_op, right_op, cypher_union.union_symbols_, left_op->OutputSymbols(symbol_table),
right_op->OutputSymbols(symbol_table));
}

View File

@ -38,8 +38,7 @@ struct PlanningContext {
};
template <class TDbAccessor>
auto MakePlanningContext(AstStorage *ast_storage, SymbolTable *symbol_table,
CypherQuery *query, TDbAccessor *db) {
auto MakePlanningContext(AstStorage *ast_storage, SymbolTable *symbol_table, CypherQuery *query, TDbAccessor *db) {
return PlanningContext<TDbAccessor>{symbol_table, ast_storage, query, db};
}
@ -66,11 +65,9 @@ namespace impl {
// Iterates over `Filters` joining them in one expression via
// `AndOperator` if symbols they use are bound.. All the joined filters are
// removed from `Filters`.
Expression *ExtractFilters(const std::unordered_set<Symbol> &, Filters &,
AstStorage &);
Expression *ExtractFilters(const std::unordered_set<Symbol> &, Filters &, AstStorage &);
std::unique_ptr<LogicalOperator> GenFilters(std::unique_ptr<LogicalOperator>,
const std::unordered_set<Symbol> &,
std::unique_ptr<LogicalOperator> GenFilters(std::unique_ptr<LogicalOperator>, const std::unordered_set<Symbol> &,
Filters &, AstStorage &);
/// Utility function for iterating pattern atoms and accumulating a result.
@ -92,9 +89,8 @@ std::unique_ptr<LogicalOperator> GenFilters(std::unique_ptr<LogicalOperator>,
// TODO: It might be a good idea to move this somewhere else, for easier usage
// in other files.
template <typename T>
auto ReducePattern(
Pattern &pattern, std::function<T(NodeAtom *)> base,
std::function<T(T, NodeAtom *, EdgeAtom *, NodeAtom *)> collect) {
auto ReducePattern(Pattern &pattern, std::function<T(NodeAtom *)> base,
std::function<T(T, NodeAtom *, EdgeAtom *, NodeAtom *)> collect) {
MG_ASSERT(!pattern.atoms_.empty(), "Missing atoms in pattern");
auto atoms_it = pattern.atoms_.begin();
auto current_node = utils::Downcast<NodeAtom>(*atoms_it++);
@ -104,8 +100,7 @@ auto ReducePattern(
while (atoms_it != pattern.atoms_.end()) {
auto edge = utils::Downcast<EdgeAtom>(*atoms_it++);
MG_ASSERT(edge, "Expected an edge atom in pattern.");
MG_ASSERT(atoms_it != pattern.atoms_.end(),
"Edge atom should not end the pattern.");
MG_ASSERT(atoms_it != pattern.atoms_.end(), "Edge atom should not end the pattern.");
auto prev_node = current_node;
current_node = utils::Downcast<NodeAtom>(*atoms_it++);
MG_ASSERT(current_node, "Expected a node atom in pattern.");
@ -118,28 +113,23 @@ auto ReducePattern(
// If so, it creates a logical operator for named path generation, binds its
// symbol, removes that path from the collection of unhandled ones and returns
// the new op. Otherwise, returns `last_op`.
std::unique_ptr<LogicalOperator> GenNamedPaths(
std::unique_ptr<LogicalOperator> last_op,
std::unordered_set<Symbol> &bound_symbols,
std::unordered_map<Symbol, std::vector<Symbol>> &named_paths);
std::unique_ptr<LogicalOperator> GenNamedPaths(std::unique_ptr<LogicalOperator> last_op,
std::unordered_set<Symbol> &bound_symbols,
std::unordered_map<Symbol, std::vector<Symbol>> &named_paths);
std::unique_ptr<LogicalOperator> GenReturn(
Return &ret, std::unique_ptr<LogicalOperator> input_op,
SymbolTable &symbol_table, bool is_write,
const std::unordered_set<Symbol> &bound_symbols, AstStorage &storage);
std::unique_ptr<LogicalOperator> GenReturn(Return &ret, std::unique_ptr<LogicalOperator> input_op,
SymbolTable &symbol_table, bool is_write,
const std::unordered_set<Symbol> &bound_symbols, AstStorage &storage);
std::unique_ptr<LogicalOperator> GenWith(
With &with, std::unique_ptr<LogicalOperator> input_op,
SymbolTable &symbol_table, bool is_write,
std::unordered_set<Symbol> &bound_symbols, AstStorage &storage);
std::unique_ptr<LogicalOperator> GenWith(With &with, std::unique_ptr<LogicalOperator> input_op,
SymbolTable &symbol_table, bool is_write,
std::unordered_set<Symbol> &bound_symbols, AstStorage &storage);
std::unique_ptr<LogicalOperator> GenUnion(
const CypherUnion &cypher_union, std::shared_ptr<LogicalOperator> left_op,
std::shared_ptr<LogicalOperator> right_op, SymbolTable &symbol_table);
std::unique_ptr<LogicalOperator> GenUnion(const CypherUnion &cypher_union, std::shared_ptr<LogicalOperator> left_op,
std::shared_ptr<LogicalOperator> right_op, SymbolTable &symbol_table);
template <class TBoolOperator>
Expression *BoolJoin(AstStorage &storage, Expression *expr1,
Expression *expr2) {
Expression *BoolJoin(AstStorage &storage, Expression *expr1, Expression *expr2) {
if (expr1 && expr2) {
return storage.Create<TBoolOperator>(expr1, expr2);
}
@ -166,52 +156,40 @@ class RuleBasedPlanner {
// Set to true if a query command writes to the database.
bool is_write = false;
for (const auto &query_part : query_parts) {
MatchContext match_ctx{query_part.matching, *context.symbol_table,
context.bound_symbols};
MatchContext match_ctx{query_part.matching, *context.symbol_table, context.bound_symbols};
input_op = PlanMatching(match_ctx, std::move(input_op));
for (const auto &matching : query_part.optional_matching) {
MatchContext opt_ctx{matching, *context.symbol_table,
context.bound_symbols};
MatchContext opt_ctx{matching, *context.symbol_table, context.bound_symbols};
auto match_op = PlanMatching(opt_ctx, nullptr);
if (match_op) {
input_op = std::make_unique<Optional>(
std::move(input_op), std::move(match_op), opt_ctx.new_symbols);
input_op = std::make_unique<Optional>(std::move(input_op), std::move(match_op), opt_ctx.new_symbols);
}
}
uint64_t merge_id = 0;
for (auto *clause : query_part.remaining_clauses) {
MG_ASSERT(!utils::IsSubtype(*clause, Match::kType),
"Unexpected Match in remaining clauses");
MG_ASSERT(!utils::IsSubtype(*clause, Match::kType), "Unexpected Match in remaining clauses");
if (auto *ret = utils::Downcast<Return>(clause)) {
input_op = impl::GenReturn(
*ret, std::move(input_op), *context.symbol_table, is_write,
context.bound_symbols, *context.ast_storage);
input_op = impl::GenReturn(*ret, std::move(input_op), *context.symbol_table, is_write, context.bound_symbols,
*context.ast_storage);
} else if (auto *merge = utils::Downcast<query::Merge>(clause)) {
input_op = GenMerge(*merge, std::move(input_op),
query_part.merge_matching[merge_id++]);
input_op = GenMerge(*merge, std::move(input_op), query_part.merge_matching[merge_id++]);
// Treat MERGE clause as write, because we do not know if it will
// create anything.
is_write = true;
} else if (auto *with = utils::Downcast<query::With>(clause)) {
input_op = impl::GenWith(*with, std::move(input_op),
*context.symbol_table, is_write,
context.bound_symbols, *context.ast_storage);
input_op = impl::GenWith(*with, std::move(input_op), *context.symbol_table, is_write, context.bound_symbols,
*context.ast_storage);
// WITH clause advances the command, so reset the flag.
is_write = false;
} else if (auto op = HandleWriteClause(clause, input_op,
*context.symbol_table,
context.bound_symbols)) {
} else if (auto op = HandleWriteClause(clause, input_op, *context.symbol_table, context.bound_symbols)) {
is_write = true;
input_op = std::move(op);
} else if (auto *unwind = utils::Downcast<query::Unwind>(clause)) {
const auto &symbol =
context.symbol_table->at(*unwind->named_expression_);
const auto &symbol = context.symbol_table->at(*unwind->named_expression_);
context.bound_symbols.insert(symbol);
input_op = std::make_unique<plan::Unwind>(
std::move(input_op), unwind->named_expression_->expression_,
symbol);
} else if (auto *call_proc =
utils::Downcast<query::CallProcedure>(clause)) {
input_op =
std::make_unique<plan::Unwind>(std::move(input_op), unwind->named_expression_->expression_, symbol);
} else if (auto *call_proc = utils::Downcast<query::CallProcedure>(clause)) {
std::vector<Symbol> result_symbols;
result_symbols.reserve(call_proc->result_identifiers_.size());
for (const auto *ident : call_proc->result_identifiers_) {
@ -223,13 +201,10 @@ class RuleBasedPlanner {
// need to plan this operator with Accumulate and pass in
// storage::View::NEW.
input_op = std::make_unique<plan::CallProcedure>(
std::move(input_op), call_proc->procedure_name_,
call_proc->arguments_, call_proc->result_fields_, result_symbols,
call_proc->memory_limit_, call_proc->memory_scale_);
std::move(input_op), call_proc->procedure_name_, call_proc->arguments_, call_proc->result_fields_,
result_symbols, call_proc->memory_limit_, call_proc->memory_scale_);
} else {
throw utils::NotYetImplemented(
"clause '{}' conversion to operator(s)",
clause->GetTypeInfo().name);
throw utils::NotYetImplemented("clause '{}' conversion to operator(s)", clause->GetTypeInfo().name);
}
}
}
@ -239,34 +214,25 @@ class RuleBasedPlanner {
private:
TPlanningContext *context_;
storage::LabelId GetLabel(LabelIx label) {
return context_->db->NameToLabel(label.name);
}
storage::LabelId GetLabel(LabelIx label) { return context_->db->NameToLabel(label.name); }
storage::PropertyId GetProperty(PropertyIx prop) {
return context_->db->NameToProperty(prop.name);
}
storage::PropertyId GetProperty(PropertyIx prop) { return context_->db->NameToProperty(prop.name); }
storage::EdgeTypeId GetEdgeType(EdgeTypeIx edge_type) {
return context_->db->NameToEdgeType(edge_type.name);
}
storage::EdgeTypeId GetEdgeType(EdgeTypeIx edge_type) { return context_->db->NameToEdgeType(edge_type.name); }
std::unique_ptr<LogicalOperator> GenCreate(
Create &create, std::unique_ptr<LogicalOperator> input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols) {
std::unique_ptr<LogicalOperator> GenCreate(Create &create, std::unique_ptr<LogicalOperator> input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols) {
auto last_op = std::move(input_op);
for (auto pattern : create.patterns_) {
last_op = GenCreateForPattern(*pattern, std::move(last_op), symbol_table,
bound_symbols);
last_op = GenCreateForPattern(*pattern, std::move(last_op), symbol_table, bound_symbols);
}
return last_op;
}
std::unique_ptr<LogicalOperator> GenCreateForPattern(
Pattern &pattern, std::unique_ptr<LogicalOperator> input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols) {
std::unique_ptr<LogicalOperator> GenCreateForPattern(Pattern &pattern, std::unique_ptr<LogicalOperator> input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols) {
auto node_to_creation_info = [&](const NodeAtom &node) {
const auto &node_symbol = symbol_table.at(*node.identifier_);
std::vector<storage::LabelId> labels;
@ -292,8 +258,7 @@ class RuleBasedPlanner {
}
};
auto collect = [&](std::unique_ptr<LogicalOperator> last_op,
NodeAtom *prev_node, EdgeAtom *edge, NodeAtom *node) {
auto collect = [&](std::unique_ptr<LogicalOperator> last_op, NodeAtom *prev_node, EdgeAtom *edge, NodeAtom *node) {
// Store the symbol from the first node as the input to CreateExpand.
const auto &input_symbol = symbol_table.at(*prev_node->identifier_);
// If the expand node was already bound, then we need to indicate this,
@ -312,28 +277,19 @@ class RuleBasedPlanner {
for (const auto &kv : edge->properties_) {
properties.push_back({GetProperty(kv.first), kv.second});
}
MG_ASSERT(
edge->edge_types_.size() == 1,
"Creating an edge with a single type should be required by syntax");
EdgeCreationInfo edge_info{edge_symbol, properties,
GetEdgeType(edge->edge_types_[0]),
edge->direction_};
return std::make_unique<CreateExpand>(node_info, edge_info,
std::move(last_op), input_symbol,
node_existing);
MG_ASSERT(edge->edge_types_.size() == 1, "Creating an edge with a single type should be required by syntax");
EdgeCreationInfo edge_info{edge_symbol, properties, GetEdgeType(edge->edge_types_[0]), edge->direction_};
return std::make_unique<CreateExpand>(node_info, edge_info, std::move(last_op), input_symbol, node_existing);
};
auto last_op = impl::ReducePattern<std::unique_ptr<LogicalOperator>>(
pattern, base, collect);
auto last_op = impl::ReducePattern<std::unique_ptr<LogicalOperator>>(pattern, base, collect);
// If the pattern is named, append the path constructing logical operator.
if (pattern.identifier_->user_declared_) {
std::vector<Symbol> path_elements;
for (const PatternAtom *atom : pattern.atoms_)
path_elements.emplace_back(symbol_table.at(*atom->identifier_));
last_op = std::make_unique<ConstructNamedPath>(
std::move(last_op), symbol_table.at(*pattern.identifier_),
path_elements);
for (const PatternAtom *atom : pattern.atoms_) path_elements.emplace_back(symbol_table.at(*atom->identifier_));
last_op = std::make_unique<ConstructNamedPath>(std::move(last_op), symbol_table.at(*pattern.identifier_),
path_elements);
}
return last_op;
@ -342,26 +298,20 @@ class RuleBasedPlanner {
// Generate an operator for a clause which writes to the database. Ownership
// of input_op is transferred to the newly created operator. If the clause
// isn't handled, returns nullptr and input_op is left as is.
std::unique_ptr<LogicalOperator> HandleWriteClause(
Clause *clause, std::unique_ptr<LogicalOperator> &input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols) {
std::unique_ptr<LogicalOperator> HandleWriteClause(Clause *clause, std::unique_ptr<LogicalOperator> &input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols) {
if (auto *create = utils::Downcast<Create>(clause)) {
return GenCreate(*create, std::move(input_op), symbol_table,
bound_symbols);
return GenCreate(*create, std::move(input_op), symbol_table, bound_symbols);
} else if (auto *del = utils::Downcast<query::Delete>(clause)) {
return std::make_unique<plan::Delete>(std::move(input_op),
del->expressions_, del->detach_);
return std::make_unique<plan::Delete>(std::move(input_op), del->expressions_, del->detach_);
} else if (auto *set = utils::Downcast<query::SetProperty>(clause)) {
return std::make_unique<plan::SetProperty>(
std::move(input_op), GetProperty(set->property_lookup_->property_),
set->property_lookup_, set->expression_);
return std::make_unique<plan::SetProperty>(std::move(input_op), GetProperty(set->property_lookup_->property_),
set->property_lookup_, set->expression_);
} else if (auto *set = utils::Downcast<query::SetProperties>(clause)) {
auto op = set->update_ ? plan::SetProperties::Op::UPDATE
: plan::SetProperties::Op::REPLACE;
auto op = set->update_ ? plan::SetProperties::Op::UPDATE : plan::SetProperties::Op::REPLACE;
const auto &input_symbol = symbol_table.at(*set->identifier_);
return std::make_unique<plan::SetProperties>(
std::move(input_op), input_symbol, set->expression_, op);
return std::make_unique<plan::SetProperties>(std::move(input_op), input_symbol, set->expression_, op);
} else if (auto *set = utils::Downcast<query::SetLabels>(clause)) {
const auto &input_symbol = symbol_table.at(*set->identifier_);
std::vector<storage::LabelId> labels;
@ -369,12 +319,10 @@ class RuleBasedPlanner {
for (const auto &label : set->labels_) {
labels.push_back(GetLabel(label));
}
return std::make_unique<plan::SetLabels>(std::move(input_op),
input_symbol, labels);
return std::make_unique<plan::SetLabels>(std::move(input_op), input_symbol, labels);
} else if (auto *rem = utils::Downcast<query::RemoveProperty>(clause)) {
return std::make_unique<plan::RemoveProperty>(
std::move(input_op), GetProperty(rem->property_lookup_->property_),
rem->property_lookup_);
return std::make_unique<plan::RemoveProperty>(std::move(input_op), GetProperty(rem->property_lookup_->property_),
rem->property_lookup_);
} else if (auto *rem = utils::Downcast<query::RemoveLabels>(clause)) {
const auto &input_symbol = symbol_table.at(*rem->identifier_);
std::vector<storage::LabelId> labels;
@ -382,14 +330,13 @@ class RuleBasedPlanner {
for (const auto &label : rem->labels_) {
labels.push_back(GetLabel(label));
}
return std::make_unique<plan::RemoveLabels>(std::move(input_op),
input_symbol, labels);
return std::make_unique<plan::RemoveLabels>(std::move(input_op), input_symbol, labels);
}
return nullptr;
}
std::unique_ptr<LogicalOperator> PlanMatching(
MatchContext &match_context, std::unique_ptr<LogicalOperator> input_op) {
std::unique_ptr<LogicalOperator> PlanMatching(MatchContext &match_context,
std::unique_ptr<LogicalOperator> input_op) {
auto &bound_symbols = match_context.bound_symbols;
auto &storage = *context_->ast_storage;
const auto &symbol_table = match_context.symbol_table;
@ -401,21 +348,16 @@ class RuleBasedPlanner {
// Try to generate any filters even before the 1st match operator. This
// optimizes the optional match which filters only on symbols bound in
// regular match.
auto last_op =
impl::GenFilters(std::move(input_op), bound_symbols, filters, storage);
auto last_op = impl::GenFilters(std::move(input_op), bound_symbols, filters, storage);
for (const auto &expansion : matching.expansions) {
const auto &node1_symbol = symbol_table.at(*expansion.node1->identifier_);
if (bound_symbols.insert(node1_symbol).second) {
// We have just bound this symbol, so generate ScanAll which fills it.
last_op = std::make_unique<ScanAll>(std::move(last_op), node1_symbol,
match_context.view);
last_op = std::make_unique<ScanAll>(std::move(last_op), node1_symbol, match_context.view);
match_context.new_symbols.emplace_back(node1_symbol);
last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters,
storage);
last_op =
impl::GenNamedPaths(std::move(last_op), bound_symbols, named_paths);
last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters,
storage);
last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters, storage);
last_op = impl::GenNamedPaths(std::move(last_op), bound_symbols, named_paths);
last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters, storage);
}
// We have an edge, so generate Expand.
if (expansion.edge) {
@ -423,12 +365,10 @@ class RuleBasedPlanner {
// If the expand symbols were already bound, then we need to indicate
// that they exist. The Expand will then check whether the pattern holds
// instead of writing the expansion to symbols.
const auto &node_symbol =
symbol_table.at(*expansion.node2->identifier_);
const auto &node_symbol = symbol_table.at(*expansion.node2->identifier_);
auto existing_node = utils::Contains(bound_symbols, node_symbol);
const auto &edge_symbol = symbol_table.at(*edge->identifier_);
MG_ASSERT(!utils::Contains(bound_symbols, edge_symbol),
"Existing edges are not supported");
MG_ASSERT(!utils::Contains(bound_symbols, edge_symbol), "Existing edges are not supported");
std::vector<storage::EdgeTypeId> edge_types;
edge_types.reserve(edge->edge_types_.size());
for (const auto &type : edge->edge_types_) {
@ -439,48 +379,38 @@ class RuleBasedPlanner {
std::optional<Symbol> total_weight;
if (edge->type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) {
weight_lambda.emplace(ExpansionLambda{
symbol_table.at(*edge->weight_lambda_.inner_edge),
symbol_table.at(*edge->weight_lambda_.inner_node),
edge->weight_lambda_.expression});
weight_lambda.emplace(ExpansionLambda{symbol_table.at(*edge->weight_lambda_.inner_edge),
symbol_table.at(*edge->weight_lambda_.inner_node),
edge->weight_lambda_.expression});
total_weight.emplace(symbol_table.at(*edge->total_weight_));
}
ExpansionLambda filter_lambda;
filter_lambda.inner_edge_symbol =
symbol_table.at(*edge->filter_lambda_.inner_edge);
filter_lambda.inner_node_symbol =
symbol_table.at(*edge->filter_lambda_.inner_node);
filter_lambda.inner_edge_symbol = symbol_table.at(*edge->filter_lambda_.inner_edge);
filter_lambda.inner_node_symbol = symbol_table.at(*edge->filter_lambda_.inner_node);
{
// Bind the inner edge and node symbols so they're available for
// inline filtering in ExpandVariable.
bool inner_edge_bound =
bound_symbols.insert(filter_lambda.inner_edge_symbol).second;
bool inner_node_bound =
bound_symbols.insert(filter_lambda.inner_node_symbol).second;
MG_ASSERT(inner_edge_bound && inner_node_bound,
"An inner edge and node can't be bound from before");
bool inner_edge_bound = bound_symbols.insert(filter_lambda.inner_edge_symbol).second;
bool inner_node_bound = bound_symbols.insert(filter_lambda.inner_node_symbol).second;
MG_ASSERT(inner_edge_bound && inner_node_bound, "An inner edge and node can't be bound from before");
}
// Join regular filters with lambda filter expression, so that they
// are done inline together. Semantic analysis should guarantee that
// lambda filtering uses bound symbols.
filter_lambda.expression = impl::BoolJoin<AndOperator>(
storage, impl::ExtractFilters(bound_symbols, filters, storage),
edge->filter_lambda_.expression);
storage, impl::ExtractFilters(bound_symbols, filters, storage), edge->filter_lambda_.expression);
// At this point it's possible we have leftover filters for inline
// filtering (they use the inner symbols. If they were not collected,
// we have to remove them manually because no other filter-extraction
// will ever bind them again.
filters.erase(
std::remove_if(
filters.begin(), filters.end(),
[e = filter_lambda.inner_edge_symbol,
n = filter_lambda.inner_node_symbol](FilterInfo &fi) {
return utils::Contains(fi.used_symbols, e) ||
utils::Contains(fi.used_symbols, n);
}),
filters.end());
filters.erase(std::remove_if(
filters.begin(), filters.end(),
[e = filter_lambda.inner_edge_symbol, n = filter_lambda.inner_node_symbol](FilterInfo &fi) {
return utils::Contains(fi.used_symbols, e) || utils::Contains(fi.used_symbols, n);
}),
filters.end());
// Unbind the temporarily bound inner symbols for filtering.
bound_symbols.erase(filter_lambda.inner_edge_symbol);
bound_symbols.erase(filter_lambda.inner_node_symbol);
@ -490,19 +420,15 @@ class RuleBasedPlanner {
}
// TODO: Pass weight lambda.
MG_ASSERT(
match_context.view == storage::View::OLD,
"ExpandVariable should only be planned with storage::View::OLD");
last_op = std::make_unique<ExpandVariable>(
std::move(last_op), node1_symbol, node_symbol, edge_symbol,
edge->type_, expansion.direction, edge_types,
expansion.is_flipped, edge->lower_bound_, edge->upper_bound_,
existing_node, filter_lambda, weight_lambda, total_weight);
MG_ASSERT(match_context.view == storage::View::OLD,
"ExpandVariable should only be planned with storage::View::OLD");
last_op = std::make_unique<ExpandVariable>(std::move(last_op), node1_symbol, node_symbol, edge_symbol,
edge->type_, expansion.direction, edge_types, expansion.is_flipped,
edge->lower_bound_, edge->upper_bound_, existing_node,
filter_lambda, weight_lambda, total_weight);
} else {
last_op = std::make_unique<Expand>(std::move(last_op), node1_symbol,
node_symbol, edge_symbol,
expansion.direction, edge_types,
existing_node, match_context.view);
last_op = std::make_unique<Expand>(std::move(last_op), node1_symbol, node_symbol, edge_symbol,
expansion.direction, edge_types, existing_node, match_context.view);
}
// Bind the expanded edge and node.
@ -520,60 +446,47 @@ class RuleBasedPlanner {
}
std::vector<Symbol> other_symbols;
for (const auto &symbol : edge_symbols) {
if (symbol == edge_symbol ||
bound_symbols.find(symbol) == bound_symbols.end()) {
if (symbol == edge_symbol || bound_symbols.find(symbol) == bound_symbols.end()) {
continue;
}
other_symbols.push_back(symbol);
}
if (!other_symbols.empty()) {
last_op = std::make_unique<EdgeUniquenessFilter>(
std::move(last_op), edge_symbol, other_symbols);
last_op = std::make_unique<EdgeUniquenessFilter>(std::move(last_op), edge_symbol, other_symbols);
}
}
last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters,
storage);
last_op =
impl::GenNamedPaths(std::move(last_op), bound_symbols, named_paths);
last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters,
storage);
last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters, storage);
last_op = impl::GenNamedPaths(std::move(last_op), bound_symbols, named_paths);
last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters, storage);
}
}
MG_ASSERT(named_paths.empty(), "Expected to generate all named paths");
// We bound all named path symbols, so just add them to new_symbols.
for (const auto &named_path : matching.named_paths) {
MG_ASSERT(utils::Contains(bound_symbols, named_path.first),
"Expected generated named path to have bound symbol");
MG_ASSERT(utils::Contains(bound_symbols, named_path.first), "Expected generated named path to have bound symbol");
match_context.new_symbols.emplace_back(named_path.first);
}
MG_ASSERT(filters.empty(), "Expected to generate all filters");
return last_op;
}
auto GenMerge(query::Merge &merge, std::unique_ptr<LogicalOperator> input_op,
const Matching &matching) {
auto GenMerge(query::Merge &merge, std::unique_ptr<LogicalOperator> input_op, const Matching &matching) {
// Copy the bound symbol set, because we don't want to use the updated
// version when generating the create part.
std::unordered_set<Symbol> bound_symbols_copy(context_->bound_symbols);
MatchContext match_ctx{matching, *context_->symbol_table,
bound_symbols_copy, storage::View::NEW};
MatchContext match_ctx{matching, *context_->symbol_table, bound_symbols_copy, storage::View::NEW};
auto on_match = PlanMatching(match_ctx, nullptr);
// Use the original bound_symbols, so we fill it with new symbols.
auto on_create =
GenCreateForPattern(*merge.pattern_, nullptr, *context_->symbol_table,
context_->bound_symbols);
auto on_create = GenCreateForPattern(*merge.pattern_, nullptr, *context_->symbol_table, context_->bound_symbols);
for (auto &set : merge.on_create_) {
on_create = HandleWriteClause(set, on_create, *context_->symbol_table,
context_->bound_symbols);
on_create = HandleWriteClause(set, on_create, *context_->symbol_table, context_->bound_symbols);
MG_ASSERT(on_create, "Expected SET in MERGE ... ON CREATE");
}
for (auto &set : merge.on_match_) {
on_match = HandleWriteClause(set, on_match, *context_->symbol_table,
context_->bound_symbols);
on_match = HandleWriteClause(set, on_match, *context_->symbol_table, context_->bound_symbols);
MG_ASSERT(on_match, "Expected SET in MERGE ... ON MATCH");
}
return std::make_unique<plan::Merge>(
std::move(input_op), std::move(on_match), std::move(on_create));
return std::make_unique<plan::Merge>(std::move(input_op), std::move(on_match), std::move(on_create));
}
};

View File

@ -19,9 +19,7 @@ namespace plan {
*/
class ScopedProfile {
public:
ScopedProfile(uint64_t key, const char *name,
query::ExecutionContext *context) noexcept
: context_(context) {
ScopedProfile(uint64_t key, const char *name, query::ExecutionContext *context) noexcept : context_(context) {
if (UNLIKELY(context_->is_profile_query)) {
root_ = context_->stats_root;

View File

@ -6,9 +6,8 @@
#include "utils/flag_validation.hpp"
#include "utils/logging.hpp"
DEFINE_VALIDATED_HIDDEN_uint64(
query_max_plans, 1000U, "Maximum number of generated plans for a query.",
FLAG_IN_RANGE(1, std::numeric_limits<std::uint64_t>::max()));
DEFINE_VALIDATED_HIDDEN_uint64(query_max_plans, 1000U, "Maximum number of generated plans for a query.",
FLAG_IN_RANGE(1, std::numeric_limits<std::uint64_t>::max()));
namespace query::plan::impl {
@ -17,13 +16,10 @@ namespace {
// Add applicable expansions for `node_symbol` to `next_expansions`. These
// expansions are removed from `node_symbol_to_expansions`, while
// `seen_expansions` and `expanded_symbols` are populated with new data.
void AddNextExpansions(
const Symbol &node_symbol, const Matching &matching,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &expanded_symbols,
std::unordered_map<Symbol, std::set<size_t>> &node_symbol_to_expansions,
std::unordered_set<size_t> &seen_expansions,
std::queue<Expansion> &next_expansions) {
void AddNextExpansions(const Symbol &node_symbol, const Matching &matching, const SymbolTable &symbol_table,
std::unordered_set<Symbol> &expanded_symbols,
std::unordered_map<Symbol, std::set<size_t>> &node_symbol_to_expansions,
std::unordered_set<size_t> &seen_expansions, std::queue<Expansion> &next_expansions) {
auto node_to_expansions_it = node_symbol_to_expansions.find(node_symbol);
if (node_to_expansions_it == node_symbol_to_expansions.end()) {
return;
@ -37,8 +33,7 @@ void AddNextExpansions(
// therefore bound. If the symbols are not found in the whole expansion,
// then the semantic analysis should guarantee that the symbols have been
// bound long before we expand.
if (matching.expansion_symbols.find(range_symbol) !=
matching.expansion_symbols.end() &&
if (matching.expansion_symbols.find(range_symbol) != matching.expansion_symbols.end() &&
expanded_symbols.find(range_symbol) == expanded_symbols.end()) {
return false;
}
@ -62,18 +57,15 @@ void AddNextExpansions(
}
if (symbol_table.at(*expansion.node1->identifier_) != node_symbol) {
// We are not expanding from node1, so flip the expansion.
DMG_ASSERT(
expansion.node2 &&
symbol_table.at(*expansion.node2->identifier_) == node_symbol,
"Expected node_symbol to be bound in node2");
DMG_ASSERT(expansion.node2 && symbol_table.at(*expansion.node2->identifier_) == node_symbol,
"Expected node_symbol to be bound in node2");
if (expansion.edge->type_ != EdgeAtom::Type::BREADTH_FIRST) {
// BFS must *not* be flipped. Doing that changes the BFS results.
std::swap(expansion.node1, expansion.node2);
expansion.is_flipped = true;
if (expansion.direction != EdgeAtom::Direction::BOTH) {
expansion.direction = expansion.direction == EdgeAtom::Direction::IN
? EdgeAtom::Direction::OUT
: EdgeAtom::Direction::IN;
expansion.direction =
expansion.direction == EdgeAtom::Direction::IN ? EdgeAtom::Direction::OUT : EdgeAtom::Direction::IN;
}
}
}
@ -95,20 +87,17 @@ void AddNextExpansions(
// the chain can no longer be continued, a different starting node is picked
// among remaining expansions and the process continues. This is done until all
// matching.expansions are used.
std::vector<Expansion> ExpansionsFrom(const NodeAtom *start_node,
const Matching &matching,
std::vector<Expansion> ExpansionsFrom(const NodeAtom *start_node, const Matching &matching,
const SymbolTable &symbol_table) {
// Make a copy of node_symbol_to_expansions, because we will modify it as
// expansions are chained.
auto node_symbol_to_expansions = matching.node_symbol_to_expansions;
std::unordered_set<size_t> seen_expansions;
std::queue<Expansion> next_expansions;
std::unordered_set<Symbol> expanded_symbols(
{symbol_table.at(*start_node->identifier_)});
std::unordered_set<Symbol> expanded_symbols({symbol_table.at(*start_node->identifier_)});
auto add_next_expansions = [&](const auto *node) {
AddNextExpansions(symbol_table.at(*node->identifier_), matching,
symbol_table, expanded_symbols, node_symbol_to_expansions,
seen_expansions, next_expansions);
AddNextExpansions(symbol_table.at(*node->identifier_), matching, symbol_table, expanded_symbols,
node_symbol_to_expansions, seen_expansions, next_expansions);
};
add_next_expansions(start_node);
// Potential optimization: expansions and next_expansions could be merge into
@ -141,11 +130,9 @@ std::vector<Expansion> ExpansionsFrom(const NodeAtom *start_node,
// Collect all unique nodes from expansions. Uniqueness is determined by
// symbol uniqueness.
auto ExpansionNodes(const std::vector<Expansion> &expansions,
const SymbolTable &symbol_table) {
std::unordered_set<NodeAtom *, NodeSymbolHash, NodeSymbolEqual> nodes(
expansions.size(), NodeSymbolHash(symbol_table),
NodeSymbolEqual(symbol_table));
auto ExpansionNodes(const std::vector<Expansion> &expansions, const SymbolTable &symbol_table) {
std::unordered_set<NodeAtom *, NodeSymbolHash, NodeSymbolEqual> nodes(expansions.size(), NodeSymbolHash(symbol_table),
NodeSymbolEqual(symbol_table));
for (const auto &expansion : expansions) {
// TODO: Handle labels and properties from different node atoms.
nodes.insert(expansion.node1);
@ -158,11 +145,8 @@ auto ExpansionNodes(const std::vector<Expansion> &expansions,
} // namespace
VaryMatchingStart::VaryMatchingStart(Matching matching,
const SymbolTable &symbol_table)
: matching_(matching),
symbol_table_(symbol_table),
nodes_(ExpansionNodes(matching.expansions, symbol_table)) {}
VaryMatchingStart::VaryMatchingStart(Matching matching, const SymbolTable &symbol_table)
: matching_(matching), symbol_table_(symbol_table), nodes_(ExpansionNodes(matching.expansions, symbol_table)) {}
VaryMatchingStart::iterator::iterator(VaryMatchingStart *self, bool is_done)
: self_(self),
@ -175,12 +159,10 @@ VaryMatchingStart::iterator::iterator(VaryMatchingStart *self, bool is_done)
// Overwrite the original matching expansions with the new ones by
// generating it from the first start node.
start_nodes_it_ = self_->nodes_.begin();
current_matching_.expansions = ExpansionsFrom(
**start_nodes_it_, self_->matching_, self_->symbol_table_);
current_matching_.expansions = ExpansionsFrom(**start_nodes_it_, self_->matching_, self_->symbol_table_);
}
DMG_ASSERT(
start_nodes_it_ || self_->nodes_.empty(),
"start_nodes_it_ should only be nullopt when self_->nodes_ is empty");
DMG_ASSERT(start_nodes_it_ || self_->nodes_.empty(),
"start_nodes_it_ should only be nullopt when self_->nodes_ is empty");
if (is_done) {
start_nodes_it_ = self_->nodes_.end();
}
@ -188,9 +170,7 @@ VaryMatchingStart::iterator::iterator(VaryMatchingStart *self, bool is_done)
VaryMatchingStart::iterator &VaryMatchingStart::iterator::operator++() {
if (!start_nodes_it_) {
DMG_ASSERT(
self_->nodes_.empty(),
"start_nodes_it_ should only be nullopt when self_->nodes_ is empty");
DMG_ASSERT(self_->nodes_.empty(), "start_nodes_it_ should only be nullopt when self_->nodes_ is empty");
start_nodes_it_ = self_->nodes_.end();
}
if (*start_nodes_it_ == self_->nodes_.end()) {
@ -203,13 +183,12 @@ VaryMatchingStart::iterator &VaryMatchingStart::iterator::operator++() {
return *this;
}
const auto &start_node = **start_nodes_it_;
current_matching_.expansions =
ExpansionsFrom(start_node, self_->matching_, self_->symbol_table_);
current_matching_.expansions = ExpansionsFrom(start_node, self_->matching_, self_->symbol_table_);
return *this;
}
CartesianProduct<VaryMatchingStart> VaryMultiMatchingStarts(
const std::vector<Matching> &matchings, const SymbolTable &symbol_table) {
CartesianProduct<VaryMatchingStart> VaryMultiMatchingStarts(const std::vector<Matching> &matchings,
const SymbolTable &symbol_table) {
std::vector<VaryMatchingStart> variants;
variants.reserve(matchings.size());
for (const auto &matching : matchings) {
@ -218,23 +197,19 @@ CartesianProduct<VaryMatchingStart> VaryMultiMatchingStarts(
return MakeCartesianProduct(std::move(variants));
}
VaryQueryPartMatching::VaryQueryPartMatching(SingleQueryPart query_part,
const SymbolTable &symbol_table)
VaryQueryPartMatching::VaryQueryPartMatching(SingleQueryPart query_part, const SymbolTable &symbol_table)
: query_part_(std::move(query_part)),
matchings_(VaryMatchingStart(query_part_.matching, symbol_table)),
optional_matchings_(
VaryMultiMatchingStarts(query_part_.optional_matching, symbol_table)),
merge_matchings_(
VaryMultiMatchingStarts(query_part_.merge_matching, symbol_table)) {}
optional_matchings_(VaryMultiMatchingStarts(query_part_.optional_matching, symbol_table)),
merge_matchings_(VaryMultiMatchingStarts(query_part_.merge_matching, symbol_table)) {}
VaryQueryPartMatching::iterator::iterator(
const SingleQueryPart &query_part,
VaryMatchingStart::iterator matchings_begin,
VaryMatchingStart::iterator matchings_end,
CartesianProduct<VaryMatchingStart>::iterator optional_begin,
CartesianProduct<VaryMatchingStart>::iterator optional_end,
CartesianProduct<VaryMatchingStart>::iterator merge_begin,
CartesianProduct<VaryMatchingStart>::iterator merge_end)
VaryQueryPartMatching::iterator::iterator(const SingleQueryPart &query_part,
VaryMatchingStart::iterator matchings_begin,
VaryMatchingStart::iterator matchings_end,
CartesianProduct<VaryMatchingStart>::iterator optional_begin,
CartesianProduct<VaryMatchingStart>::iterator optional_end,
CartesianProduct<VaryMatchingStart>::iterator merge_begin,
CartesianProduct<VaryMatchingStart>::iterator merge_end)
: current_query_part_(query_part),
matchings_it_(matchings_begin),
matchings_end_(matchings_end),
@ -304,8 +279,7 @@ bool VaryQueryPartMatching::iterator::operator==(const iterator &other) const {
// iterators can be at any position.
return true;
}
return matchings_it_ == other.matchings_it_ &&
optional_it_ == other.optional_it_ && merge_it_ == other.merge_it_;
return matchings_it_ == other.matchings_it_ && optional_it_ == other.optional_it_ && merge_it_ == other.merge_it_;
}
} // namespace query::plan::impl

View File

@ -39,9 +39,7 @@ class CartesianProduct {
public:
CartesianProduct(std::vector<TSet> sets)
: original_sets_(std::move(sets)),
begin_(original_sets_.begin()),
end_(original_sets_.end()) {}
: original_sets_(std::move(sets)), begin_(original_sets_.begin()), end_(original_sets_.end()) {}
class iterator {
public:
@ -51,8 +49,7 @@ class CartesianProduct {
typedef const std::vector<TElement> &reference;
typedef const std::vector<TElement> *pointer;
explicit iterator(CartesianProduct *self, bool is_done)
: self_(self), is_done_(is_done) {
explicit iterator(CartesianProduct *self, bool is_done) : self_(self), is_done_(is_done) {
if (is_done || self->begin_ == self->end_) {
is_done_ = true;
return;
@ -92,9 +89,8 @@ class CartesianProduct {
++sets_it->second;
}
// We can now collect another product from the modified set iterators.
DMG_ASSERT(
current_product_.size() == sets_.size(),
"Expected size of current_product_ to match the size of sets_");
DMG_ASSERT(current_product_.size() == sets_.size(),
"Expected size of current_product_ to match the size of sets_");
size_t i = 0;
// Change only the prefix of the product, remaining elements (after
// sets_it) should be the same.
@ -106,9 +102,7 @@ class CartesianProduct {
}
bool operator==(const iterator &other) const {
if (self_->begin_ != other.self_->begin_ ||
self_->end_ != other.self_->end_)
return false;
if (self_->begin_ != other.self_->begin_ || self_->end_ != other.self_->end_) return false;
return (is_done_ && other.is_done_) || (sets_ == other.sets_);
}
@ -126,9 +120,7 @@ class CartesianProduct {
// Vector of (original_sets_iterator, set_iterator) pairs. The
// original_sets_iterator points to the set among all the sets, while the
// set_iterator points to an element inside the pointed to set.
std::vector<
std::pair<decltype(self_->begin_), decltype(self_->begin_->begin())>>
sets_;
std::vector<std::pair<decltype(self_->begin_), decltype(self_->begin_->begin())>> sets_;
// Currently built product from pointed to elements in all sets.
std::vector<TElement> current_product_;
// Set to true when we have generated all products.
@ -153,8 +145,7 @@ namespace impl {
class NodeSymbolHash {
public:
explicit NodeSymbolHash(const SymbolTable &symbol_table)
: symbol_table_(symbol_table) {}
explicit NodeSymbolHash(const SymbolTable &symbol_table) : symbol_table_(symbol_table) {}
size_t operator()(const NodeAtom *node_atom) const {
return std::hash<Symbol>{}(symbol_table_.at(*node_atom->identifier_));
@ -166,13 +157,10 @@ class NodeSymbolHash {
class NodeSymbolEqual {
public:
explicit NodeSymbolEqual(const SymbolTable &symbol_table)
: symbol_table_(symbol_table) {}
explicit NodeSymbolEqual(const SymbolTable &symbol_table) : symbol_table_(symbol_table) {}
bool operator()(const NodeAtom *node_atom1,
const NodeAtom *node_atom2) const {
return symbol_table_.at(*node_atom1->identifier_) ==
symbol_table_.at(*node_atom2->identifier_);
bool operator()(const NodeAtom *node_atom1, const NodeAtom *node_atom2) const {
return symbol_table_.at(*node_atom1->identifier_) == symbol_table_.at(*node_atom2->identifier_);
}
private:
@ -213,9 +201,7 @@ class VaryMatchingStart {
// being at the end. When there are no nodes, this iterator needs to produce
// a single result, which is the original matching passed in. Setting
// start_nodes_it_ to end signifies the end of our iteration.
std::optional<std::unordered_set<NodeAtom *, NodeSymbolHash,
NodeSymbolEqual>::iterator>
start_nodes_it_;
std::optional<std::unordered_set<NodeAtom *, NodeSymbolHash, NodeSymbolEqual>::iterator> start_nodes_it_;
};
auto begin() { return iterator(this, false); }
@ -231,8 +217,7 @@ class VaryMatchingStart {
// Similar to VaryMatchingStart, but varies the starting nodes for all given
// matchings. After all matchings produce multiple alternative starts, the
// Cartesian product of all of them is returned.
CartesianProduct<VaryMatchingStart> VaryMultiMatchingStarts(
const std::vector<Matching> &, const SymbolTable &);
CartesianProduct<VaryMatchingStart> VaryMultiMatchingStarts(const std::vector<Matching> &, const SymbolTable &);
// Produces alternative query parts out of a single part by varying how each
// graph matching is done.
@ -248,12 +233,9 @@ class VaryQueryPartMatching {
typedef const SingleQueryPart &reference;
typedef const SingleQueryPart *pointer;
iterator(const SingleQueryPart &, VaryMatchingStart::iterator,
VaryMatchingStart::iterator,
CartesianProduct<VaryMatchingStart>::iterator,
CartesianProduct<VaryMatchingStart>::iterator,
CartesianProduct<VaryMatchingStart>::iterator,
CartesianProduct<VaryMatchingStart>::iterator);
iterator(const SingleQueryPart &, VaryMatchingStart::iterator, VaryMatchingStart::iterator,
CartesianProduct<VaryMatchingStart>::iterator, CartesianProduct<VaryMatchingStart>::iterator,
CartesianProduct<VaryMatchingStart>::iterator, CartesianProduct<VaryMatchingStart>::iterator);
iterator &operator++();
reference operator*() const { return current_query_part_; }
@ -276,14 +258,12 @@ class VaryQueryPartMatching {
};
auto begin() {
return iterator(query_part_, matchings_.begin(), matchings_.end(),
optional_matchings_.begin(), optional_matchings_.end(),
merge_matchings_.begin(), merge_matchings_.end());
return iterator(query_part_, matchings_.begin(), matchings_.end(), optional_matchings_.begin(),
optional_matchings_.end(), merge_matchings_.begin(), merge_matchings_.end());
}
auto end() {
return iterator(query_part_, matchings_.end(), matchings_.end(),
optional_matchings_.end(), optional_matchings_.end(),
merge_matchings_.end(), merge_matchings_.end());
return iterator(query_part_, matchings_.end(), matchings_.end(), optional_matchings_.end(),
optional_matchings_.end(), merge_matchings_.end(), merge_matchings_.end());
}
private:
@ -313,21 +293,17 @@ class VariableStartPlanner {
// Generates different, equivalent query parts by taking different graph
// matching routes for each query part.
auto VaryQueryMatching(const std::vector<SingleQueryPart> &query_parts,
const SymbolTable &symbol_table) {
auto VaryQueryMatching(const std::vector<SingleQueryPart> &query_parts, const SymbolTable &symbol_table) {
std::vector<impl::VaryQueryPartMatching> alternative_query_parts;
alternative_query_parts.reserve(query_parts.size());
for (const auto &query_part : query_parts) {
alternative_query_parts.emplace_back(
impl::VaryQueryPartMatching(query_part, symbol_table));
alternative_query_parts.emplace_back(impl::VaryQueryPartMatching(query_part, symbol_table));
}
return iter::slice(MakeCartesianProduct(std::move(alternative_query_parts)),
0UL, FLAGS_query_max_plans);
return iter::slice(MakeCartesianProduct(std::move(alternative_query_parts)), 0UL, FLAGS_query_max_plans);
}
public:
explicit VariableStartPlanner(TPlanningContext *context)
: context_(context) {}
explicit VariableStartPlanner(TPlanningContext *context) : context_(context) {}
/// @brief Generate multiple plans by varying the order of graph traversal.
auto Plan(const std::vector<SingleQueryPart> &query_parts) {
@ -342,10 +318,8 @@ class VariableStartPlanner {
/// @brief The result of plan generation is an iterable of roots to multiple
/// generated operator trees.
using PlanResult = typename std::result_of<decltype (
&VariableStartPlanner<TPlanningContext>::Plan)(
VariableStartPlanner<TPlanningContext>,
std::vector<SingleQueryPart> &)>::type;
using PlanResult = typename std::result_of<decltype (&VariableStartPlanner<TPlanningContext>::Plan)(
VariableStartPlanner<TPlanningContext>, std::vector<SingleQueryPart> &)>::type;
};
} // namespace query::plan

View File

@ -19,12 +19,8 @@ class VertexCountCache {
VertexCountCache(TDbAccessor *db) : db_(db) {}
auto NameToLabel(const std::string &name) { return db_->NameToLabel(name); }
auto NameToProperty(const std::string &name) {
return db_->NameToProperty(name);
}
auto NameToEdgeType(const std::string &name) {
return db_->NameToEdgeType(name);
}
auto NameToProperty(const std::string &name) { return db_->NameToProperty(name); }
auto NameToEdgeType(const std::string &name) { return db_->NameToEdgeType(name); }
int64_t VerticesCount() {
if (!vertices_count_) vertices_count_ = db_->VerticesCount();
@ -39,14 +35,12 @@ class VertexCountCache {
int64_t VerticesCount(storage::LabelId label, storage::PropertyId property) {
auto key = std::make_pair(label, property);
if (label_property_vertex_count_.find(key) ==
label_property_vertex_count_.end())
if (label_property_vertex_count_.find(key) == label_property_vertex_count_.end())
label_property_vertex_count_[key] = db_->VerticesCount(label, property);
return label_property_vertex_count_.at(key);
}
int64_t VerticesCount(storage::LabelId label, storage::PropertyId property,
const storage::PropertyValue &value) {
int64_t VerticesCount(storage::LabelId label, storage::PropertyId property, const storage::PropertyValue &value) {
auto label_prop = std::make_pair(label, property);
auto &value_vertex_count = property_value_vertex_count_[label_prop];
// TODO: Why do we even need TypedValue in this whole file?
@ -56,25 +50,20 @@ class VertexCountCache {
return value_vertex_count.at(tv_value);
}
int64_t VerticesCount(
storage::LabelId label, storage::PropertyId property,
const std::optional<utils::Bound<storage::PropertyValue>> &lower,
const std::optional<utils::Bound<storage::PropertyValue>> &upper) {
int64_t VerticesCount(storage::LabelId label, storage::PropertyId property,
const std::optional<utils::Bound<storage::PropertyValue>> &lower,
const std::optional<utils::Bound<storage::PropertyValue>> &upper) {
auto label_prop = std::make_pair(label, property);
auto &bounds_vertex_count = property_bounds_vertex_count_[label_prop];
BoundsKey bounds = std::make_pair(lower, upper);
if (bounds_vertex_count.find(bounds) == bounds_vertex_count.end())
bounds_vertex_count[bounds] =
db_->VerticesCount(label, property, lower, upper);
bounds_vertex_count[bounds] = db_->VerticesCount(label, property, lower, upper);
return bounds_vertex_count.at(bounds);
}
bool LabelIndexExists(storage::LabelId label) {
return db_->LabelIndexExists(label);
}
bool LabelIndexExists(storage::LabelId label) { return db_->LabelIndexExists(label); }
bool LabelPropertyIndexExists(storage::LabelId label,
storage::PropertyId property) {
bool LabelPropertyIndexExists(storage::LabelId label, storage::PropertyId property) {
return db_->LabelPropertyIndexExists(label, property);
}
@ -83,8 +72,7 @@ class VertexCountCache {
struct LabelPropertyHash {
size_t operator()(const LabelPropertyKey &key) const {
return utils::HashCombine<storage::LabelId, storage::PropertyId>{}(
key.first, key.second);
return utils::HashCombine<storage::LabelId, storage::PropertyId>{}(key.first, key.second);
}
};
@ -107,11 +95,8 @@ class VertexCountCache {
struct BoundsEqual {
bool operator()(const BoundsKey &a, const BoundsKey &b) const {
auto bound_equal = [](const auto &maybe_bound_a,
const auto &maybe_bound_b) {
if (maybe_bound_a && maybe_bound_b &&
maybe_bound_a->type() != maybe_bound_b->type())
return false;
auto bound_equal = [](const auto &maybe_bound_a, const auto &maybe_bound_b) {
if (maybe_bound_a && maybe_bound_b && maybe_bound_a->type() != maybe_bound_b->type()) return false;
query::TypedValue bound_a;
query::TypedValue bound_b;
if (maybe_bound_a) bound_a = TypedValue(maybe_bound_a->value());
@ -125,18 +110,14 @@ class VertexCountCache {
TDbAccessor *db_;
std::optional<int64_t> vertices_count_;
std::unordered_map<storage::LabelId, int64_t> label_vertex_count_;
std::unordered_map<LabelPropertyKey, int64_t, LabelPropertyHash>
label_property_vertex_count_;
std::unordered_map<LabelPropertyKey, int64_t, LabelPropertyHash> label_property_vertex_count_;
std::unordered_map<
LabelPropertyKey,
std::unordered_map<query::TypedValue, int64_t, query::TypedValue::Hash,
query::TypedValue::BoolEqual>,
std::unordered_map<query::TypedValue, int64_t, query::TypedValue::Hash, query::TypedValue::BoolEqual>,
LabelPropertyHash>
property_value_vertex_count_;
std::unordered_map<
LabelPropertyKey,
std::unordered_map<BoundsKey, int64_t, BoundsHash, BoundsEqual>,
LabelPropertyHash>
std::unordered_map<LabelPropertyKey, std::unordered_map<BoundsKey, int64_t, BoundsHash, BoundsEqual>,
LabelPropertyHash>
property_bounds_vertex_count_;
};

View File

@ -42,8 +42,7 @@ class CypherType {
virtual const NullableType *AsNullableType() const { return nullptr; }
};
using CypherTypePtr =
std::unique_ptr<CypherType, std::function<void(CypherType *)>>;
using CypherTypePtr = std::unique_ptr<CypherType, std::function<void(CypherType *)>>;
// Simple Types
@ -51,65 +50,45 @@ class AnyType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "ANY"; }
bool SatisfiesType(const mgp_value &value) const override {
return !mgp_value_is_null(&value);
}
bool SatisfiesType(const mgp_value &value) const override { return !mgp_value_is_null(&value); }
bool SatisfiesType(const query::TypedValue &value) const override {
return !value.IsNull();
}
bool SatisfiesType(const query::TypedValue &value) const override { return !value.IsNull(); }
};
class BoolType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "BOOLEAN"; }
bool SatisfiesType(const mgp_value &value) const override {
return mgp_value_is_bool(&value);
}
bool SatisfiesType(const mgp_value &value) const override { return mgp_value_is_bool(&value); }
bool SatisfiesType(const query::TypedValue &value) const override {
return value.IsBool();
}
bool SatisfiesType(const query::TypedValue &value) const override { return value.IsBool(); }
};
class StringType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "STRING"; }
bool SatisfiesType(const mgp_value &value) const override {
return mgp_value_is_string(&value);
}
bool SatisfiesType(const mgp_value &value) const override { return mgp_value_is_string(&value); }
bool SatisfiesType(const query::TypedValue &value) const override {
return value.IsString();
}
bool SatisfiesType(const query::TypedValue &value) const override { return value.IsString(); }
};
class IntType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "INTEGER"; }
bool SatisfiesType(const mgp_value &value) const override {
return mgp_value_is_int(&value);
}
bool SatisfiesType(const mgp_value &value) const override { return mgp_value_is_int(&value); }
bool SatisfiesType(const query::TypedValue &value) const override {
return value.IsInt();
}
bool SatisfiesType(const query::TypedValue &value) const override { return value.IsInt(); }
};
class FloatType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "FLOAT"; }
bool SatisfiesType(const mgp_value &value) const override {
return mgp_value_is_double(&value);
}
bool SatisfiesType(const mgp_value &value) const override { return mgp_value_is_double(&value); }
bool SatisfiesType(const query::TypedValue &value) const override {
return value.IsDouble();
}
bool SatisfiesType(const query::TypedValue &value) const override { return value.IsDouble(); }
};
class NumberType : public CypherType {
@ -120,50 +99,34 @@ class NumberType : public CypherType {
return mgp_value_is_int(&value) || mgp_value_is_double(&value);
}
bool SatisfiesType(const query::TypedValue &value) const override {
return value.IsInt() || value.IsDouble();
}
bool SatisfiesType(const query::TypedValue &value) const override { return value.IsInt() || value.IsDouble(); }
};
class NodeType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "NODE"; }
bool SatisfiesType(const mgp_value &value) const override {
return mgp_value_is_vertex(&value);
}
bool SatisfiesType(const mgp_value &value) const override { return mgp_value_is_vertex(&value); }
bool SatisfiesType(const query::TypedValue &value) const override {
return value.IsVertex();
}
bool SatisfiesType(const query::TypedValue &value) const override { return value.IsVertex(); }
};
class RelationshipType : public CypherType {
public:
std::string_view GetPresentableName() const override {
return "RELATIONSHIP";
}
std::string_view GetPresentableName() const override { return "RELATIONSHIP"; }
bool SatisfiesType(const mgp_value &value) const override {
return mgp_value_is_edge(&value);
}
bool SatisfiesType(const mgp_value &value) const override { return mgp_value_is_edge(&value); }
bool SatisfiesType(const query::TypedValue &value) const override {
return value.IsEdge();
}
bool SatisfiesType(const query::TypedValue &value) const override { return value.IsEdge(); }
};
class PathType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "PATH"; }
bool SatisfiesType(const mgp_value &value) const override {
return mgp_value_is_path(&value);
}
bool SatisfiesType(const mgp_value &value) const override { return mgp_value_is_path(&value); }
bool SatisfiesType(const query::TypedValue &value) const override {
return value.IsPath();
}
bool SatisfiesType(const query::TypedValue &value) const override { return value.IsPath(); }
};
// TODO: There's also Temporal Types, but we currently do not support those.
@ -178,8 +141,7 @@ class MapType : public CypherType {
std::string_view GetPresentableName() const override { return "MAP"; }
bool SatisfiesType(const mgp_value &value) const override {
return mgp_value_is_map(&value) || mgp_value_is_vertex(&value) ||
mgp_value_is_edge(&value);
return mgp_value_is_map(&value) || mgp_value_is_vertex(&value) || mgp_value_is_edge(&value);
}
bool SatisfiesType(const query::TypedValue &value) const override {
@ -197,14 +159,11 @@ class ListType : public CypherType {
/// @throw std::bad_alloc
/// @throw std::length_error
explicit ListType(CypherTypePtr element_type, utils::MemoryResource *memory)
: element_type_(std::move(element_type)),
presentable_name_("LIST OF ", memory) {
: element_type_(std::move(element_type)), presentable_name_("LIST OF ", memory) {
presentable_name_.append(element_type_->GetPresentableName());
}
std::string_view GetPresentableName() const override {
return presentable_name_;
}
std::string_view GetPresentableName() const override { return presentable_name_; }
bool SatisfiesType(const mgp_value &value) const override {
if (!mgp_value_is_list(&value)) return false;
@ -239,8 +198,7 @@ class NullableType : public CypherType {
const auto *list_type = type_->AsListType();
// ListType is specially formatted
if (list_type) {
presentable_name_.assign("LIST? OF ")
.append(list_type->element_type_->GetPresentableName());
presentable_name_.assign("LIST? OF ").append(list_type->element_type_->GetPresentableName());
} else {
presentable_name_.assign(type_->GetPresentableName()).append("?");
}
@ -252,8 +210,7 @@ class NullableType : public CypherType {
/// Otherwise, `type` is wrapped in a new instance of NullableType.
/// @throw std::bad_alloc
/// @throw std::length_error
static CypherTypePtr Create(CypherTypePtr type,
utils::MemoryResource *memory) {
static CypherTypePtr Create(CypherTypePtr type, utils::MemoryResource *memory) {
if (type->AsNullableType()) return type;
utils::Allocator<NullableType> alloc(memory);
auto *nullable = alloc.allocate(1);
@ -268,9 +225,7 @@ class NullableType : public CypherType {
});
}
std::string_view GetPresentableName() const override {
return presentable_name_;
}
std::string_view GetPresentableName() const override { return presentable_name_; }
bool SatisfiesType(const mgp_value &value) const override {
return mgp_value_is_null(&value) || type_->SatisfiesType(value);

View File

@ -20,8 +20,7 @@ void *mgp_alloc(mgp_memory *memory, size_t size_in_bytes) {
return mgp_aligned_alloc(memory, size_in_bytes, alignof(std::max_align_t));
}
void *mgp_aligned_alloc(mgp_memory *memory, const size_t size_in_bytes,
const size_t alignment) {
void *mgp_aligned_alloc(mgp_memory *memory, const size_t size_in_bytes, const size_t alignment) {
if (size_in_bytes == 0U || !utils::IsPow2(alignment)) return nullptr;
// Simplify alignment by always using values greater or equal to max_align.
const size_t alloc_align = std::max(alignment, alignof(std::max_align_t));
@ -32,8 +31,7 @@ void *mgp_aligned_alloc(mgp_memory *memory, const size_t size_in_bytes,
// just allocate an additional multiple of `alloc_align` of bytes such that
// the header fits. `data` will then be aligned after this multiple of bytes.
static_assert(std::is_same_v<size_t, uint64_t>);
const auto maybe_bytes_for_header =
utils::RoundUint64ToMultiple(header_size, alloc_align);
const auto maybe_bytes_for_header = utils::RoundUint64ToMultiple(header_size, alloc_align);
if (!maybe_bytes_for_header) return nullptr;
const size_t bytes_for_header = *maybe_bytes_for_header;
const size_t alloc_size = bytes_for_header + size_in_bytes;
@ -41,10 +39,8 @@ void *mgp_aligned_alloc(mgp_memory *memory, const size_t size_in_bytes,
try {
void *ptr = memory->impl->Allocate(alloc_size, alloc_align);
char *data = reinterpret_cast<char *>(ptr) + bytes_for_header;
std::memcpy(data - sizeof(size_in_bytes), &size_in_bytes,
sizeof(size_in_bytes));
std::memcpy(data - sizeof(size_in_bytes) - sizeof(alloc_align),
&alloc_align, sizeof(alloc_align));
std::memcpy(data - sizeof(size_in_bytes), &size_in_bytes, sizeof(size_in_bytes));
std::memcpy(data - sizeof(size_in_bytes) - sizeof(alloc_align), &alloc_align, sizeof(alloc_align));
return data;
} catch (...) {
return nullptr;
@ -56,17 +52,14 @@ void mgp_free(mgp_memory *memory, void *const p) {
char *const data = reinterpret_cast<char *>(p);
// Read the header containing size & alignment info.
size_t size_in_bytes;
std::memcpy(&size_in_bytes, data - sizeof(size_in_bytes),
sizeof(size_in_bytes));
std::memcpy(&size_in_bytes, data - sizeof(size_in_bytes), sizeof(size_in_bytes));
size_t alloc_align;
std::memcpy(&alloc_align, data - sizeof(size_in_bytes) - sizeof(alloc_align),
sizeof(alloc_align));
std::memcpy(&alloc_align, data - sizeof(size_in_bytes) - sizeof(alloc_align), sizeof(alloc_align));
// Reconstruct how many bytes we allocated on top of the original request.
// We need not check allocation request overflow, since we did so already in
// mgp_aligned_alloc.
const size_t header_size = sizeof(size_in_bytes) + sizeof(alloc_align);
const size_t bytes_for_header =
*utils::RoundUint64ToMultiple(header_size, alloc_align);
const size_t bytes_for_header = *utils::RoundUint64ToMultiple(header_size, alloc_align);
const size_t alloc_size = bytes_for_header + size_in_bytes;
// Get the original ptr we allocated.
void *const original_ptr = data - bytes_for_header;
@ -89,8 +82,7 @@ U *new_mgp_object(utils::MemoryResource *memory, TArgs &&...args) {
template <class U, class... TArgs>
U *new_mgp_object(mgp_memory *memory, TArgs &&...args) {
return new_mgp_object<U, TArgs...>(memory->impl,
std::forward<TArgs>(args)...);
return new_mgp_object<U, TArgs...>(memory->impl, std::forward<TArgs>(args)...);
}
// Assume that deallocation and object destruction never throws. If it does,
@ -127,14 +119,12 @@ mgp_value_type FromTypedValueType(query::TypedValue::Type type) {
}
}
query::TypedValue ToTypedValue(const mgp_value &val,
utils::MemoryResource *memory) {
query::TypedValue ToTypedValue(const mgp_value &val, utils::MemoryResource *memory) {
switch (mgp_value_get_type(&val)) {
case MGP_VALUE_TYPE_NULL:
return query::TypedValue(memory);
case MGP_VALUE_TYPE_BOOL:
return query::TypedValue(static_cast<bool>(mgp_value_get_bool(&val)),
memory);
return query::TypedValue(static_cast<bool>(mgp_value_get_bool(&val)), memory);
case MGP_VALUE_TYPE_INT:
return query::TypedValue(mgp_value_get_int(&val), memory);
case MGP_VALUE_TYPE_DOUBLE:
@ -178,11 +168,9 @@ query::TypedValue ToTypedValue(const mgp_value &val,
} // namespace
mgp_value::mgp_value(utils::MemoryResource *m) noexcept
: type(MGP_VALUE_TYPE_NULL), memory(m) {}
mgp_value::mgp_value(utils::MemoryResource *m) noexcept : type(MGP_VALUE_TYPE_NULL), memory(m) {}
mgp_value::mgp_value(bool val, utils::MemoryResource *m) noexcept
: type(MGP_VALUE_TYPE_BOOL), memory(m), bool_v(val) {}
mgp_value::mgp_value(bool val, utils::MemoryResource *m) noexcept : type(MGP_VALUE_TYPE_BOOL), memory(m), bool_v(val) {}
mgp_value::mgp_value(int64_t val, utils::MemoryResource *m) noexcept
: type(MGP_VALUE_TYPE_INT), memory(m), int_v(val) {}
@ -195,36 +183,30 @@ mgp_value::mgp_value(const char *val, utils::MemoryResource *m)
mgp_value::mgp_value(mgp_list *val, utils::MemoryResource *m) noexcept
: type(MGP_VALUE_TYPE_LIST), memory(m), list_v(val) {
MG_ASSERT(val->GetMemoryResource() == m,
"Unable to take ownership of a pointer with different allocator.");
MG_ASSERT(val->GetMemoryResource() == m, "Unable to take ownership of a pointer with different allocator.");
}
mgp_value::mgp_value(mgp_map *val, utils::MemoryResource *m) noexcept
: type(MGP_VALUE_TYPE_MAP), memory(m), map_v(val) {
MG_ASSERT(val->GetMemoryResource() == m,
"Unable to take ownership of a pointer with different allocator.");
MG_ASSERT(val->GetMemoryResource() == m, "Unable to take ownership of a pointer with different allocator.");
}
mgp_value::mgp_value(mgp_vertex *val, utils::MemoryResource *m) noexcept
: type(MGP_VALUE_TYPE_VERTEX), memory(m), vertex_v(val) {
MG_ASSERT(val->GetMemoryResource() == m,
"Unable to take ownership of a pointer with different allocator.");
MG_ASSERT(val->GetMemoryResource() == m, "Unable to take ownership of a pointer with different allocator.");
}
mgp_value::mgp_value(mgp_edge *val, utils::MemoryResource *m) noexcept
: type(MGP_VALUE_TYPE_EDGE), memory(m), edge_v(val) {
MG_ASSERT(val->GetMemoryResource() == m,
"Unable to take ownership of a pointer with different allocator.");
MG_ASSERT(val->GetMemoryResource() == m, "Unable to take ownership of a pointer with different allocator.");
}
mgp_value::mgp_value(mgp_path *val, utils::MemoryResource *m) noexcept
: type(MGP_VALUE_TYPE_PATH), memory(m), path_v(val) {
MG_ASSERT(val->GetMemoryResource() == m,
"Unable to take ownership of a pointer with different allocator.");
MG_ASSERT(val->GetMemoryResource() == m, "Unable to take ownership of a pointer with different allocator.");
}
mgp_value::mgp_value(const query::TypedValue &tv, const mgp_graph *graph,
utils::MemoryResource *m)
mgp_value::mgp_value(const query::TypedValue &tv, const mgp_graph *graph, utils::MemoryResource *m)
: type(FromTypedValueType(tv.type())), memory(m) {
switch (type) {
case MGP_VALUE_TYPE_NULL:
@ -299,8 +281,7 @@ mgp_value::mgp_value(const query::TypedValue &tv, const mgp_graph *graph,
}
}
mgp_value::mgp_value(const storage::PropertyValue &pv, utils::MemoryResource *m)
: memory(m) {
mgp_value::mgp_value(const storage::PropertyValue &pv, utils::MemoryResource *m) : memory(m) {
switch (pv.type()) {
case storage::PropertyValue::Type::Null:
type = MGP_VALUE_TYPE_NULL;
@ -353,8 +334,7 @@ mgp_value::mgp_value(const storage::PropertyValue &pv, utils::MemoryResource *m)
}
}
mgp_value::mgp_value(const mgp_value &other, utils::MemoryResource *m)
: type(other.type), memory(m) {
mgp_value::mgp_value(const mgp_value &other, utils::MemoryResource *m) : type(other.type), memory(m) {
switch (other.type) {
case MGP_VALUE_TYPE_NULL:
break;
@ -433,8 +413,7 @@ void DeleteValueMember(mgp_value *value) noexcept {
} // namespace
mgp_value::mgp_value(mgp_value &&other, utils::MemoryResource *m)
: type(other.type), memory(m) {
mgp_value::mgp_value(mgp_value &&other, utils::MemoryResource *m) : type(other.type), memory(m) {
switch (other.type) {
case MGP_VALUE_TYPE_NULL:
break;
@ -451,8 +430,7 @@ mgp_value::mgp_value(mgp_value &&other, utils::MemoryResource *m)
new (&string_v) utils::pmr::string(std::move(other.string_v), m);
break;
case MGP_VALUE_TYPE_LIST:
static_assert(std::is_pointer_v<decltype(list_v)>,
"Expected to move list_v by copying pointers.");
static_assert(std::is_pointer_v<decltype(list_v)>, "Expected to move list_v by copying pointers.");
if (*other.GetMemoryResource() == *m) {
list_v = other.list_v;
other.type = MGP_VALUE_TYPE_NULL;
@ -462,8 +440,7 @@ mgp_value::mgp_value(mgp_value &&other, utils::MemoryResource *m)
}
break;
case MGP_VALUE_TYPE_MAP:
static_assert(std::is_pointer_v<decltype(map_v)>,
"Expected to move map_v by copying pointers.");
static_assert(std::is_pointer_v<decltype(map_v)>, "Expected to move map_v by copying pointers.");
if (*other.GetMemoryResource() == *m) {
map_v = other.map_v;
other.type = MGP_VALUE_TYPE_NULL;
@ -473,8 +450,7 @@ mgp_value::mgp_value(mgp_value &&other, utils::MemoryResource *m)
}
break;
case MGP_VALUE_TYPE_VERTEX:
static_assert(std::is_pointer_v<decltype(vertex_v)>,
"Expected to move vertex_v by copying pointers.");
static_assert(std::is_pointer_v<decltype(vertex_v)>, "Expected to move vertex_v by copying pointers.");
if (*other.GetMemoryResource() == *m) {
vertex_v = other.vertex_v;
other.type = MGP_VALUE_TYPE_NULL;
@ -484,8 +460,7 @@ mgp_value::mgp_value(mgp_value &&other, utils::MemoryResource *m)
}
break;
case MGP_VALUE_TYPE_EDGE:
static_assert(std::is_pointer_v<decltype(edge_v)>,
"Expected to move edge_v by copying pointers.");
static_assert(std::is_pointer_v<decltype(edge_v)>, "Expected to move edge_v by copying pointers.");
if (*other.GetMemoryResource() == *m) {
edge_v = other.edge_v;
other.type = MGP_VALUE_TYPE_NULL;
@ -495,8 +470,7 @@ mgp_value::mgp_value(mgp_value &&other, utils::MemoryResource *m)
}
break;
case MGP_VALUE_TYPE_PATH:
static_assert(std::is_pointer_v<decltype(path_v)>,
"Expected to move path_v by copying pointers.");
static_assert(std::is_pointer_v<decltype(path_v)>, "Expected to move path_v by copying pointers.");
if (*other.GetMemoryResource() == *m) {
path_v = other.path_v;
other.type = MGP_VALUE_TYPE_NULL;
@ -514,21 +488,13 @@ mgp_value::~mgp_value() noexcept { DeleteValueMember(this); }
void mgp_value_destroy(mgp_value *val) { delete_mgp_object(val); }
mgp_value *mgp_value_make_null(mgp_memory *memory) {
return new_mgp_object<mgp_value>(memory);
}
mgp_value *mgp_value_make_null(mgp_memory *memory) { return new_mgp_object<mgp_value>(memory); }
mgp_value *mgp_value_make_bool(int val, mgp_memory *memory) {
return new_mgp_object<mgp_value>(memory, val != 0);
}
mgp_value *mgp_value_make_bool(int val, mgp_memory *memory) { return new_mgp_object<mgp_value>(memory, val != 0); }
mgp_value *mgp_value_make_int(int64_t val, mgp_memory *memory) {
return new_mgp_object<mgp_value>(memory, val);
}
mgp_value *mgp_value_make_int(int64_t val, mgp_memory *memory) { return new_mgp_object<mgp_value>(memory, val); }
mgp_value *mgp_value_make_double(double val, mgp_memory *memory) {
return new_mgp_object<mgp_value>(memory, val);
}
mgp_value *mgp_value_make_double(double val, mgp_memory *memory) { return new_mgp_object<mgp_value>(memory, val); }
mgp_value *mgp_value_make_string(const char *val, mgp_memory *memory) {
try {
@ -540,67 +506,37 @@ mgp_value *mgp_value_make_string(const char *val, mgp_memory *memory) {
}
}
mgp_value *mgp_value_make_list(mgp_list *val) {
return new_mgp_object<mgp_value>(val->GetMemoryResource(), val);
}
mgp_value *mgp_value_make_list(mgp_list *val) { return new_mgp_object<mgp_value>(val->GetMemoryResource(), val); }
mgp_value *mgp_value_make_map(mgp_map *val) {
return new_mgp_object<mgp_value>(val->GetMemoryResource(), val);
}
mgp_value *mgp_value_make_map(mgp_map *val) { return new_mgp_object<mgp_value>(val->GetMemoryResource(), val); }
mgp_value *mgp_value_make_vertex(mgp_vertex *val) {
return new_mgp_object<mgp_value>(val->GetMemoryResource(), val);
}
mgp_value *mgp_value_make_vertex(mgp_vertex *val) { return new_mgp_object<mgp_value>(val->GetMemoryResource(), val); }
mgp_value *mgp_value_make_edge(mgp_edge *val) {
return new_mgp_object<mgp_value>(val->GetMemoryResource(), val);
}
mgp_value *mgp_value_make_edge(mgp_edge *val) { return new_mgp_object<mgp_value>(val->GetMemoryResource(), val); }
mgp_value *mgp_value_make_path(mgp_path *val) {
return new_mgp_object<mgp_value>(val->GetMemoryResource(), val);
}
mgp_value *mgp_value_make_path(mgp_path *val) { return new_mgp_object<mgp_value>(val->GetMemoryResource(), val); }
mgp_value_type mgp_value_get_type(const mgp_value *val) { return val->type; }
int mgp_value_is_null(const mgp_value *val) {
return mgp_value_get_type(val) == MGP_VALUE_TYPE_NULL;
}
int mgp_value_is_null(const mgp_value *val) { return mgp_value_get_type(val) == MGP_VALUE_TYPE_NULL; }
int mgp_value_is_bool(const mgp_value *val) {
return mgp_value_get_type(val) == MGP_VALUE_TYPE_BOOL;
}
int mgp_value_is_bool(const mgp_value *val) { return mgp_value_get_type(val) == MGP_VALUE_TYPE_BOOL; }
int mgp_value_is_int(const mgp_value *val) {
return mgp_value_get_type(val) == MGP_VALUE_TYPE_INT;
}
int mgp_value_is_int(const mgp_value *val) { return mgp_value_get_type(val) == MGP_VALUE_TYPE_INT; }
int mgp_value_is_double(const mgp_value *val) {
return mgp_value_get_type(val) == MGP_VALUE_TYPE_DOUBLE;
}
int mgp_value_is_double(const mgp_value *val) { return mgp_value_get_type(val) == MGP_VALUE_TYPE_DOUBLE; }
int mgp_value_is_string(const mgp_value *val) {
return mgp_value_get_type(val) == MGP_VALUE_TYPE_STRING;
}
int mgp_value_is_string(const mgp_value *val) { return mgp_value_get_type(val) == MGP_VALUE_TYPE_STRING; }
int mgp_value_is_list(const mgp_value *val) {
return mgp_value_get_type(val) == MGP_VALUE_TYPE_LIST;
}
int mgp_value_is_list(const mgp_value *val) { return mgp_value_get_type(val) == MGP_VALUE_TYPE_LIST; }
int mgp_value_is_map(const mgp_value *val) {
return mgp_value_get_type(val) == MGP_VALUE_TYPE_MAP;
}
int mgp_value_is_map(const mgp_value *val) { return mgp_value_get_type(val) == MGP_VALUE_TYPE_MAP; }
int mgp_value_is_vertex(const mgp_value *val) {
return mgp_value_get_type(val) == MGP_VALUE_TYPE_VERTEX;
}
int mgp_value_is_vertex(const mgp_value *val) { return mgp_value_get_type(val) == MGP_VALUE_TYPE_VERTEX; }
int mgp_value_is_edge(const mgp_value *val) {
return mgp_value_get_type(val) == MGP_VALUE_TYPE_EDGE;
}
int mgp_value_is_edge(const mgp_value *val) { return mgp_value_get_type(val) == MGP_VALUE_TYPE_EDGE; }
int mgp_value_is_path(const mgp_value *val) {
return mgp_value_get_type(val) == MGP_VALUE_TYPE_PATH;
}
int mgp_value_is_path(const mgp_value *val) { return mgp_value_get_type(val) == MGP_VALUE_TYPE_PATH; }
int mgp_value_get_bool(const mgp_value *val) { return val->bool_v ? 1 : 0; }
@ -608,17 +544,13 @@ int64_t mgp_value_get_int(const mgp_value *val) { return val->int_v; }
double mgp_value_get_double(const mgp_value *val) { return val->double_v; }
const char *mgp_value_get_string(const mgp_value *val) {
return val->string_v.c_str();
}
const char *mgp_value_get_string(const mgp_value *val) { return val->string_v.c_str(); }
const mgp_list *mgp_value_get_list(const mgp_value *val) { return val->list_v; }
const mgp_map *mgp_value_get_map(const mgp_value *val) { return val->map_v; }
const mgp_vertex *mgp_value_get_vertex(const mgp_value *val) {
return val->vertex_v;
}
const mgp_vertex *mgp_value_get_vertex(const mgp_value *val) { return val->vertex_v; }
const mgp_edge *mgp_value_get_edge(const mgp_value *val) { return val->edge_v; }
@ -656,18 +588,14 @@ int mgp_list_append_extend(mgp_list *list, const mgp_value *val) {
size_t mgp_list_size(const mgp_list *list) { return list->elems.size(); }
size_t mgp_list_capacity(const mgp_list *list) {
return list->elems.capacity();
}
size_t mgp_list_capacity(const mgp_list *list) { return list->elems.capacity(); }
const mgp_value *mgp_list_at(const mgp_list *list, size_t i) {
if (i >= mgp_list_size(list)) return nullptr;
return &list->elems[i];
}
mgp_map *mgp_map_make_empty(mgp_memory *memory) {
return new_mgp_object<mgp_map>(memory);
}
mgp_map *mgp_map_make_empty(mgp_memory *memory) { return new_mgp_object<mgp_map>(memory); }
void mgp_map_destroy(mgp_map *map) { delete_mgp_object(map); }
@ -693,21 +621,15 @@ const mgp_value *mgp_map_at(const mgp_map *map, const char *key) {
const char *mgp_map_item_key(const mgp_map_item *item) { return item->key; }
const mgp_value *mgp_map_item_value(const mgp_map_item *item) {
return item->value;
}
const mgp_value *mgp_map_item_value(const mgp_map_item *item) { return item->value; }
mgp_map_items_iterator *mgp_map_iter_items(const mgp_map *map,
mgp_memory *memory) {
mgp_map_items_iterator *mgp_map_iter_items(const mgp_map *map, mgp_memory *memory) {
return new_mgp_object<mgp_map_items_iterator>(memory, map);
}
void mgp_map_items_iterator_destroy(mgp_map_items_iterator *it) {
delete_mgp_object(it);
}
void mgp_map_items_iterator_destroy(mgp_map_items_iterator *it) { delete_mgp_object(it); }
const mgp_map_item *mgp_map_items_iterator_get(
const mgp_map_items_iterator *it) {
const mgp_map_item *mgp_map_items_iterator_get(const mgp_map_items_iterator *it) {
if (it->current_it == it->map->items.end()) return nullptr;
return &it->current;
}
@ -720,8 +642,7 @@ const mgp_map_item *mgp_map_items_iterator_next(mgp_map_items_iterator *it) {
return &it->current;
}
mgp_path *mgp_path_make_with_start(const mgp_vertex *vertex,
mgp_memory *memory) {
mgp_path *mgp_path_make_with_start(const mgp_vertex *vertex, mgp_memory *memory) {
auto *path = new_mgp_object<mgp_path>(memory);
if (!path) return nullptr;
try {
@ -734,16 +655,14 @@ mgp_path *mgp_path_make_with_start(const mgp_vertex *vertex,
}
mgp_path *mgp_path_copy(const mgp_path *path, mgp_memory *memory) {
MG_ASSERT(mgp_path_size(path) == path->vertices.size() - 1,
"Invalid mgp_path");
MG_ASSERT(mgp_path_size(path) == path->vertices.size() - 1, "Invalid mgp_path");
return new_mgp_object<mgp_path>(memory, *path);
}
void mgp_path_destroy(mgp_path *path) { delete_mgp_object(path); }
int mgp_path_expand(mgp_path *path, const mgp_edge *edge) {
MG_ASSERT(mgp_path_size(path) == path->vertices.size() - 1,
"Invalid mgp_path");
MG_ASSERT(mgp_path_size(path) == path->vertices.size() - 1, "Invalid mgp_path");
// Check that the both the last vertex on path and dst_vertex are endpoints of
// the given edge.
const auto *src_vertex = &path->vertices.back();
@ -820,17 +739,15 @@ mgp_result_record *mgp_result_new_record(mgp_result *res) {
auto *memory = res->rows.get_allocator().GetMemoryResource();
MG_ASSERT(res->signature, "Expected to have a valid signature");
try {
res->rows.push_back(mgp_result_record{
res->signature,
utils::pmr::map<utils::pmr::string, query::TypedValue>(memory)});
res->rows.push_back(
mgp_result_record{res->signature, utils::pmr::map<utils::pmr::string, query::TypedValue>(memory)});
} catch (...) {
return nullptr;
}
return &res->rows.back();
}
int mgp_result_record_insert(mgp_result_record *record, const char *field_name,
const mgp_value *val) {
int mgp_result_record_insert(mgp_result_record *record, const char *field_name, const mgp_value *val) {
auto *memory = record->values.get_allocator().GetMemoryResource();
// Validate field_name & val satisfy the procedure's result signature.
MG_ASSERT(record->signature, "Expected to have a valid signature");
@ -848,12 +765,9 @@ int mgp_result_record_insert(mgp_result_record *record, const char *field_name,
/// Graph Constructs
void mgp_properties_iterator_destroy(mgp_properties_iterator *it) {
delete_mgp_object(it);
}
void mgp_properties_iterator_destroy(mgp_properties_iterator *it) { delete_mgp_object(it); }
const mgp_property *mgp_properties_iterator_get(
const mgp_properties_iterator *it) {
const mgp_property *mgp_properties_iterator_get(const mgp_properties_iterator *it) {
if (it->current) return &it->property;
return nullptr;
}
@ -878,9 +792,7 @@ const mgp_property *mgp_properties_iterator_next(mgp_properties_iterator *it) {
return nullptr;
}
it->current.emplace(
utils::pmr::string(
it->graph->impl->PropertyToName(it->current_it->first),
it->GetMemoryResource()),
utils::pmr::string(it->graph->impl->PropertyToName(it->current_it->first), it->GetMemoryResource()),
mgp_value(it->current_it->second, it->GetMemoryResource()));
it->property.name = it->current->first.c_str();
it->property.value = &it->current->second;
@ -891,19 +803,13 @@ const mgp_property *mgp_properties_iterator_next(mgp_properties_iterator *it) {
}
}
mgp_vertex_id mgp_vertex_get_id(const mgp_vertex *v) {
return mgp_vertex_id{.as_int = v->impl.Gid().AsInt()};
}
mgp_vertex_id mgp_vertex_get_id(const mgp_vertex *v) { return mgp_vertex_id{.as_int = v->impl.Gid().AsInt()}; }
mgp_vertex *mgp_vertex_copy(const mgp_vertex *v, mgp_memory *memory) {
return new_mgp_object<mgp_vertex>(memory, *v);
}
mgp_vertex *mgp_vertex_copy(const mgp_vertex *v, mgp_memory *memory) { return new_mgp_object<mgp_vertex>(memory, *v); }
void mgp_vertex_destroy(mgp_vertex *v) { delete_mgp_object(v); }
int mgp_vertex_equal(const mgp_vertex *a, const mgp_vertex *b) {
return a->impl == b->impl ? 1 : 0;
}
int mgp_vertex_equal(const mgp_vertex *a, const mgp_vertex *b) { return a->impl == b->impl ? 1 : 0; }
size_t mgp_vertex_labels_count(const mgp_vertex *v) {
auto maybe_labels = v->impl.Labels(v->graph->view);
@ -940,10 +846,9 @@ mgp_label mgp_vertex_label_at(const mgp_vertex *v, size_t i) {
}
if (i >= maybe_labels->size()) return mgp_label{nullptr};
const auto &label = (*maybe_labels)[i];
static_assert(
std::is_lvalue_reference_v<decltype(v->graph->impl->LabelToName(label))>,
"Expected LabelToName to return a pointer or reference, so we "
"don't have to take a copy and manage memory.");
static_assert(std::is_lvalue_reference_v<decltype(v->graph->impl->LabelToName(label))>,
"Expected LabelToName to return a pointer or reference, so we "
"don't have to take a copy and manage memory.");
const auto &name = v->graph->impl->LabelToName(label);
return mgp_label{name.c_str()};
}
@ -979,12 +884,9 @@ int mgp_vertex_has_label_named(const mgp_vertex *v, const char *name) {
return *maybe_has_label;
}
int mgp_vertex_has_label(const mgp_vertex *v, mgp_label label) {
return mgp_vertex_has_label_named(v, label.name);
}
int mgp_vertex_has_label(const mgp_vertex *v, mgp_label label) { return mgp_vertex_has_label_named(v, label.name); }
mgp_value *mgp_vertex_get_property(const mgp_vertex *v, const char *name,
mgp_memory *memory) {
mgp_value *mgp_vertex_get_property(const mgp_vertex *v, const char *name, mgp_memory *memory) {
try {
const auto &key = v->graph->impl->NameToProperty(name);
auto maybe_prop = v->impl.GetProperty(v->graph->view, key);
@ -1009,8 +911,7 @@ mgp_value *mgp_vertex_get_property(const mgp_vertex *v, const char *name,
}
}
mgp_properties_iterator *mgp_vertex_iter_properties(const mgp_vertex *v,
mgp_memory *memory) {
mgp_properties_iterator *mgp_vertex_iter_properties(const mgp_vertex *v, mgp_memory *memory) {
// NOTE: This copies the whole properties into the iterator.
// TODO: Think of a good way to avoid the copy which doesn't just rely on some
// assumption that storage may return a pointer to the property store. This
@ -1030,8 +931,7 @@ mgp_properties_iterator *mgp_vertex_iter_properties(const mgp_vertex *v,
return nullptr;
}
}
return new_mgp_object<mgp_properties_iterator>(memory, v->graph,
std::move(*maybe_props));
return new_mgp_object<mgp_properties_iterator>(memory, v->graph, std::move(*maybe_props));
} catch (...) {
// Since we are copying stuff, we may get std::bad_alloc. Hopefully, no
// other exceptions are possible, but catch them all just in case.
@ -1039,12 +939,9 @@ mgp_properties_iterator *mgp_vertex_iter_properties(const mgp_vertex *v,
}
}
void mgp_edges_iterator_destroy(mgp_edges_iterator *it) {
delete_mgp_object(it);
}
void mgp_edges_iterator_destroy(mgp_edges_iterator *it) { delete_mgp_object(it); }
mgp_edges_iterator *mgp_vertex_iter_in_edges(const mgp_vertex *v,
mgp_memory *memory) {
mgp_edges_iterator *mgp_vertex_iter_in_edges(const mgp_vertex *v, mgp_memory *memory) {
auto *it = new_mgp_object<mgp_edges_iterator>(memory, *v);
if (!it) return nullptr;
try {
@ -1076,8 +973,7 @@ mgp_edges_iterator *mgp_vertex_iter_in_edges(const mgp_vertex *v,
return it;
}
mgp_edges_iterator *mgp_vertex_iter_out_edges(const mgp_vertex *v,
mgp_memory *memory) {
mgp_edges_iterator *mgp_vertex_iter_out_edges(const mgp_vertex *v, mgp_memory *memory) {
auto *it = new_mgp_object<mgp_edges_iterator>(memory, *v);
if (!it) return nullptr;
try {
@ -1127,8 +1023,7 @@ const mgp_edge *mgp_edges_iterator_next(mgp_edges_iterator *it) {
it->current_e = std::nullopt;
return nullptr;
}
it->current_e.emplace(**impl_it, it->source_vertex.graph,
it->GetMemoryResource());
it->current_e.emplace(**impl_it, it->source_vertex.graph, it->GetMemoryResource());
return &*it->current_e;
};
try {
@ -1144,9 +1039,7 @@ const mgp_edge *mgp_edges_iterator_next(mgp_edges_iterator *it) {
}
}
mgp_edge_id mgp_edge_get_id(const mgp_edge *e) {
return mgp_edge_id{.as_int = e->impl.Gid().AsInt()};
}
mgp_edge_id mgp_edge_get_id(const mgp_edge *e) { return mgp_edge_id{.as_int = e->impl.Gid().AsInt()}; }
mgp_edge *mgp_edge_copy(const mgp_edge *v, mgp_memory *memory) {
return new_mgp_object<mgp_edge>(memory, v->impl, v->from.graph);
@ -1154,17 +1047,13 @@ mgp_edge *mgp_edge_copy(const mgp_edge *v, mgp_memory *memory) {
void mgp_edge_destroy(mgp_edge *e) { delete_mgp_object(e); }
int mgp_edge_equal(const struct mgp_edge *e1, const struct mgp_edge *e2) {
return e1->impl == e2->impl ? 1 : 0;
}
int mgp_edge_equal(const struct mgp_edge *e1, const struct mgp_edge *e2) { return e1->impl == e2->impl ? 1 : 0; }
mgp_edge_type mgp_edge_get_type(const mgp_edge *e) {
const auto &name = e->from.graph->impl->EdgeTypeToName(e->impl.EdgeType());
static_assert(
std::is_lvalue_reference_v<decltype(
e->from.graph->impl->EdgeTypeToName(e->impl.EdgeType()))>,
"Expected EdgeTypeToName to return a pointer or reference, so we "
"don't have to take a copy and manage memory.");
static_assert(std::is_lvalue_reference_v<decltype(e->from.graph->impl->EdgeTypeToName(e->impl.EdgeType()))>,
"Expected EdgeTypeToName to return a pointer or reference, so we "
"don't have to take a copy and manage memory.");
return mgp_edge_type{name.c_str()};
}
@ -1172,8 +1061,7 @@ const mgp_vertex *mgp_edge_get_from(const mgp_edge *e) { return &e->from; }
const mgp_vertex *mgp_edge_get_to(const mgp_edge *e) { return &e->to; }
mgp_value *mgp_edge_get_property(const mgp_edge *e, const char *name,
mgp_memory *memory) {
mgp_value *mgp_edge_get_property(const mgp_edge *e, const char *name, mgp_memory *memory) {
try {
const auto &key = e->from.graph->impl->NameToProperty(name);
auto view = e->from.graph->view;
@ -1199,8 +1087,7 @@ mgp_value *mgp_edge_get_property(const mgp_edge *e, const char *name,
}
}
mgp_properties_iterator *mgp_edge_iter_properties(const mgp_edge *e,
mgp_memory *memory) {
mgp_properties_iterator *mgp_edge_iter_properties(const mgp_edge *e, mgp_memory *memory) {
// NOTE: This copies the whole properties into iterator.
// TODO: Think of a good way to avoid the copy which doesn't just rely on some
// assumption that storage may return a pointer to the property store. This
@ -1221,8 +1108,7 @@ mgp_properties_iterator *mgp_edge_iter_properties(const mgp_edge *e,
return nullptr;
}
}
return new_mgp_object<mgp_properties_iterator>(memory, e->from.graph,
std::move(*maybe_props));
return new_mgp_object<mgp_properties_iterator>(memory, e->from.graph, std::move(*maybe_props));
} catch (...) {
// Since we are copying stuff, we may get std::bad_alloc. Hopefully, no
// other exceptions are possible, but catch them all just in case.
@ -1230,21 +1116,15 @@ mgp_properties_iterator *mgp_edge_iter_properties(const mgp_edge *e,
}
}
mgp_vertex *mgp_graph_get_vertex_by_id(const mgp_graph *graph, mgp_vertex_id id,
mgp_memory *memory) {
auto maybe_vertex =
graph->impl->FindVertex(storage::Gid::FromInt(id.as_int), graph->view);
if (maybe_vertex)
return new_mgp_object<mgp_vertex>(memory, *maybe_vertex, graph);
mgp_vertex *mgp_graph_get_vertex_by_id(const mgp_graph *graph, mgp_vertex_id id, mgp_memory *memory) {
auto maybe_vertex = graph->impl->FindVertex(storage::Gid::FromInt(id.as_int), graph->view);
if (maybe_vertex) return new_mgp_object<mgp_vertex>(memory, *maybe_vertex, graph);
return nullptr;
}
void mgp_vertices_iterator_destroy(mgp_vertices_iterator *it) {
delete_mgp_object(it);
}
void mgp_vertices_iterator_destroy(mgp_vertices_iterator *it) { delete_mgp_object(it); }
mgp_vertices_iterator *mgp_graph_iter_vertices(const mgp_graph *graph,
mgp_memory *memory) {
mgp_vertices_iterator *mgp_graph_iter_vertices(const mgp_graph *graph, mgp_memory *memory) {
try {
return new_mgp_object<mgp_vertices_iterator>(memory, graph);
} catch (...) {
@ -1337,8 +1217,7 @@ const mgp_type *mgp_type_node() {
const mgp_type *mgp_type_relationship() {
static RelationshipType impl;
static mgp_type relationship_type{
CypherTypePtr(&impl, NoOpCypherTypeDeleter)};
static mgp_type relationship_type{CypherTypePtr(&impl, NoOpCypherTypeDeleter)};
return &relationship_type;
}
@ -1351,8 +1230,7 @@ const mgp_type *mgp_type_path() {
const mgp_type *mgp_type_list(const mgp_type *type) {
if (!type) return nullptr;
// Maps `type` to corresponding instance of ListType.
static utils::pmr::map<const mgp_type *, mgp_type> list_types(
utils::NewDeleteResource());
static utils::pmr::map<const mgp_type *, mgp_type> list_types(utils::NewDeleteResource());
static utils::SpinLock lock;
std::lock_guard<utils::SpinLock> guard(lock);
auto found_it = list_types.find(type);
@ -1362,11 +1240,8 @@ const mgp_type *mgp_type_list(const mgp_type *type) {
CypherTypePtr impl(
alloc.new_object<ListType>(
// Just obtain the pointer to original impl, don't own it.
CypherTypePtr(type->impl.get(), NoOpCypherTypeDeleter),
alloc.GetMemoryResource()),
[alloc](CypherType *base_ptr) mutable {
alloc.delete_object(static_cast<ListType *>(base_ptr));
});
CypherTypePtr(type->impl.get(), NoOpCypherTypeDeleter), alloc.GetMemoryResource()),
[alloc](CypherType *base_ptr) mutable { alloc.delete_object(static_cast<ListType *>(base_ptr)); });
return &list_types.emplace(type, mgp_type{std::move(impl)}).first->second;
} catch (const std::bad_alloc &) {
return nullptr;
@ -1376,34 +1251,28 @@ const mgp_type *mgp_type_list(const mgp_type *type) {
const mgp_type *mgp_type_nullable(const mgp_type *type) {
if (!type) return nullptr;
// Maps `type` to corresponding instance of NullableType.
static utils::pmr::map<const mgp_type *, mgp_type> gNullableTypes(
utils::NewDeleteResource());
static utils::pmr::map<const mgp_type *, mgp_type> gNullableTypes(utils::NewDeleteResource());
static utils::SpinLock lock;
std::lock_guard<utils::SpinLock> guard(lock);
auto found_it = gNullableTypes.find(type);
if (found_it != gNullableTypes.end()) return &found_it->second;
try {
auto alloc = gNullableTypes.get_allocator();
auto impl = NullableType::Create(
CypherTypePtr(type->impl.get(), NoOpCypherTypeDeleter),
alloc.GetMemoryResource());
return &gNullableTypes.emplace(type, mgp_type{std::move(impl)})
.first->second;
auto impl = NullableType::Create(CypherTypePtr(type->impl.get(), NoOpCypherTypeDeleter), alloc.GetMemoryResource());
return &gNullableTypes.emplace(type, mgp_type{std::move(impl)}).first->second;
} catch (const std::bad_alloc &) {
return nullptr;
}
}
mgp_proc *mgp_module_add_read_procedure(mgp_module *module, const char *name,
mgp_proc_cb cb) {
mgp_proc *mgp_module_add_read_procedure(mgp_module *module, const char *name, mgp_proc_cb cb) {
if (!module || !cb) return nullptr;
if (!IsValidIdentifierName(name)) return nullptr;
if (module->procedures.find(name) != module->procedures.end()) return nullptr;
try {
auto *memory = module->procedures.get_allocator().GetMemoryResource();
// May throw std::bad_alloc, std::length_error
return &module->procedures.emplace(name, mgp_proc(name, cb, memory))
.first->second;
return &module->procedures.emplace(name, mgp_proc(name, cb, memory)).first->second;
} catch (...) {
return nullptr;
}
@ -1421,8 +1290,7 @@ int mgp_proc_add_arg(mgp_proc *proc, const char *name, const mgp_type *type) {
}
}
int mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, const mgp_type *type,
const mgp_value *default_value) {
int mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, const mgp_type *type, const mgp_value *default_value) {
if (!proc || !type || !default_value) return 0;
if (!IsValidIdentifierName(name)) return 0;
switch (mgp_value_get_type(default_value)) {
@ -1444,8 +1312,7 @@ int mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, const mgp_type *type,
if (!type->impl->SatisfiesType(*default_value)) return 0;
auto *memory = proc->opt_args.get_allocator().GetMemoryResource();
try {
proc->opt_args.emplace_back(utils::pmr::string(name, memory),
type->impl.get(),
proc->opt_args.emplace_back(utils::pmr::string(name, memory), type->impl.get(),
ToTypedValue(*default_value, memory));
return 1;
} catch (...) {
@ -1455,15 +1322,13 @@ int mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, const mgp_type *type,
namespace {
int AddResultToProc(mgp_proc *proc, const char *name, const mgp_type *type,
bool is_deprecated) {
int AddResultToProc(mgp_proc *proc, const char *name, const mgp_type *type, bool is_deprecated) {
if (!proc || !type) return 0;
if (!IsValidIdentifierName(name)) return 0;
if (proc->results.find(name) != proc->results.end()) return 0;
try {
auto *memory = proc->results.get_allocator().GetMemoryResource();
proc->results.emplace(utils::pmr::string(name, memory),
std::make_pair(type->impl.get(), is_deprecated));
proc->results.emplace(utils::pmr::string(name, memory), std::make_pair(type->impl.get(), is_deprecated));
return 1;
} catch (...) {
return 0;
@ -1472,13 +1337,11 @@ int AddResultToProc(mgp_proc *proc, const char *name, const mgp_type *type,
} // namespace
int mgp_proc_add_result(mgp_proc *proc, const char *name,
const mgp_type *type) {
int mgp_proc_add_result(mgp_proc *proc, const char *name, const mgp_type *type) {
return AddResultToProc(proc, name, type, false);
}
int mgp_proc_add_deprecated_result(mgp_proc *proc, const char *name,
const mgp_type *type) {
int mgp_proc_add_deprecated_result(mgp_proc *proc, const char *name, const mgp_type *type) {
return AddResultToProc(proc, name, type, true);
}
@ -1509,18 +1372,16 @@ std::ostream &PrintValue(const TypedValue &value, std::ostream *stream) {
return (*stream) << utils::Escape(value.ValueString());
case TypedValue::Type::List:
(*stream) << "[";
utils::PrintIterable(
*stream, value.ValueList(), ", ",
[](auto &stream, const auto &elem) { PrintValue(elem, &stream); });
utils::PrintIterable(*stream, value.ValueList(), ", ",
[](auto &stream, const auto &elem) { PrintValue(elem, &stream); });
return (*stream) << "]";
case TypedValue::Type::Map:
(*stream) << "{";
utils::PrintIterable(*stream, value.ValueMap(), ", ",
[](auto &stream, const auto &item) {
// Map keys are not escaped strings.
stream << item.first << ": ";
PrintValue(item.second, &stream);
});
utils::PrintIterable(*stream, value.ValueMap(), ", ", [](auto &stream, const auto &item) {
// Map keys are not escaped strings.
stream << item.first << ": ";
PrintValue(item.second, &stream);
});
return (*stream) << "}";
case TypedValue::Type::Vertex:
case TypedValue::Type::Edge:
@ -1533,24 +1394,20 @@ std::ostream &PrintValue(const TypedValue &value, std::ostream *stream) {
void PrintProcSignature(const mgp_proc &proc, std::ostream *stream) {
(*stream) << proc.name << "(";
utils::PrintIterable(
*stream, proc.args, ", ", [](auto &stream, const auto &arg) {
stream << arg.first << " :: " << arg.second->GetPresentableName();
});
utils::PrintIterable(*stream, proc.args, ", ", [](auto &stream, const auto &arg) {
stream << arg.first << " :: " << arg.second->GetPresentableName();
});
if (!proc.args.empty() && !proc.opt_args.empty()) (*stream) << ", ";
utils::PrintIterable(
*stream, proc.opt_args, ", ", [](auto &stream, const auto &arg) {
stream << std::get<0>(arg) << " = ";
PrintValue(std::get<2>(arg), &stream)
<< " :: " << std::get<1>(arg)->GetPresentableName();
});
utils::PrintIterable(*stream, proc.opt_args, ", ", [](auto &stream, const auto &arg) {
stream << std::get<0>(arg) << " = ";
PrintValue(std::get<2>(arg), &stream) << " :: " << std::get<1>(arg)->GetPresentableName();
});
(*stream) << ") :: (";
utils::PrintIterable(
*stream, proc.results, ", ", [](auto &stream, const auto &name_result) {
const auto &[type, is_deprecated] = name_result.second;
if (is_deprecated) stream << "DEPRECATED ";
stream << name_result.first << " :: " << type->GetPresentableName();
});
utils::PrintIterable(*stream, proc.results, ", ", [](auto &stream, const auto &name_result) {
const auto &[type, is_deprecated] = name_result.second;
if (is_deprecated) stream << "DEPRECATED ";
stream << name_result.first << " :: " << type->GetPresentableName();
});
(*stream) << ")";
}

View File

@ -56,8 +56,7 @@ struct mgp_value {
/// Construct by copying query::TypedValue using utils::MemoryResource.
/// mgp_graph is needed to construct mgp_vertex and mgp_edge.
/// @throw std::bad_alloc
mgp_value(const query::TypedValue &, const mgp_graph *,
utils::MemoryResource *);
mgp_value(const query::TypedValue &, const mgp_graph *, utils::MemoryResource *);
/// Construct by copying storage::PropertyValue using utils::MemoryResource.
/// @throw std::bad_alloc
@ -112,14 +111,11 @@ struct mgp_list {
explicit mgp_list(utils::MemoryResource *memory) : elems(memory) {}
mgp_list(utils::pmr::vector<mgp_value> &&elems, utils::MemoryResource *memory)
: elems(std::move(elems), memory) {}
mgp_list(utils::pmr::vector<mgp_value> &&elems, utils::MemoryResource *memory) : elems(std::move(elems), memory) {}
mgp_list(const mgp_list &other, utils::MemoryResource *memory)
: elems(other.elems, memory) {}
mgp_list(const mgp_list &other, utils::MemoryResource *memory) : elems(other.elems, memory) {}
mgp_list(mgp_list &&other, utils::MemoryResource *memory)
: elems(std::move(other.elems), memory) {}
mgp_list(mgp_list &&other, utils::MemoryResource *memory) : elems(std::move(other.elems), memory) {}
mgp_list(mgp_list &&other) noexcept : elems(std::move(other.elems)) {}
@ -131,9 +127,7 @@ struct mgp_list {
~mgp_list() = default;
utils::MemoryResource *GetMemoryResource() const noexcept {
return elems.get_allocator().GetMemoryResource();
}
utils::MemoryResource *GetMemoryResource() const noexcept { return elems.get_allocator().GetMemoryResource(); }
// C++17 vector can work with incomplete type.
utils::pmr::vector<mgp_value> elems;
@ -145,15 +139,12 @@ struct mgp_map {
explicit mgp_map(utils::MemoryResource *memory) : items(memory) {}
mgp_map(utils::pmr::map<utils::pmr::string, mgp_value> &&items,
utils::MemoryResource *memory)
mgp_map(utils::pmr::map<utils::pmr::string, mgp_value> &&items, utils::MemoryResource *memory)
: items(std::move(items), memory) {}
mgp_map(const mgp_map &other, utils::MemoryResource *memory)
: items(other.items, memory) {}
mgp_map(const mgp_map &other, utils::MemoryResource *memory) : items(other.items, memory) {}
mgp_map(mgp_map &&other, utils::MemoryResource *memory)
: items(std::move(other.items), memory) {}
mgp_map(mgp_map &&other, utils::MemoryResource *memory) : items(std::move(other.items), memory) {}
mgp_map(mgp_map &&other) noexcept : items(std::move(other.items)) {}
@ -165,9 +156,7 @@ struct mgp_map {
~mgp_map() = default;
utils::MemoryResource *GetMemoryResource() const noexcept {
return items.get_allocator().GetMemoryResource();
}
utils::MemoryResource *GetMemoryResource() const noexcept { return items.get_allocator().GetMemoryResource(); }
// Unfortunately using incomplete type with map is undefined, so mgp_map
// needs to be defined after mgp_value.
@ -215,8 +204,7 @@ struct mgp_vertex {
// have everything noexcept here.
static_assert(std::is_nothrow_copy_constructible_v<query::VertexAccessor>);
mgp_vertex(query::VertexAccessor v, const mgp_graph *graph,
utils::MemoryResource *memory) noexcept
mgp_vertex(query::VertexAccessor v, const mgp_graph *graph, utils::MemoryResource *memory) noexcept
: memory(memory), impl(v), graph(graph) {}
mgp_vertex(const mgp_vertex &other, utils::MemoryResource *memory) noexcept
@ -225,8 +213,7 @@ struct mgp_vertex {
mgp_vertex(mgp_vertex &&other, utils::MemoryResource *memory) noexcept
: memory(memory), impl(other.impl), graph(other.graph) {}
mgp_vertex(mgp_vertex &&other) noexcept
: memory(other.memory), impl(other.impl), graph(other.graph) {}
mgp_vertex(mgp_vertex &&other) noexcept : memory(other.memory), impl(other.impl), graph(other.graph) {}
/// Copy construction without utils::MemoryResource is not allowed.
mgp_vertex(const mgp_vertex &) = delete;
@ -253,30 +240,17 @@ struct mgp_edge {
// have everything noexcept here.
static_assert(std::is_nothrow_copy_constructible_v<query::EdgeAccessor>);
mgp_edge(const query::EdgeAccessor &impl, const mgp_graph *graph,
utils::MemoryResource *memory) noexcept
: memory(memory),
impl(impl),
from(impl.From(), graph, memory),
to(impl.To(), graph, memory) {}
mgp_edge(const query::EdgeAccessor &impl, const mgp_graph *graph, utils::MemoryResource *memory) noexcept
: memory(memory), impl(impl), from(impl.From(), graph, memory), to(impl.To(), graph, memory) {}
mgp_edge(const mgp_edge &other, utils::MemoryResource *memory) noexcept
: memory(memory),
impl(other.impl),
from(other.from, memory),
to(other.to, memory) {}
: memory(memory), impl(other.impl), from(other.from, memory), to(other.to, memory) {}
mgp_edge(mgp_edge &&other, utils::MemoryResource *memory) noexcept
: memory(other.memory),
impl(other.impl),
from(std::move(other.from), memory),
to(std::move(other.to), memory) {}
: memory(other.memory), impl(other.impl), from(std::move(other.from), memory), to(std::move(other.to), memory) {}
mgp_edge(mgp_edge &&other) noexcept
: memory(other.memory),
impl(other.impl),
from(std::move(other.from)),
to(std::move(other.to)) {}
: memory(other.memory), impl(other.impl), from(std::move(other.from)), to(std::move(other.to)) {}
/// Copy construction without utils::MemoryResource is not allowed.
mgp_edge(const mgp_edge &) = delete;
@ -298,18 +272,15 @@ struct mgp_path {
/// Allocator type so that STL containers are aware that we need one.
using allocator_type = utils::Allocator<mgp_path>;
explicit mgp_path(utils::MemoryResource *memory)
: vertices(memory), edges(memory) {}
explicit mgp_path(utils::MemoryResource *memory) : vertices(memory), edges(memory) {}
mgp_path(const mgp_path &other, utils::MemoryResource *memory)
: vertices(other.vertices, memory), edges(other.edges, memory) {}
mgp_path(mgp_path &&other, utils::MemoryResource *memory)
: vertices(std::move(other.vertices), memory),
edges(std::move(other.edges), memory) {}
: vertices(std::move(other.vertices), memory), edges(std::move(other.edges), memory) {}
mgp_path(mgp_path &&other) noexcept
: vertices(std::move(other.vertices)), edges(std::move(other.edges)) {}
mgp_path(mgp_path &&other) noexcept : vertices(std::move(other.vertices)), edges(std::move(other.edges)) {}
/// Copy construction without utils::MemoryResource is not allowed.
mgp_path(const mgp_path &) = delete;
@ -319,9 +290,7 @@ struct mgp_path {
~mgp_path() = default;
utils::MemoryResource *GetMemoryResource() const noexcept {
return vertices.get_allocator().GetMemoryResource();
}
utils::MemoryResource *GetMemoryResource() const noexcept { return vertices.get_allocator().GetMemoryResource(); }
utils::pmr::vector<mgp_vertex> vertices;
utils::pmr::vector<mgp_edge> edges;
@ -329,24 +298,18 @@ struct mgp_path {
struct mgp_result_record {
/// Result record signature as defined for mgp_proc.
const utils::pmr::map<utils::pmr::string,
std::pair<const query::procedure::CypherType *, bool>>
*signature;
const utils::pmr::map<utils::pmr::string, std::pair<const query::procedure::CypherType *, bool>> *signature;
utils::pmr::map<utils::pmr::string, query::TypedValue> values;
};
struct mgp_result {
explicit mgp_result(
const utils::pmr::map<
utils::pmr::string,
std::pair<const query::procedure::CypherType *, bool>> *signature,
const utils::pmr::map<utils::pmr::string, std::pair<const query::procedure::CypherType *, bool>> *signature,
utils::MemoryResource *mem)
: signature(signature), rows(mem) {}
/// Result record signature as defined for mgp_proc.
const utils::pmr::map<utils::pmr::string,
std::pair<const query::procedure::CypherType *, bool>>
*signature;
const utils::pmr::map<utils::pmr::string, std::pair<const query::procedure::CypherType *, bool>> *signature;
utils::pmr::vector<mgp_result_record> rows;
std::optional<utils::pmr::string> error_msg;
};
@ -367,32 +330,23 @@ struct mgp_properties_iterator {
utils::MemoryResource *memory;
const mgp_graph *graph;
std::remove_reference_t<decltype(
*std::declval<query::VertexAccessor>().Properties(graph->view))>
pvs;
std::remove_reference_t<decltype(*std::declval<query::VertexAccessor>().Properties(graph->view))> pvs;
decltype(pvs.begin()) current_it;
std::optional<std::pair<utils::pmr::string, mgp_value>> current;
mgp_property property{nullptr, nullptr};
// Construct with no properties.
explicit mgp_properties_iterator(const mgp_graph *graph,
utils::MemoryResource *memory)
explicit mgp_properties_iterator(const mgp_graph *graph, utils::MemoryResource *memory)
: memory(memory), graph(graph), current_it(pvs.begin()) {}
// May throw who the #$@! knows what because PropertyValueStore doesn't
// document what it throws, and it may surely throw some piece of !@#$
// exception because it's built on top of STL and other libraries.
mgp_properties_iterator(const mgp_graph *graph, decltype(pvs) pvs,
utils::MemoryResource *memory)
: memory(memory),
graph(graph),
pvs(std::move(pvs)),
current_it(this->pvs.begin()) {
mgp_properties_iterator(const mgp_graph *graph, decltype(pvs) pvs, utils::MemoryResource *memory)
: memory(memory), graph(graph), pvs(std::move(pvs)), current_it(this->pvs.begin()) {
if (current_it != this->pvs.end()) {
current.emplace(
utils::pmr::string(graph->impl->PropertyToName(current_it->first),
memory),
mgp_value(current_it->second, memory));
current.emplace(utils::pmr::string(graph->impl->PropertyToName(current_it->first), memory),
mgp_value(current_it->second, memory));
property.name = current->first.c_str();
property.value = &current->second;
}
@ -414,11 +368,9 @@ struct mgp_edges_iterator {
// Hopefully mgp_vertex copy constructor remains noexcept, so that we can
// have everything noexcept here.
static_assert(std::is_nothrow_constructible_v<mgp_vertex, const mgp_vertex &,
utils::MemoryResource *>);
static_assert(std::is_nothrow_constructible_v<mgp_vertex, const mgp_vertex &, utils::MemoryResource *>);
mgp_edges_iterator(const mgp_vertex &v,
utils::MemoryResource *memory) noexcept
mgp_edges_iterator(const mgp_vertex &v, utils::MemoryResource *memory) noexcept
: memory(memory), source_vertex(v, memory) {}
mgp_edges_iterator(mgp_edges_iterator &&other) noexcept
@ -440,13 +392,9 @@ struct mgp_edges_iterator {
utils::MemoryResource *memory;
mgp_vertex source_vertex;
std::optional<std::remove_reference_t<decltype(
*source_vertex.impl.InEdges(source_vertex.graph->view))>>
in;
std::optional<std::remove_reference_t<decltype(*source_vertex.impl.InEdges(source_vertex.graph->view))>> in;
std::optional<decltype(in->begin())> in_it;
std::optional<std::remove_reference_t<decltype(
*source_vertex.impl.OutEdges(source_vertex.graph->view))>>
out;
std::optional<std::remove_reference_t<decltype(*source_vertex.impl.OutEdges(source_vertex.graph->view))>> out;
std::optional<decltype(out->begin())> out_it;
std::optional<mgp_edge> current_e;
};
@ -456,10 +404,7 @@ struct mgp_vertices_iterator {
/// @throw anything VerticesIterable may throw
mgp_vertices_iterator(const mgp_graph *graph, utils::MemoryResource *memory)
: memory(memory),
graph(graph),
vertices(graph->impl->Vertices(graph->view)),
current_it(vertices.begin()) {
: memory(memory), graph(graph), vertices(graph->impl->Vertices(graph->view)), current_it(vertices.begin()) {
if (current_it != vertices.end()) {
current_v.emplace(*current_it, graph, memory);
}
@ -484,24 +429,13 @@ struct mgp_proc {
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_proc(const char *name, mgp_proc_cb cb, utils::MemoryResource *memory)
: name(name, memory),
cb(cb),
args(memory),
opt_args(memory),
results(memory) {}
: name(name, memory), cb(cb), args(memory), opt_args(memory), results(memory) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_proc(const char *name,
std::function<void(const mgp_list *, const mgp_graph *, mgp_result *,
mgp_memory *)>
cb,
mgp_proc(const char *name, std::function<void(const mgp_list *, const mgp_graph *, mgp_result *, mgp_memory *)> cb,
utils::MemoryResource *memory)
: name(name, memory),
cb(cb),
args(memory),
opt_args(memory),
results(memory) {}
: name(name, memory), cb(cb), args(memory), opt_args(memory), results(memory) {}
/// @throw std::bad_alloc
/// @throw std::length_error
@ -530,22 +464,13 @@ struct mgp_proc {
/// Name of the procedure.
utils::pmr::string name;
/// Entry-point for the procedure.
std::function<void(const mgp_list *, const mgp_graph *, mgp_result *,
mgp_memory *)>
cb;
std::function<void(const mgp_list *, const mgp_graph *, mgp_result *, mgp_memory *)> cb;
/// Required, positional arguments as a (name, type) pair.
utils::pmr::vector<
std::pair<utils::pmr::string, const query::procedure::CypherType *>>
args;
utils::pmr::vector<std::pair<utils::pmr::string, const query::procedure::CypherType *>> args;
/// Optional positional arguments as a (name, type, default_value) tuple.
utils::pmr::vector<
std::tuple<utils::pmr::string, const query::procedure::CypherType *,
query::TypedValue>>
opt_args;
utils::pmr::vector<std::tuple<utils::pmr::string, const query::procedure::CypherType *, query::TypedValue>> opt_args;
/// Fields this procedure returns, as a (name -> (type, is_deprecated)) map.
utils::pmr::map<utils::pmr::string,
std::pair<const query::procedure::CypherType *, bool>>
results;
utils::pmr::map<utils::pmr::string, std::pair<const query::procedure::CypherType *, bool>> results;
};
struct mgp_module {
@ -553,11 +478,9 @@ struct mgp_module {
explicit mgp_module(utils::MemoryResource *memory) : procedures(memory) {}
mgp_module(const mgp_module &other, utils::MemoryResource *memory)
: procedures(other.procedures, memory) {}
mgp_module(const mgp_module &other, utils::MemoryResource *memory) : procedures(other.procedures, memory) {}
mgp_module(mgp_module &&other, utils::MemoryResource *memory)
: procedures(std::move(other.procedures), memory) {}
mgp_module(mgp_module &&other, utils::MemoryResource *memory) : procedures(std::move(other.procedures), memory) {}
mgp_module(const mgp_module &) = default;
mgp_module(mgp_module &&) = default;

View File

@ -30,8 +30,7 @@ class BuiltinModule final : public Module {
bool Close() override;
const std::map<std::string, mgp_proc, std::less<>> *Procedures()
const override;
const std::map<std::string, mgp_proc, std::less<>> *Procedures() const override;
void AddProcedure(std::string_view name, mgp_proc proc);
@ -46,19 +45,13 @@ BuiltinModule::~BuiltinModule() {}
bool BuiltinModule::Close() { return true; }
const std::map<std::string, mgp_proc, std::less<>> *BuiltinModule::Procedures()
const {
return &procedures_;
}
const std::map<std::string, mgp_proc, std::less<>> *BuiltinModule::Procedures() const { return &procedures_; }
void BuiltinModule::AddProcedure(std::string_view name, mgp_proc proc) {
procedures_.emplace(name, std::move(proc));
}
void BuiltinModule::AddProcedure(std::string_view name, mgp_proc proc) { procedures_.emplace(name, std::move(proc)); }
namespace {
void RegisterMgLoad(ModuleRegistry *module_registry, utils::RWLock *lock,
BuiltinModule *module) {
void RegisterMgLoad(ModuleRegistry *module_registry, utils::RWLock *lock, BuiltinModule *module) {
// Loading relies on the fact that regular procedure invocation through
// CallProcedureCursor::Pull takes ModuleRegistry::lock_ with READ access. To
// load modules we have to upgrade our READ access to WRITE access,
@ -83,27 +76,19 @@ void RegisterMgLoad(ModuleRegistry *module_registry, utils::RWLock *lock,
}
lock->lock_shared();
};
auto load_all_cb = [module_registry, with_unlock_shared](
const mgp_list *, const mgp_graph *, mgp_result *,
mgp_memory *) {
with_unlock_shared(
[&]() { module_registry->UnloadAndLoadModulesFromDirectory(); });
auto load_all_cb = [module_registry, with_unlock_shared](const mgp_list *, const mgp_graph *, mgp_result *,
mgp_memory *) {
with_unlock_shared([&]() { module_registry->UnloadAndLoadModulesFromDirectory(); });
};
mgp_proc load_all("load_all", load_all_cb, utils::NewDeleteResource());
module->AddProcedure("load_all", std::move(load_all));
auto load_cb = [module_registry, with_unlock_shared](
const mgp_list *args, const mgp_graph *, mgp_result *res,
mgp_memory *) {
MG_ASSERT(mgp_list_size(args) == 1U,
"Should have been type checked already");
auto load_cb = [module_registry, with_unlock_shared](const mgp_list *args, const mgp_graph *, mgp_result *res,
mgp_memory *) {
MG_ASSERT(mgp_list_size(args) == 1U, "Should have been type checked already");
const mgp_value *arg = mgp_list_at(args, 0);
MG_ASSERT(mgp_value_is_string(arg),
"Should have been type checked already");
MG_ASSERT(mgp_value_is_string(arg), "Should have been type checked already");
bool succ = false;
with_unlock_shared([&]() {
succ = module_registry->LoadOrReloadModuleFromName(
mgp_value_get_string(arg));
});
with_unlock_shared([&]() { succ = module_registry->LoadOrReloadModuleFromName(mgp_value_get_string(arg)); });
if (!succ) mgp_result_set_error_msg(res, "Failed to (re)load the module.");
};
mgp_proc load("load", load_cb, utils::NewDeleteResource());
@ -113,11 +98,8 @@ void RegisterMgLoad(ModuleRegistry *module_registry, utils::RWLock *lock,
void RegisterMgProcedures(
// We expect modules to be sorted by name.
const std::map<std::string, std::unique_ptr<Module>, std::less<>>
*all_modules,
BuiltinModule *module) {
auto procedures_cb = [all_modules](const mgp_list *, const mgp_graph *,
mgp_result *result, mgp_memory *memory) {
const std::map<std::string, std::unique_ptr<Module>, std::less<>> *all_modules, BuiltinModule *module) {
auto procedures_cb = [all_modules](const mgp_list *, const mgp_graph *, mgp_result *result, mgp_memory *memory) {
// Iterating over all_modules assumes that the standard mechanism of custom
// procedure invocations takes the ModuleRegistry::lock_ with READ access.
// For details on how the invocation is done, take a look at the
@ -125,8 +107,7 @@ void RegisterMgProcedures(
for (const auto &[module_name, module] : *all_modules) {
// Return the results in sorted order by module and by procedure.
static_assert(
std::is_same_v<decltype(module->Procedures()),
const std::map<std::string, mgp_proc, std::less<>> *>,
std::is_same_v<decltype(module->Procedures()), const std::map<std::string, mgp_proc, std::less<>> *>,
"Expected module procedures to be sorted by name");
for (const auto &[proc_name, proc] : *module->Procedures()) {
auto *record = mgp_result_new_record(result);
@ -146,16 +127,14 @@ void RegisterMgProcedures(
ss << module_name << ".";
PrintProcSignature(proc, &ss);
const auto signature = ss.str();
auto *signature_value =
mgp_value_make_string(signature.c_str(), memory);
auto *signature_value = mgp_value_make_string(signature.c_str(), memory);
if (!signature_value) {
mgp_value_destroy(name_value);
mgp_result_set_error_msg(result, "Not enough memory!");
return;
}
int succ1 = mgp_result_record_insert(record, "name", name_value);
int succ2 =
mgp_result_record_insert(record, "signature", signature_value);
int succ2 = mgp_result_record_insert(record, "signature", signature_value);
mgp_value_destroy(name_value);
mgp_value_destroy(signature_value);
if (!succ1 || !succ2) {
@ -206,8 +185,7 @@ class SharedLibraryModule final : public Module {
bool Close() override;
const std::map<std::string, mgp_proc, std::less<>> *Procedures()
const override;
const std::map<std::string, mgp_proc, std::less<>> *Procedures() const override;
private:
/// Path as requested for loading the module from a library.
@ -239,8 +217,7 @@ bool SharedLibraryModule::Load(const std::filesystem::path &file_path) {
return false;
}
// Get required mgp_init_module
init_fn_ = reinterpret_cast<int (*)(mgp_module *, mgp_memory *)>(
dlsym(handle_, "mgp_init_module"));
init_fn_ = reinterpret_cast<int (*)(mgp_module *, mgp_memory *)>(dlsym(handle_, "mgp_init_module"));
const char *error = dlerror();
if (!init_fn_ || error) {
spdlog::error("Unable to load module {}; {}", file_path, error);
@ -248,13 +225,11 @@ bool SharedLibraryModule::Load(const std::filesystem::path &file_path) {
handle_ = nullptr;
return false;
}
if (!WithModuleRegistration(&procedures_, [&](auto *module_def,
auto *memory) {
if (!WithModuleRegistration(&procedures_, [&](auto *module_def, auto *memory) {
// Run mgp_init_module which must succeed.
int init_res = init_fn_(module_def, memory);
if (init_res != 0) {
spdlog::error("Unable to load module {}; mgp_init_module_returned {}",
file_path, init_res);
spdlog::error("Unable to load module {}; mgp_init_module_returned {}", file_path, init_res);
dlclose(handle_);
handle_ = nullptr;
return false;
@ -264,8 +239,7 @@ bool SharedLibraryModule::Load(const std::filesystem::path &file_path) {
return false;
}
// Get optional mgp_shutdown_module
shutdown_fn_ =
reinterpret_cast<int (*)()>(dlsym(handle_, "mgp_shutdown_module"));
shutdown_fn_ = reinterpret_cast<int (*)()>(dlsym(handle_, "mgp_shutdown_module"));
error = dlerror();
if (error) spdlog::warn("When loading module {}; {}", file_path, error);
spdlog::info("Loaded module {}", file_path);
@ -273,16 +247,14 @@ bool SharedLibraryModule::Load(const std::filesystem::path &file_path) {
}
bool SharedLibraryModule::Close() {
MG_ASSERT(handle_,
"Attempting to close a module that has not been loaded...");
MG_ASSERT(handle_, "Attempting to close a module that has not been loaded...");
spdlog::info("Closing module {}...", file_path_);
// non-existent shutdown function is semantically the same as a shutdown
// function that does nothing.
int shutdown_res = 0;
if (shutdown_fn_) shutdown_res = shutdown_fn_();
if (shutdown_res != 0) {
spdlog::warn("When closing module {}; mgp_shutdown_module returned {}",
file_path_, shutdown_res);
spdlog::warn("When closing module {}; mgp_shutdown_module returned {}", file_path_, shutdown_res);
}
if (dlclose(handle_) != 0) {
spdlog::error("Failed to close module {}; {}", file_path_, dlerror());
@ -294,8 +266,7 @@ bool SharedLibraryModule::Close() {
return true;
}
const std::map<std::string, mgp_proc, std::less<>>
*SharedLibraryModule::Procedures() const {
const std::map<std::string, mgp_proc, std::less<>> *SharedLibraryModule::Procedures() const {
MG_ASSERT(handle_,
"Attempting to access procedures of a module that has not "
"been loaded...");
@ -315,8 +286,7 @@ class PythonModule final : public Module {
bool Close() override;
const std::map<std::string, mgp_proc, std::less<>> *Procedures()
const override;
const std::map<std::string, mgp_proc, std::less<>> *Procedures() const override;
private:
std::filesystem::path file_path_;
@ -340,10 +310,9 @@ bool PythonModule::Load(const std::filesystem::path &file_path) {
spdlog::error("Unable to load module {}; {}", file_path, *maybe_exc);
return false;
}
py_module_ =
WithModuleRegistration(&procedures_, [&](auto *module_def, auto *memory) {
return ImportPyModule(file_path.stem().c_str(), module_def);
});
py_module_ = WithModuleRegistration(&procedures_, [&](auto *module_def, auto *memory) {
return ImportPyModule(file_path.stem().c_str(), module_def);
});
if (py_module_) {
spdlog::info("Loaded module {}", file_path);
return true;
@ -354,8 +323,7 @@ bool PythonModule::Load(const std::filesystem::path &file_path) {
}
bool PythonModule::Close() {
MG_ASSERT(py_module_,
"Attempting to close a module that has not been loaded...");
MG_ASSERT(py_module_, "Attempting to close a module that has not been loaded...");
spdlog::info("Closing module {}...", file_path_);
// The procedures are closures which hold references to the Python callbacks.
// Releasing these references might result in deallocations so we need to take
@ -365,8 +333,7 @@ bool PythonModule::Close() {
// Delete the module from the `sys.modules` directory so that the module will
// be properly imported if imported again.
py::Object sys(PyImport_ImportModule("sys"));
if (PyDict_DelItemString(sys.GetAttr("modules").Ptr(),
file_path_.stem().c_str()) != 0) {
if (PyDict_DelItemString(sys.GetAttr("modules").Ptr(), file_path_.stem().c_str()) != 0) {
spdlog::warn("Failed to remove the module from sys.modules");
py_module_ = py::Object(nullptr);
return false;
@ -376,8 +343,7 @@ bool PythonModule::Close() {
return true;
}
const std::map<std::string, mgp_proc, std::less<>> *PythonModule::Procedures()
const {
const std::map<std::string, mgp_proc, std::less<>> *PythonModule::Procedures() const {
MG_ASSERT(py_module_,
"Attempting to access procedures of a module that has "
"not been loaded...");
@ -407,8 +373,7 @@ std::unique_ptr<Module> LoadModuleFromFile(const std::filesystem::path &path) {
} // namespace
bool ModuleRegistry::RegisterModule(const std::string_view &name,
std::unique_ptr<Module> module) {
bool ModuleRegistry::RegisterModule(const std::string_view &name, std::unique_ptr<Module> module) {
MG_ASSERT(!name.empty(), "Module name cannot be empty");
MG_ASSERT(module, "Tried to register an invalid module");
if (modules_.find(name) != modules_.end()) {
@ -420,8 +385,7 @@ bool ModuleRegistry::RegisterModule(const std::string_view &name,
}
void ModuleRegistry::DoUnloadAllModules() {
MG_ASSERT(modules_.find("mg") != modules_.end(),
"Expected the builtin \"mg\" module to be present.");
MG_ASSERT(modules_.find("mg") != modules_.end(), "Expected the builtin \"mg\" module to be present.");
// This is correct because the destructor will close each module. However,
// we don't want to unload the builtin "mg" module.
auto module = std::move(modules_["mg"]);
@ -436,10 +400,7 @@ ModuleRegistry::ModuleRegistry() {
modules_.emplace("mg", std::move(module));
}
void ModuleRegistry::SetModulesDirectory(
const std::filesystem::path &modules_dir) {
modules_dir_ = modules_dir;
}
void ModuleRegistry::SetModulesDirectory(const std::filesystem::path &modules_dir) { modules_dir_ = modules_dir; }
bool ModuleRegistry::LoadOrReloadModuleFromName(const std::string_view &name) {
if (modules_dir_.empty()) return false;
@ -500,16 +461,14 @@ void ModuleRegistry::UnloadAllModules() {
}
std::optional<std::pair<procedure::ModulePtr, const mgp_proc *>> FindProcedure(
const ModuleRegistry &module_registry,
const std::string_view &fully_qualified_procedure_name,
const ModuleRegistry &module_registry, const std::string_view &fully_qualified_procedure_name,
utils::MemoryResource *memory) {
utils::pmr::vector<std::string_view> name_parts(memory);
utils::Split(&name_parts, fully_qualified_procedure_name, ".");
if (name_parts.size() == 1U) return std::nullopt;
auto last_dot_pos = fully_qualified_procedure_name.find_last_of('.');
MG_ASSERT(last_dot_pos != std::string_view::npos);
const auto &module_name =
fully_qualified_procedure_name.substr(0, last_dot_pos);
const auto &module_name = fully_qualified_procedure_name.substr(0, last_dot_pos);
const auto &proc_name = name_parts.back();
auto module = module_registry.GetModuleNamed(module_name);
if (!module) return std::nullopt;

View File

@ -29,8 +29,7 @@ class Module {
virtual bool Close() = 0;
/// Returns registered procedures of this module
virtual const std::map<std::string, mgp_proc, std::less<>> *Procedures()
const = 0;
virtual const std::map<std::string, mgp_proc, std::less<>> *Procedures() const = 0;
};
/// Proxy for a registered Module, acquires a read lock from ModuleRegistry.
@ -41,8 +40,7 @@ class ModulePtr final {
public:
ModulePtr() = default;
ModulePtr(std::nullptr_t) {}
ModulePtr(const Module *module, std::shared_lock<utils::RWLock> lock)
: module_(module), lock_(std::move(lock)) {}
ModulePtr(const Module *module, std::shared_lock<utils::RWLock> lock) : module_(module), lock_(std::move(lock)) {}
explicit operator bool() const { return static_cast<bool>(module_); }
@ -55,8 +53,7 @@ class ModuleRegistry final {
std::map<std::string, std::unique_ptr<Module>, std::less<>> modules_;
mutable utils::RWLock lock_{utils::RWLock::Priority::WRITE};
bool RegisterModule(const std::string_view &name,
std::unique_ptr<Module> module);
bool RegisterModule(const std::string_view &name, std::unique_ptr<Module> module);
void DoUnloadAllModules();
@ -105,8 +102,7 @@ extern ModuleRegistry gModuleRegistry;
/// inside this function. ModulePtr must be kept alive to make sure it won't be
/// unloaded.
std::optional<std::pair<procedure::ModulePtr, const mgp_proc *>> FindProcedure(
const ModuleRegistry &module_registry,
const std::string_view &fully_qualified_procedure_name,
const ModuleRegistry &module_registry, const std::string_view &fully_qualified_procedure_name,
utils::MemoryResource *memory);
} // namespace query::procedure

View File

@ -65,8 +65,7 @@ void PyVerticesIteratorDealloc(PyVerticesIterator *self) {
Py_TYPE(self)->tp_free(self);
}
PyObject *PyVerticesIteratorGet(PyVerticesIterator *self,
PyObject *Py_UNUSED(ignored)) {
PyObject *PyVerticesIteratorGet(PyVerticesIterator *self, PyObject *Py_UNUSED(ignored)) {
MG_ASSERT(self->it);
MG_ASSERT(self->py_graph);
MG_ASSERT(self->py_graph->graph);
@ -75,8 +74,7 @@ PyObject *PyVerticesIteratorGet(PyVerticesIterator *self,
return MakePyVertex(*vertex, self->py_graph);
}
PyObject *PyVerticesIteratorNext(PyVerticesIterator *self,
PyObject *Py_UNUSED(ignored)) {
PyObject *PyVerticesIteratorNext(PyVerticesIterator *self, PyObject *Py_UNUSED(ignored)) {
MG_ASSERT(self->it);
MG_ASSERT(self->py_graph);
MG_ASSERT(self->py_graph->graph);
@ -86,8 +84,7 @@ PyObject *PyVerticesIteratorNext(PyVerticesIterator *self,
}
static PyMethodDef PyVerticesIteratorMethods[] = {
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy),
METH_NOARGS, "__reduce__ is not supported"},
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
{"get", reinterpret_cast<PyCFunction>(PyVerticesIteratorGet), METH_NOARGS,
"Get the current vertex pointed to by the iterator or return None."},
{"next", reinterpret_cast<PyCFunction>(PyVerticesIteratorNext), METH_NOARGS,
@ -128,8 +125,7 @@ void PyEdgesIteratorDealloc(PyEdgesIterator *self) {
Py_TYPE(self)->tp_free(self);
}
PyObject *PyEdgesIteratorGet(PyEdgesIterator *self,
PyObject *Py_UNUSED(ignored)) {
PyObject *PyEdgesIteratorGet(PyEdgesIterator *self, PyObject *Py_UNUSED(ignored)) {
MG_ASSERT(self->it);
MG_ASSERT(self->py_graph);
MG_ASSERT(self->py_graph->graph);
@ -138,8 +134,7 @@ PyObject *PyEdgesIteratorGet(PyEdgesIterator *self,
return MakePyEdge(*edge, self->py_graph);
}
PyObject *PyEdgesIteratorNext(PyEdgesIterator *self,
PyObject *Py_UNUSED(ignored)) {
PyObject *PyEdgesIteratorNext(PyEdgesIterator *self, PyObject *Py_UNUSED(ignored)) {
MG_ASSERT(self->it);
MG_ASSERT(self->py_graph);
MG_ASSERT(self->py_graph->graph);
@ -149,8 +144,7 @@ PyObject *PyEdgesIteratorNext(PyEdgesIterator *self,
}
static PyMethodDef PyEdgesIteratorMethods[] = {
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy),
METH_NOARGS, "__reduce__ is not supported"},
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
{"get", reinterpret_cast<PyCFunction>(PyEdgesIteratorGet), METH_NOARGS,
"Get the current edge pointed to by the iterator or return None."},
{"next", reinterpret_cast<PyCFunction>(PyEdgesIteratorNext), METH_NOARGS,
@ -176,9 +170,7 @@ PyObject *PyGraphInvalidate(PyGraph *self, PyObject *Py_UNUSED(ignored)) {
Py_RETURN_NONE;
}
PyObject *PyGraphIsValid(PyGraph *self, PyObject *Py_UNUSED(ignored)) {
return PyBool_FromLong(!!self->graph);
}
PyObject *PyGraphIsValid(PyGraph *self, PyObject *Py_UNUSED(ignored)) { return PyBool_FromLong(!!self->graph); }
PyObject *MakePyVertex(mgp_vertex *vertex, PyGraph *py_graph);
@ -188,11 +180,9 @@ PyObject *PyGraphGetVertexById(PyGraph *self, PyObject *args) {
static_assert(std::is_same_v<int64_t, long>);
int64_t id;
if (!PyArg_ParseTuple(args, "l", &id)) return nullptr;
auto *vertex =
mgp_graph_get_vertex_by_id(self->graph, mgp_vertex_id{id}, self->memory);
auto *vertex = mgp_graph_get_vertex_by_id(self->graph, mgp_vertex_id{id}, self->memory);
if (!vertex) {
PyErr_SetString(PyExc_IndexError,
"Unable to find the vertex with given ID.");
PyErr_SetString(PyExc_IndexError, "Unable to find the vertex with given ID.");
return nullptr;
}
auto *py_vertex = MakePyVertex(vertex, self);
@ -205,12 +195,10 @@ PyObject *PyGraphIterVertices(PyGraph *self, PyObject *Py_UNUSED(ignored)) {
MG_ASSERT(self->memory);
auto *vertices_it = mgp_graph_iter_vertices(self->graph, self->memory);
if (!vertices_it) {
PyErr_SetString(PyExc_MemoryError,
"Unable to allocate mgp_vertices_iterator.");
PyErr_SetString(PyExc_MemoryError, "Unable to allocate mgp_vertices_iterator.");
return nullptr;
}
auto *py_vertices_it =
PyObject_New(PyVerticesIterator, &PyVerticesIteratorType);
auto *py_vertices_it = PyObject_New(PyVerticesIterator, &PyVerticesIteratorType);
if (!py_vertices_it) {
mgp_vertices_iterator_destroy(vertices_it);
return nullptr;
@ -227,17 +215,14 @@ PyObject *PyGraphMustAbort(PyGraph *self, PyObject *Py_UNUSED(ignored)) {
}
static PyMethodDef PyGraphMethods[] = {
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy),
METH_NOARGS, "__reduce__ is not supported"},
{"invalidate", reinterpret_cast<PyCFunction>(PyGraphInvalidate),
METH_NOARGS,
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
{"invalidate", reinterpret_cast<PyCFunction>(PyGraphInvalidate), METH_NOARGS,
"Invalidate the Graph context thus preventing the Graph from being used."},
{"is_valid", reinterpret_cast<PyCFunction>(PyGraphIsValid), METH_NOARGS,
"Return True if Graph is in valid context and may be used."},
{"get_vertex_by_id", reinterpret_cast<PyCFunction>(PyGraphGetVertexById),
METH_VARARGS, "Get the vertex or raise IndexError."},
{"iter_vertices", reinterpret_cast<PyCFunction>(PyGraphIterVertices),
METH_NOARGS, "Return _mgp.VerticesIterator."},
{"get_vertex_by_id", reinterpret_cast<PyCFunction>(PyGraphGetVertexById), METH_VARARGS,
"Get the vertex or raise IndexError."},
{"iter_vertices", reinterpret_cast<PyCFunction>(PyGraphIterVertices), METH_NOARGS, "Return _mgp.VerticesIterator."},
{"must_abort", reinterpret_cast<PyCFunction>(PyGraphMustAbort), METH_NOARGS,
"Check whether the running procedure should abort"},
{nullptr},
@ -317,8 +302,7 @@ PyObject *PyQueryProcAddOptArg(PyQueryProc *self, PyObject *args) {
const char *name = nullptr;
PyObject *py_type = nullptr;
PyObject *py_value = nullptr;
if (!PyArg_ParseTuple(args, "sOO", &name, &py_type, &py_value))
return nullptr;
if (!PyArg_ParseTuple(args, "sOO", &name, &py_type, &py_value)) return nullptr;
if (Py_TYPE(py_type) != &PyCypherTypeType) {
PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type.");
return nullptr;
@ -379,26 +363,21 @@ PyObject *PyQueryProcAddDeprecatedResult(PyQueryProc *self, PyObject *args) {
}
const auto *type = reinterpret_cast<PyCypherType *>(py_type)->type;
if (!mgp_proc_add_deprecated_result(self->proc, name, type)) {
PyErr_SetString(PyExc_ValueError,
"Invalid call to mgp_proc_add_deprecated_result.");
PyErr_SetString(PyExc_ValueError, "Invalid call to mgp_proc_add_deprecated_result.");
return nullptr;
}
Py_RETURN_NONE;
}
static PyMethodDef PyQueryProcMethods[] = {
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy),
METH_NOARGS, "__reduce__ is not supported"},
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
{"add_arg", reinterpret_cast<PyCFunction>(PyQueryProcAddArg), METH_VARARGS,
"Add a required argument to a procedure."},
{"add_opt_arg", reinterpret_cast<PyCFunction>(PyQueryProcAddOptArg),
METH_VARARGS,
{"add_opt_arg", reinterpret_cast<PyCFunction>(PyQueryProcAddOptArg), METH_VARARGS,
"Add an optional argument with a default value to a procedure."},
{"add_result", reinterpret_cast<PyCFunction>(PyQueryProcAddResult),
METH_VARARGS, "Add a result field to a procedure."},
{"add_deprecated_result",
reinterpret_cast<PyCFunction>(PyQueryProcAddDeprecatedResult),
METH_VARARGS,
{"add_result", reinterpret_cast<PyCFunction>(PyQueryProcAddResult), METH_VARARGS,
"Add a result field to a procedure."},
{"add_deprecated_result", reinterpret_cast<PyCFunction>(PyQueryProcAddDeprecatedResult), METH_VARARGS,
"Add a result field to a procedure and mark it as deprecated."},
{nullptr},
};
@ -447,8 +426,7 @@ py::Object MgpListToPyTuple(const mgp_list *list, PyObject *py_graph) {
namespace {
std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result,
py::Object py_record) {
std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result, py::Object py_record) {
py::Object py_mgp(PyImport_ImportModule("mgp"));
if (!py_mgp) return py::FetchError();
auto record_cls = py_mgp.GetAttr("Record");
@ -463,8 +441,7 @@ std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result,
py::Object fields(py_record.GetAttr("fields"));
if (!fields) return py::FetchError();
if (!PyDict_Check(fields)) {
PyErr_SetString(PyExc_TypeError,
"Expected 'mgp.Record.fields' to be a 'dict'");
PyErr_SetString(PyExc_TypeError, "Expected 'mgp.Record.fields' to be a 'dict'");
return py::FetchError();
}
py::Object items(PyDict_Items(fields.Ptr()));
@ -483,8 +460,7 @@ std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result,
if (!key) return py::FetchError();
if (!PyUnicode_Check(key)) {
std::stringstream ss;
ss << "Field name '" << py::Object::FromBorrow(key)
<< "' is not an instance of 'str'";
ss << "Field name '" << py::Object::FromBorrow(key) << "' is not an instance of 'str'";
const auto &msg = ss.str();
PyErr_SetString(PyExc_TypeError, msg.c_str());
return py::FetchError();
@ -505,9 +481,8 @@ std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result,
MG_ASSERT(field_val);
if (!mgp_result_record_insert(record, field_name, field_val)) {
std::stringstream ss;
ss << "Unable to insert field '" << py::Object::FromBorrow(key)
<< "' with value: '" << py::Object::FromBorrow(val)
<< "'; did you set the correct field type?";
ss << "Unable to insert field '" << py::Object::FromBorrow(key) << "' with value: '"
<< py::Object::FromBorrow(val) << "'; did you set the correct field type?";
const auto &msg = ss.str();
PyErr_SetString(PyExc_ValueError, msg.c_str());
mgp_value_destroy(field_val);
@ -518,8 +493,7 @@ std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result,
return std::nullopt;
}
std::optional<py::ExceptionInfo> AddMultipleRecordsFromPython(
mgp_result *result, py::Object py_seq) {
std::optional<py::ExceptionInfo> AddMultipleRecordsFromPython(mgp_result *result, py::Object py_seq) {
Py_ssize_t len = PySequence_Size(py_seq.Ptr());
if (len == -1) return py::FetchError();
for (Py_ssize_t i = 0; i < len; ++i) {
@ -531,13 +505,11 @@ std::optional<py::ExceptionInfo> AddMultipleRecordsFromPython(
return std::nullopt;
}
void CallPythonProcedure(py::Object py_cb, const mgp_list *args,
const mgp_graph *graph, mgp_result *result,
void CallPythonProcedure(py::Object py_cb, const mgp_list *args, const mgp_graph *graph, mgp_result *result,
mgp_memory *memory) {
auto gil = py::EnsureGIL();
auto error_to_msg = [](const std::optional<py::ExceptionInfo> &exc_info)
-> std::optional<std::string> {
auto error_to_msg = [](const std::optional<py::ExceptionInfo> &exc_info) -> std::optional<std::string> {
if (!exc_info) return std::nullopt;
// Here we tell the traceback formatter to skip the first line of the
// traceback because that line will always be our wrapper function in our
@ -626,23 +598,19 @@ PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) {
const auto *name = PyUnicode_AsUTF8(py_name.Ptr());
if (!name) return nullptr;
if (!IsValidIdentifierName(name)) {
PyErr_SetString(PyExc_ValueError,
"Procedure name is not a valid identifier");
PyErr_SetString(PyExc_ValueError, "Procedure name is not a valid identifier");
return nullptr;
}
auto *memory = self->module->procedures.get_allocator().GetMemoryResource();
mgp_proc proc(
name,
[py_cb](const mgp_list *args, const mgp_graph *graph, mgp_result *result,
mgp_memory *memory) {
[py_cb](const mgp_list *args, const mgp_graph *graph, mgp_result *result, mgp_memory *memory) {
CallPythonProcedure(py_cb, args, graph, result, memory);
},
memory);
const auto &[proc_it, did_insert] =
self->module->procedures.emplace(name, std::move(proc));
const auto &[proc_it, did_insert] = self->module->procedures.emplace(name, std::move(proc));
if (!did_insert) {
PyErr_SetString(PyExc_ValueError,
"Already registered a procedure with the same name.");
PyErr_SetString(PyExc_ValueError, "Already registered a procedure with the same name.");
return nullptr;
}
auto *py_proc = PyObject_New(PyQueryProc, &PyQueryProcType);
@ -652,10 +620,8 @@ PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) {
}
static PyMethodDef PyQueryModuleMethods[] = {
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy),
METH_NOARGS, "__reduce__ is not supported"},
{"add_read_procedure",
reinterpret_cast<PyCFunction>(PyQueryModuleAddReadProcedure), METH_O,
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
{"add_read_procedure", reinterpret_cast<PyCFunction>(PyQueryModuleAddReadProcedure), METH_O,
"Register a read-only procedure with this module."},
{nullptr},
};
@ -697,21 +663,15 @@ PyObject *PyMgpModuleTypeList(PyObject *mod, PyObject *obj) {
return MakePyCypherType(mgp_type_list(py_type->type));
}
PyObject *PyMgpModuleTypeAny(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
return MakePyCypherType(mgp_type_any());
}
PyObject *PyMgpModuleTypeAny(PyObject *mod, PyObject *Py_UNUSED(ignored)) { return MakePyCypherType(mgp_type_any()); }
PyObject *PyMgpModuleTypeBool(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
return MakePyCypherType(mgp_type_bool());
}
PyObject *PyMgpModuleTypeBool(PyObject *mod, PyObject *Py_UNUSED(ignored)) { return MakePyCypherType(mgp_type_bool()); }
PyObject *PyMgpModuleTypeString(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
return MakePyCypherType(mgp_type_string());
}
PyObject *PyMgpModuleTypeInt(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
return MakePyCypherType(mgp_type_int());
}
PyObject *PyMgpModuleTypeInt(PyObject *mod, PyObject *Py_UNUSED(ignored)) { return MakePyCypherType(mgp_type_int()); }
PyObject *PyMgpModuleTypeFloat(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
return MakePyCypherType(mgp_type_float());
@ -721,45 +681,29 @@ PyObject *PyMgpModuleTypeNumber(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
return MakePyCypherType(mgp_type_number());
}
PyObject *PyMgpModuleTypeMap(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
return MakePyCypherType(mgp_type_map());
}
PyObject *PyMgpModuleTypeMap(PyObject *mod, PyObject *Py_UNUSED(ignored)) { return MakePyCypherType(mgp_type_map()); }
PyObject *PyMgpModuleTypeNode(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
return MakePyCypherType(mgp_type_node());
}
PyObject *PyMgpModuleTypeNode(PyObject *mod, PyObject *Py_UNUSED(ignored)) { return MakePyCypherType(mgp_type_node()); }
PyObject *PyMgpModuleTypeRelationship(PyObject *mod,
PyObject *Py_UNUSED(ignored)) {
PyObject *PyMgpModuleTypeRelationship(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
return MakePyCypherType(mgp_type_relationship());
}
PyObject *PyMgpModuleTypePath(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
return MakePyCypherType(mgp_type_path());
}
PyObject *PyMgpModuleTypePath(PyObject *mod, PyObject *Py_UNUSED(ignored)) { return MakePyCypherType(mgp_type_path()); }
static PyMethodDef PyMgpModuleMethods[] = {
{"type_nullable", PyMgpModuleTypeNullable, METH_O,
"Build a type representing either a `null` value or a value of given "
"type."},
{"type_list", PyMgpModuleTypeList, METH_O,
"Build a type representing a list of values of given type."},
{"type_any", PyMgpModuleTypeAny, METH_NOARGS,
"Get the type representing any value that isn't `null`."},
{"type_bool", PyMgpModuleTypeBool, METH_NOARGS,
"Get the type representing boolean values."},
{"type_string", PyMgpModuleTypeString, METH_NOARGS,
"Get the type representing string values."},
{"type_int", PyMgpModuleTypeInt, METH_NOARGS,
"Get the type representing integer values."},
{"type_float", PyMgpModuleTypeFloat, METH_NOARGS,
"Get the type representing floating-point values."},
{"type_number", PyMgpModuleTypeNumber, METH_NOARGS,
"Get the type representing any number value."},
{"type_map", PyMgpModuleTypeMap, METH_NOARGS,
"Get the type representing map values."},
{"type_node", PyMgpModuleTypeNode, METH_NOARGS,
"Get the type representing graph node values."},
{"type_list", PyMgpModuleTypeList, METH_O, "Build a type representing a list of values of given type."},
{"type_any", PyMgpModuleTypeAny, METH_NOARGS, "Get the type representing any value that isn't `null`."},
{"type_bool", PyMgpModuleTypeBool, METH_NOARGS, "Get the type representing boolean values."},
{"type_string", PyMgpModuleTypeString, METH_NOARGS, "Get the type representing string values."},
{"type_int", PyMgpModuleTypeInt, METH_NOARGS, "Get the type representing integer values."},
{"type_float", PyMgpModuleTypeFloat, METH_NOARGS, "Get the type representing floating-point values."},
{"type_number", PyMgpModuleTypeNumber, METH_NOARGS, "Get the type representing any number value."},
{"type_map", PyMgpModuleTypeMap, METH_NOARGS, "Get the type representing map values."},
{"type_node", PyMgpModuleTypeNode, METH_NOARGS, "Get the type representing graph node values."},
{"type_relationship", PyMgpModuleTypeRelationship, METH_NOARGS,
"Get the type representing graph relationship values."},
{"type_path", PyMgpModuleTypePath, METH_NOARGS,
@ -796,8 +740,7 @@ void PyPropertiesIteratorDealloc(PyPropertiesIterator *self) {
Py_TYPE(self)->tp_free(self);
}
PyObject *PyPropertiesIteratorGet(PyPropertiesIterator *self,
PyObject *Py_UNUSED(ignored)) {
PyObject *PyPropertiesIteratorGet(PyPropertiesIterator *self, PyObject *Py_UNUSED(ignored)) {
MG_ASSERT(self->it);
MG_ASSERT(self->py_graph);
MG_ASSERT(self->py_graph->graph);
@ -810,8 +753,7 @@ PyObject *PyPropertiesIteratorGet(PyPropertiesIterator *self,
return PyTuple_Pack(2, py_name.Ptr(), py_value.Ptr());
}
PyObject *PyPropertiesIteratorNext(PyPropertiesIterator *self,
PyObject *Py_UNUSED(ignored)) {
PyObject *PyPropertiesIteratorNext(PyPropertiesIterator *self, PyObject *Py_UNUSED(ignored)) {
MG_ASSERT(self->it);
MG_ASSERT(self->py_graph);
MG_ASSERT(self->py_graph->graph);
@ -827,8 +769,8 @@ PyObject *PyPropertiesIteratorNext(PyPropertiesIterator *self,
static PyMethodDef PyPropertiesIteratorMethods[] = {
{"get", reinterpret_cast<PyCFunction>(PyPropertiesIteratorGet), METH_NOARGS,
"Get the current proprety pointed to by the iterator or return None."},
{"next", reinterpret_cast<PyCFunction>(PyPropertiesIteratorNext),
METH_NOARGS, "Advance the iterator to the next property and return it."},
{"next", reinterpret_cast<PyCFunction>(PyPropertiesIteratorNext), METH_NOARGS,
"Advance the iterator to the next property and return it."},
{nullptr},
};
@ -908,15 +850,12 @@ PyObject *PyEdgeIterProperties(PyEdge *self, PyObject *Py_UNUSED(ignored)) {
MG_ASSERT(self->edge);
MG_ASSERT(self->py_graph);
MG_ASSERT(self->py_graph->graph);
auto *properties_it =
mgp_edge_iter_properties(self->edge, self->py_graph->memory);
auto *properties_it = mgp_edge_iter_properties(self->edge, self->py_graph->memory);
if (!properties_it) {
PyErr_SetString(PyExc_MemoryError,
"Unable to allocate mgp_properties_iterator.");
PyErr_SetString(PyExc_MemoryError, "Unable to allocate mgp_properties_iterator.");
return nullptr;
}
auto *py_properties_it =
PyObject_New(PyPropertiesIterator, &PyPropertiesIteratorType);
auto *py_properties_it = PyObject_New(PyPropertiesIterator, &PyPropertiesIteratorType);
if (!py_properties_it) {
mgp_properties_iterator_destroy(properties_it);
return nullptr;
@ -934,11 +873,9 @@ PyObject *PyEdgeGetProperty(PyEdge *self, PyObject *args) {
MG_ASSERT(self->py_graph->graph);
const char *prop_name = nullptr;
if (!PyArg_ParseTuple(args, "s", &prop_name)) return nullptr;
auto *prop_value =
mgp_edge_get_property(self->edge, prop_name, self->py_graph->memory);
auto *prop_value = mgp_edge_get_property(self->edge, prop_name, self->py_graph->memory);
if (!prop_value) {
PyErr_SetString(PyExc_MemoryError,
"Unable to allocate mgp_value for edge property value.");
PyErr_SetString(PyExc_MemoryError, "Unable to allocate mgp_value for edge property value.");
return nullptr;
}
auto py_prop_value = MgpValueToPyObject(*prop_value, self->py_graph);
@ -947,22 +884,17 @@ PyObject *PyEdgeGetProperty(PyEdge *self, PyObject *args) {
}
static PyMethodDef PyEdgeMethods[] = {
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy),
METH_NOARGS, "__reduce__ is not supported."},
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported."},
{"is_valid", reinterpret_cast<PyCFunction>(PyEdgeIsValid), METH_NOARGS,
"Return True if Edge is in valid context and may be used."},
{"get_id", reinterpret_cast<PyCFunction>(PyEdgeGetId), METH_NOARGS,
"Return edge id."},
{"get_type_name", reinterpret_cast<PyCFunction>(PyEdgeGetTypeName),
METH_NOARGS, "Return the edge's type name."},
{"from_vertex", reinterpret_cast<PyCFunction>(PyEdgeFromVertex),
METH_NOARGS, "Return the edge's source vertex."},
{"to_vertex", reinterpret_cast<PyCFunction>(PyEdgeToVertex), METH_NOARGS,
"Return the edge's destination vertex."},
{"iter_properties", reinterpret_cast<PyCFunction>(PyEdgeIterProperties),
METH_NOARGS, "Return _mgp.PropertiesIterator for this edge."},
{"get_property", reinterpret_cast<PyCFunction>(PyEdgeGetProperty),
METH_VARARGS, "Return edge property with given name."},
{"get_id", reinterpret_cast<PyCFunction>(PyEdgeGetId), METH_NOARGS, "Return edge id."},
{"get_type_name", reinterpret_cast<PyCFunction>(PyEdgeGetTypeName), METH_NOARGS, "Return the edge's type name."},
{"from_vertex", reinterpret_cast<PyCFunction>(PyEdgeFromVertex), METH_NOARGS, "Return the edge's source vertex."},
{"to_vertex", reinterpret_cast<PyCFunction>(PyEdgeToVertex), METH_NOARGS, "Return the edge's destination vertex."},
{"iter_properties", reinterpret_cast<PyCFunction>(PyEdgeIterProperties), METH_NOARGS,
"Return _mgp.PropertiesIterator for this edge."},
{"get_property", reinterpret_cast<PyCFunction>(PyEdgeGetProperty), METH_VARARGS,
"Return edge property with given name."},
{nullptr},
};
@ -1008,8 +940,7 @@ PyObject *PyEdgeRichCompare(PyObject *self, PyObject *other, int op) {
MG_ASSERT(self);
MG_ASSERT(other);
if (Py_TYPE(self) != &PyEdgeType || Py_TYPE(other) != &PyEdgeType ||
op != Py_EQ) {
if (Py_TYPE(self) != &PyEdgeType || Py_TYPE(other) != &PyEdgeType || op != Py_EQ) {
Py_RETURN_NOTIMPLEMENTED;
}
@ -1065,14 +996,12 @@ PyObject *PyVertexLabelAt(PyVertex *self, PyObject *args) {
MG_ASSERT(self->vertex);
MG_ASSERT(self->py_graph);
MG_ASSERT(self->py_graph->graph);
static_assert(std::numeric_limits<Py_ssize_t>::max() <=
std::numeric_limits<size_t>::max());
static_assert(std::numeric_limits<Py_ssize_t>::max() <= std::numeric_limits<size_t>::max());
Py_ssize_t id;
if (!PyArg_ParseTuple(args, "n", &id)) return nullptr;
auto label = mgp_vertex_label_at(self->vertex, id);
if (label.name == nullptr || id < 0) {
PyErr_SetString(PyExc_IndexError,
"Unable to find the label with given ID.");
PyErr_SetString(PyExc_IndexError, "Unable to find the label with given ID.");
return nullptr;
}
return PyUnicode_FromString(label.name);
@ -1083,11 +1012,9 @@ PyObject *PyVertexIterInEdges(PyVertex *self, PyObject *Py_UNUSED(ignored)) {
MG_ASSERT(self->vertex);
MG_ASSERT(self->py_graph);
MG_ASSERT(self->py_graph->graph);
auto *edges_it =
mgp_vertex_iter_in_edges(self->vertex, self->py_graph->memory);
auto *edges_it = mgp_vertex_iter_in_edges(self->vertex, self->py_graph->memory);
if (!edges_it) {
PyErr_SetString(PyExc_MemoryError,
"Unable to allocate mgp_edges_iterator for in edges.");
PyErr_SetString(PyExc_MemoryError, "Unable to allocate mgp_edges_iterator for in edges.");
return nullptr;
}
auto *py_edges_it = PyObject_New(PyEdgesIterator, &PyEdgesIteratorType);
@ -1106,11 +1033,9 @@ PyObject *PyVertexIterOutEdges(PyVertex *self, PyObject *Py_UNUSED(ignored)) {
MG_ASSERT(self->vertex);
MG_ASSERT(self->py_graph);
MG_ASSERT(self->py_graph->graph);
auto *edges_it =
mgp_vertex_iter_out_edges(self->vertex, self->py_graph->memory);
auto *edges_it = mgp_vertex_iter_out_edges(self->vertex, self->py_graph->memory);
if (!edges_it) {
PyErr_SetString(PyExc_MemoryError,
"Unable to allocate mgp_edges_iterator for out edges.");
PyErr_SetString(PyExc_MemoryError, "Unable to allocate mgp_edges_iterator for out edges.");
return nullptr;
}
auto *py_edges_it = PyObject_New(PyEdgesIterator, &PyEdgesIteratorType);
@ -1129,15 +1054,12 @@ PyObject *PyVertexIterProperties(PyVertex *self, PyObject *Py_UNUSED(ignored)) {
MG_ASSERT(self->vertex);
MG_ASSERT(self->py_graph);
MG_ASSERT(self->py_graph->graph);
auto *properties_it =
mgp_vertex_iter_properties(self->vertex, self->py_graph->memory);
auto *properties_it = mgp_vertex_iter_properties(self->vertex, self->py_graph->memory);
if (!properties_it) {
PyErr_SetString(PyExc_MemoryError,
"Unable to allocate mgp_properties_iterator.");
PyErr_SetString(PyExc_MemoryError, "Unable to allocate mgp_properties_iterator.");
return nullptr;
}
auto *py_properties_it =
PyObject_New(PyPropertiesIterator, &PyPropertiesIteratorType);
auto *py_properties_it = PyObject_New(PyPropertiesIterator, &PyPropertiesIteratorType);
if (!py_properties_it) {
mgp_properties_iterator_destroy(properties_it);
return nullptr;
@ -1155,11 +1077,9 @@ PyObject *PyVertexGetProperty(PyVertex *self, PyObject *args) {
MG_ASSERT(self->py_graph->graph);
const char *prop_name = nullptr;
if (!PyArg_ParseTuple(args, "s", &prop_name)) return nullptr;
auto *prop_value =
mgp_vertex_get_property(self->vertex, prop_name, self->py_graph->memory);
auto *prop_value = mgp_vertex_get_property(self->vertex, prop_name, self->py_graph->memory);
if (!prop_value) {
PyErr_SetString(PyExc_MemoryError,
"Unable to allocate mgp_value for vertex property value.");
PyErr_SetString(PyExc_MemoryError, "Unable to allocate mgp_value for vertex property value.");
return nullptr;
}
auto py_prop_value = MgpValueToPyObject(*prop_value, self->py_graph);
@ -1168,24 +1088,22 @@ PyObject *PyVertexGetProperty(PyVertex *self, PyObject *args) {
}
static PyMethodDef PyVertexMethods[] = {
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy),
METH_NOARGS, "__reduce__ is not supported."},
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported."},
{"is_valid", reinterpret_cast<PyCFunction>(PyVertexIsValid), METH_NOARGS,
"Return True if Vertex is in valid context and may be used."},
{"get_id", reinterpret_cast<PyCFunction>(PyVertexGetId), METH_NOARGS,
"Return vertex id."},
{"labels_count", reinterpret_cast<PyCFunction>(PyVertexLabelsCount),
METH_NOARGS, "Return number of lables of a vertex."},
{"get_id", reinterpret_cast<PyCFunction>(PyVertexGetId), METH_NOARGS, "Return vertex id."},
{"labels_count", reinterpret_cast<PyCFunction>(PyVertexLabelsCount), METH_NOARGS,
"Return number of lables of a vertex."},
{"label_at", reinterpret_cast<PyCFunction>(PyVertexLabelAt), METH_VARARGS,
"Return label of a vertex on a given index."},
{"iter_in_edges", reinterpret_cast<PyCFunction>(PyVertexIterInEdges),
METH_NOARGS, "Return _mgp.EdgesIterator for in edges."},
{"iter_out_edges", reinterpret_cast<PyCFunction>(PyVertexIterOutEdges),
METH_NOARGS, "Return _mgp.EdgesIterator for out edges."},
{"iter_properties", reinterpret_cast<PyCFunction>(PyVertexIterProperties),
METH_NOARGS, "Return _mgp.PropertiesIterator for this vertex."},
{"get_property", reinterpret_cast<PyCFunction>(PyVertexGetProperty),
METH_VARARGS, "Return vertex property with given name."},
{"iter_in_edges", reinterpret_cast<PyCFunction>(PyVertexIterInEdges), METH_NOARGS,
"Return _mgp.EdgesIterator for in edges."},
{"iter_out_edges", reinterpret_cast<PyCFunction>(PyVertexIterOutEdges), METH_NOARGS,
"Return _mgp.EdgesIterator for out edges."},
{"iter_properties", reinterpret_cast<PyCFunction>(PyVertexIterProperties), METH_NOARGS,
"Return _mgp.PropertiesIterator for this vertex."},
{"get_property", reinterpret_cast<PyCFunction>(PyVertexGetProperty), METH_VARARGS,
"Return vertex property with given name."},
{nullptr},
};
@ -1234,8 +1152,7 @@ PyObject *PyVertexRichCompare(PyObject *self, PyObject *other, int op) {
MG_ASSERT(self);
MG_ASSERT(other);
if (Py_TYPE(self) != &PyVertexType || Py_TYPE(other) != &PyVertexType ||
op != Py_EQ) {
if (Py_TYPE(self) != &PyVertexType || Py_TYPE(other) != &PyVertexType || op != Py_EQ) {
Py_RETURN_NOTIMPLEMENTED;
}
@ -1283,12 +1200,9 @@ PyObject *PyPathExpand(PyPath *self, PyObject *edge) {
auto *py_edge = reinterpret_cast<PyEdge *>(edge);
const auto *to = mgp_edge_get_to(py_edge->edge);
const auto *from = mgp_edge_get_from(py_edge->edge);
const auto *last_vertex =
mgp_path_vertex_at(self->path, mgp_path_size(self->path));
if (!mgp_vertex_equal(last_vertex, to) &&
!mgp_vertex_equal(last_vertex, from)) {
PyErr_SetString(PyExc_ValueError,
"Edge is not a continuation of the path.");
const auto *last_vertex = mgp_path_vertex_at(self->path, mgp_path_size(self->path));
if (!mgp_vertex_equal(last_vertex, to) && !mgp_vertex_equal(last_vertex, from)) {
PyErr_SetString(PyExc_ValueError, "Edge is not a continuation of the path.");
return nullptr;
}
if (!mgp_path_expand(self->path, py_edge->edge)) {
@ -1309,8 +1223,7 @@ PyObject *PyPathVertexAt(PyPath *self, PyObject *args) {
MG_ASSERT(self->path);
MG_ASSERT(self->py_graph);
MG_ASSERT(self->py_graph->graph);
static_assert(std::numeric_limits<Py_ssize_t>::max() <=
std::numeric_limits<size_t>::max());
static_assert(std::numeric_limits<Py_ssize_t>::max() <= std::numeric_limits<size_t>::max());
Py_ssize_t i;
if (!PyArg_ParseTuple(args, "n", &i)) return nullptr;
const auto *vertex = mgp_path_vertex_at(self->path, i);
@ -1325,8 +1238,7 @@ PyObject *PyPathEdgeAt(PyPath *self, PyObject *args) {
MG_ASSERT(self->path);
MG_ASSERT(self->py_graph);
MG_ASSERT(self->py_graph->graph);
static_assert(std::numeric_limits<Py_ssize_t>::max() <=
std::numeric_limits<size_t>::max());
static_assert(std::numeric_limits<Py_ssize_t>::max() <= std::numeric_limits<size_t>::max());
Py_ssize_t i;
if (!PyArg_ParseTuple(args, "n", &i)) return nullptr;
const auto *edge = mgp_path_edge_at(self->path, i);
@ -1338,16 +1250,14 @@ PyObject *PyPathEdgeAt(PyPath *self, PyObject *args) {
}
static PyMethodDef PyPathMethods[] = {
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy),
METH_NOARGS, "__reduce__ is not supported"},
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
{"is_valid", reinterpret_cast<PyCFunction>(PyPathIsValid), METH_NOARGS,
"Return True if Path is in valid context and may be used."},
{"make_with_start", reinterpret_cast<PyCFunction>(PyPathMakeWithStart),
METH_O | METH_CLASS, "Create a path with a starting vertex."},
{"make_with_start", reinterpret_cast<PyCFunction>(PyPathMakeWithStart), METH_O | METH_CLASS,
"Create a path with a starting vertex."},
{"expand", reinterpret_cast<PyCFunction>(PyPathExpand), METH_O,
"Append an edge continuing from the last vertex on the path."},
{"size", reinterpret_cast<PyCFunction>(PyPathSize), METH_NOARGS,
"Return the number of edges in a mgp_path."},
{"size", reinterpret_cast<PyCFunction>(PyPathSize), METH_NOARGS, "Return the number of edges in a mgp_path."},
{"vertex_at", reinterpret_cast<PyCFunction>(PyPathVertexAt), METH_VARARGS,
"Return the vertex from a path at given index."},
{"edge_at", reinterpret_cast<PyCFunction>(PyPathEdgeAt), METH_VARARGS,
@ -1394,18 +1304,15 @@ PyObject *MakePyPath(const mgp_path &path, PyGraph *py_graph) {
PyObject *PyPathMakeWithStart(PyTypeObject *type, PyObject *vertex) {
if (type != &PyPathType) {
PyErr_SetString(PyExc_TypeError,
"Expected '<class _mgp.Path>' as the first argument.");
PyErr_SetString(PyExc_TypeError, "Expected '<class _mgp.Path>' as the first argument.");
return nullptr;
}
if (Py_TYPE(vertex) != &PyVertexType) {
PyErr_SetString(PyExc_TypeError,
"Expected a '_mgp.Vertex' as the second argument.");
PyErr_SetString(PyExc_TypeError, "Expected a '_mgp.Vertex' as the second argument.");
return nullptr;
}
auto *py_vertex = reinterpret_cast<PyVertex *>(vertex);
auto *path =
mgp_path_make_with_start(py_vertex->vertex, py_vertex->py_graph->memory);
auto *path = mgp_path_make_with_start(py_vertex->vertex, py_vertex->py_graph->memory);
if (!path) {
PyErr_SetString(PyExc_MemoryError, "Unable to allocate mgp_path.");
return nullptr;
@ -1431,10 +1338,8 @@ PyObject *PyInitMgpModule() {
}
return true;
};
if (!register_type(&PyPropertiesIteratorType, "PropertiesIterator"))
return nullptr;
if (!register_type(&PyVerticesIteratorType, "VerticesIterator"))
return nullptr;
if (!register_type(&PyPropertiesIteratorType, "PropertiesIterator")) return nullptr;
if (!register_type(&PyVerticesIteratorType, "VerticesIterator")) return nullptr;
if (!register_type(&PyEdgesIteratorType, "EdgesIterator")) return nullptr;
if (!register_type(&PyGraphType, "Graph")) return nullptr;
if (!register_type(&PyEdgeType, "Edge")) return nullptr;
@ -1481,14 +1386,11 @@ auto WithMgpModule(mgp_module *module_def, const TFun &fun) {
} // namespace
py::Object ImportPyModule(const char *name, mgp_module *module_def) {
return WithMgpModule(
module_def, [name]() { return py::Object(PyImport_ImportModule(name)); });
return WithMgpModule(module_def, [name]() { return py::Object(PyImport_ImportModule(name)); });
}
py::Object ReloadPyModule(PyObject *py_module, mgp_module *module_def) {
return WithMgpModule(module_def, [py_module]() {
return py::Object(PyImport_ReloadModule(py_module));
});
return WithMgpModule(module_def, [py_module]() { return py::Object(PyImport_ReloadModule(py_module)); });
}
py::Object MgpValueToPyObject(const mgp_value &value, PyObject *py_graph) {
@ -1522,8 +1424,7 @@ py::Object MgpValueToPyObject(const mgp_value &value, PyGraph *py_graph) {
auto py_val = MgpValueToPyObject(val, py_graph);
if (!py_val) return nullptr;
// Unlike PyList_SET_ITEM, PyDict_SetItem does not steal the value.
if (PyDict_SetItemString(py_dict.Ptr(), key.c_str(), py_val.Ptr()) != 0)
return nullptr;
if (PyDict_SetItemString(py_dict.Ptr(), key.c_str(), py_val.Ptr()) != 0) return nullptr;
}
return py_dict;
}
@ -1531,34 +1432,29 @@ py::Object MgpValueToPyObject(const mgp_value &value, PyGraph *py_graph) {
py::Object py_mgp(PyImport_ImportModule("mgp"));
if (!py_mgp) return nullptr;
const auto *v = mgp_value_get_vertex(&value);
py::Object py_vertex(
reinterpret_cast<PyObject *>(MakePyVertex(*v, py_graph)));
py::Object py_vertex(reinterpret_cast<PyObject *>(MakePyVertex(*v, py_graph)));
return py_mgp.CallMethod("Vertex", py_vertex);
}
case MGP_VALUE_TYPE_EDGE: {
py::Object py_mgp(PyImport_ImportModule("mgp"));
if (!py_mgp) return nullptr;
const auto *e = mgp_value_get_edge(&value);
py::Object py_edge(
reinterpret_cast<PyObject *>(MakePyEdge(*e, py_graph)));
py::Object py_edge(reinterpret_cast<PyObject *>(MakePyEdge(*e, py_graph)));
return py_mgp.CallMethod("Edge", py_edge);
}
case MGP_VALUE_TYPE_PATH: {
py::Object py_mgp(PyImport_ImportModule("mgp"));
if (!py_mgp) return nullptr;
const auto *p = mgp_value_get_path(&value);
py::Object py_path(
reinterpret_cast<PyObject *>(MakePyPath(*p, py_graph)));
py::Object py_path(reinterpret_cast<PyObject *>(MakePyPath(*p, py_graph)));
return py_mgp.CallMethod("Path", py_path);
}
}
}
mgp_value *PyObjectToMgpValue(PyObject *o, mgp_memory *memory) {
auto py_seq_to_list = [memory](PyObject *seq, Py_ssize_t len,
const auto &py_seq_get_item) {
static_assert(std::numeric_limits<Py_ssize_t>::max() <=
std::numeric_limits<size_t>::max());
auto py_seq_to_list = [memory](PyObject *seq, Py_ssize_t len, const auto &py_seq_get_item) {
static_assert(std::numeric_limits<Py_ssize_t>::max() <= std::numeric_limits<size_t>::max());
mgp_list *list = mgp_list_make_empty(len, memory);
if (!list) throw std::bad_alloc();
for (Py_ssize_t i = 0; i < len; ++i) {
@ -1603,8 +1499,7 @@ mgp_value *PyObjectToMgpValue(PyObject *o, mgp_memory *memory) {
if (res == -1) {
PyErr_Clear();
std::stringstream ss;
ss << "Error when checking object is instance of 'mgp." << mgp_type_name
<< "' type";
ss << "Error when checking object is instance of 'mgp." << mgp_type_name << "' type";
throw std::invalid_argument(ss.str());
}
return static_cast<bool>(res);
@ -1628,13 +1523,9 @@ mgp_value *PyObjectToMgpValue(PyObject *o, mgp_memory *memory) {
} else if (PyUnicode_Check(o)) {
mgp_v = mgp_value_make_string(PyUnicode_AsUTF8(o), memory);
} else if (PyList_Check(o)) {
mgp_v = py_seq_to_list(o, PyList_Size(o), [](auto *list, const auto i) {
return PyList_GET_ITEM(list, i);
});
mgp_v = py_seq_to_list(o, PyList_Size(o), [](auto *list, const auto i) { return PyList_GET_ITEM(list, i); });
} else if (PyTuple_Check(o)) {
mgp_v = py_seq_to_list(o, PyTuple_Size(o), [](auto *tuple, const auto i) {
return PyTuple_GET_ITEM(tuple, i);
});
mgp_v = py_seq_to_list(o, PyTuple_Size(o), [](auto *tuple, const auto i) { return PyTuple_GET_ITEM(tuple, i); });
} else if (PyDict_Check(o)) {
mgp_map *map = mgp_map_make_empty(memory);
@ -1725,8 +1616,7 @@ mgp_value *PyObjectToMgpValue(PyObject *o, mgp_memory *memory) {
py::Object vertex(PyObject_GetAttrString(o, "_vertex"));
if (!vertex) {
PyErr_Clear();
throw std::invalid_argument(
"'mgp.Vertex' is missing '_vertex' attribute");
throw std::invalid_argument("'mgp.Vertex' is missing '_vertex' attribute");
}
return PyObjectToMgpValue(vertex.Ptr(), memory);
} else if (is_mgp_instance(o, "Path")) {

View File

@ -22,17 +22,15 @@ class AnyStream final {
public:
template <class TStream>
AnyStream(TStream *stream, utils::MemoryResource *memory_resource)
: content_{utils::Allocator<GenericWrapper<TStream>>{memory_resource}
.template new_object<GenericWrapper<TStream>>(stream),
[memory_resource](Wrapper *ptr) {
utils::Allocator<GenericWrapper<TStream>>{memory_resource}
.template delete_object<GenericWrapper<TStream>>(
static_cast<GenericWrapper<TStream> *>(ptr));
}} {}
: content_{
utils::Allocator<GenericWrapper<TStream>>{memory_resource}.template new_object<GenericWrapper<TStream>>(
stream),
[memory_resource](Wrapper *ptr) {
utils::Allocator<GenericWrapper<TStream>>{memory_resource}
.template delete_object<GenericWrapper<TStream>>(static_cast<GenericWrapper<TStream> *>(ptr));
}} {}
void Result(const std::vector<TypedValue> &values) {
content_->Result(values);
}
void Result(const std::vector<TypedValue> &values) { content_->Result(values); }
private:
struct Wrapper {
@ -43,9 +41,7 @@ class AnyStream final {
struct GenericWrapper final : public Wrapper {
explicit GenericWrapper(TStream *stream) : stream_{stream} {}
void Result(const std::vector<TypedValue> &values) override {
stream_->Result(values);
}
void Result(const std::vector<TypedValue> &values) override { stream_->Result(values); }
TStream *stream_;
};

Some files were not shown because too many files have changed in this diff Show More