Format all the memgraph and test source files (#97)
This commit is contained in:
parent
435af8b833
commit
3f3c55a4aa
@ -38,8 +38,7 @@ inline nlohmann::json PropertyValueToJson(const storage::PropertyValue &pv) {
|
||||
case storage::PropertyValue::Type::Map: {
|
||||
ret = nlohmann::json::object();
|
||||
for (const auto &item : pv.ValueMap()) {
|
||||
ret.push_back(nlohmann::json::object_t::value_type(
|
||||
item.first, PropertyValueToJson(item.second)));
|
||||
ret.push_back(nlohmann::json::object_t::value_type(item.first, PropertyValueToJson(item.second)));
|
||||
}
|
||||
break;
|
||||
}
|
||||
@ -47,8 +46,7 @@ inline nlohmann::json PropertyValueToJson(const storage::PropertyValue &pv) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
Log::Log(const std::filesystem::path &storage_directory, int32_t buffer_size,
|
||||
int32_t buffer_flush_interval_millis)
|
||||
Log::Log(const std::filesystem::path &storage_directory, int32_t buffer_size, int32_t buffer_flush_interval_millis)
|
||||
: storage_directory_(storage_directory),
|
||||
buffer_size_(buffer_size),
|
||||
buffer_flush_interval_millis_(buffer_flush_interval_millis),
|
||||
@ -63,9 +61,7 @@ void Log::Start() {
|
||||
started_ = true;
|
||||
|
||||
ReopenLog();
|
||||
scheduler_.Run("Audit",
|
||||
std::chrono::milliseconds(buffer_flush_interval_millis_),
|
||||
[&] { Flush(); });
|
||||
scheduler_.Run("Audit", std::chrono::milliseconds(buffer_flush_interval_millis_), [&] { Flush(); });
|
||||
}
|
||||
|
||||
Log::~Log() {
|
||||
@ -78,12 +74,11 @@ Log::~Log() {
|
||||
Flush();
|
||||
}
|
||||
|
||||
void Log::Record(const std::string &address, const std::string &username,
|
||||
const std::string &query,
|
||||
void Log::Record(const std::string &address, const std::string &username, const std::string &query,
|
||||
const storage::PropertyValue ¶ms) {
|
||||
if (!started_.load(std::memory_order_relaxed)) return;
|
||||
auto timestamp = std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
std::chrono::system_clock::now().time_since_epoch())
|
||||
auto timestamp =
|
||||
std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch())
|
||||
.count();
|
||||
buffer_->emplace(Item{timestamp, address, username, query, params});
|
||||
}
|
||||
@ -92,8 +87,7 @@ void Log::ReopenLog() {
|
||||
if (!started_.load(std::memory_order_relaxed)) return;
|
||||
std::lock_guard<std::mutex> guard(lock_);
|
||||
if (log_.IsOpen()) log_.Close();
|
||||
log_.Open(storage_directory_ / "audit.log",
|
||||
utils::OutputFile::Mode::APPEND_TO_EXISTING);
|
||||
log_.Open(storage_directory_ / "audit.log", utils::OutputFile::Mode::APPEND_TO_EXISTING);
|
||||
}
|
||||
|
||||
void Log::Flush() {
|
||||
@ -101,10 +95,8 @@ void Log::Flush() {
|
||||
for (uint64_t i = 0; i < buffer_size_; ++i) {
|
||||
auto item = buffer_->pop();
|
||||
if (!item) break;
|
||||
log_.Write(
|
||||
fmt::format("{}.{:06d},{},{},{},{}\n", item->timestamp / 1000000,
|
||||
item->timestamp % 1000000, item->address, item->username,
|
||||
utils::Escape(item->query),
|
||||
log_.Write(fmt::format("{}.{:06d},{},{},{},{}\n", item->timestamp / 1000000, item->timestamp % 1000000,
|
||||
item->address, item->username, utils::Escape(item->query),
|
||||
utils::Escape(PropertyValueToJson(item->params).dump())));
|
||||
}
|
||||
log_.Sync();
|
||||
|
@ -27,8 +27,7 @@ class Log {
|
||||
};
|
||||
|
||||
public:
|
||||
Log(const std::filesystem::path &storage_directory, int32_t buffer_size,
|
||||
int32_t buffer_flush_interval_millis);
|
||||
Log(const std::filesystem::path &storage_directory, int32_t buffer_size, int32_t buffer_flush_interval_millis);
|
||||
|
||||
~Log();
|
||||
|
||||
@ -43,8 +42,8 @@ class Log {
|
||||
void Start();
|
||||
|
||||
/// Adds an entry to the audit log. Thread-safe.
|
||||
void Record(const std::string &address, const std::string &username,
|
||||
const std::string &query, const storage::PropertyValue ¶ms);
|
||||
void Record(const std::string &address, const std::string &username, const std::string &query,
|
||||
const storage::PropertyValue ¶ms);
|
||||
|
||||
/// Reopens the log file. Used for log file rotation. Thread-safe.
|
||||
void ReopenLog();
|
||||
|
@ -12,26 +12,20 @@
|
||||
#include "utils/logging.hpp"
|
||||
#include "utils/string.hpp"
|
||||
|
||||
DEFINE_VALIDATED_string(
|
||||
auth_module_executable, "",
|
||||
"Absolute path to the auth module executable that should be used.", {
|
||||
DEFINE_VALIDATED_string(auth_module_executable, "", "Absolute path to the auth module executable that should be used.",
|
||||
{
|
||||
if (value.empty()) return true;
|
||||
// Check the file status, following symlinks.
|
||||
auto status = std::filesystem::status(value);
|
||||
if (!std::filesystem::is_regular_file(status)) {
|
||||
std::cerr << "The auth module path doesn't exist or isn't a file!"
|
||||
<< std::endl;
|
||||
std::cerr << "The auth module path doesn't exist or isn't a file!" << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
DEFINE_bool(auth_module_create_missing_user, true,
|
||||
"Set to false to disable creation of missing users.");
|
||||
DEFINE_bool(auth_module_create_missing_role, true,
|
||||
"Set to false to disable creation of missing roles.");
|
||||
DEFINE_bool(
|
||||
auth_module_manage_roles, true,
|
||||
"Set to false to disable management of roles through the auth module.");
|
||||
DEFINE_bool(auth_module_create_missing_user, true, "Set to false to disable creation of missing users.");
|
||||
DEFINE_bool(auth_module_create_missing_role, true, "Set to false to disable creation of missing roles.");
|
||||
DEFINE_bool(auth_module_manage_roles, true, "Set to false to disable management of roles through the auth module.");
|
||||
DEFINE_VALIDATED_int32(auth_module_timeout_ms, 10000,
|
||||
"Timeout (in milliseconds) used when waiting for a "
|
||||
"response from the auth module.",
|
||||
@ -60,11 +54,9 @@ const std::string kLinkPrefix = "link:";
|
||||
* key="link:<username>", value="<rolename>"
|
||||
*/
|
||||
|
||||
Auth::Auth(const std::string &storage_directory)
|
||||
: storage_(storage_directory), module_(FLAGS_auth_module_executable) {}
|
||||
Auth::Auth(const std::string &storage_directory) : storage_(storage_directory), module_(FLAGS_auth_module_executable) {}
|
||||
|
||||
std::optional<User> Auth::Authenticate(const std::string &username,
|
||||
const std::string &password) {
|
||||
std::optional<User> Auth::Authenticate(const std::string &username, const std::string &password) {
|
||||
if (module_.IsUsed()) {
|
||||
nlohmann::json params = nlohmann::json::object();
|
||||
params["username"] = username;
|
||||
@ -73,8 +65,7 @@ std::optional<User> Auth::Authenticate(const std::string &username,
|
||||
auto ret = module_.Call(params, FLAGS_auth_module_timeout_ms);
|
||||
|
||||
// Verify response integrity.
|
||||
if (!ret.is_object() || ret.find("authenticated") == ret.end() ||
|
||||
ret.find("role") == ret.end()) {
|
||||
if (!ret.is_object() || ret.find("authenticated") == ret.end() || ret.find("role") == ret.end()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
const auto &ret_authenticated = ret.at("authenticated");
|
||||
@ -142,9 +133,7 @@ std::optional<User> Auth::Authenticate(const std::string &username,
|
||||
} else {
|
||||
auto user = GetUser(username);
|
||||
if (!user) {
|
||||
spdlog::warn(
|
||||
"Couldn't authenticate user '{}' because the user doesn't exist",
|
||||
username);
|
||||
spdlog::warn("Couldn't authenticate user '{}' because the user doesn't exist", username);
|
||||
return std::nullopt;
|
||||
}
|
||||
if (!user->CheckPassword(password)) {
|
||||
@ -182,12 +171,10 @@ std::optional<User> Auth::GetUser(const std::string &username_orig) {
|
||||
void Auth::SaveUser(const User &user) {
|
||||
bool success = false;
|
||||
if (user.role()) {
|
||||
success = storage_.PutMultiple(
|
||||
{{kUserPrefix + user.username(), user.Serialize().dump()},
|
||||
success = storage_.PutMultiple({{kUserPrefix + user.username(), user.Serialize().dump()},
|
||||
{kLinkPrefix + user.username(), user.role()->rolename()}});
|
||||
} else {
|
||||
success = storage_.PutAndDeleteMultiple(
|
||||
{{kUserPrefix + user.username(), user.Serialize().dump()}},
|
||||
success = storage_.PutAndDeleteMultiple({{kUserPrefix + user.username(), user.Serialize().dump()}},
|
||||
{kLinkPrefix + user.username()});
|
||||
}
|
||||
if (!success) {
|
||||
@ -195,8 +182,7 @@ void Auth::SaveUser(const User &user) {
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<User> Auth::AddUser(const std::string &username,
|
||||
const std::optional<std::string> &password) {
|
||||
std::optional<User> Auth::AddUser(const std::string &username, const std::optional<std::string> &password) {
|
||||
auto existing_user = GetUser(username);
|
||||
if (existing_user) return std::nullopt;
|
||||
auto existing_role = GetRole(username);
|
||||
@ -210,8 +196,7 @@ std::optional<User> Auth::AddUser(const std::string &username,
|
||||
bool Auth::RemoveUser(const std::string &username_orig) {
|
||||
auto username = utils::ToLowerCase(username_orig);
|
||||
if (!storage_.Get(kUserPrefix + username)) return false;
|
||||
std::vector<std::string> keys(
|
||||
{kLinkPrefix + username, kUserPrefix + username});
|
||||
std::vector<std::string> keys({kLinkPrefix + username, kUserPrefix + username});
|
||||
if (!storage_.DeleteMultiple(keys)) {
|
||||
throw AuthException("Couldn't remove user '{}'!", username);
|
||||
}
|
||||
@ -220,8 +205,7 @@ bool Auth::RemoveUser(const std::string &username_orig) {
|
||||
|
||||
std::vector<auth::User> Auth::AllUsers() {
|
||||
std::vector<auth::User> ret;
|
||||
for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix);
|
||||
++it) {
|
||||
for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) {
|
||||
auto username = it->first.substr(kUserPrefix.size());
|
||||
if (username != utils::ToLowerCase(username)) continue;
|
||||
auto user = GetUser(username);
|
||||
@ -232,9 +216,7 @@ std::vector<auth::User> Auth::AllUsers() {
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool Auth::HasUsers() {
|
||||
return storage_.begin(kUserPrefix) != storage_.end(kUserPrefix);
|
||||
}
|
||||
bool Auth::HasUsers() { return storage_.begin(kUserPrefix) != storage_.end(kUserPrefix); }
|
||||
|
||||
std::optional<Role> Auth::GetRole(const std::string &rolename_orig) {
|
||||
auto rolename = utils::ToLowerCase(rolename_orig);
|
||||
@ -271,8 +253,7 @@ bool Auth::RemoveRole(const std::string &rolename_orig) {
|
||||
auto rolename = utils::ToLowerCase(rolename_orig);
|
||||
if (!storage_.Get(kRolePrefix + rolename)) return false;
|
||||
std::vector<std::string> keys;
|
||||
for (auto it = storage_.begin(kLinkPrefix); it != storage_.end(kLinkPrefix);
|
||||
++it) {
|
||||
for (auto it = storage_.begin(kLinkPrefix); it != storage_.end(kLinkPrefix); ++it) {
|
||||
if (utils::ToLowerCase(it->second) == rolename) {
|
||||
keys.push_back(it->first);
|
||||
}
|
||||
@ -286,8 +267,7 @@ bool Auth::RemoveRole(const std::string &rolename_orig) {
|
||||
|
||||
std::vector<auth::Role> Auth::AllRoles() {
|
||||
std::vector<auth::Role> ret;
|
||||
for (auto it = storage_.begin(kRolePrefix); it != storage_.end(kRolePrefix);
|
||||
++it) {
|
||||
for (auto it = storage_.begin(kRolePrefix); it != storage_.end(kRolePrefix); ++it) {
|
||||
auto rolename = it->first.substr(kRolePrefix.size());
|
||||
if (rolename != utils::ToLowerCase(rolename)) continue;
|
||||
auto role = GetRole(rolename);
|
||||
@ -300,12 +280,10 @@ std::vector<auth::Role> Auth::AllRoles() {
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<auth::User> Auth::AllUsersForRole(
|
||||
const std::string &rolename_orig) {
|
||||
std::vector<auth::User> Auth::AllUsersForRole(const std::string &rolename_orig) {
|
||||
auto rolename = utils::ToLowerCase(rolename_orig);
|
||||
std::vector<auth::User> ret;
|
||||
for (auto it = storage_.begin(kLinkPrefix); it != storage_.end(kLinkPrefix);
|
||||
++it) {
|
||||
for (auto it = storage_.begin(kLinkPrefix); it != storage_.end(kLinkPrefix); ++it) {
|
||||
auto username = it->first.substr(kLinkPrefix.size());
|
||||
if (username != utils::ToLowerCase(username)) continue;
|
||||
if (it->second != utils::ToLowerCase(it->second)) continue;
|
||||
|
@ -32,8 +32,7 @@ class Auth final {
|
||||
* @return a user when the username and password match, nullopt otherwise
|
||||
* @throw AuthException if unable to authenticate for whatever reason.
|
||||
*/
|
||||
std::optional<User> Authenticate(const std::string &username,
|
||||
const std::string &password);
|
||||
std::optional<User> Authenticate(const std::string &username, const std::string &password);
|
||||
|
||||
/**
|
||||
* Gets a user from the storage.
|
||||
@ -63,9 +62,7 @@ class Auth final {
|
||||
* @return a user when the user is created, nullopt if the user exists
|
||||
* @throw AuthException if unable to save the user.
|
||||
*/
|
||||
std::optional<User> AddUser(
|
||||
const std::string &username,
|
||||
const std::optional<std::string> &password = std::nullopt);
|
||||
std::optional<User> AddUser(const std::string &username, const std::optional<std::string> &password = std::nullopt);
|
||||
|
||||
/**
|
||||
* Removes a user from the storage.
|
||||
|
@ -9,8 +9,7 @@
|
||||
#include "utils/cast.hpp"
|
||||
#include "utils/string.hpp"
|
||||
|
||||
DEFINE_bool(auth_password_permit_null, true,
|
||||
"Set to false to disable null passwords.");
|
||||
DEFINE_bool(auth_password_permit_null, true, "Set to false to disable null passwords.");
|
||||
|
||||
DEFINE_string(auth_password_strength_regex, ".+",
|
||||
"The regular expression that should be used to match the entire "
|
||||
@ -129,8 +128,7 @@ Permissions Permissions::Deserialize(const nlohmann::json &data) {
|
||||
if (!data.is_object()) {
|
||||
throw AuthException("Couldn't load permissions data!");
|
||||
}
|
||||
if (!data["grants"].is_number_unsigned() ||
|
||||
!data["denies"].is_number_unsigned()) {
|
||||
if (!data["grants"].is_number_unsigned() || !data["denies"].is_number_unsigned()) {
|
||||
throw AuthException("Couldn't load permissions data!");
|
||||
}
|
||||
return {data["grants"], data["denies"]};
|
||||
@ -143,12 +141,9 @@ bool operator==(const Permissions &first, const Permissions &second) {
|
||||
return first.grants() == second.grants() && first.denies() == second.denies();
|
||||
}
|
||||
|
||||
bool operator!=(const Permissions &first, const Permissions &second) {
|
||||
return !(first == second);
|
||||
}
|
||||
bool operator!=(const Permissions &first, const Permissions &second) { return !(first == second); }
|
||||
|
||||
Role::Role(const std::string &rolename)
|
||||
: rolename_(utils::ToLowerCase(rolename)) {}
|
||||
Role::Role(const std::string &rolename) : rolename_(utils::ToLowerCase(rolename)) {}
|
||||
|
||||
Role::Role(const std::string &rolename, const Permissions &permissions)
|
||||
: rolename_(utils::ToLowerCase(rolename)), permissions_(permissions) {}
|
||||
@ -176,18 +171,13 @@ Role Role::Deserialize(const nlohmann::json &data) {
|
||||
}
|
||||
|
||||
bool operator==(const Role &first, const Role &second) {
|
||||
return first.rolename_ == second.rolename_ &&
|
||||
first.permissions_ == second.permissions_;
|
||||
return first.rolename_ == second.rolename_ && first.permissions_ == second.permissions_;
|
||||
}
|
||||
|
||||
User::User(const std::string &username)
|
||||
: username_(utils::ToLowerCase(username)) {}
|
||||
User::User(const std::string &username) : username_(utils::ToLowerCase(username)) {}
|
||||
|
||||
User::User(const std::string &username, const std::string &password_hash,
|
||||
const Permissions &permissions)
|
||||
: username_(utils::ToLowerCase(username)),
|
||||
password_hash_(password_hash),
|
||||
permissions_(permissions) {}
|
||||
User::User(const std::string &username, const std::string &password_hash, const Permissions &permissions)
|
||||
: username_(utils::ToLowerCase(username)), password_hash_(password_hash), permissions_(permissions) {}
|
||||
|
||||
bool User::CheckPassword(const std::string &password) {
|
||||
if (password_hash_ == "") return true;
|
||||
@ -244,8 +234,7 @@ User User::Deserialize(const nlohmann::json &data) {
|
||||
if (!data.is_object()) {
|
||||
throw AuthException("Couldn't load user data!");
|
||||
}
|
||||
if (!data["username"].is_string() || !data["password_hash"].is_string() ||
|
||||
!data["permissions"].is_object()) {
|
||||
if (!data["username"].is_string() || !data["password_hash"].is_string() || !data["permissions"].is_object()) {
|
||||
throw AuthException("Couldn't load user data!");
|
||||
}
|
||||
auto permissions = Permissions::Deserialize(data["permissions"]);
|
||||
@ -253,9 +242,7 @@ User User::Deserialize(const nlohmann::json &data) {
|
||||
}
|
||||
|
||||
bool operator==(const User &first, const User &second) {
|
||||
return first.username_ == second.username_ &&
|
||||
first.password_hash_ == second.password_hash_ &&
|
||||
first.permissions_ == second.permissions_ &&
|
||||
first.role_ == second.role_;
|
||||
return first.username_ == second.username_ && first.password_hash_ == second.password_hash_ &&
|
||||
first.permissions_ == second.permissions_ && first.role_ == second.role_;
|
||||
}
|
||||
} // namespace auth
|
||||
|
@ -29,11 +29,9 @@ enum class Permission : uint64_t {
|
||||
|
||||
// Constant list of all available permissions.
|
||||
const std::vector<Permission> kPermissionsAll = {
|
||||
Permission::MATCH, Permission::CREATE, Permission::MERGE,
|
||||
Permission::DELETE, Permission::SET, Permission::REMOVE,
|
||||
Permission::INDEX, Permission::STATS, Permission::CONSTRAINT,
|
||||
Permission::DUMP, Permission::AUTH, Permission::REPLICATION,
|
||||
Permission::LOCK_PATH};
|
||||
Permission::MATCH, Permission::CREATE, Permission::MERGE, Permission::DELETE, Permission::SET,
|
||||
Permission::REMOVE, Permission::INDEX, Permission::STATS, Permission::CONSTRAINT, Permission::DUMP,
|
||||
Permission::AUTH, Permission::REPLICATION, Permission::LOCK_PATH};
|
||||
|
||||
// Function that converts a permission to its string representation.
|
||||
std::string PermissionToString(Permission permission);
|
||||
@ -110,15 +108,13 @@ class User final {
|
||||
public:
|
||||
User(const std::string &username);
|
||||
|
||||
User(const std::string &username, const std::string &password_hash,
|
||||
const Permissions &permissions);
|
||||
User(const std::string &username, const std::string &password_hash, const Permissions &permissions);
|
||||
|
||||
/// @throw AuthException if unable to verify the password.
|
||||
bool CheckPassword(const std::string &password);
|
||||
|
||||
/// @throw AuthException if unable to set the password.
|
||||
void UpdatePassword(
|
||||
const std::optional<std::string> &password = std::nullopt);
|
||||
void UpdatePassword(const std::optional<std::string> &password = std::nullopt);
|
||||
|
||||
void SetRole(const Role &role);
|
||||
|
||||
|
@ -86,54 +86,22 @@ class CharPP final {
|
||||
////////////////////////////////////
|
||||
|
||||
const std::vector<int> kSeccompSyscallsBlacklist = {
|
||||
SCMP_SYS(mknod),
|
||||
SCMP_SYS(mount),
|
||||
SCMP_SYS(setuid),
|
||||
SCMP_SYS(stime),
|
||||
SCMP_SYS(ptrace),
|
||||
SCMP_SYS(setgid),
|
||||
SCMP_SYS(acct),
|
||||
SCMP_SYS(umount),
|
||||
SCMP_SYS(setpgid),
|
||||
SCMP_SYS(chroot),
|
||||
SCMP_SYS(setreuid),
|
||||
SCMP_SYS(setregid),
|
||||
SCMP_SYS(sethostname),
|
||||
SCMP_SYS(settimeofday),
|
||||
SCMP_SYS(setgroups),
|
||||
SCMP_SYS(swapon),
|
||||
SCMP_SYS(reboot),
|
||||
SCMP_SYS(setpriority),
|
||||
SCMP_SYS(ioperm),
|
||||
SCMP_SYS(syslog),
|
||||
SCMP_SYS(iopl),
|
||||
SCMP_SYS(vhangup),
|
||||
SCMP_SYS(vm86old),
|
||||
SCMP_SYS(swapoff),
|
||||
SCMP_SYS(setdomainname),
|
||||
SCMP_SYS(adjtimex),
|
||||
SCMP_SYS(init_module),
|
||||
SCMP_SYS(delete_module),
|
||||
SCMP_SYS(setfsuid),
|
||||
SCMP_SYS(setfsgid),
|
||||
SCMP_SYS(setresuid),
|
||||
SCMP_SYS(vm86),
|
||||
SCMP_SYS(setresgid),
|
||||
SCMP_SYS(capset),
|
||||
SCMP_SYS(setreuid),
|
||||
SCMP_SYS(setregid),
|
||||
SCMP_SYS(setgroups),
|
||||
SCMP_SYS(setresuid),
|
||||
SCMP_SYS(setresgid),
|
||||
SCMP_SYS(setuid),
|
||||
SCMP_SYS(setgid),
|
||||
SCMP_SYS(setfsuid),
|
||||
SCMP_SYS(setfsgid),
|
||||
SCMP_SYS(pivot_root),
|
||||
SCMP_SYS(sched_setaffinity),
|
||||
SCMP_SYS(clock_settime),
|
||||
SCMP_SYS(kexec_load),
|
||||
SCMP_SYS(mknodat),
|
||||
SCMP_SYS(mknod), SCMP_SYS(mount), SCMP_SYS(setuid),
|
||||
SCMP_SYS(stime), SCMP_SYS(ptrace), SCMP_SYS(setgid),
|
||||
SCMP_SYS(acct), SCMP_SYS(umount), SCMP_SYS(setpgid),
|
||||
SCMP_SYS(chroot), SCMP_SYS(setreuid), SCMP_SYS(setregid),
|
||||
SCMP_SYS(sethostname), SCMP_SYS(settimeofday), SCMP_SYS(setgroups),
|
||||
SCMP_SYS(swapon), SCMP_SYS(reboot), SCMP_SYS(setpriority),
|
||||
SCMP_SYS(ioperm), SCMP_SYS(syslog), SCMP_SYS(iopl),
|
||||
SCMP_SYS(vhangup), SCMP_SYS(vm86old), SCMP_SYS(swapoff),
|
||||
SCMP_SYS(setdomainname), SCMP_SYS(adjtimex), SCMP_SYS(init_module),
|
||||
SCMP_SYS(delete_module), SCMP_SYS(setfsuid), SCMP_SYS(setfsgid),
|
||||
SCMP_SYS(setresuid), SCMP_SYS(vm86), SCMP_SYS(setresgid),
|
||||
SCMP_SYS(capset), SCMP_SYS(setreuid), SCMP_SYS(setregid),
|
||||
SCMP_SYS(setgroups), SCMP_SYS(setresuid), SCMP_SYS(setresgid),
|
||||
SCMP_SYS(setuid), SCMP_SYS(setgid), SCMP_SYS(setfsuid),
|
||||
SCMP_SYS(setfsgid), SCMP_SYS(pivot_root), SCMP_SYS(sched_setaffinity),
|
||||
SCMP_SYS(clock_settime), SCMP_SYS(kexec_load), SCMP_SYS(mknodat),
|
||||
SCMP_SYS(unshare),
|
||||
#ifdef SYS_seccomp
|
||||
SCMP_SYS(seccomp),
|
||||
@ -182,24 +150,20 @@ int Target(void *arg) {
|
||||
// Redirect `stdin` to `/dev/null`.
|
||||
int fd = open("/dev/null", O_RDONLY | O_CLOEXEC);
|
||||
if (fd == -1) {
|
||||
std::cerr
|
||||
<< "Couldn't open \"/dev/null\" for auth module stdin because of: "
|
||||
<< strerror(errno) << " (" << errno << ")!" << std::endl;
|
||||
std::cerr << "Couldn't open \"/dev/null\" for auth module stdin because of: " << strerror(errno) << " (" << errno
|
||||
<< ")!" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
if (dup2(fd, STDIN_FILENO) != STDIN_FILENO) {
|
||||
std::cerr
|
||||
<< "Couldn't attach \"/dev/null\" to auth module stdin because of: "
|
||||
<< strerror(errno) << " (" << errno << ")!" << std::endl;
|
||||
std::cerr << "Couldn't attach \"/dev/null\" to auth module stdin because of: " << strerror(errno) << " (" << errno
|
||||
<< ")!" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
// Change the current directory to the module directory.
|
||||
if (chdir(ta->module_executable_path.parent_path().c_str()) != 0) {
|
||||
std::cerr << "Couldn't change directory to "
|
||||
<< ta->module_executable_path.parent_path()
|
||||
<< " for auth module stdin because of: " << strerror(errno)
|
||||
<< " (" << errno << ")!" << std::endl;
|
||||
std::cerr << "Couldn't change directory to " << ta->module_executable_path.parent_path()
|
||||
<< " for auth module stdin because of: " << strerror(errno) << " (" << errno << ")!" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
@ -214,8 +178,7 @@ int Target(void *arg) {
|
||||
}
|
||||
|
||||
// Connect the communication input pipe.
|
||||
if (dup2(ta->pipe_to_module, kCommunicationToModuleFd) !=
|
||||
kCommunicationToModuleFd) {
|
||||
if (dup2(ta->pipe_to_module, kCommunicationToModuleFd) != kCommunicationToModuleFd) {
|
||||
std::cerr << "Couldn't attach communication to module pipe to auth module "
|
||||
"because of: "
|
||||
<< strerror(errno) << " (" << errno << ")!" << std::endl;
|
||||
@ -223,8 +186,7 @@ int Target(void *arg) {
|
||||
}
|
||||
|
||||
// Connect the communication output pipe.
|
||||
if (dup2(ta->pipe_from_module, kCommunicationFromModuleFd) !=
|
||||
kCommunicationFromModuleFd) {
|
||||
if (dup2(ta->pipe_from_module, kCommunicationFromModuleFd) != kCommunicationFromModuleFd) {
|
||||
std::cerr << "Couldn't attach communication from module pipe to auth "
|
||||
"module because of: "
|
||||
<< strerror(errno) << " (" << errno << ")!" << std::endl;
|
||||
@ -246,8 +208,8 @@ int Target(void *arg) {
|
||||
sigemptyset(&action.sa_mask);
|
||||
action.sa_flags = 0;
|
||||
if (sigaction(SIGINT, &action, nullptr) != 0) {
|
||||
std::cerr << "Couldn't ignore SIGINT for auth module because of: "
|
||||
<< strerror(errno) << " (" << errno << ")!" << std::endl;
|
||||
std::cerr << "Couldn't ignore SIGINT for auth module because of: " << strerror(errno) << " (" << errno << ")!"
|
||||
<< std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
@ -261,8 +223,7 @@ int Target(void *arg) {
|
||||
|
||||
// If the `execve` call succeeded then the process will exit from that call
|
||||
// and won't reach this piece of code ever.
|
||||
std::cerr << "Couldn't start auth module because of: " << strerror(errno)
|
||||
<< " (" << errno << ")!" << std::endl;
|
||||
std::cerr << "Couldn't start auth module because of: " << strerror(errno) << " (" << errno << ")!" << std::endl;
|
||||
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
@ -408,8 +369,7 @@ bool Module::Startup() {
|
||||
return true;
|
||||
}
|
||||
|
||||
nlohmann::json Module::Call(const nlohmann::json ¶ms,
|
||||
int timeout_millisec) {
|
||||
nlohmann::json Module::Call(const nlohmann::json ¶ms, int timeout_millisec) {
|
||||
std::lock_guard<std::mutex> guard(lock_);
|
||||
|
||||
if (!params.is_object()) return {};
|
||||
|
@ -43,16 +43,14 @@ class ClientFatalException : public utils::BasicException {
|
||||
// only handle the `ClientFatalException`.
|
||||
class ServerCommunicationException : public ClientFatalException {
|
||||
public:
|
||||
ServerCommunicationException()
|
||||
: ClientFatalException("Couldn't communicate with the server!") {}
|
||||
ServerCommunicationException() : ClientFatalException("Couldn't communicate with the server!") {}
|
||||
};
|
||||
|
||||
// Internal exception used whenever a malformed data error occurs. You should
|
||||
// only handle the `ClientFatalException`.
|
||||
class ServerMalformedDataException : public ClientFatalException {
|
||||
public:
|
||||
ServerMalformedDataException()
|
||||
: ClientFatalException("The server sent malformed data!") {}
|
||||
ServerMalformedDataException() : ClientFatalException("The server sent malformed data!") {}
|
||||
};
|
||||
|
||||
/// Structure that is used to return results from an executed query.
|
||||
@ -79,8 +77,7 @@ class Client final {
|
||||
/// connection is set-up, multiple queries may be executed through a single
|
||||
/// established connection.
|
||||
/// @throws ClientFatalException when we couldn't connect to the server
|
||||
void Connect(const io::network::Endpoint &endpoint,
|
||||
const std::string &username, const std::string &password,
|
||||
void Connect(const io::network::Endpoint &endpoint, const std::string &username, const std::string &password,
|
||||
const std::string &client_name = "memgraph-bolt") {
|
||||
if (!client_.Connect(endpoint)) {
|
||||
throw ClientFatalException("Couldn't connect to {}!", endpoint);
|
||||
@ -103,14 +100,11 @@ class Client final {
|
||||
}
|
||||
if (memcmp(kProtocol, client_.GetData(), sizeof(kProtocol)) != 0) {
|
||||
SPDLOG_ERROR("Server negotiated unsupported protocol version!");
|
||||
throw ClientFatalException(
|
||||
"The server negotiated an usupported protocol version!");
|
||||
throw ClientFatalException("The server negotiated an usupported protocol version!");
|
||||
}
|
||||
client_.ShiftData(sizeof(kProtocol));
|
||||
|
||||
if (!encoder_.MessageInit(client_name, {{"scheme", "basic"},
|
||||
{"principal", username},
|
||||
{"credentials", password}})) {
|
||||
if (!encoder_.MessageInit(client_name, {{"scheme", "basic"}, {"principal", username}, {"credentials", password}})) {
|
||||
SPDLOG_ERROR("Couldn't send init message!");
|
||||
throw ServerCommunicationException();
|
||||
}
|
||||
@ -135,15 +129,12 @@ class Client final {
|
||||
/// executing the query (eg. mistyped query,
|
||||
/// etc.)
|
||||
/// @throws ClientFatalException when we couldn't communicate with the server
|
||||
QueryData Execute(const std::string &query,
|
||||
const std::map<std::string, Value> ¶meters) {
|
||||
QueryData Execute(const std::string &query, const std::map<std::string, Value> ¶meters) {
|
||||
if (!client_.IsConnected()) {
|
||||
throw ClientFatalException(
|
||||
"You must first connect to the server before using the client!");
|
||||
throw ClientFatalException("You must first connect to the server before using the client!");
|
||||
}
|
||||
|
||||
SPDLOG_INFO("Sending run message with statement: '{}'; parameters: {}",
|
||||
query, parameters);
|
||||
SPDLOG_INFO("Sending run message with statement: '{}'; parameters: {}", query, parameters);
|
||||
|
||||
encoder_.MessageRun(query, parameters);
|
||||
encoder_.MessagePullAll();
|
||||
@ -165,8 +156,7 @@ class Client final {
|
||||
if (it != tmp.end()) {
|
||||
auto it_code = tmp.find("code");
|
||||
if (it_code != tmp.end()) {
|
||||
throw ClientQueryException(it_code->second.ValueString(),
|
||||
it->second.ValueString());
|
||||
throw ClientQueryException(it_code->second.ValueString(), it->second.ValueString());
|
||||
} else {
|
||||
throw ClientQueryException("", it->second.ValueString());
|
||||
}
|
||||
@ -209,8 +199,7 @@ class Client final {
|
||||
if (it != tmp.end()) {
|
||||
auto it_code = tmp.find("code");
|
||||
if (it_code != tmp.end()) {
|
||||
throw ClientQueryException(it_code->second.ValueString(),
|
||||
it->second.ValueString());
|
||||
throw ClientQueryException(it_code->second.ValueString(), it->second.ValueString());
|
||||
} else {
|
||||
throw ClientQueryException("", it->second.ValueString());
|
||||
}
|
||||
@ -308,15 +297,11 @@ class Client final {
|
||||
communication::ClientOutputStream output_stream_{client_};
|
||||
|
||||
// decoder objects
|
||||
ChunkedDecoderBuffer<communication::ClientInputStream> decoder_buffer_{
|
||||
input_stream_};
|
||||
Decoder<ChunkedDecoderBuffer<communication::ClientInputStream>> decoder_{
|
||||
decoder_buffer_};
|
||||
ChunkedDecoderBuffer<communication::ClientInputStream> decoder_buffer_{input_stream_};
|
||||
Decoder<ChunkedDecoderBuffer<communication::ClientInputStream>> decoder_{decoder_buffer_};
|
||||
|
||||
// encoder objects
|
||||
ChunkedEncoderBuffer<communication::ClientOutputStream> encoder_buffer_{
|
||||
output_stream_};
|
||||
ClientEncoder<ChunkedEncoderBuffer<communication::ClientOutputStream>>
|
||||
encoder_{encoder_buffer_};
|
||||
ChunkedEncoderBuffer<communication::ClientOutputStream> encoder_buffer_{output_stream_};
|
||||
ClientEncoder<ChunkedEncoderBuffer<communication::ClientOutputStream>> encoder_{encoder_buffer_};
|
||||
};
|
||||
} // namespace communication::bolt
|
||||
|
@ -78,12 +78,8 @@ enum class Marker : uint8_t {
|
||||
};
|
||||
|
||||
static constexpr uint8_t MarkerString = 0, MarkerList = 1, MarkerMap = 2;
|
||||
static constexpr Marker MarkerTiny[3] = {Marker::TinyString, Marker::TinyList,
|
||||
Marker::TinyMap};
|
||||
static constexpr Marker Marker8[3] = {Marker::String8, Marker::List8,
|
||||
Marker::Map8};
|
||||
static constexpr Marker Marker16[3] = {Marker::String16, Marker::List16,
|
||||
Marker::Map16};
|
||||
static constexpr Marker Marker32[3] = {Marker::String32, Marker::List32,
|
||||
Marker::Map32};
|
||||
static constexpr Marker MarkerTiny[3] = {Marker::TinyString, Marker::TinyList, Marker::TinyMap};
|
||||
static constexpr Marker Marker8[3] = {Marker::String8, Marker::List8, Marker::Map8};
|
||||
static constexpr Marker Marker16[3] = {Marker::String16, Marker::List16, Marker::Map16};
|
||||
static constexpr Marker Marker32[3] = {Marker::String32, Marker::List32, Marker::Map32};
|
||||
} // namespace communication::bolt
|
||||
|
@ -40,9 +40,7 @@ enum class ChunkState : uint8_t {
|
||||
template <typename TBuffer>
|
||||
class ChunkedDecoderBuffer {
|
||||
public:
|
||||
ChunkedDecoderBuffer(TBuffer &buffer) : buffer_(buffer) {
|
||||
data_.reserve(kChunkMaxDataSize);
|
||||
}
|
||||
ChunkedDecoderBuffer(TBuffer &buffer) : buffer_(buffer) { data_.reserve(kChunkMaxDataSize); }
|
||||
|
||||
/**
|
||||
* Reads data from the internal buffer.
|
||||
|
@ -158,8 +158,7 @@ class Decoder {
|
||||
}
|
||||
|
||||
bool ReadBool(const Marker &marker, Value *data) {
|
||||
DMG_ASSERT(marker == Marker::False || marker == Marker::True,
|
||||
"Received invalid marker!");
|
||||
DMG_ASSERT(marker == Marker::False || marker == Marker::True, "Received invalid marker!");
|
||||
if (marker == Marker::False) {
|
||||
*data = Value(false);
|
||||
} else {
|
||||
|
@ -7,8 +7,7 @@
|
||||
#include "utils/cast.hpp"
|
||||
#include "utils/endian.hpp"
|
||||
|
||||
static_assert(std::is_same_v<std::uint8_t, char> ||
|
||||
std::is_same_v<std::uint8_t, unsigned char>,
|
||||
static_assert(std::is_same_v<std::uint8_t, char> || std::is_same_v<std::uint8_t, unsigned char>,
|
||||
"communication::bolt::Encoder requires uint8_t to be "
|
||||
"implemented as char or unsigned char.");
|
||||
|
||||
@ -29,9 +28,7 @@ class BaseEncoder {
|
||||
|
||||
void WriteRAW(const uint8_t *data, uint64_t len) { buffer_.Write(data, len); }
|
||||
|
||||
void WriteRAW(const char *data, uint64_t len) {
|
||||
WriteRAW((const uint8_t *)data, len);
|
||||
}
|
||||
void WriteRAW(const char *data, uint64_t len) { WriteRAW((const uint8_t *)data, len); }
|
||||
|
||||
void WriteRAW(const uint8_t data) { WriteRAW(&data, 1); }
|
||||
|
||||
@ -126,8 +123,7 @@ class BaseEncoder {
|
||||
|
||||
void WriteEdge(const Edge &edge, bool unbound = false) {
|
||||
WriteRAW(utils::UnderlyingCast(Marker::TinyStruct) + (unbound ? 3 : 5));
|
||||
WriteRAW(utils::UnderlyingCast(unbound ? Signature::UnboundRelationship
|
||||
: Signature::Relationship));
|
||||
WriteRAW(utils::UnderlyingCast(unbound ? Signature::UnboundRelationship : Signature::Relationship));
|
||||
|
||||
WriteInt(edge.id.AsInt());
|
||||
if (!unbound) {
|
||||
|
@ -37,8 +37,7 @@ namespace communication::bolt {
|
||||
template <class TOutputStream>
|
||||
class ChunkedEncoderBuffer {
|
||||
public:
|
||||
ChunkedEncoderBuffer(TOutputStream &output_stream)
|
||||
: output_stream_(output_stream) {}
|
||||
ChunkedEncoderBuffer(TOutputStream &output_stream) : output_stream_(output_stream) {}
|
||||
|
||||
/**
|
||||
* Writes n values into the buffer. If n is bigger than whole chunk size
|
||||
@ -53,12 +52,10 @@ class ChunkedEncoderBuffer {
|
||||
while (n > 0) {
|
||||
// Define the number of bytes which will be copied into the chunk because
|
||||
// the internal storage is a fixed length array.
|
||||
size_t size =
|
||||
n < kChunkMaxDataSize - have_ ? n : kChunkMaxDataSize - have_;
|
||||
size_t size = n < kChunkMaxDataSize - have_ ? n : kChunkMaxDataSize - have_;
|
||||
|
||||
// Copy `size` values to the chunk array.
|
||||
std::memcpy(chunk_.data() + kChunkHeaderSize + have_, values + written,
|
||||
size);
|
||||
std::memcpy(chunk_.data() + kChunkHeaderSize + have_, values + written, size);
|
||||
|
||||
// Update positions. The position pointer and incoming size have to be
|
||||
// updated because all incoming values have to be processed.
|
||||
@ -87,8 +84,7 @@ class ChunkedEncoderBuffer {
|
||||
chunk_[1] = have_ & 0xFF;
|
||||
|
||||
// Write the data to the stream.
|
||||
auto ret = output_stream_.Write(chunk_.data(), kChunkHeaderSize + have_,
|
||||
have_more);
|
||||
auto ret = output_stream_.Write(chunk_.data(), kChunkHeaderSize + have_, have_more);
|
||||
|
||||
// Cleanup.
|
||||
Clear();
|
||||
|
@ -38,8 +38,7 @@ class ClientEncoder : private BaseEncoder<Buffer> {
|
||||
* @returns true if the data was successfully sent to the client
|
||||
* when flushing, false otherwise
|
||||
*/
|
||||
bool MessageInit(const std::string client_name,
|
||||
const std::map<std::string, Value> &auth_token) {
|
||||
bool MessageInit(const std::string client_name, const std::map<std::string, Value> &auth_token) {
|
||||
WriteRAW(utils::UnderlyingCast(Marker::TinyStruct2));
|
||||
WriteRAW(utils::UnderlyingCast(Signature::Init));
|
||||
WriteString(client_name);
|
||||
@ -65,9 +64,7 @@ class ClientEncoder : private BaseEncoder<Buffer> {
|
||||
* @returns true if the data was successfully sent to the client
|
||||
* when flushing, false otherwise
|
||||
*/
|
||||
bool MessageRun(const std::string &statement,
|
||||
const std::map<std::string, Value> ¶meters,
|
||||
bool have_more = true) {
|
||||
bool MessageRun(const std::string &statement, const std::map<std::string, Value> ¶meters, bool have_more = true) {
|
||||
WriteRAW(utils::UnderlyingCast(Marker::TinyStruct2));
|
||||
WriteRAW(utils::UnderlyingCast(Signature::Run));
|
||||
WriteString(statement);
|
||||
|
@ -50,13 +50,10 @@ class VerboseError : public utils::BasicException {
|
||||
};
|
||||
|
||||
template <class... Args>
|
||||
VerboseError(Classification classification, const std::string &category,
|
||||
const std::string &title, const std::string &format,
|
||||
Args &&... args)
|
||||
VerboseError(Classification classification, const std::string &category, const std::string &title,
|
||||
const std::string &format, Args &&...args)
|
||||
: BasicException(format, std::forward<Args>(args)...),
|
||||
code_(fmt::format("Memgraph.{}.{}.{}",
|
||||
ClassificationToString(classification), category,
|
||||
title)) {}
|
||||
code_(fmt::format("Memgraph.{}.{}.{}", ClassificationToString(classification), category, title)) {}
|
||||
|
||||
const std::string &code() const noexcept { return code_; }
|
||||
|
||||
|
@ -62,9 +62,7 @@ class Session {
|
||||
* @param q If set, defines from which query to pull the results,
|
||||
* otherwise the last query is used.
|
||||
*/
|
||||
virtual std::map<std::string, Value> Pull(TEncoder *encoder,
|
||||
std::optional<int> n,
|
||||
std::optional<int> qid) = 0;
|
||||
virtual std::map<std::string, Value> Pull(TEncoder *encoder, std::optional<int> n, std::optional<int> qid) = 0;
|
||||
|
||||
/**
|
||||
* Discard results of the processed query.
|
||||
@ -74,8 +72,7 @@ class Session {
|
||||
* @param q If set, defines from which query to discard the results,
|
||||
* otherwise the last query is used.
|
||||
*/
|
||||
virtual std::map<std::string, Value> Discard(std::optional<int> n,
|
||||
std::optional<int> qid) = 0;
|
||||
virtual std::map<std::string, Value> Discard(std::optional<int> n, std::optional<int> qid) = 0;
|
||||
|
||||
virtual void BeginTransaction() = 0;
|
||||
virtual void CommitTransaction() = 0;
|
||||
@ -85,8 +82,7 @@ class Session {
|
||||
virtual void Abort() = 0;
|
||||
|
||||
/** Return `true` if the user was successfully authenticated. */
|
||||
virtual bool Authenticate(const std::string &username,
|
||||
const std::string &password) = 0;
|
||||
virtual bool Authenticate(const std::string &username, const std::string &password) = 0;
|
||||
|
||||
/** Return the name of the server that should be used for the Bolt INIT
|
||||
* message. */
|
||||
@ -104,8 +100,7 @@ class Session {
|
||||
|
||||
// Receive the handshake.
|
||||
if (input_stream_.size() < kHandshakeSize) {
|
||||
spdlog::trace("Received partial handshake of size {}",
|
||||
input_stream_.size());
|
||||
spdlog::trace("Received partial handshake of size {}", input_stream_.size());
|
||||
return;
|
||||
}
|
||||
state_ = StateHandshakeRun(*this);
|
||||
|
@ -44,4 +44,4 @@ enum class State : uint8_t {
|
||||
*/
|
||||
Close
|
||||
};
|
||||
}
|
||||
} // namespace communication::bolt
|
||||
|
@ -26,8 +26,7 @@ State StateErrorRun(TSession &session, State state) {
|
||||
return State::Close;
|
||||
}
|
||||
|
||||
if (UNLIKELY(signature == Signature::Noop && session.version_.major == 4 &&
|
||||
session.version_.minor == 1)) {
|
||||
if (UNLIKELY(signature == Signature::Noop && session.version_.major == 4 && session.version_.minor == 1)) {
|
||||
spdlog::trace("Received NOOP message");
|
||||
return state;
|
||||
}
|
||||
@ -35,8 +34,7 @@ State StateErrorRun(TSession &session, State state) {
|
||||
// Clear the data buffer if it has any leftover data.
|
||||
session.encoder_buffer_.Clear();
|
||||
|
||||
if ((session.version_.major == 1 && signature == Signature::AckFailure) ||
|
||||
signature == Signature::Reset) {
|
||||
if ((session.version_.major == 1 && signature == Signature::AckFailure) || signature == Signature::Reset) {
|
||||
if (signature == Signature::AckFailure) {
|
||||
spdlog::trace("AckFailure received");
|
||||
} else {
|
||||
@ -62,8 +60,7 @@ State StateErrorRun(TSession &session, State state) {
|
||||
// All bolt client messages have less than 15 parameters so if we receive
|
||||
// anything than a TinyStruct it's an error.
|
||||
if ((value & 0xF0) != utils::UnderlyingCast(Marker::TinyStruct)) {
|
||||
spdlog::trace("Expected TinyStruct marker, but received 0x{:02X}!",
|
||||
value);
|
||||
spdlog::trace("Expected TinyStruct marker, but received 0x{:02X}!", value);
|
||||
return State::Close;
|
||||
}
|
||||
|
||||
|
@ -16,8 +16,7 @@
|
||||
namespace communication::bolt {
|
||||
|
||||
// TODO (mferencevic): revise these error messages
|
||||
inline std::pair<std::string, std::string> ExceptionToErrorMessage(
|
||||
const std::exception &e) {
|
||||
inline std::pair<std::string, std::string> ExceptionToErrorMessage(const std::exception &e) {
|
||||
if (auto *verbose = dynamic_cast<const VerboseError *>(&e)) {
|
||||
return {verbose->code(), verbose->what()};
|
||||
}
|
||||
@ -54,8 +53,7 @@ inline std::pair<std::string, std::string> ExceptionToErrorMessage(
|
||||
// All exceptions used in memgraph are derived from BasicException. Since
|
||||
// we caught some other exception we don't know what is going on. Return
|
||||
// DatabaseError, log real message and return generic string.
|
||||
spdlog::error("Unknown exception occurred during query execution {}",
|
||||
e.what());
|
||||
spdlog::error("Unknown exception occurred during query execution {}", e.what());
|
||||
return {"Memgraph.DatabaseError.MemgraphError.MemgraphError",
|
||||
"An unknown exception occurred, this is unexpected. Real message "
|
||||
"should be in database logs."};
|
||||
@ -69,8 +67,7 @@ inline State HandleFailure(TSession &session, const std::exception &e) {
|
||||
}
|
||||
session.encoder_buffer_.Clear();
|
||||
auto code_message = ExceptionToErrorMessage(e);
|
||||
bool fail_sent = session.encoder_.MessageFailure(
|
||||
{{"code", code_message.first}, {"message", code_message.second}});
|
||||
bool fail_sent = session.encoder_.MessageFailure({{"code", code_message.first}, {"message", code_message.second}});
|
||||
if (!fail_sent) {
|
||||
spdlog::trace("Couldn't send failure message!");
|
||||
return State::Close;
|
||||
@ -80,15 +77,12 @@ inline State HandleFailure(TSession &session, const std::exception &e) {
|
||||
|
||||
template <typename TSession>
|
||||
State HandleRun(TSession &session, State state, Marker marker) {
|
||||
const std::map<std::string, Value> kEmptyFields = {
|
||||
{"fields", std::vector<Value>{}}};
|
||||
const std::map<std::string, Value> kEmptyFields = {{"fields", std::vector<Value>{}}};
|
||||
|
||||
const auto expected_marker =
|
||||
session.version_.major == 1 ? Marker::TinyStruct2 : Marker::TinyStruct3;
|
||||
const auto expected_marker = session.version_.major == 1 ? Marker::TinyStruct2 : Marker::TinyStruct3;
|
||||
if (marker != expected_marker) {
|
||||
spdlog::trace("Expected {} marker, but received 0x{:02X}!",
|
||||
session.version_.major == 1 ? "TinyStruct2" : "TinyStruct3",
|
||||
utils::UnderlyingCast(marker));
|
||||
session.version_.major == 1 ? "TinyStruct2" : "TinyStruct3", utils::UnderlyingCast(marker));
|
||||
return State::Close;
|
||||
}
|
||||
|
||||
@ -117,15 +111,13 @@ State HandleRun(TSession &session, State state, Marker marker) {
|
||||
return State::Close;
|
||||
}
|
||||
|
||||
DMG_ASSERT(!session.encoder_buffer_.HasData(),
|
||||
"There should be no data to write in this state");
|
||||
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
|
||||
|
||||
spdlog::debug("[Run] '{}'", query.ValueString());
|
||||
|
||||
try {
|
||||
// Interpret can throw.
|
||||
auto [header, qid] =
|
||||
session.Interpret(query.ValueString(), params.ValueMap());
|
||||
auto [header, qid] = session.Interpret(query.ValueString(), params.ValueMap());
|
||||
// Convert std::string to Value
|
||||
std::vector<Value> vec;
|
||||
std::map<std::string, Value> data;
|
||||
@ -146,12 +138,10 @@ State HandleRun(TSession &session, State state, Marker marker) {
|
||||
namespace detail {
|
||||
template <bool is_pull, typename TSession>
|
||||
State HandlePullDiscard(TSession &session, State state, Marker marker) {
|
||||
const auto expected_marker =
|
||||
session.version_.major == 1 ? Marker::TinyStruct : Marker::TinyStruct1;
|
||||
const auto expected_marker = session.version_.major == 1 ? Marker::TinyStruct : Marker::TinyStruct1;
|
||||
if (marker != expected_marker) {
|
||||
spdlog::trace("Expected {} marker, but received 0x{:02X}!",
|
||||
session.version_.major == 1 ? "TinyStruct" : "TinyStruct1",
|
||||
utils::UnderlyingCast(marker));
|
||||
session.version_.major == 1 ? "TinyStruct" : "TinyStruct1", utils::UnderlyingCast(marker));
|
||||
return State::Close;
|
||||
}
|
||||
|
||||
@ -176,15 +166,13 @@ State HandlePullDiscard(TSession &session, State state, Marker marker) {
|
||||
}
|
||||
const auto &extra_map = extra.ValueMap();
|
||||
if (extra_map.count("n")) {
|
||||
if (const auto n_value = extra_map.at("n").ValueInt();
|
||||
n_value != kPullAll) {
|
||||
if (const auto n_value = extra_map.at("n").ValueInt(); n_value != kPullAll) {
|
||||
n = n_value;
|
||||
}
|
||||
}
|
||||
|
||||
if (extra_map.count("qid")) {
|
||||
if (const auto qid_value = extra_map.at("qid").ValueInt();
|
||||
qid_value != kPullLast) {
|
||||
if (const auto qid_value = extra_map.at("qid").ValueInt(); qid_value != kPullLast) {
|
||||
qid = qid_value;
|
||||
}
|
||||
}
|
||||
@ -236,8 +224,7 @@ State HandleReset(Session &session, State, Marker marker) {
|
||||
// now this command only resets the session to a clean state. It
|
||||
// does not IGNORE running and pending commands as it should.
|
||||
if (marker != Marker::TinyStruct) {
|
||||
spdlog::trace("Expected TinyStruct marker, but received 0x{:02X}!",
|
||||
utils::UnderlyingCast(marker));
|
||||
spdlog::trace("Expected TinyStruct marker, but received 0x{:02X}!", utils::UnderlyingCast(marker));
|
||||
return State::Close;
|
||||
}
|
||||
|
||||
@ -262,8 +249,7 @@ State HandleBegin(Session &session, State state, Marker marker) {
|
||||
}
|
||||
|
||||
if (marker != Marker::TinyStruct1) {
|
||||
spdlog::trace("Expected TinyStruct1 marker, but received 0x{:02x}!",
|
||||
utils::UnderlyingCast(marker));
|
||||
spdlog::trace("Expected TinyStruct1 marker, but received 0x{:02x}!", utils::UnderlyingCast(marker));
|
||||
return State::Close;
|
||||
}
|
||||
|
||||
@ -278,8 +264,7 @@ State HandleBegin(Session &session, State state, Marker marker) {
|
||||
return State::Close;
|
||||
}
|
||||
|
||||
DMG_ASSERT(!session.encoder_buffer_.HasData(),
|
||||
"There should be no data to write in this state");
|
||||
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
|
||||
|
||||
if (!session.encoder_.MessageSuccess({})) {
|
||||
spdlog::trace("Couldn't send success message!");
|
||||
@ -303,8 +288,7 @@ State HandleCommit(Session &session, State state, Marker marker) {
|
||||
}
|
||||
|
||||
if (marker != Marker::TinyStruct) {
|
||||
spdlog::trace("Expected TinyStruct marker, but received 0x{:02x}!",
|
||||
utils::UnderlyingCast(marker));
|
||||
spdlog::trace("Expected TinyStruct marker, but received 0x{:02x}!", utils::UnderlyingCast(marker));
|
||||
return State::Close;
|
||||
}
|
||||
|
||||
@ -313,8 +297,7 @@ State HandleCommit(Session &session, State state, Marker marker) {
|
||||
return State::Close;
|
||||
}
|
||||
|
||||
DMG_ASSERT(!session.encoder_buffer_.HasData(),
|
||||
"There should be no data to write in this state");
|
||||
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
|
||||
|
||||
try {
|
||||
if (!session.encoder_.MessageSuccess({})) {
|
||||
@ -336,8 +319,7 @@ State HandleRollback(Session &session, State state, Marker marker) {
|
||||
}
|
||||
|
||||
if (marker != Marker::TinyStruct) {
|
||||
spdlog::trace("Expected TinyStruct marker, but received 0x{:02x}!",
|
||||
utils::UnderlyingCast(marker));
|
||||
spdlog::trace("Expected TinyStruct marker, but received 0x{:02x}!", utils::UnderlyingCast(marker));
|
||||
return State::Close;
|
||||
}
|
||||
|
||||
@ -346,8 +328,7 @@ State HandleRollback(Session &session, State state, Marker marker) {
|
||||
return State::Close;
|
||||
}
|
||||
|
||||
DMG_ASSERT(!session.encoder_buffer_.HasData(),
|
||||
"There should be no data to write in this state");
|
||||
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
|
||||
|
||||
try {
|
||||
if (!session.encoder_.MessageSuccess({})) {
|
||||
@ -376,8 +357,7 @@ State StateExecutingRun(Session &session, State state) {
|
||||
return State::Close;
|
||||
}
|
||||
|
||||
if (UNLIKELY(signature == Signature::Noop && session.version_.major == 4 &&
|
||||
session.version_.minor == 1)) {
|
||||
if (UNLIKELY(signature == Signature::Noop && session.version_.major == 4 && session.version_.minor == 1)) {
|
||||
spdlog::trace("Received NOOP message");
|
||||
return state;
|
||||
}
|
||||
@ -399,8 +379,7 @@ State StateExecutingRun(Session &session, State state) {
|
||||
} else if (signature == Signature::Goodbye && session.version_.major != 1) {
|
||||
throw SessionClosedException("Closing connection.");
|
||||
} else {
|
||||
spdlog::trace("Unrecognized signature received (0x{:02X})!",
|
||||
utils::UnderlyingCast(signature));
|
||||
spdlog::trace("Unrecognized signature received (0x{:02X})!", utils::UnderlyingCast(signature));
|
||||
return State::Close;
|
||||
}
|
||||
}
|
||||
|
@ -17,15 +17,13 @@ namespace communication::bolt {
|
||||
*/
|
||||
template <typename TSession>
|
||||
State StateHandshakeRun(TSession &session) {
|
||||
auto precmp =
|
||||
std::memcmp(session.input_stream_.data(), kPreamble, sizeof(kPreamble));
|
||||
auto precmp = std::memcmp(session.input_stream_.data(), kPreamble, sizeof(kPreamble));
|
||||
if (UNLIKELY(precmp != 0)) {
|
||||
spdlog::trace("Received a wrong preamble!");
|
||||
return State::Close;
|
||||
}
|
||||
|
||||
DMG_ASSERT(session.input_stream_.size() >= kHandshakeSize,
|
||||
"Wrong size of the handshake data!");
|
||||
DMG_ASSERT(session.input_stream_.size() >= kHandshakeSize, "Wrong size of the handshake data!");
|
||||
|
||||
auto dataPosition = session.input_stream_.data() + sizeof(kPreamble);
|
||||
|
||||
@ -61,8 +59,7 @@ State StateHandshakeRun(TSession &session) {
|
||||
return State::Close;
|
||||
}
|
||||
|
||||
spdlog::info("Using version {}.{} of protocol", session.version_.major,
|
||||
session.version_.minor);
|
||||
spdlog::info("Using version {}.{} of protocol", session.version_.major, session.version_.minor);
|
||||
|
||||
// Delete data from the input stream. It is guaranteed that there will more
|
||||
// than, or equal to 20 bytes (kHandshakeSize) in the buffer.
|
||||
|
@ -15,8 +15,7 @@ namespace detail {
|
||||
template <typename TSession>
|
||||
std::optional<Value> StateInitRunV1(TSession &session, const Marker marker) {
|
||||
if (UNLIKELY(marker != Marker::TinyStruct2)) {
|
||||
spdlog::trace("Expected TinyStruct2 marker, but received 0x{:02X}!",
|
||||
utils::UnderlyingCast(marker));
|
||||
spdlog::trace("Expected TinyStruct2 marker, but received 0x{:02X}!", utils::UnderlyingCast(marker));
|
||||
spdlog::trace(
|
||||
"The client sent malformed data, but we are continuing "
|
||||
"because the official Neo4j Java driver sends malformed "
|
||||
@ -45,8 +44,7 @@ std::optional<Value> StateInitRunV1(TSession &session, const Marker marker) {
|
||||
template <typename TSession>
|
||||
std::optional<Value> StateInitRunV4(TSession &session, const Marker marker) {
|
||||
if (UNLIKELY(marker != Marker::TinyStruct1)) {
|
||||
spdlog::trace("Expected TinyStruct1 marker, but received 0x{:02X}!",
|
||||
utils::UnderlyingCast(marker));
|
||||
spdlog::trace("Expected TinyStruct1 marker, but received 0x{:02X}!", utils::UnderlyingCast(marker));
|
||||
spdlog::trace(
|
||||
"The client sent malformed data, but we are continuing "
|
||||
"because the official Neo4j Java driver sends malformed "
|
||||
@ -80,8 +78,7 @@ std::optional<Value> StateInitRunV4(TSession &session, const Marker marker) {
|
||||
*/
|
||||
template <typename Session>
|
||||
State StateInitRun(Session &session) {
|
||||
DMG_ASSERT(!session.encoder_buffer_.HasData(),
|
||||
"There should be no data to write in this state");
|
||||
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
|
||||
|
||||
Marker marker;
|
||||
Signature signature;
|
||||
@ -90,21 +87,18 @@ State StateInitRun(Session &session) {
|
||||
return State::Close;
|
||||
}
|
||||
|
||||
if (UNLIKELY(signature == Signature::Noop && session.version_.major == 4 &&
|
||||
session.version_.minor == 1)) {
|
||||
if (UNLIKELY(signature == Signature::Noop && session.version_.major == 4 && session.version_.minor == 1)) {
|
||||
SPDLOG_DEBUG("Received NOOP message");
|
||||
return State::Init;
|
||||
}
|
||||
|
||||
if (UNLIKELY(signature != Signature::Init)) {
|
||||
spdlog::trace("Expected Init signature, but received 0x{:02X}!",
|
||||
utils::UnderlyingCast(signature));
|
||||
spdlog::trace("Expected Init signature, but received 0x{:02X}!", utils::UnderlyingCast(signature));
|
||||
return State::Close;
|
||||
}
|
||||
|
||||
auto maybeMetadata = session.version_.major == 1
|
||||
? detail::StateInitRunV1(session, marker)
|
||||
: detail::StateInitRunV4(session, marker);
|
||||
auto maybeMetadata =
|
||||
session.version_.major == 1 ? detail::StateInitRunV1(session, marker) : detail::StateInitRunV4(session, marker);
|
||||
|
||||
if (!maybeMetadata) {
|
||||
return State::Close;
|
||||
@ -126,16 +120,14 @@ State StateInitRun(Session &session) {
|
||||
username = data["principal"].ValueString();
|
||||
password = data["credentials"].ValueString();
|
||||
} else if (data["scheme"].ValueString() != "none") {
|
||||
spdlog::warn("Unsupported authentication scheme: {}",
|
||||
data["scheme"].ValueString());
|
||||
spdlog::warn("Unsupported authentication scheme: {}", data["scheme"].ValueString());
|
||||
return State::Close;
|
||||
}
|
||||
|
||||
// Authenticate the user.
|
||||
if (!session.Authenticate(username, password)) {
|
||||
if (!session.encoder_.MessageFailure(
|
||||
{{"code", "Memgraph.ClientError.Security.Unauthenticated"},
|
||||
{"message", "Authentication failure"}})) {
|
||||
{{"code", "Memgraph.ClientError.Security.Unauthenticated"}, {"message", "Authentication failure"}})) {
|
||||
spdlog::trace("Couldn't send failure message to the client!");
|
||||
}
|
||||
// Throw an exception to indicate to the network stack that the session
|
||||
|
@ -198,8 +198,7 @@ Value &Value::operator=(Value &&other) noexcept {
|
||||
new (&edge_v) Edge(std::move(other.edge_v));
|
||||
break;
|
||||
case Type::UnboundedEdge:
|
||||
new (&unbounded_edge_v)
|
||||
UnboundedEdge(std::move(other.unbounded_edge_v));
|
||||
new (&unbounded_edge_v) UnboundedEdge(std::move(other.unbounded_edge_v));
|
||||
break;
|
||||
case Type::Path:
|
||||
new (&path_v) Path(std::move(other.path_v));
|
||||
@ -258,17 +257,14 @@ std::ostream &operator<<(std::ostream &os, const Vertex &vertex) {
|
||||
if (vertex.labels.size() > 0) {
|
||||
os << ":";
|
||||
}
|
||||
utils::PrintIterable(os, vertex.labels, ":",
|
||||
[&](auto &stream, auto label) { stream << label; });
|
||||
utils::PrintIterable(os, vertex.labels, ":", [&](auto &stream, auto label) { stream << label; });
|
||||
if (vertex.labels.size() > 0 && vertex.properties.size() > 0) {
|
||||
os << " ";
|
||||
}
|
||||
if (vertex.properties.size() > 0) {
|
||||
os << "{";
|
||||
utils::PrintIterable(os, vertex.properties, ", ",
|
||||
[&](auto &stream, const auto &pair) {
|
||||
stream << pair.first << ": " << pair.second;
|
||||
});
|
||||
[&](auto &stream, const auto &pair) { stream << pair.first << ": " << pair.second; });
|
||||
os << "}";
|
||||
}
|
||||
return os << ")";
|
||||
@ -279,9 +275,7 @@ std::ostream &operator<<(std::ostream &os, const Edge &edge) {
|
||||
if (edge.properties.size() > 0) {
|
||||
os << " {";
|
||||
utils::PrintIterable(os, edge.properties, ", ",
|
||||
[&](auto &stream, const auto &pair) {
|
||||
stream << pair.first << ": " << pair.second;
|
||||
});
|
||||
[&](auto &stream, const auto &pair) { stream << pair.first << ": " << pair.second; });
|
||||
os << "}";
|
||||
}
|
||||
return os << "]";
|
||||
@ -292,9 +286,7 @@ std::ostream &operator<<(std::ostream &os, const UnboundedEdge &edge) {
|
||||
if (edge.properties.size() > 0) {
|
||||
os << " {";
|
||||
utils::PrintIterable(os, edge.properties, ", ",
|
||||
[&](auto &stream, const auto &pair) {
|
||||
stream << pair.first << ": " << pair.second;
|
||||
});
|
||||
[&](auto &stream, const auto &pair) { stream << pair.first << ": " << pair.second; });
|
||||
os << "}";
|
||||
}
|
||||
return os << "]";
|
||||
@ -339,9 +331,7 @@ std::ostream &operator<<(std::ostream &os, const Value &value) {
|
||||
case Value::Type::Map:
|
||||
os << "{";
|
||||
utils::PrintIterable(os, value.ValueMap(), ", ",
|
||||
[](auto &stream, const auto &pair) {
|
||||
stream << pair.first << ": " << pair.second;
|
||||
});
|
||||
[](auto &stream, const auto &pair) { stream << pair.first << ": " << pair.second; });
|
||||
return os << "}";
|
||||
case Value::Type::Vertex:
|
||||
return os << value.ValueVertex();
|
||||
|
@ -33,9 +33,7 @@ class Id {
|
||||
int64_t id_;
|
||||
};
|
||||
|
||||
inline bool operator==(const Id &id1, const Id &id2) {
|
||||
return id1.AsInt() == id2.AsInt();
|
||||
}
|
||||
inline bool operator==(const Id &id1, const Id &id2) { return id1.AsInt() == id2.AsInt(); }
|
||||
|
||||
inline bool operator!=(const Id &id1, const Id &id2) { return !(id1 == id2); }
|
||||
|
||||
@ -84,13 +82,10 @@ struct Path {
|
||||
// into the collection and puts that index into `indices`. A multiplier is
|
||||
// added to switch between positive and negative indices (that define edge
|
||||
// direction).
|
||||
auto add_element = [this](auto &collection, const auto &element,
|
||||
int multiplier, int offset) {
|
||||
auto add_element = [this](auto &collection, const auto &element, int multiplier, int offset) {
|
||||
auto found =
|
||||
std::find_if(collection.begin(), collection.end(),
|
||||
[&](const auto &e) { return e.id == element.id; });
|
||||
indices.emplace_back(multiplier *
|
||||
(std::distance(collection.begin(), found) + offset));
|
||||
std::find_if(collection.begin(), collection.end(), [&](const auto &e) { return e.id == element.id; });
|
||||
indices.emplace_back(multiplier * (std::distance(collection.begin(), found) + offset));
|
||||
if (found == collection.end()) collection.push_back(element);
|
||||
};
|
||||
|
||||
@ -125,19 +120,7 @@ class Value {
|
||||
Value() : type_(Type::Null) {}
|
||||
|
||||
/** Types that can be stored in a Value. */
|
||||
enum class Type : unsigned {
|
||||
Null,
|
||||
Bool,
|
||||
Int,
|
||||
Double,
|
||||
String,
|
||||
List,
|
||||
Map,
|
||||
Vertex,
|
||||
Edge,
|
||||
UnboundedEdge,
|
||||
Path
|
||||
};
|
||||
enum class Type : unsigned { Null, Bool, Int, Double, String, List, Map, Vertex, Edge, UnboundedEdge, Path };
|
||||
|
||||
// constructors for primitive types
|
||||
Value(bool value) : type_(Type::Bool) { bool_v = value; }
|
||||
@ -146,47 +129,29 @@ class Value {
|
||||
Value(double value) : type_(Type::Double) { double_v = value; }
|
||||
|
||||
// constructors for non-primitive types
|
||||
Value(const std::string &value) : type_(Type::String) {
|
||||
new (&string_v) std::string(value);
|
||||
}
|
||||
Value(const std::string &value) : type_(Type::String) { new (&string_v) std::string(value); }
|
||||
Value(const char *value) : Value(std::string(value)) {}
|
||||
Value(const std::vector<Value> &value) : type_(Type::List) {
|
||||
new (&list_v) std::vector<Value>(value);
|
||||
}
|
||||
Value(const std::vector<Value> &value) : type_(Type::List) { new (&list_v) std::vector<Value>(value); }
|
||||
Value(const std::map<std::string, Value> &value) : type_(Type::Map) {
|
||||
new (&map_v) std::map<std::string, Value>(value);
|
||||
}
|
||||
Value(const Vertex &value) : type_(Type::Vertex) {
|
||||
new (&vertex_v) Vertex(value);
|
||||
}
|
||||
Value(const Vertex &value) : type_(Type::Vertex) { new (&vertex_v) Vertex(value); }
|
||||
Value(const Edge &value) : type_(Type::Edge) { new (&edge_v) Edge(value); }
|
||||
Value(const UnboundedEdge &value) : type_(Type::UnboundedEdge) {
|
||||
new (&unbounded_edge_v) UnboundedEdge(value);
|
||||
}
|
||||
Value(const UnboundedEdge &value) : type_(Type::UnboundedEdge) { new (&unbounded_edge_v) UnboundedEdge(value); }
|
||||
Value(const Path &value) : type_(Type::Path) { new (&path_v) Path(value); }
|
||||
|
||||
// move constructors for non-primitive values
|
||||
Value(std::string &&value) noexcept : type_(Type::String) {
|
||||
new (&string_v) std::string(std::move(value));
|
||||
}
|
||||
Value(std::vector<Value> &&value) noexcept : type_(Type::List) {
|
||||
new (&list_v) std::vector<Value>(std::move(value));
|
||||
}
|
||||
Value(std::string &&value) noexcept : type_(Type::String) { new (&string_v) std::string(std::move(value)); }
|
||||
Value(std::vector<Value> &&value) noexcept : type_(Type::List) { new (&list_v) std::vector<Value>(std::move(value)); }
|
||||
Value(std::map<std::string, Value> &&value) noexcept : type_(Type::Map) {
|
||||
new (&map_v) std::map<std::string, Value>(std::move(value));
|
||||
}
|
||||
Value(Vertex &&value) noexcept : type_(Type::Vertex) {
|
||||
new (&vertex_v) Vertex(std::move(value));
|
||||
}
|
||||
Value(Edge &&value) noexcept : type_(Type::Edge) {
|
||||
new (&edge_v) Edge(std::move(value));
|
||||
}
|
||||
Value(Vertex &&value) noexcept : type_(Type::Vertex) { new (&vertex_v) Vertex(std::move(value)); }
|
||||
Value(Edge &&value) noexcept : type_(Type::Edge) { new (&edge_v) Edge(std::move(value)); }
|
||||
Value(UnboundedEdge &&value) noexcept : type_(Type::UnboundedEdge) {
|
||||
new (&unbounded_edge_v) UnboundedEdge(std::move(value));
|
||||
}
|
||||
Value(Path &&value) noexcept : type_(Type::Path) {
|
||||
new (&path_v) Path(std::move(value));
|
||||
}
|
||||
Value(Path &&value) noexcept : type_(Type::Path) { new (&path_v) Path(std::move(value)); }
|
||||
|
||||
Value &operator=(const Value &other);
|
||||
Value &operator=(Value &&other) noexcept;
|
||||
|
@ -4,8 +4,7 @@
|
||||
|
||||
namespace communication {
|
||||
|
||||
Buffer::Buffer()
|
||||
: data_(kBufferInitialSize, 0), read_end_(this), write_end_(this) {}
|
||||
Buffer::Buffer() : data_(kBufferInitialSize, 0), read_end_(this), write_end_(this) {}
|
||||
|
||||
Buffer::ReadEnd::ReadEnd(Buffer *buffer) : buffer_(buffer) {}
|
||||
|
||||
@ -21,9 +20,7 @@ void Buffer::ReadEnd::Clear() { buffer_->Clear(); }
|
||||
|
||||
Buffer::WriteEnd::WriteEnd(Buffer *buffer) : buffer_(buffer) {}
|
||||
|
||||
io::network::StreamBuffer Buffer::WriteEnd::Allocate() {
|
||||
return buffer_->Allocate();
|
||||
}
|
||||
io::network::StreamBuffer Buffer::WriteEnd::Allocate() { return buffer_->Allocate(); }
|
||||
|
||||
void Buffer::WriteEnd::Written(size_t len) { buffer_->Written(len); }
|
||||
|
||||
|
@ -195,8 +195,7 @@ bool Client::Write(const uint8_t *data, size_t len, bool have_more) {
|
||||
}
|
||||
|
||||
bool Client::Write(const std::string &str, bool have_more) {
|
||||
return Write(reinterpret_cast<const uint8_t *>(str.data()), str.size(),
|
||||
have_more);
|
||||
return Write(reinterpret_cast<const uint8_t *>(str.data()), str.size(), have_more);
|
||||
}
|
||||
|
||||
const io::network::Endpoint &Client::endpoint() { return socket_.endpoint(); }
|
||||
@ -224,12 +223,9 @@ void ClientInputStream::Clear() { client_.ClearData(); }
|
||||
|
||||
ClientOutputStream::ClientOutputStream(Client &client) : client_(client) {}
|
||||
|
||||
bool ClientOutputStream::Write(const uint8_t *data, size_t len,
|
||||
bool have_more) {
|
||||
bool ClientOutputStream::Write(const uint8_t *data, size_t len, bool have_more) {
|
||||
return client_.Write(data, len, have_more);
|
||||
}
|
||||
bool ClientOutputStream::Write(const std::string &str, bool have_more) {
|
||||
return client_.Write(str, have_more);
|
||||
}
|
||||
bool ClientOutputStream::Write(const std::string &str, bool have_more) { return client_.Write(str, have_more); }
|
||||
|
||||
} // namespace communication
|
||||
|
@ -19,21 +19,16 @@ ClientContext::ClientContext(bool use_ssl) : use_ssl_(use_ssl), ctx_(nullptr) {
|
||||
}
|
||||
}
|
||||
|
||||
ClientContext::ClientContext(const std::string &key_file,
|
||||
const std::string &cert_file)
|
||||
: ClientContext(true) {
|
||||
ClientContext::ClientContext(const std::string &key_file, const std::string &cert_file) : ClientContext(true) {
|
||||
if (key_file != "" && cert_file != "") {
|
||||
MG_ASSERT(SSL_CTX_use_certificate_file(ctx_, cert_file.c_str(),
|
||||
SSL_FILETYPE_PEM) == 1,
|
||||
MG_ASSERT(SSL_CTX_use_certificate_file(ctx_, cert_file.c_str(), SSL_FILETYPE_PEM) == 1,
|
||||
"Couldn't load client certificate from file: {}", cert_file);
|
||||
MG_ASSERT(SSL_CTX_use_PrivateKey_file(ctx_, key_file.c_str(),
|
||||
SSL_FILETYPE_PEM) == 1,
|
||||
MG_ASSERT(SSL_CTX_use_PrivateKey_file(ctx_, key_file.c_str(), SSL_FILETYPE_PEM) == 1,
|
||||
"Couldn't load client private key from file: ", key_file);
|
||||
}
|
||||
}
|
||||
|
||||
ClientContext::ClientContext(ClientContext &&other) noexcept
|
||||
: use_ssl_(other.use_ssl_), ctx_(other.ctx_) {
|
||||
ClientContext::ClientContext(ClientContext &&other) noexcept : use_ssl_(other.use_ssl_), ctx_(other.ctx_) {
|
||||
other.use_ssl_ = false;
|
||||
other.ctx_ = nullptr;
|
||||
}
|
||||
@ -69,9 +64,8 @@ bool ClientContext::use_ssl() { return use_ssl_; }
|
||||
|
||||
ServerContext::ServerContext() : use_ssl_(false), ctx_(nullptr) {}
|
||||
|
||||
ServerContext::ServerContext(const std::string &key_file,
|
||||
const std::string &cert_file,
|
||||
const std::string &ca_file, bool verify_peer)
|
||||
ServerContext::ServerContext(const std::string &key_file, const std::string &cert_file, const std::string &ca_file,
|
||||
bool verify_peer)
|
||||
: use_ssl_(true),
|
||||
#if OPENSSL_VERSION_NUMBER < 0x10100000L
|
||||
ctx_(SSL_CTX_new(SSLv23_server_method()))
|
||||
@ -81,11 +75,9 @@ ServerContext::ServerContext(const std::string &key_file,
|
||||
{
|
||||
// TODO (mferencevic): add support for encrypted private keys
|
||||
// TODO (mferencevic): add certificate revocation list (CRL)
|
||||
MG_ASSERT(SSL_CTX_use_certificate_file(ctx_, cert_file.c_str(),
|
||||
SSL_FILETYPE_PEM) == 1,
|
||||
MG_ASSERT(SSL_CTX_use_certificate_file(ctx_, cert_file.c_str(), SSL_FILETYPE_PEM) == 1,
|
||||
"Couldn't load server certificate from file: {}", cert_file);
|
||||
MG_ASSERT(SSL_CTX_use_PrivateKey_file(ctx_, key_file.c_str(),
|
||||
SSL_FILETYPE_PEM) == 1,
|
||||
MG_ASSERT(SSL_CTX_use_PrivateKey_file(ctx_, key_file.c_str(), SSL_FILETYPE_PEM) == 1,
|
||||
"Couldn't load server private key from file: {}", key_file);
|
||||
|
||||
// Disable legacy SSL support. Other options can be seen here:
|
||||
@ -94,29 +86,25 @@ ServerContext::ServerContext(const std::string &key_file,
|
||||
|
||||
if (ca_file != "") {
|
||||
// Load the certificate authority file.
|
||||
MG_ASSERT(
|
||||
SSL_CTX_load_verify_locations(ctx_, ca_file.c_str(), nullptr) == 1,
|
||||
MG_ASSERT(SSL_CTX_load_verify_locations(ctx_, ca_file.c_str(), nullptr) == 1,
|
||||
"Couldn't load certificate authority from file: {}", ca_file);
|
||||
|
||||
if (verify_peer) {
|
||||
// Add the CA to list of accepted CAs that is sent to the client.
|
||||
STACK_OF(X509_NAME) *ca_names = SSL_load_client_CA_file(ca_file.c_str());
|
||||
MG_ASSERT(ca_names != nullptr,
|
||||
"Couldn't load certificate authority from file: {}", ca_file);
|
||||
MG_ASSERT(ca_names != nullptr, "Couldn't load certificate authority from file: {}", ca_file);
|
||||
// `ca_names` doesn' need to be free'd because we pass it to
|
||||
// `SSL_CTX_set_client_CA_list`:
|
||||
// https://mta.openssl.org/pipermail/openssl-users/2015-May/001363.html
|
||||
SSL_CTX_set_client_CA_list(ctx_, ca_names);
|
||||
|
||||
// Enable verification of the client certificate.
|
||||
SSL_CTX_set_verify(
|
||||
ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
|
||||
SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ServerContext::ServerContext(ServerContext &&other) noexcept
|
||||
: use_ssl_(other.use_ssl_), ctx_(other.ctx_) {
|
||||
ServerContext::ServerContext(ServerContext &&other) noexcept : use_ssl_(other.use_ssl_), ctx_(other.ctx_) {
|
||||
other.use_ssl_ = false;
|
||||
other.ctx_ = nullptr;
|
||||
}
|
||||
|
@ -72,8 +72,8 @@ class ServerContext final {
|
||||
* to check that the client certificate is valid, then you need to supply a
|
||||
* valid `ca_file` as well.
|
||||
*/
|
||||
ServerContext(const std::string &key_file, const std::string &cert_file,
|
||||
const std::string &ca_file = "", bool verify_peer = false);
|
||||
ServerContext(const std::string &key_file, const std::string &cert_file, const std::string &ca_file = "",
|
||||
bool verify_peer = false);
|
||||
|
||||
// This object can't be copied because the underlying SSL implementation is
|
||||
// messy and ownership can't be handled correctly.
|
||||
|
@ -28,10 +28,7 @@ void LockingFunction(int mode, int n, const char *file, int line) {
|
||||
}
|
||||
}
|
||||
|
||||
unsigned long IdFunction() {
|
||||
return (unsigned long)std::hash<std::thread::id>()(
|
||||
std::this_thread::get_id());
|
||||
}
|
||||
unsigned long IdFunction() { return (unsigned long)std::hash<std::thread::id>()(std::this_thread::get_id()); }
|
||||
|
||||
void SetupThreading() {
|
||||
crypto_locks.resize(CRYPTO_num_locks());
|
||||
@ -58,8 +55,7 @@ SSLInit::SSLInit() {
|
||||
ERR_load_crypto_strings();
|
||||
|
||||
// Ignore SIGPIPE.
|
||||
MG_ASSERT(utils::SignalIgnore(utils::Signal::Pipe),
|
||||
"Couldn't ignore SIGPIPE!");
|
||||
MG_ASSERT(utils::SignalIgnore(utils::Signal::Pipe), "Couldn't ignore SIGPIPE!");
|
||||
|
||||
SetupThreading();
|
||||
}
|
||||
|
@ -40,8 +40,7 @@ class Listener final {
|
||||
using SessionHandler = Session<TSession, TSessionData>;
|
||||
|
||||
public:
|
||||
Listener(TSessionData *data, ServerContext *context,
|
||||
int inactivity_timeout_sec, const std::string &service_name,
|
||||
Listener(TSessionData *data, ServerContext *context, int inactivity_timeout_sec, const std::string &service_name,
|
||||
size_t workers_count)
|
||||
: data_(data),
|
||||
alive_(false),
|
||||
@ -77,8 +76,8 @@ class Listener final {
|
||||
int fd = connection.fd();
|
||||
|
||||
// Create a new Session for the connection.
|
||||
sessions_.push_back(std::make_unique<SessionHandler>(
|
||||
std::move(connection), data_, context_, inactivity_timeout_sec_));
|
||||
sessions_.push_back(
|
||||
std::make_unique<SessionHandler>(std::move(connection), data_, context_, inactivity_timeout_sec_));
|
||||
|
||||
// Register the connection in Epoll.
|
||||
// We want to listen to an incoming event which is edge triggered and
|
||||
@ -86,8 +85,7 @@ class Listener final {
|
||||
// concurrently and that is why we use `EPOLLONESHOT`, for a detailed
|
||||
// description what are the problems and why this is correct see:
|
||||
// https://idea.popcount.org/2017-02-20-epoll-is-fundamentally-broken-12/
|
||||
epoll_.Add(fd, EPOLLIN | EPOLLET | EPOLLRDHUP | EPOLLONESHOT,
|
||||
sessions_.back().get());
|
||||
epoll_.Add(fd, EPOLLIN | EPOLLET | EPOLLRDHUP | EPOLLONESHOT, sessions_.back().get());
|
||||
}
|
||||
|
||||
/**
|
||||
@ -117,8 +115,7 @@ class Listener final {
|
||||
std::lock_guard<utils::SpinLock> guard(lock_);
|
||||
for (auto &session : sessions_) {
|
||||
if (session->TimedOut()) {
|
||||
spdlog::warn("{} session associated with {} timed out",
|
||||
service_name, session->socket().endpoint());
|
||||
spdlog::warn("{} session associated with {} timed out", service_name, session->socket().endpoint());
|
||||
// Here we shutdown the socket to terminate any leftover
|
||||
// blocking `Write` calls and to signal an event that the
|
||||
// session is closed. Session cleanup will be done in the event
|
||||
@ -178,8 +175,7 @@ class Listener final {
|
||||
// dereference it here. It is safe to dereference the pointer because
|
||||
// this design guarantees that there will never be an event that has
|
||||
// a stale Session pointer.
|
||||
SessionHandler &session =
|
||||
*reinterpret_cast<SessionHandler *>(event.data.ptr);
|
||||
SessionHandler &session = *reinterpret_cast<SessionHandler *>(event.data.ptr);
|
||||
|
||||
// Process epoll events. We use epoll in edge-triggered mode so we process
|
||||
// all events here. Only one of the `if` statements must be executed
|
||||
@ -192,20 +188,16 @@ class Listener final {
|
||||
;
|
||||
} else if (event.events & EPOLLRDHUP) {
|
||||
// The client closed the connection.
|
||||
spdlog::info("{} client {} closed the connection.", service_name_,
|
||||
session.socket().endpoint());
|
||||
spdlog::info("{} client {} closed the connection.", service_name_, session.socket().endpoint());
|
||||
CloseSession(session);
|
||||
} else if (!(event.events & EPOLLIN) ||
|
||||
event.events & (EPOLLHUP | EPOLLERR)) {
|
||||
} else if (!(event.events & EPOLLIN) || event.events & (EPOLLHUP | EPOLLERR)) {
|
||||
// There was an error on the server side.
|
||||
spdlog::error("Error occured in {} session associated with {}",
|
||||
service_name_, session.socket().endpoint());
|
||||
spdlog::error("Error occured in {} session associated with {}", service_name_, session.socket().endpoint());
|
||||
CloseSession(session);
|
||||
} else {
|
||||
// Unhandled epoll event.
|
||||
spdlog::error(
|
||||
"Unhandled event occured in {} session associated with {} events: {}",
|
||||
service_name_, session.socket().endpoint(), event.events);
|
||||
spdlog::error("Unhandled event occured in {} session associated with {} events: {}", service_name_,
|
||||
session.socket().endpoint(), event.events);
|
||||
CloseSession(session);
|
||||
}
|
||||
}
|
||||
@ -215,13 +207,11 @@ class Listener final {
|
||||
if (session.Execute()) {
|
||||
// Session execution done, rearm epoll to send events for this
|
||||
// socket.
|
||||
epoll_.Modify(session.socket().fd(),
|
||||
EPOLLIN | EPOLLET | EPOLLRDHUP | EPOLLONESHOT, &session);
|
||||
epoll_.Modify(session.socket().fd(), EPOLLIN | EPOLLET | EPOLLRDHUP | EPOLLONESHOT, &session);
|
||||
return false;
|
||||
}
|
||||
} catch (const SessionClosedException &e) {
|
||||
spdlog::info("{} client {} closed the connection.", service_name_,
|
||||
session.socket().endpoint());
|
||||
spdlog::info("{} client {} closed the connection.", service_name_, session.socket().endpoint());
|
||||
CloseSession(session);
|
||||
return false;
|
||||
} catch (const std::exception &e) {
|
||||
@ -245,11 +235,9 @@ class Listener final {
|
||||
epoll_.Delete(session.socket().fd());
|
||||
|
||||
std::lock_guard<utils::SpinLock> guard(lock_);
|
||||
auto it = std::find_if(sessions_.begin(), sessions_.end(),
|
||||
[&](const auto &l) { return l.get() == &session; });
|
||||
auto it = std::find_if(sessions_.begin(), sessions_.end(), [&](const auto &l) { return l.get() == &session; });
|
||||
|
||||
MG_ASSERT(it != sessions_.end(),
|
||||
"Trying to remove session that is not found in sessions!");
|
||||
MG_ASSERT(it != sessions_.end(), "Trying to remove session that is not found in sessions!");
|
||||
int i = it - sessions_.begin();
|
||||
swap(sessions_[i], sessions_.back());
|
||||
|
||||
|
@ -27,9 +27,7 @@ class ResultStreamFaker {
|
||||
|
||||
void Header(const std::vector<std::string> &fields) { header_ = fields; }
|
||||
|
||||
void Result(const std::vector<communication::bolt::Value> &values) {
|
||||
results_.push_back(values);
|
||||
}
|
||||
void Result(const std::vector<communication::bolt::Value> &values) { results_.push_back(values); }
|
||||
|
||||
void Result(const std::vector<query::TypedValue> &values) {
|
||||
std::vector<communication::bolt::Value> bvalues;
|
||||
@ -42,16 +40,12 @@ class ResultStreamFaker {
|
||||
results_.push_back(std::move(bvalues));
|
||||
}
|
||||
|
||||
void Summary(
|
||||
const std::map<std::string, communication::bolt::Value> &summary) {
|
||||
summary_ = summary;
|
||||
}
|
||||
void Summary(const std::map<std::string, communication::bolt::Value> &summary) { summary_ = summary; }
|
||||
|
||||
void Summary(const std::map<std::string, query::TypedValue> &summary) {
|
||||
std::map<std::string, communication::bolt::Value> bsummary;
|
||||
for (const auto &item : summary) {
|
||||
auto maybe_value =
|
||||
glue::ToBoltValue(item.second, *store_, storage::View::NEW);
|
||||
auto maybe_value = glue::ToBoltValue(item.second, *store_, storage::View::NEW);
|
||||
MG_ASSERT(maybe_value.HasValue());
|
||||
bsummary.insert({item.first, std::move(*maybe_value)});
|
||||
}
|
||||
@ -64,8 +58,7 @@ class ResultStreamFaker {
|
||||
|
||||
const auto &GetSummary() const { return summary_; }
|
||||
|
||||
friend std::ostream &operator<<(std::ostream &os,
|
||||
const ResultStreamFaker &results) {
|
||||
friend std::ostream &operator<<(std::ostream &os, const ResultStreamFaker &results) {
|
||||
auto decoded_value_to_string = [](const auto &value) {
|
||||
std::stringstream ss;
|
||||
ss << value;
|
||||
@ -73,21 +66,16 @@ class ResultStreamFaker {
|
||||
};
|
||||
const std::vector<std::string> &header = results.GetHeader();
|
||||
std::vector<int> column_widths(header.size());
|
||||
std::transform(header.begin(), header.end(), column_widths.begin(),
|
||||
[](const auto &s) { return s.size(); });
|
||||
std::transform(header.begin(), header.end(), column_widths.begin(), [](const auto &s) { return s.size(); });
|
||||
|
||||
// convert all the results into strings, and track max column width
|
||||
auto &results_data = results.GetResults();
|
||||
std::vector<std::vector<std::string>> result_strings(
|
||||
results_data.size(), std::vector<std::string>(column_widths.size()));
|
||||
for (int row_ind = 0; row_ind < static_cast<int>(results_data.size());
|
||||
++row_ind) {
|
||||
for (int col_ind = 0; col_ind < static_cast<int>(column_widths.size());
|
||||
++col_ind) {
|
||||
std::string string_val =
|
||||
decoded_value_to_string(results_data[row_ind][col_ind]);
|
||||
column_widths[col_ind] =
|
||||
std::max(column_widths[col_ind], (int)string_val.size());
|
||||
std::vector<std::vector<std::string>> result_strings(results_data.size(),
|
||||
std::vector<std::string>(column_widths.size()));
|
||||
for (int row_ind = 0; row_ind < static_cast<int>(results_data.size()); ++row_ind) {
|
||||
for (int col_ind = 0; col_ind < static_cast<int>(column_widths.size()); ++col_ind) {
|
||||
std::string string_val = decoded_value_to_string(results_data[row_ind][col_ind]);
|
||||
column_widths[col_ind] = std::max(column_widths[col_ind], (int)string_val.size());
|
||||
result_strings[row_ind][col_ind] = string_val;
|
||||
}
|
||||
}
|
||||
@ -96,15 +84,13 @@ class ResultStreamFaker {
|
||||
// first define some helper functions
|
||||
auto emit_horizontal_line = [&]() {
|
||||
os << "+";
|
||||
for (auto col_width : column_widths)
|
||||
os << std::string((unsigned long)col_width + 2, '-') << "+";
|
||||
for (auto col_width : column_widths) os << std::string((unsigned long)col_width + 2, '-') << "+";
|
||||
os << std::endl;
|
||||
};
|
||||
|
||||
auto emit_result_vec = [&](const std::vector<std::string> result_vec) {
|
||||
os << "| ";
|
||||
for (int col_ind = 0; col_ind < static_cast<int>(column_widths.size());
|
||||
++col_ind) {
|
||||
for (int col_ind = 0; col_ind < static_cast<int>(column_widths.size()); ++col_ind) {
|
||||
const std::string &res = result_vec[col_ind];
|
||||
os << res << std::string(column_widths[col_ind] - res.size(), ' ');
|
||||
os << " | ";
|
||||
@ -123,9 +109,7 @@ class ResultStreamFaker {
|
||||
// output the summary
|
||||
os << "Query summary: {";
|
||||
utils::PrintIterable(os, results.GetSummary(), ", ",
|
||||
[&](auto &stream, const auto &kv) {
|
||||
stream << kv.first << ": " << kv.second;
|
||||
});
|
||||
[&](auto &stream, const auto &kv) { stream << kv.first << ": " << kv.second; });
|
||||
os << "}" << std::endl;
|
||||
|
||||
return os;
|
||||
|
@ -46,14 +46,12 @@ class Server final {
|
||||
* Constructs and binds server to endpoint, operates on session data and
|
||||
* invokes workers_count workers
|
||||
*/
|
||||
Server(const io::network::Endpoint &endpoint, TSessionData *session_data,
|
||||
ServerContext *context, int inactivity_timeout_sec,
|
||||
const std::string &service_name,
|
||||
Server(const io::network::Endpoint &endpoint, TSessionData *session_data, ServerContext *context,
|
||||
int inactivity_timeout_sec, const std::string &service_name,
|
||||
size_t workers_count = std::thread::hardware_concurrency())
|
||||
: alive_(false),
|
||||
endpoint_(endpoint),
|
||||
listener_(session_data, context, inactivity_timeout_sec, service_name,
|
||||
workers_count),
|
||||
listener_(session_data, context, inactivity_timeout_sec, service_name, workers_count),
|
||||
service_name_(service_name) {}
|
||||
|
||||
~Server() {
|
||||
@ -69,8 +67,7 @@ class Server final {
|
||||
Server &operator=(Server &&) = delete;
|
||||
|
||||
const auto &endpoint() const {
|
||||
MG_ASSERT(alive_,
|
||||
"You can't get the server endpoint when it's not running!");
|
||||
MG_ASSERT(alive_, "You can't get the server endpoint when it's not running!");
|
||||
return socket_.endpoint();
|
||||
}
|
||||
|
||||
@ -138,8 +135,7 @@ class Server final {
|
||||
// Connection is not available anymore or configuration failed.
|
||||
return;
|
||||
}
|
||||
spdlog::info("Accepted a {} connection from {}", service_name_,
|
||||
s->endpoint());
|
||||
spdlog::info("Accepted a {} connection from {}", service_name_, s->endpoint());
|
||||
listener_.AddConnection(std::move(*s));
|
||||
}
|
||||
|
||||
|
@ -35,22 +35,17 @@ using InputStream = Buffer::ReadEnd;
|
||||
*/
|
||||
class OutputStream final {
|
||||
public:
|
||||
OutputStream(
|
||||
std::function<bool(const uint8_t *, size_t, bool)> write_function)
|
||||
: write_function_(write_function) {}
|
||||
OutputStream(std::function<bool(const uint8_t *, size_t, bool)> write_function) : write_function_(write_function) {}
|
||||
|
||||
OutputStream(const OutputStream &) = delete;
|
||||
OutputStream(OutputStream &&) = delete;
|
||||
OutputStream &operator=(const OutputStream &) = delete;
|
||||
OutputStream &operator=(OutputStream &&) = delete;
|
||||
|
||||
bool Write(const uint8_t *data, size_t len, bool have_more = false) {
|
||||
return write_function_(data, len, have_more);
|
||||
}
|
||||
bool Write(const uint8_t *data, size_t len, bool have_more = false) { return write_function_(data, len, have_more); }
|
||||
|
||||
bool Write(const std::string &str, bool have_more = false) {
|
||||
return Write(reinterpret_cast<const uint8_t *>(str.data()), str.size(),
|
||||
have_more);
|
||||
return Write(reinterpret_cast<const uint8_t *>(str.data()), str.size(), have_more);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -65,14 +60,10 @@ class OutputStream final {
|
||||
template <class TSession, class TSessionData>
|
||||
class Session final {
|
||||
public:
|
||||
Session(io::network::Socket &&socket, TSessionData *data,
|
||||
ServerContext *context, int inactivity_timeout_sec)
|
||||
Session(io::network::Socket &&socket, TSessionData *data, ServerContext *context, int inactivity_timeout_sec)
|
||||
: socket_(std::move(socket)),
|
||||
output_stream_([this](const uint8_t *data, size_t len, bool have_more) {
|
||||
return Write(data, len, have_more);
|
||||
}),
|
||||
session_(data, socket_.endpoint(), input_buffer_.read_end(),
|
||||
&output_stream_),
|
||||
output_stream_([this](const uint8_t *data, size_t len, bool have_more) { return Write(data, len, have_more); }),
|
||||
session_(data, socket_.endpoint(), input_buffer_.read_end(), &output_stream_),
|
||||
inactivity_timeout_sec_(inactivity_timeout_sec) {
|
||||
// Set socket options.
|
||||
// The socket is set to be a non-blocking socket. We use the socket in a
|
||||
@ -243,8 +234,7 @@ class Session final {
|
||||
bool TimedOut() {
|
||||
std::unique_lock<utils::SpinLock> guard(lock_);
|
||||
if (execution_active_) return false;
|
||||
return last_event_time_ + std::chrono::seconds(inactivity_timeout_sec_) <
|
||||
std::chrono::steady_clock::now();
|
||||
return last_event_time_ + std::chrono::seconds(inactivity_timeout_sec_) < std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
/**
|
||||
@ -316,8 +306,7 @@ class Session final {
|
||||
TSession session_;
|
||||
|
||||
// Time of the last event and associated lock.
|
||||
std::chrono::time_point<std::chrono::steady_clock> last_event_time_{
|
||||
std::chrono::steady_clock::now()};
|
||||
std::chrono::time_point<std::chrono::steady_clock> last_event_time_{std::chrono::steady_clock::now()};
|
||||
bool execution_active_{false};
|
||||
utils::SpinLock lock_;
|
||||
const int inactivity_timeout_sec_;
|
||||
|
@ -20,9 +20,7 @@
|
||||
template <typename TElement>
|
||||
class RingBuffer {
|
||||
public:
|
||||
explicit RingBuffer(int capacity) : capacity_(capacity) {
|
||||
buffer_ = std::make_unique<TElement[]>(capacity_);
|
||||
}
|
||||
explicit RingBuffer(int capacity) : capacity_(capacity) { buffer_ = std::make_unique<TElement[]>(capacity_); }
|
||||
|
||||
RingBuffer(const RingBuffer &) = delete;
|
||||
RingBuffer(RingBuffer &&) = delete;
|
||||
|
@ -32,34 +32,28 @@ query::TypedValue ToTypedValue(const Value &value) {
|
||||
}
|
||||
case Value::Type::Map: {
|
||||
std::map<std::string, query::TypedValue> map;
|
||||
for (const auto &kv : value.ValueMap())
|
||||
map.emplace(kv.first, ToTypedValue(kv.second));
|
||||
for (const auto &kv : value.ValueMap()) map.emplace(kv.first, ToTypedValue(kv.second));
|
||||
return query::TypedValue(std::move(map));
|
||||
}
|
||||
case Value::Type::Vertex:
|
||||
case Value::Type::Edge:
|
||||
case Value::Type::UnboundedEdge:
|
||||
case Value::Type::Path:
|
||||
throw communication::bolt::ValueException(
|
||||
"Unsupported conversion from Value to TypedValue");
|
||||
throw communication::bolt::ValueException("Unsupported conversion from Value to TypedValue");
|
||||
}
|
||||
}
|
||||
|
||||
storage::Result<communication::bolt::Vertex> ToBoltVertex(
|
||||
const query::VertexAccessor &vertex, const storage::Storage &db,
|
||||
storage::View view) {
|
||||
storage::Result<communication::bolt::Vertex> ToBoltVertex(const query::VertexAccessor &vertex,
|
||||
const storage::Storage &db, storage::View view) {
|
||||
return ToBoltVertex(vertex.impl_, db, view);
|
||||
}
|
||||
|
||||
storage::Result<communication::bolt::Edge> ToBoltEdge(
|
||||
const query::EdgeAccessor &edge, const storage::Storage &db,
|
||||
storage::Result<communication::bolt::Edge> ToBoltEdge(const query::EdgeAccessor &edge, const storage::Storage &db,
|
||||
storage::View view) {
|
||||
return ToBoltEdge(edge.impl_, db, view);
|
||||
}
|
||||
|
||||
storage::Result<Value> ToBoltValue(const query::TypedValue &value,
|
||||
const storage::Storage &db,
|
||||
storage::View view) {
|
||||
storage::Result<Value> ToBoltValue(const query::TypedValue &value, const storage::Storage &db, storage::View view) {
|
||||
switch (value.type()) {
|
||||
case query::TypedValue::Type::Null:
|
||||
return Value();
|
||||
@ -90,20 +84,17 @@ storage::Result<Value> ToBoltValue(const query::TypedValue &value,
|
||||
}
|
||||
return Value(std::move(map));
|
||||
}
|
||||
case query::TypedValue::Type::Vertex:
|
||||
{
|
||||
case query::TypedValue::Type::Vertex: {
|
||||
auto maybe_vertex = ToBoltVertex(value.ValueVertex(), db, view);
|
||||
if (maybe_vertex.HasError()) return maybe_vertex.GetError();
|
||||
return Value(std::move(*maybe_vertex));
|
||||
}
|
||||
case query::TypedValue::Type::Edge:
|
||||
{
|
||||
case query::TypedValue::Type::Edge: {
|
||||
auto maybe_edge = ToBoltEdge(value.ValueEdge(), db, view);
|
||||
if (maybe_edge.HasError()) return maybe_edge.GetError();
|
||||
return Value(std::move(*maybe_edge));
|
||||
}
|
||||
case query::TypedValue::Type::Path:
|
||||
{
|
||||
case query::TypedValue::Type::Path: {
|
||||
auto maybe_path = ToBoltPath(value.ValuePath(), db, view);
|
||||
if (maybe_path.HasError()) return maybe_path.GetError();
|
||||
return Value(std::move(*maybe_path));
|
||||
@ -111,9 +102,8 @@ storage::Result<Value> ToBoltValue(const query::TypedValue &value,
|
||||
}
|
||||
}
|
||||
|
||||
storage::Result<communication::bolt::Vertex> ToBoltVertex(
|
||||
const storage::VertexAccessor &vertex, const storage::Storage &db,
|
||||
storage::View view) {
|
||||
storage::Result<communication::bolt::Vertex> ToBoltVertex(const storage::VertexAccessor &vertex,
|
||||
const storage::Storage &db, storage::View view) {
|
||||
auto id = communication::bolt::Id::FromUint(vertex.Gid().AsUint());
|
||||
auto maybe_labels = vertex.Labels(view);
|
||||
if (maybe_labels.HasError()) return maybe_labels.GetError();
|
||||
@ -131,12 +121,10 @@ storage::Result<communication::bolt::Vertex> ToBoltVertex(
|
||||
return communication::bolt::Vertex{id, labels, properties};
|
||||
}
|
||||
|
||||
storage::Result<communication::bolt::Edge> ToBoltEdge(
|
||||
const storage::EdgeAccessor &edge, const storage::Storage &db,
|
||||
storage::Result<communication::bolt::Edge> ToBoltEdge(const storage::EdgeAccessor &edge, const storage::Storage &db,
|
||||
storage::View view) {
|
||||
auto id = communication::bolt::Id::FromUint(edge.Gid().AsUint());
|
||||
auto from =
|
||||
communication::bolt::Id::FromUint(edge.FromVertex().Gid().AsUint());
|
||||
auto from = communication::bolt::Id::FromUint(edge.FromVertex().Gid().AsUint());
|
||||
auto to = communication::bolt::Id::FromUint(edge.ToVertex().Gid().AsUint());
|
||||
auto type = db.EdgeTypeToName(edge.EdgeType());
|
||||
auto maybe_properties = edge.Properties(view);
|
||||
@ -148,8 +136,8 @@ storage::Result<communication::bolt::Edge> ToBoltEdge(
|
||||
return communication::bolt::Edge{id, from, to, type, properties};
|
||||
}
|
||||
|
||||
storage::Result<communication::bolt::Path> ToBoltPath(
|
||||
const query::Path &path, const storage::Storage &db, storage::View view) {
|
||||
storage::Result<communication::bolt::Path> ToBoltPath(const query::Path &path, const storage::Storage &db,
|
||||
storage::View view) {
|
||||
std::vector<communication::bolt::Vertex> vertices;
|
||||
vertices.reserve(path.vertices().size());
|
||||
for (const auto &v : path.vertices()) {
|
||||
@ -182,22 +170,19 @@ storage::PropertyValue ToPropertyValue(const Value &value) {
|
||||
case Value::Type::List: {
|
||||
std::vector<storage::PropertyValue> vec;
|
||||
vec.reserve(value.ValueList().size());
|
||||
for (const auto &value : value.ValueList())
|
||||
vec.emplace_back(ToPropertyValue(value));
|
||||
for (const auto &value : value.ValueList()) vec.emplace_back(ToPropertyValue(value));
|
||||
return storage::PropertyValue(std::move(vec));
|
||||
}
|
||||
case Value::Type::Map: {
|
||||
std::map<std::string, storage::PropertyValue> map;
|
||||
for (const auto &kv : value.ValueMap())
|
||||
map.emplace(kv.first, ToPropertyValue(kv.second));
|
||||
for (const auto &kv : value.ValueMap()) map.emplace(kv.first, ToPropertyValue(kv.second));
|
||||
return storage::PropertyValue(std::move(map));
|
||||
}
|
||||
case Value::Type::Vertex:
|
||||
case Value::Type::Edge:
|
||||
case Value::Type::UnboundedEdge:
|
||||
case Value::Type::Path:
|
||||
throw communication::bolt::ValueException(
|
||||
"Unsupported conversion from Value to PropertyValue");
|
||||
throw communication::bolt::ValueException("Unsupported conversion from Value to PropertyValue");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -21,17 +21,15 @@ namespace glue {
|
||||
/// @param storage::View for deciding which vertex attributes are visible.
|
||||
///
|
||||
/// @throw std::bad_alloc
|
||||
storage::Result<communication::bolt::Vertex> ToBoltVertex(
|
||||
const storage::VertexAccessor &vertex, const storage::Storage &db,
|
||||
storage::View view);
|
||||
storage::Result<communication::bolt::Vertex> ToBoltVertex(const storage::VertexAccessor &vertex,
|
||||
const storage::Storage &db, storage::View view);
|
||||
|
||||
/// @param storage::EdgeAccessor for converting to communication::bolt::Edge.
|
||||
/// @param storage::Storage for getting edge type and property names.
|
||||
/// @param storage::View for deciding which edge attributes are visible.
|
||||
///
|
||||
/// @throw std::bad_alloc
|
||||
storage::Result<communication::bolt::Edge> ToBoltEdge(
|
||||
const storage::EdgeAccessor &edge, const storage::Storage &db,
|
||||
storage::Result<communication::bolt::Edge> ToBoltEdge(const storage::EdgeAccessor &edge, const storage::Storage &db,
|
||||
storage::View view);
|
||||
|
||||
/// @param query::Path for converting to communication::bolt::Path.
|
||||
@ -39,16 +37,15 @@ storage::Result<communication::bolt::Edge> ToBoltEdge(
|
||||
/// @param storage::View for ToBoltVertex and ToBoltEdge.
|
||||
///
|
||||
/// @throw std::bad_alloc
|
||||
storage::Result<communication::bolt::Path> ToBoltPath(
|
||||
const query::Path &path, const storage::Storage &db, storage::View view);
|
||||
storage::Result<communication::bolt::Path> ToBoltPath(const query::Path &path, const storage::Storage &db,
|
||||
storage::View view);
|
||||
|
||||
/// @param query::TypedValue for converting to communication::bolt::Value.
|
||||
/// @param storage::Storage for ToBoltVertex and ToBoltEdge.
|
||||
/// @param storage::View for ToBoltVertex and ToBoltEdge.
|
||||
///
|
||||
/// @throw std::bad_alloc
|
||||
storage::Result<communication::bolt::Value> ToBoltValue(
|
||||
const query::TypedValue &value, const storage::Storage &db,
|
||||
storage::Result<communication::bolt::Value> ToBoltValue(const query::TypedValue &value, const storage::Storage &db,
|
||||
storage::View view);
|
||||
|
||||
query::TypedValue ToTypedValue(const communication::bolt::Value &value);
|
||||
|
@ -17,16 +17,12 @@
|
||||
inline void LoadConfig(const std::string &product_name) {
|
||||
namespace fs = std::filesystem;
|
||||
std::vector<fs::path> configs = {fs::path("/etc/memgraph/memgraph.conf")};
|
||||
if (getenv("HOME") != nullptr)
|
||||
configs.emplace_back(fs::path(getenv("HOME")) /
|
||||
fs::path(".memgraph/config"));
|
||||
if (getenv("HOME") != nullptr) configs.emplace_back(fs::path(getenv("HOME")) / fs::path(".memgraph/config"));
|
||||
{
|
||||
auto memgraph_config = getenv("MEMGRAPH_CONFIG");
|
||||
if (memgraph_config != nullptr) {
|
||||
auto path = fs::path(memgraph_config);
|
||||
MG_ASSERT(
|
||||
fs::exists(path),
|
||||
"MEMGRAPH_CONFIG environment variable set to nonexisting path: {}",
|
||||
MG_ASSERT(fs::exists(path), "MEMGRAPH_CONFIG environment variable set to nonexisting path: {}",
|
||||
path.generic_string());
|
||||
configs.emplace_back(path);
|
||||
}
|
||||
@ -35,8 +31,7 @@ inline void LoadConfig(const std::string &product_name) {
|
||||
std::vector<std::string> flagfile_arguments;
|
||||
for (const auto &config : configs)
|
||||
if (fs::exists(config)) {
|
||||
flagfile_arguments.emplace_back(
|
||||
std::string("--flag-file=" + config.generic_string()));
|
||||
flagfile_arguments.emplace_back(std::string("--flag-file=" + config.generic_string()));
|
||||
}
|
||||
|
||||
int custom_argc = static_cast<int>(flagfile_arguments.size()) + 1;
|
||||
|
@ -25,10 +25,8 @@ Endpoint::IpFamily Endpoint::GetIpFamily(const std::string &ip_address) {
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<std::pair<std::string, uint16_t>>
|
||||
Endpoint::ParseSocketOrIpAddress(
|
||||
const std::string &address,
|
||||
const std::optional<uint16_t> default_port = {}) {
|
||||
std::optional<std::pair<std::string, uint16_t>> Endpoint::ParseSocketOrIpAddress(
|
||||
const std::string &address, const std::optional<uint16_t> default_port = {}) {
|
||||
/// expected address format:
|
||||
/// - "ip_address:port_number"
|
||||
/// - "ip_address"
|
||||
@ -80,8 +78,7 @@ std::string Endpoint::SocketAddress() const {
|
||||
}
|
||||
|
||||
Endpoint::Endpoint() {}
|
||||
Endpoint::Endpoint(std::string ip_address, uint16_t port)
|
||||
: address(std::move(ip_address)), port(port) {
|
||||
Endpoint::Endpoint(std::string ip_address, uint16_t port) : address(std::move(ip_address)), port(port) {
|
||||
IpFamily ip_family = GetIpFamily(address);
|
||||
if (ip_family == IpFamily::NONE) {
|
||||
throw NetworkError("Not a valid IPv4 or IPv6 address: {}", ip_address);
|
||||
|
@ -21,13 +21,11 @@ class Epoll {
|
||||
public:
|
||||
using Event = struct epoll_event;
|
||||
|
||||
Epoll(bool set_cloexec = false)
|
||||
: epoll_fd_(epoll_create1(set_cloexec ? EPOLL_CLOEXEC : 0)) {
|
||||
Epoll(bool set_cloexec = false) : epoll_fd_(epoll_create1(set_cloexec ? EPOLL_CLOEXEC : 0)) {
|
||||
// epoll_create1 returns an error if there is a logical error in our code
|
||||
// (for example invalid flags) or if there is irrecoverable error. In both
|
||||
// cases it is best to terminate.
|
||||
MG_ASSERT(epoll_fd_ != -1, "Error on epoll create: ({}) {}", errno,
|
||||
strerror(errno));
|
||||
MG_ASSERT(epoll_fd_ != -1, "Error on epoll create: ({}) {}", errno, strerror(errno));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -42,15 +40,13 @@ class Epoll {
|
||||
Event event;
|
||||
event.events = events;
|
||||
event.data.ptr = ptr;
|
||||
int status = epoll_ctl(epoll_fd_, (modify ? EPOLL_CTL_MOD : EPOLL_CTL_ADD),
|
||||
fd, &event);
|
||||
int status = epoll_ctl(epoll_fd_, (modify ? EPOLL_CTL_MOD : EPOLL_CTL_ADD), fd, &event);
|
||||
// epoll_ctl can return an error on our logical error or on irrecoverable
|
||||
// error. There is a third possibility that some system limit is reached. In
|
||||
// that case we could return an erorr and close connection. Chances of
|
||||
// reaching system limit in normally working memgraph is extremely unlikely,
|
||||
// so it is correct to terminate even in that case.
|
||||
MG_ASSERT(!status, "Error on epoll {}: ({}) {}",
|
||||
(modify ? "modify" : "add"), errno, strerror(errno));
|
||||
MG_ASSERT(!status, "Error on epoll {}: ({}) {}", (modify ? "modify" : "add"), errno, strerror(errno));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -60,9 +56,7 @@ class Epoll {
|
||||
* @param events epoll events mask
|
||||
* @param ptr pointer to the associated event handler
|
||||
*/
|
||||
void Modify(int fd, uint32_t events, void *ptr) {
|
||||
Add(fd, events, ptr, true);
|
||||
}
|
||||
void Modify(int fd, uint32_t events, void *ptr) { Add(fd, events, ptr, true); }
|
||||
|
||||
/**
|
||||
* This function deletes a file descriptor that is listened for events.
|
||||
@ -76,8 +70,7 @@ class Epoll {
|
||||
// that case we could return an erorr and close connection. Chances of
|
||||
// reaching system limit in normally working memgraph is extremely unlikely,
|
||||
// so it is correct to terminate even in that case.
|
||||
MG_ASSERT(!status, "Error on epoll delete: ({}) {}", errno,
|
||||
strerror(errno));
|
||||
MG_ASSERT(!status, "Error on epoll delete: ({}) {}", errno, strerror(errno));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -91,8 +84,7 @@ class Epoll {
|
||||
int Wait(Event *events, int max_events, int timeout) {
|
||||
auto num_events = epoll_wait(epoll_fd_, events, max_events, timeout);
|
||||
// If this check fails there was logical error in our code.
|
||||
MG_ASSERT(num_events != -1 || errno == EINTR,
|
||||
"Error on epoll wait: ({}) {}", errno, strerror(errno));
|
||||
MG_ASSERT(num_events != -1 || errno == EINTR, "Error on epoll wait: ({}) {}", errno, strerror(errno));
|
||||
// num_events can be -1 if errno was EINTR (epoll_wait interrupted by signal
|
||||
// handler). We treat that as no events, so we return 0.
|
||||
return num_events == -1 ? 0 : num_events;
|
||||
|
@ -8,4 +8,4 @@ class NetworkError : public utils::StacktraceException {
|
||||
public:
|
||||
using utils::StacktraceException::StacktraceException;
|
||||
};
|
||||
}
|
||||
} // namespace io::network
|
||||
|
@ -59,8 +59,7 @@ bool Socket::IsOpen() const { return socket_ != -1; }
|
||||
bool Socket::Connect(const Endpoint &endpoint) {
|
||||
if (socket_ != -1) return false;
|
||||
|
||||
auto info = AddrInfo::Get(endpoint.address.c_str(),
|
||||
std::to_string(endpoint.port).c_str());
|
||||
auto info = AddrInfo::Get(endpoint.address.c_str(), std::to_string(endpoint.port).c_str());
|
||||
|
||||
for (struct addrinfo *it = info; it != nullptr; it = it->ai_next) {
|
||||
int sfd = socket(it->ai_family, it->ai_socktype, it->ai_protocol);
|
||||
@ -83,8 +82,7 @@ bool Socket::Connect(const Endpoint &endpoint) {
|
||||
bool Socket::Bind(const Endpoint &endpoint) {
|
||||
if (socket_ != -1) return false;
|
||||
|
||||
auto info = AddrInfo::Get(endpoint.address.c_str(),
|
||||
std::to_string(endpoint.port).c_str());
|
||||
auto info = AddrInfo::Get(endpoint.address.c_str(), std::to_string(endpoint.port).c_str());
|
||||
|
||||
for (struct addrinfo *it = info; it != nullptr; it = it->ai_next) {
|
||||
int sfd = socket(it->ai_family, it->ai_socktype, it->ai_protocol);
|
||||
@ -130,38 +128,30 @@ void Socket::SetNonBlocking() {
|
||||
int flags = fcntl(socket_, F_GETFL, 0);
|
||||
MG_ASSERT(flags != -1, "Can't get socket mode");
|
||||
flags |= O_NONBLOCK;
|
||||
MG_ASSERT(fcntl(socket_, F_SETFL, flags) != -1,
|
||||
"Can't set socket nonblocking");
|
||||
MG_ASSERT(fcntl(socket_, F_SETFL, flags) != -1, "Can't set socket nonblocking");
|
||||
}
|
||||
|
||||
void Socket::SetKeepAlive() {
|
||||
int optval = 1;
|
||||
socklen_t optlen = sizeof(optval);
|
||||
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_SOCKET, SO_KEEPALIVE, &optval, optlen),
|
||||
"Can't set socket keep alive");
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_SOCKET, SO_KEEPALIVE, &optval, optlen), "Can't set socket keep alive");
|
||||
|
||||
optval = 20; // wait 20s before sending keep-alive packets
|
||||
MG_ASSERT(
|
||||
!setsockopt(socket_, SOL_TCP, TCP_KEEPIDLE, (void *)&optval, optlen),
|
||||
"Can't set socket keep alive");
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPIDLE, (void *)&optval, optlen), "Can't set socket keep alive");
|
||||
|
||||
optval = 4; // 4 keep-alive packets must fail to close
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPCNT, (void *)&optval, optlen),
|
||||
"Can't set socket keep alive");
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPCNT, (void *)&optval, optlen), "Can't set socket keep alive");
|
||||
|
||||
optval = 15; // send keep-alive packets every 15s
|
||||
MG_ASSERT(
|
||||
!setsockopt(socket_, SOL_TCP, TCP_KEEPINTVL, (void *)&optval, optlen),
|
||||
"Can't set socket keep alive");
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_KEEPINTVL, (void *)&optval, optlen), "Can't set socket keep alive");
|
||||
}
|
||||
|
||||
void Socket::SetNoDelay() {
|
||||
int optval = 1;
|
||||
socklen_t optlen = sizeof(optval);
|
||||
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_NODELAY, (void *)&optval, optlen),
|
||||
"Can't set socket no delay");
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_TCP, TCP_NODELAY, (void *)&optval, optlen), "Can't set socket no delay");
|
||||
}
|
||||
|
||||
void Socket::SetTimeout(long sec, long usec) {
|
||||
@ -169,11 +159,9 @@ void Socket::SetTimeout(long sec, long usec) {
|
||||
tv.tv_sec = sec;
|
||||
tv.tv_usec = usec;
|
||||
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
|
||||
"Can't set socket timeout");
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), "Can't set socket timeout");
|
||||
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
|
||||
"Can't set socket timeout");
|
||||
MG_ASSERT(!setsockopt(socket_, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), "Can't set socket timeout");
|
||||
}
|
||||
|
||||
int Socket::ErrorStatus() const {
|
||||
@ -238,8 +226,7 @@ bool Socket::Write(const uint8_t *data, size_t len, bool have_more) {
|
||||
}
|
||||
|
||||
bool Socket::Write(const std::string &s, bool have_more) {
|
||||
return Write(reinterpret_cast<const uint8_t *>(s.data()), s.size(),
|
||||
have_more);
|
||||
return Write(reinterpret_cast<const uint8_t *>(s.data()), s.size(), have_more);
|
||||
}
|
||||
|
||||
ssize_t Socket::Read(void *buffer, size_t len, bool nonblock) {
|
||||
|
@ -14,4 +14,4 @@ struct StreamBuffer {
|
||||
uint8_t *data;
|
||||
size_t len;
|
||||
};
|
||||
}
|
||||
} // namespace io::network
|
||||
|
@ -23,9 +23,8 @@ std::string ResolveHostname(std::string hostname) {
|
||||
|
||||
int addr_result;
|
||||
addrinfo *servinfo;
|
||||
MG_ASSERT((addr_result =
|
||||
getaddrinfo(hostname.c_str(), NULL, &hints, &servinfo)) == 0,
|
||||
"Error with getaddrinfo: {}", gai_strerror(addr_result));
|
||||
MG_ASSERT((addr_result = getaddrinfo(hostname.c_str(), NULL, &hints, &servinfo)) == 0, "Error with getaddrinfo: {}",
|
||||
gai_strerror(addr_result));
|
||||
MG_ASSERT(servinfo, "Could not resolve address: {}", hostname);
|
||||
|
||||
std::string address;
|
||||
|
@ -12,18 +12,16 @@ struct KVStore::impl {
|
||||
rocksdb::Options options;
|
||||
};
|
||||
|
||||
KVStore::KVStore(std::filesystem::path storage)
|
||||
: pimpl_(std::make_unique<impl>()) {
|
||||
KVStore::KVStore(std::filesystem::path storage) : pimpl_(std::make_unique<impl>()) {
|
||||
pimpl_->storage = storage;
|
||||
if (!utils::EnsureDir(pimpl_->storage))
|
||||
throw KVStoreError("Folder for the key-value store " +
|
||||
pimpl_->storage.string() + " couldn't be initialized!");
|
||||
throw KVStoreError("Folder for the key-value store " + pimpl_->storage.string() + " couldn't be initialized!");
|
||||
pimpl_->options.create_if_missing = true;
|
||||
rocksdb::DB *db = nullptr;
|
||||
auto s = rocksdb::DB::Open(pimpl_->options, storage.c_str(), &db);
|
||||
if (!s.ok())
|
||||
throw KVStoreError("RocksDB couldn't be initialized inside " +
|
||||
storage.string() + " -- " + std::string(s.ToString()));
|
||||
throw KVStoreError("RocksDB couldn't be initialized inside " + storage.string() + " -- " +
|
||||
std::string(s.ToString()));
|
||||
pimpl_->db.reset(db);
|
||||
}
|
||||
|
||||
@ -72,18 +70,15 @@ bool KVStore::DeleteMultiple(const std::vector<std::string> &keys) {
|
||||
}
|
||||
|
||||
bool KVStore::DeletePrefix(const std::string &prefix) {
|
||||
std::unique_ptr<rocksdb::Iterator> iter = std::unique_ptr<rocksdb::Iterator>(
|
||||
pimpl_->db->NewIterator(rocksdb::ReadOptions()));
|
||||
for (iter->Seek(prefix); iter->Valid() && iter->key().starts_with(prefix);
|
||||
iter->Next()) {
|
||||
if (!pimpl_->db->Delete(rocksdb::WriteOptions(), iter->key()).ok())
|
||||
return false;
|
||||
std::unique_ptr<rocksdb::Iterator> iter =
|
||||
std::unique_ptr<rocksdb::Iterator>(pimpl_->db->NewIterator(rocksdb::ReadOptions()));
|
||||
for (iter->Seek(prefix); iter->Valid() && iter->key().starts_with(prefix); iter->Next()) {
|
||||
if (!pimpl_->db->Delete(rocksdb::WriteOptions(), iter->key()).ok()) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool KVStore::PutAndDeleteMultiple(
|
||||
const std::map<std::string, std::string> &items,
|
||||
bool KVStore::PutAndDeleteMultiple(const std::map<std::string, std::string> &items,
|
||||
const std::vector<std::string> &keys) {
|
||||
rocksdb::WriteBatch batch;
|
||||
for (const auto &item : items) {
|
||||
@ -105,22 +100,16 @@ struct KVStore::iterator::impl {
|
||||
std::pair<std::string, std::string> disk_prop;
|
||||
};
|
||||
|
||||
KVStore::iterator::iterator(const KVStore *kvstore, const std::string &prefix,
|
||||
bool at_end)
|
||||
KVStore::iterator::iterator(const KVStore *kvstore, const std::string &prefix, bool at_end)
|
||||
: pimpl_(std::make_unique<impl>()) {
|
||||
pimpl_->kvstore = kvstore;
|
||||
pimpl_->prefix = prefix;
|
||||
pimpl_->it = std::unique_ptr<rocksdb::Iterator>(
|
||||
pimpl_->kvstore->pimpl_->db->NewIterator(rocksdb::ReadOptions()));
|
||||
pimpl_->it = std::unique_ptr<rocksdb::Iterator>(pimpl_->kvstore->pimpl_->db->NewIterator(rocksdb::ReadOptions()));
|
||||
pimpl_->it->Seek(pimpl_->prefix);
|
||||
if (!pimpl_->it->Valid() || !pimpl_->it->key().starts_with(pimpl_->prefix) ||
|
||||
at_end)
|
||||
pimpl_->it = nullptr;
|
||||
if (!pimpl_->it->Valid() || !pimpl_->it->key().starts_with(pimpl_->prefix) || at_end) pimpl_->it = nullptr;
|
||||
}
|
||||
|
||||
KVStore::iterator::iterator(KVStore::iterator &&other) {
|
||||
pimpl_ = std::move(other.pimpl_);
|
||||
}
|
||||
KVStore::iterator::iterator(KVStore::iterator &&other) { pimpl_ = std::move(other.pimpl_); }
|
||||
|
||||
KVStore::iterator::~iterator() {}
|
||||
|
||||
@ -131,24 +120,19 @@ KVStore::iterator &KVStore::iterator::operator=(KVStore::iterator &&other) {
|
||||
|
||||
KVStore::iterator &KVStore::iterator::operator++() {
|
||||
pimpl_->it->Next();
|
||||
if (!pimpl_->it->Valid() || !pimpl_->it->key().starts_with(pimpl_->prefix))
|
||||
pimpl_->it = nullptr;
|
||||
if (!pimpl_->it->Valid() || !pimpl_->it->key().starts_with(pimpl_->prefix)) pimpl_->it = nullptr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool KVStore::iterator::operator==(const iterator &other) const {
|
||||
return pimpl_->kvstore == other.pimpl_->kvstore &&
|
||||
pimpl_->prefix == other.pimpl_->prefix &&
|
||||
return pimpl_->kvstore == other.pimpl_->kvstore && pimpl_->prefix == other.pimpl_->prefix &&
|
||||
pimpl_->it == other.pimpl_->it;
|
||||
}
|
||||
|
||||
bool KVStore::iterator::operator!=(const iterator &other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
bool KVStore::iterator::operator!=(const iterator &other) const { return !(*this == other); }
|
||||
|
||||
KVStore::iterator::reference KVStore::iterator::operator*() {
|
||||
pimpl_->disk_prop = {pimpl_->it->key().ToString(),
|
||||
pimpl_->it->value().ToString()};
|
||||
pimpl_->disk_prop = {pimpl_->it->key().ToString(), pimpl_->it->value().ToString()};
|
||||
return pimpl_->disk_prop;
|
||||
}
|
||||
|
||||
@ -166,8 +150,7 @@ size_t KVStore::Size(const std::string &prefix) {
|
||||
return size;
|
||||
}
|
||||
|
||||
bool KVStore::CompactRange(const std::string &begin_prefix,
|
||||
const std::string &end_prefix) {
|
||||
bool KVStore::CompactRange(const std::string &begin_prefix, const std::string &end_prefix) {
|
||||
rocksdb::CompactRangeOptions options;
|
||||
rocksdb::Slice begin(begin_prefix);
|
||||
rocksdb::Slice end(end_prefix);
|
||||
|
@ -114,8 +114,7 @@ class KVStore final {
|
||||
* @return true if the items have been successfully stored and deleted.
|
||||
* In case of any error false is going to be returned.
|
||||
*/
|
||||
bool PutAndDeleteMultiple(const std::map<std::string, std::string> &items,
|
||||
const std::vector<std::string> &keys);
|
||||
bool PutAndDeleteMultiple(const std::map<std::string, std::string> &items, const std::vector<std::string> &keys);
|
||||
|
||||
/**
|
||||
* Returns total number of stored (key, value) pairs. The function takes an
|
||||
@ -140,8 +139,7 @@ class KVStore final {
|
||||
*
|
||||
* @return - true if the compaction finished successfully, false otherwise.
|
||||
*/
|
||||
bool CompactRange(const std::string &begin_prefix,
|
||||
const std::string &end_prefix);
|
||||
bool CompactRange(const std::string &begin_prefix, const std::string &end_prefix);
|
||||
|
||||
/**
|
||||
* Custom prefix-based iterator over kvstore.
|
||||
@ -150,17 +148,14 @@ class KVStore final {
|
||||
* and behaves as if all of those pairs are stored in a single iterable
|
||||
* collection of std::pair<std::string, std::string>.
|
||||
*/
|
||||
class iterator final
|
||||
: public std::iterator<
|
||||
std::input_iterator_tag, // iterator_category
|
||||
class iterator final : public std::iterator<std::input_iterator_tag, // iterator_category
|
||||
std::pair<std::string, std::string>, // value_type
|
||||
long, // difference_type
|
||||
const std::pair<std::string, std::string> *, // pointer
|
||||
const std::pair<std::string, std::string> & // reference
|
||||
> {
|
||||
public:
|
||||
explicit iterator(const KVStore *kvstore, const std::string &prefix = "",
|
||||
bool at_end = false);
|
||||
explicit iterator(const KVStore *kvstore, const std::string &prefix = "", bool at_end = false);
|
||||
|
||||
iterator(const iterator &other) = delete;
|
||||
|
||||
@ -191,13 +186,9 @@ class KVStore final {
|
||||
std::unique_ptr<impl> pimpl_;
|
||||
};
|
||||
|
||||
iterator begin(const std::string &prefix = "") {
|
||||
return iterator(this, prefix);
|
||||
}
|
||||
iterator begin(const std::string &prefix = "") { return iterator(this, prefix); }
|
||||
|
||||
iterator end(const std::string &prefix = "") {
|
||||
return iterator(this, prefix, true);
|
||||
}
|
||||
iterator end(const std::string &prefix = "") { return iterator(this, prefix, true); }
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
|
@ -26,8 +26,7 @@ std::optional<std::string> KVStore::Get(const std::string &key) const noexcept {
|
||||
}
|
||||
|
||||
bool KVStore::Delete(const std::string &key) {
|
||||
LOG_FATAL(
|
||||
"Unsupported operation (KVStore::Delete) -- this is a dummy kvstore");
|
||||
LOG_FATAL("Unsupported operation (KVStore::Delete) -- this is a dummy kvstore");
|
||||
}
|
||||
|
||||
bool KVStore::DeleteMultiple(const std::vector<std::string> &keys) {
|
||||
@ -42,8 +41,7 @@ bool KVStore::DeletePrefix(const std::string &prefix) {
|
||||
"dummy kvstore");
|
||||
}
|
||||
|
||||
bool KVStore::PutAndDeleteMultiple(
|
||||
const std::map<std::string, std::string> &items,
|
||||
bool KVStore::PutAndDeleteMultiple(const std::map<std::string, std::string> &items,
|
||||
const std::vector<std::string> &keys) {
|
||||
LOG_FATAL(
|
||||
"Unsupported operation (KVStore::PutAndDeleteMultiple) -- this is a "
|
||||
@ -54,13 +52,9 @@ bool KVStore::PutAndDeleteMultiple(
|
||||
|
||||
struct KVStore::iterator::impl {};
|
||||
|
||||
KVStore::iterator::iterator(const KVStore *kvstore, const std::string &prefix,
|
||||
bool at_end)
|
||||
: pimpl_(new impl()) {}
|
||||
KVStore::iterator::iterator(const KVStore *kvstore, const std::string &prefix, bool at_end) : pimpl_(new impl()) {}
|
||||
|
||||
KVStore::iterator::iterator(KVStore::iterator &&other) {
|
||||
pimpl_ = std::move(other.pimpl_);
|
||||
}
|
||||
KVStore::iterator::iterator(KVStore::iterator &&other) { pimpl_ = std::move(other.pimpl_); }
|
||||
|
||||
KVStore::iterator::~iterator() {}
|
||||
|
||||
@ -77,9 +71,7 @@ KVStore::iterator &KVStore::iterator::operator++() {
|
||||
|
||||
bool KVStore::iterator::operator==(const iterator &other) const { return true; }
|
||||
|
||||
bool KVStore::iterator::operator!=(const iterator &other) const {
|
||||
return false;
|
||||
}
|
||||
bool KVStore::iterator::operator!=(const iterator &other) const { return false; }
|
||||
|
||||
KVStore::iterator::reference KVStore::iterator::operator*() {
|
||||
LOG_FATAL(
|
||||
@ -99,8 +91,7 @@ bool KVStore::iterator::IsValid() { return false; }
|
||||
|
||||
size_t KVStore::Size(const std::string &prefix) { return 0; }
|
||||
|
||||
bool KVStore::CompactRange(const std::string &begin_prefix,
|
||||
const std::string &end_prefix) {
|
||||
bool KVStore::CompactRange(const std::string &begin_prefix, const std::string &end_prefix) {
|
||||
LOG_FATAL(
|
||||
"Unsupported operation (KVStore::Compact) -- this is a "
|
||||
"dummy kvstore");
|
||||
|
365
src/memgraph.cpp
365
src/memgraph.cpp
@ -63,25 +63,19 @@
|
||||
#endif
|
||||
|
||||
// Bolt server flags.
|
||||
DEFINE_string(bolt_address, "0.0.0.0",
|
||||
"IP address on which the Bolt server should listen.");
|
||||
DEFINE_VALIDATED_int32(bolt_port, 7687,
|
||||
"Port on which the Bolt server should listen.",
|
||||
DEFINE_string(bolt_address, "0.0.0.0", "IP address on which the Bolt server should listen.");
|
||||
DEFINE_VALIDATED_int32(bolt_port, 7687, "Port on which the Bolt server should listen.",
|
||||
FLAG_IN_RANGE(0, std::numeric_limits<uint16_t>::max()));
|
||||
DEFINE_VALIDATED_int32(
|
||||
bolt_num_workers, std::max(std::thread::hardware_concurrency(), 1U),
|
||||
DEFINE_VALIDATED_int32(bolt_num_workers, std::max(std::thread::hardware_concurrency(), 1U),
|
||||
"Number of workers used by the Bolt server. By default, this will be the "
|
||||
"number of processing units available on the machine.",
|
||||
FLAG_IN_RANGE(1, INT32_MAX));
|
||||
DEFINE_VALIDATED_int32(
|
||||
bolt_session_inactivity_timeout, 1800,
|
||||
DEFINE_VALIDATED_int32(bolt_session_inactivity_timeout, 1800,
|
||||
"Time in seconds after which inactive Bolt sessions will be "
|
||||
"closed.",
|
||||
FLAG_IN_RANGE(1, INT32_MAX));
|
||||
DEFINE_string(bolt_cert_file, "",
|
||||
"Certificate file which should be used for the Bolt server.");
|
||||
DEFINE_string(bolt_key_file, "",
|
||||
"Key file which should be used for the Bolt server.");
|
||||
DEFINE_string(bolt_cert_file, "", "Certificate file which should be used for the Bolt server.");
|
||||
DEFINE_string(bolt_key_file, "", "Key file which should be used for the Bolt server.");
|
||||
DEFINE_string(bolt_server_name_for_init, "",
|
||||
"Server name which the database should send to the client in the "
|
||||
"Bolt INIT message.");
|
||||
@ -89,26 +83,20 @@ DEFINE_string(bolt_server_name_for_init, "",
|
||||
// General purpose flags.
|
||||
// NOTE: The `data_directory` flag must be the same here and in
|
||||
// `mg_import_csv`. If you change it, make sure to change it there as well.
|
||||
DEFINE_string(data_directory, "mg_data",
|
||||
"Path to directory in which to save all permanent data.");
|
||||
DEFINE_HIDDEN_string(
|
||||
log_link_basename, "",
|
||||
"Basename used for symlink creation to the last log file.");
|
||||
DEFINE_string(data_directory, "mg_data", "Path to directory in which to save all permanent data.");
|
||||
DEFINE_HIDDEN_string(log_link_basename, "", "Basename used for symlink creation to the last log file.");
|
||||
DEFINE_uint64(memory_warning_threshold, 1024,
|
||||
"Memory warning threshold, in MB. If Memgraph detects there is "
|
||||
"less available RAM it will log a warning. Set to 0 to "
|
||||
"disable.");
|
||||
|
||||
// Storage flags.
|
||||
DEFINE_VALIDATED_uint64(storage_gc_cycle_sec, 30,
|
||||
"Storage garbage collector interval (in seconds).",
|
||||
DEFINE_VALIDATED_uint64(storage_gc_cycle_sec, 30, "Storage garbage collector interval (in seconds).",
|
||||
FLAG_IN_RANGE(1, 24 * 3600));
|
||||
// NOTE: The `storage_properties_on_edges` flag must be the same here and in
|
||||
// `mg_import_csv`. If you change it, make sure to change it there as well.
|
||||
DEFINE_bool(storage_properties_on_edges, false,
|
||||
"Controls whether edges have properties.");
|
||||
DEFINE_bool(storage_recover_on_startup, false,
|
||||
"Controls whether the storage recovers persisted data on startup.");
|
||||
DEFINE_bool(storage_properties_on_edges, false, "Controls whether edges have properties.");
|
||||
DEFINE_bool(storage_recover_on_startup, false, "Controls whether the storage recovers persisted data on startup.");
|
||||
DEFINE_VALIDATED_uint64(storage_snapshot_interval_sec, 0,
|
||||
"Storage snapshot creation interval (in seconds). Set "
|
||||
"to 0 to disable periodic snapshot creation.",
|
||||
@ -116,21 +104,15 @@ DEFINE_VALIDATED_uint64(storage_snapshot_interval_sec, 0,
|
||||
DEFINE_bool(storage_wal_enabled, false,
|
||||
"Controls whether the storage uses write-ahead-logging. To enable "
|
||||
"WAL periodic snapshots must be enabled.");
|
||||
DEFINE_VALIDATED_uint64(storage_snapshot_retention_count, 3,
|
||||
"The number of snapshots that should always be kept.",
|
||||
DEFINE_VALIDATED_uint64(storage_snapshot_retention_count, 3, "The number of snapshots that should always be kept.",
|
||||
FLAG_IN_RANGE(1, 1000000));
|
||||
DEFINE_VALIDATED_uint64(storage_wal_file_size_kib,
|
||||
storage::Config::Durability().wal_file_size_kibibytes,
|
||||
"Minimum file size of each WAL file.",
|
||||
FLAG_IN_RANGE(1, 1000 * 1024));
|
||||
DEFINE_VALIDATED_uint64(
|
||||
storage_wal_file_flush_every_n_tx,
|
||||
storage::Config::Durability().wal_file_flush_every_n_tx,
|
||||
DEFINE_VALIDATED_uint64(storage_wal_file_size_kib, storage::Config::Durability().wal_file_size_kibibytes,
|
||||
"Minimum file size of each WAL file.", FLAG_IN_RANGE(1, 1000 * 1024));
|
||||
DEFINE_VALIDATED_uint64(storage_wal_file_flush_every_n_tx, storage::Config::Durability().wal_file_flush_every_n_tx,
|
||||
"Issue a 'fsync' call after this amount of transactions are written to the "
|
||||
"WAL file. Set to 1 for fully synchronous operation.",
|
||||
FLAG_IN_RANGE(1, 1000000));
|
||||
DEFINE_bool(storage_snapshot_on_exit, false,
|
||||
"Controls whether the storage creates another snapshot on exit.");
|
||||
DEFINE_bool(storage_snapshot_on_exit, false, "Controls whether the storage creates another snapshot on exit.");
|
||||
|
||||
DEFINE_bool(telemetry_enabled, false,
|
||||
"Set to true to enable telemetry. We collect information about the "
|
||||
@ -141,11 +123,9 @@ DEFINE_bool(telemetry_enabled, false,
|
||||
// Audit logging flags.
|
||||
#ifdef MG_ENTERPRISE
|
||||
DEFINE_bool(audit_enabled, false, "Set to true to enable audit logging.");
|
||||
DEFINE_VALIDATED_int32(audit_buffer_size, audit::kBufferSizeDefault,
|
||||
"Maximum number of items in the audit log buffer.",
|
||||
DEFINE_VALIDATED_int32(audit_buffer_size, audit::kBufferSizeDefault, "Maximum number of items in the audit log buffer.",
|
||||
FLAG_IN_RANGE(1, INT32_MAX));
|
||||
DEFINE_VALIDATED_int32(
|
||||
audit_buffer_flush_interval_ms, audit::kBufferFlushIntervalMillisDefault,
|
||||
DEFINE_VALIDATED_int32(audit_buffer_flush_interval_ms, audit::kBufferFlushIntervalMillisDefault,
|
||||
"Interval (in milliseconds) used for flushing the audit log buffer.",
|
||||
FLAG_IN_RANGE(10, INT32_MAX));
|
||||
#endif
|
||||
@ -155,41 +135,34 @@ DEFINE_uint64(query_execution_timeout_sec, 180,
|
||||
"Maximum allowed query execution time. Queries exceeding this "
|
||||
"limit will be aborted. Value of 0 means no limit.");
|
||||
|
||||
DEFINE_VALIDATED_string(
|
||||
query_modules_directory, "",
|
||||
"Directory where modules with custom query procedures are stored.", {
|
||||
DEFINE_VALIDATED_string(query_modules_directory, "", "Directory where modules with custom query procedures are stored.",
|
||||
{
|
||||
if (value.empty()) return true;
|
||||
if (utils::DirExists(value)) return true;
|
||||
std::cout << "Expected --" << flagname << " to point to a directory."
|
||||
<< std::endl;
|
||||
std::cout << "Expected --" << flagname << " to point to a directory." << std::endl;
|
||||
return false;
|
||||
});
|
||||
|
||||
// Logging flags
|
||||
DEFINE_bool(also_log_to_stderr, false,
|
||||
"Log messages go to stderr in addition to logfiles");
|
||||
DEFINE_bool(also_log_to_stderr, false, "Log messages go to stderr in addition to logfiles");
|
||||
DEFINE_string(log_file, "", "Path to where the log should be stored.");
|
||||
|
||||
namespace {
|
||||
constexpr std::array log_level_mappings{
|
||||
std::pair{"TRACE", spdlog::level::trace},
|
||||
std::pair{"DEBUG", spdlog::level::debug},
|
||||
std::pair{"INFO", spdlog::level::info},
|
||||
std::pair{"WARNING", spdlog::level::warn},
|
||||
std::pair{"ERROR", spdlog::level::err},
|
||||
std::pair{"CRITICAL", spdlog::level::critical}};
|
||||
std::pair{"TRACE", spdlog::level::trace}, std::pair{"DEBUG", spdlog::level::debug},
|
||||
std::pair{"INFO", spdlog::level::info}, std::pair{"WARNING", spdlog::level::warn},
|
||||
std::pair{"ERROR", spdlog::level::err}, std::pair{"CRITICAL", spdlog::level::critical}};
|
||||
|
||||
std::string GetAllowedLogLevelsString() {
|
||||
std::vector<std::string> allowed_log_levels;
|
||||
allowed_log_levels.reserve(log_level_mappings.size());
|
||||
std::transform(log_level_mappings.cbegin(), log_level_mappings.cend(),
|
||||
std::back_inserter(allowed_log_levels),
|
||||
std::transform(log_level_mappings.cbegin(), log_level_mappings.cend(), std::back_inserter(allowed_log_levels),
|
||||
[](const auto &mapping) { return mapping.first; });
|
||||
return utils::Join(allowed_log_levels, ", ");
|
||||
}
|
||||
|
||||
const std::string log_level_help_string = fmt::format(
|
||||
"Minimum log level. Allowed values: {}", GetAllowedLogLevelsString());
|
||||
const std::string log_level_help_string =
|
||||
fmt::format("Minimum log level. Allowed values: {}", GetAllowedLogLevelsString());
|
||||
} // namespace
|
||||
|
||||
DEFINE_VALIDATED_string(log_level, "WARNING", log_level_help_string.c_str(), {
|
||||
@ -199,11 +172,8 @@ DEFINE_VALIDATED_string(log_level, "WARNING", log_level_help_string.c_str(), {
|
||||
}
|
||||
|
||||
if (std::find_if(log_level_mappings.cbegin(), log_level_mappings.cend(),
|
||||
[&](const auto &mapping) {
|
||||
return mapping.first == value;
|
||||
}) == log_level_mappings.cend()) {
|
||||
std::cout << "Invalid value for log level. Allowed values: "
|
||||
<< GetAllowedLogLevelsString() << std::endl;
|
||||
[&](const auto &mapping) { return mapping.first == value; }) == log_level_mappings.cend()) {
|
||||
std::cout << "Invalid value for log level. Allowed values: " << GetAllowedLogLevelsString() << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -212,8 +182,7 @@ DEFINE_VALIDATED_string(log_level, "WARNING", log_level_help_string.c_str(), {
|
||||
|
||||
namespace {
|
||||
void ParseLogLevel() {
|
||||
const auto mapping_iter = std::find_if(
|
||||
log_level_mappings.cbegin(), log_level_mappings.cend(),
|
||||
const auto mapping_iter = std::find_if(log_level_mappings.cbegin(), log_level_mappings.cend(),
|
||||
[](const auto &mapping) { return mapping.first == FLAGS_log_level; });
|
||||
MG_ASSERT(mapping_iter != log_level_mappings.cend(), "Invalid log level");
|
||||
|
||||
@ -227,8 +196,7 @@ void ConfigureLogging() {
|
||||
std::vector<spdlog::sink_ptr> loggers;
|
||||
|
||||
if (FLAGS_also_log_to_stderr) {
|
||||
loggers.emplace_back(
|
||||
std::make_shared<spdlog::sinks::stderr_color_sink_mt>());
|
||||
loggers.emplace_back(std::make_shared<spdlog::sinks::stderr_color_sink_mt>());
|
||||
}
|
||||
|
||||
if (!FLAGS_log_file.empty()) {
|
||||
@ -240,12 +208,10 @@ void ConfigureLogging() {
|
||||
local_time = localtime(¤t_time);
|
||||
|
||||
loggers.emplace_back(std::make_shared<spdlog::sinks::daily_file_sink_mt>(
|
||||
FLAGS_log_file, local_time->tm_hour, local_time->tm_min, false,
|
||||
log_retention_count));
|
||||
FLAGS_log_file, local_time->tm_hour, local_time->tm_min, false, log_retention_count));
|
||||
}
|
||||
|
||||
spdlog::set_default_logger(std::make_shared<spdlog::logger>(
|
||||
"memgraph_log", loggers.begin(), loggers.end()));
|
||||
spdlog::set_default_logger(std::make_shared<spdlog::logger>("memgraph_log", loggers.begin(), loggers.end()));
|
||||
|
||||
spdlog::flush_on(spdlog::level::trace);
|
||||
ParseLogLevel();
|
||||
@ -258,13 +224,9 @@ void ConfigureLogging() {
|
||||
struct SessionData {
|
||||
// Explicit constructor here to ensure that pointers to all objects are
|
||||
// supplied.
|
||||
SessionData(storage::Storage *db,
|
||||
query::InterpreterContext *interpreter_context, auth::Auth *auth,
|
||||
SessionData(storage::Storage *db, query::InterpreterContext *interpreter_context, auth::Auth *auth,
|
||||
audit::Log *audit_log)
|
||||
: db(db),
|
||||
interpreter_context(interpreter_context),
|
||||
auth(auth),
|
||||
audit_log(audit_log) {}
|
||||
: db(db), interpreter_context(interpreter_context), auth(auth), audit_log(audit_log) {}
|
||||
storage::Storage *db;
|
||||
query::InterpreterContext *interpreter_context;
|
||||
auth::Auth *auth;
|
||||
@ -274,24 +236,19 @@ struct SessionData {
|
||||
struct SessionData {
|
||||
// Explicit constructor here to ensure that pointers to all objects are
|
||||
// supplied.
|
||||
SessionData(storage::Storage *db,
|
||||
query::InterpreterContext *interpreter_context)
|
||||
SessionData(storage::Storage *db, query::InterpreterContext *interpreter_context)
|
||||
: db(db), interpreter_context(interpreter_context) {}
|
||||
storage::Storage *db;
|
||||
query::InterpreterContext *interpreter_context;
|
||||
};
|
||||
#endif
|
||||
|
||||
class BoltSession final
|
||||
: public communication::bolt::Session<communication::InputStream,
|
||||
communication::OutputStream> {
|
||||
class BoltSession final : public communication::bolt::Session<communication::InputStream, communication::OutputStream> {
|
||||
public:
|
||||
BoltSession(SessionData *data, const io::network::Endpoint &endpoint,
|
||||
communication::InputStream *input_stream,
|
||||
BoltSession(SessionData *data, const io::network::Endpoint &endpoint, communication::InputStream *input_stream,
|
||||
communication::OutputStream *output_stream)
|
||||
: communication::bolt::Session<communication::InputStream,
|
||||
communication::OutputStream>(
|
||||
input_stream, output_stream),
|
||||
: communication::bolt::Session<communication::InputStream, communication::OutputStream>(input_stream,
|
||||
output_stream),
|
||||
db_(data->db),
|
||||
interpreter_(data->interpreter_context),
|
||||
#ifdef MG_ENTERPRISE
|
||||
@ -301,8 +258,7 @@ class BoltSession final
|
||||
endpoint_(endpoint) {
|
||||
}
|
||||
|
||||
using communication::bolt::Session<communication::InputStream,
|
||||
communication::OutputStream>::TEncoder;
|
||||
using communication::bolt::Session<communication::InputStream, communication::OutputStream>::TEncoder;
|
||||
|
||||
void BeginTransaction() override { interpreter_.BeginTransaction(); }
|
||||
|
||||
@ -311,15 +267,11 @@ class BoltSession final
|
||||
void RollbackTransaction() override { interpreter_.RollbackTransaction(); }
|
||||
|
||||
std::pair<std::vector<std::string>, std::optional<int>> Interpret(
|
||||
const std::string &query,
|
||||
const std::map<std::string, communication::bolt::Value> ¶ms)
|
||||
override {
|
||||
const std::string &query, const std::map<std::string, communication::bolt::Value> ¶ms) override {
|
||||
std::map<std::string, storage::PropertyValue> params_pv;
|
||||
for (const auto &kv : params)
|
||||
params_pv.emplace(kv.first, glue::ToPropertyValue(kv.second));
|
||||
for (const auto &kv : params) params_pv.emplace(kv.first, glue::ToPropertyValue(kv.second));
|
||||
#ifdef MG_ENTERPRISE
|
||||
audit_log_->Record(endpoint_.address, user_ ? user_->username() : "", query,
|
||||
storage::PropertyValue(params_pv));
|
||||
audit_log_->Record(endpoint_.address, user_ ? user_->username() : "", query, storage::PropertyValue(params_pv));
|
||||
#endif
|
||||
try {
|
||||
auto result = interpreter_.Prepare(query, params_pv);
|
||||
@ -327,8 +279,7 @@ class BoltSession final
|
||||
if (user_) {
|
||||
const auto &permissions = user_->GetPermissions();
|
||||
for (const auto &privilege : result.privileges) {
|
||||
if (permissions.Has(glue::PrivilegeToPermission(privilege)) !=
|
||||
auth::PermissionLevel::GRANT) {
|
||||
if (permissions.Has(glue::PrivilegeToPermission(privilege)) != auth::PermissionLevel::GRANT) {
|
||||
interpreter_.Abort();
|
||||
throw communication::bolt::ClientError(
|
||||
"You are not authorized to execute this query! Please contact "
|
||||
@ -346,23 +297,20 @@ class BoltSession final
|
||||
}
|
||||
}
|
||||
|
||||
std::map<std::string, communication::bolt::Value> Pull(
|
||||
TEncoder *encoder, std::optional<int> n,
|
||||
std::map<std::string, communication::bolt::Value> Pull(TEncoder *encoder, std::optional<int> n,
|
||||
std::optional<int> qid) override {
|
||||
TypedValueResultStream stream(encoder, db_);
|
||||
return PullResults(stream, n, qid);
|
||||
}
|
||||
|
||||
std::map<std::string, communication::bolt::Value> Discard(
|
||||
std::optional<int> n, std::optional<int> qid) override {
|
||||
std::map<std::string, communication::bolt::Value> Discard(std::optional<int> n, std::optional<int> qid) override {
|
||||
DiscardValueResultStream stream;
|
||||
return PullResults(stream, n, qid);
|
||||
}
|
||||
|
||||
void Abort() override { interpreter_.Abort(); }
|
||||
|
||||
bool Authenticate(const std::string &username,
|
||||
const std::string &password) override {
|
||||
bool Authenticate(const std::string &username, const std::string &password) override {
|
||||
#ifdef MG_ENTERPRISE
|
||||
if (!auth_->HasUsers()) return true;
|
||||
user_ = auth_->Authenticate(username, password);
|
||||
@ -379,14 +327,13 @@ class BoltSession final
|
||||
|
||||
private:
|
||||
template <typename TStream>
|
||||
std::map<std::string, communication::bolt::Value> PullResults(
|
||||
TStream &stream, std::optional<int> n, std::optional<int> qid) {
|
||||
std::map<std::string, communication::bolt::Value> PullResults(TStream &stream, std::optional<int> n,
|
||||
std::optional<int> qid) {
|
||||
try {
|
||||
const auto &summary = interpreter_.Pull(&stream, n, qid);
|
||||
std::map<std::string, communication::bolt::Value> decoded_summary;
|
||||
for (const auto &kv : summary) {
|
||||
auto maybe_value =
|
||||
glue::ToBoltValue(kv.second, *db_, storage::View::NEW);
|
||||
auto maybe_value = glue::ToBoltValue(kv.second, *db_, storage::View::NEW);
|
||||
if (maybe_value.HasError()) {
|
||||
switch (maybe_value.GetError()) {
|
||||
case storage::Error::DELETED_OBJECT:
|
||||
@ -394,8 +341,7 @@ class BoltSession final
|
||||
case storage::Error::VERTEX_HAS_EDGES:
|
||||
case storage::Error::PROPERTIES_DISABLED:
|
||||
case storage::Error::NONEXISTENT_OBJECT:
|
||||
throw communication::bolt::ClientError(
|
||||
"Unexpected storage error when streaming summary.");
|
||||
throw communication::bolt::ClientError("Unexpected storage error when streaming summary.");
|
||||
}
|
||||
}
|
||||
decoded_summary.emplace(kv.first, std::move(*maybe_value));
|
||||
@ -412,8 +358,7 @@ class BoltSession final
|
||||
/// before forwarding the calls to original TEncoder.
|
||||
class TypedValueResultStream {
|
||||
public:
|
||||
TypedValueResultStream(TEncoder *encoder, const storage::Storage *db)
|
||||
: encoder_(encoder), db_(db) {}
|
||||
TypedValueResultStream(TEncoder *encoder, const storage::Storage *db) : encoder_(encoder), db_(db) {}
|
||||
|
||||
void Result(const std::vector<query::TypedValue> &values) {
|
||||
std::vector<communication::bolt::Value> decoded_values;
|
||||
@ -423,16 +368,13 @@ class BoltSession final
|
||||
if (maybe_value.HasError()) {
|
||||
switch (maybe_value.GetError()) {
|
||||
case storage::Error::DELETED_OBJECT:
|
||||
throw communication::bolt::ClientError(
|
||||
"Returning a deleted object as a result.");
|
||||
throw communication::bolt::ClientError("Returning a deleted object as a result.");
|
||||
case storage::Error::NONEXISTENT_OBJECT:
|
||||
throw communication::bolt::ClientError(
|
||||
"Returning a nonexistent object as a result.");
|
||||
throw communication::bolt::ClientError("Returning a nonexistent object as a result.");
|
||||
case storage::Error::VERTEX_HAS_EDGES:
|
||||
case storage::Error::SERIALIZATION_ERROR:
|
||||
case storage::Error::PROPERTIES_DISABLED:
|
||||
throw communication::bolt::ClientError(
|
||||
"Unexpected storage error when streaming results.");
|
||||
throw communication::bolt::ClientError("Unexpected storage error when streaming results.");
|
||||
}
|
||||
}
|
||||
decoded_values.emplace_back(std::move(*maybe_value));
|
||||
@ -467,8 +409,7 @@ using ServerT = communication::Server<BoltSession, SessionData>;
|
||||
using communication::ServerContext;
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
DEFINE_string(
|
||||
auth_user_or_role_name_regex, "[a-zA-Z0-9_.+-@]+",
|
||||
DEFINE_string(auth_user_or_role_name_regex, "[a-zA-Z0-9_.+-@]+",
|
||||
"Set to the regular expression that each user or role name must fulfill.");
|
||||
|
||||
class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
@ -476,11 +417,9 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
std::regex name_regex_;
|
||||
|
||||
public:
|
||||
AuthQueryHandler(auth::Auth *auth, const std::regex &name_regex)
|
||||
: auth_(auth), name_regex_(name_regex) {}
|
||||
AuthQueryHandler(auth::Auth *auth, const std::regex &name_regex) : auth_(auth), name_regex_(name_regex) {}
|
||||
|
||||
bool CreateUser(const std::string &username,
|
||||
const std::optional<std::string> &password) override {
|
||||
bool CreateUser(const std::string &username, const std::optional<std::string> &password) override {
|
||||
if (!std::regex_match(username, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user name.");
|
||||
}
|
||||
@ -506,8 +445,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
}
|
||||
}
|
||||
|
||||
void SetPassword(const std::string &username,
|
||||
const std::optional<std::string> &password) override {
|
||||
void SetPassword(const std::string &username, const std::optional<std::string> &password) override {
|
||||
if (!std::regex_match(username, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user name.");
|
||||
}
|
||||
@ -515,8 +453,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
std::lock_guard<std::mutex> lock(auth_->WithLock());
|
||||
auto user = auth_->GetUser(username);
|
||||
if (!user) {
|
||||
throw query::QueryRuntimeException("User '{}' doesn't exist.",
|
||||
username);
|
||||
throw query::QueryRuntimeException("User '{}' doesn't exist.", username);
|
||||
}
|
||||
user->UpdatePassword(password);
|
||||
auth_->SaveUser(*user);
|
||||
@ -581,8 +518,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<std::string> GetRolenameForUser(
|
||||
const std::string &username) override {
|
||||
std::optional<std::string> GetRolenameForUser(const std::string &username) override {
|
||||
if (!std::regex_match(username, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user name.");
|
||||
}
|
||||
@ -590,8 +526,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
std::lock_guard<std::mutex> lock(auth_->WithLock());
|
||||
auto user = auth_->GetUser(username);
|
||||
if (!user) {
|
||||
throw query::QueryRuntimeException("User '{}' doesn't exist .",
|
||||
username);
|
||||
throw query::QueryRuntimeException("User '{}' doesn't exist .", username);
|
||||
}
|
||||
if (user->role()) return user->role()->rolename();
|
||||
return std::nullopt;
|
||||
@ -600,8 +535,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<query::TypedValue> GetUsernamesForRole(
|
||||
const std::string &rolename) override {
|
||||
std::vector<query::TypedValue> GetUsernamesForRole(const std::string &rolename) override {
|
||||
if (!std::regex_match(rolename, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid role name.");
|
||||
}
|
||||
@ -609,8 +543,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
std::lock_guard<std::mutex> lock(auth_->WithLock());
|
||||
auto role = auth_->GetRole(rolename);
|
||||
if (!role) {
|
||||
throw query::QueryRuntimeException("Role '{}' doesn't exist.",
|
||||
rolename);
|
||||
throw query::QueryRuntimeException("Role '{}' doesn't exist.", rolename);
|
||||
}
|
||||
std::vector<query::TypedValue> usernames;
|
||||
const auto &users = auth_->AllUsersForRole(rolename);
|
||||
@ -624,8 +557,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
}
|
||||
}
|
||||
|
||||
void SetRole(const std::string &username,
|
||||
const std::string &rolename) override {
|
||||
void SetRole(const std::string &username, const std::string &rolename) override {
|
||||
if (!std::regex_match(username, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user name.");
|
||||
}
|
||||
@ -636,17 +568,14 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
std::lock_guard<std::mutex> lock(auth_->WithLock());
|
||||
auto user = auth_->GetUser(username);
|
||||
if (!user) {
|
||||
throw query::QueryRuntimeException("User '{}' doesn't exist .",
|
||||
username);
|
||||
throw query::QueryRuntimeException("User '{}' doesn't exist .", username);
|
||||
}
|
||||
auto role = auth_->GetRole(rolename);
|
||||
if (!role) {
|
||||
throw query::QueryRuntimeException("Role '{}' doesn't exist .",
|
||||
rolename);
|
||||
throw query::QueryRuntimeException("Role '{}' doesn't exist .", rolename);
|
||||
}
|
||||
if (user->role()) {
|
||||
throw query::QueryRuntimeException(
|
||||
"User '{}' is already a member of role '{}'.", username,
|
||||
throw query::QueryRuntimeException("User '{}' is already a member of role '{}'.", username,
|
||||
user->role()->rolename());
|
||||
}
|
||||
user->SetRole(*role);
|
||||
@ -664,8 +593,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
std::lock_guard<std::mutex> lock(auth_->WithLock());
|
||||
auto user = auth_->GetUser(username);
|
||||
if (!user) {
|
||||
throw query::QueryRuntimeException("User '{}' doesn't exist .",
|
||||
username);
|
||||
throw query::QueryRuntimeException("User '{}' doesn't exist .", username);
|
||||
}
|
||||
user->ClearRole();
|
||||
auth_->SaveUser(*user);
|
||||
@ -674,8 +602,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<query::TypedValue>> GetPrivileges(
|
||||
const std::string &user_or_role) override {
|
||||
std::vector<std::vector<query::TypedValue>> GetPrivileges(const std::string &user_or_role) override {
|
||||
if (!std::regex_match(user_or_role, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user or role name.");
|
||||
}
|
||||
@ -685,8 +612,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
auto user = auth_->GetUser(user_or_role);
|
||||
auto role = auth_->GetRole(user_or_role);
|
||||
if (!user && !role) {
|
||||
throw query::QueryRuntimeException("User or role '{}' doesn't exist.",
|
||||
user_or_role);
|
||||
throw query::QueryRuntimeException("User or role '{}' doesn't exist.", user_or_role);
|
||||
}
|
||||
if (user) {
|
||||
const auto &permissions = user->GetPermissions();
|
||||
@ -709,8 +635,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
description.emplace_back("DENIED TO ROLE");
|
||||
}
|
||||
}
|
||||
grants.push_back(
|
||||
{query::TypedValue(auth::PermissionToString(permission)),
|
||||
grants.push_back({query::TypedValue(auth::PermissionToString(permission)),
|
||||
query::TypedValue(auth::PermissionLevelToString(effective)),
|
||||
query::TypedValue(utils::Join(description, ", "))});
|
||||
}
|
||||
@ -727,8 +652,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
} else if (effective == auth::PermissionLevel::DENY) {
|
||||
description = "DENIED TO ROLE";
|
||||
}
|
||||
grants.push_back(
|
||||
{query::TypedValue(auth::PermissionToString(permission)),
|
||||
grants.push_back({query::TypedValue(auth::PermissionToString(permission)),
|
||||
query::TypedValue(auth::PermissionLevelToString(effective)),
|
||||
query::TypedValue(description)});
|
||||
}
|
||||
@ -740,11 +664,9 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
}
|
||||
}
|
||||
|
||||
void GrantPrivilege(
|
||||
const std::string &user_or_role,
|
||||
void GrantPrivilege(const std::string &user_or_role,
|
||||
const std::vector<query::AuthQuery::Privilege> &privileges) override {
|
||||
EditPermissions(user_or_role, privileges,
|
||||
[](auto *permissions, const auto &permission) {
|
||||
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
|
||||
// TODO (mferencevic): should we first check that the
|
||||
// privilege is granted/denied/revoked before
|
||||
// unconditionally granting/denying/revoking it?
|
||||
@ -752,11 +674,9 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
});
|
||||
}
|
||||
|
||||
void DenyPrivilege(
|
||||
const std::string &user_or_role,
|
||||
void DenyPrivilege(const std::string &user_or_role,
|
||||
const std::vector<query::AuthQuery::Privilege> &privileges) override {
|
||||
EditPermissions(user_or_role, privileges,
|
||||
[](auto *permissions, const auto &permission) {
|
||||
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
|
||||
// TODO (mferencevic): should we first check that the
|
||||
// privilege is granted/denied/revoked before
|
||||
// unconditionally granting/denying/revoking it?
|
||||
@ -764,11 +684,9 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
});
|
||||
}
|
||||
|
||||
void RevokePrivilege(
|
||||
const std::string &user_or_role,
|
||||
void RevokePrivilege(const std::string &user_or_role,
|
||||
const std::vector<query::AuthQuery::Privilege> &privileges) override {
|
||||
EditPermissions(user_or_role, privileges,
|
||||
[](auto *permissions, const auto &permission) {
|
||||
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
|
||||
// TODO (mferencevic): should we first check that the
|
||||
// privilege is granted/denied/revoked before
|
||||
// unconditionally granting/denying/revoking it?
|
||||
@ -778,9 +696,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
|
||||
private:
|
||||
template <class TEditFun>
|
||||
void EditPermissions(
|
||||
const std::string &user_or_role,
|
||||
const std::vector<query::AuthQuery::Privilege> &privileges,
|
||||
void EditPermissions(const std::string &user_or_role, const std::vector<query::AuthQuery::Privilege> &privileges,
|
||||
const TEditFun &edit_fun) {
|
||||
if (!std::regex_match(user_or_role, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user or role name.");
|
||||
@ -795,8 +711,7 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
auto user = auth_->GetUser(user_or_role);
|
||||
auto role = auth_->GetRole(user_or_role);
|
||||
if (!user && !role) {
|
||||
throw query::QueryRuntimeException("User or role '{}' doesn't exist.",
|
||||
user_or_role);
|
||||
throw query::QueryRuntimeException("User or role '{}' doesn't exist.", user_or_role);
|
||||
}
|
||||
if (user) {
|
||||
for (const auto &permission : permissions) {
|
||||
@ -818,71 +733,44 @@ class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
class NoAuthInCommunity : public query::QueryRuntimeException {
|
||||
public:
|
||||
NoAuthInCommunity()
|
||||
: query::QueryRuntimeException::QueryRuntimeException(
|
||||
"Auth is not supported in Memgraph Community!") {}
|
||||
: query::QueryRuntimeException::QueryRuntimeException("Auth is not supported in Memgraph Community!") {}
|
||||
};
|
||||
|
||||
class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
public:
|
||||
bool CreateUser(const std::string &,
|
||||
const std::optional<std::string> &) override {
|
||||
throw NoAuthInCommunity();
|
||||
}
|
||||
bool CreateUser(const std::string &, const std::optional<std::string> &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
bool DropUser(const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
void SetPassword(const std::string &,
|
||||
const std::optional<std::string> &) override {
|
||||
throw NoAuthInCommunity();
|
||||
}
|
||||
void SetPassword(const std::string &, const std::optional<std::string> &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
bool CreateRole(const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
bool DropRole(const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
std::vector<query::TypedValue> GetUsernames() override {
|
||||
throw NoAuthInCommunity();
|
||||
}
|
||||
std::vector<query::TypedValue> GetUsernames() override { throw NoAuthInCommunity(); }
|
||||
|
||||
std::vector<query::TypedValue> GetRolenames() override {
|
||||
throw NoAuthInCommunity();
|
||||
}
|
||||
std::vector<query::TypedValue> GetRolenames() override { throw NoAuthInCommunity(); }
|
||||
|
||||
std::optional<std::string> GetRolenameForUser(const std::string &) override {
|
||||
throw NoAuthInCommunity();
|
||||
}
|
||||
std::optional<std::string> GetRolenameForUser(const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
std::vector<query::TypedValue> GetUsernamesForRole(
|
||||
const std::string &) override {
|
||||
throw NoAuthInCommunity();
|
||||
}
|
||||
std::vector<query::TypedValue> GetUsernamesForRole(const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
void SetRole(const std::string &, const std::string &) override {
|
||||
throw NoAuthInCommunity();
|
||||
}
|
||||
void SetRole(const std::string &, const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
void ClearRole(const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
std::vector<std::vector<query::TypedValue>> GetPrivileges(
|
||||
const std::string &) override {
|
||||
std::vector<std::vector<query::TypedValue>> GetPrivileges(const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
void GrantPrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override {
|
||||
throw NoAuthInCommunity();
|
||||
}
|
||||
|
||||
void GrantPrivilege(
|
||||
const std::string &,
|
||||
const std::vector<query::AuthQuery::Privilege> &) override {
|
||||
void DenyPrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override {
|
||||
throw NoAuthInCommunity();
|
||||
}
|
||||
|
||||
void DenyPrivilege(
|
||||
const std::string &,
|
||||
const std::vector<query::AuthQuery::Privilege> &) override {
|
||||
throw NoAuthInCommunity();
|
||||
}
|
||||
|
||||
void RevokePrivilege(
|
||||
const std::string &,
|
||||
const std::vector<query::AuthQuery::Privilege> &) override {
|
||||
void RevokePrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override {
|
||||
throw NoAuthInCommunity();
|
||||
}
|
||||
};
|
||||
@ -911,11 +799,9 @@ void InitSignalHandlers(const std::function<void()> &shutdown_fun) {
|
||||
shutdown_fun();
|
||||
};
|
||||
|
||||
MG_ASSERT(utils::SignalHandler::RegisterHandler(
|
||||
utils::Signal::Terminate, shutdown, block_shutdown_signals),
|
||||
MG_ASSERT(utils::SignalHandler::RegisterHandler(utils::Signal::Terminate, shutdown, block_shutdown_signals),
|
||||
"Unable to register SIGTERM handler!");
|
||||
MG_ASSERT(utils::SignalHandler::RegisterHandler(
|
||||
utils::Signal::Interupt, shutdown, block_shutdown_signals),
|
||||
MG_ASSERT(utils::SignalHandler::RegisterHandler(utils::Signal::Interupt, shutdown, block_shutdown_signals),
|
||||
"Unable to register SIGINT handler!");
|
||||
}
|
||||
|
||||
@ -952,13 +838,10 @@ int main(int argc, char **argv) {
|
||||
auto gil = py::EnsureGIL();
|
||||
auto maybe_exc = py::AppendToSysPath(py_support_dir.c_str());
|
||||
if (maybe_exc) {
|
||||
spdlog::error("Unable to load support for embedded Python: {}",
|
||||
*maybe_exc);
|
||||
spdlog::error("Unable to load support for embedded Python: {}", *maybe_exc);
|
||||
}
|
||||
} else {
|
||||
spdlog::error(
|
||||
"Unable to load support for embedded Python: missing directory {}",
|
||||
py_support_dir);
|
||||
spdlog::error("Unable to load support for embedded Python: missing directory {}", py_support_dir);
|
||||
}
|
||||
} catch (const std::filesystem::filesystem_error &e) {
|
||||
spdlog::error("Unable to load support for embedded Python: {}", e.what());
|
||||
@ -978,8 +861,7 @@ int main(int argc, char **argv) {
|
||||
mem_log_scheduler.Run("Memory warning", std::chrono::seconds(3), [] {
|
||||
auto free_ram = utils::sysinfo::AvailableMemoryKilobytes();
|
||||
if (free_ram && *free_ram / 1024 < FLAGS_memory_warning_threshold)
|
||||
spdlog::warn("Running out of available RAM, only {} MB left",
|
||||
*free_ram / 1024);
|
||||
spdlog::warn("Running out of available RAM, only {} MB left", *free_ram / 1024);
|
||||
});
|
||||
} else {
|
||||
// Kernel version for the `MemAvailable` value is from: man procfs
|
||||
@ -990,8 +872,7 @@ int main(int argc, char **argv) {
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "You are running Memgraph v" << gflags::VersionString()
|
||||
<< std::endl;
|
||||
std::cout << "You are running Memgraph v" << gflags::VersionString() << std::endl;
|
||||
|
||||
auto data_directory = std::filesystem::path(FLAGS_data_directory);
|
||||
|
||||
@ -1011,17 +892,14 @@ int main(int argc, char **argv) {
|
||||
auth::Auth auth{data_directory / "auth"};
|
||||
|
||||
// Audit log
|
||||
audit::Log audit_log{data_directory / "audit", FLAGS_audit_buffer_size,
|
||||
FLAGS_audit_buffer_flush_interval_ms};
|
||||
audit::Log audit_log{data_directory / "audit", FLAGS_audit_buffer_size, FLAGS_audit_buffer_flush_interval_ms};
|
||||
// Start the log if enabled.
|
||||
if (FLAGS_audit_enabled) {
|
||||
audit_log.Start();
|
||||
}
|
||||
// Setup SIGUSR2 to be used for reopening audit log files, when e.g. logrotate
|
||||
// rotates our audit logs.
|
||||
MG_ASSERT(
|
||||
utils::SignalHandler::RegisterHandler(
|
||||
utils::Signal::User2, [&audit_log]() { audit_log.ReopenLog(); }),
|
||||
MG_ASSERT(utils::SignalHandler::RegisterHandler(utils::Signal::User2, [&audit_log]() { audit_log.ReopenLog(); }),
|
||||
"Unable to register SIGUSR2 handler!");
|
||||
|
||||
// End enterprise features initialization
|
||||
@ -1030,11 +908,9 @@ int main(int argc, char **argv) {
|
||||
// Main storage and execution engines initialization
|
||||
|
||||
storage::Config db_config{
|
||||
.gc = {.type = storage::Config::Gc::Type::PERIODIC,
|
||||
.interval = std::chrono::seconds(FLAGS_storage_gc_cycle_sec)},
|
||||
.gc = {.type = storage::Config::Gc::Type::PERIODIC, .interval = std::chrono::seconds(FLAGS_storage_gc_cycle_sec)},
|
||||
.items = {.properties_on_edges = FLAGS_storage_properties_on_edges},
|
||||
.durability = {
|
||||
.storage_directory = FLAGS_data_directory,
|
||||
.durability = {.storage_directory = FLAGS_data_directory,
|
||||
.recover_on_startup = FLAGS_storage_recover_on_startup,
|
||||
.snapshot_retention_count = FLAGS_storage_snapshot_retention_count,
|
||||
.wal_file_size_kibibytes = FLAGS_storage_wal_file_size_kib,
|
||||
@ -1046,38 +922,31 @@ int main(int argc, char **argv) {
|
||||
"In order to use write-ahead-logging you must enable "
|
||||
"periodic snapshots by setting the snapshot interval to a "
|
||||
"value larger than 0!");
|
||||
db_config.durability.snapshot_wal_mode =
|
||||
storage::Config::Durability::SnapshotWalMode::DISABLED;
|
||||
db_config.durability.snapshot_wal_mode = storage::Config::Durability::SnapshotWalMode::DISABLED;
|
||||
}
|
||||
} else {
|
||||
if (FLAGS_storage_wal_enabled) {
|
||||
db_config.durability.snapshot_wal_mode = storage::Config::Durability::
|
||||
SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL;
|
||||
db_config.durability.snapshot_wal_mode = storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL;
|
||||
} else {
|
||||
db_config.durability.snapshot_wal_mode =
|
||||
storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT;
|
||||
db_config.durability.snapshot_wal_mode = storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT;
|
||||
}
|
||||
db_config.durability.snapshot_interval =
|
||||
std::chrono::seconds(FLAGS_storage_snapshot_interval_sec);
|
||||
db_config.durability.snapshot_interval = std::chrono::seconds(FLAGS_storage_snapshot_interval_sec);
|
||||
}
|
||||
storage::Storage db(db_config);
|
||||
query::InterpreterContext interpreter_context{&db};
|
||||
|
||||
query::SetExecutionTimeout(&interpreter_context,
|
||||
FLAGS_query_execution_timeout_sec);
|
||||
query::SetExecutionTimeout(&interpreter_context, FLAGS_query_execution_timeout_sec);
|
||||
#ifdef MG_ENTERPRISE
|
||||
SessionData session_data{&db, &interpreter_context, &auth, &audit_log};
|
||||
#else
|
||||
SessionData session_data{&db, &interpreter_context};
|
||||
#endif
|
||||
|
||||
query::procedure::gModuleRegistry.SetModulesDirectory(
|
||||
FLAGS_query_modules_directory);
|
||||
query::procedure::gModuleRegistry.SetModulesDirectory(FLAGS_query_modules_directory);
|
||||
query::procedure::gModuleRegistry.UnloadAndLoadModulesFromDirectory();
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
AuthQueryHandler auth_handler(&auth,
|
||||
std::regex(FLAGS_auth_user_or_role_name_regex));
|
||||
AuthQueryHandler auth_handler(&auth, std::regex(FLAGS_auth_user_or_role_name_regex));
|
||||
#else
|
||||
AuthQueryHandler auth_handler;
|
||||
#endif
|
||||
@ -1093,15 +962,13 @@ int main(int argc, char **argv) {
|
||||
spdlog::warn("Using non-secure Bolt connection (without SSL)");
|
||||
}
|
||||
|
||||
ServerT server({FLAGS_bolt_address, static_cast<uint16_t>(FLAGS_bolt_port)},
|
||||
&session_data, &context, FLAGS_bolt_session_inactivity_timeout,
|
||||
service_name, FLAGS_bolt_num_workers);
|
||||
ServerT server({FLAGS_bolt_address, static_cast<uint16_t>(FLAGS_bolt_port)}, &session_data, &context,
|
||||
FLAGS_bolt_session_inactivity_timeout, service_name, FLAGS_bolt_num_workers);
|
||||
|
||||
// Setup telemetry
|
||||
std::optional<telemetry::Telemetry> telemetry;
|
||||
if (FLAGS_telemetry_enabled) {
|
||||
telemetry.emplace(
|
||||
"https://telemetry.memgraph.com/88b5e7e8-746a-11e8-9f85-538a9e9690cc/",
|
||||
telemetry.emplace("https://telemetry.memgraph.com/88b5e7e8-746a-11e8-9f85-538a9e9690cc/",
|
||||
data_directory / "telemetry", std::chrono::minutes(10));
|
||||
telemetry->AddCollector("db", [&db]() -> nlohmann::json {
|
||||
auto info = db.GetInfo();
|
||||
|
@ -42,30 +42,22 @@ bool ValidateIdTypeOptions(const char *flagname, const std::string &value) {
|
||||
// They are used to automatically load the same configuration as the main
|
||||
// Memgraph binary so that the flags don't need to be specified when importing a
|
||||
// CSV file on a correctly set-up Memgraph installation.
|
||||
DEFINE_string(data_directory, "mg_data",
|
||||
"Path to directory in which to save all permanent data.");
|
||||
DEFINE_bool(storage_properties_on_edges, false,
|
||||
"Controls whether relationships have properties.");
|
||||
DEFINE_string(data_directory, "mg_data", "Path to directory in which to save all permanent data.");
|
||||
DEFINE_bool(storage_properties_on_edges, false, "Controls whether relationships have properties.");
|
||||
|
||||
// CSV import flags.
|
||||
DEFINE_string(array_delimiter, ";",
|
||||
"Delimiter between elements of array values.");
|
||||
DEFINE_string(array_delimiter, ";", "Delimiter between elements of array values.");
|
||||
DEFINE_validator(array_delimiter, &ValidateControlCharacter);
|
||||
DEFINE_string(delimiter, ",", "Delimiter between each field in the CSV.");
|
||||
DEFINE_validator(delimiter, &ValidateControlCharacter);
|
||||
DEFINE_string(quote, "\"",
|
||||
"Quotation character for data in the CSV. Cannot contain '\n'");
|
||||
DEFINE_string(quote, "\"", "Quotation character for data in the CSV. Cannot contain '\n'");
|
||||
DEFINE_validator(quote, &ValidateControlCharacter);
|
||||
DEFINE_bool(skip_duplicate_nodes, false,
|
||||
"Set to true to skip duplicate nodes instead of raising an error.");
|
||||
DEFINE_bool(skip_duplicate_nodes, false, "Set to true to skip duplicate nodes instead of raising an error.");
|
||||
DEFINE_bool(skip_bad_relationships, false,
|
||||
"Set to true to skip relationships that connect nodes that don't "
|
||||
"exist instead of raising an error.");
|
||||
DEFINE_bool(ignore_empty_strings, false,
|
||||
"Set to true to treat empty strings as null values.");
|
||||
DEFINE_bool(
|
||||
ignore_extra_columns, false,
|
||||
"Set to true to ignore columns that aren't specified in the header.");
|
||||
DEFINE_bool(ignore_empty_strings, false, "Set to true to treat empty strings as null values.");
|
||||
DEFINE_bool(ignore_extra_columns, false, "Set to true to ignore columns that aren't specified in the header.");
|
||||
DEFINE_bool(trim_strings, false,
|
||||
"Set to true to trim leading/trailing whitespace from all fields "
|
||||
"that are loaded from the CSV file.");
|
||||
@ -75,16 +67,14 @@ DEFINE_string(id_type, "STRING",
|
||||
DEFINE_validator(id_type, &ValidateIdTypeOptions);
|
||||
// Arguments `--nodes` and `--relationships` can be input multiple times and are
|
||||
// handled with custom parsing.
|
||||
DEFINE_string(
|
||||
nodes, "",
|
||||
DEFINE_string(nodes, "",
|
||||
"Files that should be parsed for nodes. The CSV header will be loaded from "
|
||||
"the first supplied file, all other files supplied in a single flag will "
|
||||
"be treated as data files. Additional labels can be specified for the node "
|
||||
"files. The flag can be specified multiple times (useful for differently "
|
||||
"formatted node files). The format of this argument is: "
|
||||
"[<label>[:<label>]...=]<file>[,<file>][,<file>]...");
|
||||
DEFINE_string(
|
||||
relationships, "",
|
||||
DEFINE_string(relationships, "",
|
||||
"Files that should be parsed for relationships. The CSV header will be "
|
||||
"loaded from the first supplied file, all other files supplied in a single "
|
||||
"flag will be treated as data files. The relationship type can be "
|
||||
@ -92,8 +82,7 @@ DEFINE_string(
|
||||
"times (useful for differently formatted relationship files). The format "
|
||||
"of this argument is: [<type>=]<file>[,<file>][,<file>]...");
|
||||
|
||||
std::vector<std::string> ParseRepeatedFlag(const std::string &flagname,
|
||||
int argc, char *argv[]) {
|
||||
std::vector<std::string> ParseRepeatedFlag(const std::string &flagname, int argc, char *argv[]) {
|
||||
std::vector<std::string> values;
|
||||
for (int i = 1; i < argc; ++i) {
|
||||
std::string flag(argv[i]);
|
||||
@ -132,9 +121,7 @@ struct NodeId {
|
||||
std::string id_space;
|
||||
};
|
||||
|
||||
bool operator==(const NodeId &a, const NodeId &b) {
|
||||
return a.id == b.id && a.id_space == b.id_space;
|
||||
}
|
||||
bool operator==(const NodeId &a, const NodeId &b) { return a.id == b.id && a.id_space == b.id_space; }
|
||||
|
||||
std::ostream &operator<<(std::ostream &stream, const NodeId &node_id) {
|
||||
if (!node_id.id_space.empty()) {
|
||||
@ -171,8 +158,7 @@ enum class CsvParserState {
|
||||
EXPECT_DELIMITER,
|
||||
};
|
||||
|
||||
bool SubstringStartsWith(const std::string_view &str, size_t pos,
|
||||
const std::string_view &what) {
|
||||
bool SubstringStartsWith(const std::string_view &str, size_t pos, const std::string_view &what) {
|
||||
return utils::StartsWith(utils::Substr(str, pos), what);
|
||||
}
|
||||
|
||||
@ -262,8 +248,7 @@ std::pair<std::vector<std::string>, uint64_t> ReadRow(std::istream &stream) {
|
||||
}
|
||||
case CsvParserState::QUOTING: {
|
||||
auto quote_now = SubstringStartsWith(line, i, FLAGS_quote);
|
||||
auto quote_next =
|
||||
SubstringStartsWith(line, i + FLAGS_quote.size(), FLAGS_quote);
|
||||
auto quote_next = SubstringStartsWith(line, i + FLAGS_quote.size(), FLAGS_quote);
|
||||
if (quote_now && quote_next) {
|
||||
// This is an escaped quote character.
|
||||
column += FLAGS_quote;
|
||||
@ -293,8 +278,7 @@ std::pair<std::vector<std::string>, uint64_t> ReadRow(std::istream &stream) {
|
||||
state = CsvParserState::NEXT_FIELD;
|
||||
i += FLAGS_delimiter.size() - 1;
|
||||
} else {
|
||||
throw LoadException("Expected '{}' after '{}', but got '{}'",
|
||||
FLAGS_delimiter, FLAGS_quote, c);
|
||||
throw LoadException("Expected '{}' after '{}', but got '{}'", FLAGS_delimiter, FLAGS_quote, c);
|
||||
}
|
||||
break;
|
||||
}
|
||||
@ -326,8 +310,7 @@ std::pair<std::vector<std::string>, uint64_t> ReadRow(std::istream &stream) {
|
||||
}
|
||||
|
||||
if (FLAGS_trim_strings) {
|
||||
std::transform(std::begin(row), std::end(row), std::begin(row),
|
||||
[](const auto &item) { return utils::Trim(item); });
|
||||
std::transform(std::begin(row), std::end(row), std::begin(row), [](const auto &item) { return utils::Trim(item); });
|
||||
}
|
||||
|
||||
return {std::move(row), lines_count};
|
||||
@ -373,13 +356,10 @@ double StringToDouble(const std::string &value) {
|
||||
}
|
||||
|
||||
/// @throw LoadException
|
||||
storage::PropertyValue StringToValue(const std::string &str,
|
||||
const std::string &type) {
|
||||
if (FLAGS_ignore_empty_strings && str.empty())
|
||||
return storage::PropertyValue();
|
||||
storage::PropertyValue StringToValue(const std::string &str, const std::string &type) {
|
||||
if (FLAGS_ignore_empty_strings && str.empty()) return storage::PropertyValue();
|
||||
auto convert = [](const auto &str, const auto &type) {
|
||||
if (type == "integer" || type == "int" || type == "long" ||
|
||||
type == "byte" || type == "short") {
|
||||
if (type == "integer" || type == "int" || type == "long" || type == "byte" || type == "short") {
|
||||
return storage::PropertyValue(StringToInt(str));
|
||||
} else if (type == "float" || type == "double") {
|
||||
return storage::PropertyValue(StringToDouble(str));
|
||||
@ -411,8 +391,7 @@ storage::PropertyValue StringToValue(const std::string &str,
|
||||
std::string GetIdSpace(const std::string &type) {
|
||||
// The format of this field is as follows:
|
||||
// [START_|END_]ID[(<id_space>)]
|
||||
std::regex format(R"(^(START_|END_)?ID(\(([^\(\)]+)\))?$)",
|
||||
std::regex::extended);
|
||||
std::regex format(R"(^(START_|END_)?ID(\(([^\(\)]+)\))?$)", std::regex::extended);
|
||||
std::smatch res;
|
||||
if (!std::regex_match(type, res, format))
|
||||
throw LoadException(
|
||||
@ -424,8 +403,7 @@ std::string GetIdSpace(const std::string &type) {
|
||||
}
|
||||
|
||||
/// @throw LoadException
|
||||
void ProcessNodeRow(storage::Storage *store, const std::vector<Field> &fields,
|
||||
const std::vector<std::string> &row,
|
||||
void ProcessNodeRow(storage::Storage *store, const std::vector<Field> &fields, const std::vector<std::string> &row,
|
||||
const std::vector<std::string> &additional_labels,
|
||||
std::unordered_map<NodeId, storage::Gid> *node_id_map) {
|
||||
std::optional<NodeId> id;
|
||||
@ -458,45 +436,32 @@ void ProcessNodeRow(storage::Storage *store, const std::vector<Field> &fields,
|
||||
} else {
|
||||
pv_id = storage::PropertyValue(node_id.id);
|
||||
}
|
||||
auto node_property =
|
||||
node.SetProperty(acc.NameToProperty(field.name), pv_id);
|
||||
if (!node_property.HasValue())
|
||||
throw LoadException("Couldn't add property '{}' to the node",
|
||||
field.name);
|
||||
if (!*node_property)
|
||||
throw LoadException("The property '{}' already exists", field.name);
|
||||
auto node_property = node.SetProperty(acc.NameToProperty(field.name), pv_id);
|
||||
if (!node_property.HasValue()) throw LoadException("Couldn't add property '{}' to the node", field.name);
|
||||
if (!*node_property) throw LoadException("The property '{}' already exists", field.name);
|
||||
}
|
||||
id = node_id;
|
||||
} else if (field.type == "LABEL") {
|
||||
for (const auto &label : utils::Split(value, FLAGS_array_delimiter)) {
|
||||
auto node_label = node.AddLabel(acc.NameToLabel(label));
|
||||
if (!node_label.HasValue())
|
||||
throw LoadException("Couldn't add label '{}' to the node", label);
|
||||
if (!*node_label)
|
||||
throw LoadException("The label '{}' already exists", label);
|
||||
if (!node_label.HasValue()) throw LoadException("Couldn't add label '{}' to the node", label);
|
||||
if (!*node_label) throw LoadException("The label '{}' already exists", label);
|
||||
}
|
||||
} else if (field.type != "IGNORE") {
|
||||
auto node_property = node.SetProperty(acc.NameToProperty(field.name),
|
||||
StringToValue(value, field.type));
|
||||
if (!node_property.HasValue())
|
||||
throw LoadException("Couldn't add property '{}' to the node",
|
||||
field.name);
|
||||
if (!*node_property)
|
||||
throw LoadException("The property '{}' already exists", field.name);
|
||||
auto node_property = node.SetProperty(acc.NameToProperty(field.name), StringToValue(value, field.type));
|
||||
if (!node_property.HasValue()) throw LoadException("Couldn't add property '{}' to the node", field.name);
|
||||
if (!*node_property) throw LoadException("The property '{}' already exists", field.name);
|
||||
}
|
||||
}
|
||||
for (const auto &label : additional_labels) {
|
||||
auto node_label = node.AddLabel(acc.NameToLabel(label));
|
||||
if (!node_label.HasValue())
|
||||
throw LoadException("Couldn't add label '{}' to the node", label);
|
||||
if (!*node_label)
|
||||
throw LoadException("The label '{}' already exists", label);
|
||||
if (!node_label.HasValue()) throw LoadException("Couldn't add label '{}' to the node", label);
|
||||
if (!*node_label) throw LoadException("The label '{}' already exists", label);
|
||||
}
|
||||
if (acc.Commit().HasError()) throw LoadException("Couldn't store the node");
|
||||
}
|
||||
|
||||
void ProcessNodes(storage::Storage *store, const std::string &nodes_path,
|
||||
std::optional<std::vector<Field>> *header,
|
||||
void ProcessNodes(storage::Storage *store, const std::string &nodes_path, std::optional<std::vector<Field>> *header,
|
||||
std::unordered_map<NodeId, storage::Gid> *node_id_map,
|
||||
const std::vector<std::string> &additional_labels) {
|
||||
std::ifstream nodes_file(nodes_path);
|
||||
@ -524,16 +489,13 @@ void ProcessNodes(storage::Storage *store, const std::string &nodes_path,
|
||||
row_number += lines_count;
|
||||
}
|
||||
} catch (const LoadException &e) {
|
||||
LOG_FATAL("Couldn't process row {} of '{}' because of: {}", row_number,
|
||||
nodes_path, e.what());
|
||||
LOG_FATAL("Couldn't process row {} of '{}' because of: {}", row_number, nodes_path, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
/// @throw LoadException
|
||||
void ProcessRelationshipsRow(
|
||||
storage::Storage *store, const std::vector<Field> &fields,
|
||||
const std::vector<std::string> &row,
|
||||
std::optional<std::string> relationship_type,
|
||||
void ProcessRelationshipsRow(storage::Storage *store, const std::vector<Field> &fields,
|
||||
const std::vector<std::string> &row, std::optional<std::string> relationship_type,
|
||||
const std::unordered_map<NodeId, storage::Gid> &node_id_map) {
|
||||
std::optional<storage::Gid> start_id;
|
||||
std::optional<storage::Gid> end_id;
|
||||
@ -576,14 +538,11 @@ void ProcessRelationshipsRow(
|
||||
}
|
||||
end_id = it->second;
|
||||
} else if (field.type == "TYPE") {
|
||||
if (relationship_type)
|
||||
throw LoadException("Only one relationship TYPE must be specified");
|
||||
if (relationship_type) throw LoadException("Only one relationship TYPE must be specified");
|
||||
relationship_type = value;
|
||||
} else if (field.type != "IGNORE") {
|
||||
auto [it, inserted] =
|
||||
properties.emplace(field.name, StringToValue(value, field.type));
|
||||
if (!inserted)
|
||||
throw LoadException("The property '{}' already exists", field.name);
|
||||
auto [it, inserted] = properties.emplace(field.name, StringToValue(value, field.type));
|
||||
if (!inserted) throw LoadException("The property '{}' already exists", field.name);
|
||||
}
|
||||
}
|
||||
if (!start_id) throw LoadException("START_ID must be set");
|
||||
@ -596,18 +555,14 @@ void ProcessRelationshipsRow(
|
||||
auto to_node = acc.FindVertex(*end_id, storage::View::NEW);
|
||||
if (!to_node) throw LoadException("To node must be in the storage");
|
||||
|
||||
auto relationship = acc.CreateEdge(&*from_node, &*to_node,
|
||||
acc.NameToEdgeType(*relationship_type));
|
||||
if (!relationship.HasValue())
|
||||
throw LoadException("Couldn't create the relationship");
|
||||
auto relationship = acc.CreateEdge(&*from_node, &*to_node, acc.NameToEdgeType(*relationship_type));
|
||||
if (!relationship.HasValue()) throw LoadException("Couldn't create the relationship");
|
||||
|
||||
for (const auto &property : properties) {
|
||||
auto ret = relationship->SetProperty(acc.NameToProperty(property.first),
|
||||
property.second);
|
||||
auto ret = relationship->SetProperty(acc.NameToProperty(property.first), property.second);
|
||||
if (!ret.HasValue()) {
|
||||
if (ret.GetError() != storage::Error::PROPERTIES_DISABLED) {
|
||||
throw LoadException("Couldn't add property '{}' to the relationship",
|
||||
property.first);
|
||||
throw LoadException("Couldn't add property '{}' to the relationship", property.first);
|
||||
} else {
|
||||
throw LoadException(
|
||||
"Couldn't add property '{}' to the relationship because properties "
|
||||
@ -617,12 +572,10 @@ void ProcessRelationshipsRow(
|
||||
}
|
||||
}
|
||||
|
||||
if (acc.Commit().HasError())
|
||||
throw LoadException("Couldn't store the relationship");
|
||||
if (acc.Commit().HasError()) throw LoadException("Couldn't store the relationship");
|
||||
}
|
||||
|
||||
void ProcessRelationships(
|
||||
storage::Storage *store, const std::string &relationships_path,
|
||||
void ProcessRelationships(storage::Storage *store, const std::string &relationships_path,
|
||||
const std::optional<std::string> &relationship_type,
|
||||
std::optional<std::vector<Field>> *header,
|
||||
const std::unordered_map<NodeId, storage::Gid> &node_id_map) {
|
||||
@ -647,13 +600,11 @@ void ProcessRelationships(
|
||||
if (row.size() > (*header)->size()) {
|
||||
row.resize((*header)->size());
|
||||
}
|
||||
ProcessRelationshipsRow(store, **header, row, relationship_type,
|
||||
node_id_map);
|
||||
ProcessRelationshipsRow(store, **header, row, relationship_type, node_id_map);
|
||||
row_number += lines_count;
|
||||
}
|
||||
} catch (const LoadException &e) {
|
||||
LOG_FATAL("Couldn't process row {} of '{}' because of: {}", row_number,
|
||||
relationships_path, e.what());
|
||||
LOG_FATAL("Couldn't process row {} of '{}' because of: {}", row_number, relationships_path, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
@ -735,8 +686,7 @@ int main(int argc, char *argv[]) {
|
||||
.items = {.properties_on_edges = FLAGS_storage_properties_on_edges},
|
||||
.durability = {.storage_directory = FLAGS_data_directory,
|
||||
.recover_on_startup = false,
|
||||
.snapshot_wal_mode =
|
||||
storage::Config::Durability::SnapshotWalMode::DISABLED,
|
||||
.snapshot_wal_mode = storage::Config::Durability::SnapshotWalMode::DISABLED,
|
||||
.snapshot_on_exit = true},
|
||||
}};
|
||||
|
||||
@ -748,8 +698,7 @@ int main(int argc, char *argv[]) {
|
||||
std::optional<std::vector<Field>> header;
|
||||
for (const auto &nodes_file : files) {
|
||||
spdlog::info("Loading {}", nodes_file);
|
||||
ProcessNodes(&store, nodes_file, &header, &node_id_map,
|
||||
additional_labels);
|
||||
ProcessNodes(&store, nodes_file, &header, &node_id_map, additional_labels);
|
||||
}
|
||||
}
|
||||
|
||||
@ -759,8 +708,7 @@ int main(int argc, char *argv[]) {
|
||||
std::optional<std::vector<Field>> header;
|
||||
for (const auto &relationships_file : files) {
|
||||
spdlog::info("Loading {}", relationships_file);
|
||||
ProcessRelationships(&store, relationships_file, type, &header,
|
||||
node_id_map);
|
||||
ProcessRelationships(&store, relationships_file, type, &header, node_id_map);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -104,34 +104,26 @@ class [[nodiscard]] Object final {
|
||||
/// This function always succeeds, meaning that exceptions that occur while
|
||||
/// calling __getattr__ and __getattribute__ will get suppressed. To get error
|
||||
/// reporting, use GetAttr instead.
|
||||
bool HasAttr(const char *attr_name) const {
|
||||
return PyObject_HasAttrString(ptr_, attr_name);
|
||||
}
|
||||
bool HasAttr(const char *attr_name) const { return PyObject_HasAttrString(ptr_, attr_name); }
|
||||
|
||||
/// Equivalent to `hasattr(this, attr_name)` in Python.
|
||||
///
|
||||
/// This function always succeeds, meaning that exceptions that occur while
|
||||
/// calling __getattr__ and __getattribute__ will get suppressed. To get error
|
||||
/// reporting, use GetAttr instead.
|
||||
bool HasAttr(PyObject *attr_name) const {
|
||||
return PyObject_HasAttr(ptr_, attr_name);
|
||||
}
|
||||
bool HasAttr(PyObject *attr_name) const { return PyObject_HasAttr(ptr_, attr_name); }
|
||||
|
||||
/// Equivalent to `this.attr_name` in Python.
|
||||
///
|
||||
/// Returned Object is nullptr if an error occurred.
|
||||
/// @sa FetchError
|
||||
Object GetAttr(const char *attr_name) const {
|
||||
return Object(PyObject_GetAttrString(ptr_, attr_name));
|
||||
}
|
||||
Object GetAttr(const char *attr_name) const { return Object(PyObject_GetAttrString(ptr_, attr_name)); }
|
||||
|
||||
/// Equivalent to `this.attr_name` in Python.
|
||||
///
|
||||
/// Returned Object is nullptr if an error occurred.
|
||||
/// @sa FetchError
|
||||
Object GetAttr(PyObject *attr_name) const {
|
||||
return Object(PyObject_GetAttr(ptr_, attr_name));
|
||||
}
|
||||
Object GetAttr(PyObject *attr_name) const { return Object(PyObject_GetAttr(ptr_, attr_name)); }
|
||||
|
||||
/// Equivalent to `this.attr_name = v` in Python.
|
||||
///
|
||||
@ -145,9 +137,7 @@ class [[nodiscard]] Object final {
|
||||
///
|
||||
/// False is returned if an error occurred.
|
||||
/// @sa FetchError
|
||||
[[nodiscard]] bool SetAttr(PyObject *attr_name, PyObject *v) {
|
||||
return PyObject_SetAttr(ptr_, attr_name, v) == 0;
|
||||
}
|
||||
[[nodiscard]] bool SetAttr(PyObject *attr_name, PyObject *v) { return PyObject_SetAttr(ptr_, attr_name, v) == 0; }
|
||||
|
||||
/// Equivalent to `callable()` in Python.
|
||||
///
|
||||
@ -161,8 +151,7 @@ class [[nodiscard]] Object final {
|
||||
/// @sa FetchError
|
||||
template <class... TArgs>
|
||||
Object Call(const TArgs &...args) const {
|
||||
return Object(PyObject_CallFunctionObjArgs(
|
||||
ptr_, static_cast<PyObject *>(args)..., nullptr));
|
||||
return Object(PyObject_CallFunctionObjArgs(ptr_, static_cast<PyObject *>(args)..., nullptr));
|
||||
}
|
||||
|
||||
/// Equivalent to `obj.meth_name()` in Python.
|
||||
@ -170,8 +159,7 @@ class [[nodiscard]] Object final {
|
||||
/// Returned Object is nullptr if an error occurred.
|
||||
/// @sa FetchError
|
||||
Object CallMethod(std::string_view meth_name) const {
|
||||
Object name(
|
||||
PyUnicode_FromStringAndSize(meth_name.data(), meth_name.size()));
|
||||
Object name(PyUnicode_FromStringAndSize(meth_name.data(), meth_name.size()));
|
||||
return Object(PyObject_CallMethodObjArgs(ptr_, name.Ptr(), nullptr));
|
||||
}
|
||||
|
||||
@ -181,10 +169,8 @@ class [[nodiscard]] Object final {
|
||||
/// @sa FetchError
|
||||
template <class... TArgs>
|
||||
Object CallMethod(std::string_view meth_name, const TArgs &...args) const {
|
||||
Object name(
|
||||
PyUnicode_FromStringAndSize(meth_name.data(), meth_name.size()));
|
||||
return Object(PyObject_CallMethodObjArgs(
|
||||
ptr_, name.Ptr(), static_cast<PyObject *>(args)..., nullptr));
|
||||
Object name(PyUnicode_FromStringAndSize(meth_name.data(), meth_name.size()));
|
||||
return Object(PyObject_CallMethodObjArgs(ptr_, name.Ptr(), static_cast<PyObject *>(args)..., nullptr));
|
||||
}
|
||||
};
|
||||
|
||||
@ -210,8 +196,7 @@ struct [[nodiscard]] ExceptionInfo final {
|
||||
/// argument `skip_first_line` allows the user to skip the first line of the
|
||||
/// traceback. It is useful if the first line in the traceback always prints
|
||||
/// some internal wrapper function.
|
||||
[[nodiscard]] inline std::string FormatException(const ExceptionInfo &exc_info,
|
||||
bool skip_first_line = false) {
|
||||
[[nodiscard]] inline std::string FormatException(const ExceptionInfo &exc_info, bool skip_first_line = false) {
|
||||
if (!exc_info.type) return "";
|
||||
Object traceback_mod(PyImport_ImportModule("traceback"));
|
||||
MG_ASSERT(traceback_mod);
|
||||
@ -221,8 +206,7 @@ struct [[nodiscard]] ExceptionInfo final {
|
||||
if (skip_first_line && traceback_root) {
|
||||
traceback_root = traceback_root.GetAttr("tb_next");
|
||||
}
|
||||
auto list = format_exception_fn.Call(
|
||||
exc_info.type, exc_info.value ? exc_info.value.Ptr() : Py_None,
|
||||
auto list = format_exception_fn.Call(exc_info.type, exc_info.value ? exc_info.value.Ptr() : Py_None,
|
||||
traceback_root ? traceback_root.Ptr() : Py_None);
|
||||
MG_ASSERT(list);
|
||||
std::stringstream ss;
|
||||
@ -235,8 +219,7 @@ struct [[nodiscard]] ExceptionInfo final {
|
||||
}
|
||||
|
||||
/// Write ExceptionInfo to stream just like the Python interpreter would.
|
||||
inline std::ostream &operator<<(std::ostream &os,
|
||||
const ExceptionInfo &exc_info) {
|
||||
inline std::ostream &operator<<(std::ostream &os, const ExceptionInfo &exc_info) {
|
||||
os << FormatException(exc_info);
|
||||
return os;
|
||||
}
|
||||
@ -259,16 +242,14 @@ inline std::ostream &operator<<(std::ostream &os,
|
||||
}
|
||||
|
||||
inline void RestoreError(ExceptionInfo exc_info) {
|
||||
PyErr_Restore(exc_info.type.Steal(), exc_info.value.Steal(),
|
||||
exc_info.traceback.Steal());
|
||||
PyErr_Restore(exc_info.type.Steal(), exc_info.value.Steal(), exc_info.traceback.Steal());
|
||||
}
|
||||
|
||||
/// Append `dir` to Python's `sys.path`.
|
||||
///
|
||||
/// The function does not check whether the directory exists, or is readable.
|
||||
/// ExceptionInfo is returned if an error occurred.
|
||||
[[nodiscard]] inline std::optional<ExceptionInfo> AppendToSysPath(
|
||||
const char *dir) {
|
||||
[[nodiscard]] inline std::optional<ExceptionInfo> AppendToSysPath(const char *dir) {
|
||||
MG_ASSERT(dir);
|
||||
auto *py_path = PySys_GetObject("path");
|
||||
MG_ASSERT(py_path);
|
||||
|
@ -15,9 +15,7 @@ bool TypedValueCompare(const TypedValue &a, const TypedValue &b) {
|
||||
// comparisons are from this point legal only between values of
|
||||
// the same type, or int+float combinations
|
||||
if ((a.type() != b.type() && !(a.IsNumeric() && b.IsNumeric())))
|
||||
throw QueryRuntimeException(
|
||||
"Can't compare value of type {} to value of type {}.", a.type(),
|
||||
b.type());
|
||||
throw QueryRuntimeException("Can't compare value of type {} to value of type {}.", a.type(), b.type());
|
||||
|
||||
switch (a.type()) {
|
||||
case TypedValue::Type::Bool:
|
||||
@ -39,8 +37,7 @@ bool TypedValueCompare(const TypedValue &a, const TypedValue &b) {
|
||||
case TypedValue::Type::Vertex:
|
||||
case TypedValue::Type::Edge:
|
||||
case TypedValue::Type::Path:
|
||||
throw QueryRuntimeException(
|
||||
"Comparison is not defined for values of type {}.", a.type());
|
||||
throw QueryRuntimeException("Comparison is not defined for values of type {}.", a.type());
|
||||
default:
|
||||
LOG_FATAL("Unhandled comparison for types");
|
||||
}
|
||||
|
@ -27,12 +27,10 @@ bool TypedValueCompare(const TypedValue &a, const TypedValue &b);
|
||||
class TypedValueVectorCompare final {
|
||||
public:
|
||||
TypedValueVectorCompare() {}
|
||||
explicit TypedValueVectorCompare(const std::vector<Ordering> &ordering)
|
||||
: ordering_(ordering) {}
|
||||
explicit TypedValueVectorCompare(const std::vector<Ordering> &ordering) : ordering_(ordering) {}
|
||||
|
||||
template <class TAllocator>
|
||||
bool operator()(const std::vector<TypedValue, TAllocator> &c1,
|
||||
const std::vector<TypedValue, TAllocator> &c2) const {
|
||||
bool operator()(const std::vector<TypedValue, TAllocator> &c1, const std::vector<TypedValue, TAllocator> &c2) const {
|
||||
// ordering is invalid if there are more elements in the collections
|
||||
// then there are in the ordering_ vector
|
||||
MG_ASSERT(c1.size() <= ordering_.size() && c2.size() <= ordering_.size(),
|
||||
@ -41,12 +39,9 @@ class TypedValueVectorCompare final {
|
||||
auto c1_it = c1.begin();
|
||||
auto c2_it = c2.begin();
|
||||
auto ordering_it = ordering_.begin();
|
||||
for (; c1_it != c1.end() && c2_it != c2.end();
|
||||
c1_it++, c2_it++, ordering_it++) {
|
||||
if (impl::TypedValueCompare(*c1_it, *c2_it))
|
||||
return *ordering_it == Ordering::ASC;
|
||||
if (impl::TypedValueCompare(*c2_it, *c1_it))
|
||||
return *ordering_it == Ordering::DESC;
|
||||
for (; c1_it != c1.end() && c2_it != c2.end(); c1_it++, c2_it++, ordering_it++) {
|
||||
if (impl::TypedValueCompare(*c1_it, *c2_it)) return *ordering_it == Ordering::ASC;
|
||||
if (impl::TypedValueCompare(*c2_it, *c1_it)) return *ordering_it == Ordering::DESC;
|
||||
}
|
||||
|
||||
// at least one collection is exhausted
|
||||
@ -61,41 +56,33 @@ class TypedValueVectorCompare final {
|
||||
};
|
||||
|
||||
/// Raise QueryRuntimeException if the value for symbol isn't of expected type.
|
||||
inline void ExpectType(const Symbol &symbol, const TypedValue &value,
|
||||
TypedValue::Type expected) {
|
||||
inline void ExpectType(const Symbol &symbol, const TypedValue &value, TypedValue::Type expected) {
|
||||
if (value.type() != expected)
|
||||
throw QueryRuntimeException("Expected a {} for '{}', but got {}.", expected,
|
||||
symbol.name(), value.type());
|
||||
throw QueryRuntimeException("Expected a {} for '{}', but got {}.", expected, symbol.name(), value.type());
|
||||
}
|
||||
|
||||
/// Set a property `value` mapped with given `key` on a `record`.
|
||||
///
|
||||
/// @throw QueryRuntimeException if value cannot be set as a property value
|
||||
template <class TRecordAccessor>
|
||||
void PropsSetChecked(TRecordAccessor *record, const storage::PropertyId &key,
|
||||
const TypedValue &value) {
|
||||
void PropsSetChecked(TRecordAccessor *record, const storage::PropertyId &key, const TypedValue &value) {
|
||||
try {
|
||||
auto maybe_error = record->SetProperty(key, storage::PropertyValue(value));
|
||||
if (maybe_error.HasError()) {
|
||||
switch (maybe_error.GetError()) {
|
||||
case storage::Error::SERIALIZATION_ERROR:
|
||||
throw QueryRuntimeException(
|
||||
"Can't serialize due to concurrent operations.");
|
||||
throw QueryRuntimeException("Can't serialize due to concurrent operations.");
|
||||
case storage::Error::DELETED_OBJECT:
|
||||
throw QueryRuntimeException(
|
||||
"Trying to set properties on a deleted object.");
|
||||
throw QueryRuntimeException("Trying to set properties on a deleted object.");
|
||||
case storage::Error::PROPERTIES_DISABLED:
|
||||
throw QueryRuntimeException(
|
||||
"Can't set property because properties on edges are disabled.");
|
||||
throw QueryRuntimeException("Can't set property because properties on edges are disabled.");
|
||||
case storage::Error::VERTEX_HAS_EDGES:
|
||||
case storage::Error::NONEXISTENT_OBJECT:
|
||||
throw QueryRuntimeException(
|
||||
"Unexpected error when setting a property.");
|
||||
throw QueryRuntimeException("Unexpected error when setting a property.");
|
||||
}
|
||||
}
|
||||
} catch (const TypedValueException &) {
|
||||
throw QueryRuntimeException("'{}' cannot be used as a property value.",
|
||||
value.type());
|
||||
throw QueryRuntimeException("'{}' cannot be used as a property value.", value.type());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -26,8 +26,8 @@ struct EvaluationContext {
|
||||
mutable std::unordered_map<std::string, int64_t> counters;
|
||||
};
|
||||
|
||||
inline std::vector<storage::PropertyId> NamesToProperties(
|
||||
const std::vector<std::string> &property_names, DbAccessor *dba) {
|
||||
inline std::vector<storage::PropertyId> NamesToProperties(const std::vector<std::string> &property_names,
|
||||
DbAccessor *dba) {
|
||||
std::vector<storage::PropertyId> properties;
|
||||
properties.reserve(property_names.size());
|
||||
for (const auto &name : property_names) {
|
||||
@ -36,8 +36,7 @@ inline std::vector<storage::PropertyId> NamesToProperties(
|
||||
return properties;
|
||||
}
|
||||
|
||||
inline std::vector<storage::LabelId> NamesToLabels(
|
||||
const std::vector<std::string> &label_names, DbAccessor *dba) {
|
||||
inline std::vector<storage::LabelId> NamesToLabels(const std::vector<std::string> &label_names, DbAccessor *dba) {
|
||||
std::vector<storage::LabelId> labels;
|
||||
labels.reserve(label_names.size());
|
||||
for (const auto &name : label_names) {
|
||||
@ -60,11 +59,9 @@ struct ExecutionContext {
|
||||
};
|
||||
|
||||
inline bool MustAbort(const ExecutionContext &context) {
|
||||
return (context.is_shutting_down &&
|
||||
context.is_shutting_down->load(std::memory_order_acquire)) ||
|
||||
return (context.is_shutting_down && context.is_shutting_down->load(std::memory_order_acquire)) ||
|
||||
(context.max_execution_time_sec > 0 &&
|
||||
context.execution_tsc_timer.Elapsed() >=
|
||||
context.max_execution_time_sec);
|
||||
context.execution_tsc_timer.Elapsed() >= context.max_execution_time_sec);
|
||||
}
|
||||
|
||||
} // namespace query
|
||||
|
@ -47,19 +47,15 @@ class EdgeAccessor final {
|
||||
|
||||
auto Properties(storage::View view) const { return impl_.Properties(view); }
|
||||
|
||||
storage::Result<storage::PropertyValue> GetProperty(storage::View view,
|
||||
storage::PropertyId key) const {
|
||||
storage::Result<storage::PropertyValue> GetProperty(storage::View view, storage::PropertyId key) const {
|
||||
return impl_.GetProperty(key, view);
|
||||
}
|
||||
|
||||
storage::Result<bool> SetProperty(storage::PropertyId key,
|
||||
const storage::PropertyValue &value) {
|
||||
storage::Result<bool> SetProperty(storage::PropertyId key, const storage::PropertyValue &value) {
|
||||
return impl_.SetProperty(key, value);
|
||||
}
|
||||
|
||||
storage::Result<bool> RemoveProperty(storage::PropertyId key) {
|
||||
return SetProperty(key, storage::PropertyValue());
|
||||
}
|
||||
storage::Result<bool> RemoveProperty(storage::PropertyId key) { return SetProperty(key, storage::PropertyValue()); }
|
||||
|
||||
utils::BasicResult<storage::Error, void> ClearProperties() {
|
||||
auto ret = impl_.ClearProperties();
|
||||
@ -86,44 +82,32 @@ class VertexAccessor final {
|
||||
public:
|
||||
storage::VertexAccessor impl_;
|
||||
|
||||
static EdgeAccessor MakeEdgeAccessor(const storage::EdgeAccessor impl) {
|
||||
return EdgeAccessor(impl);
|
||||
}
|
||||
static EdgeAccessor MakeEdgeAccessor(const storage::EdgeAccessor impl) { return EdgeAccessor(impl); }
|
||||
|
||||
public:
|
||||
explicit VertexAccessor(storage::VertexAccessor impl)
|
||||
: impl_(std::move(impl)) {}
|
||||
explicit VertexAccessor(storage::VertexAccessor impl) : impl_(std::move(impl)) {}
|
||||
|
||||
auto Labels(storage::View view) const { return impl_.Labels(view); }
|
||||
|
||||
storage::Result<bool> AddLabel(storage::LabelId label) {
|
||||
return impl_.AddLabel(label);
|
||||
}
|
||||
storage::Result<bool> AddLabel(storage::LabelId label) { return impl_.AddLabel(label); }
|
||||
|
||||
storage::Result<bool> RemoveLabel(storage::LabelId label) {
|
||||
return impl_.RemoveLabel(label);
|
||||
}
|
||||
storage::Result<bool> RemoveLabel(storage::LabelId label) { return impl_.RemoveLabel(label); }
|
||||
|
||||
storage::Result<bool> HasLabel(storage::View view,
|
||||
storage::LabelId label) const {
|
||||
storage::Result<bool> HasLabel(storage::View view, storage::LabelId label) const {
|
||||
return impl_.HasLabel(label, view);
|
||||
}
|
||||
|
||||
auto Properties(storage::View view) const { return impl_.Properties(view); }
|
||||
|
||||
storage::Result<storage::PropertyValue> GetProperty(storage::View view,
|
||||
storage::PropertyId key) const {
|
||||
storage::Result<storage::PropertyValue> GetProperty(storage::View view, storage::PropertyId key) const {
|
||||
return impl_.GetProperty(key, view);
|
||||
}
|
||||
|
||||
storage::Result<bool> SetProperty(storage::PropertyId key,
|
||||
const storage::PropertyValue &value) {
|
||||
storage::Result<bool> SetProperty(storage::PropertyId key, const storage::PropertyValue &value) {
|
||||
return impl_.SetProperty(key, value);
|
||||
}
|
||||
|
||||
storage::Result<bool> RemoveProperty(storage::PropertyId key) {
|
||||
return SetProperty(key, storage::PropertyValue());
|
||||
}
|
||||
storage::Result<bool> RemoveProperty(storage::PropertyId key) { return SetProperty(key, storage::PropertyValue()); }
|
||||
|
||||
utils::BasicResult<storage::Error, void> ClearProperties() {
|
||||
auto ret = impl_.ClearProperties();
|
||||
@ -131,10 +115,8 @@ class VertexAccessor final {
|
||||
return {};
|
||||
}
|
||||
|
||||
auto InEdges(storage::View view,
|
||||
const std::vector<storage::EdgeTypeId> &edge_types) const
|
||||
-> storage::Result<decltype(iter::imap(MakeEdgeAccessor,
|
||||
*impl_.InEdges(view)))> {
|
||||
auto InEdges(storage::View view, const std::vector<storage::EdgeTypeId> &edge_types) const
|
||||
-> storage::Result<decltype(iter::imap(MakeEdgeAccessor, *impl_.InEdges(view)))> {
|
||||
auto maybe_edges = impl_.InEdges(view, edge_types);
|
||||
if (maybe_edges.HasError()) return maybe_edges.GetError();
|
||||
return iter::imap(MakeEdgeAccessor, std::move(*maybe_edges));
|
||||
@ -142,20 +124,15 @@ class VertexAccessor final {
|
||||
|
||||
auto InEdges(storage::View view) const { return InEdges(view, {}); }
|
||||
|
||||
auto InEdges(storage::View view,
|
||||
const std::vector<storage::EdgeTypeId> &edge_types,
|
||||
const VertexAccessor &dest) const
|
||||
-> storage::Result<decltype(iter::imap(MakeEdgeAccessor,
|
||||
*impl_.InEdges(view)))> {
|
||||
auto InEdges(storage::View view, const std::vector<storage::EdgeTypeId> &edge_types, const VertexAccessor &dest) const
|
||||
-> storage::Result<decltype(iter::imap(MakeEdgeAccessor, *impl_.InEdges(view)))> {
|
||||
auto maybe_edges = impl_.InEdges(view, edge_types, &dest.impl_);
|
||||
if (maybe_edges.HasError()) return maybe_edges.GetError();
|
||||
return iter::imap(MakeEdgeAccessor, std::move(*maybe_edges));
|
||||
}
|
||||
|
||||
auto OutEdges(storage::View view,
|
||||
const std::vector<storage::EdgeTypeId> &edge_types) const
|
||||
-> storage::Result<decltype(iter::imap(MakeEdgeAccessor,
|
||||
*impl_.OutEdges(view)))> {
|
||||
auto OutEdges(storage::View view, const std::vector<storage::EdgeTypeId> &edge_types) const
|
||||
-> storage::Result<decltype(iter::imap(MakeEdgeAccessor, *impl_.OutEdges(view)))> {
|
||||
auto maybe_edges = impl_.OutEdges(view, edge_types);
|
||||
if (maybe_edges.HasError()) return maybe_edges.GetError();
|
||||
return iter::imap(MakeEdgeAccessor, std::move(*maybe_edges));
|
||||
@ -163,23 +140,17 @@ class VertexAccessor final {
|
||||
|
||||
auto OutEdges(storage::View view) const { return OutEdges(view, {}); }
|
||||
|
||||
auto OutEdges(storage::View view,
|
||||
const std::vector<storage::EdgeTypeId> &edge_types,
|
||||
auto OutEdges(storage::View view, const std::vector<storage::EdgeTypeId> &edge_types,
|
||||
const VertexAccessor &dest) const
|
||||
-> storage::Result<decltype(iter::imap(MakeEdgeAccessor,
|
||||
*impl_.OutEdges(view)))> {
|
||||
-> storage::Result<decltype(iter::imap(MakeEdgeAccessor, *impl_.OutEdges(view)))> {
|
||||
auto maybe_edges = impl_.OutEdges(view, edge_types, &dest.impl_);
|
||||
if (maybe_edges.HasError()) return maybe_edges.GetError();
|
||||
return iter::imap(MakeEdgeAccessor, std::move(*maybe_edges));
|
||||
}
|
||||
|
||||
storage::Result<size_t> InDegree(storage::View view) const {
|
||||
return impl_.InDegree(view);
|
||||
}
|
||||
storage::Result<size_t> InDegree(storage::View view) const { return impl_.InDegree(view); }
|
||||
|
||||
storage::Result<size_t> OutDegree(storage::View view) const {
|
||||
return impl_.OutDegree(view);
|
||||
}
|
||||
storage::Result<size_t> OutDegree(storage::View view) const { return impl_.OutDegree(view); }
|
||||
|
||||
int64_t CypherId() const { return impl_.Gid().AsInt(); }
|
||||
|
||||
@ -190,13 +161,9 @@ class VertexAccessor final {
|
||||
bool operator!=(const VertexAccessor &v) const { return !(*this == v); }
|
||||
};
|
||||
|
||||
inline VertexAccessor EdgeAccessor::To() const {
|
||||
return VertexAccessor(impl_.ToVertex());
|
||||
}
|
||||
inline VertexAccessor EdgeAccessor::To() const { return VertexAccessor(impl_.ToVertex()); }
|
||||
|
||||
inline VertexAccessor EdgeAccessor::From() const {
|
||||
return VertexAccessor(impl_.FromVertex());
|
||||
}
|
||||
inline VertexAccessor EdgeAccessor::From() const { return VertexAccessor(impl_.FromVertex()); }
|
||||
|
||||
inline bool EdgeAccessor::IsCycle() const { return To() == From(); }
|
||||
|
||||
@ -225,8 +192,7 @@ class DbAccessor final {
|
||||
bool operator!=(const Iterator &other) const { return !(other == *this); }
|
||||
};
|
||||
|
||||
explicit VerticesIterable(storage::VerticesIterable iterable)
|
||||
: iterable_(std::move(iterable)) {}
|
||||
explicit VerticesIterable(storage::VerticesIterable iterable) : iterable_(std::move(iterable)) {}
|
||||
|
||||
Iterator begin() { return Iterator(iterable_.begin()); }
|
||||
|
||||
@ -234,60 +200,45 @@ class DbAccessor final {
|
||||
};
|
||||
|
||||
public:
|
||||
explicit DbAccessor(storage::Storage::Accessor *accessor)
|
||||
: accessor_(accessor) {}
|
||||
explicit DbAccessor(storage::Storage::Accessor *accessor) : accessor_(accessor) {}
|
||||
|
||||
std::optional<VertexAccessor> FindVertex(storage::Gid gid,
|
||||
storage::View view) {
|
||||
std::optional<VertexAccessor> FindVertex(storage::Gid gid, storage::View view) {
|
||||
auto maybe_vertex = accessor_->FindVertex(gid, view);
|
||||
if (maybe_vertex) return VertexAccessor(*maybe_vertex);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
VerticesIterable Vertices(storage::View view) {
|
||||
return VerticesIterable(accessor_->Vertices(view));
|
||||
}
|
||||
VerticesIterable Vertices(storage::View view) { return VerticesIterable(accessor_->Vertices(view)); }
|
||||
|
||||
VerticesIterable Vertices(storage::View view, storage::LabelId label) {
|
||||
return VerticesIterable(accessor_->Vertices(label, view));
|
||||
}
|
||||
|
||||
VerticesIterable Vertices(storage::View view, storage::LabelId label,
|
||||
storage::PropertyId property) {
|
||||
VerticesIterable Vertices(storage::View view, storage::LabelId label, storage::PropertyId property) {
|
||||
return VerticesIterable(accessor_->Vertices(label, property, view));
|
||||
}
|
||||
|
||||
VerticesIterable Vertices(storage::View view, storage::LabelId label,
|
||||
storage::PropertyId property,
|
||||
VerticesIterable Vertices(storage::View view, storage::LabelId label, storage::PropertyId property,
|
||||
const storage::PropertyValue &value) {
|
||||
return VerticesIterable(accessor_->Vertices(label, property, value, view));
|
||||
}
|
||||
|
||||
VerticesIterable Vertices(
|
||||
storage::View view, storage::LabelId label, storage::PropertyId property,
|
||||
VerticesIterable Vertices(storage::View view, storage::LabelId label, storage::PropertyId property,
|
||||
const std::optional<utils::Bound<storage::PropertyValue>> &lower,
|
||||
const std::optional<utils::Bound<storage::PropertyValue>> &upper) {
|
||||
return VerticesIterable(
|
||||
accessor_->Vertices(label, property, lower, upper, view));
|
||||
return VerticesIterable(accessor_->Vertices(label, property, lower, upper, view));
|
||||
}
|
||||
|
||||
VertexAccessor InsertVertex() {
|
||||
return VertexAccessor(accessor_->CreateVertex());
|
||||
}
|
||||
VertexAccessor InsertVertex() { return VertexAccessor(accessor_->CreateVertex()); }
|
||||
|
||||
storage::Result<EdgeAccessor> InsertEdge(VertexAccessor *from,
|
||||
VertexAccessor *to,
|
||||
storage::Result<EdgeAccessor> InsertEdge(VertexAccessor *from, VertexAccessor *to,
|
||||
const storage::EdgeTypeId &edge_type) {
|
||||
auto maybe_edge =
|
||||
accessor_->CreateEdge(&from->impl_, &to->impl_, edge_type);
|
||||
if (maybe_edge.HasError())
|
||||
return storage::Result<EdgeAccessor>(maybe_edge.GetError());
|
||||
auto maybe_edge = accessor_->CreateEdge(&from->impl_, &to->impl_, edge_type);
|
||||
if (maybe_edge.HasError()) return storage::Result<EdgeAccessor>(maybe_edge.GetError());
|
||||
return EdgeAccessor(std::move(*maybe_edge));
|
||||
}
|
||||
|
||||
storage::Result<bool> RemoveEdge(EdgeAccessor *edge) {
|
||||
return accessor_->DeleteEdge(&edge->impl_);
|
||||
}
|
||||
storage::Result<bool> RemoveEdge(EdgeAccessor *edge) { return accessor_->DeleteEdge(&edge->impl_); }
|
||||
|
||||
storage::Result<bool> DetachRemoveVertex(VertexAccessor *vertex_accessor) {
|
||||
return accessor_->DetachDeleteVertex(&vertex_accessor->impl_);
|
||||
@ -297,55 +248,35 @@ class DbAccessor final {
|
||||
return accessor_->DeleteVertex(&vertex_accessor->impl_);
|
||||
}
|
||||
|
||||
storage::PropertyId NameToProperty(const std::string_view &name) {
|
||||
return accessor_->NameToProperty(name);
|
||||
}
|
||||
storage::PropertyId NameToProperty(const std::string_view &name) { return accessor_->NameToProperty(name); }
|
||||
|
||||
storage::LabelId NameToLabel(const std::string_view &name) {
|
||||
return accessor_->NameToLabel(name);
|
||||
}
|
||||
storage::LabelId NameToLabel(const std::string_view &name) { return accessor_->NameToLabel(name); }
|
||||
|
||||
storage::EdgeTypeId NameToEdgeType(const std::string_view &name) {
|
||||
return accessor_->NameToEdgeType(name);
|
||||
}
|
||||
storage::EdgeTypeId NameToEdgeType(const std::string_view &name) { return accessor_->NameToEdgeType(name); }
|
||||
|
||||
const std::string &PropertyToName(storage::PropertyId prop) const {
|
||||
return accessor_->PropertyToName(prop);
|
||||
}
|
||||
const std::string &PropertyToName(storage::PropertyId prop) const { return accessor_->PropertyToName(prop); }
|
||||
|
||||
const std::string &LabelToName(storage::LabelId label) const {
|
||||
return accessor_->LabelToName(label);
|
||||
}
|
||||
const std::string &LabelToName(storage::LabelId label) const { return accessor_->LabelToName(label); }
|
||||
|
||||
const std::string &EdgeTypeToName(storage::EdgeTypeId type) const {
|
||||
return accessor_->EdgeTypeToName(type);
|
||||
}
|
||||
const std::string &EdgeTypeToName(storage::EdgeTypeId type) const { return accessor_->EdgeTypeToName(type); }
|
||||
|
||||
void AdvanceCommand() { accessor_->AdvanceCommand(); }
|
||||
|
||||
utils::BasicResult<storage::ConstraintViolation, void> Commit() {
|
||||
return accessor_->Commit();
|
||||
}
|
||||
utils::BasicResult<storage::ConstraintViolation, void> Commit() { return accessor_->Commit(); }
|
||||
|
||||
void Abort() { accessor_->Abort(); }
|
||||
|
||||
bool LabelIndexExists(storage::LabelId label) const {
|
||||
return accessor_->LabelIndexExists(label);
|
||||
}
|
||||
bool LabelIndexExists(storage::LabelId label) const { return accessor_->LabelIndexExists(label); }
|
||||
|
||||
bool LabelPropertyIndexExists(storage::LabelId label,
|
||||
storage::PropertyId prop) const {
|
||||
bool LabelPropertyIndexExists(storage::LabelId label, storage::PropertyId prop) const {
|
||||
return accessor_->LabelPropertyIndexExists(label, prop);
|
||||
}
|
||||
|
||||
int64_t VerticesCount() const { return accessor_->ApproximateVertexCount(); }
|
||||
|
||||
int64_t VerticesCount(storage::LabelId label) const {
|
||||
return accessor_->ApproximateVertexCount(label);
|
||||
}
|
||||
int64_t VerticesCount(storage::LabelId label) const { return accessor_->ApproximateVertexCount(label); }
|
||||
|
||||
int64_t VerticesCount(storage::LabelId label,
|
||||
storage::PropertyId property) const {
|
||||
int64_t VerticesCount(storage::LabelId label, storage::PropertyId property) const {
|
||||
return accessor_->ApproximateVertexCount(label, property);
|
||||
}
|
||||
|
||||
@ -354,20 +285,15 @@ class DbAccessor final {
|
||||
return accessor_->ApproximateVertexCount(label, property, value);
|
||||
}
|
||||
|
||||
int64_t VerticesCount(
|
||||
storage::LabelId label, storage::PropertyId property,
|
||||
int64_t VerticesCount(storage::LabelId label, storage::PropertyId property,
|
||||
const std::optional<utils::Bound<storage::PropertyValue>> &lower,
|
||||
const std::optional<utils::Bound<storage::PropertyValue>> &upper) const {
|
||||
return accessor_->ApproximateVertexCount(label, property, lower, upper);
|
||||
}
|
||||
|
||||
storage::IndicesInfo ListAllIndices() const {
|
||||
return accessor_->ListAllIndices();
|
||||
}
|
||||
storage::IndicesInfo ListAllIndices() const { return accessor_->ListAllIndices(); }
|
||||
|
||||
storage::ConstraintsInfo ListAllConstraints() const {
|
||||
return accessor_->ListAllConstraints();
|
||||
}
|
||||
storage::ConstraintsInfo ListAllConstraints() const { return accessor_->ListAllConstraints(); }
|
||||
};
|
||||
|
||||
} // namespace query
|
||||
@ -376,16 +302,12 @@ namespace std {
|
||||
|
||||
template <>
|
||||
struct hash<query::VertexAccessor> {
|
||||
size_t operator()(const query::VertexAccessor &v) const {
|
||||
return std::hash<decltype(v.impl_)>{}(v.impl_);
|
||||
}
|
||||
size_t operator()(const query::VertexAccessor &v) const { return std::hash<decltype(v.impl_)>{}(v.impl_); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hash<query::EdgeAccessor> {
|
||||
size_t operator()(const query::EdgeAccessor &e) const {
|
||||
return std::hash<decltype(e.impl_)>{}(e.impl_);
|
||||
}
|
||||
size_t operator()(const query::EdgeAccessor &e) const { return std::hash<decltype(e.impl_)>{}(e.impl_); }
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
@ -50,8 +50,7 @@ void DumpPreciseDouble(std::ostream *os, double value) {
|
||||
// A temporary stream is used to keep precision of the original output
|
||||
// stream unchanged.
|
||||
std::ostringstream temp_oss;
|
||||
temp_oss << std::setprecision(std::numeric_limits<double>::max_digits10)
|
||||
<< value;
|
||||
temp_oss << std::setprecision(std::numeric_limits<double>::max_digits10) << value;
|
||||
*os << temp_oss.str();
|
||||
}
|
||||
|
||||
@ -75,9 +74,7 @@ void DumpPropertyValue(std::ostream *os, const storage::PropertyValue &value) {
|
||||
case storage::PropertyValue::Type::List: {
|
||||
*os << "[";
|
||||
const auto &list = value.ValueList();
|
||||
utils::PrintIterable(*os, list, ", ", [](auto &os, const auto &item) {
|
||||
DumpPropertyValue(&os, item);
|
||||
});
|
||||
utils::PrintIterable(*os, list, ", ", [](auto &os, const auto &item) { DumpPropertyValue(&os, item); });
|
||||
*os << "]";
|
||||
return;
|
||||
}
|
||||
@ -94,8 +91,7 @@ void DumpPropertyValue(std::ostream *os, const storage::PropertyValue &value) {
|
||||
}
|
||||
}
|
||||
|
||||
void DumpProperties(
|
||||
std::ostream *os, query::DbAccessor *dba,
|
||||
void DumpProperties(std::ostream *os, query::DbAccessor *dba,
|
||||
const std::map<storage::PropertyId, storage::PropertyValue> &store,
|
||||
std::optional<int64_t> property_id = std::nullopt) {
|
||||
*os << "{";
|
||||
@ -110,24 +106,20 @@ void DumpProperties(
|
||||
*os << "}";
|
||||
}
|
||||
|
||||
void DumpVertex(std::ostream *os, query::DbAccessor *dba,
|
||||
const query::VertexAccessor &vertex) {
|
||||
void DumpVertex(std::ostream *os, query::DbAccessor *dba, const query::VertexAccessor &vertex) {
|
||||
*os << "CREATE (";
|
||||
*os << ":" << kInternalVertexLabel;
|
||||
auto maybe_labels = vertex.Labels(storage::View::OLD);
|
||||
if (maybe_labels.HasError()) {
|
||||
switch (maybe_labels.GetError()) {
|
||||
case storage::Error::DELETED_OBJECT:
|
||||
throw query::QueryRuntimeException(
|
||||
"Trying to get labels from a deleted node.");
|
||||
throw query::QueryRuntimeException("Trying to get labels from a deleted node.");
|
||||
case storage::Error::NONEXISTENT_OBJECT:
|
||||
throw query::QueryRuntimeException(
|
||||
"Trying to get labels from a node that doesn't exist.");
|
||||
throw query::QueryRuntimeException("Trying to get labels from a node that doesn't exist.");
|
||||
case storage::Error::SERIALIZATION_ERROR:
|
||||
case storage::Error::VERTEX_HAS_EDGES:
|
||||
case storage::Error::PROPERTIES_DISABLED:
|
||||
throw query::QueryRuntimeException(
|
||||
"Unexpected error when getting labels.");
|
||||
throw query::QueryRuntimeException("Unexpected error when getting labels.");
|
||||
}
|
||||
}
|
||||
for (const auto &label : *maybe_labels) {
|
||||
@ -138,24 +130,20 @@ void DumpVertex(std::ostream *os, query::DbAccessor *dba,
|
||||
if (maybe_props.HasError()) {
|
||||
switch (maybe_props.GetError()) {
|
||||
case storage::Error::DELETED_OBJECT:
|
||||
throw query::QueryRuntimeException(
|
||||
"Trying to get properties from a deleted object.");
|
||||
throw query::QueryRuntimeException("Trying to get properties from a deleted object.");
|
||||
case storage::Error::NONEXISTENT_OBJECT:
|
||||
throw query::QueryRuntimeException(
|
||||
"Trying to get properties from a node that doesn't exist.");
|
||||
throw query::QueryRuntimeException("Trying to get properties from a node that doesn't exist.");
|
||||
case storage::Error::SERIALIZATION_ERROR:
|
||||
case storage::Error::VERTEX_HAS_EDGES:
|
||||
case storage::Error::PROPERTIES_DISABLED:
|
||||
throw query::QueryRuntimeException(
|
||||
"Unexpected error when getting properties.");
|
||||
throw query::QueryRuntimeException("Unexpected error when getting properties.");
|
||||
}
|
||||
}
|
||||
DumpProperties(os, dba, *maybe_props, vertex.CypherId());
|
||||
*os << ");";
|
||||
}
|
||||
|
||||
void DumpEdge(std::ostream *os, query::DbAccessor *dba,
|
||||
const query::EdgeAccessor &edge) {
|
||||
void DumpEdge(std::ostream *os, query::DbAccessor *dba, const query::EdgeAccessor &edge) {
|
||||
*os << "MATCH ";
|
||||
*os << "(u:" << kInternalVertexLabel << "), ";
|
||||
*os << "(v:" << kInternalVertexLabel << ")";
|
||||
@ -169,16 +157,13 @@ void DumpEdge(std::ostream *os, query::DbAccessor *dba,
|
||||
if (maybe_props.HasError()) {
|
||||
switch (maybe_props.GetError()) {
|
||||
case storage::Error::DELETED_OBJECT:
|
||||
throw query::QueryRuntimeException(
|
||||
"Trying to get properties from a deleted object.");
|
||||
throw query::QueryRuntimeException("Trying to get properties from a deleted object.");
|
||||
case storage::Error::NONEXISTENT_OBJECT:
|
||||
throw query::QueryRuntimeException(
|
||||
"Trying to get properties from an edge that doesn't exist.");
|
||||
throw query::QueryRuntimeException("Trying to get properties from an edge that doesn't exist.");
|
||||
case storage::Error::SERIALIZATION_ERROR:
|
||||
case storage::Error::VERTEX_HAS_EDGES:
|
||||
case storage::Error::PROPERTIES_DISABLED:
|
||||
throw query::QueryRuntimeException(
|
||||
"Unexpected error when getting properties.");
|
||||
throw query::QueryRuntimeException("Unexpected error when getting properties.");
|
||||
}
|
||||
}
|
||||
if (maybe_props->size() > 0) {
|
||||
@ -188,33 +173,26 @@ void DumpEdge(std::ostream *os, query::DbAccessor *dba,
|
||||
*os << "]->(v);";
|
||||
}
|
||||
|
||||
void DumpLabelIndex(std::ostream *os, query::DbAccessor *dba,
|
||||
const storage::LabelId label) {
|
||||
void DumpLabelIndex(std::ostream *os, query::DbAccessor *dba, const storage::LabelId label) {
|
||||
*os << "CREATE INDEX ON :" << EscapeName(dba->LabelToName(label)) << ";";
|
||||
}
|
||||
|
||||
void DumpLabelPropertyIndex(std::ostream *os, query::DbAccessor *dba,
|
||||
storage::LabelId label,
|
||||
void DumpLabelPropertyIndex(std::ostream *os, query::DbAccessor *dba, storage::LabelId label,
|
||||
storage::PropertyId property) {
|
||||
*os << "CREATE INDEX ON :" << EscapeName(dba->LabelToName(label)) << "("
|
||||
<< EscapeName(dba->PropertyToName(property)) << ");";
|
||||
}
|
||||
|
||||
void DumpExistenceConstraint(std::ostream *os, query::DbAccessor *dba,
|
||||
storage::LabelId label,
|
||||
storage::PropertyId property) {
|
||||
*os << "CREATE CONSTRAINT ON (u:" << EscapeName(dba->LabelToName(label))
|
||||
<< ") ASSERT EXISTS (u." << EscapeName(dba->PropertyToName(property))
|
||||
*os << "CREATE INDEX ON :" << EscapeName(dba->LabelToName(label)) << "(" << EscapeName(dba->PropertyToName(property))
|
||||
<< ");";
|
||||
}
|
||||
|
||||
void DumpUniqueConstraint(std::ostream *os, query::DbAccessor *dba,
|
||||
storage::LabelId label,
|
||||
void DumpExistenceConstraint(std::ostream *os, query::DbAccessor *dba, storage::LabelId label,
|
||||
storage::PropertyId property) {
|
||||
*os << "CREATE CONSTRAINT ON (u:" << EscapeName(dba->LabelToName(label)) << ") ASSERT EXISTS (u."
|
||||
<< EscapeName(dba->PropertyToName(property)) << ");";
|
||||
}
|
||||
|
||||
void DumpUniqueConstraint(std::ostream *os, query::DbAccessor *dba, storage::LabelId label,
|
||||
const std::set<storage::PropertyId> &properties) {
|
||||
*os << "CREATE CONSTRAINT ON (u:" << EscapeName(dba->LabelToName(label))
|
||||
<< ") ASSERT ";
|
||||
utils::PrintIterable(
|
||||
*os, properties, ", ", [&dba](auto &stream, const auto &property) {
|
||||
*os << "CREATE CONSTRAINT ON (u:" << EscapeName(dba->LabelToName(label)) << ") ASSERT ";
|
||||
utils::PrintIterable(*os, properties, ", ", [&dba](auto &stream, const auto &property) {
|
||||
stream << "u." << EscapeName(dba->PropertyToName(property));
|
||||
});
|
||||
*os << " IS UNIQUE;";
|
||||
@ -250,8 +228,7 @@ bool PullPlanDump::Pull(AnyStream *stream, std::optional<int> n) {
|
||||
// finishes. If the function did not finish streaming all the results,
|
||||
// std::nullopt should be returned because n results have already been sent.
|
||||
while (current_chunk_index_ < pull_chunks_.size() && (!n || *n > 0)) {
|
||||
const auto maybe_streamed_count =
|
||||
pull_chunks_[current_chunk_index_](stream, n);
|
||||
const auto maybe_streamed_count = pull_chunks_[current_chunk_index_](stream, n);
|
||||
|
||||
if (!maybe_streamed_count) {
|
||||
// n wasn't large enough to stream all the results from the current chunk
|
||||
@ -273,9 +250,7 @@ bool PullPlanDump::Pull(AnyStream *stream, std::optional<int> n) {
|
||||
|
||||
PullPlanDump::PullChunk PullPlanDump::CreateLabelIndicesPullChunk() {
|
||||
// Dump all label indices
|
||||
return [this, global_index = 0U](
|
||||
AnyStream *stream,
|
||||
std::optional<int> n) mutable -> std::optional<size_t> {
|
||||
return [this, global_index = 0U](AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
|
||||
// Delay the construction of indices vectors
|
||||
if (!indices_info_) {
|
||||
indices_info_.emplace(dba_->ListAllIndices());
|
||||
@ -301,9 +276,7 @@ PullPlanDump::PullChunk PullPlanDump::CreateLabelIndicesPullChunk() {
|
||||
}
|
||||
|
||||
PullPlanDump::PullChunk PullPlanDump::CreateLabelPropertyIndicesPullChunk() {
|
||||
return [this, global_index = 0U](
|
||||
AnyStream *stream,
|
||||
std::optional<int> n) mutable -> std::optional<size_t> {
|
||||
return [this, global_index = 0U](AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
|
||||
// Delay the construction of indices vectors
|
||||
if (!indices_info_) {
|
||||
indices_info_.emplace(dba_->ListAllIndices());
|
||||
@ -314,8 +287,7 @@ PullPlanDump::PullChunk PullPlanDump::CreateLabelPropertyIndicesPullChunk() {
|
||||
while (global_index < label_property.size() && (!n || local_counter < *n)) {
|
||||
std::ostringstream os;
|
||||
const auto &label_property_index = label_property[global_index];
|
||||
DumpLabelPropertyIndex(&os, dba_, label_property_index.first,
|
||||
label_property_index.second);
|
||||
DumpLabelPropertyIndex(&os, dba_, label_property_index.first, label_property_index.second);
|
||||
stream->Result({TypedValue(os.str())});
|
||||
|
||||
++global_index;
|
||||
@ -331,9 +303,7 @@ PullPlanDump::PullChunk PullPlanDump::CreateLabelPropertyIndicesPullChunk() {
|
||||
}
|
||||
|
||||
PullPlanDump::PullChunk PullPlanDump::CreateExistenceConstraintsPullChunk() {
|
||||
return [this, global_index = 0U](
|
||||
AnyStream *stream,
|
||||
std::optional<int> n) mutable -> std::optional<size_t> {
|
||||
return [this, global_index = 0U](AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
|
||||
// Delay the construction of constraint vectors
|
||||
if (!constraints_info_) {
|
||||
constraints_info_.emplace(dba_->ListAllConstraints());
|
||||
@ -360,9 +330,7 @@ PullPlanDump::PullChunk PullPlanDump::CreateExistenceConstraintsPullChunk() {
|
||||
}
|
||||
|
||||
PullPlanDump::PullChunk PullPlanDump::CreateUniqueConstraintsPullChunk() {
|
||||
return [this, global_index = 0U](
|
||||
AnyStream *stream,
|
||||
std::optional<int> n) mutable -> std::optional<size_t> {
|
||||
return [this, global_index = 0U](AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
|
||||
// Delay the construction of constraint vectors
|
||||
if (!constraints_info_) {
|
||||
constraints_info_.emplace(dba_->ListAllConstraints());
|
||||
@ -389,12 +357,10 @@ PullPlanDump::PullChunk PullPlanDump::CreateUniqueConstraintsPullChunk() {
|
||||
}
|
||||
|
||||
PullPlanDump::PullChunk PullPlanDump::CreateInternalIndexPullChunk() {
|
||||
return [this](AnyStream *stream,
|
||||
std::optional<int>) mutable -> std::optional<size_t> {
|
||||
return [this](AnyStream *stream, std::optional<int>) mutable -> std::optional<size_t> {
|
||||
if (vertices_iterable_.begin() != vertices_iterable_.end()) {
|
||||
std::ostringstream os;
|
||||
os << "CREATE INDEX ON :" << kInternalVertexLabel << "("
|
||||
<< kInternalPropertyId << ");";
|
||||
os << "CREATE INDEX ON :" << kInternalVertexLabel << "(" << kInternalPropertyId << ");";
|
||||
stream->Result({TypedValue(os.str())});
|
||||
internal_index_created_ = true;
|
||||
return 1;
|
||||
@ -404,10 +370,8 @@ PullPlanDump::PullChunk PullPlanDump::CreateInternalIndexPullChunk() {
|
||||
}
|
||||
|
||||
PullPlanDump::PullChunk PullPlanDump::CreateVertexPullChunk() {
|
||||
return [this,
|
||||
maybe_current_iter = std::optional<VertexAccessorIterableIterator>{}](
|
||||
AnyStream *stream,
|
||||
std::optional<int> n) mutable -> std::optional<size_t> {
|
||||
return [this, maybe_current_iter = std::optional<VertexAccessorIterableIterator>{}](
|
||||
AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
|
||||
// Delay the call of begin() function
|
||||
// If multiple begins are called before an iteration,
|
||||
// one iteration will make the rest of iterators be in undefined
|
||||
@ -419,8 +383,7 @@ PullPlanDump::PullChunk PullPlanDump::CreateVertexPullChunk() {
|
||||
auto ¤t_iter{*maybe_current_iter};
|
||||
|
||||
size_t local_counter = 0;
|
||||
while (current_iter != vertices_iterable_.end() &&
|
||||
(!n || local_counter < *n)) {
|
||||
while (current_iter != vertices_iterable_.end() && (!n || local_counter < *n)) {
|
||||
std::ostringstream os;
|
||||
DumpVertex(&os, dba_, *current_iter);
|
||||
stream->Result({TypedValue(os.str())});
|
||||
@ -436,16 +399,12 @@ PullPlanDump::PullChunk PullPlanDump::CreateVertexPullChunk() {
|
||||
}
|
||||
|
||||
PullPlanDump::PullChunk PullPlanDump::CreateEdgePullChunk() {
|
||||
return
|
||||
[this,
|
||||
maybe_current_vertex_iter =
|
||||
std::optional<VertexAccessorIterableIterator>{},
|
||||
return [this, maybe_current_vertex_iter = std::optional<VertexAccessorIterableIterator>{},
|
||||
// we need to save the iterable which contains list of accessor so
|
||||
// our saved iterator is valid in the next run
|
||||
maybe_edge_iterable = std::shared_ptr<EdgeAccessorIterable>{nullptr},
|
||||
maybe_current_edge_iter = std::optional<EdgeAccessorIterableIterator>{}](
|
||||
AnyStream *stream,
|
||||
std::optional<int> n) mutable -> std::optional<size_t> {
|
||||
AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
|
||||
// Delay the call of begin() function
|
||||
// If multiple begins are called before an iteration,
|
||||
// one iteration will make the rest of iterators be in undefined
|
||||
@ -456,24 +415,17 @@ PullPlanDump::PullChunk PullPlanDump::CreateEdgePullChunk() {
|
||||
|
||||
auto ¤t_vertex_iter{*maybe_current_vertex_iter};
|
||||
size_t local_counter = 0U;
|
||||
for (; current_vertex_iter != vertices_iterable_.end() &&
|
||||
(!n || local_counter < *n);
|
||||
++current_vertex_iter) {
|
||||
for (; current_vertex_iter != vertices_iterable_.end() && (!n || local_counter < *n); ++current_vertex_iter) {
|
||||
const auto &vertex = *current_vertex_iter;
|
||||
// If we have a saved iterable from a previous pull
|
||||
// we need to use the same iterable
|
||||
if (!maybe_edge_iterable) {
|
||||
maybe_edge_iterable = std::make_shared<EdgeAccessorIterable>(
|
||||
vertex.OutEdges(storage::View::OLD));
|
||||
maybe_edge_iterable = std::make_shared<EdgeAccessorIterable>(vertex.OutEdges(storage::View::OLD));
|
||||
}
|
||||
auto &maybe_edges = *maybe_edge_iterable;
|
||||
MG_ASSERT(maybe_edges.HasValue(), "Invalid database state!");
|
||||
auto current_edge_iter = maybe_current_edge_iter
|
||||
? *maybe_current_edge_iter
|
||||
: maybe_edges->begin();
|
||||
for (; current_edge_iter != maybe_edges->end() &&
|
||||
(!n || local_counter < *n);
|
||||
++current_edge_iter) {
|
||||
auto current_edge_iter = maybe_current_edge_iter ? *maybe_current_edge_iter : maybe_edges->begin();
|
||||
for (; current_edge_iter != maybe_edges->end() && (!n || local_counter < *n); ++current_edge_iter) {
|
||||
std::ostringstream os;
|
||||
DumpEdge(&os, dba_, *current_edge_iter);
|
||||
stream->Result({TypedValue(os.str())});
|
||||
@ -502,8 +454,7 @@ PullPlanDump::PullChunk PullPlanDump::CreateDropInternalIndexPullChunk() {
|
||||
return [this](AnyStream *stream, std::optional<int>) {
|
||||
if (internal_index_created_) {
|
||||
std::ostringstream os;
|
||||
os << "DROP INDEX ON :" << kInternalVertexLabel << "("
|
||||
<< kInternalPropertyId << ");";
|
||||
os << "DROP INDEX ON :" << kInternalVertexLabel << "(" << kInternalPropertyId << ");";
|
||||
stream->Result({TypedValue(os.str())});
|
||||
return 1;
|
||||
}
|
||||
@ -515,8 +466,7 @@ PullPlanDump::PullChunk PullPlanDump::CreateInternalIndexCleanupPullChunk() {
|
||||
return [this](AnyStream *stream, std::optional<int>) {
|
||||
if (internal_index_created_) {
|
||||
std::ostringstream os;
|
||||
os << "MATCH (u) REMOVE u:" << kInternalVertexLabel << ", u."
|
||||
<< kInternalPropertyId << ";";
|
||||
os << "MATCH (u) REMOVE u:" << kInternalVertexLabel << ", u." << kInternalPropertyId << ";";
|
||||
stream->Result({TypedValue(os.str())});
|
||||
return 1;
|
||||
}
|
||||
@ -524,8 +474,6 @@ PullPlanDump::PullChunk PullPlanDump::CreateInternalIndexCleanupPullChunk() {
|
||||
};
|
||||
}
|
||||
|
||||
void DumpDatabaseToCypherQueries(query::DbAccessor *dba, AnyStream *stream) {
|
||||
PullPlanDump(dba).Pull(stream, {});
|
||||
}
|
||||
void DumpDatabaseToCypherQueries(query::DbAccessor *dba, AnyStream *stream) { PullPlanDump(dba).Pull(stream, {}); }
|
||||
|
||||
} // namespace query
|
||||
|
@ -23,23 +23,18 @@ struct PullPlanDump {
|
||||
std::optional<storage::IndicesInfo> indices_info_ = std::nullopt;
|
||||
std::optional<storage::ConstraintsInfo> constraints_info_ = std::nullopt;
|
||||
|
||||
using VertexAccessorIterable =
|
||||
decltype(std::declval<query::DbAccessor>().Vertices(storage::View::OLD));
|
||||
using VertexAccessorIterableIterator =
|
||||
decltype(std::declval<VertexAccessorIterable>().begin());
|
||||
using VertexAccessorIterable = decltype(std::declval<query::DbAccessor>().Vertices(storage::View::OLD));
|
||||
using VertexAccessorIterableIterator = decltype(std::declval<VertexAccessorIterable>().begin());
|
||||
|
||||
using EdgeAccessorIterable =
|
||||
decltype(std::declval<VertexAccessor>().OutEdges(storage::View::OLD));
|
||||
using EdgeAccessorIterableIterator =
|
||||
decltype(std::declval<EdgeAccessorIterable>().GetValue().begin());
|
||||
using EdgeAccessorIterable = decltype(std::declval<VertexAccessor>().OutEdges(storage::View::OLD));
|
||||
using EdgeAccessorIterableIterator = decltype(std::declval<EdgeAccessorIterable>().GetValue().begin());
|
||||
|
||||
VertexAccessorIterable vertices_iterable_;
|
||||
bool internal_index_created_ = false;
|
||||
|
||||
size_t current_chunk_index_ = 0;
|
||||
|
||||
using PullChunk = std::function<std::optional<size_t>(AnyStream *stream,
|
||||
std::optional<int> n)>;
|
||||
using PullChunk = std::function<std::optional<size_t>(AnyStream *stream, std::optional<int> n)>;
|
||||
// We define every part of the dump query in a self contained function.
|
||||
// Each functions is responsible of keeping track of its execution status.
|
||||
// If a function did finish its execution, it should return number of results
|
||||
|
@ -45,23 +45,19 @@ class SemanticException : public QueryException {
|
||||
|
||||
class UnboundVariableError : public SemanticException {
|
||||
public:
|
||||
explicit UnboundVariableError(const std::string &name)
|
||||
: SemanticException("Unbound variable: " + name + ".") {}
|
||||
explicit UnboundVariableError(const std::string &name) : SemanticException("Unbound variable: " + name + ".") {}
|
||||
};
|
||||
|
||||
class RedeclareVariableError : public SemanticException {
|
||||
public:
|
||||
explicit RedeclareVariableError(const std::string &name)
|
||||
: SemanticException("Redeclaring variable: " + name + ".") {}
|
||||
explicit RedeclareVariableError(const std::string &name) : SemanticException("Redeclaring variable: " + name + ".") {}
|
||||
};
|
||||
|
||||
class TypeMismatchError : public SemanticException {
|
||||
public:
|
||||
TypeMismatchError(const std::string &name, const std::string &datum,
|
||||
const std::string &expected)
|
||||
: SemanticException(
|
||||
fmt::format("Type mismatch: {} already defined as {}, expected {}.",
|
||||
name, datum, expected)) {}
|
||||
TypeMismatchError(const std::string &name, const std::string &datum, const std::string &expected)
|
||||
: SemanticException(fmt::format("Type mismatch: {} already defined as {}, expected {}.", name, datum, expected)) {
|
||||
}
|
||||
};
|
||||
|
||||
class UnprovidedParameterError : public QueryException {
|
||||
@ -72,16 +68,13 @@ class UnprovidedParameterError : public QueryException {
|
||||
class ProfileInMulticommandTxException : public QueryException {
|
||||
public:
|
||||
using QueryException::QueryException;
|
||||
ProfileInMulticommandTxException()
|
||||
: QueryException("PROFILE not allowed in multicommand transactions.") {}
|
||||
ProfileInMulticommandTxException() : QueryException("PROFILE not allowed in multicommand transactions.") {}
|
||||
};
|
||||
|
||||
class IndexInMulticommandTxException : public QueryException {
|
||||
public:
|
||||
using QueryException::QueryException;
|
||||
IndexInMulticommandTxException()
|
||||
: QueryException(
|
||||
"Index manipulation not allowed in multicommand transactions.") {}
|
||||
IndexInMulticommandTxException() : QueryException("Index manipulation not allowed in multicommand transactions.") {}
|
||||
};
|
||||
|
||||
class ConstraintInMulticommandTxException : public QueryException {
|
||||
@ -96,9 +89,7 @@ class ConstraintInMulticommandTxException : public QueryException {
|
||||
class InfoInMulticommandTxException : public QueryException {
|
||||
public:
|
||||
using QueryException::QueryException;
|
||||
InfoInMulticommandTxException()
|
||||
: QueryException(
|
||||
"Info reporting not allowed in multicommand transactions.") {}
|
||||
InfoInMulticommandTxException() : QueryException("Info reporting not allowed in multicommand transactions.") {}
|
||||
};
|
||||
|
||||
/**
|
||||
@ -147,38 +138,30 @@ class RemoveAttachedVertexException : public QueryRuntimeException {
|
||||
class UserModificationInMulticommandTxException : public QueryException {
|
||||
public:
|
||||
UserModificationInMulticommandTxException()
|
||||
: QueryException(
|
||||
"Authentication clause not allowed in multicommand transactions.") {
|
||||
}
|
||||
: QueryException("Authentication clause not allowed in multicommand transactions.") {}
|
||||
};
|
||||
|
||||
class StreamClauseInMulticommandTxException : public QueryException {
|
||||
public:
|
||||
StreamClauseInMulticommandTxException()
|
||||
: QueryException(
|
||||
"Stream clause not allowed in multicommand transactions.") {}
|
||||
StreamClauseInMulticommandTxException() : QueryException("Stream clause not allowed in multicommand transactions.") {}
|
||||
};
|
||||
|
||||
class InvalidArgumentsException : public QueryException {
|
||||
public:
|
||||
InvalidArgumentsException(const std::string &argument_name,
|
||||
const std::string &message)
|
||||
: QueryException(fmt::format("Invalid arguments sent: {} - {}",
|
||||
argument_name, message)) {}
|
||||
InvalidArgumentsException(const std::string &argument_name, const std::string &message)
|
||||
: QueryException(fmt::format("Invalid arguments sent: {} - {}", argument_name, message)) {}
|
||||
};
|
||||
|
||||
class ReplicationModificationInMulticommandTxException : public QueryException {
|
||||
public:
|
||||
ReplicationModificationInMulticommandTxException()
|
||||
: QueryException(
|
||||
"Replication clause not allowed in multicommand transactions.") {}
|
||||
: QueryException("Replication clause not allowed in multicommand transactions.") {}
|
||||
};
|
||||
|
||||
class LockPathModificationInMulticommandTxException : public QueryException {
|
||||
public:
|
||||
LockPathModificationInMulticommandTxException()
|
||||
: QueryException(
|
||||
"Lock path clause not allowed in multicommand transactions.") {}
|
||||
: QueryException("Lock path clause not allowed in multicommand transactions.") {}
|
||||
};
|
||||
|
||||
} // namespace query
|
||||
|
@ -76,23 +76,17 @@ class ReplicationQuery;
|
||||
class LockPathQuery;
|
||||
|
||||
using TreeCompositeVisitor = ::utils::CompositeVisitor<
|
||||
SingleQuery, CypherUnion, NamedExpression, OrOperator, XorOperator,
|
||||
AndOperator, NotOperator, AdditionOperator, SubtractionOperator,
|
||||
MultiplicationOperator, DivisionOperator, ModOperator, NotEqualOperator,
|
||||
EqualOperator, LessOperator, GreaterOperator, LessEqualOperator,
|
||||
GreaterEqualOperator, InListOperator, SubscriptOperator,
|
||||
ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator,
|
||||
IsNullOperator, ListLiteral, MapLiteral, PropertyLookup, LabelsTest,
|
||||
Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Any, None,
|
||||
CallProcedure, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom,
|
||||
Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty,
|
||||
RemoveLabels, Merge, Unwind, RegexMatch>;
|
||||
SingleQuery, CypherUnion, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, AdditionOperator,
|
||||
SubtractionOperator, MultiplicationOperator, DivisionOperator, ModOperator, NotEqualOperator, EqualOperator,
|
||||
LessOperator, GreaterOperator, LessEqualOperator, GreaterEqualOperator, InListOperator, SubscriptOperator,
|
||||
ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral,
|
||||
PropertyLookup, LabelsTest, Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Any, None, CallProcedure,
|
||||
Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels,
|
||||
RemoveProperty, RemoveLabels, Merge, Unwind, RegexMatch>;
|
||||
|
||||
using TreeLeafVisitor =
|
||||
::utils::LeafVisitor<Identifier, PrimitiveLiteral, ParameterLookup>;
|
||||
using TreeLeafVisitor = ::utils::LeafVisitor<Identifier, PrimitiveLiteral, ParameterLookup>;
|
||||
|
||||
class HierarchicalTreeVisitor : public TreeCompositeVisitor,
|
||||
public TreeLeafVisitor {
|
||||
class HierarchicalTreeVisitor : public TreeCompositeVisitor, public TreeLeafVisitor {
|
||||
public:
|
||||
using TreeCompositeVisitor::PostVisit;
|
||||
using TreeCompositeVisitor::PreVisit;
|
||||
@ -103,21 +97,15 @@ class HierarchicalTreeVisitor : public TreeCompositeVisitor,
|
||||
template <class TResult>
|
||||
class ExpressionVisitor
|
||||
: public ::utils::Visitor<
|
||||
TResult, NamedExpression, OrOperator, XorOperator, AndOperator,
|
||||
NotOperator, AdditionOperator, SubtractionOperator,
|
||||
MultiplicationOperator, DivisionOperator, ModOperator,
|
||||
NotEqualOperator, EqualOperator, LessOperator, GreaterOperator,
|
||||
LessEqualOperator, GreaterEqualOperator, InListOperator,
|
||||
SubscriptOperator, ListSlicingOperator, IfOperator, UnaryPlusOperator,
|
||||
UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral,
|
||||
PropertyLookup, LabelsTest, Aggregation, Function, Reduce, Coalesce,
|
||||
Extract, All, Single, Any, None, ParameterLookup, Identifier,
|
||||
PrimitiveLiteral, RegexMatch> {};
|
||||
TResult, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, AdditionOperator,
|
||||
SubtractionOperator, MultiplicationOperator, DivisionOperator, ModOperator, NotEqualOperator, EqualOperator,
|
||||
LessOperator, GreaterOperator, LessEqualOperator, GreaterEqualOperator, InListOperator, SubscriptOperator,
|
||||
ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral,
|
||||
MapLiteral, PropertyLookup, LabelsTest, Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Any,
|
||||
None, ParameterLookup, Identifier, PrimitiveLiteral, RegexMatch> {};
|
||||
|
||||
template <class TResult>
|
||||
class QueryVisitor
|
||||
: public ::utils::Visitor<TResult, CypherQuery, ExplainQuery, ProfileQuery,
|
||||
IndexQuery, AuthQuery, InfoQuery, ConstraintQuery,
|
||||
DumpQuery, ReplicationQuery, LockPathQuery> {};
|
||||
class QueryVisitor : public ::utils::Visitor<TResult, CypherQuery, ExplainQuery, ProfileQuery, IndexQuery, AuthQuery,
|
||||
InfoQuery, ConstraintQuery, DumpQuery, ReplicationQuery, LockPathQuery> {};
|
||||
|
||||
} // namespace query
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -22,12 +22,10 @@ struct ParsingContext {
|
||||
|
||||
class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
public:
|
||||
explicit CypherMainVisitor(ParsingContext context, AstStorage *storage)
|
||||
: context_(context), storage_(storage) {}
|
||||
explicit CypherMainVisitor(ParsingContext context, AstStorage *storage) : context_(context), storage_(storage) {}
|
||||
|
||||
private:
|
||||
Expression *CreateBinaryOperatorByToken(size_t token, Expression *e1,
|
||||
Expression *e2) {
|
||||
Expression *CreateBinaryOperatorByToken(size_t token, Expression *e1, Expression *e2) {
|
||||
switch (token) {
|
||||
case MemgraphCypher::OR:
|
||||
return storage_->Create<OrOperator>(e1, e2);
|
||||
@ -82,8 +80,7 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
for (auto *child : all_children) {
|
||||
antlr4::tree::TerminalNode *operator_node = nullptr;
|
||||
if ((operator_node = dynamic_cast<antlr4::tree::TerminalNode *>(child))) {
|
||||
if (std::find(allowed_operators.begin(), allowed_operators.end(),
|
||||
operator_node->getSymbol()->getType()) !=
|
||||
if (std::find(allowed_operators.begin(), allowed_operators.end(), operator_node->getSymbol()->getType()) !=
|
||||
allowed_operators.end()) {
|
||||
operators.push_back(operator_node->getSymbol()->getType());
|
||||
}
|
||||
@ -100,8 +97,7 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
* expression7.
|
||||
*/
|
||||
template <typename TExpression>
|
||||
Expression *LeftAssociativeOperatorExpression(
|
||||
std::vector<TExpression *> _expressions,
|
||||
Expression *LeftAssociativeOperatorExpression(std::vector<TExpression *> _expressions,
|
||||
std::vector<antlr4::tree::ParseTree *> all_children,
|
||||
const std::vector<size_t> &allowed_operators) {
|
||||
DMG_ASSERT(_expressions.size(), "can't happen");
|
||||
@ -114,16 +110,13 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
|
||||
Expression *first_operand = expressions[0];
|
||||
for (int i = 1; i < (int)expressions.size(); ++i) {
|
||||
first_operand = CreateBinaryOperatorByToken(
|
||||
operators[i - 1], first_operand, expressions[i]);
|
||||
first_operand = CreateBinaryOperatorByToken(operators[i - 1], first_operand, expressions[i]);
|
||||
}
|
||||
return first_operand;
|
||||
}
|
||||
|
||||
template <typename TExpression>
|
||||
Expression *PrefixUnaryOperator(
|
||||
TExpression *_expression,
|
||||
std::vector<antlr4::tree::ParseTree *> all_children,
|
||||
Expression *PrefixUnaryOperator(TExpression *_expression, std::vector<antlr4::tree::ParseTree *> all_children,
|
||||
const std::vector<size_t> &allowed_operators) {
|
||||
DMG_ASSERT(_expression, "can't happen");
|
||||
auto operators = ExtractOperators(all_children, allowed_operators);
|
||||
@ -138,26 +131,22 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
/**
|
||||
* @return CypherQuery*
|
||||
*/
|
||||
antlrcpp::Any visitCypherQuery(
|
||||
MemgraphCypher::CypherQueryContext *ctx) override;
|
||||
antlrcpp::Any visitCypherQuery(MemgraphCypher::CypherQueryContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return IndexQuery*
|
||||
*/
|
||||
antlrcpp::Any visitIndexQuery(
|
||||
MemgraphCypher::IndexQueryContext *ctx) override;
|
||||
antlrcpp::Any visitIndexQuery(MemgraphCypher::IndexQueryContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return ExplainQuery*
|
||||
*/
|
||||
antlrcpp::Any visitExplainQuery(
|
||||
MemgraphCypher::ExplainQueryContext *ctx) override;
|
||||
antlrcpp::Any visitExplainQuery(MemgraphCypher::ExplainQueryContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return ProfileQuery*
|
||||
*/
|
||||
antlrcpp::Any visitProfileQuery(
|
||||
MemgraphCypher::ProfileQueryContext *ctx) override;
|
||||
antlrcpp::Any visitProfileQuery(MemgraphCypher::ProfileQueryContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return InfoQuery*
|
||||
@ -167,14 +156,12 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
/**
|
||||
* @return Constraint
|
||||
*/
|
||||
antlrcpp::Any visitConstraint(
|
||||
MemgraphCypher::ConstraintContext *ctx) override;
|
||||
antlrcpp::Any visitConstraint(MemgraphCypher::ConstraintContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return ConstraintQuery*
|
||||
*/
|
||||
antlrcpp::Any visitConstraintQuery(
|
||||
MemgraphCypher::ConstraintQueryContext *ctx) override;
|
||||
antlrcpp::Any visitConstraintQuery(MemgraphCypher::ConstraintQueryContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return AuthQuery*
|
||||
@ -189,56 +176,47 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
/**
|
||||
* @return ReplicationQuery*
|
||||
*/
|
||||
antlrcpp::Any visitReplicationQuery(
|
||||
MemgraphCypher::ReplicationQueryContext *ctx) override;
|
||||
antlrcpp::Any visitReplicationQuery(MemgraphCypher::ReplicationQueryContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return ReplicationQuery*
|
||||
*/
|
||||
antlrcpp::Any visitSetReplicationRole(
|
||||
MemgraphCypher::SetReplicationRoleContext *ctx) override;
|
||||
antlrcpp::Any visitSetReplicationRole(MemgraphCypher::SetReplicationRoleContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return ReplicationQuery*
|
||||
*/
|
||||
antlrcpp::Any visitShowReplicationRole(
|
||||
MemgraphCypher::ShowReplicationRoleContext *ctx) override;
|
||||
antlrcpp::Any visitShowReplicationRole(MemgraphCypher::ShowReplicationRoleContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return ReplicationQuery*
|
||||
*/
|
||||
antlrcpp::Any visitRegisterReplica(
|
||||
MemgraphCypher::RegisterReplicaContext *ctx) override;
|
||||
antlrcpp::Any visitRegisterReplica(MemgraphCypher::RegisterReplicaContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return ReplicationQuery*
|
||||
*/
|
||||
antlrcpp::Any visitDropReplica(
|
||||
MemgraphCypher::DropReplicaContext *ctx) override;
|
||||
antlrcpp::Any visitDropReplica(MemgraphCypher::DropReplicaContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return ReplicationQuery*
|
||||
*/
|
||||
antlrcpp::Any visitShowReplicas(
|
||||
MemgraphCypher::ShowReplicasContext *ctx) override;
|
||||
antlrcpp::Any visitShowReplicas(MemgraphCypher::ShowReplicasContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return LockPathQuery*
|
||||
*/
|
||||
antlrcpp::Any visitLockPathQuery(
|
||||
MemgraphCypher::LockPathQueryContext *ctx) override;
|
||||
antlrcpp::Any visitLockPathQuery(MemgraphCypher::LockPathQueryContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return CypherUnion*
|
||||
*/
|
||||
antlrcpp::Any visitCypherUnion(
|
||||
MemgraphCypher::CypherUnionContext *ctx) override;
|
||||
antlrcpp::Any visitCypherUnion(MemgraphCypher::CypherUnionContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return SingleQuery*
|
||||
*/
|
||||
antlrcpp::Any visitSingleQuery(
|
||||
MemgraphCypher::SingleQueryContext *ctx) override;
|
||||
antlrcpp::Any visitSingleQuery(MemgraphCypher::SingleQueryContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return Clause* or vector<Clause*>!!!
|
||||
@ -248,8 +226,7 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
/**
|
||||
* @return Match*
|
||||
*/
|
||||
antlrcpp::Any visitCypherMatch(
|
||||
MemgraphCypher::CypherMatchContext *ctx) override;
|
||||
antlrcpp::Any visitCypherMatch(MemgraphCypher::CypherMatchContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return Create*
|
||||
@ -259,20 +236,17 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
/**
|
||||
* @return CallProcedure*
|
||||
*/
|
||||
antlrcpp::Any visitCallProcedure(
|
||||
MemgraphCypher::CallProcedureContext *ctx) override;
|
||||
antlrcpp::Any visitCallProcedure(MemgraphCypher::CallProcedureContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return std::string
|
||||
*/
|
||||
antlrcpp::Any visitUserOrRoleName(
|
||||
MemgraphCypher::UserOrRoleNameContext *ctx) override;
|
||||
antlrcpp::Any visitUserOrRoleName(MemgraphCypher::UserOrRoleNameContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any visitCreateRole(
|
||||
MemgraphCypher::CreateRoleContext *ctx) override;
|
||||
antlrcpp::Any visitCreateRole(MemgraphCypher::CreateRoleContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return AuthQuery*
|
||||
@ -287,8 +261,7 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
/**
|
||||
* @return IndexQuery*
|
||||
*/
|
||||
antlrcpp::Any visitCreateIndex(
|
||||
MemgraphCypher::CreateIndexContext *ctx) override;
|
||||
antlrcpp::Any visitCreateIndex(MemgraphCypher::CreateIndexContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return DropIndex*
|
||||
@ -298,14 +271,12 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
/**
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any visitCreateUser(
|
||||
MemgraphCypher::CreateUserContext *ctx) override;
|
||||
antlrcpp::Any visitCreateUser(MemgraphCypher::CreateUserContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any visitSetPassword(
|
||||
MemgraphCypher::SetPasswordContext *ctx) override;
|
||||
antlrcpp::Any visitSetPassword(MemgraphCypher::SetPasswordContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return AuthQuery*
|
||||
@ -330,20 +301,17 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
/**
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any visitGrantPrivilege(
|
||||
MemgraphCypher::GrantPrivilegeContext *ctx) override;
|
||||
antlrcpp::Any visitGrantPrivilege(MemgraphCypher::GrantPrivilegeContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any visitDenyPrivilege(
|
||||
MemgraphCypher::DenyPrivilegeContext *ctx) override;
|
||||
antlrcpp::Any visitDenyPrivilege(MemgraphCypher::DenyPrivilegeContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any visitRevokePrivilege(
|
||||
MemgraphCypher::RevokePrivilegeContext *ctx) override;
|
||||
antlrcpp::Any visitRevokePrivilege(MemgraphCypher::RevokePrivilegeContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return AuthQuery::Privilege
|
||||
@ -353,46 +321,39 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
/**
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any visitShowPrivileges(
|
||||
MemgraphCypher::ShowPrivilegesContext *ctx) override;
|
||||
antlrcpp::Any visitShowPrivileges(MemgraphCypher::ShowPrivilegesContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any visitShowRoleForUser(
|
||||
MemgraphCypher::ShowRoleForUserContext *ctx) override;
|
||||
antlrcpp::Any visitShowRoleForUser(MemgraphCypher::ShowRoleForUserContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return AuthQuery*
|
||||
*/
|
||||
antlrcpp::Any visitShowUsersForRole(
|
||||
MemgraphCypher::ShowUsersForRoleContext *ctx) override;
|
||||
antlrcpp::Any visitShowUsersForRole(MemgraphCypher::ShowUsersForRoleContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return Return*
|
||||
*/
|
||||
antlrcpp::Any visitCypherReturn(
|
||||
MemgraphCypher::CypherReturnContext *ctx) override;
|
||||
antlrcpp::Any visitCypherReturn(MemgraphCypher::CypherReturnContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return Return*
|
||||
*/
|
||||
antlrcpp::Any visitReturnBody(
|
||||
MemgraphCypher::ReturnBodyContext *ctx) override;
|
||||
antlrcpp::Any visitReturnBody(MemgraphCypher::ReturnBodyContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return pair<bool, vector<NamedExpression*>> first member is true if
|
||||
* asterisk was found in return
|
||||
* expressions.
|
||||
*/
|
||||
antlrcpp::Any visitReturnItems(
|
||||
MemgraphCypher::ReturnItemsContext *ctx) override;
|
||||
antlrcpp::Any visitReturnItems(MemgraphCypher::ReturnItemsContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return vector<NamedExpression*>
|
||||
*/
|
||||
antlrcpp::Any visitReturnItem(
|
||||
MemgraphCypher::ReturnItemContext *ctx) override;
|
||||
antlrcpp::Any visitReturnItem(MemgraphCypher::ReturnItemContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return vector<SortItem>
|
||||
@ -407,44 +368,37 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
/**
|
||||
* @return NodeAtom*
|
||||
*/
|
||||
antlrcpp::Any visitNodePattern(
|
||||
MemgraphCypher::NodePatternContext *ctx) override;
|
||||
antlrcpp::Any visitNodePattern(MemgraphCypher::NodePatternContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return vector<LabelIx>
|
||||
*/
|
||||
antlrcpp::Any visitNodeLabels(
|
||||
MemgraphCypher::NodeLabelsContext *ctx) override;
|
||||
antlrcpp::Any visitNodeLabels(MemgraphCypher::NodeLabelsContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return unordered_map<PropertyIx, Expression*>
|
||||
*/
|
||||
antlrcpp::Any visitProperties(
|
||||
MemgraphCypher::PropertiesContext *ctx) override;
|
||||
antlrcpp::Any visitProperties(MemgraphCypher::PropertiesContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return map<std::string, Expression*>
|
||||
*/
|
||||
antlrcpp::Any visitMapLiteral(
|
||||
MemgraphCypher::MapLiteralContext *ctx) override;
|
||||
antlrcpp::Any visitMapLiteral(MemgraphCypher::MapLiteralContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return vector<Expression*>
|
||||
*/
|
||||
antlrcpp::Any visitListLiteral(
|
||||
MemgraphCypher::ListLiteralContext *ctx) override;
|
||||
antlrcpp::Any visitListLiteral(MemgraphCypher::ListLiteralContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return PropertyIx
|
||||
*/
|
||||
antlrcpp::Any visitPropertyKeyName(
|
||||
MemgraphCypher::PropertyKeyNameContext *ctx) override;
|
||||
antlrcpp::Any visitPropertyKeyName(MemgraphCypher::PropertyKeyNameContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return string
|
||||
*/
|
||||
antlrcpp::Any visitSymbolicName(
|
||||
MemgraphCypher::SymbolicNameContext *ctx) override;
|
||||
antlrcpp::Any visitSymbolicName(MemgraphCypher::SymbolicNameContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return vector<Pattern*>
|
||||
@ -454,185 +408,160 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
/**
|
||||
* @return Pattern*
|
||||
*/
|
||||
antlrcpp::Any visitPatternPart(
|
||||
MemgraphCypher::PatternPartContext *ctx) override;
|
||||
antlrcpp::Any visitPatternPart(MemgraphCypher::PatternPartContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return Pattern*
|
||||
*/
|
||||
antlrcpp::Any visitPatternElement(
|
||||
MemgraphCypher::PatternElementContext *ctx) override;
|
||||
antlrcpp::Any visitPatternElement(MemgraphCypher::PatternElementContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return vector<pair<EdgeAtom*, NodeAtom*>>
|
||||
*/
|
||||
antlrcpp::Any visitPatternElementChain(
|
||||
MemgraphCypher::PatternElementChainContext *ctx) override;
|
||||
antlrcpp::Any visitPatternElementChain(MemgraphCypher::PatternElementChainContext *ctx) override;
|
||||
|
||||
/**
|
||||
*@return EdgeAtom*
|
||||
*/
|
||||
antlrcpp::Any visitRelationshipPattern(
|
||||
MemgraphCypher::RelationshipPatternContext *ctx) override;
|
||||
antlrcpp::Any visitRelationshipPattern(MemgraphCypher::RelationshipPatternContext *ctx) override;
|
||||
|
||||
/**
|
||||
* This should never be called. Everything is done directly in
|
||||
* visitRelationshipPattern.
|
||||
*/
|
||||
antlrcpp::Any visitRelationshipDetail(
|
||||
MemgraphCypher::RelationshipDetailContext *ctx) override;
|
||||
antlrcpp::Any visitRelationshipDetail(MemgraphCypher::RelationshipDetailContext *ctx) override;
|
||||
|
||||
/**
|
||||
* This should never be called. Everything is done directly in
|
||||
* visitRelationshipPattern.
|
||||
*/
|
||||
antlrcpp::Any visitRelationshipLambda(
|
||||
MemgraphCypher::RelationshipLambdaContext *ctx) override;
|
||||
antlrcpp::Any visitRelationshipLambda(MemgraphCypher::RelationshipLambdaContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return vector<EdgeTypeIx>
|
||||
*/
|
||||
antlrcpp::Any visitRelationshipTypes(
|
||||
MemgraphCypher::RelationshipTypesContext *ctx) override;
|
||||
antlrcpp::Any visitRelationshipTypes(MemgraphCypher::RelationshipTypesContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return std::tuple<EdgeAtom::Type, int64_t, int64_t>.
|
||||
*/
|
||||
antlrcpp::Any visitVariableExpansion(
|
||||
MemgraphCypher::VariableExpansionContext *ctx) override;
|
||||
antlrcpp::Any visitVariableExpansion(MemgraphCypher::VariableExpansionContext *ctx) override;
|
||||
|
||||
/**
|
||||
* Top level expression, does nothing.
|
||||
*
|
||||
* @return Expression*
|
||||
*/
|
||||
antlrcpp::Any visitExpression(
|
||||
MemgraphCypher::ExpressionContext *ctx) override;
|
||||
antlrcpp::Any visitExpression(MemgraphCypher::ExpressionContext *ctx) override;
|
||||
|
||||
/**
|
||||
* OR.
|
||||
*
|
||||
* @return Expression*
|
||||
*/
|
||||
antlrcpp::Any visitExpression12(
|
||||
MemgraphCypher::Expression12Context *ctx) override;
|
||||
antlrcpp::Any visitExpression12(MemgraphCypher::Expression12Context *ctx) override;
|
||||
|
||||
/**
|
||||
* XOR.
|
||||
*
|
||||
* @return Expression*
|
||||
*/
|
||||
antlrcpp::Any visitExpression11(
|
||||
MemgraphCypher::Expression11Context *ctx) override;
|
||||
antlrcpp::Any visitExpression11(MemgraphCypher::Expression11Context *ctx) override;
|
||||
|
||||
/**
|
||||
* AND.
|
||||
*
|
||||
* @return Expression*
|
||||
*/
|
||||
antlrcpp::Any visitExpression10(
|
||||
MemgraphCypher::Expression10Context *ctx) override;
|
||||
antlrcpp::Any visitExpression10(MemgraphCypher::Expression10Context *ctx) override;
|
||||
|
||||
/**
|
||||
* NOT.
|
||||
*
|
||||
* @return Expression*
|
||||
*/
|
||||
antlrcpp::Any visitExpression9(
|
||||
MemgraphCypher::Expression9Context *ctx) override;
|
||||
antlrcpp::Any visitExpression9(MemgraphCypher::Expression9Context *ctx) override;
|
||||
|
||||
/**
|
||||
* Comparisons.
|
||||
*
|
||||
* @return Expression*
|
||||
*/
|
||||
antlrcpp::Any visitExpression8(
|
||||
MemgraphCypher::Expression8Context *ctx) override;
|
||||
antlrcpp::Any visitExpression8(MemgraphCypher::Expression8Context *ctx) override;
|
||||
|
||||
/**
|
||||
* Never call this. Everything related to generating code for comparison
|
||||
* operators should be done in visitExpression8.
|
||||
*/
|
||||
antlrcpp::Any visitPartialComparisonExpression(
|
||||
MemgraphCypher::PartialComparisonExpressionContext *ctx) override;
|
||||
antlrcpp::Any visitPartialComparisonExpression(MemgraphCypher::PartialComparisonExpressionContext *ctx) override;
|
||||
|
||||
/**
|
||||
* Addition and subtraction.
|
||||
*
|
||||
* @return Expression*
|
||||
*/
|
||||
antlrcpp::Any visitExpression7(
|
||||
MemgraphCypher::Expression7Context *ctx) override;
|
||||
antlrcpp::Any visitExpression7(MemgraphCypher::Expression7Context *ctx) override;
|
||||
|
||||
/**
|
||||
* Multiplication, division, modding.
|
||||
*
|
||||
* @return Expression*
|
||||
*/
|
||||
antlrcpp::Any visitExpression6(
|
||||
MemgraphCypher::Expression6Context *ctx) override;
|
||||
antlrcpp::Any visitExpression6(MemgraphCypher::Expression6Context *ctx) override;
|
||||
|
||||
/**
|
||||
* Power.
|
||||
*
|
||||
* @return Expression*
|
||||
*/
|
||||
antlrcpp::Any visitExpression5(
|
||||
MemgraphCypher::Expression5Context *ctx) override;
|
||||
antlrcpp::Any visitExpression5(MemgraphCypher::Expression5Context *ctx) override;
|
||||
|
||||
/**
|
||||
* Unary minus and plus.
|
||||
*
|
||||
* @return Expression*
|
||||
*/
|
||||
antlrcpp::Any visitExpression4(
|
||||
MemgraphCypher::Expression4Context *ctx) override;
|
||||
antlrcpp::Any visitExpression4(MemgraphCypher::Expression4Context *ctx) override;
|
||||
|
||||
/**
|
||||
* IS NULL, IS NOT NULL, STARTS WITH, END WITH, =~, ...
|
||||
*
|
||||
* @return Expression*
|
||||
*/
|
||||
antlrcpp::Any visitExpression3a(
|
||||
MemgraphCypher::Expression3aContext *ctx) override;
|
||||
antlrcpp::Any visitExpression3a(MemgraphCypher::Expression3aContext *ctx) override;
|
||||
|
||||
/**
|
||||
* Does nothing, everything is done in visitExpression3a.
|
||||
*
|
||||
* @return Expression*
|
||||
*/
|
||||
antlrcpp::Any visitStringAndNullOperators(
|
||||
MemgraphCypher::StringAndNullOperatorsContext *ctx) override;
|
||||
antlrcpp::Any visitStringAndNullOperators(MemgraphCypher::StringAndNullOperatorsContext *ctx) override;
|
||||
|
||||
/**
|
||||
* List indexing and slicing.
|
||||
*
|
||||
* @return Expression*
|
||||
*/
|
||||
antlrcpp::Any visitExpression3b(
|
||||
MemgraphCypher::Expression3bContext *ctx) override;
|
||||
antlrcpp::Any visitExpression3b(MemgraphCypher::Expression3bContext *ctx) override;
|
||||
|
||||
/**
|
||||
* Does nothing, everything is done in visitExpression3b.
|
||||
*/
|
||||
antlrcpp::Any visitListIndexingOrSlicing(
|
||||
MemgraphCypher::ListIndexingOrSlicingContext *ctx) override;
|
||||
antlrcpp::Any visitListIndexingOrSlicing(MemgraphCypher::ListIndexingOrSlicingContext *ctx) override;
|
||||
|
||||
/**
|
||||
* Node labels test.
|
||||
*
|
||||
* @return Expression*
|
||||
*/
|
||||
antlrcpp::Any visitExpression2a(
|
||||
MemgraphCypher::Expression2aContext *ctx) override;
|
||||
antlrcpp::Any visitExpression2a(MemgraphCypher::Expression2aContext *ctx) override;
|
||||
|
||||
/**
|
||||
* Property lookup.
|
||||
*
|
||||
* @return Expression*
|
||||
*/
|
||||
antlrcpp::Any visitExpression2b(
|
||||
MemgraphCypher::Expression2bContext *ctx) override;
|
||||
antlrcpp::Any visitExpression2b(MemgraphCypher::Expression2bContext *ctx) override;
|
||||
|
||||
/**
|
||||
* Literals, params, list comprehension...
|
||||
@ -649,20 +578,17 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
/**
|
||||
* @return Expression*
|
||||
*/
|
||||
antlrcpp::Any visitParenthesizedExpression(
|
||||
MemgraphCypher::ParenthesizedExpressionContext *ctx) override;
|
||||
antlrcpp::Any visitParenthesizedExpression(MemgraphCypher::ParenthesizedExpressionContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return Expression*
|
||||
*/
|
||||
antlrcpp::Any visitFunctionInvocation(
|
||||
MemgraphCypher::FunctionInvocationContext *ctx) override;
|
||||
antlrcpp::Any visitFunctionInvocation(MemgraphCypher::FunctionInvocationContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return string - uppercased
|
||||
*/
|
||||
antlrcpp::Any visitFunctionName(
|
||||
MemgraphCypher::FunctionNameContext *ctx) override;
|
||||
antlrcpp::Any visitFunctionName(MemgraphCypher::FunctionNameContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return Expression*
|
||||
@ -679,32 +605,27 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
/**
|
||||
* @return bool
|
||||
*/
|
||||
antlrcpp::Any visitBooleanLiteral(
|
||||
MemgraphCypher::BooleanLiteralContext *ctx) override;
|
||||
antlrcpp::Any visitBooleanLiteral(MemgraphCypher::BooleanLiteralContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return TypedValue with either double or int
|
||||
*/
|
||||
antlrcpp::Any visitNumberLiteral(
|
||||
MemgraphCypher::NumberLiteralContext *ctx) override;
|
||||
antlrcpp::Any visitNumberLiteral(MemgraphCypher::NumberLiteralContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return int64_t
|
||||
*/
|
||||
antlrcpp::Any visitIntegerLiteral(
|
||||
MemgraphCypher::IntegerLiteralContext *ctx) override;
|
||||
antlrcpp::Any visitIntegerLiteral(MemgraphCypher::IntegerLiteralContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return double
|
||||
*/
|
||||
antlrcpp::Any visitDoubleLiteral(
|
||||
MemgraphCypher::DoubleLiteralContext *ctx) override;
|
||||
antlrcpp::Any visitDoubleLiteral(MemgraphCypher::DoubleLiteralContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return Delete*
|
||||
*/
|
||||
antlrcpp::Any visitCypherDelete(
|
||||
MemgraphCypher::CypherDeleteContext *ctx) override;
|
||||
antlrcpp::Any visitCypherDelete(MemgraphCypher::CypherDeleteContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return Where*
|
||||
@ -729,27 +650,23 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
/**
|
||||
* @return Clause*
|
||||
*/
|
||||
antlrcpp::Any visitRemoveItem(
|
||||
MemgraphCypher::RemoveItemContext *ctx) override;
|
||||
antlrcpp::Any visitRemoveItem(MemgraphCypher::RemoveItemContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return PropertyLookup*
|
||||
*/
|
||||
antlrcpp::Any visitPropertyExpression(
|
||||
MemgraphCypher::PropertyExpressionContext *ctx) override;
|
||||
antlrcpp::Any visitPropertyExpression(MemgraphCypher::PropertyExpressionContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return IfOperator*
|
||||
*/
|
||||
antlrcpp::Any visitCaseExpression(
|
||||
MemgraphCypher::CaseExpressionContext *ctx) override;
|
||||
antlrcpp::Any visitCaseExpression(MemgraphCypher::CaseExpressionContext *ctx) override;
|
||||
|
||||
/**
|
||||
* Never call this. Ast generation for this production is done in
|
||||
* @c visitCaseExpression.
|
||||
*/
|
||||
antlrcpp::Any visitCaseAlternatives(
|
||||
MemgraphCypher::CaseAlternativesContext *ctx) override;
|
||||
antlrcpp::Any visitCaseAlternatives(MemgraphCypher::CaseAlternativesContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return With*
|
||||
@ -770,8 +687,7 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
* Never call this. Ast generation for these expressions should be done by
|
||||
* explicitly visiting the members of @c FilterExpressionContext.
|
||||
*/
|
||||
antlrcpp::Any visitFilterExpression(
|
||||
MemgraphCypher::FilterExpressionContext *) override;
|
||||
antlrcpp::Any visitFilterExpression(MemgraphCypher::FilterExpressionContext *) override;
|
||||
|
||||
public:
|
||||
Query *query() { return query_; }
|
||||
|
@ -89,8 +89,7 @@ void PrintObject(std::ostream *out, const std::map<K, V> &map);
|
||||
|
||||
template <typename T>
|
||||
void PrintObject(std::ostream *out, const T &arg) {
|
||||
static_assert(
|
||||
!std::is_convertible<T, Expression *>::value,
|
||||
static_assert(!std::is_convertible<T, Expression *>::value,
|
||||
"This overload shouldn't be called with pointers convertible "
|
||||
"to Expression *. This means your other PrintObject overloads aren't "
|
||||
"being called for certain AST nodes when they should (or perhaps such "
|
||||
@ -98,13 +97,9 @@ void PrintObject(std::ostream *out, const T &arg) {
|
||||
*out << arg;
|
||||
}
|
||||
|
||||
void PrintObject(std::ostream *out, const std::string &str) {
|
||||
*out << utils::Escape(str);
|
||||
}
|
||||
void PrintObject(std::ostream *out, const std::string &str) { *out << utils::Escape(str); }
|
||||
|
||||
void PrintObject(std::ostream *out, Aggregation::Op op) {
|
||||
*out << Aggregation::OpToString(op);
|
||||
}
|
||||
void PrintObject(std::ostream *out, Aggregation::Op op) { *out << Aggregation::OpToString(op); }
|
||||
|
||||
void PrintObject(std::ostream *out, Expression *expr) {
|
||||
if (expr) {
|
||||
@ -115,9 +110,7 @@ void PrintObject(std::ostream *out, Expression *expr) {
|
||||
}
|
||||
}
|
||||
|
||||
void PrintObject(std::ostream *out, Identifier *expr) {
|
||||
PrintObject(out, static_cast<Expression *>(expr));
|
||||
}
|
||||
void PrintObject(std::ostream *out, Identifier *expr) { PrintObject(out, static_cast<Expression *>(expr)); }
|
||||
|
||||
void PrintObject(std::ostream *out, const storage::PropertyValue &value) {
|
||||
switch (value.type()) {
|
||||
@ -154,9 +147,7 @@ void PrintObject(std::ostream *out, const storage::PropertyValue &value) {
|
||||
template <typename T>
|
||||
void PrintObject(std::ostream *out, const std::vector<T> &vec) {
|
||||
*out << "[";
|
||||
utils::PrintIterable(*out, vec, ", ", [](auto &stream, const auto &item) {
|
||||
PrintObject(&stream, item);
|
||||
});
|
||||
utils::PrintIterable(*out, vec, ", ", [](auto &stream, const auto &item) { PrintObject(&stream, item); });
|
||||
*out << "]";
|
||||
}
|
||||
|
||||
@ -179,26 +170,22 @@ void PrintOperatorArgs(std::ostream *out, const T &arg) {
|
||||
}
|
||||
|
||||
template <typename T, typename... Ts>
|
||||
void PrintOperatorArgs(std::ostream *out, const T &arg, const Ts &... args) {
|
||||
void PrintOperatorArgs(std::ostream *out, const T &arg, const Ts &...args) {
|
||||
*out << " ";
|
||||
PrintObject(out, arg);
|
||||
PrintOperatorArgs(out, args...);
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
void PrintOperator(std::ostream *out, const std::string &name,
|
||||
const Ts &... args) {
|
||||
void PrintOperator(std::ostream *out, const std::string &name, const Ts &...args) {
|
||||
*out << "(" << name;
|
||||
PrintOperatorArgs(out, args...);
|
||||
}
|
||||
|
||||
ExpressionPrettyPrinter::ExpressionPrettyPrinter(std::ostream *out)
|
||||
: out_(out) {}
|
||||
ExpressionPrettyPrinter::ExpressionPrettyPrinter(std::ostream *out) : out_(out) {}
|
||||
|
||||
#define UNARY_OPERATOR_VISIT(OP_NODE, OP_STR) \
|
||||
void ExpressionPrettyPrinter::Visit(OP_NODE &op) { \
|
||||
PrintOperator(out_, OP_STR, op.expression_); \
|
||||
}
|
||||
void ExpressionPrettyPrinter::Visit(OP_NODE &op) { PrintOperator(out_, OP_STR, op.expression_); }
|
||||
|
||||
UNARY_OPERATOR_VISIT(NotOperator, "Not");
|
||||
UNARY_OPERATOR_VISIT(UnaryPlusOperator, "+");
|
||||
@ -208,9 +195,7 @@ UNARY_OPERATOR_VISIT(IsNullOperator, "IsNull");
|
||||
#undef UNARY_OPERATOR_VISIT
|
||||
|
||||
#define BINARY_OPERATOR_VISIT(OP_NODE, OP_STR) \
|
||||
void ExpressionPrettyPrinter::Visit(OP_NODE &op) { \
|
||||
PrintOperator(out_, OP_STR, op.expression1_, op.expression2_); \
|
||||
}
|
||||
void ExpressionPrettyPrinter::Visit(OP_NODE &op) { PrintOperator(out_, OP_STR, op.expression1_, op.expression2_); }
|
||||
|
||||
BINARY_OPERATOR_VISIT(OrOperator, "Or");
|
||||
BINARY_OPERATOR_VISIT(XorOperator, "Xor");
|
||||
@ -232,18 +217,14 @@ BINARY_OPERATOR_VISIT(SubscriptOperator, "Subscript");
|
||||
#undef BINARY_OPERATOR_VISIT
|
||||
|
||||
void ExpressionPrettyPrinter::Visit(ListSlicingOperator &op) {
|
||||
PrintOperator(out_, "ListSlicing", op.list_, op.lower_bound_,
|
||||
op.upper_bound_);
|
||||
PrintOperator(out_, "ListSlicing", op.list_, op.lower_bound_, op.upper_bound_);
|
||||
}
|
||||
|
||||
void ExpressionPrettyPrinter::Visit(IfOperator &op) {
|
||||
PrintOperator(out_, "If", op.condition_, op.then_expression_,
|
||||
op.else_expression_);
|
||||
PrintOperator(out_, "If", op.condition_, op.then_expression_, op.else_expression_);
|
||||
}
|
||||
|
||||
void ExpressionPrettyPrinter::Visit(ListLiteral &op) {
|
||||
PrintOperator(out_, "ListLiteral", op.elements_);
|
||||
}
|
||||
void ExpressionPrettyPrinter::Visit(ListLiteral &op) { PrintOperator(out_, "ListLiteral", op.elements_); }
|
||||
|
||||
void ExpressionPrettyPrinter::Visit(MapLiteral &op) {
|
||||
std::map<std::string, Expression *> map;
|
||||
@ -253,74 +234,53 @@ void ExpressionPrettyPrinter::Visit(MapLiteral &op) {
|
||||
PrintObject(out_, map);
|
||||
}
|
||||
|
||||
void ExpressionPrettyPrinter::Visit(LabelsTest &op) {
|
||||
PrintOperator(out_, "LabelsTest", op.expression_);
|
||||
}
|
||||
void ExpressionPrettyPrinter::Visit(LabelsTest &op) { PrintOperator(out_, "LabelsTest", op.expression_); }
|
||||
|
||||
void ExpressionPrettyPrinter::Visit(Aggregation &op) {
|
||||
PrintOperator(out_, "Aggregation", op.op_);
|
||||
}
|
||||
void ExpressionPrettyPrinter::Visit(Aggregation &op) { PrintOperator(out_, "Aggregation", op.op_); }
|
||||
|
||||
void ExpressionPrettyPrinter::Visit(Function &op) {
|
||||
PrintOperator(out_, "Function", op.function_name_, op.arguments_);
|
||||
}
|
||||
void ExpressionPrettyPrinter::Visit(Function &op) { PrintOperator(out_, "Function", op.function_name_, op.arguments_); }
|
||||
|
||||
void ExpressionPrettyPrinter::Visit(Reduce &op) {
|
||||
PrintOperator(out_, "Reduce", op.accumulator_, op.initializer_,
|
||||
op.identifier_, op.list_, op.expression_);
|
||||
PrintOperator(out_, "Reduce", op.accumulator_, op.initializer_, op.identifier_, op.list_, op.expression_);
|
||||
}
|
||||
|
||||
void ExpressionPrettyPrinter::Visit(Coalesce &op) {
|
||||
PrintOperator(out_, "Coalesce", op.expressions_);
|
||||
}
|
||||
void ExpressionPrettyPrinter::Visit(Coalesce &op) { PrintOperator(out_, "Coalesce", op.expressions_); }
|
||||
|
||||
void ExpressionPrettyPrinter::Visit(Extract &op) {
|
||||
PrintOperator(out_, "Extract", op.identifier_, op.list_, op.expression_);
|
||||
}
|
||||
|
||||
void ExpressionPrettyPrinter::Visit(All &op) {
|
||||
PrintOperator(out_, "All", op.identifier_, op.list_expression_,
|
||||
op.where_->expression_);
|
||||
PrintOperator(out_, "All", op.identifier_, op.list_expression_, op.where_->expression_);
|
||||
}
|
||||
|
||||
void ExpressionPrettyPrinter::Visit(Single &op) {
|
||||
PrintOperator(out_, "Single", op.identifier_, op.list_expression_,
|
||||
op.where_->expression_);
|
||||
PrintOperator(out_, "Single", op.identifier_, op.list_expression_, op.where_->expression_);
|
||||
}
|
||||
|
||||
void ExpressionPrettyPrinter::Visit(Any &op) {
|
||||
PrintOperator(out_, "Any", op.identifier_, op.list_expression_,
|
||||
op.where_->expression_);
|
||||
PrintOperator(out_, "Any", op.identifier_, op.list_expression_, op.where_->expression_);
|
||||
}
|
||||
|
||||
void ExpressionPrettyPrinter::Visit(None &op) {
|
||||
PrintOperator(out_, "None", op.identifier_, op.list_expression_,
|
||||
op.where_->expression_);
|
||||
PrintOperator(out_, "None", op.identifier_, op.list_expression_, op.where_->expression_);
|
||||
}
|
||||
|
||||
void ExpressionPrettyPrinter::Visit(Identifier &op) {
|
||||
PrintOperator(out_, "Identifier", op.name_);
|
||||
}
|
||||
void ExpressionPrettyPrinter::Visit(Identifier &op) { PrintOperator(out_, "Identifier", op.name_); }
|
||||
|
||||
void ExpressionPrettyPrinter::Visit(PrimitiveLiteral &op) {
|
||||
PrintObject(out_, op.value_);
|
||||
}
|
||||
void ExpressionPrettyPrinter::Visit(PrimitiveLiteral &op) { PrintObject(out_, op.value_); }
|
||||
|
||||
void ExpressionPrettyPrinter::Visit(PropertyLookup &op) {
|
||||
PrintOperator(out_, "PropertyLookup", op.expression_, op.property_.name);
|
||||
}
|
||||
|
||||
void ExpressionPrettyPrinter::Visit(ParameterLookup &op) {
|
||||
PrintOperator(out_, "ParameterLookup", op.token_position_);
|
||||
}
|
||||
void ExpressionPrettyPrinter::Visit(ParameterLookup &op) { PrintOperator(out_, "ParameterLookup", op.token_position_); }
|
||||
|
||||
void ExpressionPrettyPrinter::Visit(NamedExpression &op) {
|
||||
PrintOperator(out_, "NamedExpression", op.name_, op.expression_);
|
||||
}
|
||||
|
||||
void ExpressionPrettyPrinter::Visit(RegexMatch &op) {
|
||||
PrintOperator(out_, "=~", op.string_expr_, op.regex_);
|
||||
}
|
||||
void ExpressionPrettyPrinter::Visit(RegexMatch &op) { PrintOperator(out_, "=~", op.string_expr_, op.regex_); }
|
||||
|
||||
} // namespace
|
||||
|
||||
|
@ -35,12 +35,10 @@ class Parser {
|
||||
|
||||
private:
|
||||
class FirstMessageErrorListener : public antlr4::BaseErrorListener {
|
||||
void syntaxError(antlr4::IRecognizer *, antlr4::Token *, size_t line,
|
||||
size_t position, const std::string &message,
|
||||
void syntaxError(antlr4::IRecognizer *, antlr4::Token *, size_t line, size_t position, const std::string &message,
|
||||
std::exception_ptr) override {
|
||||
if (error_.empty()) {
|
||||
error_ = "line " + std::to_string(line) + ":" +
|
||||
std::to_string(position + 1) + " " + message;
|
||||
error_ = "line " + std::to_string(line) + ":" + std::to_string(position + 1) + " " + message;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -26,8 +26,7 @@ std::string ParseStringLiteral(const std::string &s) {
|
||||
auto EncodeEscapedUnicodeCodepointUtf32 = [](const std::string &s, int &i) {
|
||||
const int kLongUnicodeLength = 8;
|
||||
int j = i + 1;
|
||||
while (j < static_cast<int>(s.size()) - 1 &&
|
||||
j < i + kLongUnicodeLength + 1 && isxdigit(s[j])) {
|
||||
while (j < static_cast<int>(s.size()) - 1 && j < i + kLongUnicodeLength + 1 && isxdigit(s[j])) {
|
||||
++j;
|
||||
}
|
||||
if (j - i == kLongUnicodeLength + 1) {
|
||||
@ -43,8 +42,7 @@ std::string ParseStringLiteral(const std::string &s) {
|
||||
auto EncodeEscapedUnicodeCodepointUtf16 = [](const std::string &s, int &i) {
|
||||
const int kShortUnicodeLength = 4;
|
||||
int j = i + 1;
|
||||
while (j < static_cast<int>(s.size()) - 1 &&
|
||||
j < i + kShortUnicodeLength + 1 && isxdigit(s[j])) {
|
||||
while (j < static_cast<int>(s.size()) - 1 && j < i + kShortUnicodeLength + 1 && isxdigit(s[j])) {
|
||||
++j;
|
||||
}
|
||||
if (j - i >= kShortUnicodeLength + 1) {
|
||||
@ -56,31 +54,24 @@ std::string ParseStringLiteral(const std::string &s) {
|
||||
throw SemanticException("Invalid UTF codepoint.");
|
||||
}
|
||||
++j;
|
||||
if (j >= static_cast<int>(s.size()) - 1 ||
|
||||
(s[j] != 'u' && s[j] != 'U')) {
|
||||
if (j >= static_cast<int>(s.size()) - 1 || (s[j] != 'u' && s[j] != 'U')) {
|
||||
throw SemanticException("Invalid UTF codepoint.");
|
||||
}
|
||||
++j;
|
||||
int k = j;
|
||||
while (k < static_cast<int>(s.size()) - 1 &&
|
||||
k < j + kShortUnicodeLength && isxdigit(s[k])) {
|
||||
while (k < static_cast<int>(s.size()) - 1 && k < j + kShortUnicodeLength && isxdigit(s[k])) {
|
||||
++k;
|
||||
}
|
||||
if (k != j + kShortUnicodeLength) {
|
||||
throw SemanticException("Invalid UTF codepoint.");
|
||||
}
|
||||
char16_t surrogates[3] = {t,
|
||||
static_cast<char16_t>(stoi(
|
||||
s.substr(j, kShortUnicodeLength), 0, 16)),
|
||||
0};
|
||||
char16_t surrogates[3] = {t, static_cast<char16_t>(stoi(s.substr(j, kShortUnicodeLength), 0, 16)), 0};
|
||||
i += kShortUnicodeLength + 2 + kShortUnicodeLength;
|
||||
std::wstring_convert<std::codecvt_utf8_utf16<char16_t>, char16_t>
|
||||
converter;
|
||||
std::wstring_convert<std::codecvt_utf8_utf16<char16_t>, char16_t> converter;
|
||||
return converter.to_bytes(surrogates);
|
||||
} else {
|
||||
i += kShortUnicodeLength;
|
||||
std::wstring_convert<std::codecvt_utf8_utf16<char16_t>, char16_t>
|
||||
converter;
|
||||
std::wstring_convert<std::codecvt_utf8_utf16<char16_t>, char16_t> converter;
|
||||
return converter.to_bytes(t);
|
||||
}
|
||||
}
|
||||
@ -168,8 +159,7 @@ std::string ParseParameter(const std::string &s) {
|
||||
if (s[1] != '`') return s.substr(1);
|
||||
// If parameter name is escaped symbolic name then symbolic name should be
|
||||
// unescaped and leading and trailing backquote should be removed.
|
||||
DMG_ASSERT(s.size() > 3U && s.back() == '`',
|
||||
"Invalid string passed as parameter name");
|
||||
DMG_ASSERT(s.size() > 3U && s.back() == '`', "Invalid string passed as parameter name");
|
||||
std::string out;
|
||||
for (int i = 2; i < static_cast<int>(s.size()) - 1; ++i) {
|
||||
if (s[i] == '`') {
|
||||
|
@ -2,8 +2,7 @@
|
||||
|
||||
namespace query {
|
||||
|
||||
class PrivilegeExtractor : public QueryVisitor<void>,
|
||||
public HierarchicalTreeVisitor {
|
||||
class PrivilegeExtractor : public QueryVisitor<void>, public HierarchicalTreeVisitor {
|
||||
public:
|
||||
using HierarchicalTreeVisitor::PostVisit;
|
||||
using HierarchicalTreeVisitor::PreVisit;
|
||||
@ -12,19 +11,13 @@ class PrivilegeExtractor : public QueryVisitor<void>,
|
||||
|
||||
std::vector<AuthQuery::Privilege> privileges() { return privileges_; }
|
||||
|
||||
void Visit(IndexQuery &) override {
|
||||
AddPrivilege(AuthQuery::Privilege::INDEX);
|
||||
}
|
||||
void Visit(IndexQuery &) override { AddPrivilege(AuthQuery::Privilege::INDEX); }
|
||||
|
||||
void Visit(AuthQuery &) override { AddPrivilege(AuthQuery::Privilege::AUTH); }
|
||||
|
||||
void Visit(ExplainQuery &query) override {
|
||||
query.cypher_query_->Accept(*this);
|
||||
}
|
||||
void Visit(ExplainQuery &query) override { query.cypher_query_->Accept(*this); }
|
||||
|
||||
void Visit(ProfileQuery &query) override {
|
||||
query.cypher_query_->Accept(*this);
|
||||
}
|
||||
void Visit(ProfileQuery &query) override { query.cypher_query_->Accept(*this); }
|
||||
|
||||
void Visit(InfoQuery &info_query) override {
|
||||
switch (info_query.info_type_) {
|
||||
@ -44,9 +37,7 @@ class PrivilegeExtractor : public QueryVisitor<void>,
|
||||
}
|
||||
}
|
||||
|
||||
void Visit(ConstraintQuery &constraint_query) override {
|
||||
AddPrivilege(AuthQuery::Privilege::CONSTRAINT);
|
||||
}
|
||||
void Visit(ConstraintQuery &constraint_query) override { AddPrivilege(AuthQuery::Privilege::CONSTRAINT); }
|
||||
|
||||
void Visit(CypherQuery &query) override {
|
||||
query.single_query_->Accept(*this);
|
||||
@ -55,13 +46,9 @@ class PrivilegeExtractor : public QueryVisitor<void>,
|
||||
}
|
||||
}
|
||||
|
||||
void Visit(DumpQuery &dump_query) override {
|
||||
AddPrivilege(AuthQuery::Privilege::DUMP);
|
||||
}
|
||||
void Visit(DumpQuery &dump_query) override { AddPrivilege(AuthQuery::Privilege::DUMP); }
|
||||
|
||||
void Visit(LockPathQuery &lock_path_query) override {
|
||||
AddPrivilege(AuthQuery::Privilege::LOCK_PATH);
|
||||
}
|
||||
void Visit(LockPathQuery &lock_path_query) override { AddPrivilege(AuthQuery::Privilege::LOCK_PATH); }
|
||||
|
||||
void Visit(ReplicationQuery &replication_query) override {
|
||||
switch (replication_query.action_) {
|
||||
|
@ -12,24 +12,19 @@
|
||||
|
||||
namespace query {
|
||||
|
||||
auto SymbolGenerator::CreateSymbol(const std::string &name, bool user_declared,
|
||||
Symbol::Type type, int token_position) {
|
||||
auto symbol =
|
||||
symbol_table_.CreateSymbol(name, user_declared, type, token_position);
|
||||
auto SymbolGenerator::CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type, int token_position) {
|
||||
auto symbol = symbol_table_.CreateSymbol(name, user_declared, type, token_position);
|
||||
scope_.symbols[name] = symbol;
|
||||
return symbol;
|
||||
}
|
||||
|
||||
auto SymbolGenerator::GetOrCreateSymbol(const std::string &name,
|
||||
bool user_declared, Symbol::Type type) {
|
||||
auto SymbolGenerator::GetOrCreateSymbol(const std::string &name, bool user_declared, Symbol::Type type) {
|
||||
auto search = scope_.symbols.find(name);
|
||||
if (search != scope_.symbols.end()) {
|
||||
auto symbol = search->second;
|
||||
// Unless we have `ANY` type, check that types match.
|
||||
if (type != Symbol::Type::ANY && symbol.type() != Symbol::Type::ANY &&
|
||||
type != symbol.type()) {
|
||||
throw TypeMismatchError(name, Symbol::TypeToString(symbol.type()),
|
||||
Symbol::TypeToString(type));
|
||||
if (type != Symbol::Type::ANY && symbol.type() != Symbol::Type::ANY && type != symbol.type()) {
|
||||
throw TypeMismatchError(name, Symbol::TypeToString(symbol.type()), Symbol::TypeToString(type));
|
||||
}
|
||||
return search->second;
|
||||
}
|
||||
@ -50,8 +45,7 @@ void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) {
|
||||
user_symbols.emplace_back(sym_pair.second);
|
||||
}
|
||||
if (user_symbols.empty()) {
|
||||
throw SemanticException(
|
||||
"There are no variables in scope to use for '*'.");
|
||||
throw SemanticException("There are no variables in scope to use for '*'.");
|
||||
}
|
||||
}
|
||||
// WITH/RETURN clause removes declarations of all the previous variables and
|
||||
@ -74,13 +68,11 @@ void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) {
|
||||
for (auto &named_expr : body.named_expressions) {
|
||||
const auto &name = named_expr->name_;
|
||||
if (!new_names.insert(name).second) {
|
||||
throw SemanticException(
|
||||
"Multiple results with the same name '{}' are not allowed.", name);
|
||||
throw SemanticException("Multiple results with the same name '{}' are not allowed.", name);
|
||||
}
|
||||
// An improvement would be to infer the type of the expression, so that the
|
||||
// new symbol would have a more specific type.
|
||||
named_expr->MapTo(CreateSymbol(name, true, Symbol::Type::ANY,
|
||||
named_expr->token_position_));
|
||||
named_expr->MapTo(CreateSymbol(name, true, Symbol::Type::ANY, named_expr->token_position_));
|
||||
}
|
||||
scope_.in_order_by = true;
|
||||
for (const auto &order_pair : body.order_by) {
|
||||
@ -102,8 +94,7 @@ void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) {
|
||||
// We have an ORDER BY or WHERE, but no aggregation, which means we didn't
|
||||
// clear the old symbols, so do it now. We cannot just call clear, because
|
||||
// we've added new symbols.
|
||||
for (auto sym_it = scope_.symbols.begin();
|
||||
sym_it != scope_.symbols.end();) {
|
||||
for (auto sym_it = scope_.symbols.begin(); sym_it != scope_.symbols.end();) {
|
||||
if (new_names.find(sym_it->first) == new_names.end()) {
|
||||
sym_it = scope_.symbols.erase(sym_it);
|
||||
} else {
|
||||
@ -131,8 +122,7 @@ bool SymbolGenerator::PreVisit(CypherUnion &) {
|
||||
|
||||
bool SymbolGenerator::PostVisit(CypherUnion &cypher_union) {
|
||||
if (prev_return_names_ != curr_return_names_) {
|
||||
throw SemanticException(
|
||||
"All subqueries in an UNION must have the same column names.");
|
||||
throw SemanticException("All subqueries in an UNION must have the same column names.");
|
||||
}
|
||||
|
||||
// create new symbols for the result of the union
|
||||
@ -180,8 +170,7 @@ bool SymbolGenerator::PreVisit(Return &ret) {
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PostVisit(Return &) {
|
||||
for (const auto &name_symbol : scope_.symbols)
|
||||
curr_return_names_.insert(name_symbol.first);
|
||||
for (const auto &name_symbol : scope_.symbols) curr_return_names_.insert(name_symbol.first);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -239,15 +228,13 @@ bool SymbolGenerator::PostVisit(Match &) {
|
||||
|
||||
SymbolGenerator::ReturnType SymbolGenerator::Visit(Identifier &ident) {
|
||||
if (scope_.in_skip || scope_.in_limit) {
|
||||
throw SemanticException("Variables are not allowed in {}.",
|
||||
scope_.in_skip ? "SKIP" : "LIMIT");
|
||||
throw SemanticException("Variables are not allowed in {}.", scope_.in_skip ? "SKIP" : "LIMIT");
|
||||
}
|
||||
Symbol symbol;
|
||||
if (scope_.in_pattern && !(scope_.in_node_atom || scope_.visiting_edge)) {
|
||||
// If we are in the pattern, and outside of a node or an edge, the
|
||||
// identifier is the pattern name.
|
||||
symbol = GetOrCreateSymbol(ident.name_, ident.user_declared_,
|
||||
Symbol::Type::PATH);
|
||||
symbol = GetOrCreateSymbol(ident.name_, ident.user_declared_, Symbol::Type::PATH);
|
||||
} else if (scope_.in_pattern && scope_.in_pattern_atom_identifier) {
|
||||
// Patterns used to create nodes and edges cannot redeclare already
|
||||
// established bindings. Declaration only happens in single node
|
||||
@ -255,8 +242,7 @@ SymbolGenerator::ReturnType SymbolGenerator::Visit(Identifier &ident) {
|
||||
// `MATCH (n) CREATE (n)` should throw an error that `n` is already
|
||||
// declared. While `MATCH (n) CREATE (n) -[:R]-> (n)` is allowed,
|
||||
// since `n` now references the bound node instead of declaring it.
|
||||
if ((scope_.in_create_node || scope_.in_create_edge) &&
|
||||
HasSymbol(ident.name_)) {
|
||||
if ((scope_.in_create_node || scope_.in_create_edge) && HasSymbol(ident.name_)) {
|
||||
throw RedeclareVariableError(ident.name_);
|
||||
}
|
||||
auto type = Symbol::Type::VERTEX;
|
||||
@ -266,14 +252,11 @@ SymbolGenerator::ReturnType SymbolGenerator::Visit(Identifier &ident) {
|
||||
if (HasSymbol(ident.name_)) {
|
||||
throw RedeclareVariableError(ident.name_);
|
||||
}
|
||||
type = scope_.visiting_edge->IsVariable() ? Symbol::Type::EDGE_LIST
|
||||
: Symbol::Type::EDGE;
|
||||
type = scope_.visiting_edge->IsVariable() ? Symbol::Type::EDGE_LIST : Symbol::Type::EDGE;
|
||||
}
|
||||
symbol = GetOrCreateSymbol(ident.name_, ident.user_declared_, type);
|
||||
} else if (scope_.in_pattern && !scope_.in_pattern_atom_identifier &&
|
||||
scope_.in_match) {
|
||||
if (scope_.in_edge_range &&
|
||||
scope_.visiting_edge->identifier_->name_ == ident.name_) {
|
||||
} else if (scope_.in_pattern && !scope_.in_pattern_atom_identifier && scope_.in_match) {
|
||||
if (scope_.in_edge_range && scope_.visiting_edge->identifier_->name_ == ident.name_) {
|
||||
// Prevent variable path bounds to reference the identifier which is bound
|
||||
// by the variable path itself.
|
||||
throw UnboundVariableError(ident.name_);
|
||||
@ -295,10 +278,9 @@ bool SymbolGenerator::PreVisit(Aggregation &aggr) {
|
||||
// Check if the aggregation can be used in this context. This check should
|
||||
// probably move to a separate phase, which checks if the query is well
|
||||
// formed.
|
||||
if ((!scope_.in_return && !scope_.in_with) || scope_.in_order_by ||
|
||||
scope_.in_skip || scope_.in_limit || scope_.in_where) {
|
||||
throw SemanticException(
|
||||
"Aggregation functions are only allowed in WITH and RETURN.");
|
||||
if ((!scope_.in_return && !scope_.in_with) || scope_.in_order_by || scope_.in_skip || scope_.in_limit ||
|
||||
scope_.in_where) {
|
||||
throw SemanticException("Aggregation functions are only allowed in WITH and RETURN.");
|
||||
}
|
||||
if (scope_.in_aggregation) {
|
||||
throw SemanticException(
|
||||
@ -313,13 +295,11 @@ bool SymbolGenerator::PreVisit(Aggregation &aggr) {
|
||||
// CASE count(n) WHEN 10 THEN "YES" ELSE "NO" END.
|
||||
// TODO: Rethink of allowing aggregations in some parts of the CASE
|
||||
// construct.
|
||||
throw SemanticException(
|
||||
"Using aggregation functions inside of CASE is not allowed.");
|
||||
throw SemanticException("Using aggregation functions inside of CASE is not allowed.");
|
||||
}
|
||||
// Create a virtual symbol for aggregation result.
|
||||
// Currently, we only have aggregation operators which return numbers.
|
||||
auto aggr_name =
|
||||
Aggregation::OpToString(aggr.op_) + std::to_string(aggr.symbol_pos_);
|
||||
auto aggr_name = Aggregation::OpToString(aggr.op_) + std::to_string(aggr.symbol_pos_);
|
||||
aggr.MapTo(CreateSymbol(aggr_name, false, Symbol::Type::NUMBER));
|
||||
scope_.in_aggregation = true;
|
||||
scope_.has_aggregation = true;
|
||||
@ -368,8 +348,7 @@ bool SymbolGenerator::PreVisit(None &none) {
|
||||
bool SymbolGenerator::PreVisit(Reduce &reduce) {
|
||||
reduce.initializer_->Accept(*this);
|
||||
reduce.list_->Accept(*this);
|
||||
VisitWithIdentifiers(reduce.expression_,
|
||||
{reduce.accumulator_, reduce.identifier_});
|
||||
VisitWithIdentifiers(reduce.expression_, {reduce.accumulator_, reduce.identifier_});
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -384,8 +363,7 @@ bool SymbolGenerator::PreVisit(Extract &extract) {
|
||||
bool SymbolGenerator::PreVisit(Pattern &pattern) {
|
||||
scope_.in_pattern = true;
|
||||
if ((scope_.in_create || scope_.in_merge) && pattern.atoms_.size() == 1U) {
|
||||
MG_ASSERT(utils::IsSubtype(*pattern.atoms_[0], NodeAtom::kType),
|
||||
"Expected a single NodeAtom in Pattern");
|
||||
MG_ASSERT(utils::IsSubtype(*pattern.atoms_[0], NodeAtom::kType), "Expected a single NodeAtom in Pattern");
|
||||
scope_.in_create_node = true;
|
||||
}
|
||||
return true;
|
||||
@ -399,13 +377,10 @@ bool SymbolGenerator::PostVisit(Pattern &) {
|
||||
|
||||
bool SymbolGenerator::PreVisit(NodeAtom &node_atom) {
|
||||
scope_.in_node_atom = true;
|
||||
bool props_or_labels =
|
||||
!node_atom.properties_.empty() || !node_atom.labels_.empty();
|
||||
bool props_or_labels = !node_atom.properties_.empty() || !node_atom.labels_.empty();
|
||||
const auto &node_name = node_atom.identifier_->name_;
|
||||
if ((scope_.in_create || scope_.in_merge) && props_or_labels &&
|
||||
HasSymbol(node_name)) {
|
||||
throw SemanticException(
|
||||
"Cannot create node '" + node_name +
|
||||
if ((scope_.in_create || scope_.in_merge) && props_or_labels && HasSymbol(node_name)) {
|
||||
throw SemanticException("Cannot create node '" + node_name +
|
||||
"' with labels or properties, because it is already declared.");
|
||||
}
|
||||
for (auto kv : node_atom.properties_) {
|
||||
@ -458,22 +433,19 @@ bool SymbolGenerator::PreVisit(EdgeAtom &edge_atom) {
|
||||
scope_.in_pattern = false;
|
||||
if (edge_atom.filter_lambda_.expression) {
|
||||
VisitWithIdentifiers(edge_atom.filter_lambda_.expression,
|
||||
{edge_atom.filter_lambda_.inner_edge,
|
||||
edge_atom.filter_lambda_.inner_node});
|
||||
{edge_atom.filter_lambda_.inner_edge, edge_atom.filter_lambda_.inner_node});
|
||||
} else {
|
||||
// Create inner symbols, but don't bind them in scope, since they are to
|
||||
// be used in the missing filter expression.
|
||||
auto *inner_edge = edge_atom.filter_lambda_.inner_edge;
|
||||
inner_edge->MapTo(symbol_table_.CreateSymbol(
|
||||
inner_edge->name_, inner_edge->user_declared_, Symbol::Type::EDGE));
|
||||
inner_edge->MapTo(symbol_table_.CreateSymbol(inner_edge->name_, inner_edge->user_declared_, Symbol::Type::EDGE));
|
||||
auto *inner_node = edge_atom.filter_lambda_.inner_node;
|
||||
inner_node->MapTo(symbol_table_.CreateSymbol(
|
||||
inner_node->name_, inner_node->user_declared_, Symbol::Type::VERTEX));
|
||||
inner_node->MapTo(
|
||||
symbol_table_.CreateSymbol(inner_node->name_, inner_node->user_declared_, Symbol::Type::VERTEX));
|
||||
}
|
||||
if (edge_atom.weight_lambda_.expression) {
|
||||
VisitWithIdentifiers(edge_atom.weight_lambda_.expression,
|
||||
{edge_atom.weight_lambda_.inner_edge,
|
||||
edge_atom.weight_lambda_.inner_node});
|
||||
{edge_atom.weight_lambda_.inner_edge, edge_atom.weight_lambda_.inner_node});
|
||||
}
|
||||
scope_.in_pattern = true;
|
||||
}
|
||||
@ -484,9 +456,8 @@ bool SymbolGenerator::PreVisit(EdgeAtom &edge_atom) {
|
||||
if (HasSymbol(edge_atom.total_weight_->name_)) {
|
||||
throw RedeclareVariableError(edge_atom.total_weight_->name_);
|
||||
}
|
||||
edge_atom.total_weight_->MapTo(GetOrCreateSymbol(
|
||||
edge_atom.total_weight_->name_, edge_atom.total_weight_->user_declared_,
|
||||
Symbol::Type::NUMBER));
|
||||
edge_atom.total_weight_->MapTo(GetOrCreateSymbol(edge_atom.total_weight_->name_,
|
||||
edge_atom.total_weight_->user_declared_, Symbol::Type::NUMBER));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@ -497,8 +468,7 @@ bool SymbolGenerator::PostVisit(EdgeAtom &) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void SymbolGenerator::VisitWithIdentifiers(
|
||||
Expression *expr, const std::vector<Identifier *> &identifiers) {
|
||||
void SymbolGenerator::VisitWithIdentifiers(Expression *expr, const std::vector<Identifier *> &identifiers) {
|
||||
std::vector<std::pair<std::optional<Symbol>, Identifier *>> prev_symbols;
|
||||
// Collect previous symbols if they exist.
|
||||
for (const auto &identifier : identifiers) {
|
||||
@ -507,8 +477,7 @@ void SymbolGenerator::VisitWithIdentifiers(
|
||||
if (prev_symbol_it != scope_.symbols.end()) {
|
||||
prev_symbol = prev_symbol_it->second;
|
||||
}
|
||||
identifier->MapTo(
|
||||
CreateSymbol(identifier->name_, identifier->user_declared_));
|
||||
identifier->MapTo(CreateSymbol(identifier->name_, identifier->user_declared_));
|
||||
prev_symbols.emplace_back(prev_symbol, identifier);
|
||||
}
|
||||
// Visit the expression with the new symbols bound.
|
||||
@ -525,8 +494,6 @@ void SymbolGenerator::VisitWithIdentifiers(
|
||||
}
|
||||
}
|
||||
|
||||
bool SymbolGenerator::HasSymbol(const std::string &name) {
|
||||
return scope_.symbols.find(name) != scope_.symbols.end();
|
||||
}
|
||||
bool SymbolGenerator::HasSymbol(const std::string &name) { return scope_.symbols.find(name) != scope_.symbols.end(); }
|
||||
|
||||
} // namespace query
|
||||
|
@ -17,8 +17,7 @@ namespace query {
|
||||
/// variable types.
|
||||
class SymbolGenerator : public HierarchicalTreeVisitor {
|
||||
public:
|
||||
explicit SymbolGenerator(SymbolTable &symbol_table)
|
||||
: symbol_table_(symbol_table) {}
|
||||
explicit SymbolGenerator(SymbolTable &symbol_table) : symbol_table_(symbol_table) {}
|
||||
|
||||
using HierarchicalTreeVisitor::PostVisit;
|
||||
using HierarchicalTreeVisitor::PreVisit;
|
||||
@ -117,14 +116,12 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
|
||||
|
||||
// Returns a freshly generated symbol. Previous mapping of the same name to a
|
||||
// different symbol is replaced with the new one.
|
||||
auto CreateSymbol(const std::string &name, bool user_declared,
|
||||
Symbol::Type type = Symbol::Type::ANY,
|
||||
auto CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY,
|
||||
int token_position = -1);
|
||||
|
||||
// Returns the symbol by name. If the mapping already exists, checks if the
|
||||
// types match. Otherwise, returns a new symbol.
|
||||
auto GetOrCreateSymbol(const std::string &name, bool user_declared,
|
||||
Symbol::Type type = Symbol::Type::ANY);
|
||||
auto GetOrCreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY);
|
||||
|
||||
void VisitReturnBody(ReturnBody &body, Where *where = nullptr);
|
||||
|
||||
|
@ -12,13 +12,11 @@ namespace query {
|
||||
class SymbolTable final {
|
||||
public:
|
||||
SymbolTable() {}
|
||||
const Symbol &CreateSymbol(const std::string &name, bool user_declared,
|
||||
Symbol::Type type = Symbol::Type::ANY,
|
||||
const Symbol &CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY,
|
||||
int32_t token_position = -1) {
|
||||
MG_ASSERT(table_.size() <= std::numeric_limits<int32_t>::max(),
|
||||
"SymbolTable size doesn't fit into 32-bit integer!");
|
||||
auto got = table_.emplace(position_, Symbol(name, position_, user_declared,
|
||||
type, token_position));
|
||||
auto got = table_.emplace(position_, Symbol(name, position_, user_declared, type, token_position));
|
||||
MG_ASSERT(got.second, "Duplicate symbol ID!");
|
||||
position_++;
|
||||
return got.first->second;
|
||||
@ -31,8 +29,7 @@ class SymbolTable final {
|
||||
while (true) {
|
||||
static const std::string &kAnonPrefix = "anon";
|
||||
std::string name_candidate = kAnonPrefix + std::to_string(id++);
|
||||
if (std::find_if(std::begin(table_), std::end(table_),
|
||||
[&name_candidate](const auto &item) -> bool {
|
||||
if (std::find_if(std::begin(table_), std::end(table_), [&name_candidate](const auto &item) -> bool {
|
||||
return item.second.name_ == name_candidate;
|
||||
}) == std::end(table_)) {
|
||||
return CreateSymbol(name_candidate, false, type);
|
||||
@ -40,15 +37,9 @@ class SymbolTable final {
|
||||
}
|
||||
}
|
||||
|
||||
const Symbol &at(const Identifier &ident) const {
|
||||
return table_.at(ident.symbol_pos_);
|
||||
}
|
||||
const Symbol &at(const NamedExpression &nexpr) const {
|
||||
return table_.at(nexpr.symbol_pos_);
|
||||
}
|
||||
const Symbol &at(const Aggregation &aggr) const {
|
||||
return table_.at(aggr.symbol_pos_);
|
||||
}
|
||||
const Symbol &at(const Identifier &ident) const { return table_.at(ident.symbol_pos_); }
|
||||
const Symbol &at(const NamedExpression &nexpr) const { return table_.at(nexpr.symbol_pos_); }
|
||||
const Symbol &at(const Aggregation &aggr) const { return table_.at(aggr.symbol_pos_); }
|
||||
|
||||
// TODO: Remove these since members are public
|
||||
int32_t max_position() const { return static_cast<int32_t>(table_.size()); }
|
||||
|
@ -64,9 +64,7 @@ StrippedQuery::StrippedQuery(const std::string &query) : original_(query) {
|
||||
// A helper function that stores literal and its token position in a
|
||||
// literals_. In stripped query text literal is replaced with a new_value.
|
||||
// new_value can be any value that is lexed as a literal.
|
||||
auto replace_stripped = [this, &token_strings](int position,
|
||||
const auto &value,
|
||||
const std::string &new_value) {
|
||||
auto replace_stripped = [this, &token_strings](int position, const auto &value, const std::string &new_value) {
|
||||
literals_.Add(position, storage::PropertyValue(value));
|
||||
token_strings.push_back(new_value);
|
||||
};
|
||||
@ -101,16 +99,13 @@ StrippedQuery::StrippedQuery(const std::string &query) : original_(query) {
|
||||
case Token::SPACE:
|
||||
break;
|
||||
case Token::STRING:
|
||||
replace_stripped(token_index, ParseStringLiteral(token.second),
|
||||
kStrippedStringToken);
|
||||
replace_stripped(token_index, ParseStringLiteral(token.second), kStrippedStringToken);
|
||||
break;
|
||||
case Token::INT:
|
||||
replace_stripped(token_index, ParseIntegerLiteral(token.second),
|
||||
kStrippedIntToken);
|
||||
replace_stripped(token_index, ParseIntegerLiteral(token.second), kStrippedIntToken);
|
||||
break;
|
||||
case Token::REAL:
|
||||
replace_stripped(token_index, ParseDoubleLiteral(token.second),
|
||||
kStrippedDoubleToken);
|
||||
replace_stripped(token_index, ParseDoubleLiteral(token.second), kStrippedDoubleToken);
|
||||
break;
|
||||
case Token::SPECIAL:
|
||||
case Token::ESCAPED_NAME:
|
||||
@ -135,9 +130,7 @@ StrippedQuery::StrippedQuery(const std::string &query) : original_(query) {
|
||||
while (it != tokens.end()) {
|
||||
// Store nonaliased named expressions in returns in named_exprs_.
|
||||
it = std::find_if(it, tokens.end(),
|
||||
[](const std::pair<Token, std::string> &a) {
|
||||
return utils::IEquals(a.second, "return");
|
||||
});
|
||||
[](const std::pair<Token, std::string> &a) { return utils::IEquals(a.second, "return"); });
|
||||
// There is no RETURN so there is nothing to do here.
|
||||
if (it == tokens.end()) return;
|
||||
// Skip RETURN;
|
||||
@ -172,13 +165,10 @@ StrippedQuery::StrippedQuery(const std::string &query) : original_(query) {
|
||||
int num_open_braces = 0;
|
||||
int num_open_parantheses = 0;
|
||||
int num_open_brackets = 0;
|
||||
for (; jt != tokens.end() &&
|
||||
(jt->second != "," || num_open_braces || num_open_parantheses ||
|
||||
num_open_brackets) &&
|
||||
!utils::IEquals(jt->second, "order") &&
|
||||
!utils::IEquals(jt->second, "skip") &&
|
||||
!utils::IEquals(jt->second, "limit") &&
|
||||
!utils::IEquals(jt->second, "union") && jt->second != ";";
|
||||
for (;
|
||||
jt != tokens.end() && (jt->second != "," || num_open_braces || num_open_parantheses || num_open_brackets) &&
|
||||
!utils::IEquals(jt->second, "order") && !utils::IEquals(jt->second, "skip") &&
|
||||
!utils::IEquals(jt->second, "limit") && !utils::IEquals(jt->second, "union") && jt->second != ";";
|
||||
++jt) {
|
||||
if (jt->second == "(") {
|
||||
++num_open_parantheses;
|
||||
@ -203,8 +193,7 @@ StrippedQuery::StrippedQuery(const std::string &query) : original_(query) {
|
||||
// trailing whitespaces.
|
||||
std::string s;
|
||||
auto begin_token = it - tokens.begin() + original_tokens.begin();
|
||||
auto end_token =
|
||||
last_non_space - tokens.begin() + original_tokens.begin() + 1;
|
||||
auto end_token = last_non_space - tokens.begin() + original_tokens.begin() + 1;
|
||||
for (auto kt = begin_token; kt != end_token; ++kt) {
|
||||
s += kt->second;
|
||||
}
|
||||
@ -280,9 +269,7 @@ std::pair<int, int> GetFirstUtf8SymbolCodepoint(const char *_s) {
|
||||
if ((*s2 >> 6) != 0x02) throw LexingException("Invalid character.");
|
||||
auto *s3 = s + 3;
|
||||
if ((*s3 >> 6) != 0x02) throw LexingException("Invalid character.");
|
||||
return {((*s & 0x07) << 18) | ((*s1 & 0x3f) << 12) | ((*s2 & 0x3f) << 6) |
|
||||
(*s3 & 0x3f),
|
||||
4};
|
||||
return {((*s & 0x07) << 18) | ((*s1 & 0x3f) << 12) | ((*s2 & 0x3f) << 6) | (*s3 & 0x3f), 4};
|
||||
}
|
||||
throw LexingException("Invalid character.");
|
||||
}
|
||||
@ -301,13 +288,9 @@ std::pair<int, int> GetFirstUtf8SymbolCodepoint(const char *_s) {
|
||||
// //\ ) , / '%%%%(%%'
|
||||
// , _.'/ `\<-- \<
|
||||
// `^^^` ^^ ^^
|
||||
int StrippedQuery::MatchKeyword(int start) const {
|
||||
return kKeywords.Match<tolower>(original_.c_str() + start);
|
||||
}
|
||||
int StrippedQuery::MatchKeyword(int start) const { return kKeywords.Match<tolower>(original_.c_str() + start); }
|
||||
|
||||
int StrippedQuery::MatchSpecial(int start) const {
|
||||
return kSpecialTokens.Match(original_.c_str() + start);
|
||||
}
|
||||
int StrippedQuery::MatchSpecial(int start) const { return kSpecialTokens.Match(original_.c_str() + start); }
|
||||
|
||||
int StrippedQuery::MatchString(int start) const {
|
||||
if (original_[start] != '"' && original_[start] != '\'') return 0;
|
||||
@ -316,9 +299,8 @@ int StrippedQuery::MatchString(int start) const {
|
||||
if (*p == start_char) return p - (original_.data() + start) + 1;
|
||||
if (*p == '\\') {
|
||||
++p;
|
||||
if (*p == '\\' || *p == '\'' || *p == '"' || *p == 'B' || *p == 'b' ||
|
||||
*p == 'F' || *p == 'f' || *p == 'N' || *p == 'n' || *p == 'R' ||
|
||||
*p == 'r' || *p == 'T' || *p == 't') {
|
||||
if (*p == '\\' || *p == '\'' || *p == '"' || *p == 'B' || *p == 'b' || *p == 'F' || *p == 'f' || *p == 'N' ||
|
||||
*p == 'n' || *p == 'R' || *p == 'r' || *p == 'T' || *p == 't') {
|
||||
// Allowed escaped characters.
|
||||
continue;
|
||||
} else if (*p == 'U' || *p == 'u') {
|
||||
@ -356,8 +338,7 @@ int StrippedQuery::MatchDecimalInt(int start) const {
|
||||
int StrippedQuery::MatchOctalInt(int start) const {
|
||||
if (original_[start] != '0') return 0;
|
||||
int i = start + 1;
|
||||
while (i < static_cast<int>(original_.size()) && '0' <= original_[i] &&
|
||||
original_[i] <= '7') {
|
||||
while (i < static_cast<int>(original_.size()) && '0' <= original_[i] && original_[i] <= '7') {
|
||||
++i;
|
||||
}
|
||||
if (i == start + 1) return 0;
|
||||
@ -440,15 +421,13 @@ int StrippedQuery::MatchEscapedName(int start) const {
|
||||
int StrippedQuery::MatchUnescapedName(int start) const {
|
||||
auto i = start;
|
||||
auto got = GetFirstUtf8SymbolCodepoint(original_.data() + i);
|
||||
if (got.first >= lexer_constants::kBitsetSize ||
|
||||
!kUnescapedNameAllowedStarts[got.first]) {
|
||||
if (got.first >= lexer_constants::kBitsetSize || !kUnescapedNameAllowedStarts[got.first]) {
|
||||
return 0;
|
||||
}
|
||||
i += got.second;
|
||||
while (i < static_cast<int>(original_.size())) {
|
||||
got = GetFirstUtf8SymbolCodepoint(original_.data() + i);
|
||||
if (got.first >= lexer_constants::kBitsetSize ||
|
||||
!kUnescapedNameAllowedParts[got.first]) {
|
||||
if (got.first >= lexer_constants::kBitsetSize || !kUnescapedNameAllowedParts[got.first]) {
|
||||
break;
|
||||
}
|
||||
i += got.second;
|
||||
@ -469,13 +448,11 @@ int StrippedQuery::MatchWhitespaceAndComments(int start) const {
|
||||
auto got = GetFirstUtf8SymbolCodepoint(original_.data() + i);
|
||||
if (got.first < lexer_constants::kBitsetSize && kSpaceParts[got.first]) {
|
||||
i += got.second;
|
||||
} else if (i + 1 < len && original_[i] == '/' &&
|
||||
original_[i + 1] == '*') {
|
||||
} else if (i + 1 < len && original_[i] == '/' && original_[i + 1] == '*') {
|
||||
comment_position = i;
|
||||
state = State::IN_BLOCK_COMMENT;
|
||||
i += 2;
|
||||
} else if (i + 1 < len && original_[i] == '/' &&
|
||||
original_[i + 1] == '/') {
|
||||
} else if (i + 1 < len && original_[i] == '/' && original_[i + 1] == '/') {
|
||||
comment_position = i;
|
||||
if (i + 2 < len) {
|
||||
// Special case for an empty line comment starting right at the end of
|
||||
@ -490,8 +467,7 @@ int StrippedQuery::MatchWhitespaceAndComments(int start) const {
|
||||
if (original_[i] == '\n') {
|
||||
state = State::OUT;
|
||||
++i;
|
||||
} else if (i + 1 < len && original_[i] == '\r' &&
|
||||
original_[i + 1] == '\n') {
|
||||
} else if (i + 1 < len && original_[i] == '\r' && original_[i + 1] == '\n') {
|
||||
state = State::OUT;
|
||||
i += 2;
|
||||
} else if (original_[i] == '\r') {
|
||||
|
@ -79,23 +79,19 @@ class Trie {
|
||||
const int kBitsetSize = 65536;
|
||||
|
||||
const trie::Trie kKeywords = {
|
||||
"union", "all", "optional", "match", "unwind", "as",
|
||||
"merge", "on", "create", "set", "detach", "delete",
|
||||
"remove", "with", "distinct", "return", "order", "by",
|
||||
"skip", "limit", "ascending", "asc", "descending", "desc",
|
||||
"where", "or", "xor", "and", "not", "in",
|
||||
"starts", "ends", "contains", "is", "null", "case",
|
||||
"when", "then", "else", "end", "count", "filter",
|
||||
"extract", "any", "none", "single", "true", "false",
|
||||
"reduce", "coalesce", "user", "password", "alter", "drop",
|
||||
"show", "stats", "unique", "explain", "profile", "storage",
|
||||
"index", "info", "exists", "assert", "constraint", "node",
|
||||
"key", "dump", "database", "call", "yield", "memory",
|
||||
"union", "all", "optional", "match", "unwind", "as", "merge", "on", "create",
|
||||
"set", "detach", "delete", "remove", "with", "distinct", "return", "order", "by",
|
||||
"skip", "limit", "ascending", "asc", "descending", "desc", "where", "or", "xor",
|
||||
"and", "not", "in", "starts", "ends", "contains", "is", "null", "case",
|
||||
"when", "then", "else", "end", "count", "filter", "extract", "any", "none",
|
||||
"single", "true", "false", "reduce", "coalesce", "user", "password", "alter", "drop",
|
||||
"show", "stats", "unique", "explain", "profile", "storage", "index", "info", "exists",
|
||||
"assert", "constraint", "node", "key", "dump", "database", "call", "yield", "memory",
|
||||
"mb", "kb", "unlimited"};
|
||||
|
||||
// Unicode codepoints that are allowed at the start of the unescaped name.
|
||||
const std::bitset<kBitsetSize> kUnescapedNameAllowedStarts(std::string(
|
||||
"00000000000000000000000000000000000111001111110011111100111111000111111111"
|
||||
const std::bitset<kBitsetSize> kUnescapedNameAllowedStarts(
|
||||
std::string("00000000000000000000000000000000000111001111110011111100111111000111111111"
|
||||
"11111111111111111111111111111111111111111111111111111111111111111111111111"
|
||||
"11111100000000000111111111111111111111111110100001111111111111111111111111"
|
||||
"10000000000000000000000000000000000001111111111111111111111111111111111111"
|
||||
@ -983,8 +979,8 @@ const std::bitset<kBitsetSize> kUnescapedNameAllowedStarts(std::string(
|
||||
"0000000000000000000000000000000000000000000000"));
|
||||
|
||||
// Unicode codepoints that are allowed at the middle of the unescaped name.
|
||||
const std::bitset<kBitsetSize> kUnescapedNameAllowedParts(std::string(
|
||||
"00000000000000000000000001100011000111001111110011111100111111000111111111"
|
||||
const std::bitset<kBitsetSize> kUnescapedNameAllowedParts(
|
||||
std::string("00000000000000000000000001100011000111001111110011111100111111000111111111"
|
||||
"11111111111111111111111111111111111111111111111111111111111111111111111111"
|
||||
"11111100000000000111111111111111111111111110100001111111111111111111111111"
|
||||
"10000000111111111100000000000100000001111111111111111111111111111111111111"
|
||||
@ -1871,8 +1867,8 @@ const std::bitset<kBitsetSize> kUnescapedNameAllowedParts(std::string(
|
||||
"11111111111111111111111010000111111111111111111111111110000000111111111100"
|
||||
"0000000000000000000000000000000000000000000000"));
|
||||
|
||||
const std::bitset<kBitsetSize> kSpaceParts(std::string(
|
||||
"00000000000000000000000000000000000000000000000000000000000000000000000000"
|
||||
const std::bitset<kBitsetSize> kSpaceParts(
|
||||
std::string("00000000000000000000000000000000000000000000000000000000000000000000000000"
|
||||
"00000000000000000000000000000000000000000000000000000000000000000000000000"
|
||||
"00000000000000000000000000000000000000000000000000000000000000000000000000"
|
||||
"00000000000000000000000000000000000000000000000000000000000000000000000000"
|
||||
|
@ -180,11 +180,9 @@ struct Or<ArgType, ArgTypes...> {
|
||||
|
||||
static std::string TypeNames() {
|
||||
if constexpr (sizeof...(ArgTypes) > 1) {
|
||||
return fmt::format("'{}', {}", ArgTypeName<ArgType>(),
|
||||
Or<ArgTypes...>::TypeNames());
|
||||
return fmt::format("'{}', {}", ArgTypeName<ArgType>(), Or<ArgTypes...>::TypeNames());
|
||||
} else {
|
||||
return fmt::format("'{}' or '{}'", ArgTypeName<ArgType>(),
|
||||
Or<ArgTypes...>::TypeNames());
|
||||
return fmt::format("'{}' or '{}'", ArgTypeName<ArgType>(), Or<ArgTypes...>::TypeNames());
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -206,20 +204,17 @@ template <class ArgType>
|
||||
struct Optional<ArgType> {
|
||||
static constexpr size_t size = 1;
|
||||
|
||||
static void Check(const char *name, const TypedValue *args, int64_t nargs,
|
||||
int64_t pos) {
|
||||
static void Check(const char *name, const TypedValue *args, int64_t nargs, int64_t pos) {
|
||||
if (nargs == 0) return;
|
||||
const TypedValue &arg = args[0];
|
||||
if constexpr (IsOrType<ArgType>::value) {
|
||||
if (!ArgType::Check(arg)) {
|
||||
throw QueryRuntimeException(
|
||||
"Optional '{}' argument at position {} must be either {}.", name,
|
||||
pos, ArgType::TypeNames());
|
||||
throw QueryRuntimeException("Optional '{}' argument at position {} must be either {}.", name, pos,
|
||||
ArgType::TypeNames());
|
||||
}
|
||||
} else {
|
||||
if (!ArgIsType<ArgType>(arg))
|
||||
throw QueryRuntimeException(
|
||||
"Optional '{}' argument at position {} must be '{}'.", name, pos,
|
||||
throw QueryRuntimeException("Optional '{}' argument at position {} must be '{}'.", name, pos,
|
||||
ArgTypeName<ArgType>());
|
||||
}
|
||||
}
|
||||
@ -229,8 +224,7 @@ template <class ArgType, class... ArgTypes>
|
||||
struct Optional<ArgType, ArgTypes...> {
|
||||
static constexpr size_t size = 1 + sizeof...(ArgTypes);
|
||||
|
||||
static void Check(const char *name, const TypedValue *args, int64_t nargs,
|
||||
int64_t pos) {
|
||||
static void Check(const char *name, const TypedValue *args, int64_t nargs, int64_t pos) {
|
||||
if (nargs == 0) return;
|
||||
Optional<ArgType>::Check(name, args, nargs, pos);
|
||||
Optional<ArgTypes...>::Check(name, args + 1, nargs - 1, pos + 1);
|
||||
@ -272,8 +266,7 @@ constexpr size_t FTypeOptionalArgs() {
|
||||
}
|
||||
|
||||
template <class ArgType, class... ArgTypes>
|
||||
void FType(const char *name, const TypedValue *args, int64_t nargs,
|
||||
int64_t pos = 1) {
|
||||
void FType(const char *name, const TypedValue *args, int64_t nargs, int64_t pos = 1) {
|
||||
if constexpr (std::is_same_v<ArgType, void>) {
|
||||
if (nargs != 0) {
|
||||
throw QueryRuntimeException("'{}' requires no arguments.", name);
|
||||
@ -285,30 +278,25 @@ void FType(const char *name, const TypedValue *args, int64_t nargs,
|
||||
constexpr int64_t total_args = required_args + optional_args;
|
||||
if constexpr (optional_args > 0) {
|
||||
if (nargs < required_args || nargs > total_args) {
|
||||
throw QueryRuntimeException("'{}' requires between {} and {} arguments.",
|
||||
name, required_args, total_args);
|
||||
throw QueryRuntimeException("'{}' requires between {} and {} arguments.", name, required_args, total_args);
|
||||
}
|
||||
} else {
|
||||
if (nargs != required_args) {
|
||||
throw QueryRuntimeException(
|
||||
"'{}' requires exactly {} {}.", name, required_args,
|
||||
throw QueryRuntimeException("'{}' requires exactly {} {}.", name, required_args,
|
||||
required_args == 1 ? "argument" : "arguments");
|
||||
}
|
||||
}
|
||||
const TypedValue &arg = args[0];
|
||||
if constexpr (IsOrType<ArgType>::value) {
|
||||
if (!ArgType::Check(arg)) {
|
||||
throw QueryRuntimeException(
|
||||
"'{}' argument at position {} must be either {}.", name, pos,
|
||||
ArgType::TypeNames());
|
||||
throw QueryRuntimeException("'{}' argument at position {} must be either {}.", name, pos, ArgType::TypeNames());
|
||||
}
|
||||
} else if constexpr (IsOptional<ArgType>::value) {
|
||||
static_assert(sizeof...(ArgTypes) == 0, "Optional arguments must be last!");
|
||||
ArgType::Check(name, args, nargs, pos);
|
||||
} else {
|
||||
if (!ArgIsType<ArgType>(arg)) {
|
||||
throw QueryRuntimeException("'{}' argument at position {} must be '{}'",
|
||||
name, pos, ArgTypeName<ArgType>());
|
||||
throw QueryRuntimeException("'{}' argument at position {} must be '{}'", name, pos, ArgTypeName<ArgType>());
|
||||
}
|
||||
}
|
||||
if constexpr (sizeof...(ArgTypes) > 0) {
|
||||
@ -339,15 +327,13 @@ void FType(const char *name, const TypedValue *args, int64_t nargs,
|
||||
// TODO: Implement degrees, haversin, radians
|
||||
// TODO: Implement spatial functions
|
||||
|
||||
TypedValue EndNode(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue EndNode(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, Edge>>("endNode", args, nargs);
|
||||
if (args[0].IsNull()) return TypedValue(ctx.memory);
|
||||
return TypedValue(args[0].ValueEdge().To(), ctx.memory);
|
||||
}
|
||||
|
||||
TypedValue Head(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Head(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, List>>("head", args, nargs);
|
||||
if (args[0].IsNull()) return TypedValue(ctx.memory);
|
||||
const auto &list = args[0].ValueList();
|
||||
@ -355,8 +341,7 @@ TypedValue Head(const TypedValue *args, int64_t nargs,
|
||||
return TypedValue(list[0], ctx.memory);
|
||||
}
|
||||
|
||||
TypedValue Last(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Last(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, List>>("last", args, nargs);
|
||||
if (args[0].IsNull()) return TypedValue(ctx.memory);
|
||||
const auto &list = args[0].ValueList();
|
||||
@ -364,8 +349,7 @@ TypedValue Last(const TypedValue *args, int64_t nargs,
|
||||
return TypedValue(list.back(), ctx.memory);
|
||||
}
|
||||
|
||||
TypedValue Properties(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Properties(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, Vertex, Edge>>("properties", args, nargs);
|
||||
auto *dba = ctx.db_accessor;
|
||||
auto get_properties = [&](const auto &record_accessor) {
|
||||
@ -374,16 +358,13 @@ TypedValue Properties(const TypedValue *args, int64_t nargs,
|
||||
if (maybe_props.HasError()) {
|
||||
switch (maybe_props.GetError()) {
|
||||
case storage::Error::DELETED_OBJECT:
|
||||
throw QueryRuntimeException(
|
||||
"Trying to get properties from a deleted object.");
|
||||
throw QueryRuntimeException("Trying to get properties from a deleted object.");
|
||||
case storage::Error::NONEXISTENT_OBJECT:
|
||||
throw query::QueryRuntimeException(
|
||||
"Trying to get properties from an object that doesn't exist.");
|
||||
throw query::QueryRuntimeException("Trying to get properties from an object that doesn't exist.");
|
||||
case storage::Error::SERIALIZATION_ERROR:
|
||||
case storage::Error::VERTEX_HAS_EDGES:
|
||||
case storage::Error::PROPERTIES_DISABLED:
|
||||
throw QueryRuntimeException(
|
||||
"Unexpected error when getting properties.");
|
||||
throw QueryRuntimeException("Unexpected error when getting properties.");
|
||||
}
|
||||
}
|
||||
for (const auto &property : *maybe_props) {
|
||||
@ -401,31 +382,25 @@ TypedValue Properties(const TypedValue *args, int64_t nargs,
|
||||
}
|
||||
}
|
||||
|
||||
TypedValue Size(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Size(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, List, String, Map, Path>>("size", args, nargs);
|
||||
const auto &value = args[0];
|
||||
if (value.IsNull()) {
|
||||
return TypedValue(ctx.memory);
|
||||
} else if (value.IsList()) {
|
||||
return TypedValue(static_cast<int64_t>(value.ValueList().size()),
|
||||
ctx.memory);
|
||||
return TypedValue(static_cast<int64_t>(value.ValueList().size()), ctx.memory);
|
||||
} else if (value.IsString()) {
|
||||
return TypedValue(static_cast<int64_t>(value.ValueString().size()),
|
||||
ctx.memory);
|
||||
return TypedValue(static_cast<int64_t>(value.ValueString().size()), ctx.memory);
|
||||
} else if (value.IsMap()) {
|
||||
// neo4j doesn't implement size for map, but I don't see a good reason not
|
||||
// to do it.
|
||||
return TypedValue(static_cast<int64_t>(value.ValueMap().size()),
|
||||
ctx.memory);
|
||||
return TypedValue(static_cast<int64_t>(value.ValueMap().size()), ctx.memory);
|
||||
} else {
|
||||
return TypedValue(static_cast<int64_t>(value.ValuePath().edges().size()),
|
||||
ctx.memory);
|
||||
return TypedValue(static_cast<int64_t>(value.ValuePath().edges().size()), ctx.memory);
|
||||
}
|
||||
}
|
||||
|
||||
TypedValue StartNode(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue StartNode(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, Edge>>("startNode", args, nargs);
|
||||
if (args[0].IsNull()) return TypedValue(ctx.memory);
|
||||
return TypedValue(args[0].ValueEdge().From(), ctx.memory);
|
||||
@ -439,13 +414,11 @@ size_t UnwrapDegreeResult(storage::Result<size_t> maybe_degree) {
|
||||
case storage::Error::DELETED_OBJECT:
|
||||
throw QueryRuntimeException("Trying to get degree of a deleted node.");
|
||||
case storage::Error::NONEXISTENT_OBJECT:
|
||||
throw query::QueryRuntimeException(
|
||||
"Trying to get degree of a node that doesn't exist.");
|
||||
throw query::QueryRuntimeException("Trying to get degree of a node that doesn't exist.");
|
||||
case storage::Error::SERIALIZATION_ERROR:
|
||||
case storage::Error::VERTEX_HAS_EDGES:
|
||||
case storage::Error::PROPERTIES_DISABLED:
|
||||
throw QueryRuntimeException(
|
||||
"Unexpected error when getting node degree.");
|
||||
throw QueryRuntimeException("Unexpected error when getting node degree.");
|
||||
}
|
||||
}
|
||||
return *maybe_degree;
|
||||
@ -453,8 +426,7 @@ size_t UnwrapDegreeResult(storage::Result<size_t> maybe_degree) {
|
||||
|
||||
} // namespace
|
||||
|
||||
TypedValue Degree(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Degree(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, Vertex>>("degree", args, nargs);
|
||||
if (args[0].IsNull()) return TypedValue(ctx.memory);
|
||||
const auto &vertex = args[0].ValueVertex();
|
||||
@ -463,8 +435,7 @@ TypedValue Degree(const TypedValue *args, int64_t nargs,
|
||||
return TypedValue(static_cast<int64_t>(out_degree + in_degree), ctx.memory);
|
||||
}
|
||||
|
||||
TypedValue InDegree(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue InDegree(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, Vertex>>("inDegree", args, nargs);
|
||||
if (args[0].IsNull()) return TypedValue(ctx.memory);
|
||||
const auto &vertex = args[0].ValueVertex();
|
||||
@ -472,8 +443,7 @@ TypedValue InDegree(const TypedValue *args, int64_t nargs,
|
||||
return TypedValue(static_cast<int64_t>(in_degree), ctx.memory);
|
||||
}
|
||||
|
||||
TypedValue OutDegree(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue OutDegree(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, Vertex>>("outDegree", args, nargs);
|
||||
if (args[0].IsNull()) return TypedValue(ctx.memory);
|
||||
const auto &vertex = args[0].ValueVertex();
|
||||
@ -481,8 +451,7 @@ TypedValue OutDegree(const TypedValue *args, int64_t nargs,
|
||||
return TypedValue(static_cast<int64_t>(out_degree), ctx.memory);
|
||||
}
|
||||
|
||||
TypedValue ToBoolean(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue ToBoolean(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, Bool, Integer, String>>("toBoolean", args, nargs);
|
||||
const auto &value = args[0];
|
||||
if (value.IsNull()) {
|
||||
@ -501,8 +470,7 @@ TypedValue ToBoolean(const TypedValue *args, int64_t nargs,
|
||||
}
|
||||
}
|
||||
|
||||
TypedValue ToFloat(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue ToFloat(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, Number, String>>("toFloat", args, nargs);
|
||||
const auto &value = args[0];
|
||||
if (value.IsNull()) {
|
||||
@ -513,16 +481,14 @@ TypedValue ToFloat(const TypedValue *args, int64_t nargs,
|
||||
return TypedValue(value, ctx.memory);
|
||||
} else {
|
||||
try {
|
||||
return TypedValue(utils::ParseDouble(utils::Trim(value.ValueString())),
|
||||
ctx.memory);
|
||||
return TypedValue(utils::ParseDouble(utils::Trim(value.ValueString())), ctx.memory);
|
||||
} catch (const utils::BasicException &) {
|
||||
return TypedValue(ctx.memory);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TypedValue ToInteger(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue ToInteger(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, Bool, Number, String>>("toInteger", args, nargs);
|
||||
const auto &value = args[0];
|
||||
if (value.IsNull()) {
|
||||
@ -537,28 +503,22 @@ TypedValue ToInteger(const TypedValue *args, int64_t nargs,
|
||||
try {
|
||||
// Yup, this is correct. String is valid if it has floating point
|
||||
// number, then it is parsed and converted to int.
|
||||
return TypedValue(static_cast<int64_t>(utils::ParseDouble(
|
||||
utils::Trim(value.ValueString()))),
|
||||
ctx.memory);
|
||||
return TypedValue(static_cast<int64_t>(utils::ParseDouble(utils::Trim(value.ValueString()))), ctx.memory);
|
||||
} catch (const utils::BasicException &) {
|
||||
return TypedValue(ctx.memory);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TypedValue Type(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Type(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, Edge>>("type", args, nargs);
|
||||
auto *dba = ctx.db_accessor;
|
||||
if (args[0].IsNull()) return TypedValue(ctx.memory);
|
||||
return TypedValue(dba->EdgeTypeToName(args[0].ValueEdge().EdgeType()),
|
||||
ctx.memory);
|
||||
return TypedValue(dba->EdgeTypeToName(args[0].ValueEdge().EdgeType()), ctx.memory);
|
||||
}
|
||||
|
||||
TypedValue ValueType(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
FType<Or<Null, Bool, Integer, Double, String, List, Map, Vertex, Edge, Path>>(
|
||||
"type", args, nargs);
|
||||
TypedValue ValueType(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, Bool, Integer, Double, String, List, Map, Vertex, Edge, Path>>("type", args, nargs);
|
||||
// The type names returned should be standardized openCypher type names.
|
||||
// https://github.com/opencypher/openCypher/blob/master/docs/openCypher9.pdf
|
||||
switch (args[0].type()) {
|
||||
@ -586,8 +546,7 @@ TypedValue ValueType(const TypedValue *args, int64_t nargs,
|
||||
}
|
||||
|
||||
// TODO: How is Keys different from Properties function?
|
||||
TypedValue Keys(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Keys(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, Vertex, Edge>>("keys", args, nargs);
|
||||
auto *dba = ctx.db_accessor;
|
||||
auto get_keys = [&](const auto &record_accessor) {
|
||||
@ -596,11 +555,9 @@ TypedValue Keys(const TypedValue *args, int64_t nargs,
|
||||
if (maybe_props.HasError()) {
|
||||
switch (maybe_props.GetError()) {
|
||||
case storage::Error::DELETED_OBJECT:
|
||||
throw QueryRuntimeException(
|
||||
"Trying to get keys from a deleted object.");
|
||||
throw QueryRuntimeException("Trying to get keys from a deleted object.");
|
||||
case storage::Error::NONEXISTENT_OBJECT:
|
||||
throw query::QueryRuntimeException(
|
||||
"Trying to get keys from an object that doesn't exist.");
|
||||
throw query::QueryRuntimeException("Trying to get keys from an object that doesn't exist.");
|
||||
case storage::Error::SERIALIZATION_ERROR:
|
||||
case storage::Error::VERTEX_HAS_EDGES:
|
||||
case storage::Error::PROPERTIES_DISABLED:
|
||||
@ -622,8 +579,7 @@ TypedValue Keys(const TypedValue *args, int64_t nargs,
|
||||
}
|
||||
}
|
||||
|
||||
TypedValue Labels(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Labels(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, Vertex>>("labels", args, nargs);
|
||||
auto *dba = ctx.db_accessor;
|
||||
if (args[0].IsNull()) return TypedValue(ctx.memory);
|
||||
@ -632,11 +588,9 @@ TypedValue Labels(const TypedValue *args, int64_t nargs,
|
||||
if (maybe_labels.HasError()) {
|
||||
switch (maybe_labels.GetError()) {
|
||||
case storage::Error::DELETED_OBJECT:
|
||||
throw QueryRuntimeException(
|
||||
"Trying to get labels from a deleted node.");
|
||||
throw QueryRuntimeException("Trying to get labels from a deleted node.");
|
||||
case storage::Error::NONEXISTENT_OBJECT:
|
||||
throw query::QueryRuntimeException(
|
||||
"Trying to get labels from a node that doesn't exist.");
|
||||
throw query::QueryRuntimeException("Trying to get labels from a node that doesn't exist.");
|
||||
case storage::Error::SERIALIZATION_ERROR:
|
||||
case storage::Error::VERTEX_HAS_EDGES:
|
||||
case storage::Error::PROPERTIES_DISABLED:
|
||||
@ -649,8 +603,7 @@ TypedValue Labels(const TypedValue *args, int64_t nargs,
|
||||
return TypedValue(std::move(labels));
|
||||
}
|
||||
|
||||
TypedValue Nodes(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Nodes(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, Path>>("nodes", args, nargs);
|
||||
if (args[0].IsNull()) return TypedValue(ctx.memory);
|
||||
const auto &vertices = args[0].ValuePath().vertices();
|
||||
@ -660,8 +613,7 @@ TypedValue Nodes(const TypedValue *args, int64_t nargs,
|
||||
return TypedValue(std::move(values));
|
||||
}
|
||||
|
||||
TypedValue Relationships(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Relationships(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, Path>>("relationships", args, nargs);
|
||||
if (args[0].IsNull()) return TypedValue(ctx.memory);
|
||||
const auto &edges = args[0].ValuePath().edges();
|
||||
@ -671,10 +623,8 @@ TypedValue Relationships(const TypedValue *args, int64_t nargs,
|
||||
return TypedValue(std::move(values));
|
||||
}
|
||||
|
||||
TypedValue Range(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
FType<Or<Null, Integer>, Or<Null, Integer>,
|
||||
Optional<Or<Null, NonZeroInteger>>>("range", args, nargs);
|
||||
TypedValue Range(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, Integer>, Or<Null, Integer>, Optional<Or<Null, NonZeroInteger>>>("range", args, nargs);
|
||||
for (int64_t i = 0; i < nargs; ++i)
|
||||
if (args[i].IsNull()) return TypedValue(ctx.memory);
|
||||
auto lbound = args[0].ValueInt();
|
||||
@ -693,8 +643,7 @@ TypedValue Range(const TypedValue *args, int64_t nargs,
|
||||
return TypedValue(std::move(list));
|
||||
}
|
||||
|
||||
TypedValue Tail(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Tail(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, List>>("tail", args, nargs);
|
||||
if (args[0].IsNull()) return TypedValue(ctx.memory);
|
||||
TypedValue::TVector list(args[0].ValueList(), ctx.memory);
|
||||
@ -703,10 +652,8 @@ TypedValue Tail(const TypedValue *args, int64_t nargs,
|
||||
return TypedValue(std::move(list));
|
||||
}
|
||||
|
||||
TypedValue UniformSample(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
FType<Or<Null, List>, Or<Null, NonNegativeInteger>>("uniformSample", args,
|
||||
nargs);
|
||||
TypedValue UniformSample(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, List>, Or<Null, NonNegativeInteger>>("uniformSample", args, nargs);
|
||||
static thread_local std::mt19937 pseudo_rand_gen_{std::random_device{}()};
|
||||
if (args[0].IsNull() || args[1].IsNull()) return TypedValue(ctx.memory);
|
||||
const auto &population = args[0].ValueList();
|
||||
@ -722,8 +669,7 @@ TypedValue UniformSample(const TypedValue *args, int64_t nargs,
|
||||
return TypedValue(std::move(sampled));
|
||||
}
|
||||
|
||||
TypedValue Abs(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Abs(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, Number>>("abs", args, nargs);
|
||||
const auto &value = args[0];
|
||||
if (value.IsNull()) {
|
||||
@ -736,8 +682,7 @@ TypedValue Abs(const TypedValue *args, int64_t nargs,
|
||||
}
|
||||
|
||||
#define WRAP_CMATH_FLOAT_FUNCTION(name, lowercased_name) \
|
||||
TypedValue name(const TypedValue *args, int64_t nargs, \
|
||||
const FunctionContext &ctx) { \
|
||||
TypedValue name(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { \
|
||||
FType<Or<Null, Number>>(#lowercased_name, args, nargs); \
|
||||
const auto &value = args[0]; \
|
||||
if (value.IsNull()) { \
|
||||
@ -767,8 +712,7 @@ WRAP_CMATH_FLOAT_FUNCTION(Tan, tan)
|
||||
|
||||
#undef WRAP_CMATH_FLOAT_FUNCTION
|
||||
|
||||
TypedValue Atan2(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Atan2(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, Number>, Or<Null, Number>>("atan2", args, nargs);
|
||||
if (args[0].IsNull() || args[1].IsNull()) return TypedValue(ctx.memory);
|
||||
auto to_double = [](const TypedValue &t) -> double {
|
||||
@ -783,8 +727,7 @@ TypedValue Atan2(const TypedValue *args, int64_t nargs,
|
||||
return TypedValue(atan2(y, x), ctx.memory);
|
||||
}
|
||||
|
||||
TypedValue Sign(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Sign(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, Number>>("sign", args, nargs);
|
||||
auto sign = [&](auto x) { return TypedValue((0 < x) - (x < 0), ctx.memory); };
|
||||
const auto &value = args[0];
|
||||
@ -797,20 +740,17 @@ TypedValue Sign(const TypedValue *args, int64_t nargs,
|
||||
}
|
||||
}
|
||||
|
||||
TypedValue E(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue E(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<void>("e", args, nargs);
|
||||
return TypedValue(M_E, ctx.memory);
|
||||
}
|
||||
|
||||
TypedValue Pi(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Pi(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<void>("pi", args, nargs);
|
||||
return TypedValue(M_PI, ctx.memory);
|
||||
}
|
||||
|
||||
TypedValue Rand(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Rand(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<void>("rand", args, nargs);
|
||||
static thread_local std::mt19937 pseudo_rand_gen_{std::random_device{}()};
|
||||
static thread_local std::uniform_real_distribution<> rand_dist_{0, 1};
|
||||
@ -818,8 +758,7 @@ TypedValue Rand(const TypedValue *args, int64_t nargs,
|
||||
}
|
||||
|
||||
template <class TPredicate>
|
||||
TypedValue StringMatchOperator(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue StringMatchOperator(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, String>, Or<Null, String>>(TPredicate::name, args, nargs);
|
||||
if (args[0].IsNull() || args[1].IsNull()) return TypedValue(ctx.memory);
|
||||
const auto &s1 = args[0].ValueString();
|
||||
@ -830,8 +769,7 @@ TypedValue StringMatchOperator(const TypedValue *args, int64_t nargs,
|
||||
// Check if s1 starts with s2.
|
||||
struct StartsWithPredicate {
|
||||
constexpr static const char *name = "startsWith";
|
||||
bool operator()(const TypedValue::TString &s1,
|
||||
const TypedValue::TString &s2) const {
|
||||
bool operator()(const TypedValue::TString &s1, const TypedValue::TString &s2) const {
|
||||
if (s1.size() < s2.size()) return false;
|
||||
return std::equal(s2.begin(), s2.end(), s1.begin());
|
||||
}
|
||||
@ -841,8 +779,7 @@ auto StartsWith = StringMatchOperator<StartsWithPredicate>;
|
||||
// Check if s1 ends with s2.
|
||||
struct EndsWithPredicate {
|
||||
constexpr static const char *name = "endsWith";
|
||||
bool operator()(const TypedValue::TString &s1,
|
||||
const TypedValue::TString &s2) const {
|
||||
bool operator()(const TypedValue::TString &s1, const TypedValue::TString &s2) const {
|
||||
if (s1.size() < s2.size()) return false;
|
||||
return std::equal(s2.rbegin(), s2.rend(), s1.rbegin());
|
||||
}
|
||||
@ -852,16 +789,14 @@ auto EndsWith = StringMatchOperator<EndsWithPredicate>;
|
||||
// Check if s1 contains s2.
|
||||
struct ContainsPredicate {
|
||||
constexpr static const char *name = "contains";
|
||||
bool operator()(const TypedValue::TString &s1,
|
||||
const TypedValue::TString &s2) const {
|
||||
bool operator()(const TypedValue::TString &s1, const TypedValue::TString &s2) const {
|
||||
if (s1.size() < s2.size()) return false;
|
||||
return s1.find(s2) != std::string::npos;
|
||||
}
|
||||
};
|
||||
auto Contains = StringMatchOperator<ContainsPredicate>;
|
||||
|
||||
TypedValue Assert(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Assert(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Bool, Optional<String>>("assert", args, nargs);
|
||||
if (!args[0].ValueBool()) {
|
||||
std::string message("Assertion failed");
|
||||
@ -875,24 +810,21 @@ TypedValue Assert(const TypedValue *args, int64_t nargs,
|
||||
return TypedValue(args[0], ctx.memory);
|
||||
}
|
||||
|
||||
TypedValue Counter(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &context) {
|
||||
TypedValue Counter(const TypedValue *args, int64_t nargs, const FunctionContext &context) {
|
||||
FType<String, Integer, Optional<NonZeroInteger>>("counter", args, nargs);
|
||||
int64_t step = 1;
|
||||
if (nargs == 3) {
|
||||
step = args[2].ValueInt();
|
||||
}
|
||||
|
||||
auto [it, inserted] =
|
||||
context.counters->emplace(args[0].ValueString(), args[1].ValueInt());
|
||||
auto [it, inserted] = context.counters->emplace(args[0].ValueString(), args[1].ValueInt());
|
||||
auto value = it->second;
|
||||
it->second += step;
|
||||
|
||||
return TypedValue(value, context.memory);
|
||||
}
|
||||
|
||||
TypedValue Id(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Id(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, Vertex, Edge>>("id", args, nargs);
|
||||
const auto &arg = args[0];
|
||||
if (arg.IsNull()) {
|
||||
@ -904,8 +836,7 @@ TypedValue Id(const TypedValue *args, int64_t nargs,
|
||||
}
|
||||
}
|
||||
|
||||
TypedValue ToString(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue ToString(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, String, Number, Bool>>("toString", args, nargs);
|
||||
const auto &arg = args[0];
|
||||
if (arg.IsNull()) {
|
||||
@ -923,106 +854,80 @@ TypedValue ToString(const TypedValue *args, int64_t nargs,
|
||||
}
|
||||
}
|
||||
|
||||
TypedValue Timestamp(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Timestamp(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<void>("timestamp", args, nargs);
|
||||
return TypedValue(ctx.timestamp, ctx.memory);
|
||||
}
|
||||
|
||||
TypedValue Left(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Left(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, String>, Or<Null, NonNegativeInteger>>("left", args, nargs);
|
||||
if (args[0].IsNull() || args[1].IsNull()) return TypedValue(ctx.memory);
|
||||
return TypedValue(utils::Substr(args[0].ValueString(), 0, args[1].ValueInt()),
|
||||
ctx.memory);
|
||||
return TypedValue(utils::Substr(args[0].ValueString(), 0, args[1].ValueInt()), ctx.memory);
|
||||
}
|
||||
|
||||
TypedValue Right(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Right(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, String>, Or<Null, NonNegativeInteger>>("right", args, nargs);
|
||||
if (args[0].IsNull() || args[1].IsNull()) return TypedValue(ctx.memory);
|
||||
const auto &str = args[0].ValueString();
|
||||
auto len = args[1].ValueInt();
|
||||
return len <= str.size()
|
||||
? TypedValue(utils::Substr(str, str.size() - len, len), ctx.memory)
|
||||
return len <= str.size() ? TypedValue(utils::Substr(str, str.size() - len, len), ctx.memory)
|
||||
: TypedValue(str, ctx.memory);
|
||||
}
|
||||
|
||||
TypedValue CallStringFunction(
|
||||
const TypedValue *args, int64_t nargs, utils::MemoryResource *memory,
|
||||
const char *name,
|
||||
TypedValue CallStringFunction(const TypedValue *args, int64_t nargs, utils::MemoryResource *memory, const char *name,
|
||||
std::function<TypedValue::TString(const TypedValue::TString &)> fun) {
|
||||
FType<Or<Null, String>>(name, args, nargs);
|
||||
if (args[0].IsNull()) return TypedValue(memory);
|
||||
return TypedValue(fun(args[0].ValueString()), memory);
|
||||
}
|
||||
|
||||
TypedValue LTrim(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
return CallStringFunction(
|
||||
args, nargs, ctx.memory, "lTrim", [&](const auto &str) {
|
||||
return TypedValue::TString(utils::LTrim(str), ctx.memory);
|
||||
});
|
||||
TypedValue LTrim(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
return CallStringFunction(args, nargs, ctx.memory, "lTrim",
|
||||
[&](const auto &str) { return TypedValue::TString(utils::LTrim(str), ctx.memory); });
|
||||
}
|
||||
|
||||
TypedValue RTrim(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
return CallStringFunction(
|
||||
args, nargs, ctx.memory, "rTrim", [&](const auto &str) {
|
||||
return TypedValue::TString(utils::RTrim(str), ctx.memory);
|
||||
});
|
||||
TypedValue RTrim(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
return CallStringFunction(args, nargs, ctx.memory, "rTrim",
|
||||
[&](const auto &str) { return TypedValue::TString(utils::RTrim(str), ctx.memory); });
|
||||
}
|
||||
|
||||
TypedValue Trim(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
return CallStringFunction(
|
||||
args, nargs, ctx.memory, "trim", [&](const auto &str) {
|
||||
return TypedValue::TString(utils::Trim(str), ctx.memory);
|
||||
});
|
||||
TypedValue Trim(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
return CallStringFunction(args, nargs, ctx.memory, "trim",
|
||||
[&](const auto &str) { return TypedValue::TString(utils::Trim(str), ctx.memory); });
|
||||
}
|
||||
|
||||
TypedValue Reverse(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
return CallStringFunction(
|
||||
args, nargs, ctx.memory, "reverse",
|
||||
TypedValue Reverse(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
return CallStringFunction(args, nargs, ctx.memory, "reverse",
|
||||
[&](const auto &str) { return utils::Reversed(str, ctx.memory); });
|
||||
}
|
||||
|
||||
TypedValue ToLower(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
return CallStringFunction(args, nargs, ctx.memory, "toLower",
|
||||
[&](const auto &str) {
|
||||
TypedValue ToLower(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
return CallStringFunction(args, nargs, ctx.memory, "toLower", [&](const auto &str) {
|
||||
TypedValue::TString res(ctx.memory);
|
||||
utils::ToLowerCase(&res, str);
|
||||
return res;
|
||||
});
|
||||
}
|
||||
|
||||
TypedValue ToUpper(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
return CallStringFunction(args, nargs, ctx.memory, "toUpper",
|
||||
[&](const auto &str) {
|
||||
TypedValue ToUpper(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
return CallStringFunction(args, nargs, ctx.memory, "toUpper", [&](const auto &str) {
|
||||
TypedValue::TString res(ctx.memory);
|
||||
utils::ToUpperCase(&res, str);
|
||||
return res;
|
||||
});
|
||||
}
|
||||
|
||||
TypedValue Replace(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
FType<Or<Null, String>, Or<Null, String>, Or<Null, String>>("replace", args,
|
||||
nargs);
|
||||
TypedValue Replace(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, String>, Or<Null, String>, Or<Null, String>>("replace", args, nargs);
|
||||
if (args[0].IsNull() || args[1].IsNull() || args[2].IsNull()) {
|
||||
return TypedValue(ctx.memory);
|
||||
}
|
||||
TypedValue::TString replaced(ctx.memory);
|
||||
utils::Replace(&replaced, args[0].ValueString(), args[1].ValueString(),
|
||||
args[2].ValueString());
|
||||
utils::Replace(&replaced, args[0].ValueString(), args[1].ValueString(), args[2].ValueString());
|
||||
return TypedValue(std::move(replaced));
|
||||
}
|
||||
|
||||
TypedValue Split(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue Split(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, String>, Or<Null, String>>("split", args, nargs);
|
||||
if (args[0].IsNull() || args[1].IsNull()) {
|
||||
return TypedValue(ctx.memory);
|
||||
@ -1032,10 +937,8 @@ TypedValue Split(const TypedValue *args, int64_t nargs,
|
||||
return TypedValue(std::move(result));
|
||||
}
|
||||
|
||||
TypedValue Substring(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
FType<Or<Null, String>, NonNegativeInteger, Optional<NonNegativeInteger>>(
|
||||
"substring", args, nargs);
|
||||
TypedValue Substring(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<Or<Null, String>, NonNegativeInteger, Optional<NonNegativeInteger>>("substring", args, nargs);
|
||||
if (args[0].IsNull()) return TypedValue(ctx.memory);
|
||||
const auto &str = args[0].ValueString();
|
||||
auto start = args[1].ValueInt();
|
||||
@ -1044,8 +947,7 @@ TypedValue Substring(const TypedValue *args, int64_t nargs,
|
||||
return TypedValue(utils::Substr(str, start, len), ctx.memory);
|
||||
}
|
||||
|
||||
TypedValue ToByteString(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue ToByteString(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<String>("toByteString", args, nargs);
|
||||
const auto &str = args[0].ValueString();
|
||||
if (str.empty()) return TypedValue("", ctx.memory);
|
||||
@ -1057,8 +959,7 @@ TypedValue ToByteString(const TypedValue *args, int64_t nargs,
|
||||
if (ch >= '0' && ch <= '9') return ch - '0';
|
||||
if (ch >= 'a' && ch <= 'f') return ch - 'a' + 10;
|
||||
if (ch >= 'A' && ch <= 'F') return ch - 'A' + 10;
|
||||
throw QueryRuntimeException(
|
||||
"'toByteString' argument has an invalid character '{}'", ch);
|
||||
throw QueryRuntimeException("'toByteString' argument has an invalid character '{}'", ch);
|
||||
};
|
||||
utils::pmr::string bytes(ctx.memory);
|
||||
bytes.reserve((1 + hex_str.size()) / 2);
|
||||
@ -1074,27 +975,23 @@ TypedValue ToByteString(const TypedValue *args, int64_t nargs,
|
||||
return TypedValue(std::move(bytes));
|
||||
}
|
||||
|
||||
TypedValue FromByteString(const TypedValue *args, int64_t nargs,
|
||||
const FunctionContext &ctx) {
|
||||
TypedValue FromByteString(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
|
||||
FType<String, Optional<PositiveInteger>>("fromByteString", args, nargs);
|
||||
const auto &bytes = args[0].ValueString();
|
||||
if (bytes.empty()) return TypedValue("", ctx.memory);
|
||||
size_t min_length = bytes.size();
|
||||
if (nargs == 2)
|
||||
min_length = std::max(min_length, static_cast<size_t>(args[1].ValueInt()));
|
||||
if (nargs == 2) min_length = std::max(min_length, static_cast<size_t>(args[1].ValueInt()));
|
||||
utils::pmr::string str(ctx.memory);
|
||||
str.reserve(min_length * 2 + 2);
|
||||
str.append("0x");
|
||||
for (size_t pad = 0; pad < min_length - bytes.size(); ++pad)
|
||||
str.append(2, '0');
|
||||
for (size_t pad = 0; pad < min_length - bytes.size(); ++pad) str.append(2, '0');
|
||||
// Convert the bytes to a character string in hex representation.
|
||||
// Unfortunately, we don't know whether the default `char` is signed or
|
||||
// unsigned, so we have to work around any potential undefined behaviour when
|
||||
// conversions between the 2 occur. That's why this function is more
|
||||
// complicated than it should be.
|
||||
auto to_hex = [](const unsigned char val) -> char {
|
||||
unsigned char ch = val < 10U ? static_cast<unsigned char>('0') + val
|
||||
: static_cast<unsigned char>('a') + val - 10U;
|
||||
unsigned char ch = val < 10U ? static_cast<unsigned char>('0') + val : static_cast<unsigned char>('a') + val - 10U;
|
||||
return utils::MemcpyCast<char>(ch);
|
||||
};
|
||||
for (unsigned char byte : bytes) {
|
||||
@ -1106,9 +1003,8 @@ TypedValue FromByteString(const TypedValue *args, int64_t nargs,
|
||||
|
||||
} // namespace
|
||||
|
||||
std::function<TypedValue(const TypedValue *, int64_t,
|
||||
const FunctionContext &ctx)>
|
||||
NameToFunction(const std::string &function_name) {
|
||||
std::function<TypedValue(const TypedValue *, int64_t, const FunctionContext &ctx)> NameToFunction(
|
||||
const std::string &function_name) {
|
||||
// Scalar functions
|
||||
if (function_name == "DEGREE") return Degree;
|
||||
if (function_name == "INDEGREE") return InDegree;
|
||||
|
@ -34,8 +34,7 @@ struct FunctionContext {
|
||||
/// having an array stored anywhere the caller likes, as long as it is
|
||||
/// contiguous in memory. Since most functions don't take many arguments, it's
|
||||
/// convenient to have them stored in the calling stack frame.
|
||||
std::function<TypedValue(const TypedValue *arguments, int64_t num_arguments,
|
||||
const FunctionContext &context)>
|
||||
std::function<TypedValue(const TypedValue *arguments, int64_t num_arguments, const FunctionContext &context)>
|
||||
NameToFunction(const std::string &function_name);
|
||||
|
||||
} // namespace query
|
||||
|
@ -21,14 +21,9 @@ namespace query {
|
||||
|
||||
class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
public:
|
||||
ExpressionEvaluator(Frame *frame, const SymbolTable &symbol_table,
|
||||
const EvaluationContext &ctx, DbAccessor *dba,
|
||||
ExpressionEvaluator(Frame *frame, const SymbolTable &symbol_table, const EvaluationContext &ctx, DbAccessor *dba,
|
||||
storage::View view)
|
||||
: frame_(frame),
|
||||
symbol_table_(&symbol_table),
|
||||
ctx_(&ctx),
|
||||
dba_(dba),
|
||||
view_(view) {}
|
||||
: frame_(frame), symbol_table_(&symbol_table), ctx_(&ctx), dba_(dba), view_(view) {}
|
||||
|
||||
using ExpressionVisitor<TypedValue>::Visit;
|
||||
|
||||
@ -52,8 +47,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
try { \
|
||||
return val1 CPP_OP val2; \
|
||||
} catch (const TypedValueException &) { \
|
||||
throw QueryRuntimeException("Invalid types: {} and {} for '{}'.", \
|
||||
val1.type(), val2.type(), #CYPHER_OP); \
|
||||
throw QueryRuntimeException("Invalid types: {} and {} for '{}'.", val1.type(), val2.type(), #CYPHER_OP); \
|
||||
} \
|
||||
}
|
||||
|
||||
@ -63,8 +57,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
try { \
|
||||
return CPP_OP val; \
|
||||
} catch (const TypedValueException &) { \
|
||||
throw QueryRuntimeException("Invalid type {} for '{}'.", val.type(), \
|
||||
#CYPHER_OP); \
|
||||
throw QueryRuntimeException("Invalid type {} for '{}'.", val.type(), #CYPHER_OP); \
|
||||
} \
|
||||
}
|
||||
|
||||
@ -99,8 +92,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
try {
|
||||
return value1 && value2;
|
||||
} catch (const TypedValueException &) {
|
||||
throw QueryRuntimeException("Invalid types: {} and {} for AND.",
|
||||
value1.type(), value2.type());
|
||||
throw QueryRuntimeException("Invalid types: {} and {} for AND.", value1.type(), value2.type());
|
||||
}
|
||||
}
|
||||
|
||||
@ -111,8 +103,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
}
|
||||
if (condition.type() != TypedValue::Type::Bool) {
|
||||
// At the moment IfOperator is used only in CASE construct.
|
||||
throw QueryRuntimeException("CASE expected boolean expression, got {}.",
|
||||
condition.type());
|
||||
throw QueryRuntimeException("CASE expected boolean expression, got {}.", condition.type());
|
||||
}
|
||||
if (condition.ValueBool()) {
|
||||
return if_operator.then_expression_->Accept(*this);
|
||||
@ -158,17 +149,14 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
TypedValue Visit(SubscriptOperator &list_indexing) override {
|
||||
auto lhs = list_indexing.expression1_->Accept(*this);
|
||||
auto index = list_indexing.expression2_->Accept(*this);
|
||||
if (!lhs.IsList() && !lhs.IsMap() && !lhs.IsVertex() && !lhs.IsEdge() &&
|
||||
!lhs.IsNull())
|
||||
if (!lhs.IsList() && !lhs.IsMap() && !lhs.IsVertex() && !lhs.IsEdge() && !lhs.IsNull())
|
||||
throw QueryRuntimeException(
|
||||
"Expected a list, a map, a node or an edge to index with '[]', got "
|
||||
"{}.",
|
||||
lhs.type());
|
||||
if (lhs.IsNull() || index.IsNull()) return TypedValue(ctx_->memory);
|
||||
if (lhs.IsList()) {
|
||||
if (!index.IsInt())
|
||||
throw QueryRuntimeException(
|
||||
"Expected an integer as a list index, got {}.", index.type());
|
||||
if (!index.IsInt()) throw QueryRuntimeException("Expected an integer as a list index, got {}.", index.type());
|
||||
auto index_int = index.ValueInt();
|
||||
// NOTE: Take non-const reference to list, so that we can move out the
|
||||
// indexed element as the result.
|
||||
@ -176,17 +164,14 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
if (index_int < 0) {
|
||||
index_int += static_cast<int64_t>(list.size());
|
||||
}
|
||||
if (index_int >= static_cast<int64_t>(list.size()) || index_int < 0)
|
||||
return TypedValue(ctx_->memory);
|
||||
if (index_int >= static_cast<int64_t>(list.size()) || index_int < 0) return TypedValue(ctx_->memory);
|
||||
// NOTE: Explicit move is needed, so that we return the move constructed
|
||||
// value and preserve the correct MemoryResource.
|
||||
return std::move(list[index_int]);
|
||||
}
|
||||
|
||||
if (lhs.IsMap()) {
|
||||
if (!index.IsString())
|
||||
throw QueryRuntimeException("Expected a string as a map index, got {}.",
|
||||
index.type());
|
||||
if (!index.IsString()) throw QueryRuntimeException("Expected a string as a map index, got {}.", index.type());
|
||||
// NOTE: Take non-const reference to map, so that we can move out the
|
||||
// looked-up element as the result.
|
||||
auto &map = lhs.ValueMap();
|
||||
@ -198,19 +183,13 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
}
|
||||
|
||||
if (lhs.IsVertex()) {
|
||||
if (!index.IsString())
|
||||
throw QueryRuntimeException(
|
||||
"Expected a string as a property name, got {}.", index.type());
|
||||
return TypedValue(GetProperty(lhs.ValueVertex(), index.ValueString()),
|
||||
ctx_->memory);
|
||||
if (!index.IsString()) throw QueryRuntimeException("Expected a string as a property name, got {}.", index.type());
|
||||
return TypedValue(GetProperty(lhs.ValueVertex(), index.ValueString()), ctx_->memory);
|
||||
}
|
||||
|
||||
if (lhs.IsEdge()) {
|
||||
if (!index.IsString())
|
||||
throw QueryRuntimeException(
|
||||
"Expected a string as a property name, got {}.", index.type());
|
||||
return TypedValue(GetProperty(lhs.ValueEdge(), index.ValueString()),
|
||||
ctx_->memory);
|
||||
if (!index.IsString()) throw QueryRuntimeException("Expected a string as a property name, got {}.", index.type());
|
||||
return TypedValue(GetProperty(lhs.ValueEdge(), index.ValueString()), ctx_->memory);
|
||||
}
|
||||
|
||||
// lhs is Null
|
||||
@ -227,24 +206,20 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
if (bound.type() == TypedValue::Type::Null) {
|
||||
is_null = true;
|
||||
} else if (bound.type() != TypedValue::Type::Int) {
|
||||
throw QueryRuntimeException(
|
||||
"Expected an integer for a bound in list slicing, got {}.",
|
||||
bound.type());
|
||||
throw QueryRuntimeException("Expected an integer for a bound in list slicing, got {}.", bound.type());
|
||||
}
|
||||
return bound;
|
||||
}
|
||||
return TypedValue(default_value, ctx_->memory);
|
||||
};
|
||||
auto _upper_bound =
|
||||
get_bound(op.upper_bound_, std::numeric_limits<int64_t>::max());
|
||||
auto _upper_bound = get_bound(op.upper_bound_, std::numeric_limits<int64_t>::max());
|
||||
auto _lower_bound = get_bound(op.lower_bound_, 0);
|
||||
|
||||
auto _list = op.list_->Accept(*this);
|
||||
if (_list.type() == TypedValue::Type::Null) {
|
||||
is_null = true;
|
||||
} else if (_list.type() != TypedValue::Type::List) {
|
||||
throw QueryRuntimeException("Expected a list to slice, got {}.",
|
||||
_list.type());
|
||||
throw QueryRuntimeException("Expected a list to slice, got {}.", _list.type());
|
||||
}
|
||||
|
||||
if (is_null) {
|
||||
@ -255,16 +230,14 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
if (bound < 0) {
|
||||
bound = static_cast<int64_t>(list.size()) + bound;
|
||||
}
|
||||
return std::max(static_cast<int64_t>(0),
|
||||
std::min(bound, static_cast<int64_t>(list.size())));
|
||||
return std::max(static_cast<int64_t>(0), std::min(bound, static_cast<int64_t>(list.size())));
|
||||
};
|
||||
auto lower_bound = normalise_bound(_lower_bound.ValueInt());
|
||||
auto upper_bound = normalise_bound(_upper_bound.ValueInt());
|
||||
if (upper_bound <= lower_bound) {
|
||||
return TypedValue(TypedValue::TVector(ctx_->memory), ctx_->memory);
|
||||
}
|
||||
return TypedValue(TypedValue::TVector(
|
||||
list.begin() + lower_bound, list.begin() + upper_bound, ctx_->memory));
|
||||
return TypedValue(TypedValue::TVector(list.begin() + lower_bound, list.begin() + upper_bound, ctx_->memory));
|
||||
}
|
||||
|
||||
TypedValue Visit(IsNullOperator &is_null) override {
|
||||
@ -278,13 +251,9 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
case TypedValue::Type::Null:
|
||||
return TypedValue(ctx_->memory);
|
||||
case TypedValue::Type::Vertex:
|
||||
return TypedValue(GetProperty(expression_result.ValueVertex(),
|
||||
property_lookup.property_),
|
||||
ctx_->memory);
|
||||
return TypedValue(GetProperty(expression_result.ValueVertex(), property_lookup.property_), ctx_->memory);
|
||||
case TypedValue::Type::Edge:
|
||||
return TypedValue(GetProperty(expression_result.ValueEdge(),
|
||||
property_lookup.property_),
|
||||
ctx_->memory);
|
||||
return TypedValue(GetProperty(expression_result.ValueEdge(), property_lookup.property_), ctx_->memory);
|
||||
case TypedValue::Type::Map: {
|
||||
// NOTE: Take non-const reference to map, so that we can move out the
|
||||
// looked-up element as the result.
|
||||
@ -296,8 +265,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
return std::move(found->second);
|
||||
}
|
||||
default:
|
||||
throw QueryRuntimeException(
|
||||
"Only nodes, edges and maps have properties to be looked-up.");
|
||||
throw QueryRuntimeException("Only nodes, edges and maps have properties to be looked-up.");
|
||||
}
|
||||
}
|
||||
|
||||
@ -310,8 +278,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
const auto &vertex = expression_result.ValueVertex();
|
||||
for (const auto &label : labels_test.labels_) {
|
||||
auto has_label = vertex.HasLabel(view_, GetLabel(label));
|
||||
if (has_label.HasError() &&
|
||||
has_label.GetError() == storage::Error::NONEXISTENT_OBJECT) {
|
||||
if (has_label.HasError() && has_label.GetError() == storage::Error::NONEXISTENT_OBJECT) {
|
||||
// This is a very nasty and temporary hack in order to make MERGE
|
||||
// work. The old storage had the following logic when returning an
|
||||
// `OLD` view: `return old ? old : new`. That means that if the
|
||||
@ -324,16 +291,13 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
if (has_label.HasError()) {
|
||||
switch (has_label.GetError()) {
|
||||
case storage::Error::DELETED_OBJECT:
|
||||
throw QueryRuntimeException(
|
||||
"Trying to access labels on a deleted node.");
|
||||
throw QueryRuntimeException("Trying to access labels on a deleted node.");
|
||||
case storage::Error::NONEXISTENT_OBJECT:
|
||||
throw query::QueryRuntimeException(
|
||||
"Trying to access labels from a node that doesn't exist.");
|
||||
throw query::QueryRuntimeException("Trying to access labels from a node that doesn't exist.");
|
||||
case storage::Error::SERIALIZATION_ERROR:
|
||||
case storage::Error::VERTEX_HAS_EDGES:
|
||||
case storage::Error::PROPERTIES_DISABLED:
|
||||
throw QueryRuntimeException(
|
||||
"Unexpected error when accessing labels.");
|
||||
throw QueryRuntimeException("Unexpected error when accessing labels.");
|
||||
}
|
||||
}
|
||||
if (!*has_label) {
|
||||
@ -356,15 +320,13 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
TypedValue Visit(ListLiteral &literal) override {
|
||||
TypedValue::TVector result(ctx_->memory);
|
||||
result.reserve(literal.elements_.size());
|
||||
for (const auto &expression : literal.elements_)
|
||||
result.emplace_back(expression->Accept(*this));
|
||||
for (const auto &expression : literal.elements_) result.emplace_back(expression->Accept(*this));
|
||||
return TypedValue(result, ctx_->memory);
|
||||
}
|
||||
|
||||
TypedValue Visit(MapLiteral &literal) override {
|
||||
TypedValue::TMap result(ctx_->memory);
|
||||
for (const auto &pair : literal.elements_)
|
||||
result.emplace(pair.first.name, pair.second->Accept(*this));
|
||||
for (const auto &pair : literal.elements_) result.emplace(pair.first.name, pair.second->Accept(*this));
|
||||
return TypedValue(result, ctx_->memory);
|
||||
}
|
||||
|
||||
@ -390,20 +352,16 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
}
|
||||
|
||||
TypedValue Visit(Function &function) override {
|
||||
FunctionContext function_ctx{dba_, ctx_->memory, ctx_->timestamp,
|
||||
&ctx_->counters, view_};
|
||||
FunctionContext function_ctx{dba_, ctx_->memory, ctx_->timestamp, &ctx_->counters, view_};
|
||||
// Stack allocate evaluated arguments when there's a small number of them.
|
||||
if (function.arguments_.size() <= 8) {
|
||||
TypedValue arguments[8] = {
|
||||
TypedValue(ctx_->memory), TypedValue(ctx_->memory),
|
||||
TypedValue(ctx_->memory), TypedValue(ctx_->memory),
|
||||
TypedValue(ctx_->memory), TypedValue(ctx_->memory),
|
||||
TypedValue arguments[8] = {TypedValue(ctx_->memory), TypedValue(ctx_->memory), TypedValue(ctx_->memory),
|
||||
TypedValue(ctx_->memory), TypedValue(ctx_->memory), TypedValue(ctx_->memory),
|
||||
TypedValue(ctx_->memory), TypedValue(ctx_->memory)};
|
||||
for (size_t i = 0; i < function.arguments_.size(); ++i) {
|
||||
arguments[i] = function.arguments_[i]->Accept(*this);
|
||||
}
|
||||
auto res = function.function_(arguments, function.arguments_.size(),
|
||||
function_ctx);
|
||||
auto res = function.function_(arguments, function.arguments_.size(), function_ctx);
|
||||
MG_ASSERT(res.GetMemoryResource() == ctx_->memory);
|
||||
return res;
|
||||
} else {
|
||||
@ -412,8 +370,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
for (const auto &argument : function.arguments_) {
|
||||
arguments.emplace_back(argument->Accept(*this));
|
||||
}
|
||||
auto res =
|
||||
function.function_(arguments.data(), arguments.size(), function_ctx);
|
||||
auto res = function.function_(arguments.data(), arguments.size(), function_ctx);
|
||||
MG_ASSERT(res.GetMemoryResource() == ctx_->memory);
|
||||
return res;
|
||||
}
|
||||
@ -425,8 +382,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
return TypedValue(ctx_->memory);
|
||||
}
|
||||
if (list_value.type() != TypedValue::Type::List) {
|
||||
throw QueryRuntimeException("REDUCE expected a list, got {}.",
|
||||
list_value.type());
|
||||
throw QueryRuntimeException("REDUCE expected a list, got {}.", list_value.type());
|
||||
}
|
||||
const auto &list = list_value.ValueList();
|
||||
const auto &element_symbol = symbol_table_->at(*reduce.identifier_);
|
||||
@ -446,8 +402,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
return TypedValue(ctx_->memory);
|
||||
}
|
||||
if (list_value.type() != TypedValue::Type::List) {
|
||||
throw QueryRuntimeException("EXTRACT expected a list, got {}.",
|
||||
list_value.type());
|
||||
throw QueryRuntimeException("EXTRACT expected a list, got {}.", list_value.type());
|
||||
}
|
||||
const auto &list = list_value.ValueList();
|
||||
const auto &element_symbol = symbol_table_->at(*extract.identifier_);
|
||||
@ -470,8 +425,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
return TypedValue(ctx_->memory);
|
||||
}
|
||||
if (list_value.type() != TypedValue::Type::List) {
|
||||
throw QueryRuntimeException("ALL expected a list, got {}.",
|
||||
list_value.type());
|
||||
throw QueryRuntimeException("ALL expected a list, got {}.", list_value.type());
|
||||
}
|
||||
const auto &list = list_value.ValueList();
|
||||
const auto &symbol = symbol_table_->at(*all.identifier_);
|
||||
@ -481,9 +435,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
frame_->at(symbol) = element;
|
||||
auto result = all.where_->expression_->Accept(*this);
|
||||
if (!result.IsNull() && result.type() != TypedValue::Type::Bool) {
|
||||
throw QueryRuntimeException(
|
||||
"Predicate of ALL must evaluate to boolean, got {}.",
|
||||
result.type());
|
||||
throw QueryRuntimeException("Predicate of ALL must evaluate to boolean, got {}.", result.type());
|
||||
}
|
||||
if (!result.IsNull()) {
|
||||
has_value = true;
|
||||
@ -510,8 +462,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
return TypedValue(ctx_->memory);
|
||||
}
|
||||
if (list_value.type() != TypedValue::Type::List) {
|
||||
throw QueryRuntimeException("SINGLE expected a list, got {}.",
|
||||
list_value.type());
|
||||
throw QueryRuntimeException("SINGLE expected a list, got {}.", list_value.type());
|
||||
}
|
||||
const auto &list = list_value.ValueList();
|
||||
const auto &symbol = symbol_table_->at(*single.identifier_);
|
||||
@ -521,9 +472,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
frame_->at(symbol) = element;
|
||||
auto result = single.where_->expression_->Accept(*this);
|
||||
if (!result.IsNull() && result.type() != TypedValue::Type::Bool) {
|
||||
throw QueryRuntimeException(
|
||||
"Predicate of SINGLE must evaluate to boolean, got {}.",
|
||||
result.type());
|
||||
throw QueryRuntimeException("Predicate of SINGLE must evaluate to boolean, got {}.", result.type());
|
||||
}
|
||||
if (result.type() == TypedValue::Type::Bool) {
|
||||
has_value = true;
|
||||
@ -551,8 +500,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
return TypedValue(ctx_->memory);
|
||||
}
|
||||
if (list_value.type() != TypedValue::Type::List) {
|
||||
throw QueryRuntimeException("ANY expected a list, got {}.",
|
||||
list_value.type());
|
||||
throw QueryRuntimeException("ANY expected a list, got {}.", list_value.type());
|
||||
}
|
||||
const auto &list = list_value.ValueList();
|
||||
const auto &symbol = symbol_table_->at(*any.identifier_);
|
||||
@ -561,9 +509,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
frame_->at(symbol) = element;
|
||||
auto result = any.where_->expression_->Accept(*this);
|
||||
if (!result.IsNull() && result.type() != TypedValue::Type::Bool) {
|
||||
throw QueryRuntimeException(
|
||||
"Predicate of ANY must evaluate to boolean, got {}.",
|
||||
result.type());
|
||||
throw QueryRuntimeException("Predicate of ANY must evaluate to boolean, got {}.", result.type());
|
||||
}
|
||||
if (!result.IsNull()) {
|
||||
has_value = true;
|
||||
@ -586,8 +532,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
return TypedValue(ctx_->memory);
|
||||
}
|
||||
if (list_value.type() != TypedValue::Type::List) {
|
||||
throw QueryRuntimeException("NONE expected a list, got {}.",
|
||||
list_value.type());
|
||||
throw QueryRuntimeException("NONE expected a list, got {}.", list_value.type());
|
||||
}
|
||||
const auto &list = list_value.ValueList();
|
||||
const auto &symbol = symbol_table_->at(*none.identifier_);
|
||||
@ -596,9 +541,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
frame_->at(symbol) = element;
|
||||
auto result = none.where_->expression_->Accept(*this);
|
||||
if (!result.IsNull() && result.type() != TypedValue::Type::Bool) {
|
||||
throw QueryRuntimeException(
|
||||
"Predicate of NONE must evaluate to boolean, got {}.",
|
||||
result.type());
|
||||
throw QueryRuntimeException("Predicate of NONE must evaluate to boolean, got {}.", result.type());
|
||||
}
|
||||
if (!result.IsNull()) {
|
||||
has_value = true;
|
||||
@ -616,9 +559,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
}
|
||||
|
||||
TypedValue Visit(ParameterLookup ¶m_lookup) override {
|
||||
return TypedValue(
|
||||
ctx_->parameters.AtTokenPosition(param_lookup.token_position_),
|
||||
ctx_->memory);
|
||||
return TypedValue(ctx_->parameters.AtTokenPosition(param_lookup.token_position_), ctx_->memory);
|
||||
}
|
||||
|
||||
TypedValue Visit(RegexMatch ®ex_match) override {
|
||||
@ -628,9 +569,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
return TypedValue(ctx_->memory);
|
||||
}
|
||||
if (regex_value.type() != TypedValue::Type::String) {
|
||||
throw QueryRuntimeException(
|
||||
"Regular expression must evaluate to a string, got {}.",
|
||||
regex_value.type());
|
||||
throw QueryRuntimeException("Regular expression must evaluate to a string, got {}.", regex_value.type());
|
||||
}
|
||||
if (target_string_value.type() != TypedValue::Type::String) {
|
||||
// Instead of error, we return Null which makes it compatible in case we
|
||||
@ -643,75 +582,60 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
std::regex regex(regex_value.ValueString());
|
||||
return TypedValue(std::regex_match(target_string, regex), ctx_->memory);
|
||||
} catch (const std::regex_error &e) {
|
||||
throw QueryRuntimeException("Regex error in '{}': {}",
|
||||
regex_value.ValueString(), e.what());
|
||||
throw QueryRuntimeException("Regex error in '{}': {}", regex_value.ValueString(), e.what());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
template <class TRecordAccessor>
|
||||
storage::PropertyValue GetProperty(const TRecordAccessor &record_accessor,
|
||||
PropertyIx prop) {
|
||||
auto maybe_prop =
|
||||
record_accessor.GetProperty(view_, ctx_->properties[prop.ix]);
|
||||
if (maybe_prop.HasError() &&
|
||||
maybe_prop.GetError() == storage::Error::NONEXISTENT_OBJECT) {
|
||||
storage::PropertyValue GetProperty(const TRecordAccessor &record_accessor, PropertyIx prop) {
|
||||
auto maybe_prop = record_accessor.GetProperty(view_, ctx_->properties[prop.ix]);
|
||||
if (maybe_prop.HasError() && maybe_prop.GetError() == storage::Error::NONEXISTENT_OBJECT) {
|
||||
// This is a very nasty and temporary hack in order to make MERGE work.
|
||||
// The old storage had the following logic when returning an `OLD` view:
|
||||
// `return old ? old : new`. That means that if the `OLD` view didn't
|
||||
// exist, it returned the NEW view. With this hack we simulate that
|
||||
// behavior.
|
||||
// TODO (mferencevic, teon.banek): Remove once MERGE is reimplemented.
|
||||
maybe_prop = record_accessor.GetProperty(storage::View::NEW,
|
||||
ctx_->properties[prop.ix]);
|
||||
maybe_prop = record_accessor.GetProperty(storage::View::NEW, ctx_->properties[prop.ix]);
|
||||
}
|
||||
if (maybe_prop.HasError()) {
|
||||
switch (maybe_prop.GetError()) {
|
||||
case storage::Error::DELETED_OBJECT:
|
||||
throw QueryRuntimeException(
|
||||
"Trying to get a property from a deleted object.");
|
||||
throw QueryRuntimeException("Trying to get a property from a deleted object.");
|
||||
case storage::Error::NONEXISTENT_OBJECT:
|
||||
throw query::QueryRuntimeException(
|
||||
"Trying to get a property from an object that doesn't exist.");
|
||||
throw query::QueryRuntimeException("Trying to get a property from an object that doesn't exist.");
|
||||
case storage::Error::SERIALIZATION_ERROR:
|
||||
case storage::Error::VERTEX_HAS_EDGES:
|
||||
case storage::Error::PROPERTIES_DISABLED:
|
||||
throw QueryRuntimeException(
|
||||
"Unexpected error when getting a property.");
|
||||
throw QueryRuntimeException("Unexpected error when getting a property.");
|
||||
}
|
||||
}
|
||||
return *maybe_prop;
|
||||
}
|
||||
|
||||
template <class TRecordAccessor>
|
||||
storage::PropertyValue GetProperty(const TRecordAccessor &record_accessor,
|
||||
const std::string_view &name) {
|
||||
auto maybe_prop =
|
||||
record_accessor.GetProperty(view_, dba_->NameToProperty(name));
|
||||
if (maybe_prop.HasError() &&
|
||||
maybe_prop.GetError() == storage::Error::NONEXISTENT_OBJECT) {
|
||||
storage::PropertyValue GetProperty(const TRecordAccessor &record_accessor, const std::string_view &name) {
|
||||
auto maybe_prop = record_accessor.GetProperty(view_, dba_->NameToProperty(name));
|
||||
if (maybe_prop.HasError() && maybe_prop.GetError() == storage::Error::NONEXISTENT_OBJECT) {
|
||||
// This is a very nasty and temporary hack in order to make MERGE work.
|
||||
// The old storage had the following logic when returning an `OLD` view:
|
||||
// `return old ? old : new`. That means that if the `OLD` view didn't
|
||||
// exist, it returned the NEW view. With this hack we simulate that
|
||||
// behavior.
|
||||
// TODO (mferencevic, teon.banek): Remove once MERGE is reimplemented.
|
||||
maybe_prop =
|
||||
record_accessor.GetProperty(view_, dba_->NameToProperty(name));
|
||||
maybe_prop = record_accessor.GetProperty(view_, dba_->NameToProperty(name));
|
||||
}
|
||||
if (maybe_prop.HasError()) {
|
||||
switch (maybe_prop.GetError()) {
|
||||
case storage::Error::DELETED_OBJECT:
|
||||
throw QueryRuntimeException(
|
||||
"Trying to get a property from a deleted object.");
|
||||
throw QueryRuntimeException("Trying to get a property from a deleted object.");
|
||||
case storage::Error::NONEXISTENT_OBJECT:
|
||||
throw query::QueryRuntimeException(
|
||||
"Trying to get a property from an object that doesn't exist.");
|
||||
throw query::QueryRuntimeException("Trying to get a property from an object that doesn't exist.");
|
||||
case storage::Error::SERIALIZATION_ERROR:
|
||||
case storage::Error::VERTEX_HAS_EDGES:
|
||||
case storage::Error::PROPERTIES_DISABLED:
|
||||
throw QueryRuntimeException(
|
||||
"Unexpected error when getting a property.");
|
||||
throw QueryRuntimeException("Unexpected error when getting a property.");
|
||||
}
|
||||
}
|
||||
return *maybe_prop;
|
||||
@ -732,8 +656,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
|
||||
/// @param what - Name of what's getting evaluated. Used for user feedback (via
|
||||
/// exception) when the evaluated value is not an int.
|
||||
/// @throw QueryRuntimeException if expression doesn't evaluate to an int.
|
||||
inline int64_t EvaluateInt(ExpressionEvaluator *evaluator, Expression *expr,
|
||||
const std::string &what) {
|
||||
inline int64_t EvaluateInt(ExpressionEvaluator *evaluator, Expression *expr, const std::string &what) {
|
||||
TypedValue value = expr->Accept(*evaluator);
|
||||
try {
|
||||
return value.ValueInt();
|
||||
|
@ -13,31 +13,19 @@ namespace query {
|
||||
class Frame {
|
||||
public:
|
||||
/// Create a Frame of given size backed by a utils::NewDeleteResource()
|
||||
explicit Frame(int64_t size) : elems_(size, utils::NewDeleteResource()) {
|
||||
MG_ASSERT(size >= 0);
|
||||
}
|
||||
explicit Frame(int64_t size) : elems_(size, utils::NewDeleteResource()) { MG_ASSERT(size >= 0); }
|
||||
|
||||
Frame(int64_t size, utils::MemoryResource *memory) : elems_(size, memory) {
|
||||
MG_ASSERT(size >= 0);
|
||||
}
|
||||
Frame(int64_t size, utils::MemoryResource *memory) : elems_(size, memory) { MG_ASSERT(size >= 0); }
|
||||
|
||||
TypedValue &operator[](const Symbol &symbol) {
|
||||
return elems_[symbol.position()];
|
||||
}
|
||||
const TypedValue &operator[](const Symbol &symbol) const {
|
||||
return elems_[symbol.position()];
|
||||
}
|
||||
TypedValue &operator[](const Symbol &symbol) { return elems_[symbol.position()]; }
|
||||
const TypedValue &operator[](const Symbol &symbol) const { return elems_[symbol.position()]; }
|
||||
|
||||
TypedValue &at(const Symbol &symbol) { return elems_.at(symbol.position()); }
|
||||
const TypedValue &at(const Symbol &symbol) const {
|
||||
return elems_.at(symbol.position());
|
||||
}
|
||||
const TypedValue &at(const Symbol &symbol) const { return elems_.at(symbol.position()); }
|
||||
|
||||
auto &elems() { return elems_; }
|
||||
|
||||
utils::MemoryResource *GetMemoryResource() const {
|
||||
return elems_.get_allocator().GetMemoryResource();
|
||||
}
|
||||
utils::MemoryResource *GetMemoryResource() const { return elems_.get_allocator().GetMemoryResource(); }
|
||||
|
||||
private:
|
||||
utils::pmr::vector<TypedValue> elems_;
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -39,16 +39,14 @@ class AuthQueryHandler {
|
||||
|
||||
/// Return false if the user already exists.
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
virtual bool CreateUser(const std::string &username,
|
||||
const std::optional<std::string> &password) = 0;
|
||||
virtual bool CreateUser(const std::string &username, const std::optional<std::string> &password) = 0;
|
||||
|
||||
/// Return false if the user does not exist.
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
virtual bool DropUser(const std::string &username) = 0;
|
||||
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
virtual void SetPassword(const std::string &username,
|
||||
const std::optional<std::string> &password) = 0;
|
||||
virtual void SetPassword(const std::string &username, const std::optional<std::string> &password) = 0;
|
||||
|
||||
/// Return false if the role already exists.
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
@ -65,36 +63,27 @@ class AuthQueryHandler {
|
||||
virtual std::vector<TypedValue> GetRolenames() = 0;
|
||||
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
virtual std::optional<std::string> GetRolenameForUser(
|
||||
const std::string &username) = 0;
|
||||
virtual std::optional<std::string> GetRolenameForUser(const std::string &username) = 0;
|
||||
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
virtual std::vector<TypedValue> GetUsernamesForRole(
|
||||
const std::string &rolename) = 0;
|
||||
virtual std::vector<TypedValue> GetUsernamesForRole(const std::string &rolename) = 0;
|
||||
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
virtual void SetRole(const std::string &username,
|
||||
const std::string &rolename) = 0;
|
||||
virtual void SetRole(const std::string &username, const std::string &rolename) = 0;
|
||||
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
virtual void ClearRole(const std::string &username) = 0;
|
||||
|
||||
virtual std::vector<std::vector<TypedValue>> GetPrivileges(
|
||||
const std::string &user_or_role) = 0;
|
||||
virtual std::vector<std::vector<TypedValue>> GetPrivileges(const std::string &user_or_role) = 0;
|
||||
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
virtual void GrantPrivilege(
|
||||
const std::string &user_or_role,
|
||||
const std::vector<AuthQuery::Privilege> &privileges) = 0;
|
||||
virtual void GrantPrivilege(const std::string &user_or_role, const std::vector<AuthQuery::Privilege> &privileges) = 0;
|
||||
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
virtual void DenyPrivilege(
|
||||
const std::string &user_or_role,
|
||||
const std::vector<AuthQuery::Privilege> &privileges) = 0;
|
||||
virtual void DenyPrivilege(const std::string &user_or_role, const std::vector<AuthQuery::Privilege> &privileges) = 0;
|
||||
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
virtual void RevokePrivilege(
|
||||
const std::string &user_or_role,
|
||||
virtual void RevokePrivilege(const std::string &user_or_role,
|
||||
const std::vector<AuthQuery::Privilege> &privileges) = 0;
|
||||
};
|
||||
|
||||
@ -119,18 +108,14 @@ class ReplicationQueryHandler {
|
||||
};
|
||||
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
virtual void SetReplicationRole(
|
||||
ReplicationQuery::ReplicationRole replication_role,
|
||||
std::optional<int64_t> port) = 0;
|
||||
virtual void SetReplicationRole(ReplicationQuery::ReplicationRole replication_role, std::optional<int64_t> port) = 0;
|
||||
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
virtual ReplicationQuery::ReplicationRole ShowReplicationRole() const = 0;
|
||||
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
virtual void RegisterReplica(const std::string &name,
|
||||
const std::string &socket_address,
|
||||
const ReplicationQuery::SyncMode sync_mode,
|
||||
const std::optional<double> timeout) = 0;
|
||||
virtual void RegisterReplica(const std::string &name, const std::string &socket_address,
|
||||
const ReplicationQuery::SyncMode sync_mode, const std::optional<double> timeout) = 0;
|
||||
|
||||
/// @throw QueryRuntimeException if an error ocurred.
|
||||
virtual void DropReplica(const std::string &replica_name) = 0;
|
||||
@ -145,9 +130,7 @@ class ReplicationQueryHandler {
|
||||
struct PreparedQuery {
|
||||
std::vector<std::string> header;
|
||||
std::vector<AuthQuery::Privilege> privileges;
|
||||
std::function<std::optional<QueryHandlerResult>(AnyStream *stream,
|
||||
std::optional<int> n)>
|
||||
query_handler;
|
||||
std::function<std::optional<QueryHandlerResult>(AnyStream *stream, std::optional<int> n)> query_handler;
|
||||
plan::ReadWriteTypeChecker::RWType rw_type;
|
||||
};
|
||||
|
||||
@ -172,10 +155,7 @@ class CachedPlan {
|
||||
const auto &symbol_table() const { return plan_->GetSymbolTable(); }
|
||||
const auto &ast_storage() const { return plan_->GetAstStorage(); }
|
||||
|
||||
bool IsExpired() const {
|
||||
return cache_timer_.Elapsed() >
|
||||
std::chrono::seconds(FLAGS_query_plan_cache_ttl);
|
||||
};
|
||||
bool IsExpired() const { return cache_timer_.Elapsed() > std::chrono::seconds(FLAGS_query_plan_cache_ttl); };
|
||||
|
||||
private:
|
||||
std::unique_ptr<LogicalPlan> plan_;
|
||||
@ -189,12 +169,8 @@ struct CachedQuery {
|
||||
};
|
||||
|
||||
struct QueryCacheEntry {
|
||||
bool operator==(const QueryCacheEntry &other) const {
|
||||
return first == other.first;
|
||||
}
|
||||
bool operator<(const QueryCacheEntry &other) const {
|
||||
return first < other.first;
|
||||
}
|
||||
bool operator==(const QueryCacheEntry &other) const { return first == other.first; }
|
||||
bool operator<(const QueryCacheEntry &other) const { return first < other.first; }
|
||||
bool operator==(const uint64_t &other) const { return first == other; }
|
||||
bool operator<(const uint64_t &other) const { return first < other; }
|
||||
|
||||
@ -205,12 +181,8 @@ struct QueryCacheEntry {
|
||||
};
|
||||
|
||||
struct PlanCacheEntry {
|
||||
bool operator==(const PlanCacheEntry &other) const {
|
||||
return first == other.first;
|
||||
}
|
||||
bool operator<(const PlanCacheEntry &other) const {
|
||||
return first < other.first;
|
||||
}
|
||||
bool operator==(const PlanCacheEntry &other) const { return first == other.first; }
|
||||
bool operator<(const PlanCacheEntry &other) const { return first < other.first; }
|
||||
bool operator==(const uint64_t &other) const { return first == other; }
|
||||
bool operator<(const uint64_t &other) const { return first < other; }
|
||||
|
||||
@ -228,9 +200,7 @@ struct PlanCacheEntry {
|
||||
* been passed to an `Interpreter` instance.
|
||||
*/
|
||||
struct InterpreterContext {
|
||||
explicit InterpreterContext(storage::Storage *db) : db(db) {
|
||||
MG_ASSERT(db, "Storage must not be NULL");
|
||||
}
|
||||
explicit InterpreterContext(storage::Storage *db) : db(db) { MG_ASSERT(db, "Storage must not be NULL"); }
|
||||
|
||||
storage::Storage *db;
|
||||
|
||||
@ -254,9 +224,7 @@ struct InterpreterContext {
|
||||
|
||||
/// Function that is used to tell all active interpreters that they should stop
|
||||
/// their ongoing execution.
|
||||
inline void Shutdown(InterpreterContext *context) {
|
||||
context->is_shutting_down.store(true, std::memory_order_release);
|
||||
}
|
||||
inline void Shutdown(InterpreterContext *context) { context->is_shutting_down.store(true, std::memory_order_release); }
|
||||
|
||||
/// Function used to set the maximum execution timeout in seconds.
|
||||
inline void SetExecutionTimeout(InterpreterContext *context, double timeout) {
|
||||
@ -286,9 +254,7 @@ class Interpreter final {
|
||||
*
|
||||
* @throw query::QueryException
|
||||
*/
|
||||
PrepareResult Prepare(
|
||||
const std::string &query,
|
||||
const std::map<std::string, storage::PropertyValue> ¶ms);
|
||||
PrepareResult Prepare(const std::string &query, const std::map<std::string, storage::PropertyValue> ¶ms);
|
||||
|
||||
/**
|
||||
* Execute the last prepared query and stream *all* of the results into the
|
||||
@ -329,8 +295,7 @@ class Interpreter final {
|
||||
* @throw query::QueryException
|
||||
*/
|
||||
template <typename TStream>
|
||||
std::map<std::string, TypedValue> Pull(TStream *result_stream,
|
||||
std::optional<int> n = {},
|
||||
std::map<std::string, TypedValue> Pull(TStream *result_stream, std::optional<int> n = {},
|
||||
std::optional<int> qid = {});
|
||||
|
||||
void BeginTransaction();
|
||||
@ -392,35 +357,27 @@ class Interpreter final {
|
||||
|
||||
size_t ActiveQueryExecutions() {
|
||||
return std::count_if(query_executions_.begin(), query_executions_.end(),
|
||||
[](const auto &execution) {
|
||||
return execution && execution->prepared_query;
|
||||
});
|
||||
[](const auto &execution) { return execution && execution->prepared_query; });
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TStream>
|
||||
std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream,
|
||||
std::optional<int> n,
|
||||
std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream, std::optional<int> n,
|
||||
std::optional<int> qid) {
|
||||
MG_ASSERT(in_explicit_transaction_ || !qid,
|
||||
"qid can be only used in explicit transaction!");
|
||||
const int qid_value =
|
||||
qid ? *qid : static_cast<int>(query_executions_.size() - 1);
|
||||
MG_ASSERT(in_explicit_transaction_ || !qid, "qid can be only used in explicit transaction!");
|
||||
const int qid_value = qid ? *qid : static_cast<int>(query_executions_.size() - 1);
|
||||
|
||||
if (qid_value < 0 || qid_value >= query_executions_.size()) {
|
||||
throw InvalidArgumentsException("qid",
|
||||
"Query with specified ID does not exist!");
|
||||
throw InvalidArgumentsException("qid", "Query with specified ID does not exist!");
|
||||
}
|
||||
|
||||
if (n && n < 0) {
|
||||
throw InvalidArgumentsException("n",
|
||||
"Cannot fetch negative number of results!");
|
||||
throw InvalidArgumentsException("n", "Cannot fetch negative number of results!");
|
||||
}
|
||||
|
||||
auto &query_execution = query_executions_[qid_value];
|
||||
|
||||
MG_ASSERT(query_execution && query_execution->prepared_query,
|
||||
"Query already finished executing!");
|
||||
MG_ASSERT(query_execution && query_execution->prepared_query, "Query already finished executing!");
|
||||
|
||||
// Each prepared query has its own summary so we need to somehow preserve
|
||||
// it after it finishes executing because it gets destroyed alongside
|
||||
@ -430,8 +387,7 @@ std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream,
|
||||
// Wrap the (statically polymorphic) stream type into a common type which
|
||||
// the handler knows.
|
||||
AnyStream stream{result_stream, &query_execution->execution_memory};
|
||||
const auto maybe_res =
|
||||
query_execution->prepared_query->query_handler(&stream, n);
|
||||
const auto maybe_res = query_execution->prepared_query->query_handler(&stream, n);
|
||||
// Stream is using execution memory of the query_execution which
|
||||
// can be deleted after its execution so the stream should be cleared
|
||||
// first.
|
||||
|
@ -21,9 +21,7 @@ struct Parameters {
|
||||
* @param position Token position in query of value.
|
||||
* @param value
|
||||
*/
|
||||
void Add(int position, const storage::PropertyValue &value) {
|
||||
storage_.emplace_back(position, value);
|
||||
}
|
||||
void Add(int position, const storage::PropertyValue &value) { storage_.emplace_back(position, value); }
|
||||
|
||||
/**
|
||||
* Returns the value found for the given token position.
|
||||
@ -32,11 +30,8 @@ struct Parameters {
|
||||
* @return Value for the given token position.
|
||||
*/
|
||||
const storage::PropertyValue &AtTokenPosition(int position) const {
|
||||
auto found =
|
||||
std::find_if(storage_.begin(), storage_.end(),
|
||||
[&](const auto &a) { return a.first == position; });
|
||||
MG_ASSERT(found != storage_.end(),
|
||||
"Token position must be present in container");
|
||||
auto found = std::find_if(storage_.begin(), storage_.end(), [&](const auto &a) { return a.first == position; });
|
||||
MG_ASSERT(found != storage_.end(), "Token position must be present in container");
|
||||
return found->second;
|
||||
}
|
||||
|
||||
|
@ -24,8 +24,7 @@ class Path {
|
||||
* Create the path starting with the given vertex.
|
||||
* Allocations are done using the given MemoryResource.
|
||||
*/
|
||||
explicit Path(const VertexAccessor &vertex,
|
||||
utils::MemoryResource *memory = utils::NewDeleteResource())
|
||||
explicit Path(const VertexAccessor &vertex, utils::MemoryResource *memory = utils::NewDeleteResource())
|
||||
: vertices_(memory), edges_(memory) {
|
||||
Expand(vertex);
|
||||
}
|
||||
@ -37,8 +36,7 @@ class Path {
|
||||
*/
|
||||
template <typename... TOthers>
|
||||
explicit Path(const VertexAccessor &vertex, const TOthers &...others)
|
||||
: vertices_(utils::NewDeleteResource()),
|
||||
edges_(utils::NewDeleteResource()) {
|
||||
: vertices_(utils::NewDeleteResource()), edges_(utils::NewDeleteResource()) {
|
||||
Expand(vertex);
|
||||
Expand(others...);
|
||||
}
|
||||
@ -49,8 +47,7 @@ class Path {
|
||||
* Allocations are done using the given MemoryResource.
|
||||
*/
|
||||
template <typename... TOthers>
|
||||
Path(std::allocator_arg_t, utils::MemoryResource *memory,
|
||||
const VertexAccessor &vertex, const TOthers &...others)
|
||||
Path(std::allocator_arg_t, utils::MemoryResource *memory, const VertexAccessor &vertex, const TOthers &...others)
|
||||
: vertices_(memory), edges_(memory) {
|
||||
Expand(vertex);
|
||||
Expand(others...);
|
||||
@ -65,9 +62,8 @@ class Path {
|
||||
* will default to utils::NewDeleteResource().
|
||||
*/
|
||||
Path(const Path &other)
|
||||
: Path(other, std::allocator_traits<allocator_type>::
|
||||
select_on_container_copy_construction(
|
||||
other.GetMemoryResource())
|
||||
: Path(other,
|
||||
std::allocator_traits<allocator_type>::select_on_container_copy_construction(other.GetMemoryResource())
|
||||
.GetMemoryResource()) {}
|
||||
|
||||
/** Construct a copy using the given utils::MemoryResource */
|
||||
@ -79,8 +75,7 @@ class Path {
|
||||
* utils::MemoryResource is obtained from other. After the move, other will be
|
||||
* empty.
|
||||
*/
|
||||
Path(Path &&other) noexcept
|
||||
: Path(std::move(other), other.GetMemoryResource()) {}
|
||||
Path(Path &&other) noexcept : Path(std::move(other), other.GetMemoryResource()) {}
|
||||
|
||||
/**
|
||||
* Construct with the value of other, but use the given utils::MemoryResource.
|
||||
@ -89,8 +84,7 @@ class Path {
|
||||
* performed.
|
||||
*/
|
||||
Path(Path &&other, utils::MemoryResource *memory)
|
||||
: vertices_(std::move(other.vertices_), memory),
|
||||
edges_(std::move(other.edges_), memory) {}
|
||||
: vertices_(std::move(other.vertices_), memory), edges_(std::move(other.edges_), memory) {}
|
||||
|
||||
/** Copy assign other, utils::MemoryResource of `this` is used */
|
||||
Path &operator=(const Path &) = default;
|
||||
@ -102,15 +96,13 @@ class Path {
|
||||
|
||||
/** Expands the path with the given vertex. */
|
||||
void Expand(const VertexAccessor &vertex) {
|
||||
DMG_ASSERT(vertices_.size() == edges_.size(),
|
||||
"Illegal path construction order");
|
||||
DMG_ASSERT(vertices_.size() == edges_.size(), "Illegal path construction order");
|
||||
vertices_.emplace_back(vertex);
|
||||
}
|
||||
|
||||
/** Expands the path with the given edge. */
|
||||
void Expand(const EdgeAccessor &edge) {
|
||||
DMG_ASSERT(vertices_.size() - 1 == edges_.size(),
|
||||
"Illegal path construction order");
|
||||
DMG_ASSERT(vertices_.size() - 1 == edges_.size(), "Illegal path construction order");
|
||||
edges_.emplace_back(edge);
|
||||
}
|
||||
|
||||
@ -129,13 +121,9 @@ class Path {
|
||||
const auto &vertices() const { return vertices_; }
|
||||
const auto &edges() const { return edges_; }
|
||||
|
||||
utils::MemoryResource *GetMemoryResource() const {
|
||||
return vertices_.get_allocator().GetMemoryResource();
|
||||
}
|
||||
utils::MemoryResource *GetMemoryResource() const { return vertices_.get_allocator().GetMemoryResource(); }
|
||||
|
||||
bool operator==(const Path &other) const {
|
||||
return vertices_ == other.vertices_ && edges_ == other.edges_;
|
||||
}
|
||||
bool operator==(const Path &other) const { return vertices_ == other.vertices_ && edges_ == other.edges_; }
|
||||
|
||||
private:
|
||||
// Contains all the vertices in the path.
|
||||
|
@ -91,13 +91,10 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor {
|
||||
double factor = 1.0;
|
||||
if (property_value)
|
||||
// get the exact influence based on ScanAll(label, property, value)
|
||||
factor = db_accessor_->VerticesCount(
|
||||
logical_op.label_, logical_op.property_, property_value.value());
|
||||
factor = db_accessor_->VerticesCount(logical_op.label_, logical_op.property_, property_value.value());
|
||||
else
|
||||
// estimate the influence as ScanAll(label, property) * filtering
|
||||
factor =
|
||||
db_accessor_->VerticesCount(logical_op.label_, logical_op.property_) *
|
||||
CardParam::kFilter;
|
||||
factor = db_accessor_->VerticesCount(logical_op.label_, logical_op.property_) * CardParam::kFilter;
|
||||
|
||||
cardinality_ *= factor;
|
||||
|
||||
@ -115,18 +112,14 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor {
|
||||
int64_t factor = 1;
|
||||
if (upper || lower)
|
||||
// if we have either Bound<PropertyValue>, use the value index
|
||||
factor = db_accessor_->VerticesCount(logical_op.label_,
|
||||
logical_op.property_, lower, upper);
|
||||
factor = db_accessor_->VerticesCount(logical_op.label_, logical_op.property_, lower, upper);
|
||||
else
|
||||
// no values, but we still have the label
|
||||
factor =
|
||||
db_accessor_->VerticesCount(logical_op.label_, logical_op.property_);
|
||||
factor = db_accessor_->VerticesCount(logical_op.label_, logical_op.property_);
|
||||
|
||||
// if we failed to take either bound from the op into account, then apply
|
||||
// the filtering constant to the factor
|
||||
if ((logical_op.upper_bound_ && !upper) ||
|
||||
(logical_op.lower_bound_ && !lower))
|
||||
factor *= CardParam::kFilter;
|
||||
if ((logical_op.upper_bound_ && !upper) || (logical_op.lower_bound_ && !lower)) factor *= CardParam::kFilter;
|
||||
|
||||
cardinality_ *= factor;
|
||||
|
||||
@ -136,8 +129,7 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor {
|
||||
}
|
||||
|
||||
bool PostVisit(ScanAllByLabelProperty &logical_op) override {
|
||||
const auto factor =
|
||||
db_accessor_->VerticesCount(logical_op.label_, logical_op.property_);
|
||||
const auto factor = db_accessor_->VerticesCount(logical_op.label_, logical_op.property_);
|
||||
cardinality_ *= factor;
|
||||
IncrementCost(CostParam::MakeScanAllByLabelProperty);
|
||||
return true;
|
||||
@ -181,8 +173,7 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor {
|
||||
// if the Unwind expression is a list literal, we can deduce cardinality
|
||||
// exactly, otherwise we approximate
|
||||
double unwind_value;
|
||||
if (auto *literal =
|
||||
utils::Downcast<query::ListLiteral>(unwind.input_expression_))
|
||||
if (auto *literal = utils::Downcast<query::ListLiteral>(unwind.input_expression_))
|
||||
unwind_value = literal->elements_.size();
|
||||
else
|
||||
unwind_value = MiscParam::kUnwindNoLiteral;
|
||||
@ -218,21 +209,17 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor {
|
||||
std::optional<ScanAllByLabelPropertyRange::Bound> bound) {
|
||||
if (bound) {
|
||||
auto property_value = ConstPropertyValue(bound->value());
|
||||
if (property_value)
|
||||
return utils::Bound<storage::PropertyValue>(*property_value,
|
||||
bound->type());
|
||||
if (property_value) return utils::Bound<storage::PropertyValue>(*property_value, bound->type());
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// If the expression is a constant property value, it is returned. Otherwise,
|
||||
// return nullopt.
|
||||
std::optional<storage::PropertyValue> ConstPropertyValue(
|
||||
const Expression *expression) {
|
||||
std::optional<storage::PropertyValue> ConstPropertyValue(const Expression *expression) {
|
||||
if (auto *literal = utils::Downcast<const PrimitiveLiteral>(expression)) {
|
||||
return literal->value_;
|
||||
} else if (auto *param_lookup =
|
||||
utils::Downcast<const ParameterLookup>(expression)) {
|
||||
} else if (auto *param_lookup = utils::Downcast<const ParameterLookup>(expression)) {
|
||||
return parameters.AtTokenPosition(param_lookup->token_position_);
|
||||
}
|
||||
return std::nullopt;
|
||||
@ -241,8 +228,7 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor {
|
||||
|
||||
/** Returns the estimated cost of the given plan. */
|
||||
template <class TDbAccessor>
|
||||
double EstimatePlanCost(TDbAccessor *db, const Parameters ¶meters,
|
||||
LogicalOperator &plan) {
|
||||
double EstimatePlanCost(TDbAccessor *db, const Parameters ¶meters, LogicalOperator &plan) {
|
||||
CostEstimator<TDbAccessor> estimator(db, parameters);
|
||||
plan.Accept(estimator);
|
||||
return estimator.cost();
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -28,38 +28,31 @@ class PostProcessor final {
|
||||
public:
|
||||
using ProcessedPlan = std::unique_ptr<LogicalOperator>;
|
||||
|
||||
explicit PostProcessor(const Parameters ¶meters)
|
||||
: parameters_(parameters) {}
|
||||
explicit PostProcessor(const Parameters ¶meters) : parameters_(parameters) {}
|
||||
|
||||
template <class TPlanningContext>
|
||||
std::unique_ptr<LogicalOperator> Rewrite(
|
||||
std::unique_ptr<LogicalOperator> plan, TPlanningContext *context) {
|
||||
return RewriteWithIndexLookup(std::move(plan), context->symbol_table,
|
||||
context->ast_storage, context->db);
|
||||
std::unique_ptr<LogicalOperator> Rewrite(std::unique_ptr<LogicalOperator> plan, TPlanningContext *context) {
|
||||
return RewriteWithIndexLookup(std::move(plan), context->symbol_table, context->ast_storage, context->db);
|
||||
}
|
||||
|
||||
template <class TVertexCounts>
|
||||
double EstimatePlanCost(const std::unique_ptr<LogicalOperator> &plan,
|
||||
TVertexCounts *vertex_counts) {
|
||||
double EstimatePlanCost(const std::unique_ptr<LogicalOperator> &plan, TVertexCounts *vertex_counts) {
|
||||
return ::query::plan::EstimatePlanCost(vertex_counts, parameters_, *plan);
|
||||
}
|
||||
|
||||
template <class TPlanningContext>
|
||||
std::unique_ptr<LogicalOperator> MergeWithCombinator(
|
||||
std::unique_ptr<LogicalOperator> curr_op,
|
||||
std::unique_ptr<LogicalOperator> MergeWithCombinator(std::unique_ptr<LogicalOperator> curr_op,
|
||||
std::unique_ptr<LogicalOperator> last_op, const Tree &combinator,
|
||||
TPlanningContext *context) {
|
||||
if (const auto *union_ = utils::Downcast<const CypherUnion>(&combinator)) {
|
||||
return std::unique_ptr<LogicalOperator>(
|
||||
impl::GenUnion(*union_, std::move(last_op), std::move(curr_op),
|
||||
*context->symbol_table));
|
||||
impl::GenUnion(*union_, std::move(last_op), std::move(curr_op), *context->symbol_table));
|
||||
}
|
||||
throw utils::NotYetImplemented("query combinator");
|
||||
}
|
||||
|
||||
template <class TPlanningContext>
|
||||
std::unique_ptr<LogicalOperator> MakeDistinct(
|
||||
std::unique_ptr<LogicalOperator> last_op, TPlanningContext *context) {
|
||||
std::unique_ptr<LogicalOperator> MakeDistinct(std::unique_ptr<LogicalOperator> last_op, TPlanningContext *context) {
|
||||
auto output_symbols = last_op->OutputSymbols(*context->symbol_table);
|
||||
return std::make_unique<Distinct>(std::move(last_op), output_symbols);
|
||||
}
|
||||
@ -78,12 +71,10 @@ class PostProcessor final {
|
||||
/// @sa RuleBasedPlanner
|
||||
/// @sa VariableStartPlanner
|
||||
template <template <class> class TPlanner, class TDbAccessor>
|
||||
auto MakeLogicalPlanForSingleQuery(
|
||||
std::vector<SingleQueryPart> single_query_parts,
|
||||
auto MakeLogicalPlanForSingleQuery(std::vector<SingleQueryPart> single_query_parts,
|
||||
PlanningContext<TDbAccessor> *context) {
|
||||
context->bound_symbols.clear();
|
||||
return TPlanner<PlanningContext<TDbAccessor>>(context).Plan(
|
||||
single_query_parts);
|
||||
return TPlanner<PlanningContext<TDbAccessor>>(context).Plan(single_query_parts);
|
||||
}
|
||||
|
||||
/// Generates the LogicalOperator tree and returns the resulting plan.
|
||||
@ -98,10 +89,8 @@ auto MakeLogicalPlanForSingleQuery(
|
||||
/// @return pair consisting of the final `TPlanPostProcess::ProcessedPlan` and
|
||||
/// the estimated cost of that plan as a `double`.
|
||||
template <class TPlanningContext, class TPlanPostProcess>
|
||||
auto MakeLogicalPlan(TPlanningContext *context, TPlanPostProcess *post_process,
|
||||
bool use_variable_planner) {
|
||||
auto query_parts = CollectQueryParts(*context->symbol_table,
|
||||
*context->ast_storage, context->query);
|
||||
auto MakeLogicalPlan(TPlanningContext *context, TPlanPostProcess *post_process, bool use_variable_planner) {
|
||||
auto query_parts = CollectQueryParts(*context->symbol_table, *context->ast_storage, context->query);
|
||||
auto &vertex_counts = *context->db;
|
||||
double total_cost = 0;
|
||||
|
||||
@ -113,22 +102,19 @@ auto MakeLogicalPlan(TPlanningContext *context, TPlanPostProcess *post_process,
|
||||
double min_cost = std::numeric_limits<double>::max();
|
||||
|
||||
if (use_variable_planner) {
|
||||
auto plans = MakeLogicalPlanForSingleQuery<VariableStartPlanner>(
|
||||
query_part.single_query_parts, context);
|
||||
auto plans = MakeLogicalPlanForSingleQuery<VariableStartPlanner>(query_part.single_query_parts, context);
|
||||
for (auto plan : plans) {
|
||||
// Plans are generated lazily and the current plan will disappear, so
|
||||
// it's ok to move it.
|
||||
auto rewritten_plan = post_process->Rewrite(std::move(plan), context);
|
||||
double cost =
|
||||
post_process->EstimatePlanCost(rewritten_plan, &vertex_counts);
|
||||
double cost = post_process->EstimatePlanCost(rewritten_plan, &vertex_counts);
|
||||
if (!curr_plan || cost < min_cost) {
|
||||
curr_plan.emplace(std::move(rewritten_plan));
|
||||
min_cost = cost;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>(
|
||||
query_part.single_query_parts, context);
|
||||
auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>(query_part.single_query_parts, context);
|
||||
auto rewritten_plan = post_process->Rewrite(std::move(plan), context);
|
||||
min_cost = post_process->EstimatePlanCost(rewritten_plan, &vertex_counts);
|
||||
curr_plan.emplace(std::move(rewritten_plan));
|
||||
@ -136,8 +122,7 @@ auto MakeLogicalPlan(TPlanningContext *context, TPlanPostProcess *post_process,
|
||||
|
||||
total_cost += min_cost;
|
||||
if (query_part.query_combinator) {
|
||||
last_plan = post_process->MergeWithCombinator(
|
||||
std::move(*curr_plan), std::move(last_plan),
|
||||
last_plan = post_process->MergeWithCombinator(std::move(*curr_plan), std::move(last_plan),
|
||||
*query_part.query_combinator, context);
|
||||
} else {
|
||||
last_plan = std::move(*curr_plan);
|
||||
@ -152,8 +137,7 @@ auto MakeLogicalPlan(TPlanningContext *context, TPlanPostProcess *post_process,
|
||||
}
|
||||
|
||||
template <class TPlanningContext>
|
||||
auto MakeLogicalPlan(TPlanningContext *context, const Parameters ¶meters,
|
||||
bool use_variable_planner) {
|
||||
auto MakeLogicalPlan(TPlanningContext *context, const Parameters ¶meters, bool use_variable_planner) {
|
||||
PostProcessor post_processor(parameters);
|
||||
return MakeLogicalPlan(context, &post_processor, use_variable_planner);
|
||||
}
|
||||
|
@ -8,8 +8,7 @@ namespace query::plan {
|
||||
|
||||
namespace {
|
||||
|
||||
void ForEachPattern(
|
||||
Pattern &pattern, std::function<void(NodeAtom *)> base,
|
||||
void ForEachPattern(Pattern &pattern, std::function<void(NodeAtom *)> base,
|
||||
std::function<void(NodeAtom *, EdgeAtom *, NodeAtom *)> collect) {
|
||||
DMG_ASSERT(!pattern.atoms_.empty(), "Missing atoms in pattern");
|
||||
auto atoms_it = pattern.atoms_.begin();
|
||||
@ -20,8 +19,7 @@ void ForEachPattern(
|
||||
while (atoms_it != pattern.atoms_.end()) {
|
||||
auto edge = utils::Downcast<EdgeAtom>(*atoms_it++);
|
||||
DMG_ASSERT(edge, "Expected an edge atom in pattern.");
|
||||
DMG_ASSERT(atoms_it != pattern.atoms_.end(),
|
||||
"Edge atom should not end the pattern.");
|
||||
DMG_ASSERT(atoms_it != pattern.atoms_.end(), "Edge atom should not end the pattern.");
|
||||
auto prev_node = current_node;
|
||||
current_node = utils::Downcast<NodeAtom>(*atoms_it++);
|
||||
DMG_ASSERT(current_node, "Expected a node atom in pattern.");
|
||||
@ -37,32 +35,24 @@ void ForEachPattern(
|
||||
// (m) -[e]- (n), (n) -[f]- (o).
|
||||
// This representation makes it easier to permute from which node or edge we
|
||||
// want to start expanding.
|
||||
std::vector<Expansion> NormalizePatterns(
|
||||
const SymbolTable &symbol_table, const std::vector<Pattern *> &patterns) {
|
||||
std::vector<Expansion> NormalizePatterns(const SymbolTable &symbol_table, const std::vector<Pattern *> &patterns) {
|
||||
std::vector<Expansion> expansions;
|
||||
auto ignore_node = [&](auto *) {};
|
||||
auto collect_expansion = [&](auto *prev_node, auto *edge,
|
||||
auto *current_node) {
|
||||
auto collect_expansion = [&](auto *prev_node, auto *edge, auto *current_node) {
|
||||
UsedSymbolsCollector collector(symbol_table);
|
||||
if (edge->IsVariable()) {
|
||||
if (edge->lower_bound_) edge->lower_bound_->Accept(collector);
|
||||
if (edge->upper_bound_) edge->upper_bound_->Accept(collector);
|
||||
if (edge->filter_lambda_.expression)
|
||||
edge->filter_lambda_.expression->Accept(collector);
|
||||
if (edge->filter_lambda_.expression) edge->filter_lambda_.expression->Accept(collector);
|
||||
// Remove symbols which are bound by lambda arguments.
|
||||
collector.symbols_.erase(
|
||||
symbol_table.at(*edge->filter_lambda_.inner_edge));
|
||||
collector.symbols_.erase(
|
||||
symbol_table.at(*edge->filter_lambda_.inner_node));
|
||||
collector.symbols_.erase(symbol_table.at(*edge->filter_lambda_.inner_edge));
|
||||
collector.symbols_.erase(symbol_table.at(*edge->filter_lambda_.inner_node));
|
||||
if (edge->type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) {
|
||||
collector.symbols_.erase(
|
||||
symbol_table.at(*edge->weight_lambda_.inner_edge));
|
||||
collector.symbols_.erase(
|
||||
symbol_table.at(*edge->weight_lambda_.inner_node));
|
||||
collector.symbols_.erase(symbol_table.at(*edge->weight_lambda_.inner_edge));
|
||||
collector.symbols_.erase(symbol_table.at(*edge->weight_lambda_.inner_node));
|
||||
}
|
||||
}
|
||||
expansions.emplace_back(Expansion{prev_node, edge, edge->direction_, false,
|
||||
collector.symbols_, current_node});
|
||||
expansions.emplace_back(Expansion{prev_node, edge, edge->direction_, false, collector.symbols_, current_node});
|
||||
};
|
||||
for (const auto &pattern : patterns) {
|
||||
if (pattern->atoms_.size() == 1U) {
|
||||
@ -81,8 +71,7 @@ std::vector<Expansion> NormalizePatterns(
|
||||
// as well as edge symbols which determine Cyphermorphism. Collecting filters
|
||||
// will lift them out of a pattern and generate new expressions (just like they
|
||||
// were in a Where clause).
|
||||
void AddMatching(const std::vector<Pattern *> &patterns, Where *where,
|
||||
SymbolTable &symbol_table, AstStorage &storage,
|
||||
void AddMatching(const std::vector<Pattern *> &patterns, Where *where, SymbolTable &symbol_table, AstStorage &storage,
|
||||
Matching &matching) {
|
||||
auto expansions = NormalizePatterns(symbol_table, patterns);
|
||||
std::unordered_set<Symbol> edge_symbols;
|
||||
@ -116,18 +105,15 @@ void AddMatching(const std::vector<Pattern *> &patterns, Where *where,
|
||||
std::vector<Symbol> path_elements;
|
||||
for (auto *pattern_atom : pattern->atoms_)
|
||||
path_elements.emplace_back(symbol_table.at(*pattern_atom->identifier_));
|
||||
matching.named_paths.emplace(symbol_table.at(*pattern->identifier_),
|
||||
std::move(path_elements));
|
||||
matching.named_paths.emplace(symbol_table.at(*pattern->identifier_), std::move(path_elements));
|
||||
}
|
||||
}
|
||||
if (where) {
|
||||
matching.filters.CollectWhereFilter(*where, symbol_table);
|
||||
}
|
||||
}
|
||||
void AddMatching(const Match &match, SymbolTable &symbol_table,
|
||||
AstStorage &storage, Matching &matching) {
|
||||
return AddMatching(match.patterns_, match.where_, symbol_table, storage,
|
||||
matching);
|
||||
void AddMatching(const Match &match, SymbolTable &symbol_table, AstStorage &storage, Matching &matching) {
|
||||
return AddMatching(match.patterns_, match.where_, symbol_table, storage, matching);
|
||||
}
|
||||
|
||||
auto SplitExpressionOnAnd(Expression *expression) {
|
||||
@ -151,8 +137,7 @@ auto SplitExpressionOnAnd(Expression *expression) {
|
||||
|
||||
} // namespace
|
||||
|
||||
PropertyFilter::PropertyFilter(const SymbolTable &symbol_table,
|
||||
const Symbol &symbol, PropertyIx property,
|
||||
PropertyFilter::PropertyFilter(const SymbolTable &symbol_table, const Symbol &symbol, PropertyIx property,
|
||||
Expression *value, Type type)
|
||||
: symbol_(symbol), property_(property), type_(type), value_(value) {
|
||||
MG_ASSERT(type != Type::RANGE);
|
||||
@ -161,15 +146,10 @@ PropertyFilter::PropertyFilter(const SymbolTable &symbol_table,
|
||||
is_symbol_in_value_ = utils::Contains(collector.symbols_, symbol);
|
||||
}
|
||||
|
||||
PropertyFilter::PropertyFilter(
|
||||
const SymbolTable &symbol_table, const Symbol &symbol, PropertyIx property,
|
||||
PropertyFilter::PropertyFilter(const SymbolTable &symbol_table, const Symbol &symbol, PropertyIx property,
|
||||
const std::optional<PropertyFilter::Bound> &lower_bound,
|
||||
const std::optional<PropertyFilter::Bound> &upper_bound)
|
||||
: symbol_(symbol),
|
||||
property_(property),
|
||||
type_(Type::RANGE),
|
||||
lower_bound_(lower_bound),
|
||||
upper_bound_(upper_bound) {
|
||||
: symbol_(symbol), property_(property), type_(Type::RANGE), lower_bound_(lower_bound), upper_bound_(upper_bound) {
|
||||
UsedSymbolsCollector collector(symbol_table);
|
||||
if (lower_bound) {
|
||||
lower_bound->value()->Accept(collector);
|
||||
@ -180,8 +160,7 @@ PropertyFilter::PropertyFilter(
|
||||
is_symbol_in_value_ = utils::Contains(collector.symbols_, symbol);
|
||||
}
|
||||
|
||||
PropertyFilter::PropertyFilter(const Symbol &symbol, PropertyIx property,
|
||||
Type type)
|
||||
PropertyFilter::PropertyFilter(const Symbol &symbol, PropertyIx property, Type type)
|
||||
: symbol_(symbol), property_(property), type_(type) {
|
||||
// As this constructor is used for property filters where
|
||||
// we don't have to evaluate the filter expression, we set
|
||||
@ -190,8 +169,7 @@ PropertyFilter::PropertyFilter(const Symbol &symbol, PropertyIx property,
|
||||
// we may be looking up.
|
||||
}
|
||||
|
||||
IdFilter::IdFilter(const SymbolTable &symbol_table, const Symbol &symbol,
|
||||
Expression *value)
|
||||
IdFilter::IdFilter(const SymbolTable &symbol_table, const Symbol &symbol, Expression *value)
|
||||
: symbol_(symbol), value_(value) {
|
||||
MG_ASSERT(value);
|
||||
UsedSymbolsCollector collector(symbol_table);
|
||||
@ -203,16 +181,12 @@ void Filters::EraseFilter(const FilterInfo &filter) {
|
||||
// TODO: Ideally, we want to determine the equality of both expression trees,
|
||||
// instead of a simple pointer compare.
|
||||
all_filters_.erase(std::remove_if(all_filters_.begin(), all_filters_.end(),
|
||||
[&filter](const auto &f) {
|
||||
return f.expression == filter.expression;
|
||||
}),
|
||||
[&filter](const auto &f) { return f.expression == filter.expression; }),
|
||||
all_filters_.end());
|
||||
}
|
||||
|
||||
void Filters::EraseLabelFilter(const Symbol &symbol, LabelIx label,
|
||||
std::vector<Expression *> *removed_filters) {
|
||||
for (auto filter_it = all_filters_.begin();
|
||||
filter_it != all_filters_.end();) {
|
||||
void Filters::EraseLabelFilter(const Symbol &symbol, LabelIx label, std::vector<Expression *> *removed_filters) {
|
||||
for (auto filter_it = all_filters_.begin(); filter_it != all_filters_.end();) {
|
||||
if (filter_it->type != FilterInfo::Type::Label) {
|
||||
++filter_it;
|
||||
continue;
|
||||
@ -221,15 +195,13 @@ void Filters::EraseLabelFilter(const Symbol &symbol, LabelIx label,
|
||||
++filter_it;
|
||||
continue;
|
||||
}
|
||||
auto label_it =
|
||||
std::find(filter_it->labels.begin(), filter_it->labels.end(), label);
|
||||
auto label_it = std::find(filter_it->labels.begin(), filter_it->labels.end(), label);
|
||||
if (label_it == filter_it->labels.end()) {
|
||||
++filter_it;
|
||||
continue;
|
||||
}
|
||||
filter_it->labels.erase(label_it);
|
||||
DMG_ASSERT(!utils::Contains(filter_it->labels, label),
|
||||
"Didn't expect duplicated labels");
|
||||
DMG_ASSERT(!utils::Contains(filter_it->labels, label), "Didn't expect duplicated labels");
|
||||
if (filter_it->labels.empty()) {
|
||||
// If there are no labels to filter, then erase the whole FilterInfo.
|
||||
if (removed_filters) {
|
||||
@ -242,8 +214,7 @@ void Filters::EraseLabelFilter(const Symbol &symbol, LabelIx label,
|
||||
}
|
||||
}
|
||||
|
||||
void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table,
|
||||
AstStorage &storage) {
|
||||
void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table, AstStorage &storage) {
|
||||
UsedSymbolsCollector collector(symbol_table);
|
||||
auto add_properties_variable = [&](EdgeAtom *atom) {
|
||||
const auto &symbol = symbol_table.at(*atom->identifier_);
|
||||
@ -256,18 +227,13 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table,
|
||||
{
|
||||
collector.symbols_.clear();
|
||||
prop_pair.second->Accept(collector);
|
||||
collector.symbols_.emplace(
|
||||
symbol_table.at(*atom->filter_lambda_.inner_node));
|
||||
collector.symbols_.emplace(
|
||||
symbol_table.at(*atom->filter_lambda_.inner_edge));
|
||||
collector.symbols_.emplace(symbol_table.at(*atom->filter_lambda_.inner_node));
|
||||
collector.symbols_.emplace(symbol_table.at(*atom->filter_lambda_.inner_edge));
|
||||
// First handle the inline property filter.
|
||||
auto *property_lookup = storage.Create<PropertyLookup>(
|
||||
atom->filter_lambda_.inner_edge, prop_pair.first);
|
||||
auto *prop_equal =
|
||||
storage.Create<EqualOperator>(property_lookup, prop_pair.second);
|
||||
auto *property_lookup = storage.Create<PropertyLookup>(atom->filter_lambda_.inner_edge, prop_pair.first);
|
||||
auto *prop_equal = storage.Create<EqualOperator>(property_lookup, prop_pair.second);
|
||||
// Currently, variable expand has no gains if we set PropertyFilter.
|
||||
all_filters_.emplace_back(FilterInfo{FilterInfo::Type::Generic,
|
||||
prop_equal, collector.symbols_});
|
||||
all_filters_.emplace_back(FilterInfo{FilterInfo::Type::Generic, prop_equal, collector.symbols_});
|
||||
}
|
||||
{
|
||||
collector.symbols_.clear();
|
||||
@ -275,23 +241,15 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table,
|
||||
collector.symbols_.insert(symbol); // PropertyLookup uses the symbol.
|
||||
// Now handle the post-expansion filter.
|
||||
// Create a new identifier and a symbol which will be filled in All.
|
||||
auto *identifier =
|
||||
storage
|
||||
.Create<Identifier>(atom->identifier_->name_,
|
||||
atom->identifier_->user_declared_)
|
||||
->MapTo(
|
||||
symbol_table.CreateSymbol(atom->identifier_->name_, false));
|
||||
auto *identifier = storage.Create<Identifier>(atom->identifier_->name_, atom->identifier_->user_declared_)
|
||||
->MapTo(symbol_table.CreateSymbol(atom->identifier_->name_, false));
|
||||
// Create an equality expression and store it in all_filters_.
|
||||
auto *property_lookup =
|
||||
storage.Create<PropertyLookup>(identifier, prop_pair.first);
|
||||
auto *prop_equal =
|
||||
storage.Create<EqualOperator>(property_lookup, prop_pair.second);
|
||||
auto *property_lookup = storage.Create<PropertyLookup>(identifier, prop_pair.first);
|
||||
auto *prop_equal = storage.Create<EqualOperator>(property_lookup, prop_pair.second);
|
||||
// Currently, variable expand has no gains if we set PropertyFilter.
|
||||
all_filters_.emplace_back(
|
||||
FilterInfo{FilterInfo::Type::Generic,
|
||||
storage.Create<All>(identifier, atom->identifier_,
|
||||
storage.Create<Where>(prop_equal)),
|
||||
collector.symbols_});
|
||||
all_filters_.emplace_back(FilterInfo{
|
||||
FilterInfo::Type::Generic,
|
||||
storage.Create<All>(identifier, atom->identifier_, storage.Create<Where>(prop_equal)), collector.symbols_});
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -299,17 +257,13 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table,
|
||||
const auto &symbol = symbol_table.at(*atom->identifier_);
|
||||
for (auto &prop_pair : atom->properties_) {
|
||||
// Create an equality expression and store it in all_filters_.
|
||||
auto *property_lookup =
|
||||
storage.Create<PropertyLookup>(atom->identifier_, prop_pair.first);
|
||||
auto *prop_equal =
|
||||
storage.Create<EqualOperator>(property_lookup, prop_pair.second);
|
||||
auto *property_lookup = storage.Create<PropertyLookup>(atom->identifier_, prop_pair.first);
|
||||
auto *prop_equal = storage.Create<EqualOperator>(property_lookup, prop_pair.second);
|
||||
collector.symbols_.clear();
|
||||
prop_equal->Accept(collector);
|
||||
FilterInfo filter_info{FilterInfo::Type::Property, prop_equal,
|
||||
collector.symbols_};
|
||||
FilterInfo filter_info{FilterInfo::Type::Property, prop_equal, collector.symbols_};
|
||||
// Store a PropertyFilter on the value of the property.
|
||||
filter_info.property_filter.emplace(symbol_table, symbol, prop_pair.first,
|
||||
prop_pair.second,
|
||||
filter_info.property_filter.emplace(symbol_table, symbol, prop_pair.first, prop_pair.second,
|
||||
PropertyFilter::Type::EQUAL);
|
||||
all_filters_.emplace_back(filter_info);
|
||||
}
|
||||
@ -318,10 +272,8 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table,
|
||||
const auto &node_symbol = symbol_table.at(*node->identifier_);
|
||||
if (!node->labels_.empty()) {
|
||||
// Create a LabelsTest and store it.
|
||||
auto *labels_test =
|
||||
storage.Create<LabelsTest>(node->identifier_, node->labels_);
|
||||
auto label_filter = FilterInfo{FilterInfo::Type::Label, labels_test,
|
||||
std::unordered_set<Symbol>{node_symbol}};
|
||||
auto *labels_test = storage.Create<LabelsTest>(node->identifier_, node->labels_);
|
||||
auto label_filter = FilterInfo{FilterInfo::Type::Label, labels_test, std::unordered_set<Symbol>{node_symbol}};
|
||||
label_filter.labels = node->labels_;
|
||||
all_filters_.emplace_back(label_filter);
|
||||
}
|
||||
@ -339,15 +291,13 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table,
|
||||
|
||||
// Adds the where filter expression to `all_filters_` and collects additional
|
||||
// information for potential property and label indexing.
|
||||
void Filters::CollectWhereFilter(Where &where,
|
||||
const SymbolTable &symbol_table) {
|
||||
void Filters::CollectWhereFilter(Where &where, const SymbolTable &symbol_table) {
|
||||
CollectFilterExpression(where.expression_, symbol_table);
|
||||
}
|
||||
|
||||
// Adds the expression to `all_filters_` and collects additional
|
||||
// information for potential property and label indexing.
|
||||
void Filters::CollectFilterExpression(Expression *expr,
|
||||
const SymbolTable &symbol_table) {
|
||||
void Filters::CollectFilterExpression(Expression *expr, const SymbolTable &symbol_table) {
|
||||
auto filters = SplitExpressionOnAnd(expr);
|
||||
for (const auto &filter : filters) {
|
||||
AnalyzeAndStoreFilter(filter, symbol_table);
|
||||
@ -356,16 +306,12 @@ void Filters::CollectFilterExpression(Expression *expr,
|
||||
|
||||
// Analyzes the filter expression by collecting information on filtering labels
|
||||
// and properties to be used with indexing.
|
||||
void Filters::AnalyzeAndStoreFilter(Expression *expr,
|
||||
const SymbolTable &symbol_table) {
|
||||
void Filters::AnalyzeAndStoreFilter(Expression *expr, const SymbolTable &symbol_table) {
|
||||
using Bound = PropertyFilter::Bound;
|
||||
UsedSymbolsCollector collector(symbol_table);
|
||||
expr->Accept(collector);
|
||||
auto make_filter = [&collector, &expr](FilterInfo::Type type) {
|
||||
return FilterInfo{type, expr, collector.symbols_};
|
||||
};
|
||||
auto get_property_lookup = [](auto *maybe_lookup, auto *&prop_lookup,
|
||||
auto *&ident) -> bool {
|
||||
auto make_filter = [&collector, &expr](FilterInfo::Type type) { return FilterInfo{type, expr, collector.symbols_}; };
|
||||
auto get_property_lookup = [](auto *maybe_lookup, auto *&prop_lookup, auto *&ident) -> bool {
|
||||
return (prop_lookup = utils::Downcast<PropertyLookup>(maybe_lookup)) &&
|
||||
(ident = utils::Downcast<Identifier>(prop_lookup->expression_));
|
||||
};
|
||||
@ -376,9 +322,8 @@ void Filters::AnalyzeAndStoreFilter(Expression *expr,
|
||||
Identifier *ident = nullptr;
|
||||
if (get_property_lookup(maybe_lookup, prop_lookup, ident)) {
|
||||
auto filter = make_filter(FilterInfo::Type::Property);
|
||||
filter.property_filter = PropertyFilter(
|
||||
symbol_table, symbol_table.at(*ident), prop_lookup->property_,
|
||||
val_expr, PropertyFilter::Type::EQUAL);
|
||||
filter.property_filter = PropertyFilter(symbol_table, symbol_table.at(*ident), prop_lookup->property_, val_expr,
|
||||
PropertyFilter::Type::EQUAL);
|
||||
all_filters_.emplace_back(filter);
|
||||
return true;
|
||||
}
|
||||
@ -390,9 +335,8 @@ void Filters::AnalyzeAndStoreFilter(Expression *expr,
|
||||
Identifier *ident = nullptr;
|
||||
if (get_property_lookup(maybe_lookup, prop_lookup, ident)) {
|
||||
auto filter = make_filter(FilterInfo::Type::Property);
|
||||
filter.property_filter = PropertyFilter(
|
||||
symbol_table, symbol_table.at(*ident), prop_lookup->property_,
|
||||
val_expr, PropertyFilter::Type::REGEX_MATCH);
|
||||
filter.property_filter = PropertyFilter(symbol_table, symbol_table.at(*ident), prop_lookup->property_, val_expr,
|
||||
PropertyFilter::Type::REGEX_MATCH);
|
||||
all_filters_.emplace_back(filter);
|
||||
return true;
|
||||
}
|
||||
@ -400,16 +344,14 @@ void Filters::AnalyzeAndStoreFilter(Expression *expr,
|
||||
};
|
||||
// Checks if either the expr1 and expr2 are property lookups, adds them as
|
||||
// PropertyFilter and returns true. Otherwise, returns false.
|
||||
auto add_prop_greater = [&](auto *expr1, auto *expr2,
|
||||
auto bound_type) -> bool {
|
||||
auto add_prop_greater = [&](auto *expr1, auto *expr2, auto bound_type) -> bool {
|
||||
PropertyLookup *prop_lookup = nullptr;
|
||||
Identifier *ident = nullptr;
|
||||
bool is_prop_filter = false;
|
||||
if (get_property_lookup(expr1, prop_lookup, ident)) {
|
||||
// n.prop > value
|
||||
auto filter = make_filter(FilterInfo::Type::Property);
|
||||
filter.property_filter.emplace(symbol_table, symbol_table.at(*ident),
|
||||
prop_lookup->property_,
|
||||
filter.property_filter.emplace(symbol_table, symbol_table.at(*ident), prop_lookup->property_,
|
||||
Bound(expr2, bound_type), std::nullopt);
|
||||
all_filters_.emplace_back(filter);
|
||||
is_prop_filter = true;
|
||||
@ -417,8 +359,7 @@ void Filters::AnalyzeAndStoreFilter(Expression *expr,
|
||||
if (get_property_lookup(expr2, prop_lookup, ident)) {
|
||||
// value > n.prop
|
||||
auto filter = make_filter(FilterInfo::Type::Property);
|
||||
filter.property_filter.emplace(symbol_table, symbol_table.at(*ident),
|
||||
prop_lookup->property_, std::nullopt,
|
||||
filter.property_filter.emplace(symbol_table, symbol_table.at(*ident), prop_lookup->property_, std::nullopt,
|
||||
Bound(expr1, bound_type));
|
||||
all_filters_.emplace_back(filter);
|
||||
is_prop_filter = true;
|
||||
@ -447,9 +388,8 @@ void Filters::AnalyzeAndStoreFilter(Expression *expr,
|
||||
Identifier *ident = nullptr;
|
||||
if (get_property_lookup(maybe_lookup, prop_lookup, ident)) {
|
||||
auto filter = make_filter(FilterInfo::Type::Property);
|
||||
filter.property_filter = PropertyFilter(
|
||||
symbol_table, symbol_table.at(*ident), prop_lookup->property_,
|
||||
val_expr, PropertyFilter::Type::IN);
|
||||
filter.property_filter = PropertyFilter(symbol_table, symbol_table.at(*ident), prop_lookup->property_, val_expr,
|
||||
PropertyFilter::Type::IN);
|
||||
all_filters_.emplace_back(filter);
|
||||
return true;
|
||||
}
|
||||
@ -466,30 +406,26 @@ void Filters::AnalyzeAndStoreFilter(Expression *expr,
|
||||
return false;
|
||||
}
|
||||
|
||||
auto *maybe_is_null_check =
|
||||
utils::Downcast<IsNullOperator>(maybe_is_not_null_check->expression_);
|
||||
auto *maybe_is_null_check = utils::Downcast<IsNullOperator>(maybe_is_not_null_check->expression_);
|
||||
if (!maybe_is_null_check) {
|
||||
return false;
|
||||
}
|
||||
PropertyLookup *prop_lookup = nullptr;
|
||||
Identifier *ident = nullptr;
|
||||
|
||||
if (!get_property_lookup(maybe_is_null_check->expression_, prop_lookup,
|
||||
ident)) {
|
||||
if (!get_property_lookup(maybe_is_null_check->expression_, prop_lookup, ident)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto filter = make_filter(FilterInfo::Type::Property);
|
||||
filter.property_filter =
|
||||
PropertyFilter(symbol_table.at(*ident), prop_lookup->property_,
|
||||
PropertyFilter::Type::IS_NOT_NULL);
|
||||
PropertyFilter(symbol_table.at(*ident), prop_lookup->property_, PropertyFilter::Type::IS_NOT_NULL);
|
||||
all_filters_.emplace_back(filter);
|
||||
return true;
|
||||
};
|
||||
// We are only interested to see the insides of And, because Or prevents
|
||||
// indexing since any labels and properties found there may be optional.
|
||||
DMG_ASSERT(!utils::IsSubtype(*expr, AndOperator::kType),
|
||||
"Expected AndOperators have been split.");
|
||||
DMG_ASSERT(!utils::IsSubtype(*expr, AndOperator::kType), "Expected AndOperators have been split.");
|
||||
if (auto *labels_test = utils::Downcast<LabelsTest>(expr)) {
|
||||
// Since LabelsTest may contain any expression, we can only use the
|
||||
// simplest test on an identifier.
|
||||
@ -527,25 +463,21 @@ void Filters::AnalyzeAndStoreFilter(Expression *expr,
|
||||
all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic));
|
||||
}
|
||||
} else if (auto *gt = utils::Downcast<GreaterOperator>(expr)) {
|
||||
if (!add_prop_greater(gt->expression1_, gt->expression2_,
|
||||
Bound::Type::EXCLUSIVE)) {
|
||||
if (!add_prop_greater(gt->expression1_, gt->expression2_, Bound::Type::EXCLUSIVE)) {
|
||||
all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic));
|
||||
}
|
||||
} else if (auto *ge = utils::Downcast<GreaterEqualOperator>(expr)) {
|
||||
if (!add_prop_greater(ge->expression1_, ge->expression2_,
|
||||
Bound::Type::INCLUSIVE)) {
|
||||
if (!add_prop_greater(ge->expression1_, ge->expression2_, Bound::Type::INCLUSIVE)) {
|
||||
all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic));
|
||||
}
|
||||
} else if (auto *lt = utils::Downcast<LessOperator>(expr)) {
|
||||
// Like greater, but in reverse.
|
||||
if (!add_prop_greater(lt->expression2_, lt->expression1_,
|
||||
Bound::Type::EXCLUSIVE)) {
|
||||
if (!add_prop_greater(lt->expression2_, lt->expression1_, Bound::Type::EXCLUSIVE)) {
|
||||
all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic));
|
||||
}
|
||||
} else if (auto *le = utils::Downcast<LessEqualOperator>(expr)) {
|
||||
// Like greater equal, but in reverse.
|
||||
if (!add_prop_greater(le->expression2_, le->expression1_,
|
||||
Bound::Type::INCLUSIVE)) {
|
||||
if (!add_prop_greater(le->expression2_, le->expression1_, Bound::Type::INCLUSIVE)) {
|
||||
all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic));
|
||||
}
|
||||
} else if (auto *in = utils::Downcast<InListOperator>(expr)) {
|
||||
@ -571,29 +503,25 @@ void Filters::AnalyzeAndStoreFilter(Expression *expr,
|
||||
|
||||
// Converts a Query to multiple QueryParts. In the process new Ast nodes may be
|
||||
// created, e.g. filter expressions.
|
||||
std::vector<SingleQueryPart> CollectSingleQueryParts(
|
||||
SymbolTable &symbol_table, AstStorage &storage, SingleQuery *single_query) {
|
||||
std::vector<SingleQueryPart> CollectSingleQueryParts(SymbolTable &symbol_table, AstStorage &storage,
|
||||
SingleQuery *single_query) {
|
||||
std::vector<SingleQueryPart> query_parts(1);
|
||||
auto *query_part = &query_parts.back();
|
||||
for (auto &clause : single_query->clauses_) {
|
||||
if (auto *match = utils::Downcast<Match>(clause)) {
|
||||
if (match->optional_) {
|
||||
query_part->optional_matching.emplace_back(Matching{});
|
||||
AddMatching(*match, symbol_table, storage,
|
||||
query_part->optional_matching.back());
|
||||
AddMatching(*match, symbol_table, storage, query_part->optional_matching.back());
|
||||
} else {
|
||||
DMG_ASSERT(query_part->optional_matching.empty(),
|
||||
"Match clause cannot follow optional match.");
|
||||
DMG_ASSERT(query_part->optional_matching.empty(), "Match clause cannot follow optional match.");
|
||||
AddMatching(*match, symbol_table, storage, query_part->matching);
|
||||
}
|
||||
} else {
|
||||
query_part->remaining_clauses.push_back(clause);
|
||||
if (auto *merge = utils::Downcast<query::Merge>(clause)) {
|
||||
query_part->merge_matching.emplace_back(Matching{});
|
||||
AddMatching({merge->pattern_}, nullptr, symbol_table, storage,
|
||||
query_part->merge_matching.back());
|
||||
} else if (utils::IsSubtype(*clause, With::kType) ||
|
||||
utils::IsSubtype(*clause, query::Unwind::kType) ||
|
||||
AddMatching({merge->pattern_}, nullptr, symbol_table, storage, query_part->merge_matching.back());
|
||||
} else if (utils::IsSubtype(*clause, With::kType) || utils::IsSubtype(*clause, query::Unwind::kType) ||
|
||||
utils::IsSubtype(*clause, query::CallProcedure::kType)) {
|
||||
// This query part is done, continue with a new one.
|
||||
query_parts.emplace_back(SingleQueryPart{});
|
||||
@ -606,14 +534,12 @@ std::vector<SingleQueryPart> CollectSingleQueryParts(
|
||||
return query_parts;
|
||||
}
|
||||
|
||||
QueryParts CollectQueryParts(SymbolTable &symbol_table, AstStorage &storage,
|
||||
CypherQuery *query) {
|
||||
QueryParts CollectQueryParts(SymbolTable &symbol_table, AstStorage &storage, CypherQuery *query) {
|
||||
std::vector<QueryPart> query_parts;
|
||||
|
||||
auto *single_query = query->single_query_;
|
||||
MG_ASSERT(single_query, "Expected at least a single query");
|
||||
query_parts.push_back(
|
||||
QueryPart{CollectSingleQueryParts(symbol_table, storage, single_query)});
|
||||
query_parts.push_back(QueryPart{CollectSingleQueryParts(symbol_table, storage, single_query)});
|
||||
|
||||
bool distinct = false;
|
||||
for (auto *cypher_union : query->cypher_unions_) {
|
||||
@ -623,9 +549,7 @@ QueryParts CollectQueryParts(SymbolTable &symbol_table, AstStorage &storage,
|
||||
|
||||
auto *single_query = cypher_union->single_query_;
|
||||
MG_ASSERT(single_query, "Expected UNION to have a query");
|
||||
query_parts.push_back(
|
||||
QueryPart{CollectSingleQueryParts(symbol_table, storage, single_query),
|
||||
cypher_union});
|
||||
query_parts.push_back(QueryPart{CollectSingleQueryParts(symbol_table, storage, single_query), cypher_union});
|
||||
}
|
||||
return QueryParts{query_parts, distinct};
|
||||
}
|
||||
|
@ -16,8 +16,7 @@ namespace query::plan {
|
||||
/// Collects symbols from identifiers found in visited AST nodes.
|
||||
class UsedSymbolsCollector : public HierarchicalTreeVisitor {
|
||||
public:
|
||||
explicit UsedSymbolsCollector(const SymbolTable &symbol_table)
|
||||
: symbol_table_(symbol_table) {}
|
||||
explicit UsedSymbolsCollector(const SymbolTable &symbol_table) : symbol_table_(symbol_table) {}
|
||||
|
||||
using HierarchicalTreeVisitor::PostVisit;
|
||||
using HierarchicalTreeVisitor::PreVisit;
|
||||
@ -99,11 +98,10 @@ class PropertyFilter {
|
||||
enum class Type { EQUAL, REGEX_MATCH, RANGE, IN, IS_NOT_NULL };
|
||||
|
||||
/// Construct with Expression being the equality or regex match check.
|
||||
PropertyFilter(const SymbolTable &, const Symbol &, PropertyIx, Expression *,
|
||||
Type);
|
||||
PropertyFilter(const SymbolTable &, const Symbol &, PropertyIx, Expression *, Type);
|
||||
/// Construct the range based filter.
|
||||
PropertyFilter(const SymbolTable &, const Symbol &, PropertyIx,
|
||||
const std::optional<Bound> &, const std::optional<Bound> &);
|
||||
PropertyFilter(const SymbolTable &, const Symbol &, PropertyIx, const std::optional<Bound> &,
|
||||
const std::optional<Bound> &);
|
||||
/// Construct a filter without an expression that produces a value.
|
||||
/// Used for the "PROP IS NOT NULL" filter, and can be used for any
|
||||
/// property filter that doesn't need to use an expression to produce
|
||||
@ -177,20 +175,14 @@ class Filters final {
|
||||
|
||||
auto erase(iterator pos) { return all_filters_.erase(pos); }
|
||||
auto erase(const_iterator pos) { return all_filters_.erase(pos); }
|
||||
auto erase(iterator first, iterator last) {
|
||||
return all_filters_.erase(first, last);
|
||||
}
|
||||
auto erase(const_iterator first, const_iterator last) {
|
||||
return all_filters_.erase(first, last);
|
||||
}
|
||||
auto erase(iterator first, iterator last) { return all_filters_.erase(first, last); }
|
||||
auto erase(const_iterator first, const_iterator last) { return all_filters_.erase(first, last); }
|
||||
|
||||
auto FilteredLabels(const Symbol &symbol) const {
|
||||
std::unordered_set<LabelIx> labels;
|
||||
for (const auto &filter : all_filters_) {
|
||||
if (filter.type == FilterInfo::Type::Label &&
|
||||
utils::Contains(filter.used_symbols, symbol)) {
|
||||
MG_ASSERT(filter.used_symbols.size() == 1U,
|
||||
"Expected a single used symbol for label filter");
|
||||
if (filter.type == FilterInfo::Type::Label && utils::Contains(filter.used_symbols, symbol)) {
|
||||
MG_ASSERT(filter.used_symbols.size() == 1U, "Expected a single used symbol for label filter");
|
||||
labels.insert(filter.labels.begin(), filter.labels.end());
|
||||
}
|
||||
}
|
||||
@ -205,15 +197,13 @@ class Filters final {
|
||||
/// Remove a label filter for symbol; may invalidate iterators.
|
||||
/// If removed_filters is not nullptr, fills the vector with original
|
||||
/// `Expression *` which are now completely removed.
|
||||
void EraseLabelFilter(const Symbol &, LabelIx,
|
||||
std::vector<Expression *> *removed_filters = nullptr);
|
||||
void EraseLabelFilter(const Symbol &, LabelIx, std::vector<Expression *> *removed_filters = nullptr);
|
||||
|
||||
/// Returns a vector of FilterInfo for properties.
|
||||
auto PropertyFilters(const Symbol &symbol) const {
|
||||
std::vector<FilterInfo> filters;
|
||||
for (const auto &filter : all_filters_) {
|
||||
if (filter.type == FilterInfo::Type::Property &&
|
||||
filter.property_filter->symbol_ == symbol) {
|
||||
if (filter.type == FilterInfo::Type::Property && filter.property_filter->symbol_ == symbol) {
|
||||
filters.push_back(filter);
|
||||
}
|
||||
}
|
||||
@ -224,8 +214,7 @@ class Filters final {
|
||||
auto IdFilters(const Symbol &symbol) const {
|
||||
std::vector<FilterInfo> filters;
|
||||
for (const auto &filter : all_filters_) {
|
||||
if (filter.type == FilterInfo::Type::Id &&
|
||||
filter.id_filter->symbol_ == symbol) {
|
||||
if (filter.type == FilterInfo::Type::Id && filter.id_filter->symbol_ == symbol) {
|
||||
filters.push_back(filter);
|
||||
}
|
||||
}
|
||||
|
@ -6,8 +6,7 @@
|
||||
|
||||
namespace query::plan {
|
||||
|
||||
PlanPrinter::PlanPrinter(const DbAccessor *dba, std::ostream *out)
|
||||
: dba_(dba), out_(out) {}
|
||||
PlanPrinter::PlanPrinter(const DbAccessor *dba, std::ostream *out) : dba_(dba), out_(out) {}
|
||||
|
||||
#define PRE_VISIT(TOp) \
|
||||
bool PlanPrinter::PreVisit(TOp &) { \
|
||||
@ -20,13 +19,10 @@ PRE_VISIT(CreateNode);
|
||||
bool PlanPrinter::PreVisit(CreateExpand &op) {
|
||||
WithPrintLn([&](auto &out) {
|
||||
out << "* CreateExpand (" << op.input_symbol_.name() << ")"
|
||||
<< (op.edge_info_.direction == query::EdgeAtom::Direction::IN ? "<-"
|
||||
: "-")
|
||||
<< "[" << op.edge_info_.symbol.name() << ":"
|
||||
<< dba_->EdgeTypeToName(op.edge_info_.edge_type) << "]"
|
||||
<< (op.edge_info_.direction == query::EdgeAtom::Direction::OUT ? "->"
|
||||
: "-")
|
||||
<< "(" << op.node_info_.symbol.name() << ")";
|
||||
<< (op.edge_info_.direction == query::EdgeAtom::Direction::IN ? "<-" : "-") << "["
|
||||
<< op.edge_info_.symbol.name() << ":" << dba_->EdgeTypeToName(op.edge_info_.edge_type) << "]"
|
||||
<< (op.edge_info_.direction == query::EdgeAtom::Direction::OUT ? "->" : "-") << "("
|
||||
<< op.node_info_.symbol.name() << ")";
|
||||
});
|
||||
return true;
|
||||
}
|
||||
@ -44,8 +40,7 @@ bool PlanPrinter::PreVisit(query::plan::ScanAll &op) {
|
||||
bool PlanPrinter::PreVisit(query::plan::ScanAllByLabel &op) {
|
||||
WithPrintLn([&](auto &out) {
|
||||
out << "* ScanAllByLabel"
|
||||
<< " (" << op.output_symbol_.name() << " :"
|
||||
<< dba_->LabelToName(op.label_) << ")";
|
||||
<< " (" << op.output_symbol_.name() << " :" << dba_->LabelToName(op.label_) << ")";
|
||||
});
|
||||
return true;
|
||||
}
|
||||
@ -53,8 +48,7 @@ bool PlanPrinter::PreVisit(query::plan::ScanAllByLabel &op) {
|
||||
bool PlanPrinter::PreVisit(query::plan::ScanAllByLabelPropertyValue &op) {
|
||||
WithPrintLn([&](auto &out) {
|
||||
out << "* ScanAllByLabelPropertyValue"
|
||||
<< " (" << op.output_symbol_.name() << " :"
|
||||
<< dba_->LabelToName(op.label_) << " {"
|
||||
<< " (" << op.output_symbol_.name() << " :" << dba_->LabelToName(op.label_) << " {"
|
||||
<< dba_->PropertyToName(op.property_) << "})";
|
||||
});
|
||||
return true;
|
||||
@ -63,8 +57,7 @@ bool PlanPrinter::PreVisit(query::plan::ScanAllByLabelPropertyValue &op) {
|
||||
bool PlanPrinter::PreVisit(query::plan::ScanAllByLabelPropertyRange &op) {
|
||||
WithPrintLn([&](auto &out) {
|
||||
out << "* ScanAllByLabelPropertyRange"
|
||||
<< " (" << op.output_symbol_.name() << " :"
|
||||
<< dba_->LabelToName(op.label_) << " {"
|
||||
<< " (" << op.output_symbol_.name() << " :" << dba_->LabelToName(op.label_) << " {"
|
||||
<< dba_->PropertyToName(op.property_) << "})";
|
||||
});
|
||||
return true;
|
||||
@ -73,8 +66,7 @@ bool PlanPrinter::PreVisit(query::plan::ScanAllByLabelPropertyRange &op) {
|
||||
bool PlanPrinter::PreVisit(query::plan::ScanAllByLabelProperty &op) {
|
||||
WithPrintLn([&](auto &out) {
|
||||
out << "* ScanAllByLabelProperty"
|
||||
<< " (" << op.output_symbol_.name() << " :"
|
||||
<< dba_->LabelToName(op.label_) << " {"
|
||||
<< " (" << op.output_symbol_.name() << " :" << dba_->LabelToName(op.label_) << " {"
|
||||
<< dba_->PropertyToName(op.property_) << "})";
|
||||
});
|
||||
return true;
|
||||
@ -91,17 +83,13 @@ bool PlanPrinter::PreVisit(ScanAllById &op) {
|
||||
bool PlanPrinter::PreVisit(query::plan::Expand &op) {
|
||||
WithPrintLn([&](auto &out) {
|
||||
*out_ << "* Expand (" << op.input_symbol_.name() << ")"
|
||||
<< (op.common_.direction == query::EdgeAtom::Direction::IN ? "<-"
|
||||
: "-")
|
||||
<< "[" << op.common_.edge_symbol.name();
|
||||
utils::PrintIterable(*out_, op.common_.edge_types, "|",
|
||||
[this](auto &stream, const auto &edge_type) {
|
||||
<< (op.common_.direction == query::EdgeAtom::Direction::IN ? "<-" : "-") << "["
|
||||
<< op.common_.edge_symbol.name();
|
||||
utils::PrintIterable(*out_, op.common_.edge_types, "|", [this](auto &stream, const auto &edge_type) {
|
||||
stream << ":" << dba_->EdgeTypeToName(edge_type);
|
||||
});
|
||||
*out_ << "]"
|
||||
<< (op.common_.direction == query::EdgeAtom::Direction::OUT ? "->"
|
||||
: "-")
|
||||
<< "(" << op.common_.node_symbol.name() << ")";
|
||||
*out_ << "]" << (op.common_.direction == query::EdgeAtom::Direction::OUT ? "->" : "-") << "("
|
||||
<< op.common_.node_symbol.name() << ")";
|
||||
});
|
||||
return true;
|
||||
}
|
||||
@ -124,17 +112,13 @@ bool PlanPrinter::PreVisit(query::plan::ExpandVariable &op) {
|
||||
LOG_FATAL("Unexpected ExpandVariable::type_");
|
||||
}
|
||||
*out_ << " (" << op.input_symbol_.name() << ")"
|
||||
<< (op.common_.direction == query::EdgeAtom::Direction::IN ? "<-"
|
||||
: "-")
|
||||
<< "[" << op.common_.edge_symbol.name();
|
||||
utils::PrintIterable(*out_, op.common_.edge_types, "|",
|
||||
[this](auto &stream, const auto &edge_type) {
|
||||
<< (op.common_.direction == query::EdgeAtom::Direction::IN ? "<-" : "-") << "["
|
||||
<< op.common_.edge_symbol.name();
|
||||
utils::PrintIterable(*out_, op.common_.edge_types, "|", [this](auto &stream, const auto &edge_type) {
|
||||
stream << ":" << dba_->EdgeTypeToName(edge_type);
|
||||
});
|
||||
*out_ << "]"
|
||||
<< (op.common_.direction == query::EdgeAtom::Direction::OUT ? "->"
|
||||
: "-")
|
||||
<< "(" << op.common_.node_symbol.name() << ")";
|
||||
*out_ << "]" << (op.common_.direction == query::EdgeAtom::Direction::OUT ? "->" : "-") << "("
|
||||
<< op.common_.node_symbol.name() << ")";
|
||||
});
|
||||
return true;
|
||||
}
|
||||
@ -142,9 +126,7 @@ bool PlanPrinter::PreVisit(query::plan::ExpandVariable &op) {
|
||||
bool PlanPrinter::PreVisit(query::plan::Produce &op) {
|
||||
WithPrintLn([&](auto &out) {
|
||||
out << "* Produce {";
|
||||
utils::PrintIterable(
|
||||
out, op.named_expressions_, ", ",
|
||||
[](auto &out, const auto &nexpr) { out << nexpr->name_; });
|
||||
utils::PrintIterable(out, op.named_expressions_, ", ", [](auto &out, const auto &nexpr) { out << nexpr->name_; });
|
||||
out << "}";
|
||||
});
|
||||
return true;
|
||||
@ -163,12 +145,10 @@ PRE_VISIT(Accumulate);
|
||||
bool PlanPrinter::PreVisit(query::plan::Aggregate &op) {
|
||||
WithPrintLn([&](auto &out) {
|
||||
out << "* Aggregate {";
|
||||
utils::PrintIterable(
|
||||
out, op.aggregations_, ", ",
|
||||
utils::PrintIterable(out, op.aggregations_, ", ",
|
||||
[](auto &out, const auto &aggr) { out << aggr.output_sym.name(); });
|
||||
out << "} {";
|
||||
utils::PrintIterable(out, op.remember_, ", ",
|
||||
[](auto &out, const auto &sym) { out << sym.name(); });
|
||||
utils::PrintIterable(out, op.remember_, ", ", [](auto &out, const auto &sym) { out << sym.name(); });
|
||||
out << "}";
|
||||
});
|
||||
return true;
|
||||
@ -180,8 +160,7 @@ PRE_VISIT(Limit);
|
||||
bool PlanPrinter::PreVisit(query::plan::OrderBy &op) {
|
||||
WithPrintLn([&op](auto &out) {
|
||||
out << "* OrderBy {";
|
||||
utils::PrintIterable(out, op.output_symbols_, ", ",
|
||||
[](auto &out, const auto &sym) { out << sym.name(); });
|
||||
utils::PrintIterable(out, op.output_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); });
|
||||
out << "}";
|
||||
});
|
||||
return true;
|
||||
@ -208,11 +187,9 @@ PRE_VISIT(Distinct);
|
||||
bool PlanPrinter::PreVisit(query::plan::Union &op) {
|
||||
WithPrintLn([&op](auto &out) {
|
||||
out << "* Union {";
|
||||
utils::PrintIterable(out, op.left_symbols_, ", ",
|
||||
[](auto &out, const auto &sym) { out << sym.name(); });
|
||||
utils::PrintIterable(out, op.left_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); });
|
||||
out << " : ";
|
||||
utils::PrintIterable(out, op.right_symbols_, ", ",
|
||||
[](auto &out, const auto &sym) { out << sym.name(); });
|
||||
utils::PrintIterable(out, op.right_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); });
|
||||
out << "}";
|
||||
});
|
||||
Branch(*op.right_op_);
|
||||
@ -223,8 +200,7 @@ bool PlanPrinter::PreVisit(query::plan::Union &op) {
|
||||
bool PlanPrinter::PreVisit(query::plan::CallProcedure &op) {
|
||||
WithPrintLn([&op](auto &out) {
|
||||
out << "* CallProcedure<" << op.procedure_name_ << "> {";
|
||||
utils::PrintIterable(out, op.result_symbols_, ", ",
|
||||
[](auto &out, const auto &sym) { out << sym.name(); });
|
||||
utils::PrintIterable(out, op.result_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); });
|
||||
out << "}";
|
||||
});
|
||||
return true;
|
||||
@ -238,11 +214,9 @@ bool PlanPrinter::Visit(query::plan::Once &op) {
|
||||
bool PlanPrinter::PreVisit(query::plan::Cartesian &op) {
|
||||
WithPrintLn([&op](auto &out) {
|
||||
out << "* Cartesian {";
|
||||
utils::PrintIterable(out, op.left_symbols_, ", ",
|
||||
[](auto &out, const auto &sym) { out << sym.name(); });
|
||||
utils::PrintIterable(out, op.left_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); });
|
||||
out << " : ";
|
||||
utils::PrintIterable(out, op.right_symbols_, ", ",
|
||||
[](auto &out, const auto &sym) { out << sym.name(); });
|
||||
utils::PrintIterable(out, op.right_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); });
|
||||
out << "}";
|
||||
});
|
||||
Branch(*op.right_op_);
|
||||
@ -257,23 +231,20 @@ bool PlanPrinter::DefaultPreVisit() {
|
||||
return true;
|
||||
}
|
||||
|
||||
void PlanPrinter::Branch(query::plan::LogicalOperator &op,
|
||||
const std::string &branch_name) {
|
||||
void PlanPrinter::Branch(query::plan::LogicalOperator &op, const std::string &branch_name) {
|
||||
WithPrintLn([&](auto &out) { out << "|\\ " << branch_name; });
|
||||
++depth_;
|
||||
op.Accept(*this);
|
||||
--depth_;
|
||||
}
|
||||
|
||||
void PrettyPrint(const DbAccessor &dba, const LogicalOperator *plan_root,
|
||||
std::ostream *out) {
|
||||
void PrettyPrint(const DbAccessor &dba, const LogicalOperator *plan_root, std::ostream *out) {
|
||||
PlanPrinter printer(&dba, out);
|
||||
// FIXME(mtomic): We should make visitors that take const arguments.
|
||||
const_cast<LogicalOperator *>(plan_root)->Accept(printer);
|
||||
}
|
||||
|
||||
nlohmann::json PlanToJson(const DbAccessor &dba,
|
||||
const LogicalOperator *plan_root) {
|
||||
nlohmann::json PlanToJson(const DbAccessor &dba, const LogicalOperator *plan_root) {
|
||||
impl::PlanToJsonVisitor visitor(&dba);
|
||||
// FIXME(mtomic): We should make visitors that take const arguments.
|
||||
const_cast<LogicalOperator *>(plan_root)->Accept(visitor);
|
||||
@ -352,17 +323,11 @@ json ToJson(const utils::Bound<Expression *> &bound) {
|
||||
|
||||
json ToJson(const Symbol &symbol) { return symbol.name(); }
|
||||
|
||||
json ToJson(storage::EdgeTypeId edge_type, const DbAccessor &dba) {
|
||||
return dba.EdgeTypeToName(edge_type);
|
||||
}
|
||||
json ToJson(storage::EdgeTypeId edge_type, const DbAccessor &dba) { return dba.EdgeTypeToName(edge_type); }
|
||||
|
||||
json ToJson(storage::LabelId label, const DbAccessor &dba) {
|
||||
return dba.LabelToName(label);
|
||||
}
|
||||
json ToJson(storage::LabelId label, const DbAccessor &dba) { return dba.LabelToName(label); }
|
||||
|
||||
json ToJson(storage::PropertyId property, const DbAccessor &dba) {
|
||||
return dba.PropertyToName(property);
|
||||
}
|
||||
json ToJson(storage::PropertyId property, const DbAccessor &dba) { return dba.PropertyToName(property); }
|
||||
|
||||
json ToJson(NamedExpression *nexpr) {
|
||||
json json;
|
||||
@ -371,9 +336,7 @@ json ToJson(NamedExpression *nexpr) {
|
||||
return json;
|
||||
}
|
||||
|
||||
json ToJson(
|
||||
const std::vector<std::pair<storage::PropertyId, Expression *>> &properties,
|
||||
const DbAccessor &dba) {
|
||||
json ToJson(const std::vector<std::pair<storage::PropertyId, Expression *>> &properties, const DbAccessor &dba) {
|
||||
json json;
|
||||
for (const auto &prop_pair : properties) {
|
||||
json.emplace(ToJson(prop_pair.first, dba), ToJson(prop_pair.second));
|
||||
@ -558,9 +521,7 @@ bool PlanToJsonVisitor::PreVisit(ExpandVariable &op) {
|
||||
self["upper_bound"] = op.upper_bound_ ? ToJson(op.upper_bound_) : json();
|
||||
self["existing_node"] = op.common_.existing_node;
|
||||
|
||||
self["filter_lambda"] = op.filter_lambda_.expression
|
||||
? ToJson(op.filter_lambda_.expression)
|
||||
: json();
|
||||
self["filter_lambda"] = op.filter_lambda_.expression ? ToJson(op.filter_lambda_.expression) : json();
|
||||
|
||||
if (op.type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) {
|
||||
self["weight_lambda"] = ToJson(op.weight_lambda_->expression);
|
||||
|
@ -19,19 +19,16 @@ class LogicalOperator;
|
||||
/// DbAccessor is needed for resolving label and property names.
|
||||
/// Note that `plan_root` isn't modified, but we can't take it as a const
|
||||
/// because we don't have support for visiting a const LogicalOperator.
|
||||
void PrettyPrint(const DbAccessor &dba, const LogicalOperator *plan_root,
|
||||
std::ostream *out);
|
||||
void PrettyPrint(const DbAccessor &dba, const LogicalOperator *plan_root, std::ostream *out);
|
||||
|
||||
/// Overload of `PrettyPrint` which defaults the `std::ostream` to `std::cout`.
|
||||
inline void PrettyPrint(const DbAccessor &dba,
|
||||
const LogicalOperator *plan_root) {
|
||||
inline void PrettyPrint(const DbAccessor &dba, const LogicalOperator *plan_root) {
|
||||
PrettyPrint(dba, plan_root, &std::cout);
|
||||
}
|
||||
|
||||
/// Convert a `LogicalOperator` plan to a JSON representation.
|
||||
/// DbAccessor is needed for resolving label and property names.
|
||||
nlohmann::json PlanToJson(const DbAccessor &dba,
|
||||
const LogicalOperator *plan_root);
|
||||
nlohmann::json PlanToJson(const DbAccessor &dba, const LogicalOperator *plan_root);
|
||||
|
||||
class PlanPrinter : public virtual HierarchicalLogicalOperatorVisitor {
|
||||
public:
|
||||
@ -130,8 +127,7 @@ nlohmann::json ToJson(storage::PropertyId property, const DbAccessor &dba);
|
||||
|
||||
nlohmann::json ToJson(NamedExpression *nexpr);
|
||||
|
||||
nlohmann::json ToJson(
|
||||
const std::vector<std::pair<storage::PropertyId, Expression *>> &properties,
|
||||
nlohmann::json ToJson(const std::vector<std::pair<storage::PropertyId, Expression *>> &properties,
|
||||
const DbAccessor &dba);
|
||||
|
||||
nlohmann::json ToJson(const NodeCreationInfo &node_info, const DbAccessor &dba);
|
||||
@ -141,7 +137,7 @@ nlohmann::json ToJson(const EdgeCreationInfo &edge_info, const DbAccessor &dba);
|
||||
nlohmann::json ToJson(const Aggregate::Element &elem);
|
||||
|
||||
template <class T, class... Args>
|
||||
nlohmann::json ToJson(const std::vector<T> &items, Args &&... args) {
|
||||
nlohmann::json ToJson(const std::vector<T> &items, Args &&...args) {
|
||||
nlohmann::json json;
|
||||
for (const auto &item : items) {
|
||||
json.emplace_back(ToJson(item, std::forward<Args>(args)...));
|
||||
|
@ -14,23 +14,18 @@ namespace query::plan {
|
||||
namespace {
|
||||
|
||||
unsigned long long IndividualCycles(const ProfilingStats &cumulative_stats) {
|
||||
return cumulative_stats.num_cycles -
|
||||
std::accumulate(
|
||||
cumulative_stats.children.begin(), cumulative_stats.children.end(),
|
||||
0ULL,
|
||||
return cumulative_stats.num_cycles - std::accumulate(cumulative_stats.children.begin(),
|
||||
cumulative_stats.children.end(), 0ULL,
|
||||
[](auto acc, auto &stats) { return acc + stats.num_cycles; });
|
||||
}
|
||||
|
||||
double RelativeTime(unsigned long long num_cycles,
|
||||
unsigned long long total_cycles) {
|
||||
double RelativeTime(unsigned long long num_cycles, unsigned long long total_cycles) {
|
||||
return static_cast<double>(num_cycles) / total_cycles;
|
||||
}
|
||||
|
||||
double AbsoluteTime(unsigned long long num_cycles,
|
||||
unsigned long long total_cycles,
|
||||
double AbsoluteTime(unsigned long long num_cycles, unsigned long long total_cycles,
|
||||
std::chrono::duration<double> total_time) {
|
||||
return (RelativeTime(num_cycles, total_cycles) *
|
||||
static_cast<std::chrono::duration<double, std::milli>>(total_time))
|
||||
return (RelativeTime(num_cycles, total_cycles) * static_cast<std::chrono::duration<double, std::milli>>(total_time))
|
||||
.count();
|
||||
}
|
||||
|
||||
@ -44,18 +39,15 @@ namespace {
|
||||
|
||||
class ProfilingStatsToTableHelper {
|
||||
public:
|
||||
ProfilingStatsToTableHelper(unsigned long long total_cycles,
|
||||
std::chrono::duration<double> total_time)
|
||||
ProfilingStatsToTableHelper(unsigned long long total_cycles, std::chrono::duration<double> total_time)
|
||||
: total_cycles_(total_cycles), total_time_(total_time) {}
|
||||
|
||||
void Output(const ProfilingStats &cumulative_stats) {
|
||||
auto cycles = IndividualCycles(cumulative_stats);
|
||||
|
||||
rows_.emplace_back(std::vector<TypedValue>{
|
||||
TypedValue(FormatOperator(cumulative_stats.name)),
|
||||
TypedValue(cumulative_stats.actual_hits),
|
||||
TypedValue(FormatRelativeTime(cycles)),
|
||||
TypedValue(FormatAbsoluteTime(cycles))});
|
||||
TypedValue(FormatOperator(cumulative_stats.name)), TypedValue(cumulative_stats.actual_hits),
|
||||
TypedValue(FormatRelativeTime(cycles)), TypedValue(FormatAbsoluteTime(cycles))});
|
||||
|
||||
for (size_t i = 1; i < cumulative_stats.children.size(); ++i) {
|
||||
Branch(cumulative_stats.children[i]);
|
||||
@ -70,8 +62,7 @@ class ProfilingStatsToTableHelper {
|
||||
|
||||
private:
|
||||
void Branch(const ProfilingStats &cumulative_stats) {
|
||||
rows_.emplace_back(std::vector<TypedValue>{
|
||||
TypedValue("|\\"), TypedValue(""), TypedValue(""), TypedValue("")});
|
||||
rows_.emplace_back(std::vector<TypedValue>{TypedValue("|\\"), TypedValue(""), TypedValue(""), TypedValue("")});
|
||||
|
||||
++depth_;
|
||||
Output(cumulative_stats);
|
||||
@ -89,18 +80,14 @@ class ProfilingStatsToTableHelper {
|
||||
|
||||
std::string Format(const std::string &str) { return Format(str.c_str()); }
|
||||
|
||||
std::string FormatOperator(const char *str) {
|
||||
return Format(std::string("* ") + str);
|
||||
}
|
||||
std::string FormatOperator(const char *str) { return Format(std::string("* ") + str); }
|
||||
|
||||
std::string FormatRelativeTime(unsigned long long num_cycles) {
|
||||
return fmt::format("{: 10.6f} %",
|
||||
RelativeTime(num_cycles, total_cycles_) * 100);
|
||||
return fmt::format("{: 10.6f} %", RelativeTime(num_cycles, total_cycles_) * 100);
|
||||
}
|
||||
|
||||
std::string FormatAbsoluteTime(unsigned long long num_cycles) {
|
||||
return fmt::format("{: 10.6f} ms",
|
||||
AbsoluteTime(num_cycles, total_cycles_, total_time_));
|
||||
return fmt::format("{: 10.6f} ms", AbsoluteTime(num_cycles, total_cycles_, total_time_));
|
||||
}
|
||||
|
||||
int64_t depth_{0};
|
||||
@ -111,8 +98,7 @@ class ProfilingStatsToTableHelper {
|
||||
|
||||
} // namespace
|
||||
|
||||
std::vector<std::vector<TypedValue>> ProfilingStatsToTable(
|
||||
const ProfilingStats &cumulative_stats,
|
||||
std::vector<std::vector<TypedValue>> ProfilingStatsToTable(const ProfilingStats &cumulative_stats,
|
||||
std::chrono::duration<double> total_time) {
|
||||
ProfilingStatsToTableHelper helper{cumulative_stats.num_cycles, total_time};
|
||||
helper.Output(cumulative_stats);
|
||||
@ -130,13 +116,10 @@ class ProfilingStatsToJsonHelper {
|
||||
using json = nlohmann::json;
|
||||
|
||||
public:
|
||||
ProfilingStatsToJsonHelper(unsigned long long total_cycles,
|
||||
std::chrono::duration<double> total_time)
|
||||
ProfilingStatsToJsonHelper(unsigned long long total_cycles, std::chrono::duration<double> total_time)
|
||||
: total_cycles_(total_cycles), total_time_(total_time) {}
|
||||
|
||||
void Output(const ProfilingStats &cumulative_stats) {
|
||||
return Output(cumulative_stats, &json_);
|
||||
}
|
||||
void Output(const ProfilingStats &cumulative_stats) { return Output(cumulative_stats, &json_); }
|
||||
|
||||
json ToJson() { return json_; }
|
||||
|
||||
@ -147,8 +130,7 @@ class ProfilingStatsToJsonHelper {
|
||||
obj->emplace("name", cumulative_stats.name);
|
||||
obj->emplace("actual_hits", cumulative_stats.actual_hits);
|
||||
obj->emplace("relative_time", RelativeTime(cycles, total_cycles_));
|
||||
obj->emplace("absolute_time",
|
||||
AbsoluteTime(cycles, total_cycles_, total_time_));
|
||||
obj->emplace("absolute_time", AbsoluteTime(cycles, total_cycles_, total_time_));
|
||||
obj->emplace("children", json::array());
|
||||
|
||||
for (size_t i = 0; i < cumulative_stats.children.size(); ++i) {
|
||||
@ -165,8 +147,7 @@ class ProfilingStatsToJsonHelper {
|
||||
|
||||
} // namespace
|
||||
|
||||
nlohmann::json ProfilingStatsToJson(const ProfilingStats &cumulative_stats,
|
||||
std::chrono::duration<double> total_time) {
|
||||
nlohmann::json ProfilingStatsToJson(const ProfilingStats &cumulative_stats, std::chrono::duration<double> total_time) {
|
||||
ProfilingStatsToJsonHelper helper{cumulative_stats.num_cycles, total_time};
|
||||
helper.Output(cumulative_stats);
|
||||
return helper.ToJson();
|
||||
|
@ -23,11 +23,10 @@ struct ProfilingStats {
|
||||
std::vector<ProfilingStats> children;
|
||||
};
|
||||
|
||||
std::vector<std::vector<TypedValue>> ProfilingStatsToTable(
|
||||
const ProfilingStats &cumulative_stats, std::chrono::duration<double>);
|
||||
|
||||
nlohmann::json ProfilingStatsToJson(const ProfilingStats &cumulative_stats,
|
||||
std::vector<std::vector<TypedValue>> ProfilingStatsToTable(const ProfilingStats &cumulative_stats,
|
||||
std::chrono::duration<double>);
|
||||
|
||||
nlohmann::json ProfilingStatsToJson(const ProfilingStats &cumulative_stats, std::chrono::duration<double>);
|
||||
|
||||
} // namespace plan
|
||||
} // namespace query
|
||||
|
@ -81,9 +81,7 @@ void ReadWriteTypeChecker::UpdateType(RWType op_type) {
|
||||
}
|
||||
}
|
||||
|
||||
void ReadWriteTypeChecker::InferRWType(LogicalOperator &root) {
|
||||
root.Accept(*this);
|
||||
}
|
||||
void ReadWriteTypeChecker::InferRWType(LogicalOperator &root) { root.Accept(*this); }
|
||||
|
||||
std::string ReadWriteTypeChecker::TypeToString(const RWType type) {
|
||||
switch (type) {
|
||||
|
@ -2,8 +2,7 @@
|
||||
|
||||
#include "utils/flag_validation.hpp"
|
||||
|
||||
DEFINE_VALIDATED_HIDDEN_int64(
|
||||
query_vertex_count_to_expand_existing, 10,
|
||||
DEFINE_VALIDATED_HIDDEN_int64(query_vertex_count_to_expand_existing, 10,
|
||||
"Maximum count of indexed vertices which provoke "
|
||||
"indexed lookup and then expand to existing, instead of "
|
||||
"a regular expand. Default is 10, to turn off use -1.",
|
||||
@ -11,8 +10,7 @@ DEFINE_VALIDATED_HIDDEN_int64(
|
||||
|
||||
namespace query::plan::impl {
|
||||
|
||||
Expression *RemoveAndExpressions(
|
||||
Expression *expr, const std::unordered_set<Expression *> &exprs_to_remove) {
|
||||
Expression *RemoveAndExpressions(Expression *expr, const std::unordered_set<Expression *> &exprs_to_remove) {
|
||||
auto *and_op = utils::Downcast<AndOperator>(expr);
|
||||
if (!and_op) return expr;
|
||||
if (utils::Contains(exprs_to_remove, and_op)) {
|
||||
@ -24,10 +22,8 @@ Expression *RemoveAndExpressions(
|
||||
if (utils::Contains(exprs_to_remove, and_op->expression2_)) {
|
||||
and_op->expression2_ = nullptr;
|
||||
}
|
||||
and_op->expression1_ =
|
||||
RemoveAndExpressions(and_op->expression1_, exprs_to_remove);
|
||||
and_op->expression2_ =
|
||||
RemoveAndExpressions(and_op->expression2_, exprs_to_remove);
|
||||
and_op->expression1_ = RemoveAndExpressions(and_op->expression1_, exprs_to_remove);
|
||||
and_op->expression2_ = RemoveAndExpressions(and_op->expression2_, exprs_to_remove);
|
||||
if (!and_op->expression1_ && !and_op->expression2_) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -25,14 +25,12 @@ namespace impl {
|
||||
|
||||
// Return the new root expression after removing the given expressions from the
|
||||
// given expression tree.
|
||||
Expression *RemoveAndExpressions(
|
||||
Expression *expr, const std::unordered_set<Expression *> &exprs_to_remove);
|
||||
Expression *RemoveAndExpressions(Expression *expr, const std::unordered_set<Expression *> &exprs_to_remove);
|
||||
|
||||
template <class TDbAccessor>
|
||||
class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
|
||||
public:
|
||||
IndexLookupRewriter(SymbolTable *symbol_table, AstStorage *ast_storage,
|
||||
TDbAccessor *db)
|
||||
IndexLookupRewriter(SymbolTable *symbol_table, AstStorage *ast_storage, TDbAccessor *db)
|
||||
: symbol_table_(symbol_table), ast_storage_(ast_storage), db_(db) {}
|
||||
|
||||
using HierarchicalLogicalOperatorVisitor::PostVisit;
|
||||
@ -52,10 +50,8 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
|
||||
// free the memory.
|
||||
bool PostVisit(Filter &op) override {
|
||||
prev_ops_.pop_back();
|
||||
op.expression_ =
|
||||
RemoveAndExpressions(op.expression_, filter_exprs_for_removal_);
|
||||
if (!op.expression_ ||
|
||||
utils::Contains(filter_exprs_for_removal_, op.expression_)) {
|
||||
op.expression_ = RemoveAndExpressions(op.expression_, filter_exprs_for_removal_);
|
||||
if (!op.expression_ || utils::Contains(filter_exprs_for_removal_, op.expression_)) {
|
||||
SetOnParent(op.input());
|
||||
}
|
||||
return true;
|
||||
@ -91,8 +87,7 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
|
||||
return true;
|
||||
}
|
||||
ScanAll dst_scan(expand.input(), expand.common_.node_symbol, expand.view_);
|
||||
auto indexed_scan =
|
||||
GenScanByIndex(dst_scan, FLAGS_query_vertex_count_to_expand_existing);
|
||||
auto indexed_scan = GenScanByIndex(dst_scan, FLAGS_query_vertex_count_to_expand_existing);
|
||||
if (indexed_scan) {
|
||||
expand.set_input(std::move(indexed_scan));
|
||||
expand.common_.existing_node = true;
|
||||
@ -113,8 +108,7 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
|
||||
return true;
|
||||
}
|
||||
std::unique_ptr<ScanAll> indexed_scan;
|
||||
ScanAll dst_scan(expand.input(), expand.common_.node_symbol,
|
||||
storage::View::OLD);
|
||||
ScanAll dst_scan(expand.input(), expand.common_.node_symbol, storage::View::OLD);
|
||||
// With expand to existing we only get real gains with BFS, because we use a
|
||||
// different algorithm then, so prefer expand to existing.
|
||||
if (expand.type_ == EdgeAtom::Type::BREADTH_FIRST) {
|
||||
@ -122,8 +116,7 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
|
||||
// unconditionally creating an indexed scan.
|
||||
indexed_scan = GenScanByIndex(dst_scan);
|
||||
} else {
|
||||
indexed_scan =
|
||||
GenScanByIndex(dst_scan, FLAGS_query_vertex_count_to_expand_existing);
|
||||
indexed_scan = GenScanByIndex(dst_scan, FLAGS_query_vertex_count_to_expand_existing);
|
||||
}
|
||||
if (indexed_scan) {
|
||||
expand.set_input(std::move(indexed_scan));
|
||||
@ -449,9 +442,7 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
|
||||
int64_t vertex_count;
|
||||
};
|
||||
|
||||
bool DefaultPreVisit() override {
|
||||
throw utils::NotYetImplemented("optimizing index lookup");
|
||||
}
|
||||
bool DefaultPreVisit() override { throw utils::NotYetImplemented("optimizing index lookup"); }
|
||||
|
||||
void SetOnParent(const std::shared_ptr<LogicalOperator> &input) {
|
||||
MG_ASSERT(input);
|
||||
@ -471,18 +462,12 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
|
||||
}
|
||||
}
|
||||
|
||||
storage::LabelId GetLabel(LabelIx label) {
|
||||
return db_->NameToLabel(label.name);
|
||||
}
|
||||
storage::LabelId GetLabel(LabelIx label) { return db_->NameToLabel(label.name); }
|
||||
|
||||
storage::PropertyId GetProperty(PropertyIx prop) {
|
||||
return db_->NameToProperty(prop.name);
|
||||
}
|
||||
storage::PropertyId GetProperty(PropertyIx prop) { return db_->NameToProperty(prop.name); }
|
||||
|
||||
std::optional<LabelIx> FindBestLabelIndex(
|
||||
const std::unordered_set<LabelIx> &labels) {
|
||||
MG_ASSERT(!labels.empty(),
|
||||
"Trying to find the best label without any labels.");
|
||||
std::optional<LabelIx> FindBestLabelIndex(const std::unordered_set<LabelIx> &labels) {
|
||||
MG_ASSERT(!labels.empty(), "Trying to find the best label without any labels.");
|
||||
std::optional<LabelIx> best_label;
|
||||
for (const auto &label : labels) {
|
||||
if (!db_->LabelIndexExists(GetLabel(label))) continue;
|
||||
@ -490,17 +475,15 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
|
||||
best_label = label;
|
||||
continue;
|
||||
}
|
||||
if (db_->VerticesCount(GetLabel(label)) <
|
||||
db_->VerticesCount(GetLabel(*best_label)))
|
||||
best_label = label;
|
||||
if (db_->VerticesCount(GetLabel(label)) < db_->VerticesCount(GetLabel(*best_label))) best_label = label;
|
||||
}
|
||||
return best_label;
|
||||
}
|
||||
|
||||
// Finds the label-property combination which has indexed the lowest amount of
|
||||
// vertices. If the index cannot be found, nullopt is returned.
|
||||
std::optional<LabelPropertyIndex> FindBestLabelPropertyIndex(
|
||||
const Symbol &symbol, const std::unordered_set<Symbol> &bound_symbols) {
|
||||
std::optional<LabelPropertyIndex> FindBestLabelPropertyIndex(const Symbol &symbol,
|
||||
const std::unordered_set<Symbol> &bound_symbols) {
|
||||
auto are_bound = [&bound_symbols](const auto &used_symbols) {
|
||||
for (const auto &used_symbol : used_symbols) {
|
||||
if (!utils::Contains(bound_symbols, used_symbol)) {
|
||||
@ -512,8 +495,7 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
|
||||
std::optional<LabelPropertyIndex> found;
|
||||
for (const auto &label : filters_.FilteredLabels(symbol)) {
|
||||
for (const auto &filter : filters_.PropertyFilters(symbol)) {
|
||||
if (filter.property_filter->is_symbol_in_value_ ||
|
||||
!are_bound(filter.used_symbols)) {
|
||||
if (filter.property_filter->is_symbol_in_value_ || !are_bound(filter.used_symbols)) {
|
||||
// Skip filter expressions which use the symbol whose property we are
|
||||
// looking up or aren't bound. We cannot scan by such expressions. For
|
||||
// example, in `n.a = 2 + n.b` both sides of `=` refer to `n`, so we
|
||||
@ -521,27 +503,20 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
|
||||
continue;
|
||||
}
|
||||
const auto &property = filter.property_filter->property_;
|
||||
if (!db_->LabelPropertyIndexExists(GetLabel(label),
|
||||
GetProperty(property))) {
|
||||
if (!db_->LabelPropertyIndexExists(GetLabel(label), GetProperty(property))) {
|
||||
continue;
|
||||
}
|
||||
int64_t vertex_count =
|
||||
db_->VerticesCount(GetLabel(label), GetProperty(property));
|
||||
int64_t vertex_count = db_->VerticesCount(GetLabel(label), GetProperty(property));
|
||||
auto is_better_type = [&found](PropertyFilter::Type type) {
|
||||
// Order the types by the most preferred index lookup type.
|
||||
static const PropertyFilter::Type kFilterTypeOrder[] = {
|
||||
PropertyFilter::Type::EQUAL, PropertyFilter::Type::RANGE,
|
||||
PropertyFilter::Type::REGEX_MATCH};
|
||||
auto *found_sort_ix =
|
||||
std::find(kFilterTypeOrder, kFilterTypeOrder + 3,
|
||||
found->filter.property_filter->type_);
|
||||
auto *type_sort_ix =
|
||||
std::find(kFilterTypeOrder, kFilterTypeOrder + 3, type);
|
||||
PropertyFilter::Type::EQUAL, PropertyFilter::Type::RANGE, PropertyFilter::Type::REGEX_MATCH};
|
||||
auto *found_sort_ix = std::find(kFilterTypeOrder, kFilterTypeOrder + 3, found->filter.property_filter->type_);
|
||||
auto *type_sort_ix = std::find(kFilterTypeOrder, kFilterTypeOrder + 3, type);
|
||||
return type_sort_ix < found_sort_ix;
|
||||
};
|
||||
if (!found || vertex_count < found->vertex_count ||
|
||||
(vertex_count == found->vertex_count &&
|
||||
is_better_type(filter.property_filter->type_))) {
|
||||
(vertex_count == found->vertex_count && is_better_type(filter.property_filter->type_))) {
|
||||
found = LabelPropertyIndex{label, filter, vertex_count};
|
||||
}
|
||||
}
|
||||
@ -556,15 +531,13 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
|
||||
// `max_vertex_count` controls, whether no operator should be created if the
|
||||
// vertex count in the best index exceeds this number. In such a case,
|
||||
// `nullptr` is returned and `input` is not chained.
|
||||
std::unique_ptr<ScanAll> GenScanByIndex(
|
||||
const ScanAll &scan,
|
||||
std::unique_ptr<ScanAll> GenScanByIndex(const ScanAll &scan,
|
||||
const std::optional<int64_t> &max_vertex_count = std::nullopt) {
|
||||
const auto &input = scan.input();
|
||||
const auto &node_symbol = scan.output_symbol_;
|
||||
const auto &view = scan.view_;
|
||||
const auto &modified_symbols = scan.ModifiedSymbols(*symbol_table_);
|
||||
std::unordered_set<Symbol> bound_symbols(modified_symbols.begin(),
|
||||
modified_symbols.end());
|
||||
std::unordered_set<Symbol> bound_symbols(modified_symbols.begin(), modified_symbols.end());
|
||||
auto are_bound = [&bound_symbols](const auto &used_symbols) {
|
||||
for (const auto &used_symbol : used_symbols) {
|
||||
if (!utils::Contains(bound_symbols, used_symbol)) {
|
||||
@ -576,9 +549,7 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
|
||||
// First, try to see if we can find a vertex by ID.
|
||||
if (!max_vertex_count || *max_vertex_count >= 1) {
|
||||
for (const auto &filter : filters_.IdFilters(node_symbol)) {
|
||||
if (filter.id_filter->is_symbol_in_value_ ||
|
||||
!are_bound(filter.used_symbols))
|
||||
continue;
|
||||
if (filter.id_filter->is_symbol_in_value_ || !are_bound(filter.used_symbols)) continue;
|
||||
auto *value = filter.id_filter->value_;
|
||||
filter_exprs_for_removal_.insert(filter.expression);
|
||||
filters_.EraseFilter(filter);
|
||||
@ -606,76 +577,62 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
|
||||
}
|
||||
filters_.EraseFilter(found_index->filter);
|
||||
std::vector<Expression *> removed_expressions;
|
||||
filters_.EraseLabelFilter(node_symbol, found_index->label,
|
||||
&removed_expressions);
|
||||
filter_exprs_for_removal_.insert(removed_expressions.begin(),
|
||||
removed_expressions.end());
|
||||
filters_.EraseLabelFilter(node_symbol, found_index->label, &removed_expressions);
|
||||
filter_exprs_for_removal_.insert(removed_expressions.begin(), removed_expressions.end());
|
||||
if (prop_filter.lower_bound_ || prop_filter.upper_bound_) {
|
||||
return std::make_unique<ScanAllByLabelPropertyRange>(
|
||||
input, node_symbol, GetLabel(found_index->label),
|
||||
GetProperty(prop_filter.property_), prop_filter.property_.name,
|
||||
prop_filter.lower_bound_, prop_filter.upper_bound_, view);
|
||||
input, node_symbol, GetLabel(found_index->label), GetProperty(prop_filter.property_),
|
||||
prop_filter.property_.name, prop_filter.lower_bound_, prop_filter.upper_bound_, view);
|
||||
} else if (prop_filter.type_ == PropertyFilter::Type::REGEX_MATCH) {
|
||||
// Generate index scan using the empty string as a lower bound.
|
||||
Expression *empty_string = ast_storage_->Create<PrimitiveLiteral>("");
|
||||
auto lower_bound = utils::MakeBoundInclusive(empty_string);
|
||||
return std::make_unique<ScanAllByLabelPropertyRange>(
|
||||
input, node_symbol, GetLabel(found_index->label),
|
||||
GetProperty(prop_filter.property_), prop_filter.property_.name,
|
||||
std::make_optional(lower_bound), std::nullopt, view);
|
||||
input, node_symbol, GetLabel(found_index->label), GetProperty(prop_filter.property_),
|
||||
prop_filter.property_.name, std::make_optional(lower_bound), std::nullopt, view);
|
||||
} else if (prop_filter.type_ == PropertyFilter::Type::IN) {
|
||||
// TODO(buda): ScanAllByLabelProperty + Filter should be considered
|
||||
// here once the operator and the right cardinality estimation exist.
|
||||
auto const &symbol = symbol_table_->CreateAnonymousSymbol();
|
||||
auto *expression = ast_storage_->Create<Identifier>(symbol.name_);
|
||||
expression->MapTo(symbol);
|
||||
auto unwind_operator =
|
||||
std::make_unique<Unwind>(input, prop_filter.value_, symbol);
|
||||
auto unwind_operator = std::make_unique<Unwind>(input, prop_filter.value_, symbol);
|
||||
return std::make_unique<ScanAllByLabelPropertyValue>(
|
||||
std::move(unwind_operator), node_symbol,
|
||||
GetLabel(found_index->label), GetProperty(prop_filter.property_),
|
||||
std::move(unwind_operator), node_symbol, GetLabel(found_index->label), GetProperty(prop_filter.property_),
|
||||
prop_filter.property_.name, expression, view);
|
||||
} else if (prop_filter.type_ == PropertyFilter::Type::IS_NOT_NULL) {
|
||||
return std::make_unique<ScanAllByLabelProperty>(
|
||||
input, node_symbol, GetLabel(found_index->label),
|
||||
return std::make_unique<ScanAllByLabelProperty>(input, node_symbol, GetLabel(found_index->label),
|
||||
GetProperty(prop_filter.property_), prop_filter.property_.name,
|
||||
view);
|
||||
} else {
|
||||
MG_ASSERT(
|
||||
prop_filter.value_,
|
||||
"Property filter should either have bounds or a value expression.");
|
||||
return std::make_unique<ScanAllByLabelPropertyValue>(
|
||||
input, node_symbol, GetLabel(found_index->label),
|
||||
GetProperty(prop_filter.property_), prop_filter.property_.name,
|
||||
prop_filter.value_, view);
|
||||
MG_ASSERT(prop_filter.value_, "Property filter should either have bounds or a value expression.");
|
||||
return std::make_unique<ScanAllByLabelPropertyValue>(input, node_symbol, GetLabel(found_index->label),
|
||||
GetProperty(prop_filter.property_),
|
||||
prop_filter.property_.name, prop_filter.value_, view);
|
||||
}
|
||||
}
|
||||
auto maybe_label = FindBestLabelIndex(labels);
|
||||
if (!maybe_label) return nullptr;
|
||||
const auto &label = *maybe_label;
|
||||
if (max_vertex_count &&
|
||||
db_->VerticesCount(GetLabel(label)) > *max_vertex_count) {
|
||||
if (max_vertex_count && db_->VerticesCount(GetLabel(label)) > *max_vertex_count) {
|
||||
// Don't create an indexed lookup, since we have more labeled vertices
|
||||
// than the allowed count.
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<Expression *> removed_expressions;
|
||||
filters_.EraseLabelFilter(node_symbol, label, &removed_expressions);
|
||||
filter_exprs_for_removal_.insert(removed_expressions.begin(),
|
||||
removed_expressions.end());
|
||||
return std::make_unique<ScanAllByLabel>(input, node_symbol, GetLabel(label),
|
||||
view);
|
||||
filter_exprs_for_removal_.insert(removed_expressions.begin(), removed_expressions.end());
|
||||
return std::make_unique<ScanAllByLabel>(input, node_symbol, GetLabel(label), view);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
|
||||
template <class TDbAccessor>
|
||||
std::unique_ptr<LogicalOperator> RewriteWithIndexLookup(
|
||||
std::unique_ptr<LogicalOperator> root_op, SymbolTable *symbol_table,
|
||||
AstStorage *ast_storage, TDbAccessor *db) {
|
||||
impl::IndexLookupRewriter<TDbAccessor> rewriter(symbol_table, ast_storage,
|
||||
db);
|
||||
std::unique_ptr<LogicalOperator> RewriteWithIndexLookup(std::unique_ptr<LogicalOperator> root_op,
|
||||
SymbolTable *symbol_table, AstStorage *ast_storage,
|
||||
TDbAccessor *db) {
|
||||
impl::IndexLookupRewriter<TDbAccessor> rewriter(symbol_table, ast_storage, db);
|
||||
root_op->Accept(rewriter);
|
||||
if (rewriter.new_root_) {
|
||||
// This shouldn't happen in real use case, because IndexLookupRewriter
|
||||
|
@ -14,8 +14,7 @@ namespace query::plan {
|
||||
|
||||
namespace {
|
||||
|
||||
bool HasBoundFilterSymbols(const std::unordered_set<Symbol> &bound_symbols,
|
||||
const FilterInfo &filter) {
|
||||
bool HasBoundFilterSymbols(const std::unordered_set<Symbol> &bound_symbols, const FilterInfo &filter) {
|
||||
for (const auto &symbol : filter.used_symbols) {
|
||||
if (bound_symbols.find(symbol) == bound_symbols.end()) {
|
||||
return false;
|
||||
@ -37,14 +36,9 @@ bool HasBoundFilterSymbols(const std::unordered_set<Symbol> &bound_symbols,
|
||||
// aggregations and expressions used for group by.
|
||||
class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
public:
|
||||
ReturnBodyContext(const ReturnBody &body, SymbolTable &symbol_table,
|
||||
const std::unordered_set<Symbol> &bound_symbols,
|
||||
ReturnBodyContext(const ReturnBody &body, SymbolTable &symbol_table, const std::unordered_set<Symbol> &bound_symbols,
|
||||
AstStorage &storage, Where *where = nullptr)
|
||||
: body_(body),
|
||||
symbol_table_(symbol_table),
|
||||
bound_symbols_(bound_symbols),
|
||||
storage_(storage),
|
||||
where_(where) {
|
||||
: body_(body), symbol_table_(symbol_table), bound_symbols_(bound_symbols), storage_(storage), where_(where) {
|
||||
// Collect symbols from named expressions.
|
||||
output_symbols_.reserve(body_.named_expressions.size());
|
||||
if (body.all_identifiers) {
|
||||
@ -78,8 +72,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
if (where) {
|
||||
where->Accept(*this);
|
||||
}
|
||||
MG_ASSERT(aggregations_.empty(),
|
||||
"Unexpected aggregations in ORDER BY or WHERE");
|
||||
MG_ASSERT(aggregations_.empty(), "Unexpected aggregations in ORDER BY or WHERE");
|
||||
}
|
||||
}
|
||||
|
||||
@ -94,8 +87,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
|
||||
private:
|
||||
template <typename TLiteral, typename TIteratorToExpression>
|
||||
void PostVisitCollectionLiteral(
|
||||
TLiteral &literal, TIteratorToExpression iterator_to_expression) {
|
||||
void PostVisitCollectionLiteral(TLiteral &literal, TIteratorToExpression iterator_to_expression) {
|
||||
// If there is an aggregation in the list, and there are group-bys, then we
|
||||
// need to add the group-bys manually. If there are no aggregations, the
|
||||
// whole list will be added as a group-by.
|
||||
@ -115,8 +107,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
}
|
||||
has_aggregation_.emplace_back(has_aggr);
|
||||
if (has_aggr) {
|
||||
for (auto expression_ptr : literal_group_by)
|
||||
group_by_.emplace_back(expression_ptr);
|
||||
for (auto expression_ptr : literal_group_by) group_by_.emplace_back(expression_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
@ -130,8 +121,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
}
|
||||
|
||||
bool PostVisit(MapLiteral &map_literal) override {
|
||||
MG_ASSERT(
|
||||
map_literal.elements_.size() <= has_aggregation_.size(),
|
||||
MG_ASSERT(map_literal.elements_.size() <= has_aggregation_.size(),
|
||||
"Expected has_aggregation_ flags as much as there are map elements.");
|
||||
PostVisitCollectionLiteral(map_literal, [](auto it) { return it->second; });
|
||||
return true;
|
||||
@ -141,8 +131,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
// Remove the symbol which is bound by all, because we are only interested
|
||||
// in free (unbound) symbols.
|
||||
used_symbols_.erase(symbol_table_.at(*all.identifier_));
|
||||
MG_ASSERT(has_aggregation_.size() >= 3U,
|
||||
"Expected 3 has_aggregation_ flags for ALL arguments");
|
||||
MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for ALL arguments");
|
||||
bool has_aggr = false;
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
has_aggr = has_aggr || has_aggregation_.back();
|
||||
@ -156,8 +145,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
// Remove the symbol which is bound by single, because we are only
|
||||
// interested in free (unbound) symbols.
|
||||
used_symbols_.erase(symbol_table_.at(*single.identifier_));
|
||||
MG_ASSERT(has_aggregation_.size() >= 3U,
|
||||
"Expected 3 has_aggregation_ flags for SINGLE arguments");
|
||||
MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for SINGLE arguments");
|
||||
bool has_aggr = false;
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
has_aggr = has_aggr || has_aggregation_.back();
|
||||
@ -171,8 +159,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
// Remove the symbol which is bound by any, because we are only interested
|
||||
// in free (unbound) symbols.
|
||||
used_symbols_.erase(symbol_table_.at(*any.identifier_));
|
||||
MG_ASSERT(has_aggregation_.size() >= 3U,
|
||||
"Expected 3 has_aggregation_ flags for ANY arguments");
|
||||
MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for ANY arguments");
|
||||
bool has_aggr = false;
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
has_aggr = has_aggr || has_aggregation_.back();
|
||||
@ -186,8 +173,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
// Remove the symbol which is bound by none, because we are only interested
|
||||
// in free (unbound) symbols.
|
||||
used_symbols_.erase(symbol_table_.at(*none.identifier_));
|
||||
MG_ASSERT(has_aggregation_.size() >= 3U,
|
||||
"Expected 3 has_aggregation_ flags for NONE arguments");
|
||||
MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for NONE arguments");
|
||||
bool has_aggr = false;
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
has_aggr = has_aggr || has_aggregation_.back();
|
||||
@ -202,8 +188,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
// in free (unbound) symbols.
|
||||
used_symbols_.erase(symbol_table_.at(*reduce.accumulator_));
|
||||
used_symbols_.erase(symbol_table_.at(*reduce.identifier_));
|
||||
MG_ASSERT(has_aggregation_.size() >= 5U,
|
||||
"Expected 5 has_aggregation_ flags for REDUCE arguments");
|
||||
MG_ASSERT(has_aggregation_.size() >= 5U, "Expected 5 has_aggregation_ flags for REDUCE arguments");
|
||||
bool has_aggr = false;
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
has_aggr = has_aggr || has_aggregation_.back();
|
||||
@ -215,8 +200,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
|
||||
bool PostVisit(Coalesce &coalesce) override {
|
||||
MG_ASSERT(has_aggregation_.size() >= coalesce.expressions_.size(),
|
||||
"Expected >= {} has_aggregation_ flags for COALESCE arguments",
|
||||
has_aggregation_.size());
|
||||
"Expected >= {} has_aggregation_ flags for COALESCE arguments", has_aggregation_.size());
|
||||
bool has_aggr = false;
|
||||
for (size_t i = 0; i < coalesce.expressions_.size(); ++i) {
|
||||
has_aggr = has_aggr || has_aggregation_.back();
|
||||
@ -230,8 +214,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
// Remove the symbol bound by extract, because we are only interested
|
||||
// in free (unbound) symbols.
|
||||
used_symbols_.erase(symbol_table_.at(*extract.identifier_));
|
||||
MG_ASSERT(has_aggregation_.size() >= 3U,
|
||||
"Expected 3 has_aggregation_ flags for EXTRACT arguments");
|
||||
MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for EXTRACT arguments");
|
||||
bool has_aggr = false;
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
has_aggr = has_aggr || has_aggregation_.back();
|
||||
@ -310,8 +293,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
|
||||
#define VISIT_BINARY_OPERATOR(BinaryOperator) \
|
||||
bool PostVisit(BinaryOperator &op) override { \
|
||||
MG_ASSERT(has_aggregation_.size() >= 2U, \
|
||||
"Expected at least 2 has_aggregation_ flags."); \
|
||||
MG_ASSERT(has_aggregation_.size() >= 2U, "Expected at least 2 has_aggregation_ flags."); \
|
||||
/* has_aggregation_ stack is reversed, last result is from the 2nd */ \
|
||||
/* expression. */ \
|
||||
bool aggr2 = has_aggregation_.back(); \
|
||||
@ -351,8 +333,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
bool PostVisit(Aggregation &aggr) override {
|
||||
// Aggregation contains a virtual symbol, where the result will be stored.
|
||||
const auto &symbol = symbol_table_.at(aggr);
|
||||
aggregations_.emplace_back(Aggregate::Element{
|
||||
aggr.expression1_, aggr.expression2_, aggr.op_, symbol});
|
||||
aggregations_.emplace_back(Aggregate::Element{aggr.expression1_, aggr.expression2_, aggr.op_, symbol});
|
||||
// Aggregation expression1_ is optional in COUNT(*), and COLLECT_MAP uses
|
||||
// two expressions, so we can have 0, 1 or 2 elements on the
|
||||
// has_aggregation_stack for this Aggregation expression.
|
||||
@ -368,8 +349,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
}
|
||||
|
||||
bool PostVisit(NamedExpression &named_expr) override {
|
||||
MG_ASSERT(has_aggregation_.size() == 1U,
|
||||
"Expected to reduce has_aggregation_ to single boolean.");
|
||||
MG_ASSERT(has_aggregation_.size() == 1U, "Expected to reduce has_aggregation_ to single boolean.");
|
||||
if (!has_aggregation_.back()) {
|
||||
group_by_.emplace_back(named_expr.expression_);
|
||||
}
|
||||
@ -383,8 +363,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
}
|
||||
|
||||
bool PostVisit(RegexMatch ®ex_match) override {
|
||||
MG_ASSERT(has_aggregation_.size() >= 2U,
|
||||
"Expected 2 has_aggregation_ flags for RegexMatch arguments");
|
||||
MG_ASSERT(has_aggregation_.size() >= 2U, "Expected 2 has_aggregation_ flags for RegexMatch arguments");
|
||||
bool has_aggr = has_aggregation_.back();
|
||||
has_aggregation_.pop_back();
|
||||
has_aggregation_.back() |= has_aggr;
|
||||
@ -395,17 +374,14 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
// This should be used when body.all_identifiers is true, to generate
|
||||
// expressions for Produce operator.
|
||||
void ExpandUserSymbols() {
|
||||
MG_ASSERT(named_expressions_.empty(),
|
||||
"ExpandUserSymbols should be first to fill named_expressions_");
|
||||
MG_ASSERT(output_symbols_.empty(),
|
||||
"ExpandUserSymbols should be first to fill output_symbols_");
|
||||
MG_ASSERT(named_expressions_.empty(), "ExpandUserSymbols should be first to fill named_expressions_");
|
||||
MG_ASSERT(output_symbols_.empty(), "ExpandUserSymbols should be first to fill output_symbols_");
|
||||
for (const auto &symbol : bound_symbols_) {
|
||||
if (!symbol.user_declared()) {
|
||||
continue;
|
||||
}
|
||||
auto *ident = storage_.Create<Identifier>(symbol.name())->MapTo(symbol);
|
||||
auto *named_expr =
|
||||
storage_.Create<NamedExpression>(symbol.name(), ident)->MapTo(symbol);
|
||||
auto *named_expr = storage_.Create<NamedExpression>(symbol.name(), ident)->MapTo(symbol);
|
||||
// Fill output expressions and symbols with expanded identifiers.
|
||||
named_expressions_.emplace_back(named_expr);
|
||||
output_symbols_.emplace_back(symbol);
|
||||
@ -471,38 +447,30 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
std::vector<NamedExpression *> named_expressions_;
|
||||
};
|
||||
|
||||
std::unique_ptr<LogicalOperator> GenReturnBody(
|
||||
std::unique_ptr<LogicalOperator> input_op, bool advance_command,
|
||||
std::unique_ptr<LogicalOperator> GenReturnBody(std::unique_ptr<LogicalOperator> input_op, bool advance_command,
|
||||
const ReturnBodyContext &body, bool accumulate = false) {
|
||||
std::vector<Symbol> used_symbols(body.used_symbols().begin(),
|
||||
body.used_symbols().end());
|
||||
std::vector<Symbol> used_symbols(body.used_symbols().begin(), body.used_symbols().end());
|
||||
auto last_op = std::move(input_op);
|
||||
if (accumulate) {
|
||||
// We only advance the command in Accumulate. This is done for WITH clause,
|
||||
// when the first part updated the database. RETURN clause may only need an
|
||||
// accumulation after updates, without advancing the command.
|
||||
last_op = std::make_unique<Accumulate>(std::move(last_op), used_symbols,
|
||||
advance_command);
|
||||
last_op = std::make_unique<Accumulate>(std::move(last_op), used_symbols, advance_command);
|
||||
}
|
||||
if (!body.aggregations().empty()) {
|
||||
// When we have aggregation, SKIP/LIMIT should always come after it.
|
||||
std::vector<Symbol> remember(body.group_by_used_symbols().begin(),
|
||||
body.group_by_used_symbols().end());
|
||||
last_op = std::make_unique<Aggregate>(
|
||||
std::move(last_op), body.aggregations(), body.group_by(), remember);
|
||||
std::vector<Symbol> remember(body.group_by_used_symbols().begin(), body.group_by_used_symbols().end());
|
||||
last_op = std::make_unique<Aggregate>(std::move(last_op), body.aggregations(), body.group_by(), remember);
|
||||
}
|
||||
last_op =
|
||||
std::make_unique<Produce>(std::move(last_op), body.named_expressions());
|
||||
last_op = std::make_unique<Produce>(std::move(last_op), body.named_expressions());
|
||||
// Distinct in ReturnBody only makes Produce values unique, so plan after it.
|
||||
if (body.distinct()) {
|
||||
last_op =
|
||||
std::make_unique<Distinct>(std::move(last_op), body.output_symbols());
|
||||
last_op = std::make_unique<Distinct>(std::move(last_op), body.output_symbols());
|
||||
}
|
||||
// Like Where, OrderBy can read from symbols established by named expressions
|
||||
// in Produce, so it must come after it.
|
||||
if (!body.order_by().empty()) {
|
||||
last_op = std::make_unique<OrderBy>(std::move(last_op), body.order_by(),
|
||||
body.output_symbols());
|
||||
last_op = std::make_unique<OrderBy>(std::move(last_op), body.order_by(), body.output_symbols());
|
||||
}
|
||||
// Finally, Skip and Limit must come after OrderBy.
|
||||
if (body.skip()) {
|
||||
@ -515,8 +483,7 @@ std::unique_ptr<LogicalOperator> GenReturnBody(
|
||||
// Where may see new symbols so it comes after we generate Produce and in
|
||||
// general, comes after any OrderBy, Skip or Limit.
|
||||
if (body.where()) {
|
||||
last_op =
|
||||
std::make_unique<Filter>(std::move(last_op), body.where()->expression_);
|
||||
last_op = std::make_unique<Filter>(std::move(last_op), body.where()->expression_);
|
||||
}
|
||||
return last_op;
|
||||
}
|
||||
@ -525,13 +492,11 @@ std::unique_ptr<LogicalOperator> GenReturnBody(
|
||||
|
||||
namespace impl {
|
||||
|
||||
Expression *ExtractFilters(const std::unordered_set<Symbol> &bound_symbols,
|
||||
Filters &filters, AstStorage &storage) {
|
||||
Expression *ExtractFilters(const std::unordered_set<Symbol> &bound_symbols, Filters &filters, AstStorage &storage) {
|
||||
Expression *filter_expr = nullptr;
|
||||
for (auto filters_it = filters.begin(); filters_it != filters.end();) {
|
||||
if (HasBoundFilterSymbols(bound_symbols, *filters_it)) {
|
||||
filter_expr = impl::BoolJoin<AndOperator>(storage, filter_expr,
|
||||
filters_it->expression);
|
||||
filter_expr = impl::BoolJoin<AndOperator>(storage, filter_expr, filters_it->expression);
|
||||
filters_it = filters.erase(filters_it);
|
||||
} else {
|
||||
filters_it++;
|
||||
@ -540,8 +505,7 @@ Expression *ExtractFilters(const std::unordered_set<Symbol> &bound_symbols,
|
||||
return filter_expr;
|
||||
}
|
||||
|
||||
std::unique_ptr<LogicalOperator> GenFilters(
|
||||
std::unique_ptr<LogicalOperator> last_op,
|
||||
std::unique_ptr<LogicalOperator> GenFilters(std::unique_ptr<LogicalOperator> last_op,
|
||||
const std::unordered_set<Symbol> &bound_symbols, Filters &filters,
|
||||
AstStorage &storage) {
|
||||
auto *filter_expr = ExtractFilters(bound_symbols, filters, storage);
|
||||
@ -551,8 +515,7 @@ std::unique_ptr<LogicalOperator> GenFilters(
|
||||
return last_op;
|
||||
}
|
||||
|
||||
std::unique_ptr<LogicalOperator> GenNamedPaths(
|
||||
std::unique_ptr<LogicalOperator> last_op,
|
||||
std::unique_ptr<LogicalOperator> GenNamedPaths(std::unique_ptr<LogicalOperator> last_op,
|
||||
std::unordered_set<Symbol> &bound_symbols,
|
||||
std::unordered_map<Symbol, std::vector<Symbol>> &named_paths) {
|
||||
auto all_are_bound = [&bound_symbols](const std::vector<Symbol> &syms) {
|
||||
@ -560,11 +523,9 @@ std::unique_ptr<LogicalOperator> GenNamedPaths(
|
||||
if (bound_symbols.find(sym) == bound_symbols.end()) return false;
|
||||
return true;
|
||||
};
|
||||
for (auto named_path_it = named_paths.begin();
|
||||
named_path_it != named_paths.end();) {
|
||||
for (auto named_path_it = named_paths.begin(); named_path_it != named_paths.end();) {
|
||||
if (all_are_bound(named_path_it->second)) {
|
||||
last_op = std::make_unique<ConstructNamedPath>(
|
||||
std::move(last_op), named_path_it->first,
|
||||
last_op = std::make_unique<ConstructNamedPath>(std::move(last_op), named_path_it->first,
|
||||
std::move(named_path_it->second));
|
||||
bound_symbols.insert(named_path_it->first);
|
||||
named_path_it = named_paths.erase(named_path_it);
|
||||
@ -576,8 +537,7 @@ std::unique_ptr<LogicalOperator> GenNamedPaths(
|
||||
return last_op;
|
||||
}
|
||||
|
||||
std::unique_ptr<LogicalOperator> GenReturn(
|
||||
Return &ret, std::unique_ptr<LogicalOperator> input_op,
|
||||
std::unique_ptr<LogicalOperator> GenReturn(Return &ret, std::unique_ptr<LogicalOperator> input_op,
|
||||
SymbolTable &symbol_table, bool is_write,
|
||||
const std::unordered_set<Symbol> &bound_symbols, AstStorage &storage) {
|
||||
// Similar to WITH clause, but we want to accumulate when the query writes to
|
||||
@ -592,8 +552,7 @@ std::unique_ptr<LogicalOperator> GenReturn(
|
||||
return GenReturnBody(std::move(input_op), advance_command, body, accumulate);
|
||||
}
|
||||
|
||||
std::unique_ptr<LogicalOperator> GenWith(
|
||||
With &with, std::unique_ptr<LogicalOperator> input_op,
|
||||
std::unique_ptr<LogicalOperator> GenWith(With &with, std::unique_ptr<LogicalOperator> input_op,
|
||||
SymbolTable &symbol_table, bool is_write,
|
||||
std::unordered_set<Symbol> &bound_symbols, AstStorage &storage) {
|
||||
// WITH clause is Accumulate/Aggregate (advance_command) + Produce and
|
||||
@ -603,10 +562,8 @@ std::unique_ptr<LogicalOperator> GenWith(
|
||||
bool accumulate = is_write;
|
||||
// No need to advance the command if we only performed reads.
|
||||
bool advance_command = is_write;
|
||||
ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage,
|
||||
with.where_);
|
||||
auto last_op =
|
||||
GenReturnBody(std::move(input_op), advance_command, body, accumulate);
|
||||
ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage, with.where_);
|
||||
auto last_op = GenReturnBody(std::move(input_op), advance_command, body, accumulate);
|
||||
// Reset bound symbols, so that only those in WITH are exposed.
|
||||
bound_symbols.clear();
|
||||
for (const auto &symbol : body.output_symbols()) {
|
||||
@ -615,11 +572,9 @@ std::unique_ptr<LogicalOperator> GenWith(
|
||||
return last_op;
|
||||
}
|
||||
|
||||
std::unique_ptr<LogicalOperator> GenUnion(
|
||||
const CypherUnion &cypher_union, std::shared_ptr<LogicalOperator> left_op,
|
||||
std::unique_ptr<LogicalOperator> GenUnion(const CypherUnion &cypher_union, std::shared_ptr<LogicalOperator> left_op,
|
||||
std::shared_ptr<LogicalOperator> right_op, SymbolTable &symbol_table) {
|
||||
return std::make_unique<Union>(left_op, right_op, cypher_union.union_symbols_,
|
||||
left_op->OutputSymbols(symbol_table),
|
||||
return std::make_unique<Union>(left_op, right_op, cypher_union.union_symbols_, left_op->OutputSymbols(symbol_table),
|
||||
right_op->OutputSymbols(symbol_table));
|
||||
}
|
||||
|
||||
|
@ -38,8 +38,7 @@ struct PlanningContext {
|
||||
};
|
||||
|
||||
template <class TDbAccessor>
|
||||
auto MakePlanningContext(AstStorage *ast_storage, SymbolTable *symbol_table,
|
||||
CypherQuery *query, TDbAccessor *db) {
|
||||
auto MakePlanningContext(AstStorage *ast_storage, SymbolTable *symbol_table, CypherQuery *query, TDbAccessor *db) {
|
||||
return PlanningContext<TDbAccessor>{symbol_table, ast_storage, query, db};
|
||||
}
|
||||
|
||||
@ -66,11 +65,9 @@ namespace impl {
|
||||
// Iterates over `Filters` joining them in one expression via
|
||||
// `AndOperator` if symbols they use are bound.. All the joined filters are
|
||||
// removed from `Filters`.
|
||||
Expression *ExtractFilters(const std::unordered_set<Symbol> &, Filters &,
|
||||
AstStorage &);
|
||||
Expression *ExtractFilters(const std::unordered_set<Symbol> &, Filters &, AstStorage &);
|
||||
|
||||
std::unique_ptr<LogicalOperator> GenFilters(std::unique_ptr<LogicalOperator>,
|
||||
const std::unordered_set<Symbol> &,
|
||||
std::unique_ptr<LogicalOperator> GenFilters(std::unique_ptr<LogicalOperator>, const std::unordered_set<Symbol> &,
|
||||
Filters &, AstStorage &);
|
||||
|
||||
/// Utility function for iterating pattern atoms and accumulating a result.
|
||||
@ -92,8 +89,7 @@ std::unique_ptr<LogicalOperator> GenFilters(std::unique_ptr<LogicalOperator>,
|
||||
// TODO: It might be a good idea to move this somewhere else, for easier usage
|
||||
// in other files.
|
||||
template <typename T>
|
||||
auto ReducePattern(
|
||||
Pattern &pattern, std::function<T(NodeAtom *)> base,
|
||||
auto ReducePattern(Pattern &pattern, std::function<T(NodeAtom *)> base,
|
||||
std::function<T(T, NodeAtom *, EdgeAtom *, NodeAtom *)> collect) {
|
||||
MG_ASSERT(!pattern.atoms_.empty(), "Missing atoms in pattern");
|
||||
auto atoms_it = pattern.atoms_.begin();
|
||||
@ -104,8 +100,7 @@ auto ReducePattern(
|
||||
while (atoms_it != pattern.atoms_.end()) {
|
||||
auto edge = utils::Downcast<EdgeAtom>(*atoms_it++);
|
||||
MG_ASSERT(edge, "Expected an edge atom in pattern.");
|
||||
MG_ASSERT(atoms_it != pattern.atoms_.end(),
|
||||
"Edge atom should not end the pattern.");
|
||||
MG_ASSERT(atoms_it != pattern.atoms_.end(), "Edge atom should not end the pattern.");
|
||||
auto prev_node = current_node;
|
||||
current_node = utils::Downcast<NodeAtom>(*atoms_it++);
|
||||
MG_ASSERT(current_node, "Expected a node atom in pattern.");
|
||||
@ -118,28 +113,23 @@ auto ReducePattern(
|
||||
// If so, it creates a logical operator for named path generation, binds its
|
||||
// symbol, removes that path from the collection of unhandled ones and returns
|
||||
// the new op. Otherwise, returns `last_op`.
|
||||
std::unique_ptr<LogicalOperator> GenNamedPaths(
|
||||
std::unique_ptr<LogicalOperator> last_op,
|
||||
std::unique_ptr<LogicalOperator> GenNamedPaths(std::unique_ptr<LogicalOperator> last_op,
|
||||
std::unordered_set<Symbol> &bound_symbols,
|
||||
std::unordered_map<Symbol, std::vector<Symbol>> &named_paths);
|
||||
|
||||
std::unique_ptr<LogicalOperator> GenReturn(
|
||||
Return &ret, std::unique_ptr<LogicalOperator> input_op,
|
||||
std::unique_ptr<LogicalOperator> GenReturn(Return &ret, std::unique_ptr<LogicalOperator> input_op,
|
||||
SymbolTable &symbol_table, bool is_write,
|
||||
const std::unordered_set<Symbol> &bound_symbols, AstStorage &storage);
|
||||
|
||||
std::unique_ptr<LogicalOperator> GenWith(
|
||||
With &with, std::unique_ptr<LogicalOperator> input_op,
|
||||
std::unique_ptr<LogicalOperator> GenWith(With &with, std::unique_ptr<LogicalOperator> input_op,
|
||||
SymbolTable &symbol_table, bool is_write,
|
||||
std::unordered_set<Symbol> &bound_symbols, AstStorage &storage);
|
||||
|
||||
std::unique_ptr<LogicalOperator> GenUnion(
|
||||
const CypherUnion &cypher_union, std::shared_ptr<LogicalOperator> left_op,
|
||||
std::unique_ptr<LogicalOperator> GenUnion(const CypherUnion &cypher_union, std::shared_ptr<LogicalOperator> left_op,
|
||||
std::shared_ptr<LogicalOperator> right_op, SymbolTable &symbol_table);
|
||||
|
||||
template <class TBoolOperator>
|
||||
Expression *BoolJoin(AstStorage &storage, Expression *expr1,
|
||||
Expression *expr2) {
|
||||
Expression *BoolJoin(AstStorage &storage, Expression *expr1, Expression *expr2) {
|
||||
if (expr1 && expr2) {
|
||||
return storage.Create<TBoolOperator>(expr1, expr2);
|
||||
}
|
||||
@ -166,52 +156,40 @@ class RuleBasedPlanner {
|
||||
// Set to true if a query command writes to the database.
|
||||
bool is_write = false;
|
||||
for (const auto &query_part : query_parts) {
|
||||
MatchContext match_ctx{query_part.matching, *context.symbol_table,
|
||||
context.bound_symbols};
|
||||
MatchContext match_ctx{query_part.matching, *context.symbol_table, context.bound_symbols};
|
||||
input_op = PlanMatching(match_ctx, std::move(input_op));
|
||||
for (const auto &matching : query_part.optional_matching) {
|
||||
MatchContext opt_ctx{matching, *context.symbol_table,
|
||||
context.bound_symbols};
|
||||
MatchContext opt_ctx{matching, *context.symbol_table, context.bound_symbols};
|
||||
auto match_op = PlanMatching(opt_ctx, nullptr);
|
||||
if (match_op) {
|
||||
input_op = std::make_unique<Optional>(
|
||||
std::move(input_op), std::move(match_op), opt_ctx.new_symbols);
|
||||
input_op = std::make_unique<Optional>(std::move(input_op), std::move(match_op), opt_ctx.new_symbols);
|
||||
}
|
||||
}
|
||||
uint64_t merge_id = 0;
|
||||
for (auto *clause : query_part.remaining_clauses) {
|
||||
MG_ASSERT(!utils::IsSubtype(*clause, Match::kType),
|
||||
"Unexpected Match in remaining clauses");
|
||||
MG_ASSERT(!utils::IsSubtype(*clause, Match::kType), "Unexpected Match in remaining clauses");
|
||||
if (auto *ret = utils::Downcast<Return>(clause)) {
|
||||
input_op = impl::GenReturn(
|
||||
*ret, std::move(input_op), *context.symbol_table, is_write,
|
||||
context.bound_symbols, *context.ast_storage);
|
||||
input_op = impl::GenReturn(*ret, std::move(input_op), *context.symbol_table, is_write, context.bound_symbols,
|
||||
*context.ast_storage);
|
||||
} else if (auto *merge = utils::Downcast<query::Merge>(clause)) {
|
||||
input_op = GenMerge(*merge, std::move(input_op),
|
||||
query_part.merge_matching[merge_id++]);
|
||||
input_op = GenMerge(*merge, std::move(input_op), query_part.merge_matching[merge_id++]);
|
||||
// Treat MERGE clause as write, because we do not know if it will
|
||||
// create anything.
|
||||
is_write = true;
|
||||
} else if (auto *with = utils::Downcast<query::With>(clause)) {
|
||||
input_op = impl::GenWith(*with, std::move(input_op),
|
||||
*context.symbol_table, is_write,
|
||||
context.bound_symbols, *context.ast_storage);
|
||||
input_op = impl::GenWith(*with, std::move(input_op), *context.symbol_table, is_write, context.bound_symbols,
|
||||
*context.ast_storage);
|
||||
// WITH clause advances the command, so reset the flag.
|
||||
is_write = false;
|
||||
} else if (auto op = HandleWriteClause(clause, input_op,
|
||||
*context.symbol_table,
|
||||
context.bound_symbols)) {
|
||||
} else if (auto op = HandleWriteClause(clause, input_op, *context.symbol_table, context.bound_symbols)) {
|
||||
is_write = true;
|
||||
input_op = std::move(op);
|
||||
} else if (auto *unwind = utils::Downcast<query::Unwind>(clause)) {
|
||||
const auto &symbol =
|
||||
context.symbol_table->at(*unwind->named_expression_);
|
||||
const auto &symbol = context.symbol_table->at(*unwind->named_expression_);
|
||||
context.bound_symbols.insert(symbol);
|
||||
input_op = std::make_unique<plan::Unwind>(
|
||||
std::move(input_op), unwind->named_expression_->expression_,
|
||||
symbol);
|
||||
} else if (auto *call_proc =
|
||||
utils::Downcast<query::CallProcedure>(clause)) {
|
||||
input_op =
|
||||
std::make_unique<plan::Unwind>(std::move(input_op), unwind->named_expression_->expression_, symbol);
|
||||
} else if (auto *call_proc = utils::Downcast<query::CallProcedure>(clause)) {
|
||||
std::vector<Symbol> result_symbols;
|
||||
result_symbols.reserve(call_proc->result_identifiers_.size());
|
||||
for (const auto *ident : call_proc->result_identifiers_) {
|
||||
@ -223,13 +201,10 @@ class RuleBasedPlanner {
|
||||
// need to plan this operator with Accumulate and pass in
|
||||
// storage::View::NEW.
|
||||
input_op = std::make_unique<plan::CallProcedure>(
|
||||
std::move(input_op), call_proc->procedure_name_,
|
||||
call_proc->arguments_, call_proc->result_fields_, result_symbols,
|
||||
call_proc->memory_limit_, call_proc->memory_scale_);
|
||||
std::move(input_op), call_proc->procedure_name_, call_proc->arguments_, call_proc->result_fields_,
|
||||
result_symbols, call_proc->memory_limit_, call_proc->memory_scale_);
|
||||
} else {
|
||||
throw utils::NotYetImplemented(
|
||||
"clause '{}' conversion to operator(s)",
|
||||
clause->GetTypeInfo().name);
|
||||
throw utils::NotYetImplemented("clause '{}' conversion to operator(s)", clause->GetTypeInfo().name);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -239,32 +214,23 @@ class RuleBasedPlanner {
|
||||
private:
|
||||
TPlanningContext *context_;
|
||||
|
||||
storage::LabelId GetLabel(LabelIx label) {
|
||||
return context_->db->NameToLabel(label.name);
|
||||
}
|
||||
storage::LabelId GetLabel(LabelIx label) { return context_->db->NameToLabel(label.name); }
|
||||
|
||||
storage::PropertyId GetProperty(PropertyIx prop) {
|
||||
return context_->db->NameToProperty(prop.name);
|
||||
}
|
||||
storage::PropertyId GetProperty(PropertyIx prop) { return context_->db->NameToProperty(prop.name); }
|
||||
|
||||
storage::EdgeTypeId GetEdgeType(EdgeTypeIx edge_type) {
|
||||
return context_->db->NameToEdgeType(edge_type.name);
|
||||
}
|
||||
storage::EdgeTypeId GetEdgeType(EdgeTypeIx edge_type) { return context_->db->NameToEdgeType(edge_type.name); }
|
||||
|
||||
std::unique_ptr<LogicalOperator> GenCreate(
|
||||
Create &create, std::unique_ptr<LogicalOperator> input_op,
|
||||
std::unique_ptr<LogicalOperator> GenCreate(Create &create, std::unique_ptr<LogicalOperator> input_op,
|
||||
const SymbolTable &symbol_table,
|
||||
std::unordered_set<Symbol> &bound_symbols) {
|
||||
auto last_op = std::move(input_op);
|
||||
for (auto pattern : create.patterns_) {
|
||||
last_op = GenCreateForPattern(*pattern, std::move(last_op), symbol_table,
|
||||
bound_symbols);
|
||||
last_op = GenCreateForPattern(*pattern, std::move(last_op), symbol_table, bound_symbols);
|
||||
}
|
||||
return last_op;
|
||||
}
|
||||
|
||||
std::unique_ptr<LogicalOperator> GenCreateForPattern(
|
||||
Pattern &pattern, std::unique_ptr<LogicalOperator> input_op,
|
||||
std::unique_ptr<LogicalOperator> GenCreateForPattern(Pattern &pattern, std::unique_ptr<LogicalOperator> input_op,
|
||||
const SymbolTable &symbol_table,
|
||||
std::unordered_set<Symbol> &bound_symbols) {
|
||||
auto node_to_creation_info = [&](const NodeAtom &node) {
|
||||
@ -292,8 +258,7 @@ class RuleBasedPlanner {
|
||||
}
|
||||
};
|
||||
|
||||
auto collect = [&](std::unique_ptr<LogicalOperator> last_op,
|
||||
NodeAtom *prev_node, EdgeAtom *edge, NodeAtom *node) {
|
||||
auto collect = [&](std::unique_ptr<LogicalOperator> last_op, NodeAtom *prev_node, EdgeAtom *edge, NodeAtom *node) {
|
||||
// Store the symbol from the first node as the input to CreateExpand.
|
||||
const auto &input_symbol = symbol_table.at(*prev_node->identifier_);
|
||||
// If the expand node was already bound, then we need to indicate this,
|
||||
@ -312,27 +277,18 @@ class RuleBasedPlanner {
|
||||
for (const auto &kv : edge->properties_) {
|
||||
properties.push_back({GetProperty(kv.first), kv.second});
|
||||
}
|
||||
MG_ASSERT(
|
||||
edge->edge_types_.size() == 1,
|
||||
"Creating an edge with a single type should be required by syntax");
|
||||
EdgeCreationInfo edge_info{edge_symbol, properties,
|
||||
GetEdgeType(edge->edge_types_[0]),
|
||||
edge->direction_};
|
||||
return std::make_unique<CreateExpand>(node_info, edge_info,
|
||||
std::move(last_op), input_symbol,
|
||||
node_existing);
|
||||
MG_ASSERT(edge->edge_types_.size() == 1, "Creating an edge with a single type should be required by syntax");
|
||||
EdgeCreationInfo edge_info{edge_symbol, properties, GetEdgeType(edge->edge_types_[0]), edge->direction_};
|
||||
return std::make_unique<CreateExpand>(node_info, edge_info, std::move(last_op), input_symbol, node_existing);
|
||||
};
|
||||
|
||||
auto last_op = impl::ReducePattern<std::unique_ptr<LogicalOperator>>(
|
||||
pattern, base, collect);
|
||||
auto last_op = impl::ReducePattern<std::unique_ptr<LogicalOperator>>(pattern, base, collect);
|
||||
|
||||
// If the pattern is named, append the path constructing logical operator.
|
||||
if (pattern.identifier_->user_declared_) {
|
||||
std::vector<Symbol> path_elements;
|
||||
for (const PatternAtom *atom : pattern.atoms_)
|
||||
path_elements.emplace_back(symbol_table.at(*atom->identifier_));
|
||||
last_op = std::make_unique<ConstructNamedPath>(
|
||||
std::move(last_op), symbol_table.at(*pattern.identifier_),
|
||||
for (const PatternAtom *atom : pattern.atoms_) path_elements.emplace_back(symbol_table.at(*atom->identifier_));
|
||||
last_op = std::make_unique<ConstructNamedPath>(std::move(last_op), symbol_table.at(*pattern.identifier_),
|
||||
path_elements);
|
||||
}
|
||||
|
||||
@ -342,26 +298,20 @@ class RuleBasedPlanner {
|
||||
// Generate an operator for a clause which writes to the database. Ownership
|
||||
// of input_op is transferred to the newly created operator. If the clause
|
||||
// isn't handled, returns nullptr and input_op is left as is.
|
||||
std::unique_ptr<LogicalOperator> HandleWriteClause(
|
||||
Clause *clause, std::unique_ptr<LogicalOperator> &input_op,
|
||||
std::unique_ptr<LogicalOperator> HandleWriteClause(Clause *clause, std::unique_ptr<LogicalOperator> &input_op,
|
||||
const SymbolTable &symbol_table,
|
||||
std::unordered_set<Symbol> &bound_symbols) {
|
||||
if (auto *create = utils::Downcast<Create>(clause)) {
|
||||
return GenCreate(*create, std::move(input_op), symbol_table,
|
||||
bound_symbols);
|
||||
return GenCreate(*create, std::move(input_op), symbol_table, bound_symbols);
|
||||
} else if (auto *del = utils::Downcast<query::Delete>(clause)) {
|
||||
return std::make_unique<plan::Delete>(std::move(input_op),
|
||||
del->expressions_, del->detach_);
|
||||
return std::make_unique<plan::Delete>(std::move(input_op), del->expressions_, del->detach_);
|
||||
} else if (auto *set = utils::Downcast<query::SetProperty>(clause)) {
|
||||
return std::make_unique<plan::SetProperty>(
|
||||
std::move(input_op), GetProperty(set->property_lookup_->property_),
|
||||
return std::make_unique<plan::SetProperty>(std::move(input_op), GetProperty(set->property_lookup_->property_),
|
||||
set->property_lookup_, set->expression_);
|
||||
} else if (auto *set = utils::Downcast<query::SetProperties>(clause)) {
|
||||
auto op = set->update_ ? plan::SetProperties::Op::UPDATE
|
||||
: plan::SetProperties::Op::REPLACE;
|
||||
auto op = set->update_ ? plan::SetProperties::Op::UPDATE : plan::SetProperties::Op::REPLACE;
|
||||
const auto &input_symbol = symbol_table.at(*set->identifier_);
|
||||
return std::make_unique<plan::SetProperties>(
|
||||
std::move(input_op), input_symbol, set->expression_, op);
|
||||
return std::make_unique<plan::SetProperties>(std::move(input_op), input_symbol, set->expression_, op);
|
||||
} else if (auto *set = utils::Downcast<query::SetLabels>(clause)) {
|
||||
const auto &input_symbol = symbol_table.at(*set->identifier_);
|
||||
std::vector<storage::LabelId> labels;
|
||||
@ -369,11 +319,9 @@ class RuleBasedPlanner {
|
||||
for (const auto &label : set->labels_) {
|
||||
labels.push_back(GetLabel(label));
|
||||
}
|
||||
return std::make_unique<plan::SetLabels>(std::move(input_op),
|
||||
input_symbol, labels);
|
||||
return std::make_unique<plan::SetLabels>(std::move(input_op), input_symbol, labels);
|
||||
} else if (auto *rem = utils::Downcast<query::RemoveProperty>(clause)) {
|
||||
return std::make_unique<plan::RemoveProperty>(
|
||||
std::move(input_op), GetProperty(rem->property_lookup_->property_),
|
||||
return std::make_unique<plan::RemoveProperty>(std::move(input_op), GetProperty(rem->property_lookup_->property_),
|
||||
rem->property_lookup_);
|
||||
} else if (auto *rem = utils::Downcast<query::RemoveLabels>(clause)) {
|
||||
const auto &input_symbol = symbol_table.at(*rem->identifier_);
|
||||
@ -382,14 +330,13 @@ class RuleBasedPlanner {
|
||||
for (const auto &label : rem->labels_) {
|
||||
labels.push_back(GetLabel(label));
|
||||
}
|
||||
return std::make_unique<plan::RemoveLabels>(std::move(input_op),
|
||||
input_symbol, labels);
|
||||
return std::make_unique<plan::RemoveLabels>(std::move(input_op), input_symbol, labels);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<LogicalOperator> PlanMatching(
|
||||
MatchContext &match_context, std::unique_ptr<LogicalOperator> input_op) {
|
||||
std::unique_ptr<LogicalOperator> PlanMatching(MatchContext &match_context,
|
||||
std::unique_ptr<LogicalOperator> input_op) {
|
||||
auto &bound_symbols = match_context.bound_symbols;
|
||||
auto &storage = *context_->ast_storage;
|
||||
const auto &symbol_table = match_context.symbol_table;
|
||||
@ -401,21 +348,16 @@ class RuleBasedPlanner {
|
||||
// Try to generate any filters even before the 1st match operator. This
|
||||
// optimizes the optional match which filters only on symbols bound in
|
||||
// regular match.
|
||||
auto last_op =
|
||||
impl::GenFilters(std::move(input_op), bound_symbols, filters, storage);
|
||||
auto last_op = impl::GenFilters(std::move(input_op), bound_symbols, filters, storage);
|
||||
for (const auto &expansion : matching.expansions) {
|
||||
const auto &node1_symbol = symbol_table.at(*expansion.node1->identifier_);
|
||||
if (bound_symbols.insert(node1_symbol).second) {
|
||||
// We have just bound this symbol, so generate ScanAll which fills it.
|
||||
last_op = std::make_unique<ScanAll>(std::move(last_op), node1_symbol,
|
||||
match_context.view);
|
||||
last_op = std::make_unique<ScanAll>(std::move(last_op), node1_symbol, match_context.view);
|
||||
match_context.new_symbols.emplace_back(node1_symbol);
|
||||
last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters,
|
||||
storage);
|
||||
last_op =
|
||||
impl::GenNamedPaths(std::move(last_op), bound_symbols, named_paths);
|
||||
last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters,
|
||||
storage);
|
||||
last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters, storage);
|
||||
last_op = impl::GenNamedPaths(std::move(last_op), bound_symbols, named_paths);
|
||||
last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters, storage);
|
||||
}
|
||||
// We have an edge, so generate Expand.
|
||||
if (expansion.edge) {
|
||||
@ -423,12 +365,10 @@ class RuleBasedPlanner {
|
||||
// If the expand symbols were already bound, then we need to indicate
|
||||
// that they exist. The Expand will then check whether the pattern holds
|
||||
// instead of writing the expansion to symbols.
|
||||
const auto &node_symbol =
|
||||
symbol_table.at(*expansion.node2->identifier_);
|
||||
const auto &node_symbol = symbol_table.at(*expansion.node2->identifier_);
|
||||
auto existing_node = utils::Contains(bound_symbols, node_symbol);
|
||||
const auto &edge_symbol = symbol_table.at(*edge->identifier_);
|
||||
MG_ASSERT(!utils::Contains(bound_symbols, edge_symbol),
|
||||
"Existing edges are not supported");
|
||||
MG_ASSERT(!utils::Contains(bound_symbols, edge_symbol), "Existing edges are not supported");
|
||||
std::vector<storage::EdgeTypeId> edge_types;
|
||||
edge_types.reserve(edge->edge_types_.size());
|
||||
for (const auto &type : edge->edge_types_) {
|
||||
@ -439,8 +379,7 @@ class RuleBasedPlanner {
|
||||
std::optional<Symbol> total_weight;
|
||||
|
||||
if (edge->type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) {
|
||||
weight_lambda.emplace(ExpansionLambda{
|
||||
symbol_table.at(*edge->weight_lambda_.inner_edge),
|
||||
weight_lambda.emplace(ExpansionLambda{symbol_table.at(*edge->weight_lambda_.inner_edge),
|
||||
symbol_table.at(*edge->weight_lambda_.inner_node),
|
||||
edge->weight_lambda_.expression});
|
||||
|
||||
@ -448,37 +387,28 @@ class RuleBasedPlanner {
|
||||
}
|
||||
|
||||
ExpansionLambda filter_lambda;
|
||||
filter_lambda.inner_edge_symbol =
|
||||
symbol_table.at(*edge->filter_lambda_.inner_edge);
|
||||
filter_lambda.inner_node_symbol =
|
||||
symbol_table.at(*edge->filter_lambda_.inner_node);
|
||||
filter_lambda.inner_edge_symbol = symbol_table.at(*edge->filter_lambda_.inner_edge);
|
||||
filter_lambda.inner_node_symbol = symbol_table.at(*edge->filter_lambda_.inner_node);
|
||||
{
|
||||
// Bind the inner edge and node symbols so they're available for
|
||||
// inline filtering in ExpandVariable.
|
||||
bool inner_edge_bound =
|
||||
bound_symbols.insert(filter_lambda.inner_edge_symbol).second;
|
||||
bool inner_node_bound =
|
||||
bound_symbols.insert(filter_lambda.inner_node_symbol).second;
|
||||
MG_ASSERT(inner_edge_bound && inner_node_bound,
|
||||
"An inner edge and node can't be bound from before");
|
||||
bool inner_edge_bound = bound_symbols.insert(filter_lambda.inner_edge_symbol).second;
|
||||
bool inner_node_bound = bound_symbols.insert(filter_lambda.inner_node_symbol).second;
|
||||
MG_ASSERT(inner_edge_bound && inner_node_bound, "An inner edge and node can't be bound from before");
|
||||
}
|
||||
// Join regular filters with lambda filter expression, so that they
|
||||
// are done inline together. Semantic analysis should guarantee that
|
||||
// lambda filtering uses bound symbols.
|
||||
filter_lambda.expression = impl::BoolJoin<AndOperator>(
|
||||
storage, impl::ExtractFilters(bound_symbols, filters, storage),
|
||||
edge->filter_lambda_.expression);
|
||||
storage, impl::ExtractFilters(bound_symbols, filters, storage), edge->filter_lambda_.expression);
|
||||
// At this point it's possible we have leftover filters for inline
|
||||
// filtering (they use the inner symbols. If they were not collected,
|
||||
// we have to remove them manually because no other filter-extraction
|
||||
// will ever bind them again.
|
||||
filters.erase(
|
||||
std::remove_if(
|
||||
filters.erase(std::remove_if(
|
||||
filters.begin(), filters.end(),
|
||||
[e = filter_lambda.inner_edge_symbol,
|
||||
n = filter_lambda.inner_node_symbol](FilterInfo &fi) {
|
||||
return utils::Contains(fi.used_symbols, e) ||
|
||||
utils::Contains(fi.used_symbols, n);
|
||||
[e = filter_lambda.inner_edge_symbol, n = filter_lambda.inner_node_symbol](FilterInfo &fi) {
|
||||
return utils::Contains(fi.used_symbols, e) || utils::Contains(fi.used_symbols, n);
|
||||
}),
|
||||
filters.end());
|
||||
// Unbind the temporarily bound inner symbols for filtering.
|
||||
@ -490,19 +420,15 @@ class RuleBasedPlanner {
|
||||
}
|
||||
|
||||
// TODO: Pass weight lambda.
|
||||
MG_ASSERT(
|
||||
match_context.view == storage::View::OLD,
|
||||
MG_ASSERT(match_context.view == storage::View::OLD,
|
||||
"ExpandVariable should only be planned with storage::View::OLD");
|
||||
last_op = std::make_unique<ExpandVariable>(
|
||||
std::move(last_op), node1_symbol, node_symbol, edge_symbol,
|
||||
edge->type_, expansion.direction, edge_types,
|
||||
expansion.is_flipped, edge->lower_bound_, edge->upper_bound_,
|
||||
existing_node, filter_lambda, weight_lambda, total_weight);
|
||||
last_op = std::make_unique<ExpandVariable>(std::move(last_op), node1_symbol, node_symbol, edge_symbol,
|
||||
edge->type_, expansion.direction, edge_types, expansion.is_flipped,
|
||||
edge->lower_bound_, edge->upper_bound_, existing_node,
|
||||
filter_lambda, weight_lambda, total_weight);
|
||||
} else {
|
||||
last_op = std::make_unique<Expand>(std::move(last_op), node1_symbol,
|
||||
node_symbol, edge_symbol,
|
||||
expansion.direction, edge_types,
|
||||
existing_node, match_context.view);
|
||||
last_op = std::make_unique<Expand>(std::move(last_op), node1_symbol, node_symbol, edge_symbol,
|
||||
expansion.direction, edge_types, existing_node, match_context.view);
|
||||
}
|
||||
|
||||
// Bind the expanded edge and node.
|
||||
@ -520,60 +446,47 @@ class RuleBasedPlanner {
|
||||
}
|
||||
std::vector<Symbol> other_symbols;
|
||||
for (const auto &symbol : edge_symbols) {
|
||||
if (symbol == edge_symbol ||
|
||||
bound_symbols.find(symbol) == bound_symbols.end()) {
|
||||
if (symbol == edge_symbol || bound_symbols.find(symbol) == bound_symbols.end()) {
|
||||
continue;
|
||||
}
|
||||
other_symbols.push_back(symbol);
|
||||
}
|
||||
if (!other_symbols.empty()) {
|
||||
last_op = std::make_unique<EdgeUniquenessFilter>(
|
||||
std::move(last_op), edge_symbol, other_symbols);
|
||||
last_op = std::make_unique<EdgeUniquenessFilter>(std::move(last_op), edge_symbol, other_symbols);
|
||||
}
|
||||
}
|
||||
last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters,
|
||||
storage);
|
||||
last_op =
|
||||
impl::GenNamedPaths(std::move(last_op), bound_symbols, named_paths);
|
||||
last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters,
|
||||
storage);
|
||||
last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters, storage);
|
||||
last_op = impl::GenNamedPaths(std::move(last_op), bound_symbols, named_paths);
|
||||
last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters, storage);
|
||||
}
|
||||
}
|
||||
MG_ASSERT(named_paths.empty(), "Expected to generate all named paths");
|
||||
// We bound all named path symbols, so just add them to new_symbols.
|
||||
for (const auto &named_path : matching.named_paths) {
|
||||
MG_ASSERT(utils::Contains(bound_symbols, named_path.first),
|
||||
"Expected generated named path to have bound symbol");
|
||||
MG_ASSERT(utils::Contains(bound_symbols, named_path.first), "Expected generated named path to have bound symbol");
|
||||
match_context.new_symbols.emplace_back(named_path.first);
|
||||
}
|
||||
MG_ASSERT(filters.empty(), "Expected to generate all filters");
|
||||
return last_op;
|
||||
}
|
||||
|
||||
auto GenMerge(query::Merge &merge, std::unique_ptr<LogicalOperator> input_op,
|
||||
const Matching &matching) {
|
||||
auto GenMerge(query::Merge &merge, std::unique_ptr<LogicalOperator> input_op, const Matching &matching) {
|
||||
// Copy the bound symbol set, because we don't want to use the updated
|
||||
// version when generating the create part.
|
||||
std::unordered_set<Symbol> bound_symbols_copy(context_->bound_symbols);
|
||||
MatchContext match_ctx{matching, *context_->symbol_table,
|
||||
bound_symbols_copy, storage::View::NEW};
|
||||
MatchContext match_ctx{matching, *context_->symbol_table, bound_symbols_copy, storage::View::NEW};
|
||||
auto on_match = PlanMatching(match_ctx, nullptr);
|
||||
// Use the original bound_symbols, so we fill it with new symbols.
|
||||
auto on_create =
|
||||
GenCreateForPattern(*merge.pattern_, nullptr, *context_->symbol_table,
|
||||
context_->bound_symbols);
|
||||
auto on_create = GenCreateForPattern(*merge.pattern_, nullptr, *context_->symbol_table, context_->bound_symbols);
|
||||
for (auto &set : merge.on_create_) {
|
||||
on_create = HandleWriteClause(set, on_create, *context_->symbol_table,
|
||||
context_->bound_symbols);
|
||||
on_create = HandleWriteClause(set, on_create, *context_->symbol_table, context_->bound_symbols);
|
||||
MG_ASSERT(on_create, "Expected SET in MERGE ... ON CREATE");
|
||||
}
|
||||
for (auto &set : merge.on_match_) {
|
||||
on_match = HandleWriteClause(set, on_match, *context_->symbol_table,
|
||||
context_->bound_symbols);
|
||||
on_match = HandleWriteClause(set, on_match, *context_->symbol_table, context_->bound_symbols);
|
||||
MG_ASSERT(on_match, "Expected SET in MERGE ... ON MATCH");
|
||||
}
|
||||
return std::make_unique<plan::Merge>(
|
||||
std::move(input_op), std::move(on_match), std::move(on_create));
|
||||
return std::make_unique<plan::Merge>(std::move(input_op), std::move(on_match), std::move(on_create));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -19,9 +19,7 @@ namespace plan {
|
||||
*/
|
||||
class ScopedProfile {
|
||||
public:
|
||||
ScopedProfile(uint64_t key, const char *name,
|
||||
query::ExecutionContext *context) noexcept
|
||||
: context_(context) {
|
||||
ScopedProfile(uint64_t key, const char *name, query::ExecutionContext *context) noexcept : context_(context) {
|
||||
if (UNLIKELY(context_->is_profile_query)) {
|
||||
root_ = context_->stats_root;
|
||||
|
||||
|
@ -6,8 +6,7 @@
|
||||
#include "utils/flag_validation.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
|
||||
DEFINE_VALIDATED_HIDDEN_uint64(
|
||||
query_max_plans, 1000U, "Maximum number of generated plans for a query.",
|
||||
DEFINE_VALIDATED_HIDDEN_uint64(query_max_plans, 1000U, "Maximum number of generated plans for a query.",
|
||||
FLAG_IN_RANGE(1, std::numeric_limits<std::uint64_t>::max()));
|
||||
|
||||
namespace query::plan::impl {
|
||||
@ -17,13 +16,10 @@ namespace {
|
||||
// Add applicable expansions for `node_symbol` to `next_expansions`. These
|
||||
// expansions are removed from `node_symbol_to_expansions`, while
|
||||
// `seen_expansions` and `expanded_symbols` are populated with new data.
|
||||
void AddNextExpansions(
|
||||
const Symbol &node_symbol, const Matching &matching,
|
||||
const SymbolTable &symbol_table,
|
||||
void AddNextExpansions(const Symbol &node_symbol, const Matching &matching, const SymbolTable &symbol_table,
|
||||
std::unordered_set<Symbol> &expanded_symbols,
|
||||
std::unordered_map<Symbol, std::set<size_t>> &node_symbol_to_expansions,
|
||||
std::unordered_set<size_t> &seen_expansions,
|
||||
std::queue<Expansion> &next_expansions) {
|
||||
std::unordered_set<size_t> &seen_expansions, std::queue<Expansion> &next_expansions) {
|
||||
auto node_to_expansions_it = node_symbol_to_expansions.find(node_symbol);
|
||||
if (node_to_expansions_it == node_symbol_to_expansions.end()) {
|
||||
return;
|
||||
@ -37,8 +33,7 @@ void AddNextExpansions(
|
||||
// therefore bound. If the symbols are not found in the whole expansion,
|
||||
// then the semantic analysis should guarantee that the symbols have been
|
||||
// bound long before we expand.
|
||||
if (matching.expansion_symbols.find(range_symbol) !=
|
||||
matching.expansion_symbols.end() &&
|
||||
if (matching.expansion_symbols.find(range_symbol) != matching.expansion_symbols.end() &&
|
||||
expanded_symbols.find(range_symbol) == expanded_symbols.end()) {
|
||||
return false;
|
||||
}
|
||||
@ -62,18 +57,15 @@ void AddNextExpansions(
|
||||
}
|
||||
if (symbol_table.at(*expansion.node1->identifier_) != node_symbol) {
|
||||
// We are not expanding from node1, so flip the expansion.
|
||||
DMG_ASSERT(
|
||||
expansion.node2 &&
|
||||
symbol_table.at(*expansion.node2->identifier_) == node_symbol,
|
||||
DMG_ASSERT(expansion.node2 && symbol_table.at(*expansion.node2->identifier_) == node_symbol,
|
||||
"Expected node_symbol to be bound in node2");
|
||||
if (expansion.edge->type_ != EdgeAtom::Type::BREADTH_FIRST) {
|
||||
// BFS must *not* be flipped. Doing that changes the BFS results.
|
||||
std::swap(expansion.node1, expansion.node2);
|
||||
expansion.is_flipped = true;
|
||||
if (expansion.direction != EdgeAtom::Direction::BOTH) {
|
||||
expansion.direction = expansion.direction == EdgeAtom::Direction::IN
|
||||
? EdgeAtom::Direction::OUT
|
||||
: EdgeAtom::Direction::IN;
|
||||
expansion.direction =
|
||||
expansion.direction == EdgeAtom::Direction::IN ? EdgeAtom::Direction::OUT : EdgeAtom::Direction::IN;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -95,20 +87,17 @@ void AddNextExpansions(
|
||||
// the chain can no longer be continued, a different starting node is picked
|
||||
// among remaining expansions and the process continues. This is done until all
|
||||
// matching.expansions are used.
|
||||
std::vector<Expansion> ExpansionsFrom(const NodeAtom *start_node,
|
||||
const Matching &matching,
|
||||
std::vector<Expansion> ExpansionsFrom(const NodeAtom *start_node, const Matching &matching,
|
||||
const SymbolTable &symbol_table) {
|
||||
// Make a copy of node_symbol_to_expansions, because we will modify it as
|
||||
// expansions are chained.
|
||||
auto node_symbol_to_expansions = matching.node_symbol_to_expansions;
|
||||
std::unordered_set<size_t> seen_expansions;
|
||||
std::queue<Expansion> next_expansions;
|
||||
std::unordered_set<Symbol> expanded_symbols(
|
||||
{symbol_table.at(*start_node->identifier_)});
|
||||
std::unordered_set<Symbol> expanded_symbols({symbol_table.at(*start_node->identifier_)});
|
||||
auto add_next_expansions = [&](const auto *node) {
|
||||
AddNextExpansions(symbol_table.at(*node->identifier_), matching,
|
||||
symbol_table, expanded_symbols, node_symbol_to_expansions,
|
||||
seen_expansions, next_expansions);
|
||||
AddNextExpansions(symbol_table.at(*node->identifier_), matching, symbol_table, expanded_symbols,
|
||||
node_symbol_to_expansions, seen_expansions, next_expansions);
|
||||
};
|
||||
add_next_expansions(start_node);
|
||||
// Potential optimization: expansions and next_expansions could be merge into
|
||||
@ -141,10 +130,8 @@ std::vector<Expansion> ExpansionsFrom(const NodeAtom *start_node,
|
||||
|
||||
// Collect all unique nodes from expansions. Uniqueness is determined by
|
||||
// symbol uniqueness.
|
||||
auto ExpansionNodes(const std::vector<Expansion> &expansions,
|
||||
const SymbolTable &symbol_table) {
|
||||
std::unordered_set<NodeAtom *, NodeSymbolHash, NodeSymbolEqual> nodes(
|
||||
expansions.size(), NodeSymbolHash(symbol_table),
|
||||
auto ExpansionNodes(const std::vector<Expansion> &expansions, const SymbolTable &symbol_table) {
|
||||
std::unordered_set<NodeAtom *, NodeSymbolHash, NodeSymbolEqual> nodes(expansions.size(), NodeSymbolHash(symbol_table),
|
||||
NodeSymbolEqual(symbol_table));
|
||||
for (const auto &expansion : expansions) {
|
||||
// TODO: Handle labels and properties from different node atoms.
|
||||
@ -158,11 +145,8 @@ auto ExpansionNodes(const std::vector<Expansion> &expansions,
|
||||
|
||||
} // namespace
|
||||
|
||||
VaryMatchingStart::VaryMatchingStart(Matching matching,
|
||||
const SymbolTable &symbol_table)
|
||||
: matching_(matching),
|
||||
symbol_table_(symbol_table),
|
||||
nodes_(ExpansionNodes(matching.expansions, symbol_table)) {}
|
||||
VaryMatchingStart::VaryMatchingStart(Matching matching, const SymbolTable &symbol_table)
|
||||
: matching_(matching), symbol_table_(symbol_table), nodes_(ExpansionNodes(matching.expansions, symbol_table)) {}
|
||||
|
||||
VaryMatchingStart::iterator::iterator(VaryMatchingStart *self, bool is_done)
|
||||
: self_(self),
|
||||
@ -175,11 +159,9 @@ VaryMatchingStart::iterator::iterator(VaryMatchingStart *self, bool is_done)
|
||||
// Overwrite the original matching expansions with the new ones by
|
||||
// generating it from the first start node.
|
||||
start_nodes_it_ = self_->nodes_.begin();
|
||||
current_matching_.expansions = ExpansionsFrom(
|
||||
**start_nodes_it_, self_->matching_, self_->symbol_table_);
|
||||
current_matching_.expansions = ExpansionsFrom(**start_nodes_it_, self_->matching_, self_->symbol_table_);
|
||||
}
|
||||
DMG_ASSERT(
|
||||
start_nodes_it_ || self_->nodes_.empty(),
|
||||
DMG_ASSERT(start_nodes_it_ || self_->nodes_.empty(),
|
||||
"start_nodes_it_ should only be nullopt when self_->nodes_ is empty");
|
||||
if (is_done) {
|
||||
start_nodes_it_ = self_->nodes_.end();
|
||||
@ -188,9 +170,7 @@ VaryMatchingStart::iterator::iterator(VaryMatchingStart *self, bool is_done)
|
||||
|
||||
VaryMatchingStart::iterator &VaryMatchingStart::iterator::operator++() {
|
||||
if (!start_nodes_it_) {
|
||||
DMG_ASSERT(
|
||||
self_->nodes_.empty(),
|
||||
"start_nodes_it_ should only be nullopt when self_->nodes_ is empty");
|
||||
DMG_ASSERT(self_->nodes_.empty(), "start_nodes_it_ should only be nullopt when self_->nodes_ is empty");
|
||||
start_nodes_it_ = self_->nodes_.end();
|
||||
}
|
||||
if (*start_nodes_it_ == self_->nodes_.end()) {
|
||||
@ -203,13 +183,12 @@ VaryMatchingStart::iterator &VaryMatchingStart::iterator::operator++() {
|
||||
return *this;
|
||||
}
|
||||
const auto &start_node = **start_nodes_it_;
|
||||
current_matching_.expansions =
|
||||
ExpansionsFrom(start_node, self_->matching_, self_->symbol_table_);
|
||||
current_matching_.expansions = ExpansionsFrom(start_node, self_->matching_, self_->symbol_table_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
CartesianProduct<VaryMatchingStart> VaryMultiMatchingStarts(
|
||||
const std::vector<Matching> &matchings, const SymbolTable &symbol_table) {
|
||||
CartesianProduct<VaryMatchingStart> VaryMultiMatchingStarts(const std::vector<Matching> &matchings,
|
||||
const SymbolTable &symbol_table) {
|
||||
std::vector<VaryMatchingStart> variants;
|
||||
variants.reserve(matchings.size());
|
||||
for (const auto &matching : matchings) {
|
||||
@ -218,17 +197,13 @@ CartesianProduct<VaryMatchingStart> VaryMultiMatchingStarts(
|
||||
return MakeCartesianProduct(std::move(variants));
|
||||
}
|
||||
|
||||
VaryQueryPartMatching::VaryQueryPartMatching(SingleQueryPart query_part,
|
||||
const SymbolTable &symbol_table)
|
||||
VaryQueryPartMatching::VaryQueryPartMatching(SingleQueryPart query_part, const SymbolTable &symbol_table)
|
||||
: query_part_(std::move(query_part)),
|
||||
matchings_(VaryMatchingStart(query_part_.matching, symbol_table)),
|
||||
optional_matchings_(
|
||||
VaryMultiMatchingStarts(query_part_.optional_matching, symbol_table)),
|
||||
merge_matchings_(
|
||||
VaryMultiMatchingStarts(query_part_.merge_matching, symbol_table)) {}
|
||||
optional_matchings_(VaryMultiMatchingStarts(query_part_.optional_matching, symbol_table)),
|
||||
merge_matchings_(VaryMultiMatchingStarts(query_part_.merge_matching, symbol_table)) {}
|
||||
|
||||
VaryQueryPartMatching::iterator::iterator(
|
||||
const SingleQueryPart &query_part,
|
||||
VaryQueryPartMatching::iterator::iterator(const SingleQueryPart &query_part,
|
||||
VaryMatchingStart::iterator matchings_begin,
|
||||
VaryMatchingStart::iterator matchings_end,
|
||||
CartesianProduct<VaryMatchingStart>::iterator optional_begin,
|
||||
@ -304,8 +279,7 @@ bool VaryQueryPartMatching::iterator::operator==(const iterator &other) const {
|
||||
// iterators can be at any position.
|
||||
return true;
|
||||
}
|
||||
return matchings_it_ == other.matchings_it_ &&
|
||||
optional_it_ == other.optional_it_ && merge_it_ == other.merge_it_;
|
||||
return matchings_it_ == other.matchings_it_ && optional_it_ == other.optional_it_ && merge_it_ == other.merge_it_;
|
||||
}
|
||||
|
||||
} // namespace query::plan::impl
|
||||
|
@ -39,9 +39,7 @@ class CartesianProduct {
|
||||
|
||||
public:
|
||||
CartesianProduct(std::vector<TSet> sets)
|
||||
: original_sets_(std::move(sets)),
|
||||
begin_(original_sets_.begin()),
|
||||
end_(original_sets_.end()) {}
|
||||
: original_sets_(std::move(sets)), begin_(original_sets_.begin()), end_(original_sets_.end()) {}
|
||||
|
||||
class iterator {
|
||||
public:
|
||||
@ -51,8 +49,7 @@ class CartesianProduct {
|
||||
typedef const std::vector<TElement> &reference;
|
||||
typedef const std::vector<TElement> *pointer;
|
||||
|
||||
explicit iterator(CartesianProduct *self, bool is_done)
|
||||
: self_(self), is_done_(is_done) {
|
||||
explicit iterator(CartesianProduct *self, bool is_done) : self_(self), is_done_(is_done) {
|
||||
if (is_done || self->begin_ == self->end_) {
|
||||
is_done_ = true;
|
||||
return;
|
||||
@ -92,8 +89,7 @@ class CartesianProduct {
|
||||
++sets_it->second;
|
||||
}
|
||||
// We can now collect another product from the modified set iterators.
|
||||
DMG_ASSERT(
|
||||
current_product_.size() == sets_.size(),
|
||||
DMG_ASSERT(current_product_.size() == sets_.size(),
|
||||
"Expected size of current_product_ to match the size of sets_");
|
||||
size_t i = 0;
|
||||
// Change only the prefix of the product, remaining elements (after
|
||||
@ -106,9 +102,7 @@ class CartesianProduct {
|
||||
}
|
||||
|
||||
bool operator==(const iterator &other) const {
|
||||
if (self_->begin_ != other.self_->begin_ ||
|
||||
self_->end_ != other.self_->end_)
|
||||
return false;
|
||||
if (self_->begin_ != other.self_->begin_ || self_->end_ != other.self_->end_) return false;
|
||||
return (is_done_ && other.is_done_) || (sets_ == other.sets_);
|
||||
}
|
||||
|
||||
@ -126,9 +120,7 @@ class CartesianProduct {
|
||||
// Vector of (original_sets_iterator, set_iterator) pairs. The
|
||||
// original_sets_iterator points to the set among all the sets, while the
|
||||
// set_iterator points to an element inside the pointed to set.
|
||||
std::vector<
|
||||
std::pair<decltype(self_->begin_), decltype(self_->begin_->begin())>>
|
||||
sets_;
|
||||
std::vector<std::pair<decltype(self_->begin_), decltype(self_->begin_->begin())>> sets_;
|
||||
// Currently built product from pointed to elements in all sets.
|
||||
std::vector<TElement> current_product_;
|
||||
// Set to true when we have generated all products.
|
||||
@ -153,8 +145,7 @@ namespace impl {
|
||||
|
||||
class NodeSymbolHash {
|
||||
public:
|
||||
explicit NodeSymbolHash(const SymbolTable &symbol_table)
|
||||
: symbol_table_(symbol_table) {}
|
||||
explicit NodeSymbolHash(const SymbolTable &symbol_table) : symbol_table_(symbol_table) {}
|
||||
|
||||
size_t operator()(const NodeAtom *node_atom) const {
|
||||
return std::hash<Symbol>{}(symbol_table_.at(*node_atom->identifier_));
|
||||
@ -166,13 +157,10 @@ class NodeSymbolHash {
|
||||
|
||||
class NodeSymbolEqual {
|
||||
public:
|
||||
explicit NodeSymbolEqual(const SymbolTable &symbol_table)
|
||||
: symbol_table_(symbol_table) {}
|
||||
explicit NodeSymbolEqual(const SymbolTable &symbol_table) : symbol_table_(symbol_table) {}
|
||||
|
||||
bool operator()(const NodeAtom *node_atom1,
|
||||
const NodeAtom *node_atom2) const {
|
||||
return symbol_table_.at(*node_atom1->identifier_) ==
|
||||
symbol_table_.at(*node_atom2->identifier_);
|
||||
bool operator()(const NodeAtom *node_atom1, const NodeAtom *node_atom2) const {
|
||||
return symbol_table_.at(*node_atom1->identifier_) == symbol_table_.at(*node_atom2->identifier_);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -213,9 +201,7 @@ class VaryMatchingStart {
|
||||
// being at the end. When there are no nodes, this iterator needs to produce
|
||||
// a single result, which is the original matching passed in. Setting
|
||||
// start_nodes_it_ to end signifies the end of our iteration.
|
||||
std::optional<std::unordered_set<NodeAtom *, NodeSymbolHash,
|
||||
NodeSymbolEqual>::iterator>
|
||||
start_nodes_it_;
|
||||
std::optional<std::unordered_set<NodeAtom *, NodeSymbolHash, NodeSymbolEqual>::iterator> start_nodes_it_;
|
||||
};
|
||||
|
||||
auto begin() { return iterator(this, false); }
|
||||
@ -231,8 +217,7 @@ class VaryMatchingStart {
|
||||
// Similar to VaryMatchingStart, but varies the starting nodes for all given
|
||||
// matchings. After all matchings produce multiple alternative starts, the
|
||||
// Cartesian product of all of them is returned.
|
||||
CartesianProduct<VaryMatchingStart> VaryMultiMatchingStarts(
|
||||
const std::vector<Matching> &, const SymbolTable &);
|
||||
CartesianProduct<VaryMatchingStart> VaryMultiMatchingStarts(const std::vector<Matching> &, const SymbolTable &);
|
||||
|
||||
// Produces alternative query parts out of a single part by varying how each
|
||||
// graph matching is done.
|
||||
@ -248,12 +233,9 @@ class VaryQueryPartMatching {
|
||||
typedef const SingleQueryPart &reference;
|
||||
typedef const SingleQueryPart *pointer;
|
||||
|
||||
iterator(const SingleQueryPart &, VaryMatchingStart::iterator,
|
||||
VaryMatchingStart::iterator,
|
||||
CartesianProduct<VaryMatchingStart>::iterator,
|
||||
CartesianProduct<VaryMatchingStart>::iterator,
|
||||
CartesianProduct<VaryMatchingStart>::iterator,
|
||||
CartesianProduct<VaryMatchingStart>::iterator);
|
||||
iterator(const SingleQueryPart &, VaryMatchingStart::iterator, VaryMatchingStart::iterator,
|
||||
CartesianProduct<VaryMatchingStart>::iterator, CartesianProduct<VaryMatchingStart>::iterator,
|
||||
CartesianProduct<VaryMatchingStart>::iterator, CartesianProduct<VaryMatchingStart>::iterator);
|
||||
|
||||
iterator &operator++();
|
||||
reference operator*() const { return current_query_part_; }
|
||||
@ -276,14 +258,12 @@ class VaryQueryPartMatching {
|
||||
};
|
||||
|
||||
auto begin() {
|
||||
return iterator(query_part_, matchings_.begin(), matchings_.end(),
|
||||
optional_matchings_.begin(), optional_matchings_.end(),
|
||||
merge_matchings_.begin(), merge_matchings_.end());
|
||||
return iterator(query_part_, matchings_.begin(), matchings_.end(), optional_matchings_.begin(),
|
||||
optional_matchings_.end(), merge_matchings_.begin(), merge_matchings_.end());
|
||||
}
|
||||
auto end() {
|
||||
return iterator(query_part_, matchings_.end(), matchings_.end(),
|
||||
optional_matchings_.end(), optional_matchings_.end(),
|
||||
merge_matchings_.end(), merge_matchings_.end());
|
||||
return iterator(query_part_, matchings_.end(), matchings_.end(), optional_matchings_.end(),
|
||||
optional_matchings_.end(), merge_matchings_.end(), merge_matchings_.end());
|
||||
}
|
||||
|
||||
private:
|
||||
@ -313,21 +293,17 @@ class VariableStartPlanner {
|
||||
|
||||
// Generates different, equivalent query parts by taking different graph
|
||||
// matching routes for each query part.
|
||||
auto VaryQueryMatching(const std::vector<SingleQueryPart> &query_parts,
|
||||
const SymbolTable &symbol_table) {
|
||||
auto VaryQueryMatching(const std::vector<SingleQueryPart> &query_parts, const SymbolTable &symbol_table) {
|
||||
std::vector<impl::VaryQueryPartMatching> alternative_query_parts;
|
||||
alternative_query_parts.reserve(query_parts.size());
|
||||
for (const auto &query_part : query_parts) {
|
||||
alternative_query_parts.emplace_back(
|
||||
impl::VaryQueryPartMatching(query_part, symbol_table));
|
||||
alternative_query_parts.emplace_back(impl::VaryQueryPartMatching(query_part, symbol_table));
|
||||
}
|
||||
return iter::slice(MakeCartesianProduct(std::move(alternative_query_parts)),
|
||||
0UL, FLAGS_query_max_plans);
|
||||
return iter::slice(MakeCartesianProduct(std::move(alternative_query_parts)), 0UL, FLAGS_query_max_plans);
|
||||
}
|
||||
|
||||
public:
|
||||
explicit VariableStartPlanner(TPlanningContext *context)
|
||||
: context_(context) {}
|
||||
explicit VariableStartPlanner(TPlanningContext *context) : context_(context) {}
|
||||
|
||||
/// @brief Generate multiple plans by varying the order of graph traversal.
|
||||
auto Plan(const std::vector<SingleQueryPart> &query_parts) {
|
||||
@ -342,10 +318,8 @@ class VariableStartPlanner {
|
||||
|
||||
/// @brief The result of plan generation is an iterable of roots to multiple
|
||||
/// generated operator trees.
|
||||
using PlanResult = typename std::result_of<decltype (
|
||||
&VariableStartPlanner<TPlanningContext>::Plan)(
|
||||
VariableStartPlanner<TPlanningContext>,
|
||||
std::vector<SingleQueryPart> &)>::type;
|
||||
using PlanResult = typename std::result_of<decltype (&VariableStartPlanner<TPlanningContext>::Plan)(
|
||||
VariableStartPlanner<TPlanningContext>, std::vector<SingleQueryPart> &)>::type;
|
||||
};
|
||||
|
||||
} // namespace query::plan
|
||||
|
@ -19,12 +19,8 @@ class VertexCountCache {
|
||||
VertexCountCache(TDbAccessor *db) : db_(db) {}
|
||||
|
||||
auto NameToLabel(const std::string &name) { return db_->NameToLabel(name); }
|
||||
auto NameToProperty(const std::string &name) {
|
||||
return db_->NameToProperty(name);
|
||||
}
|
||||
auto NameToEdgeType(const std::string &name) {
|
||||
return db_->NameToEdgeType(name);
|
||||
}
|
||||
auto NameToProperty(const std::string &name) { return db_->NameToProperty(name); }
|
||||
auto NameToEdgeType(const std::string &name) { return db_->NameToEdgeType(name); }
|
||||
|
||||
int64_t VerticesCount() {
|
||||
if (!vertices_count_) vertices_count_ = db_->VerticesCount();
|
||||
@ -39,14 +35,12 @@ class VertexCountCache {
|
||||
|
||||
int64_t VerticesCount(storage::LabelId label, storage::PropertyId property) {
|
||||
auto key = std::make_pair(label, property);
|
||||
if (label_property_vertex_count_.find(key) ==
|
||||
label_property_vertex_count_.end())
|
||||
if (label_property_vertex_count_.find(key) == label_property_vertex_count_.end())
|
||||
label_property_vertex_count_[key] = db_->VerticesCount(label, property);
|
||||
return label_property_vertex_count_.at(key);
|
||||
}
|
||||
|
||||
int64_t VerticesCount(storage::LabelId label, storage::PropertyId property,
|
||||
const storage::PropertyValue &value) {
|
||||
int64_t VerticesCount(storage::LabelId label, storage::PropertyId property, const storage::PropertyValue &value) {
|
||||
auto label_prop = std::make_pair(label, property);
|
||||
auto &value_vertex_count = property_value_vertex_count_[label_prop];
|
||||
// TODO: Why do we even need TypedValue in this whole file?
|
||||
@ -56,25 +50,20 @@ class VertexCountCache {
|
||||
return value_vertex_count.at(tv_value);
|
||||
}
|
||||
|
||||
int64_t VerticesCount(
|
||||
storage::LabelId label, storage::PropertyId property,
|
||||
int64_t VerticesCount(storage::LabelId label, storage::PropertyId property,
|
||||
const std::optional<utils::Bound<storage::PropertyValue>> &lower,
|
||||
const std::optional<utils::Bound<storage::PropertyValue>> &upper) {
|
||||
auto label_prop = std::make_pair(label, property);
|
||||
auto &bounds_vertex_count = property_bounds_vertex_count_[label_prop];
|
||||
BoundsKey bounds = std::make_pair(lower, upper);
|
||||
if (bounds_vertex_count.find(bounds) == bounds_vertex_count.end())
|
||||
bounds_vertex_count[bounds] =
|
||||
db_->VerticesCount(label, property, lower, upper);
|
||||
bounds_vertex_count[bounds] = db_->VerticesCount(label, property, lower, upper);
|
||||
return bounds_vertex_count.at(bounds);
|
||||
}
|
||||
|
||||
bool LabelIndexExists(storage::LabelId label) {
|
||||
return db_->LabelIndexExists(label);
|
||||
}
|
||||
bool LabelIndexExists(storage::LabelId label) { return db_->LabelIndexExists(label); }
|
||||
|
||||
bool LabelPropertyIndexExists(storage::LabelId label,
|
||||
storage::PropertyId property) {
|
||||
bool LabelPropertyIndexExists(storage::LabelId label, storage::PropertyId property) {
|
||||
return db_->LabelPropertyIndexExists(label, property);
|
||||
}
|
||||
|
||||
@ -83,8 +72,7 @@ class VertexCountCache {
|
||||
|
||||
struct LabelPropertyHash {
|
||||
size_t operator()(const LabelPropertyKey &key) const {
|
||||
return utils::HashCombine<storage::LabelId, storage::PropertyId>{}(
|
||||
key.first, key.second);
|
||||
return utils::HashCombine<storage::LabelId, storage::PropertyId>{}(key.first, key.second);
|
||||
}
|
||||
};
|
||||
|
||||
@ -107,11 +95,8 @@ class VertexCountCache {
|
||||
|
||||
struct BoundsEqual {
|
||||
bool operator()(const BoundsKey &a, const BoundsKey &b) const {
|
||||
auto bound_equal = [](const auto &maybe_bound_a,
|
||||
const auto &maybe_bound_b) {
|
||||
if (maybe_bound_a && maybe_bound_b &&
|
||||
maybe_bound_a->type() != maybe_bound_b->type())
|
||||
return false;
|
||||
auto bound_equal = [](const auto &maybe_bound_a, const auto &maybe_bound_b) {
|
||||
if (maybe_bound_a && maybe_bound_b && maybe_bound_a->type() != maybe_bound_b->type()) return false;
|
||||
query::TypedValue bound_a;
|
||||
query::TypedValue bound_b;
|
||||
if (maybe_bound_a) bound_a = TypedValue(maybe_bound_a->value());
|
||||
@ -125,17 +110,13 @@ class VertexCountCache {
|
||||
TDbAccessor *db_;
|
||||
std::optional<int64_t> vertices_count_;
|
||||
std::unordered_map<storage::LabelId, int64_t> label_vertex_count_;
|
||||
std::unordered_map<LabelPropertyKey, int64_t, LabelPropertyHash>
|
||||
label_property_vertex_count_;
|
||||
std::unordered_map<LabelPropertyKey, int64_t, LabelPropertyHash> label_property_vertex_count_;
|
||||
std::unordered_map<
|
||||
LabelPropertyKey,
|
||||
std::unordered_map<query::TypedValue, int64_t, query::TypedValue::Hash,
|
||||
query::TypedValue::BoolEqual>,
|
||||
std::unordered_map<query::TypedValue, int64_t, query::TypedValue::Hash, query::TypedValue::BoolEqual>,
|
||||
LabelPropertyHash>
|
||||
property_value_vertex_count_;
|
||||
std::unordered_map<
|
||||
LabelPropertyKey,
|
||||
std::unordered_map<BoundsKey, int64_t, BoundsHash, BoundsEqual>,
|
||||
std::unordered_map<LabelPropertyKey, std::unordered_map<BoundsKey, int64_t, BoundsHash, BoundsEqual>,
|
||||
LabelPropertyHash>
|
||||
property_bounds_vertex_count_;
|
||||
};
|
||||
|
@ -42,8 +42,7 @@ class CypherType {
|
||||
virtual const NullableType *AsNullableType() const { return nullptr; }
|
||||
};
|
||||
|
||||
using CypherTypePtr =
|
||||
std::unique_ptr<CypherType, std::function<void(CypherType *)>>;
|
||||
using CypherTypePtr = std::unique_ptr<CypherType, std::function<void(CypherType *)>>;
|
||||
|
||||
// Simple Types
|
||||
|
||||
@ -51,65 +50,45 @@ class AnyType : public CypherType {
|
||||
public:
|
||||
std::string_view GetPresentableName() const override { return "ANY"; }
|
||||
|
||||
bool SatisfiesType(const mgp_value &value) const override {
|
||||
return !mgp_value_is_null(&value);
|
||||
}
|
||||
bool SatisfiesType(const mgp_value &value) const override { return !mgp_value_is_null(&value); }
|
||||
|
||||
bool SatisfiesType(const query::TypedValue &value) const override {
|
||||
return !value.IsNull();
|
||||
}
|
||||
bool SatisfiesType(const query::TypedValue &value) const override { return !value.IsNull(); }
|
||||
};
|
||||
|
||||
class BoolType : public CypherType {
|
||||
public:
|
||||
std::string_view GetPresentableName() const override { return "BOOLEAN"; }
|
||||
|
||||
bool SatisfiesType(const mgp_value &value) const override {
|
||||
return mgp_value_is_bool(&value);
|
||||
}
|
||||
bool SatisfiesType(const mgp_value &value) const override { return mgp_value_is_bool(&value); }
|
||||
|
||||
bool SatisfiesType(const query::TypedValue &value) const override {
|
||||
return value.IsBool();
|
||||
}
|
||||
bool SatisfiesType(const query::TypedValue &value) const override { return value.IsBool(); }
|
||||
};
|
||||
|
||||
class StringType : public CypherType {
|
||||
public:
|
||||
std::string_view GetPresentableName() const override { return "STRING"; }
|
||||
|
||||
bool SatisfiesType(const mgp_value &value) const override {
|
||||
return mgp_value_is_string(&value);
|
||||
}
|
||||
bool SatisfiesType(const mgp_value &value) const override { return mgp_value_is_string(&value); }
|
||||
|
||||
bool SatisfiesType(const query::TypedValue &value) const override {
|
||||
return value.IsString();
|
||||
}
|
||||
bool SatisfiesType(const query::TypedValue &value) const override { return value.IsString(); }
|
||||
};
|
||||
|
||||
class IntType : public CypherType {
|
||||
public:
|
||||
std::string_view GetPresentableName() const override { return "INTEGER"; }
|
||||
|
||||
bool SatisfiesType(const mgp_value &value) const override {
|
||||
return mgp_value_is_int(&value);
|
||||
}
|
||||
bool SatisfiesType(const mgp_value &value) const override { return mgp_value_is_int(&value); }
|
||||
|
||||
bool SatisfiesType(const query::TypedValue &value) const override {
|
||||
return value.IsInt();
|
||||
}
|
||||
bool SatisfiesType(const query::TypedValue &value) const override { return value.IsInt(); }
|
||||
};
|
||||
|
||||
class FloatType : public CypherType {
|
||||
public:
|
||||
std::string_view GetPresentableName() const override { return "FLOAT"; }
|
||||
|
||||
bool SatisfiesType(const mgp_value &value) const override {
|
||||
return mgp_value_is_double(&value);
|
||||
}
|
||||
bool SatisfiesType(const mgp_value &value) const override { return mgp_value_is_double(&value); }
|
||||
|
||||
bool SatisfiesType(const query::TypedValue &value) const override {
|
||||
return value.IsDouble();
|
||||
}
|
||||
bool SatisfiesType(const query::TypedValue &value) const override { return value.IsDouble(); }
|
||||
};
|
||||
|
||||
class NumberType : public CypherType {
|
||||
@ -120,50 +99,34 @@ class NumberType : public CypherType {
|
||||
return mgp_value_is_int(&value) || mgp_value_is_double(&value);
|
||||
}
|
||||
|
||||
bool SatisfiesType(const query::TypedValue &value) const override {
|
||||
return value.IsInt() || value.IsDouble();
|
||||
}
|
||||
bool SatisfiesType(const query::TypedValue &value) const override { return value.IsInt() || value.IsDouble(); }
|
||||
};
|
||||
|
||||
class NodeType : public CypherType {
|
||||
public:
|
||||
std::string_view GetPresentableName() const override { return "NODE"; }
|
||||
|
||||
bool SatisfiesType(const mgp_value &value) const override {
|
||||
return mgp_value_is_vertex(&value);
|
||||
}
|
||||
bool SatisfiesType(const mgp_value &value) const override { return mgp_value_is_vertex(&value); }
|
||||
|
||||
bool SatisfiesType(const query::TypedValue &value) const override {
|
||||
return value.IsVertex();
|
||||
}
|
||||
bool SatisfiesType(const query::TypedValue &value) const override { return value.IsVertex(); }
|
||||
};
|
||||
|
||||
class RelationshipType : public CypherType {
|
||||
public:
|
||||
std::string_view GetPresentableName() const override {
|
||||
return "RELATIONSHIP";
|
||||
}
|
||||
std::string_view GetPresentableName() const override { return "RELATIONSHIP"; }
|
||||
|
||||
bool SatisfiesType(const mgp_value &value) const override {
|
||||
return mgp_value_is_edge(&value);
|
||||
}
|
||||
bool SatisfiesType(const mgp_value &value) const override { return mgp_value_is_edge(&value); }
|
||||
|
||||
bool SatisfiesType(const query::TypedValue &value) const override {
|
||||
return value.IsEdge();
|
||||
}
|
||||
bool SatisfiesType(const query::TypedValue &value) const override { return value.IsEdge(); }
|
||||
};
|
||||
|
||||
class PathType : public CypherType {
|
||||
public:
|
||||
std::string_view GetPresentableName() const override { return "PATH"; }
|
||||
|
||||
bool SatisfiesType(const mgp_value &value) const override {
|
||||
return mgp_value_is_path(&value);
|
||||
}
|
||||
bool SatisfiesType(const mgp_value &value) const override { return mgp_value_is_path(&value); }
|
||||
|
||||
bool SatisfiesType(const query::TypedValue &value) const override {
|
||||
return value.IsPath();
|
||||
}
|
||||
bool SatisfiesType(const query::TypedValue &value) const override { return value.IsPath(); }
|
||||
};
|
||||
|
||||
// TODO: There's also Temporal Types, but we currently do not support those.
|
||||
@ -178,8 +141,7 @@ class MapType : public CypherType {
|
||||
std::string_view GetPresentableName() const override { return "MAP"; }
|
||||
|
||||
bool SatisfiesType(const mgp_value &value) const override {
|
||||
return mgp_value_is_map(&value) || mgp_value_is_vertex(&value) ||
|
||||
mgp_value_is_edge(&value);
|
||||
return mgp_value_is_map(&value) || mgp_value_is_vertex(&value) || mgp_value_is_edge(&value);
|
||||
}
|
||||
|
||||
bool SatisfiesType(const query::TypedValue &value) const override {
|
||||
@ -197,14 +159,11 @@ class ListType : public CypherType {
|
||||
/// @throw std::bad_alloc
|
||||
/// @throw std::length_error
|
||||
explicit ListType(CypherTypePtr element_type, utils::MemoryResource *memory)
|
||||
: element_type_(std::move(element_type)),
|
||||
presentable_name_("LIST OF ", memory) {
|
||||
: element_type_(std::move(element_type)), presentable_name_("LIST OF ", memory) {
|
||||
presentable_name_.append(element_type_->GetPresentableName());
|
||||
}
|
||||
|
||||
std::string_view GetPresentableName() const override {
|
||||
return presentable_name_;
|
||||
}
|
||||
std::string_view GetPresentableName() const override { return presentable_name_; }
|
||||
|
||||
bool SatisfiesType(const mgp_value &value) const override {
|
||||
if (!mgp_value_is_list(&value)) return false;
|
||||
@ -239,8 +198,7 @@ class NullableType : public CypherType {
|
||||
const auto *list_type = type_->AsListType();
|
||||
// ListType is specially formatted
|
||||
if (list_type) {
|
||||
presentable_name_.assign("LIST? OF ")
|
||||
.append(list_type->element_type_->GetPresentableName());
|
||||
presentable_name_.assign("LIST? OF ").append(list_type->element_type_->GetPresentableName());
|
||||
} else {
|
||||
presentable_name_.assign(type_->GetPresentableName()).append("?");
|
||||
}
|
||||
@ -252,8 +210,7 @@ class NullableType : public CypherType {
|
||||
/// Otherwise, `type` is wrapped in a new instance of NullableType.
|
||||
/// @throw std::bad_alloc
|
||||
/// @throw std::length_error
|
||||
static CypherTypePtr Create(CypherTypePtr type,
|
||||
utils::MemoryResource *memory) {
|
||||
static CypherTypePtr Create(CypherTypePtr type, utils::MemoryResource *memory) {
|
||||
if (type->AsNullableType()) return type;
|
||||
utils::Allocator<NullableType> alloc(memory);
|
||||
auto *nullable = alloc.allocate(1);
|
||||
@ -268,9 +225,7 @@ class NullableType : public CypherType {
|
||||
});
|
||||
}
|
||||
|
||||
std::string_view GetPresentableName() const override {
|
||||
return presentable_name_;
|
||||
}
|
||||
std::string_view GetPresentableName() const override { return presentable_name_; }
|
||||
|
||||
bool SatisfiesType(const mgp_value &value) const override {
|
||||
return mgp_value_is_null(&value) || type_->SatisfiesType(value);
|
||||
|
@ -20,8 +20,7 @@ void *mgp_alloc(mgp_memory *memory, size_t size_in_bytes) {
|
||||
return mgp_aligned_alloc(memory, size_in_bytes, alignof(std::max_align_t));
|
||||
}
|
||||
|
||||
void *mgp_aligned_alloc(mgp_memory *memory, const size_t size_in_bytes,
|
||||
const size_t alignment) {
|
||||
void *mgp_aligned_alloc(mgp_memory *memory, const size_t size_in_bytes, const size_t alignment) {
|
||||
if (size_in_bytes == 0U || !utils::IsPow2(alignment)) return nullptr;
|
||||
// Simplify alignment by always using values greater or equal to max_align.
|
||||
const size_t alloc_align = std::max(alignment, alignof(std::max_align_t));
|
||||
@ -32,8 +31,7 @@ void *mgp_aligned_alloc(mgp_memory *memory, const size_t size_in_bytes,
|
||||
// just allocate an additional multiple of `alloc_align` of bytes such that
|
||||
// the header fits. `data` will then be aligned after this multiple of bytes.
|
||||
static_assert(std::is_same_v<size_t, uint64_t>);
|
||||
const auto maybe_bytes_for_header =
|
||||
utils::RoundUint64ToMultiple(header_size, alloc_align);
|
||||
const auto maybe_bytes_for_header = utils::RoundUint64ToMultiple(header_size, alloc_align);
|
||||
if (!maybe_bytes_for_header) return nullptr;
|
||||
const size_t bytes_for_header = *maybe_bytes_for_header;
|
||||
const size_t alloc_size = bytes_for_header + size_in_bytes;
|
||||
@ -41,10 +39,8 @@ void *mgp_aligned_alloc(mgp_memory *memory, const size_t size_in_bytes,
|
||||
try {
|
||||
void *ptr = memory->impl->Allocate(alloc_size, alloc_align);
|
||||
char *data = reinterpret_cast<char *>(ptr) + bytes_for_header;
|
||||
std::memcpy(data - sizeof(size_in_bytes), &size_in_bytes,
|
||||
sizeof(size_in_bytes));
|
||||
std::memcpy(data - sizeof(size_in_bytes) - sizeof(alloc_align),
|
||||
&alloc_align, sizeof(alloc_align));
|
||||
std::memcpy(data - sizeof(size_in_bytes), &size_in_bytes, sizeof(size_in_bytes));
|
||||
std::memcpy(data - sizeof(size_in_bytes) - sizeof(alloc_align), &alloc_align, sizeof(alloc_align));
|
||||
return data;
|
||||
} catch (...) {
|
||||
return nullptr;
|
||||
@ -56,17 +52,14 @@ void mgp_free(mgp_memory *memory, void *const p) {
|
||||
char *const data = reinterpret_cast<char *>(p);
|
||||
// Read the header containing size & alignment info.
|
||||
size_t size_in_bytes;
|
||||
std::memcpy(&size_in_bytes, data - sizeof(size_in_bytes),
|
||||
sizeof(size_in_bytes));
|
||||
std::memcpy(&size_in_bytes, data - sizeof(size_in_bytes), sizeof(size_in_bytes));
|
||||
size_t alloc_align;
|
||||
std::memcpy(&alloc_align, data - sizeof(size_in_bytes) - sizeof(alloc_align),
|
||||
sizeof(alloc_align));
|
||||
std::memcpy(&alloc_align, data - sizeof(size_in_bytes) - sizeof(alloc_align), sizeof(alloc_align));
|
||||
// Reconstruct how many bytes we allocated on top of the original request.
|
||||
// We need not check allocation request overflow, since we did so already in
|
||||
// mgp_aligned_alloc.
|
||||
const size_t header_size = sizeof(size_in_bytes) + sizeof(alloc_align);
|
||||
const size_t bytes_for_header =
|
||||
*utils::RoundUint64ToMultiple(header_size, alloc_align);
|
||||
const size_t bytes_for_header = *utils::RoundUint64ToMultiple(header_size, alloc_align);
|
||||
const size_t alloc_size = bytes_for_header + size_in_bytes;
|
||||
// Get the original ptr we allocated.
|
||||
void *const original_ptr = data - bytes_for_header;
|
||||
@ -89,8 +82,7 @@ U *new_mgp_object(utils::MemoryResource *memory, TArgs &&...args) {
|
||||
|
||||
template <class U, class... TArgs>
|
||||
U *new_mgp_object(mgp_memory *memory, TArgs &&...args) {
|
||||
return new_mgp_object<U, TArgs...>(memory->impl,
|
||||
std::forward<TArgs>(args)...);
|
||||
return new_mgp_object<U, TArgs...>(memory->impl, std::forward<TArgs>(args)...);
|
||||
}
|
||||
|
||||
// Assume that deallocation and object destruction never throws. If it does,
|
||||
@ -127,14 +119,12 @@ mgp_value_type FromTypedValueType(query::TypedValue::Type type) {
|
||||
}
|
||||
}
|
||||
|
||||
query::TypedValue ToTypedValue(const mgp_value &val,
|
||||
utils::MemoryResource *memory) {
|
||||
query::TypedValue ToTypedValue(const mgp_value &val, utils::MemoryResource *memory) {
|
||||
switch (mgp_value_get_type(&val)) {
|
||||
case MGP_VALUE_TYPE_NULL:
|
||||
return query::TypedValue(memory);
|
||||
case MGP_VALUE_TYPE_BOOL:
|
||||
return query::TypedValue(static_cast<bool>(mgp_value_get_bool(&val)),
|
||||
memory);
|
||||
return query::TypedValue(static_cast<bool>(mgp_value_get_bool(&val)), memory);
|
||||
case MGP_VALUE_TYPE_INT:
|
||||
return query::TypedValue(mgp_value_get_int(&val), memory);
|
||||
case MGP_VALUE_TYPE_DOUBLE:
|
||||
@ -178,11 +168,9 @@ query::TypedValue ToTypedValue(const mgp_value &val,
|
||||
|
||||
} // namespace
|
||||
|
||||
mgp_value::mgp_value(utils::MemoryResource *m) noexcept
|
||||
: type(MGP_VALUE_TYPE_NULL), memory(m) {}
|
||||
mgp_value::mgp_value(utils::MemoryResource *m) noexcept : type(MGP_VALUE_TYPE_NULL), memory(m) {}
|
||||
|
||||
mgp_value::mgp_value(bool val, utils::MemoryResource *m) noexcept
|
||||
: type(MGP_VALUE_TYPE_BOOL), memory(m), bool_v(val) {}
|
||||
mgp_value::mgp_value(bool val, utils::MemoryResource *m) noexcept : type(MGP_VALUE_TYPE_BOOL), memory(m), bool_v(val) {}
|
||||
|
||||
mgp_value::mgp_value(int64_t val, utils::MemoryResource *m) noexcept
|
||||
: type(MGP_VALUE_TYPE_INT), memory(m), int_v(val) {}
|
||||
@ -195,36 +183,30 @@ mgp_value::mgp_value(const char *val, utils::MemoryResource *m)
|
||||
|
||||
mgp_value::mgp_value(mgp_list *val, utils::MemoryResource *m) noexcept
|
||||
: type(MGP_VALUE_TYPE_LIST), memory(m), list_v(val) {
|
||||
MG_ASSERT(val->GetMemoryResource() == m,
|
||||
"Unable to take ownership of a pointer with different allocator.");
|
||||
MG_ASSERT(val->GetMemoryResource() == m, "Unable to take ownership of a pointer with different allocator.");
|
||||
}
|
||||
|
||||
mgp_value::mgp_value(mgp_map *val, utils::MemoryResource *m) noexcept
|
||||
: type(MGP_VALUE_TYPE_MAP), memory(m), map_v(val) {
|
||||
MG_ASSERT(val->GetMemoryResource() == m,
|
||||
"Unable to take ownership of a pointer with different allocator.");
|
||||
MG_ASSERT(val->GetMemoryResource() == m, "Unable to take ownership of a pointer with different allocator.");
|
||||
}
|
||||
|
||||
mgp_value::mgp_value(mgp_vertex *val, utils::MemoryResource *m) noexcept
|
||||
: type(MGP_VALUE_TYPE_VERTEX), memory(m), vertex_v(val) {
|
||||
MG_ASSERT(val->GetMemoryResource() == m,
|
||||
"Unable to take ownership of a pointer with different allocator.");
|
||||
MG_ASSERT(val->GetMemoryResource() == m, "Unable to take ownership of a pointer with different allocator.");
|
||||
}
|
||||
|
||||
mgp_value::mgp_value(mgp_edge *val, utils::MemoryResource *m) noexcept
|
||||
: type(MGP_VALUE_TYPE_EDGE), memory(m), edge_v(val) {
|
||||
MG_ASSERT(val->GetMemoryResource() == m,
|
||||
"Unable to take ownership of a pointer with different allocator.");
|
||||
MG_ASSERT(val->GetMemoryResource() == m, "Unable to take ownership of a pointer with different allocator.");
|
||||
}
|
||||
|
||||
mgp_value::mgp_value(mgp_path *val, utils::MemoryResource *m) noexcept
|
||||
: type(MGP_VALUE_TYPE_PATH), memory(m), path_v(val) {
|
||||
MG_ASSERT(val->GetMemoryResource() == m,
|
||||
"Unable to take ownership of a pointer with different allocator.");
|
||||
MG_ASSERT(val->GetMemoryResource() == m, "Unable to take ownership of a pointer with different allocator.");
|
||||
}
|
||||
|
||||
mgp_value::mgp_value(const query::TypedValue &tv, const mgp_graph *graph,
|
||||
utils::MemoryResource *m)
|
||||
mgp_value::mgp_value(const query::TypedValue &tv, const mgp_graph *graph, utils::MemoryResource *m)
|
||||
: type(FromTypedValueType(tv.type())), memory(m) {
|
||||
switch (type) {
|
||||
case MGP_VALUE_TYPE_NULL:
|
||||
@ -299,8 +281,7 @@ mgp_value::mgp_value(const query::TypedValue &tv, const mgp_graph *graph,
|
||||
}
|
||||
}
|
||||
|
||||
mgp_value::mgp_value(const storage::PropertyValue &pv, utils::MemoryResource *m)
|
||||
: memory(m) {
|
||||
mgp_value::mgp_value(const storage::PropertyValue &pv, utils::MemoryResource *m) : memory(m) {
|
||||
switch (pv.type()) {
|
||||
case storage::PropertyValue::Type::Null:
|
||||
type = MGP_VALUE_TYPE_NULL;
|
||||
@ -353,8 +334,7 @@ mgp_value::mgp_value(const storage::PropertyValue &pv, utils::MemoryResource *m)
|
||||
}
|
||||
}
|
||||
|
||||
mgp_value::mgp_value(const mgp_value &other, utils::MemoryResource *m)
|
||||
: type(other.type), memory(m) {
|
||||
mgp_value::mgp_value(const mgp_value &other, utils::MemoryResource *m) : type(other.type), memory(m) {
|
||||
switch (other.type) {
|
||||
case MGP_VALUE_TYPE_NULL:
|
||||
break;
|
||||
@ -433,8 +413,7 @@ void DeleteValueMember(mgp_value *value) noexcept {
|
||||
|
||||
} // namespace
|
||||
|
||||
mgp_value::mgp_value(mgp_value &&other, utils::MemoryResource *m)
|
||||
: type(other.type), memory(m) {
|
||||
mgp_value::mgp_value(mgp_value &&other, utils::MemoryResource *m) : type(other.type), memory(m) {
|
||||
switch (other.type) {
|
||||
case MGP_VALUE_TYPE_NULL:
|
||||
break;
|
||||
@ -451,8 +430,7 @@ mgp_value::mgp_value(mgp_value &&other, utils::MemoryResource *m)
|
||||
new (&string_v) utils::pmr::string(std::move(other.string_v), m);
|
||||
break;
|
||||
case MGP_VALUE_TYPE_LIST:
|
||||
static_assert(std::is_pointer_v<decltype(list_v)>,
|
||||
"Expected to move list_v by copying pointers.");
|
||||
static_assert(std::is_pointer_v<decltype(list_v)>, "Expected to move list_v by copying pointers.");
|
||||
if (*other.GetMemoryResource() == *m) {
|
||||
list_v = other.list_v;
|
||||
other.type = MGP_VALUE_TYPE_NULL;
|
||||
@ -462,8 +440,7 @@ mgp_value::mgp_value(mgp_value &&other, utils::MemoryResource *m)
|
||||
}
|
||||
break;
|
||||
case MGP_VALUE_TYPE_MAP:
|
||||
static_assert(std::is_pointer_v<decltype(map_v)>,
|
||||
"Expected to move map_v by copying pointers.");
|
||||
static_assert(std::is_pointer_v<decltype(map_v)>, "Expected to move map_v by copying pointers.");
|
||||
if (*other.GetMemoryResource() == *m) {
|
||||
map_v = other.map_v;
|
||||
other.type = MGP_VALUE_TYPE_NULL;
|
||||
@ -473,8 +450,7 @@ mgp_value::mgp_value(mgp_value &&other, utils::MemoryResource *m)
|
||||
}
|
||||
break;
|
||||
case MGP_VALUE_TYPE_VERTEX:
|
||||
static_assert(std::is_pointer_v<decltype(vertex_v)>,
|
||||
"Expected to move vertex_v by copying pointers.");
|
||||
static_assert(std::is_pointer_v<decltype(vertex_v)>, "Expected to move vertex_v by copying pointers.");
|
||||
if (*other.GetMemoryResource() == *m) {
|
||||
vertex_v = other.vertex_v;
|
||||
other.type = MGP_VALUE_TYPE_NULL;
|
||||
@ -484,8 +460,7 @@ mgp_value::mgp_value(mgp_value &&other, utils::MemoryResource *m)
|
||||
}
|
||||
break;
|
||||
case MGP_VALUE_TYPE_EDGE:
|
||||
static_assert(std::is_pointer_v<decltype(edge_v)>,
|
||||
"Expected to move edge_v by copying pointers.");
|
||||
static_assert(std::is_pointer_v<decltype(edge_v)>, "Expected to move edge_v by copying pointers.");
|
||||
if (*other.GetMemoryResource() == *m) {
|
||||
edge_v = other.edge_v;
|
||||
other.type = MGP_VALUE_TYPE_NULL;
|
||||
@ -495,8 +470,7 @@ mgp_value::mgp_value(mgp_value &&other, utils::MemoryResource *m)
|
||||
}
|
||||
break;
|
||||
case MGP_VALUE_TYPE_PATH:
|
||||
static_assert(std::is_pointer_v<decltype(path_v)>,
|
||||
"Expected to move path_v by copying pointers.");
|
||||
static_assert(std::is_pointer_v<decltype(path_v)>, "Expected to move path_v by copying pointers.");
|
||||
if (*other.GetMemoryResource() == *m) {
|
||||
path_v = other.path_v;
|
||||
other.type = MGP_VALUE_TYPE_NULL;
|
||||
@ -514,21 +488,13 @@ mgp_value::~mgp_value() noexcept { DeleteValueMember(this); }
|
||||
|
||||
void mgp_value_destroy(mgp_value *val) { delete_mgp_object(val); }
|
||||
|
||||
mgp_value *mgp_value_make_null(mgp_memory *memory) {
|
||||
return new_mgp_object<mgp_value>(memory);
|
||||
}
|
||||
mgp_value *mgp_value_make_null(mgp_memory *memory) { return new_mgp_object<mgp_value>(memory); }
|
||||
|
||||
mgp_value *mgp_value_make_bool(int val, mgp_memory *memory) {
|
||||
return new_mgp_object<mgp_value>(memory, val != 0);
|
||||
}
|
||||
mgp_value *mgp_value_make_bool(int val, mgp_memory *memory) { return new_mgp_object<mgp_value>(memory, val != 0); }
|
||||
|
||||
mgp_value *mgp_value_make_int(int64_t val, mgp_memory *memory) {
|
||||
return new_mgp_object<mgp_value>(memory, val);
|
||||
}
|
||||
mgp_value *mgp_value_make_int(int64_t val, mgp_memory *memory) { return new_mgp_object<mgp_value>(memory, val); }
|
||||
|
||||
mgp_value *mgp_value_make_double(double val, mgp_memory *memory) {
|
||||
return new_mgp_object<mgp_value>(memory, val);
|
||||
}
|
||||
mgp_value *mgp_value_make_double(double val, mgp_memory *memory) { return new_mgp_object<mgp_value>(memory, val); }
|
||||
|
||||
mgp_value *mgp_value_make_string(const char *val, mgp_memory *memory) {
|
||||
try {
|
||||
@ -540,67 +506,37 @@ mgp_value *mgp_value_make_string(const char *val, mgp_memory *memory) {
|
||||
}
|
||||
}
|
||||
|
||||
mgp_value *mgp_value_make_list(mgp_list *val) {
|
||||
return new_mgp_object<mgp_value>(val->GetMemoryResource(), val);
|
||||
}
|
||||
mgp_value *mgp_value_make_list(mgp_list *val) { return new_mgp_object<mgp_value>(val->GetMemoryResource(), val); }
|
||||
|
||||
mgp_value *mgp_value_make_map(mgp_map *val) {
|
||||
return new_mgp_object<mgp_value>(val->GetMemoryResource(), val);
|
||||
}
|
||||
mgp_value *mgp_value_make_map(mgp_map *val) { return new_mgp_object<mgp_value>(val->GetMemoryResource(), val); }
|
||||
|
||||
mgp_value *mgp_value_make_vertex(mgp_vertex *val) {
|
||||
return new_mgp_object<mgp_value>(val->GetMemoryResource(), val);
|
||||
}
|
||||
mgp_value *mgp_value_make_vertex(mgp_vertex *val) { return new_mgp_object<mgp_value>(val->GetMemoryResource(), val); }
|
||||
|
||||
mgp_value *mgp_value_make_edge(mgp_edge *val) {
|
||||
return new_mgp_object<mgp_value>(val->GetMemoryResource(), val);
|
||||
}
|
||||
mgp_value *mgp_value_make_edge(mgp_edge *val) { return new_mgp_object<mgp_value>(val->GetMemoryResource(), val); }
|
||||
|
||||
mgp_value *mgp_value_make_path(mgp_path *val) {
|
||||
return new_mgp_object<mgp_value>(val->GetMemoryResource(), val);
|
||||
}
|
||||
mgp_value *mgp_value_make_path(mgp_path *val) { return new_mgp_object<mgp_value>(val->GetMemoryResource(), val); }
|
||||
|
||||
mgp_value_type mgp_value_get_type(const mgp_value *val) { return val->type; }
|
||||
|
||||
int mgp_value_is_null(const mgp_value *val) {
|
||||
return mgp_value_get_type(val) == MGP_VALUE_TYPE_NULL;
|
||||
}
|
||||
int mgp_value_is_null(const mgp_value *val) { return mgp_value_get_type(val) == MGP_VALUE_TYPE_NULL; }
|
||||
|
||||
int mgp_value_is_bool(const mgp_value *val) {
|
||||
return mgp_value_get_type(val) == MGP_VALUE_TYPE_BOOL;
|
||||
}
|
||||
int mgp_value_is_bool(const mgp_value *val) { return mgp_value_get_type(val) == MGP_VALUE_TYPE_BOOL; }
|
||||
|
||||
int mgp_value_is_int(const mgp_value *val) {
|
||||
return mgp_value_get_type(val) == MGP_VALUE_TYPE_INT;
|
||||
}
|
||||
int mgp_value_is_int(const mgp_value *val) { return mgp_value_get_type(val) == MGP_VALUE_TYPE_INT; }
|
||||
|
||||
int mgp_value_is_double(const mgp_value *val) {
|
||||
return mgp_value_get_type(val) == MGP_VALUE_TYPE_DOUBLE;
|
||||
}
|
||||
int mgp_value_is_double(const mgp_value *val) { return mgp_value_get_type(val) == MGP_VALUE_TYPE_DOUBLE; }
|
||||
|
||||
int mgp_value_is_string(const mgp_value *val) {
|
||||
return mgp_value_get_type(val) == MGP_VALUE_TYPE_STRING;
|
||||
}
|
||||
int mgp_value_is_string(const mgp_value *val) { return mgp_value_get_type(val) == MGP_VALUE_TYPE_STRING; }
|
||||
|
||||
int mgp_value_is_list(const mgp_value *val) {
|
||||
return mgp_value_get_type(val) == MGP_VALUE_TYPE_LIST;
|
||||
}
|
||||
int mgp_value_is_list(const mgp_value *val) { return mgp_value_get_type(val) == MGP_VALUE_TYPE_LIST; }
|
||||
|
||||
int mgp_value_is_map(const mgp_value *val) {
|
||||
return mgp_value_get_type(val) == MGP_VALUE_TYPE_MAP;
|
||||
}
|
||||
int mgp_value_is_map(const mgp_value *val) { return mgp_value_get_type(val) == MGP_VALUE_TYPE_MAP; }
|
||||
|
||||
int mgp_value_is_vertex(const mgp_value *val) {
|
||||
return mgp_value_get_type(val) == MGP_VALUE_TYPE_VERTEX;
|
||||
}
|
||||
int mgp_value_is_vertex(const mgp_value *val) { return mgp_value_get_type(val) == MGP_VALUE_TYPE_VERTEX; }
|
||||
|
||||
int mgp_value_is_edge(const mgp_value *val) {
|
||||
return mgp_value_get_type(val) == MGP_VALUE_TYPE_EDGE;
|
||||
}
|
||||
int mgp_value_is_edge(const mgp_value *val) { return mgp_value_get_type(val) == MGP_VALUE_TYPE_EDGE; }
|
||||
|
||||
int mgp_value_is_path(const mgp_value *val) {
|
||||
return mgp_value_get_type(val) == MGP_VALUE_TYPE_PATH;
|
||||
}
|
||||
int mgp_value_is_path(const mgp_value *val) { return mgp_value_get_type(val) == MGP_VALUE_TYPE_PATH; }
|
||||
|
||||
int mgp_value_get_bool(const mgp_value *val) { return val->bool_v ? 1 : 0; }
|
||||
|
||||
@ -608,17 +544,13 @@ int64_t mgp_value_get_int(const mgp_value *val) { return val->int_v; }
|
||||
|
||||
double mgp_value_get_double(const mgp_value *val) { return val->double_v; }
|
||||
|
||||
const char *mgp_value_get_string(const mgp_value *val) {
|
||||
return val->string_v.c_str();
|
||||
}
|
||||
const char *mgp_value_get_string(const mgp_value *val) { return val->string_v.c_str(); }
|
||||
|
||||
const mgp_list *mgp_value_get_list(const mgp_value *val) { return val->list_v; }
|
||||
|
||||
const mgp_map *mgp_value_get_map(const mgp_value *val) { return val->map_v; }
|
||||
|
||||
const mgp_vertex *mgp_value_get_vertex(const mgp_value *val) {
|
||||
return val->vertex_v;
|
||||
}
|
||||
const mgp_vertex *mgp_value_get_vertex(const mgp_value *val) { return val->vertex_v; }
|
||||
|
||||
const mgp_edge *mgp_value_get_edge(const mgp_value *val) { return val->edge_v; }
|
||||
|
||||
@ -656,18 +588,14 @@ int mgp_list_append_extend(mgp_list *list, const mgp_value *val) {
|
||||
|
||||
size_t mgp_list_size(const mgp_list *list) { return list->elems.size(); }
|
||||
|
||||
size_t mgp_list_capacity(const mgp_list *list) {
|
||||
return list->elems.capacity();
|
||||
}
|
||||
size_t mgp_list_capacity(const mgp_list *list) { return list->elems.capacity(); }
|
||||
|
||||
const mgp_value *mgp_list_at(const mgp_list *list, size_t i) {
|
||||
if (i >= mgp_list_size(list)) return nullptr;
|
||||
return &list->elems[i];
|
||||
}
|
||||
|
||||
mgp_map *mgp_map_make_empty(mgp_memory *memory) {
|
||||
return new_mgp_object<mgp_map>(memory);
|
||||
}
|
||||
mgp_map *mgp_map_make_empty(mgp_memory *memory) { return new_mgp_object<mgp_map>(memory); }
|
||||
|
||||
void mgp_map_destroy(mgp_map *map) { delete_mgp_object(map); }
|
||||
|
||||
@ -693,21 +621,15 @@ const mgp_value *mgp_map_at(const mgp_map *map, const char *key) {
|
||||
|
||||
const char *mgp_map_item_key(const mgp_map_item *item) { return item->key; }
|
||||
|
||||
const mgp_value *mgp_map_item_value(const mgp_map_item *item) {
|
||||
return item->value;
|
||||
}
|
||||
const mgp_value *mgp_map_item_value(const mgp_map_item *item) { return item->value; }
|
||||
|
||||
mgp_map_items_iterator *mgp_map_iter_items(const mgp_map *map,
|
||||
mgp_memory *memory) {
|
||||
mgp_map_items_iterator *mgp_map_iter_items(const mgp_map *map, mgp_memory *memory) {
|
||||
return new_mgp_object<mgp_map_items_iterator>(memory, map);
|
||||
}
|
||||
|
||||
void mgp_map_items_iterator_destroy(mgp_map_items_iterator *it) {
|
||||
delete_mgp_object(it);
|
||||
}
|
||||
void mgp_map_items_iterator_destroy(mgp_map_items_iterator *it) { delete_mgp_object(it); }
|
||||
|
||||
const mgp_map_item *mgp_map_items_iterator_get(
|
||||
const mgp_map_items_iterator *it) {
|
||||
const mgp_map_item *mgp_map_items_iterator_get(const mgp_map_items_iterator *it) {
|
||||
if (it->current_it == it->map->items.end()) return nullptr;
|
||||
return &it->current;
|
||||
}
|
||||
@ -720,8 +642,7 @@ const mgp_map_item *mgp_map_items_iterator_next(mgp_map_items_iterator *it) {
|
||||
return &it->current;
|
||||
}
|
||||
|
||||
mgp_path *mgp_path_make_with_start(const mgp_vertex *vertex,
|
||||
mgp_memory *memory) {
|
||||
mgp_path *mgp_path_make_with_start(const mgp_vertex *vertex, mgp_memory *memory) {
|
||||
auto *path = new_mgp_object<mgp_path>(memory);
|
||||
if (!path) return nullptr;
|
||||
try {
|
||||
@ -734,16 +655,14 @@ mgp_path *mgp_path_make_with_start(const mgp_vertex *vertex,
|
||||
}
|
||||
|
||||
mgp_path *mgp_path_copy(const mgp_path *path, mgp_memory *memory) {
|
||||
MG_ASSERT(mgp_path_size(path) == path->vertices.size() - 1,
|
||||
"Invalid mgp_path");
|
||||
MG_ASSERT(mgp_path_size(path) == path->vertices.size() - 1, "Invalid mgp_path");
|
||||
return new_mgp_object<mgp_path>(memory, *path);
|
||||
}
|
||||
|
||||
void mgp_path_destroy(mgp_path *path) { delete_mgp_object(path); }
|
||||
|
||||
int mgp_path_expand(mgp_path *path, const mgp_edge *edge) {
|
||||
MG_ASSERT(mgp_path_size(path) == path->vertices.size() - 1,
|
||||
"Invalid mgp_path");
|
||||
MG_ASSERT(mgp_path_size(path) == path->vertices.size() - 1, "Invalid mgp_path");
|
||||
// Check that the both the last vertex on path and dst_vertex are endpoints of
|
||||
// the given edge.
|
||||
const auto *src_vertex = &path->vertices.back();
|
||||
@ -820,17 +739,15 @@ mgp_result_record *mgp_result_new_record(mgp_result *res) {
|
||||
auto *memory = res->rows.get_allocator().GetMemoryResource();
|
||||
MG_ASSERT(res->signature, "Expected to have a valid signature");
|
||||
try {
|
||||
res->rows.push_back(mgp_result_record{
|
||||
res->signature,
|
||||
utils::pmr::map<utils::pmr::string, query::TypedValue>(memory)});
|
||||
res->rows.push_back(
|
||||
mgp_result_record{res->signature, utils::pmr::map<utils::pmr::string, query::TypedValue>(memory)});
|
||||
} catch (...) {
|
||||
return nullptr;
|
||||
}
|
||||
return &res->rows.back();
|
||||
}
|
||||
|
||||
int mgp_result_record_insert(mgp_result_record *record, const char *field_name,
|
||||
const mgp_value *val) {
|
||||
int mgp_result_record_insert(mgp_result_record *record, const char *field_name, const mgp_value *val) {
|
||||
auto *memory = record->values.get_allocator().GetMemoryResource();
|
||||
// Validate field_name & val satisfy the procedure's result signature.
|
||||
MG_ASSERT(record->signature, "Expected to have a valid signature");
|
||||
@ -848,12 +765,9 @@ int mgp_result_record_insert(mgp_result_record *record, const char *field_name,
|
||||
|
||||
/// Graph Constructs
|
||||
|
||||
void mgp_properties_iterator_destroy(mgp_properties_iterator *it) {
|
||||
delete_mgp_object(it);
|
||||
}
|
||||
void mgp_properties_iterator_destroy(mgp_properties_iterator *it) { delete_mgp_object(it); }
|
||||
|
||||
const mgp_property *mgp_properties_iterator_get(
|
||||
const mgp_properties_iterator *it) {
|
||||
const mgp_property *mgp_properties_iterator_get(const mgp_properties_iterator *it) {
|
||||
if (it->current) return &it->property;
|
||||
return nullptr;
|
||||
}
|
||||
@ -878,9 +792,7 @@ const mgp_property *mgp_properties_iterator_next(mgp_properties_iterator *it) {
|
||||
return nullptr;
|
||||
}
|
||||
it->current.emplace(
|
||||
utils::pmr::string(
|
||||
it->graph->impl->PropertyToName(it->current_it->first),
|
||||
it->GetMemoryResource()),
|
||||
utils::pmr::string(it->graph->impl->PropertyToName(it->current_it->first), it->GetMemoryResource()),
|
||||
mgp_value(it->current_it->second, it->GetMemoryResource()));
|
||||
it->property.name = it->current->first.c_str();
|
||||
it->property.value = &it->current->second;
|
||||
@ -891,19 +803,13 @@ const mgp_property *mgp_properties_iterator_next(mgp_properties_iterator *it) {
|
||||
}
|
||||
}
|
||||
|
||||
mgp_vertex_id mgp_vertex_get_id(const mgp_vertex *v) {
|
||||
return mgp_vertex_id{.as_int = v->impl.Gid().AsInt()};
|
||||
}
|
||||
mgp_vertex_id mgp_vertex_get_id(const mgp_vertex *v) { return mgp_vertex_id{.as_int = v->impl.Gid().AsInt()}; }
|
||||
|
||||
mgp_vertex *mgp_vertex_copy(const mgp_vertex *v, mgp_memory *memory) {
|
||||
return new_mgp_object<mgp_vertex>(memory, *v);
|
||||
}
|
||||
mgp_vertex *mgp_vertex_copy(const mgp_vertex *v, mgp_memory *memory) { return new_mgp_object<mgp_vertex>(memory, *v); }
|
||||
|
||||
void mgp_vertex_destroy(mgp_vertex *v) { delete_mgp_object(v); }
|
||||
|
||||
int mgp_vertex_equal(const mgp_vertex *a, const mgp_vertex *b) {
|
||||
return a->impl == b->impl ? 1 : 0;
|
||||
}
|
||||
int mgp_vertex_equal(const mgp_vertex *a, const mgp_vertex *b) { return a->impl == b->impl ? 1 : 0; }
|
||||
|
||||
size_t mgp_vertex_labels_count(const mgp_vertex *v) {
|
||||
auto maybe_labels = v->impl.Labels(v->graph->view);
|
||||
@ -940,8 +846,7 @@ mgp_label mgp_vertex_label_at(const mgp_vertex *v, size_t i) {
|
||||
}
|
||||
if (i >= maybe_labels->size()) return mgp_label{nullptr};
|
||||
const auto &label = (*maybe_labels)[i];
|
||||
static_assert(
|
||||
std::is_lvalue_reference_v<decltype(v->graph->impl->LabelToName(label))>,
|
||||
static_assert(std::is_lvalue_reference_v<decltype(v->graph->impl->LabelToName(label))>,
|
||||
"Expected LabelToName to return a pointer or reference, so we "
|
||||
"don't have to take a copy and manage memory.");
|
||||
const auto &name = v->graph->impl->LabelToName(label);
|
||||
@ -979,12 +884,9 @@ int mgp_vertex_has_label_named(const mgp_vertex *v, const char *name) {
|
||||
return *maybe_has_label;
|
||||
}
|
||||
|
||||
int mgp_vertex_has_label(const mgp_vertex *v, mgp_label label) {
|
||||
return mgp_vertex_has_label_named(v, label.name);
|
||||
}
|
||||
int mgp_vertex_has_label(const mgp_vertex *v, mgp_label label) { return mgp_vertex_has_label_named(v, label.name); }
|
||||
|
||||
mgp_value *mgp_vertex_get_property(const mgp_vertex *v, const char *name,
|
||||
mgp_memory *memory) {
|
||||
mgp_value *mgp_vertex_get_property(const mgp_vertex *v, const char *name, mgp_memory *memory) {
|
||||
try {
|
||||
const auto &key = v->graph->impl->NameToProperty(name);
|
||||
auto maybe_prop = v->impl.GetProperty(v->graph->view, key);
|
||||
@ -1009,8 +911,7 @@ mgp_value *mgp_vertex_get_property(const mgp_vertex *v, const char *name,
|
||||
}
|
||||
}
|
||||
|
||||
mgp_properties_iterator *mgp_vertex_iter_properties(const mgp_vertex *v,
|
||||
mgp_memory *memory) {
|
||||
mgp_properties_iterator *mgp_vertex_iter_properties(const mgp_vertex *v, mgp_memory *memory) {
|
||||
// NOTE: This copies the whole properties into the iterator.
|
||||
// TODO: Think of a good way to avoid the copy which doesn't just rely on some
|
||||
// assumption that storage may return a pointer to the property store. This
|
||||
@ -1030,8 +931,7 @@ mgp_properties_iterator *mgp_vertex_iter_properties(const mgp_vertex *v,
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return new_mgp_object<mgp_properties_iterator>(memory, v->graph,
|
||||
std::move(*maybe_props));
|
||||
return new_mgp_object<mgp_properties_iterator>(memory, v->graph, std::move(*maybe_props));
|
||||
} catch (...) {
|
||||
// Since we are copying stuff, we may get std::bad_alloc. Hopefully, no
|
||||
// other exceptions are possible, but catch them all just in case.
|
||||
@ -1039,12 +939,9 @@ mgp_properties_iterator *mgp_vertex_iter_properties(const mgp_vertex *v,
|
||||
}
|
||||
}
|
||||
|
||||
void mgp_edges_iterator_destroy(mgp_edges_iterator *it) {
|
||||
delete_mgp_object(it);
|
||||
}
|
||||
void mgp_edges_iterator_destroy(mgp_edges_iterator *it) { delete_mgp_object(it); }
|
||||
|
||||
mgp_edges_iterator *mgp_vertex_iter_in_edges(const mgp_vertex *v,
|
||||
mgp_memory *memory) {
|
||||
mgp_edges_iterator *mgp_vertex_iter_in_edges(const mgp_vertex *v, mgp_memory *memory) {
|
||||
auto *it = new_mgp_object<mgp_edges_iterator>(memory, *v);
|
||||
if (!it) return nullptr;
|
||||
try {
|
||||
@ -1076,8 +973,7 @@ mgp_edges_iterator *mgp_vertex_iter_in_edges(const mgp_vertex *v,
|
||||
return it;
|
||||
}
|
||||
|
||||
mgp_edges_iterator *mgp_vertex_iter_out_edges(const mgp_vertex *v,
|
||||
mgp_memory *memory) {
|
||||
mgp_edges_iterator *mgp_vertex_iter_out_edges(const mgp_vertex *v, mgp_memory *memory) {
|
||||
auto *it = new_mgp_object<mgp_edges_iterator>(memory, *v);
|
||||
if (!it) return nullptr;
|
||||
try {
|
||||
@ -1127,8 +1023,7 @@ const mgp_edge *mgp_edges_iterator_next(mgp_edges_iterator *it) {
|
||||
it->current_e = std::nullopt;
|
||||
return nullptr;
|
||||
}
|
||||
it->current_e.emplace(**impl_it, it->source_vertex.graph,
|
||||
it->GetMemoryResource());
|
||||
it->current_e.emplace(**impl_it, it->source_vertex.graph, it->GetMemoryResource());
|
||||
return &*it->current_e;
|
||||
};
|
||||
try {
|
||||
@ -1144,9 +1039,7 @@ const mgp_edge *mgp_edges_iterator_next(mgp_edges_iterator *it) {
|
||||
}
|
||||
}
|
||||
|
||||
mgp_edge_id mgp_edge_get_id(const mgp_edge *e) {
|
||||
return mgp_edge_id{.as_int = e->impl.Gid().AsInt()};
|
||||
}
|
||||
mgp_edge_id mgp_edge_get_id(const mgp_edge *e) { return mgp_edge_id{.as_int = e->impl.Gid().AsInt()}; }
|
||||
|
||||
mgp_edge *mgp_edge_copy(const mgp_edge *v, mgp_memory *memory) {
|
||||
return new_mgp_object<mgp_edge>(memory, v->impl, v->from.graph);
|
||||
@ -1154,15 +1047,11 @@ mgp_edge *mgp_edge_copy(const mgp_edge *v, mgp_memory *memory) {
|
||||
|
||||
void mgp_edge_destroy(mgp_edge *e) { delete_mgp_object(e); }
|
||||
|
||||
int mgp_edge_equal(const struct mgp_edge *e1, const struct mgp_edge *e2) {
|
||||
return e1->impl == e2->impl ? 1 : 0;
|
||||
}
|
||||
int mgp_edge_equal(const struct mgp_edge *e1, const struct mgp_edge *e2) { return e1->impl == e2->impl ? 1 : 0; }
|
||||
|
||||
mgp_edge_type mgp_edge_get_type(const mgp_edge *e) {
|
||||
const auto &name = e->from.graph->impl->EdgeTypeToName(e->impl.EdgeType());
|
||||
static_assert(
|
||||
std::is_lvalue_reference_v<decltype(
|
||||
e->from.graph->impl->EdgeTypeToName(e->impl.EdgeType()))>,
|
||||
static_assert(std::is_lvalue_reference_v<decltype(e->from.graph->impl->EdgeTypeToName(e->impl.EdgeType()))>,
|
||||
"Expected EdgeTypeToName to return a pointer or reference, so we "
|
||||
"don't have to take a copy and manage memory.");
|
||||
return mgp_edge_type{name.c_str()};
|
||||
@ -1172,8 +1061,7 @@ const mgp_vertex *mgp_edge_get_from(const mgp_edge *e) { return &e->from; }
|
||||
|
||||
const mgp_vertex *mgp_edge_get_to(const mgp_edge *e) { return &e->to; }
|
||||
|
||||
mgp_value *mgp_edge_get_property(const mgp_edge *e, const char *name,
|
||||
mgp_memory *memory) {
|
||||
mgp_value *mgp_edge_get_property(const mgp_edge *e, const char *name, mgp_memory *memory) {
|
||||
try {
|
||||
const auto &key = e->from.graph->impl->NameToProperty(name);
|
||||
auto view = e->from.graph->view;
|
||||
@ -1199,8 +1087,7 @@ mgp_value *mgp_edge_get_property(const mgp_edge *e, const char *name,
|
||||
}
|
||||
}
|
||||
|
||||
mgp_properties_iterator *mgp_edge_iter_properties(const mgp_edge *e,
|
||||
mgp_memory *memory) {
|
||||
mgp_properties_iterator *mgp_edge_iter_properties(const mgp_edge *e, mgp_memory *memory) {
|
||||
// NOTE: This copies the whole properties into iterator.
|
||||
// TODO: Think of a good way to avoid the copy which doesn't just rely on some
|
||||
// assumption that storage may return a pointer to the property store. This
|
||||
@ -1221,8 +1108,7 @@ mgp_properties_iterator *mgp_edge_iter_properties(const mgp_edge *e,
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return new_mgp_object<mgp_properties_iterator>(memory, e->from.graph,
|
||||
std::move(*maybe_props));
|
||||
return new_mgp_object<mgp_properties_iterator>(memory, e->from.graph, std::move(*maybe_props));
|
||||
} catch (...) {
|
||||
// Since we are copying stuff, we may get std::bad_alloc. Hopefully, no
|
||||
// other exceptions are possible, but catch them all just in case.
|
||||
@ -1230,21 +1116,15 @@ mgp_properties_iterator *mgp_edge_iter_properties(const mgp_edge *e,
|
||||
}
|
||||
}
|
||||
|
||||
mgp_vertex *mgp_graph_get_vertex_by_id(const mgp_graph *graph, mgp_vertex_id id,
|
||||
mgp_memory *memory) {
|
||||
auto maybe_vertex =
|
||||
graph->impl->FindVertex(storage::Gid::FromInt(id.as_int), graph->view);
|
||||
if (maybe_vertex)
|
||||
return new_mgp_object<mgp_vertex>(memory, *maybe_vertex, graph);
|
||||
mgp_vertex *mgp_graph_get_vertex_by_id(const mgp_graph *graph, mgp_vertex_id id, mgp_memory *memory) {
|
||||
auto maybe_vertex = graph->impl->FindVertex(storage::Gid::FromInt(id.as_int), graph->view);
|
||||
if (maybe_vertex) return new_mgp_object<mgp_vertex>(memory, *maybe_vertex, graph);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void mgp_vertices_iterator_destroy(mgp_vertices_iterator *it) {
|
||||
delete_mgp_object(it);
|
||||
}
|
||||
void mgp_vertices_iterator_destroy(mgp_vertices_iterator *it) { delete_mgp_object(it); }
|
||||
|
||||
mgp_vertices_iterator *mgp_graph_iter_vertices(const mgp_graph *graph,
|
||||
mgp_memory *memory) {
|
||||
mgp_vertices_iterator *mgp_graph_iter_vertices(const mgp_graph *graph, mgp_memory *memory) {
|
||||
try {
|
||||
return new_mgp_object<mgp_vertices_iterator>(memory, graph);
|
||||
} catch (...) {
|
||||
@ -1337,8 +1217,7 @@ const mgp_type *mgp_type_node() {
|
||||
|
||||
const mgp_type *mgp_type_relationship() {
|
||||
static RelationshipType impl;
|
||||
static mgp_type relationship_type{
|
||||
CypherTypePtr(&impl, NoOpCypherTypeDeleter)};
|
||||
static mgp_type relationship_type{CypherTypePtr(&impl, NoOpCypherTypeDeleter)};
|
||||
return &relationship_type;
|
||||
}
|
||||
|
||||
@ -1351,8 +1230,7 @@ const mgp_type *mgp_type_path() {
|
||||
const mgp_type *mgp_type_list(const mgp_type *type) {
|
||||
if (!type) return nullptr;
|
||||
// Maps `type` to corresponding instance of ListType.
|
||||
static utils::pmr::map<const mgp_type *, mgp_type> list_types(
|
||||
utils::NewDeleteResource());
|
||||
static utils::pmr::map<const mgp_type *, mgp_type> list_types(utils::NewDeleteResource());
|
||||
static utils::SpinLock lock;
|
||||
std::lock_guard<utils::SpinLock> guard(lock);
|
||||
auto found_it = list_types.find(type);
|
||||
@ -1362,11 +1240,8 @@ const mgp_type *mgp_type_list(const mgp_type *type) {
|
||||
CypherTypePtr impl(
|
||||
alloc.new_object<ListType>(
|
||||
// Just obtain the pointer to original impl, don't own it.
|
||||
CypherTypePtr(type->impl.get(), NoOpCypherTypeDeleter),
|
||||
alloc.GetMemoryResource()),
|
||||
[alloc](CypherType *base_ptr) mutable {
|
||||
alloc.delete_object(static_cast<ListType *>(base_ptr));
|
||||
});
|
||||
CypherTypePtr(type->impl.get(), NoOpCypherTypeDeleter), alloc.GetMemoryResource()),
|
||||
[alloc](CypherType *base_ptr) mutable { alloc.delete_object(static_cast<ListType *>(base_ptr)); });
|
||||
return &list_types.emplace(type, mgp_type{std::move(impl)}).first->second;
|
||||
} catch (const std::bad_alloc &) {
|
||||
return nullptr;
|
||||
@ -1376,34 +1251,28 @@ const mgp_type *mgp_type_list(const mgp_type *type) {
|
||||
const mgp_type *mgp_type_nullable(const mgp_type *type) {
|
||||
if (!type) return nullptr;
|
||||
// Maps `type` to corresponding instance of NullableType.
|
||||
static utils::pmr::map<const mgp_type *, mgp_type> gNullableTypes(
|
||||
utils::NewDeleteResource());
|
||||
static utils::pmr::map<const mgp_type *, mgp_type> gNullableTypes(utils::NewDeleteResource());
|
||||
static utils::SpinLock lock;
|
||||
std::lock_guard<utils::SpinLock> guard(lock);
|
||||
auto found_it = gNullableTypes.find(type);
|
||||
if (found_it != gNullableTypes.end()) return &found_it->second;
|
||||
try {
|
||||
auto alloc = gNullableTypes.get_allocator();
|
||||
auto impl = NullableType::Create(
|
||||
CypherTypePtr(type->impl.get(), NoOpCypherTypeDeleter),
|
||||
alloc.GetMemoryResource());
|
||||
return &gNullableTypes.emplace(type, mgp_type{std::move(impl)})
|
||||
.first->second;
|
||||
auto impl = NullableType::Create(CypherTypePtr(type->impl.get(), NoOpCypherTypeDeleter), alloc.GetMemoryResource());
|
||||
return &gNullableTypes.emplace(type, mgp_type{std::move(impl)}).first->second;
|
||||
} catch (const std::bad_alloc &) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
mgp_proc *mgp_module_add_read_procedure(mgp_module *module, const char *name,
|
||||
mgp_proc_cb cb) {
|
||||
mgp_proc *mgp_module_add_read_procedure(mgp_module *module, const char *name, mgp_proc_cb cb) {
|
||||
if (!module || !cb) return nullptr;
|
||||
if (!IsValidIdentifierName(name)) return nullptr;
|
||||
if (module->procedures.find(name) != module->procedures.end()) return nullptr;
|
||||
try {
|
||||
auto *memory = module->procedures.get_allocator().GetMemoryResource();
|
||||
// May throw std::bad_alloc, std::length_error
|
||||
return &module->procedures.emplace(name, mgp_proc(name, cb, memory))
|
||||
.first->second;
|
||||
return &module->procedures.emplace(name, mgp_proc(name, cb, memory)).first->second;
|
||||
} catch (...) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -1421,8 +1290,7 @@ int mgp_proc_add_arg(mgp_proc *proc, const char *name, const mgp_type *type) {
|
||||
}
|
||||
}
|
||||
|
||||
int mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, const mgp_type *type,
|
||||
const mgp_value *default_value) {
|
||||
int mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, const mgp_type *type, const mgp_value *default_value) {
|
||||
if (!proc || !type || !default_value) return 0;
|
||||
if (!IsValidIdentifierName(name)) return 0;
|
||||
switch (mgp_value_get_type(default_value)) {
|
||||
@ -1444,8 +1312,7 @@ int mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, const mgp_type *type,
|
||||
if (!type->impl->SatisfiesType(*default_value)) return 0;
|
||||
auto *memory = proc->opt_args.get_allocator().GetMemoryResource();
|
||||
try {
|
||||
proc->opt_args.emplace_back(utils::pmr::string(name, memory),
|
||||
type->impl.get(),
|
||||
proc->opt_args.emplace_back(utils::pmr::string(name, memory), type->impl.get(),
|
||||
ToTypedValue(*default_value, memory));
|
||||
return 1;
|
||||
} catch (...) {
|
||||
@ -1455,15 +1322,13 @@ int mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, const mgp_type *type,
|
||||
|
||||
namespace {
|
||||
|
||||
int AddResultToProc(mgp_proc *proc, const char *name, const mgp_type *type,
|
||||
bool is_deprecated) {
|
||||
int AddResultToProc(mgp_proc *proc, const char *name, const mgp_type *type, bool is_deprecated) {
|
||||
if (!proc || !type) return 0;
|
||||
if (!IsValidIdentifierName(name)) return 0;
|
||||
if (proc->results.find(name) != proc->results.end()) return 0;
|
||||
try {
|
||||
auto *memory = proc->results.get_allocator().GetMemoryResource();
|
||||
proc->results.emplace(utils::pmr::string(name, memory),
|
||||
std::make_pair(type->impl.get(), is_deprecated));
|
||||
proc->results.emplace(utils::pmr::string(name, memory), std::make_pair(type->impl.get(), is_deprecated));
|
||||
return 1;
|
||||
} catch (...) {
|
||||
return 0;
|
||||
@ -1472,13 +1337,11 @@ int AddResultToProc(mgp_proc *proc, const char *name, const mgp_type *type,
|
||||
|
||||
} // namespace
|
||||
|
||||
int mgp_proc_add_result(mgp_proc *proc, const char *name,
|
||||
const mgp_type *type) {
|
||||
int mgp_proc_add_result(mgp_proc *proc, const char *name, const mgp_type *type) {
|
||||
return AddResultToProc(proc, name, type, false);
|
||||
}
|
||||
|
||||
int mgp_proc_add_deprecated_result(mgp_proc *proc, const char *name,
|
||||
const mgp_type *type) {
|
||||
int mgp_proc_add_deprecated_result(mgp_proc *proc, const char *name, const mgp_type *type) {
|
||||
return AddResultToProc(proc, name, type, true);
|
||||
}
|
||||
|
||||
@ -1509,14 +1372,12 @@ std::ostream &PrintValue(const TypedValue &value, std::ostream *stream) {
|
||||
return (*stream) << utils::Escape(value.ValueString());
|
||||
case TypedValue::Type::List:
|
||||
(*stream) << "[";
|
||||
utils::PrintIterable(
|
||||
*stream, value.ValueList(), ", ",
|
||||
utils::PrintIterable(*stream, value.ValueList(), ", ",
|
||||
[](auto &stream, const auto &elem) { PrintValue(elem, &stream); });
|
||||
return (*stream) << "]";
|
||||
case TypedValue::Type::Map:
|
||||
(*stream) << "{";
|
||||
utils::PrintIterable(*stream, value.ValueMap(), ", ",
|
||||
[](auto &stream, const auto &item) {
|
||||
utils::PrintIterable(*stream, value.ValueMap(), ", ", [](auto &stream, const auto &item) {
|
||||
// Map keys are not escaped strings.
|
||||
stream << item.first << ": ";
|
||||
PrintValue(item.second, &stream);
|
||||
@ -1533,20 +1394,16 @@ std::ostream &PrintValue(const TypedValue &value, std::ostream *stream) {
|
||||
|
||||
void PrintProcSignature(const mgp_proc &proc, std::ostream *stream) {
|
||||
(*stream) << proc.name << "(";
|
||||
utils::PrintIterable(
|
||||
*stream, proc.args, ", ", [](auto &stream, const auto &arg) {
|
||||
utils::PrintIterable(*stream, proc.args, ", ", [](auto &stream, const auto &arg) {
|
||||
stream << arg.first << " :: " << arg.second->GetPresentableName();
|
||||
});
|
||||
if (!proc.args.empty() && !proc.opt_args.empty()) (*stream) << ", ";
|
||||
utils::PrintIterable(
|
||||
*stream, proc.opt_args, ", ", [](auto &stream, const auto &arg) {
|
||||
utils::PrintIterable(*stream, proc.opt_args, ", ", [](auto &stream, const auto &arg) {
|
||||
stream << std::get<0>(arg) << " = ";
|
||||
PrintValue(std::get<2>(arg), &stream)
|
||||
<< " :: " << std::get<1>(arg)->GetPresentableName();
|
||||
PrintValue(std::get<2>(arg), &stream) << " :: " << std::get<1>(arg)->GetPresentableName();
|
||||
});
|
||||
(*stream) << ") :: (";
|
||||
utils::PrintIterable(
|
||||
*stream, proc.results, ", ", [](auto &stream, const auto &name_result) {
|
||||
utils::PrintIterable(*stream, proc.results, ", ", [](auto &stream, const auto &name_result) {
|
||||
const auto &[type, is_deprecated] = name_result.second;
|
||||
if (is_deprecated) stream << "DEPRECATED ";
|
||||
stream << name_result.first << " :: " << type->GetPresentableName();
|
||||
|
@ -56,8 +56,7 @@ struct mgp_value {
|
||||
/// Construct by copying query::TypedValue using utils::MemoryResource.
|
||||
/// mgp_graph is needed to construct mgp_vertex and mgp_edge.
|
||||
/// @throw std::bad_alloc
|
||||
mgp_value(const query::TypedValue &, const mgp_graph *,
|
||||
utils::MemoryResource *);
|
||||
mgp_value(const query::TypedValue &, const mgp_graph *, utils::MemoryResource *);
|
||||
|
||||
/// Construct by copying storage::PropertyValue using utils::MemoryResource.
|
||||
/// @throw std::bad_alloc
|
||||
@ -112,14 +111,11 @@ struct mgp_list {
|
||||
|
||||
explicit mgp_list(utils::MemoryResource *memory) : elems(memory) {}
|
||||
|
||||
mgp_list(utils::pmr::vector<mgp_value> &&elems, utils::MemoryResource *memory)
|
||||
: elems(std::move(elems), memory) {}
|
||||
mgp_list(utils::pmr::vector<mgp_value> &&elems, utils::MemoryResource *memory) : elems(std::move(elems), memory) {}
|
||||
|
||||
mgp_list(const mgp_list &other, utils::MemoryResource *memory)
|
||||
: elems(other.elems, memory) {}
|
||||
mgp_list(const mgp_list &other, utils::MemoryResource *memory) : elems(other.elems, memory) {}
|
||||
|
||||
mgp_list(mgp_list &&other, utils::MemoryResource *memory)
|
||||
: elems(std::move(other.elems), memory) {}
|
||||
mgp_list(mgp_list &&other, utils::MemoryResource *memory) : elems(std::move(other.elems), memory) {}
|
||||
|
||||
mgp_list(mgp_list &&other) noexcept : elems(std::move(other.elems)) {}
|
||||
|
||||
@ -131,9 +127,7 @@ struct mgp_list {
|
||||
|
||||
~mgp_list() = default;
|
||||
|
||||
utils::MemoryResource *GetMemoryResource() const noexcept {
|
||||
return elems.get_allocator().GetMemoryResource();
|
||||
}
|
||||
utils::MemoryResource *GetMemoryResource() const noexcept { return elems.get_allocator().GetMemoryResource(); }
|
||||
|
||||
// C++17 vector can work with incomplete type.
|
||||
utils::pmr::vector<mgp_value> elems;
|
||||
@ -145,15 +139,12 @@ struct mgp_map {
|
||||
|
||||
explicit mgp_map(utils::MemoryResource *memory) : items(memory) {}
|
||||
|
||||
mgp_map(utils::pmr::map<utils::pmr::string, mgp_value> &&items,
|
||||
utils::MemoryResource *memory)
|
||||
mgp_map(utils::pmr::map<utils::pmr::string, mgp_value> &&items, utils::MemoryResource *memory)
|
||||
: items(std::move(items), memory) {}
|
||||
|
||||
mgp_map(const mgp_map &other, utils::MemoryResource *memory)
|
||||
: items(other.items, memory) {}
|
||||
mgp_map(const mgp_map &other, utils::MemoryResource *memory) : items(other.items, memory) {}
|
||||
|
||||
mgp_map(mgp_map &&other, utils::MemoryResource *memory)
|
||||
: items(std::move(other.items), memory) {}
|
||||
mgp_map(mgp_map &&other, utils::MemoryResource *memory) : items(std::move(other.items), memory) {}
|
||||
|
||||
mgp_map(mgp_map &&other) noexcept : items(std::move(other.items)) {}
|
||||
|
||||
@ -165,9 +156,7 @@ struct mgp_map {
|
||||
|
||||
~mgp_map() = default;
|
||||
|
||||
utils::MemoryResource *GetMemoryResource() const noexcept {
|
||||
return items.get_allocator().GetMemoryResource();
|
||||
}
|
||||
utils::MemoryResource *GetMemoryResource() const noexcept { return items.get_allocator().GetMemoryResource(); }
|
||||
|
||||
// Unfortunately using incomplete type with map is undefined, so mgp_map
|
||||
// needs to be defined after mgp_value.
|
||||
@ -215,8 +204,7 @@ struct mgp_vertex {
|
||||
// have everything noexcept here.
|
||||
static_assert(std::is_nothrow_copy_constructible_v<query::VertexAccessor>);
|
||||
|
||||
mgp_vertex(query::VertexAccessor v, const mgp_graph *graph,
|
||||
utils::MemoryResource *memory) noexcept
|
||||
mgp_vertex(query::VertexAccessor v, const mgp_graph *graph, utils::MemoryResource *memory) noexcept
|
||||
: memory(memory), impl(v), graph(graph) {}
|
||||
|
||||
mgp_vertex(const mgp_vertex &other, utils::MemoryResource *memory) noexcept
|
||||
@ -225,8 +213,7 @@ struct mgp_vertex {
|
||||
mgp_vertex(mgp_vertex &&other, utils::MemoryResource *memory) noexcept
|
||||
: memory(memory), impl(other.impl), graph(other.graph) {}
|
||||
|
||||
mgp_vertex(mgp_vertex &&other) noexcept
|
||||
: memory(other.memory), impl(other.impl), graph(other.graph) {}
|
||||
mgp_vertex(mgp_vertex &&other) noexcept : memory(other.memory), impl(other.impl), graph(other.graph) {}
|
||||
|
||||
/// Copy construction without utils::MemoryResource is not allowed.
|
||||
mgp_vertex(const mgp_vertex &) = delete;
|
||||
@ -253,30 +240,17 @@ struct mgp_edge {
|
||||
// have everything noexcept here.
|
||||
static_assert(std::is_nothrow_copy_constructible_v<query::EdgeAccessor>);
|
||||
|
||||
mgp_edge(const query::EdgeAccessor &impl, const mgp_graph *graph,
|
||||
utils::MemoryResource *memory) noexcept
|
||||
: memory(memory),
|
||||
impl(impl),
|
||||
from(impl.From(), graph, memory),
|
||||
to(impl.To(), graph, memory) {}
|
||||
mgp_edge(const query::EdgeAccessor &impl, const mgp_graph *graph, utils::MemoryResource *memory) noexcept
|
||||
: memory(memory), impl(impl), from(impl.From(), graph, memory), to(impl.To(), graph, memory) {}
|
||||
|
||||
mgp_edge(const mgp_edge &other, utils::MemoryResource *memory) noexcept
|
||||
: memory(memory),
|
||||
impl(other.impl),
|
||||
from(other.from, memory),
|
||||
to(other.to, memory) {}
|
||||
: memory(memory), impl(other.impl), from(other.from, memory), to(other.to, memory) {}
|
||||
|
||||
mgp_edge(mgp_edge &&other, utils::MemoryResource *memory) noexcept
|
||||
: memory(other.memory),
|
||||
impl(other.impl),
|
||||
from(std::move(other.from), memory),
|
||||
to(std::move(other.to), memory) {}
|
||||
: memory(other.memory), impl(other.impl), from(std::move(other.from), memory), to(std::move(other.to), memory) {}
|
||||
|
||||
mgp_edge(mgp_edge &&other) noexcept
|
||||
: memory(other.memory),
|
||||
impl(other.impl),
|
||||
from(std::move(other.from)),
|
||||
to(std::move(other.to)) {}
|
||||
: memory(other.memory), impl(other.impl), from(std::move(other.from)), to(std::move(other.to)) {}
|
||||
|
||||
/// Copy construction without utils::MemoryResource is not allowed.
|
||||
mgp_edge(const mgp_edge &) = delete;
|
||||
@ -298,18 +272,15 @@ struct mgp_path {
|
||||
/// Allocator type so that STL containers are aware that we need one.
|
||||
using allocator_type = utils::Allocator<mgp_path>;
|
||||
|
||||
explicit mgp_path(utils::MemoryResource *memory)
|
||||
: vertices(memory), edges(memory) {}
|
||||
explicit mgp_path(utils::MemoryResource *memory) : vertices(memory), edges(memory) {}
|
||||
|
||||
mgp_path(const mgp_path &other, utils::MemoryResource *memory)
|
||||
: vertices(other.vertices, memory), edges(other.edges, memory) {}
|
||||
|
||||
mgp_path(mgp_path &&other, utils::MemoryResource *memory)
|
||||
: vertices(std::move(other.vertices), memory),
|
||||
edges(std::move(other.edges), memory) {}
|
||||
: vertices(std::move(other.vertices), memory), edges(std::move(other.edges), memory) {}
|
||||
|
||||
mgp_path(mgp_path &&other) noexcept
|
||||
: vertices(std::move(other.vertices)), edges(std::move(other.edges)) {}
|
||||
mgp_path(mgp_path &&other) noexcept : vertices(std::move(other.vertices)), edges(std::move(other.edges)) {}
|
||||
|
||||
/// Copy construction without utils::MemoryResource is not allowed.
|
||||
mgp_path(const mgp_path &) = delete;
|
||||
@ -319,9 +290,7 @@ struct mgp_path {
|
||||
|
||||
~mgp_path() = default;
|
||||
|
||||
utils::MemoryResource *GetMemoryResource() const noexcept {
|
||||
return vertices.get_allocator().GetMemoryResource();
|
||||
}
|
||||
utils::MemoryResource *GetMemoryResource() const noexcept { return vertices.get_allocator().GetMemoryResource(); }
|
||||
|
||||
utils::pmr::vector<mgp_vertex> vertices;
|
||||
utils::pmr::vector<mgp_edge> edges;
|
||||
@ -329,24 +298,18 @@ struct mgp_path {
|
||||
|
||||
struct mgp_result_record {
|
||||
/// Result record signature as defined for mgp_proc.
|
||||
const utils::pmr::map<utils::pmr::string,
|
||||
std::pair<const query::procedure::CypherType *, bool>>
|
||||
*signature;
|
||||
const utils::pmr::map<utils::pmr::string, std::pair<const query::procedure::CypherType *, bool>> *signature;
|
||||
utils::pmr::map<utils::pmr::string, query::TypedValue> values;
|
||||
};
|
||||
|
||||
struct mgp_result {
|
||||
explicit mgp_result(
|
||||
const utils::pmr::map<
|
||||
utils::pmr::string,
|
||||
std::pair<const query::procedure::CypherType *, bool>> *signature,
|
||||
const utils::pmr::map<utils::pmr::string, std::pair<const query::procedure::CypherType *, bool>> *signature,
|
||||
utils::MemoryResource *mem)
|
||||
: signature(signature), rows(mem) {}
|
||||
|
||||
/// Result record signature as defined for mgp_proc.
|
||||
const utils::pmr::map<utils::pmr::string,
|
||||
std::pair<const query::procedure::CypherType *, bool>>
|
||||
*signature;
|
||||
const utils::pmr::map<utils::pmr::string, std::pair<const query::procedure::CypherType *, bool>> *signature;
|
||||
utils::pmr::vector<mgp_result_record> rows;
|
||||
std::optional<utils::pmr::string> error_msg;
|
||||
};
|
||||
@ -367,31 +330,22 @@ struct mgp_properties_iterator {
|
||||
|
||||
utils::MemoryResource *memory;
|
||||
const mgp_graph *graph;
|
||||
std::remove_reference_t<decltype(
|
||||
*std::declval<query::VertexAccessor>().Properties(graph->view))>
|
||||
pvs;
|
||||
std::remove_reference_t<decltype(*std::declval<query::VertexAccessor>().Properties(graph->view))> pvs;
|
||||
decltype(pvs.begin()) current_it;
|
||||
std::optional<std::pair<utils::pmr::string, mgp_value>> current;
|
||||
mgp_property property{nullptr, nullptr};
|
||||
|
||||
// Construct with no properties.
|
||||
explicit mgp_properties_iterator(const mgp_graph *graph,
|
||||
utils::MemoryResource *memory)
|
||||
explicit mgp_properties_iterator(const mgp_graph *graph, utils::MemoryResource *memory)
|
||||
: memory(memory), graph(graph), current_it(pvs.begin()) {}
|
||||
|
||||
// May throw who the #$@! knows what because PropertyValueStore doesn't
|
||||
// document what it throws, and it may surely throw some piece of !@#$
|
||||
// exception because it's built on top of STL and other libraries.
|
||||
mgp_properties_iterator(const mgp_graph *graph, decltype(pvs) pvs,
|
||||
utils::MemoryResource *memory)
|
||||
: memory(memory),
|
||||
graph(graph),
|
||||
pvs(std::move(pvs)),
|
||||
current_it(this->pvs.begin()) {
|
||||
mgp_properties_iterator(const mgp_graph *graph, decltype(pvs) pvs, utils::MemoryResource *memory)
|
||||
: memory(memory), graph(graph), pvs(std::move(pvs)), current_it(this->pvs.begin()) {
|
||||
if (current_it != this->pvs.end()) {
|
||||
current.emplace(
|
||||
utils::pmr::string(graph->impl->PropertyToName(current_it->first),
|
||||
memory),
|
||||
current.emplace(utils::pmr::string(graph->impl->PropertyToName(current_it->first), memory),
|
||||
mgp_value(current_it->second, memory));
|
||||
property.name = current->first.c_str();
|
||||
property.value = ¤t->second;
|
||||
@ -414,11 +368,9 @@ struct mgp_edges_iterator {
|
||||
|
||||
// Hopefully mgp_vertex copy constructor remains noexcept, so that we can
|
||||
// have everything noexcept here.
|
||||
static_assert(std::is_nothrow_constructible_v<mgp_vertex, const mgp_vertex &,
|
||||
utils::MemoryResource *>);
|
||||
static_assert(std::is_nothrow_constructible_v<mgp_vertex, const mgp_vertex &, utils::MemoryResource *>);
|
||||
|
||||
mgp_edges_iterator(const mgp_vertex &v,
|
||||
utils::MemoryResource *memory) noexcept
|
||||
mgp_edges_iterator(const mgp_vertex &v, utils::MemoryResource *memory) noexcept
|
||||
: memory(memory), source_vertex(v, memory) {}
|
||||
|
||||
mgp_edges_iterator(mgp_edges_iterator &&other) noexcept
|
||||
@ -440,13 +392,9 @@ struct mgp_edges_iterator {
|
||||
|
||||
utils::MemoryResource *memory;
|
||||
mgp_vertex source_vertex;
|
||||
std::optional<std::remove_reference_t<decltype(
|
||||
*source_vertex.impl.InEdges(source_vertex.graph->view))>>
|
||||
in;
|
||||
std::optional<std::remove_reference_t<decltype(*source_vertex.impl.InEdges(source_vertex.graph->view))>> in;
|
||||
std::optional<decltype(in->begin())> in_it;
|
||||
std::optional<std::remove_reference_t<decltype(
|
||||
*source_vertex.impl.OutEdges(source_vertex.graph->view))>>
|
||||
out;
|
||||
std::optional<std::remove_reference_t<decltype(*source_vertex.impl.OutEdges(source_vertex.graph->view))>> out;
|
||||
std::optional<decltype(out->begin())> out_it;
|
||||
std::optional<mgp_edge> current_e;
|
||||
};
|
||||
@ -456,10 +404,7 @@ struct mgp_vertices_iterator {
|
||||
|
||||
/// @throw anything VerticesIterable may throw
|
||||
mgp_vertices_iterator(const mgp_graph *graph, utils::MemoryResource *memory)
|
||||
: memory(memory),
|
||||
graph(graph),
|
||||
vertices(graph->impl->Vertices(graph->view)),
|
||||
current_it(vertices.begin()) {
|
||||
: memory(memory), graph(graph), vertices(graph->impl->Vertices(graph->view)), current_it(vertices.begin()) {
|
||||
if (current_it != vertices.end()) {
|
||||
current_v.emplace(*current_it, graph, memory);
|
||||
}
|
||||
@ -484,24 +429,13 @@ struct mgp_proc {
|
||||
/// @throw std::bad_alloc
|
||||
/// @throw std::length_error
|
||||
mgp_proc(const char *name, mgp_proc_cb cb, utils::MemoryResource *memory)
|
||||
: name(name, memory),
|
||||
cb(cb),
|
||||
args(memory),
|
||||
opt_args(memory),
|
||||
results(memory) {}
|
||||
: name(name, memory), cb(cb), args(memory), opt_args(memory), results(memory) {}
|
||||
|
||||
/// @throw std::bad_alloc
|
||||
/// @throw std::length_error
|
||||
mgp_proc(const char *name,
|
||||
std::function<void(const mgp_list *, const mgp_graph *, mgp_result *,
|
||||
mgp_memory *)>
|
||||
cb,
|
||||
mgp_proc(const char *name, std::function<void(const mgp_list *, const mgp_graph *, mgp_result *, mgp_memory *)> cb,
|
||||
utils::MemoryResource *memory)
|
||||
: name(name, memory),
|
||||
cb(cb),
|
||||
args(memory),
|
||||
opt_args(memory),
|
||||
results(memory) {}
|
||||
: name(name, memory), cb(cb), args(memory), opt_args(memory), results(memory) {}
|
||||
|
||||
/// @throw std::bad_alloc
|
||||
/// @throw std::length_error
|
||||
@ -530,22 +464,13 @@ struct mgp_proc {
|
||||
/// Name of the procedure.
|
||||
utils::pmr::string name;
|
||||
/// Entry-point for the procedure.
|
||||
std::function<void(const mgp_list *, const mgp_graph *, mgp_result *,
|
||||
mgp_memory *)>
|
||||
cb;
|
||||
std::function<void(const mgp_list *, const mgp_graph *, mgp_result *, mgp_memory *)> cb;
|
||||
/// Required, positional arguments as a (name, type) pair.
|
||||
utils::pmr::vector<
|
||||
std::pair<utils::pmr::string, const query::procedure::CypherType *>>
|
||||
args;
|
||||
utils::pmr::vector<std::pair<utils::pmr::string, const query::procedure::CypherType *>> args;
|
||||
/// Optional positional arguments as a (name, type, default_value) tuple.
|
||||
utils::pmr::vector<
|
||||
std::tuple<utils::pmr::string, const query::procedure::CypherType *,
|
||||
query::TypedValue>>
|
||||
opt_args;
|
||||
utils::pmr::vector<std::tuple<utils::pmr::string, const query::procedure::CypherType *, query::TypedValue>> opt_args;
|
||||
/// Fields this procedure returns, as a (name -> (type, is_deprecated)) map.
|
||||
utils::pmr::map<utils::pmr::string,
|
||||
std::pair<const query::procedure::CypherType *, bool>>
|
||||
results;
|
||||
utils::pmr::map<utils::pmr::string, std::pair<const query::procedure::CypherType *, bool>> results;
|
||||
};
|
||||
|
||||
struct mgp_module {
|
||||
@ -553,11 +478,9 @@ struct mgp_module {
|
||||
|
||||
explicit mgp_module(utils::MemoryResource *memory) : procedures(memory) {}
|
||||
|
||||
mgp_module(const mgp_module &other, utils::MemoryResource *memory)
|
||||
: procedures(other.procedures, memory) {}
|
||||
mgp_module(const mgp_module &other, utils::MemoryResource *memory) : procedures(other.procedures, memory) {}
|
||||
|
||||
mgp_module(mgp_module &&other, utils::MemoryResource *memory)
|
||||
: procedures(std::move(other.procedures), memory) {}
|
||||
mgp_module(mgp_module &&other, utils::MemoryResource *memory) : procedures(std::move(other.procedures), memory) {}
|
||||
|
||||
mgp_module(const mgp_module &) = default;
|
||||
mgp_module(mgp_module &&) = default;
|
||||
|
@ -30,8 +30,7 @@ class BuiltinModule final : public Module {
|
||||
|
||||
bool Close() override;
|
||||
|
||||
const std::map<std::string, mgp_proc, std::less<>> *Procedures()
|
||||
const override;
|
||||
const std::map<std::string, mgp_proc, std::less<>> *Procedures() const override;
|
||||
|
||||
void AddProcedure(std::string_view name, mgp_proc proc);
|
||||
|
||||
@ -46,19 +45,13 @@ BuiltinModule::~BuiltinModule() {}
|
||||
|
||||
bool BuiltinModule::Close() { return true; }
|
||||
|
||||
const std::map<std::string, mgp_proc, std::less<>> *BuiltinModule::Procedures()
|
||||
const {
|
||||
return &procedures_;
|
||||
}
|
||||
const std::map<std::string, mgp_proc, std::less<>> *BuiltinModule::Procedures() const { return &procedures_; }
|
||||
|
||||
void BuiltinModule::AddProcedure(std::string_view name, mgp_proc proc) {
|
||||
procedures_.emplace(name, std::move(proc));
|
||||
}
|
||||
void BuiltinModule::AddProcedure(std::string_view name, mgp_proc proc) { procedures_.emplace(name, std::move(proc)); }
|
||||
|
||||
namespace {
|
||||
|
||||
void RegisterMgLoad(ModuleRegistry *module_registry, utils::RWLock *lock,
|
||||
BuiltinModule *module) {
|
||||
void RegisterMgLoad(ModuleRegistry *module_registry, utils::RWLock *lock, BuiltinModule *module) {
|
||||
// Loading relies on the fact that regular procedure invocation through
|
||||
// CallProcedureCursor::Pull takes ModuleRegistry::lock_ with READ access. To
|
||||
// load modules we have to upgrade our READ access to WRITE access,
|
||||
@ -83,27 +76,19 @@ void RegisterMgLoad(ModuleRegistry *module_registry, utils::RWLock *lock,
|
||||
}
|
||||
lock->lock_shared();
|
||||
};
|
||||
auto load_all_cb = [module_registry, with_unlock_shared](
|
||||
const mgp_list *, const mgp_graph *, mgp_result *,
|
||||
auto load_all_cb = [module_registry, with_unlock_shared](const mgp_list *, const mgp_graph *, mgp_result *,
|
||||
mgp_memory *) {
|
||||
with_unlock_shared(
|
||||
[&]() { module_registry->UnloadAndLoadModulesFromDirectory(); });
|
||||
with_unlock_shared([&]() { module_registry->UnloadAndLoadModulesFromDirectory(); });
|
||||
};
|
||||
mgp_proc load_all("load_all", load_all_cb, utils::NewDeleteResource());
|
||||
module->AddProcedure("load_all", std::move(load_all));
|
||||
auto load_cb = [module_registry, with_unlock_shared](
|
||||
const mgp_list *args, const mgp_graph *, mgp_result *res,
|
||||
auto load_cb = [module_registry, with_unlock_shared](const mgp_list *args, const mgp_graph *, mgp_result *res,
|
||||
mgp_memory *) {
|
||||
MG_ASSERT(mgp_list_size(args) == 1U,
|
||||
"Should have been type checked already");
|
||||
MG_ASSERT(mgp_list_size(args) == 1U, "Should have been type checked already");
|
||||
const mgp_value *arg = mgp_list_at(args, 0);
|
||||
MG_ASSERT(mgp_value_is_string(arg),
|
||||
"Should have been type checked already");
|
||||
MG_ASSERT(mgp_value_is_string(arg), "Should have been type checked already");
|
||||
bool succ = false;
|
||||
with_unlock_shared([&]() {
|
||||
succ = module_registry->LoadOrReloadModuleFromName(
|
||||
mgp_value_get_string(arg));
|
||||
});
|
||||
with_unlock_shared([&]() { succ = module_registry->LoadOrReloadModuleFromName(mgp_value_get_string(arg)); });
|
||||
if (!succ) mgp_result_set_error_msg(res, "Failed to (re)load the module.");
|
||||
};
|
||||
mgp_proc load("load", load_cb, utils::NewDeleteResource());
|
||||
@ -113,11 +98,8 @@ void RegisterMgLoad(ModuleRegistry *module_registry, utils::RWLock *lock,
|
||||
|
||||
void RegisterMgProcedures(
|
||||
// We expect modules to be sorted by name.
|
||||
const std::map<std::string, std::unique_ptr<Module>, std::less<>>
|
||||
*all_modules,
|
||||
BuiltinModule *module) {
|
||||
auto procedures_cb = [all_modules](const mgp_list *, const mgp_graph *,
|
||||
mgp_result *result, mgp_memory *memory) {
|
||||
const std::map<std::string, std::unique_ptr<Module>, std::less<>> *all_modules, BuiltinModule *module) {
|
||||
auto procedures_cb = [all_modules](const mgp_list *, const mgp_graph *, mgp_result *result, mgp_memory *memory) {
|
||||
// Iterating over all_modules assumes that the standard mechanism of custom
|
||||
// procedure invocations takes the ModuleRegistry::lock_ with READ access.
|
||||
// For details on how the invocation is done, take a look at the
|
||||
@ -125,8 +107,7 @@ void RegisterMgProcedures(
|
||||
for (const auto &[module_name, module] : *all_modules) {
|
||||
// Return the results in sorted order by module and by procedure.
|
||||
static_assert(
|
||||
std::is_same_v<decltype(module->Procedures()),
|
||||
const std::map<std::string, mgp_proc, std::less<>> *>,
|
||||
std::is_same_v<decltype(module->Procedures()), const std::map<std::string, mgp_proc, std::less<>> *>,
|
||||
"Expected module procedures to be sorted by name");
|
||||
for (const auto &[proc_name, proc] : *module->Procedures()) {
|
||||
auto *record = mgp_result_new_record(result);
|
||||
@ -146,16 +127,14 @@ void RegisterMgProcedures(
|
||||
ss << module_name << ".";
|
||||
PrintProcSignature(proc, &ss);
|
||||
const auto signature = ss.str();
|
||||
auto *signature_value =
|
||||
mgp_value_make_string(signature.c_str(), memory);
|
||||
auto *signature_value = mgp_value_make_string(signature.c_str(), memory);
|
||||
if (!signature_value) {
|
||||
mgp_value_destroy(name_value);
|
||||
mgp_result_set_error_msg(result, "Not enough memory!");
|
||||
return;
|
||||
}
|
||||
int succ1 = mgp_result_record_insert(record, "name", name_value);
|
||||
int succ2 =
|
||||
mgp_result_record_insert(record, "signature", signature_value);
|
||||
int succ2 = mgp_result_record_insert(record, "signature", signature_value);
|
||||
mgp_value_destroy(name_value);
|
||||
mgp_value_destroy(signature_value);
|
||||
if (!succ1 || !succ2) {
|
||||
@ -206,8 +185,7 @@ class SharedLibraryModule final : public Module {
|
||||
|
||||
bool Close() override;
|
||||
|
||||
const std::map<std::string, mgp_proc, std::less<>> *Procedures()
|
||||
const override;
|
||||
const std::map<std::string, mgp_proc, std::less<>> *Procedures() const override;
|
||||
|
||||
private:
|
||||
/// Path as requested for loading the module from a library.
|
||||
@ -239,8 +217,7 @@ bool SharedLibraryModule::Load(const std::filesystem::path &file_path) {
|
||||
return false;
|
||||
}
|
||||
// Get required mgp_init_module
|
||||
init_fn_ = reinterpret_cast<int (*)(mgp_module *, mgp_memory *)>(
|
||||
dlsym(handle_, "mgp_init_module"));
|
||||
init_fn_ = reinterpret_cast<int (*)(mgp_module *, mgp_memory *)>(dlsym(handle_, "mgp_init_module"));
|
||||
const char *error = dlerror();
|
||||
if (!init_fn_ || error) {
|
||||
spdlog::error("Unable to load module {}; {}", file_path, error);
|
||||
@ -248,13 +225,11 @@ bool SharedLibraryModule::Load(const std::filesystem::path &file_path) {
|
||||
handle_ = nullptr;
|
||||
return false;
|
||||
}
|
||||
if (!WithModuleRegistration(&procedures_, [&](auto *module_def,
|
||||
auto *memory) {
|
||||
if (!WithModuleRegistration(&procedures_, [&](auto *module_def, auto *memory) {
|
||||
// Run mgp_init_module which must succeed.
|
||||
int init_res = init_fn_(module_def, memory);
|
||||
if (init_res != 0) {
|
||||
spdlog::error("Unable to load module {}; mgp_init_module_returned {}",
|
||||
file_path, init_res);
|
||||
spdlog::error("Unable to load module {}; mgp_init_module_returned {}", file_path, init_res);
|
||||
dlclose(handle_);
|
||||
handle_ = nullptr;
|
||||
return false;
|
||||
@ -264,8 +239,7 @@ bool SharedLibraryModule::Load(const std::filesystem::path &file_path) {
|
||||
return false;
|
||||
}
|
||||
// Get optional mgp_shutdown_module
|
||||
shutdown_fn_ =
|
||||
reinterpret_cast<int (*)()>(dlsym(handle_, "mgp_shutdown_module"));
|
||||
shutdown_fn_ = reinterpret_cast<int (*)()>(dlsym(handle_, "mgp_shutdown_module"));
|
||||
error = dlerror();
|
||||
if (error) spdlog::warn("When loading module {}; {}", file_path, error);
|
||||
spdlog::info("Loaded module {}", file_path);
|
||||
@ -273,16 +247,14 @@ bool SharedLibraryModule::Load(const std::filesystem::path &file_path) {
|
||||
}
|
||||
|
||||
bool SharedLibraryModule::Close() {
|
||||
MG_ASSERT(handle_,
|
||||
"Attempting to close a module that has not been loaded...");
|
||||
MG_ASSERT(handle_, "Attempting to close a module that has not been loaded...");
|
||||
spdlog::info("Closing module {}...", file_path_);
|
||||
// non-existent shutdown function is semantically the same as a shutdown
|
||||
// function that does nothing.
|
||||
int shutdown_res = 0;
|
||||
if (shutdown_fn_) shutdown_res = shutdown_fn_();
|
||||
if (shutdown_res != 0) {
|
||||
spdlog::warn("When closing module {}; mgp_shutdown_module returned {}",
|
||||
file_path_, shutdown_res);
|
||||
spdlog::warn("When closing module {}; mgp_shutdown_module returned {}", file_path_, shutdown_res);
|
||||
}
|
||||
if (dlclose(handle_) != 0) {
|
||||
spdlog::error("Failed to close module {}; {}", file_path_, dlerror());
|
||||
@ -294,8 +266,7 @@ bool SharedLibraryModule::Close() {
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::map<std::string, mgp_proc, std::less<>>
|
||||
*SharedLibraryModule::Procedures() const {
|
||||
const std::map<std::string, mgp_proc, std::less<>> *SharedLibraryModule::Procedures() const {
|
||||
MG_ASSERT(handle_,
|
||||
"Attempting to access procedures of a module that has not "
|
||||
"been loaded...");
|
||||
@ -315,8 +286,7 @@ class PythonModule final : public Module {
|
||||
|
||||
bool Close() override;
|
||||
|
||||
const std::map<std::string, mgp_proc, std::less<>> *Procedures()
|
||||
const override;
|
||||
const std::map<std::string, mgp_proc, std::less<>> *Procedures() const override;
|
||||
|
||||
private:
|
||||
std::filesystem::path file_path_;
|
||||
@ -340,8 +310,7 @@ bool PythonModule::Load(const std::filesystem::path &file_path) {
|
||||
spdlog::error("Unable to load module {}; {}", file_path, *maybe_exc);
|
||||
return false;
|
||||
}
|
||||
py_module_ =
|
||||
WithModuleRegistration(&procedures_, [&](auto *module_def, auto *memory) {
|
||||
py_module_ = WithModuleRegistration(&procedures_, [&](auto *module_def, auto *memory) {
|
||||
return ImportPyModule(file_path.stem().c_str(), module_def);
|
||||
});
|
||||
if (py_module_) {
|
||||
@ -354,8 +323,7 @@ bool PythonModule::Load(const std::filesystem::path &file_path) {
|
||||
}
|
||||
|
||||
bool PythonModule::Close() {
|
||||
MG_ASSERT(py_module_,
|
||||
"Attempting to close a module that has not been loaded...");
|
||||
MG_ASSERT(py_module_, "Attempting to close a module that has not been loaded...");
|
||||
spdlog::info("Closing module {}...", file_path_);
|
||||
// The procedures are closures which hold references to the Python callbacks.
|
||||
// Releasing these references might result in deallocations so we need to take
|
||||
@ -365,8 +333,7 @@ bool PythonModule::Close() {
|
||||
// Delete the module from the `sys.modules` directory so that the module will
|
||||
// be properly imported if imported again.
|
||||
py::Object sys(PyImport_ImportModule("sys"));
|
||||
if (PyDict_DelItemString(sys.GetAttr("modules").Ptr(),
|
||||
file_path_.stem().c_str()) != 0) {
|
||||
if (PyDict_DelItemString(sys.GetAttr("modules").Ptr(), file_path_.stem().c_str()) != 0) {
|
||||
spdlog::warn("Failed to remove the module from sys.modules");
|
||||
py_module_ = py::Object(nullptr);
|
||||
return false;
|
||||
@ -376,8 +343,7 @@ bool PythonModule::Close() {
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::map<std::string, mgp_proc, std::less<>> *PythonModule::Procedures()
|
||||
const {
|
||||
const std::map<std::string, mgp_proc, std::less<>> *PythonModule::Procedures() const {
|
||||
MG_ASSERT(py_module_,
|
||||
"Attempting to access procedures of a module that has "
|
||||
"not been loaded...");
|
||||
@ -407,8 +373,7 @@ std::unique_ptr<Module> LoadModuleFromFile(const std::filesystem::path &path) {
|
||||
|
||||
} // namespace
|
||||
|
||||
bool ModuleRegistry::RegisterModule(const std::string_view &name,
|
||||
std::unique_ptr<Module> module) {
|
||||
bool ModuleRegistry::RegisterModule(const std::string_view &name, std::unique_ptr<Module> module) {
|
||||
MG_ASSERT(!name.empty(), "Module name cannot be empty");
|
||||
MG_ASSERT(module, "Tried to register an invalid module");
|
||||
if (modules_.find(name) != modules_.end()) {
|
||||
@ -420,8 +385,7 @@ bool ModuleRegistry::RegisterModule(const std::string_view &name,
|
||||
}
|
||||
|
||||
void ModuleRegistry::DoUnloadAllModules() {
|
||||
MG_ASSERT(modules_.find("mg") != modules_.end(),
|
||||
"Expected the builtin \"mg\" module to be present.");
|
||||
MG_ASSERT(modules_.find("mg") != modules_.end(), "Expected the builtin \"mg\" module to be present.");
|
||||
// This is correct because the destructor will close each module. However,
|
||||
// we don't want to unload the builtin "mg" module.
|
||||
auto module = std::move(modules_["mg"]);
|
||||
@ -436,10 +400,7 @@ ModuleRegistry::ModuleRegistry() {
|
||||
modules_.emplace("mg", std::move(module));
|
||||
}
|
||||
|
||||
void ModuleRegistry::SetModulesDirectory(
|
||||
const std::filesystem::path &modules_dir) {
|
||||
modules_dir_ = modules_dir;
|
||||
}
|
||||
void ModuleRegistry::SetModulesDirectory(const std::filesystem::path &modules_dir) { modules_dir_ = modules_dir; }
|
||||
|
||||
bool ModuleRegistry::LoadOrReloadModuleFromName(const std::string_view &name) {
|
||||
if (modules_dir_.empty()) return false;
|
||||
@ -500,16 +461,14 @@ void ModuleRegistry::UnloadAllModules() {
|
||||
}
|
||||
|
||||
std::optional<std::pair<procedure::ModulePtr, const mgp_proc *>> FindProcedure(
|
||||
const ModuleRegistry &module_registry,
|
||||
const std::string_view &fully_qualified_procedure_name,
|
||||
const ModuleRegistry &module_registry, const std::string_view &fully_qualified_procedure_name,
|
||||
utils::MemoryResource *memory) {
|
||||
utils::pmr::vector<std::string_view> name_parts(memory);
|
||||
utils::Split(&name_parts, fully_qualified_procedure_name, ".");
|
||||
if (name_parts.size() == 1U) return std::nullopt;
|
||||
auto last_dot_pos = fully_qualified_procedure_name.find_last_of('.');
|
||||
MG_ASSERT(last_dot_pos != std::string_view::npos);
|
||||
const auto &module_name =
|
||||
fully_qualified_procedure_name.substr(0, last_dot_pos);
|
||||
const auto &module_name = fully_qualified_procedure_name.substr(0, last_dot_pos);
|
||||
const auto &proc_name = name_parts.back();
|
||||
auto module = module_registry.GetModuleNamed(module_name);
|
||||
if (!module) return std::nullopt;
|
||||
|
@ -29,8 +29,7 @@ class Module {
|
||||
virtual bool Close() = 0;
|
||||
|
||||
/// Returns registered procedures of this module
|
||||
virtual const std::map<std::string, mgp_proc, std::less<>> *Procedures()
|
||||
const = 0;
|
||||
virtual const std::map<std::string, mgp_proc, std::less<>> *Procedures() const = 0;
|
||||
};
|
||||
|
||||
/// Proxy for a registered Module, acquires a read lock from ModuleRegistry.
|
||||
@ -41,8 +40,7 @@ class ModulePtr final {
|
||||
public:
|
||||
ModulePtr() = default;
|
||||
ModulePtr(std::nullptr_t) {}
|
||||
ModulePtr(const Module *module, std::shared_lock<utils::RWLock> lock)
|
||||
: module_(module), lock_(std::move(lock)) {}
|
||||
ModulePtr(const Module *module, std::shared_lock<utils::RWLock> lock) : module_(module), lock_(std::move(lock)) {}
|
||||
|
||||
explicit operator bool() const { return static_cast<bool>(module_); }
|
||||
|
||||
@ -55,8 +53,7 @@ class ModuleRegistry final {
|
||||
std::map<std::string, std::unique_ptr<Module>, std::less<>> modules_;
|
||||
mutable utils::RWLock lock_{utils::RWLock::Priority::WRITE};
|
||||
|
||||
bool RegisterModule(const std::string_view &name,
|
||||
std::unique_ptr<Module> module);
|
||||
bool RegisterModule(const std::string_view &name, std::unique_ptr<Module> module);
|
||||
|
||||
void DoUnloadAllModules();
|
||||
|
||||
@ -105,8 +102,7 @@ extern ModuleRegistry gModuleRegistry;
|
||||
/// inside this function. ModulePtr must be kept alive to make sure it won't be
|
||||
/// unloaded.
|
||||
std::optional<std::pair<procedure::ModulePtr, const mgp_proc *>> FindProcedure(
|
||||
const ModuleRegistry &module_registry,
|
||||
const std::string_view &fully_qualified_procedure_name,
|
||||
const ModuleRegistry &module_registry, const std::string_view &fully_qualified_procedure_name,
|
||||
utils::MemoryResource *memory);
|
||||
|
||||
} // namespace query::procedure
|
||||
|
@ -65,8 +65,7 @@ void PyVerticesIteratorDealloc(PyVerticesIterator *self) {
|
||||
Py_TYPE(self)->tp_free(self);
|
||||
}
|
||||
|
||||
PyObject *PyVerticesIteratorGet(PyVerticesIterator *self,
|
||||
PyObject *Py_UNUSED(ignored)) {
|
||||
PyObject *PyVerticesIteratorGet(PyVerticesIterator *self, PyObject *Py_UNUSED(ignored)) {
|
||||
MG_ASSERT(self->it);
|
||||
MG_ASSERT(self->py_graph);
|
||||
MG_ASSERT(self->py_graph->graph);
|
||||
@ -75,8 +74,7 @@ PyObject *PyVerticesIteratorGet(PyVerticesIterator *self,
|
||||
return MakePyVertex(*vertex, self->py_graph);
|
||||
}
|
||||
|
||||
PyObject *PyVerticesIteratorNext(PyVerticesIterator *self,
|
||||
PyObject *Py_UNUSED(ignored)) {
|
||||
PyObject *PyVerticesIteratorNext(PyVerticesIterator *self, PyObject *Py_UNUSED(ignored)) {
|
||||
MG_ASSERT(self->it);
|
||||
MG_ASSERT(self->py_graph);
|
||||
MG_ASSERT(self->py_graph->graph);
|
||||
@ -86,8 +84,7 @@ PyObject *PyVerticesIteratorNext(PyVerticesIterator *self,
|
||||
}
|
||||
|
||||
static PyMethodDef PyVerticesIteratorMethods[] = {
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy),
|
||||
METH_NOARGS, "__reduce__ is not supported"},
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
|
||||
{"get", reinterpret_cast<PyCFunction>(PyVerticesIteratorGet), METH_NOARGS,
|
||||
"Get the current vertex pointed to by the iterator or return None."},
|
||||
{"next", reinterpret_cast<PyCFunction>(PyVerticesIteratorNext), METH_NOARGS,
|
||||
@ -128,8 +125,7 @@ void PyEdgesIteratorDealloc(PyEdgesIterator *self) {
|
||||
Py_TYPE(self)->tp_free(self);
|
||||
}
|
||||
|
||||
PyObject *PyEdgesIteratorGet(PyEdgesIterator *self,
|
||||
PyObject *Py_UNUSED(ignored)) {
|
||||
PyObject *PyEdgesIteratorGet(PyEdgesIterator *self, PyObject *Py_UNUSED(ignored)) {
|
||||
MG_ASSERT(self->it);
|
||||
MG_ASSERT(self->py_graph);
|
||||
MG_ASSERT(self->py_graph->graph);
|
||||
@ -138,8 +134,7 @@ PyObject *PyEdgesIteratorGet(PyEdgesIterator *self,
|
||||
return MakePyEdge(*edge, self->py_graph);
|
||||
}
|
||||
|
||||
PyObject *PyEdgesIteratorNext(PyEdgesIterator *self,
|
||||
PyObject *Py_UNUSED(ignored)) {
|
||||
PyObject *PyEdgesIteratorNext(PyEdgesIterator *self, PyObject *Py_UNUSED(ignored)) {
|
||||
MG_ASSERT(self->it);
|
||||
MG_ASSERT(self->py_graph);
|
||||
MG_ASSERT(self->py_graph->graph);
|
||||
@ -149,8 +144,7 @@ PyObject *PyEdgesIteratorNext(PyEdgesIterator *self,
|
||||
}
|
||||
|
||||
static PyMethodDef PyEdgesIteratorMethods[] = {
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy),
|
||||
METH_NOARGS, "__reduce__ is not supported"},
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
|
||||
{"get", reinterpret_cast<PyCFunction>(PyEdgesIteratorGet), METH_NOARGS,
|
||||
"Get the current edge pointed to by the iterator or return None."},
|
||||
{"next", reinterpret_cast<PyCFunction>(PyEdgesIteratorNext), METH_NOARGS,
|
||||
@ -176,9 +170,7 @@ PyObject *PyGraphInvalidate(PyGraph *self, PyObject *Py_UNUSED(ignored)) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyObject *PyGraphIsValid(PyGraph *self, PyObject *Py_UNUSED(ignored)) {
|
||||
return PyBool_FromLong(!!self->graph);
|
||||
}
|
||||
PyObject *PyGraphIsValid(PyGraph *self, PyObject *Py_UNUSED(ignored)) { return PyBool_FromLong(!!self->graph); }
|
||||
|
||||
PyObject *MakePyVertex(mgp_vertex *vertex, PyGraph *py_graph);
|
||||
|
||||
@ -188,11 +180,9 @@ PyObject *PyGraphGetVertexById(PyGraph *self, PyObject *args) {
|
||||
static_assert(std::is_same_v<int64_t, long>);
|
||||
int64_t id;
|
||||
if (!PyArg_ParseTuple(args, "l", &id)) return nullptr;
|
||||
auto *vertex =
|
||||
mgp_graph_get_vertex_by_id(self->graph, mgp_vertex_id{id}, self->memory);
|
||||
auto *vertex = mgp_graph_get_vertex_by_id(self->graph, mgp_vertex_id{id}, self->memory);
|
||||
if (!vertex) {
|
||||
PyErr_SetString(PyExc_IndexError,
|
||||
"Unable to find the vertex with given ID.");
|
||||
PyErr_SetString(PyExc_IndexError, "Unable to find the vertex with given ID.");
|
||||
return nullptr;
|
||||
}
|
||||
auto *py_vertex = MakePyVertex(vertex, self);
|
||||
@ -205,12 +195,10 @@ PyObject *PyGraphIterVertices(PyGraph *self, PyObject *Py_UNUSED(ignored)) {
|
||||
MG_ASSERT(self->memory);
|
||||
auto *vertices_it = mgp_graph_iter_vertices(self->graph, self->memory);
|
||||
if (!vertices_it) {
|
||||
PyErr_SetString(PyExc_MemoryError,
|
||||
"Unable to allocate mgp_vertices_iterator.");
|
||||
PyErr_SetString(PyExc_MemoryError, "Unable to allocate mgp_vertices_iterator.");
|
||||
return nullptr;
|
||||
}
|
||||
auto *py_vertices_it =
|
||||
PyObject_New(PyVerticesIterator, &PyVerticesIteratorType);
|
||||
auto *py_vertices_it = PyObject_New(PyVerticesIterator, &PyVerticesIteratorType);
|
||||
if (!py_vertices_it) {
|
||||
mgp_vertices_iterator_destroy(vertices_it);
|
||||
return nullptr;
|
||||
@ -227,17 +215,14 @@ PyObject *PyGraphMustAbort(PyGraph *self, PyObject *Py_UNUSED(ignored)) {
|
||||
}
|
||||
|
||||
static PyMethodDef PyGraphMethods[] = {
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy),
|
||||
METH_NOARGS, "__reduce__ is not supported"},
|
||||
{"invalidate", reinterpret_cast<PyCFunction>(PyGraphInvalidate),
|
||||
METH_NOARGS,
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
|
||||
{"invalidate", reinterpret_cast<PyCFunction>(PyGraphInvalidate), METH_NOARGS,
|
||||
"Invalidate the Graph context thus preventing the Graph from being used."},
|
||||
{"is_valid", reinterpret_cast<PyCFunction>(PyGraphIsValid), METH_NOARGS,
|
||||
"Return True if Graph is in valid context and may be used."},
|
||||
{"get_vertex_by_id", reinterpret_cast<PyCFunction>(PyGraphGetVertexById),
|
||||
METH_VARARGS, "Get the vertex or raise IndexError."},
|
||||
{"iter_vertices", reinterpret_cast<PyCFunction>(PyGraphIterVertices),
|
||||
METH_NOARGS, "Return _mgp.VerticesIterator."},
|
||||
{"get_vertex_by_id", reinterpret_cast<PyCFunction>(PyGraphGetVertexById), METH_VARARGS,
|
||||
"Get the vertex or raise IndexError."},
|
||||
{"iter_vertices", reinterpret_cast<PyCFunction>(PyGraphIterVertices), METH_NOARGS, "Return _mgp.VerticesIterator."},
|
||||
{"must_abort", reinterpret_cast<PyCFunction>(PyGraphMustAbort), METH_NOARGS,
|
||||
"Check whether the running procedure should abort"},
|
||||
{nullptr},
|
||||
@ -317,8 +302,7 @@ PyObject *PyQueryProcAddOptArg(PyQueryProc *self, PyObject *args) {
|
||||
const char *name = nullptr;
|
||||
PyObject *py_type = nullptr;
|
||||
PyObject *py_value = nullptr;
|
||||
if (!PyArg_ParseTuple(args, "sOO", &name, &py_type, &py_value))
|
||||
return nullptr;
|
||||
if (!PyArg_ParseTuple(args, "sOO", &name, &py_type, &py_value)) return nullptr;
|
||||
if (Py_TYPE(py_type) != &PyCypherTypeType) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type.");
|
||||
return nullptr;
|
||||
@ -379,26 +363,21 @@ PyObject *PyQueryProcAddDeprecatedResult(PyQueryProc *self, PyObject *args) {
|
||||
}
|
||||
const auto *type = reinterpret_cast<PyCypherType *>(py_type)->type;
|
||||
if (!mgp_proc_add_deprecated_result(self->proc, name, type)) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"Invalid call to mgp_proc_add_deprecated_result.");
|
||||
PyErr_SetString(PyExc_ValueError, "Invalid call to mgp_proc_add_deprecated_result.");
|
||||
return nullptr;
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyMethodDef PyQueryProcMethods[] = {
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy),
|
||||
METH_NOARGS, "__reduce__ is not supported"},
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
|
||||
{"add_arg", reinterpret_cast<PyCFunction>(PyQueryProcAddArg), METH_VARARGS,
|
||||
"Add a required argument to a procedure."},
|
||||
{"add_opt_arg", reinterpret_cast<PyCFunction>(PyQueryProcAddOptArg),
|
||||
METH_VARARGS,
|
||||
{"add_opt_arg", reinterpret_cast<PyCFunction>(PyQueryProcAddOptArg), METH_VARARGS,
|
||||
"Add an optional argument with a default value to a procedure."},
|
||||
{"add_result", reinterpret_cast<PyCFunction>(PyQueryProcAddResult),
|
||||
METH_VARARGS, "Add a result field to a procedure."},
|
||||
{"add_deprecated_result",
|
||||
reinterpret_cast<PyCFunction>(PyQueryProcAddDeprecatedResult),
|
||||
METH_VARARGS,
|
||||
{"add_result", reinterpret_cast<PyCFunction>(PyQueryProcAddResult), METH_VARARGS,
|
||||
"Add a result field to a procedure."},
|
||||
{"add_deprecated_result", reinterpret_cast<PyCFunction>(PyQueryProcAddDeprecatedResult), METH_VARARGS,
|
||||
"Add a result field to a procedure and mark it as deprecated."},
|
||||
{nullptr},
|
||||
};
|
||||
@ -447,8 +426,7 @@ py::Object MgpListToPyTuple(const mgp_list *list, PyObject *py_graph) {
|
||||
|
||||
namespace {
|
||||
|
||||
std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result,
|
||||
py::Object py_record) {
|
||||
std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result, py::Object py_record) {
|
||||
py::Object py_mgp(PyImport_ImportModule("mgp"));
|
||||
if (!py_mgp) return py::FetchError();
|
||||
auto record_cls = py_mgp.GetAttr("Record");
|
||||
@ -463,8 +441,7 @@ std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result,
|
||||
py::Object fields(py_record.GetAttr("fields"));
|
||||
if (!fields) return py::FetchError();
|
||||
if (!PyDict_Check(fields)) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"Expected 'mgp.Record.fields' to be a 'dict'");
|
||||
PyErr_SetString(PyExc_TypeError, "Expected 'mgp.Record.fields' to be a 'dict'");
|
||||
return py::FetchError();
|
||||
}
|
||||
py::Object items(PyDict_Items(fields.Ptr()));
|
||||
@ -483,8 +460,7 @@ std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result,
|
||||
if (!key) return py::FetchError();
|
||||
if (!PyUnicode_Check(key)) {
|
||||
std::stringstream ss;
|
||||
ss << "Field name '" << py::Object::FromBorrow(key)
|
||||
<< "' is not an instance of 'str'";
|
||||
ss << "Field name '" << py::Object::FromBorrow(key) << "' is not an instance of 'str'";
|
||||
const auto &msg = ss.str();
|
||||
PyErr_SetString(PyExc_TypeError, msg.c_str());
|
||||
return py::FetchError();
|
||||
@ -505,9 +481,8 @@ std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result,
|
||||
MG_ASSERT(field_val);
|
||||
if (!mgp_result_record_insert(record, field_name, field_val)) {
|
||||
std::stringstream ss;
|
||||
ss << "Unable to insert field '" << py::Object::FromBorrow(key)
|
||||
<< "' with value: '" << py::Object::FromBorrow(val)
|
||||
<< "'; did you set the correct field type?";
|
||||
ss << "Unable to insert field '" << py::Object::FromBorrow(key) << "' with value: '"
|
||||
<< py::Object::FromBorrow(val) << "'; did you set the correct field type?";
|
||||
const auto &msg = ss.str();
|
||||
PyErr_SetString(PyExc_ValueError, msg.c_str());
|
||||
mgp_value_destroy(field_val);
|
||||
@ -518,8 +493,7 @@ std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result,
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<py::ExceptionInfo> AddMultipleRecordsFromPython(
|
||||
mgp_result *result, py::Object py_seq) {
|
||||
std::optional<py::ExceptionInfo> AddMultipleRecordsFromPython(mgp_result *result, py::Object py_seq) {
|
||||
Py_ssize_t len = PySequence_Size(py_seq.Ptr());
|
||||
if (len == -1) return py::FetchError();
|
||||
for (Py_ssize_t i = 0; i < len; ++i) {
|
||||
@ -531,13 +505,11 @@ std::optional<py::ExceptionInfo> AddMultipleRecordsFromPython(
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
void CallPythonProcedure(py::Object py_cb, const mgp_list *args,
|
||||
const mgp_graph *graph, mgp_result *result,
|
||||
void CallPythonProcedure(py::Object py_cb, const mgp_list *args, const mgp_graph *graph, mgp_result *result,
|
||||
mgp_memory *memory) {
|
||||
auto gil = py::EnsureGIL();
|
||||
|
||||
auto error_to_msg = [](const std::optional<py::ExceptionInfo> &exc_info)
|
||||
-> std::optional<std::string> {
|
||||
auto error_to_msg = [](const std::optional<py::ExceptionInfo> &exc_info) -> std::optional<std::string> {
|
||||
if (!exc_info) return std::nullopt;
|
||||
// Here we tell the traceback formatter to skip the first line of the
|
||||
// traceback because that line will always be our wrapper function in our
|
||||
@ -626,23 +598,19 @@ PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) {
|
||||
const auto *name = PyUnicode_AsUTF8(py_name.Ptr());
|
||||
if (!name) return nullptr;
|
||||
if (!IsValidIdentifierName(name)) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"Procedure name is not a valid identifier");
|
||||
PyErr_SetString(PyExc_ValueError, "Procedure name is not a valid identifier");
|
||||
return nullptr;
|
||||
}
|
||||
auto *memory = self->module->procedures.get_allocator().GetMemoryResource();
|
||||
mgp_proc proc(
|
||||
name,
|
||||
[py_cb](const mgp_list *args, const mgp_graph *graph, mgp_result *result,
|
||||
mgp_memory *memory) {
|
||||
[py_cb](const mgp_list *args, const mgp_graph *graph, mgp_result *result, mgp_memory *memory) {
|
||||
CallPythonProcedure(py_cb, args, graph, result, memory);
|
||||
},
|
||||
memory);
|
||||
const auto &[proc_it, did_insert] =
|
||||
self->module->procedures.emplace(name, std::move(proc));
|
||||
const auto &[proc_it, did_insert] = self->module->procedures.emplace(name, std::move(proc));
|
||||
if (!did_insert) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"Already registered a procedure with the same name.");
|
||||
PyErr_SetString(PyExc_ValueError, "Already registered a procedure with the same name.");
|
||||
return nullptr;
|
||||
}
|
||||
auto *py_proc = PyObject_New(PyQueryProc, &PyQueryProcType);
|
||||
@ -652,10 +620,8 @@ PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) {
|
||||
}
|
||||
|
||||
static PyMethodDef PyQueryModuleMethods[] = {
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy),
|
||||
METH_NOARGS, "__reduce__ is not supported"},
|
||||
{"add_read_procedure",
|
||||
reinterpret_cast<PyCFunction>(PyQueryModuleAddReadProcedure), METH_O,
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
|
||||
{"add_read_procedure", reinterpret_cast<PyCFunction>(PyQueryModuleAddReadProcedure), METH_O,
|
||||
"Register a read-only procedure with this module."},
|
||||
{nullptr},
|
||||
};
|
||||
@ -697,21 +663,15 @@ PyObject *PyMgpModuleTypeList(PyObject *mod, PyObject *obj) {
|
||||
return MakePyCypherType(mgp_type_list(py_type->type));
|
||||
}
|
||||
|
||||
PyObject *PyMgpModuleTypeAny(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
|
||||
return MakePyCypherType(mgp_type_any());
|
||||
}
|
||||
PyObject *PyMgpModuleTypeAny(PyObject *mod, PyObject *Py_UNUSED(ignored)) { return MakePyCypherType(mgp_type_any()); }
|
||||
|
||||
PyObject *PyMgpModuleTypeBool(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
|
||||
return MakePyCypherType(mgp_type_bool());
|
||||
}
|
||||
PyObject *PyMgpModuleTypeBool(PyObject *mod, PyObject *Py_UNUSED(ignored)) { return MakePyCypherType(mgp_type_bool()); }
|
||||
|
||||
PyObject *PyMgpModuleTypeString(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
|
||||
return MakePyCypherType(mgp_type_string());
|
||||
}
|
||||
|
||||
PyObject *PyMgpModuleTypeInt(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
|
||||
return MakePyCypherType(mgp_type_int());
|
||||
}
|
||||
PyObject *PyMgpModuleTypeInt(PyObject *mod, PyObject *Py_UNUSED(ignored)) { return MakePyCypherType(mgp_type_int()); }
|
||||
|
||||
PyObject *PyMgpModuleTypeFloat(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
|
||||
return MakePyCypherType(mgp_type_float());
|
||||
@ -721,45 +681,29 @@ PyObject *PyMgpModuleTypeNumber(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
|
||||
return MakePyCypherType(mgp_type_number());
|
||||
}
|
||||
|
||||
PyObject *PyMgpModuleTypeMap(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
|
||||
return MakePyCypherType(mgp_type_map());
|
||||
}
|
||||
PyObject *PyMgpModuleTypeMap(PyObject *mod, PyObject *Py_UNUSED(ignored)) { return MakePyCypherType(mgp_type_map()); }
|
||||
|
||||
PyObject *PyMgpModuleTypeNode(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
|
||||
return MakePyCypherType(mgp_type_node());
|
||||
}
|
||||
PyObject *PyMgpModuleTypeNode(PyObject *mod, PyObject *Py_UNUSED(ignored)) { return MakePyCypherType(mgp_type_node()); }
|
||||
|
||||
PyObject *PyMgpModuleTypeRelationship(PyObject *mod,
|
||||
PyObject *Py_UNUSED(ignored)) {
|
||||
PyObject *PyMgpModuleTypeRelationship(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
|
||||
return MakePyCypherType(mgp_type_relationship());
|
||||
}
|
||||
|
||||
PyObject *PyMgpModuleTypePath(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
|
||||
return MakePyCypherType(mgp_type_path());
|
||||
}
|
||||
PyObject *PyMgpModuleTypePath(PyObject *mod, PyObject *Py_UNUSED(ignored)) { return MakePyCypherType(mgp_type_path()); }
|
||||
|
||||
static PyMethodDef PyMgpModuleMethods[] = {
|
||||
{"type_nullable", PyMgpModuleTypeNullable, METH_O,
|
||||
"Build a type representing either a `null` value or a value of given "
|
||||
"type."},
|
||||
{"type_list", PyMgpModuleTypeList, METH_O,
|
||||
"Build a type representing a list of values of given type."},
|
||||
{"type_any", PyMgpModuleTypeAny, METH_NOARGS,
|
||||
"Get the type representing any value that isn't `null`."},
|
||||
{"type_bool", PyMgpModuleTypeBool, METH_NOARGS,
|
||||
"Get the type representing boolean values."},
|
||||
{"type_string", PyMgpModuleTypeString, METH_NOARGS,
|
||||
"Get the type representing string values."},
|
||||
{"type_int", PyMgpModuleTypeInt, METH_NOARGS,
|
||||
"Get the type representing integer values."},
|
||||
{"type_float", PyMgpModuleTypeFloat, METH_NOARGS,
|
||||
"Get the type representing floating-point values."},
|
||||
{"type_number", PyMgpModuleTypeNumber, METH_NOARGS,
|
||||
"Get the type representing any number value."},
|
||||
{"type_map", PyMgpModuleTypeMap, METH_NOARGS,
|
||||
"Get the type representing map values."},
|
||||
{"type_node", PyMgpModuleTypeNode, METH_NOARGS,
|
||||
"Get the type representing graph node values."},
|
||||
{"type_list", PyMgpModuleTypeList, METH_O, "Build a type representing a list of values of given type."},
|
||||
{"type_any", PyMgpModuleTypeAny, METH_NOARGS, "Get the type representing any value that isn't `null`."},
|
||||
{"type_bool", PyMgpModuleTypeBool, METH_NOARGS, "Get the type representing boolean values."},
|
||||
{"type_string", PyMgpModuleTypeString, METH_NOARGS, "Get the type representing string values."},
|
||||
{"type_int", PyMgpModuleTypeInt, METH_NOARGS, "Get the type representing integer values."},
|
||||
{"type_float", PyMgpModuleTypeFloat, METH_NOARGS, "Get the type representing floating-point values."},
|
||||
{"type_number", PyMgpModuleTypeNumber, METH_NOARGS, "Get the type representing any number value."},
|
||||
{"type_map", PyMgpModuleTypeMap, METH_NOARGS, "Get the type representing map values."},
|
||||
{"type_node", PyMgpModuleTypeNode, METH_NOARGS, "Get the type representing graph node values."},
|
||||
{"type_relationship", PyMgpModuleTypeRelationship, METH_NOARGS,
|
||||
"Get the type representing graph relationship values."},
|
||||
{"type_path", PyMgpModuleTypePath, METH_NOARGS,
|
||||
@ -796,8 +740,7 @@ void PyPropertiesIteratorDealloc(PyPropertiesIterator *self) {
|
||||
Py_TYPE(self)->tp_free(self);
|
||||
}
|
||||
|
||||
PyObject *PyPropertiesIteratorGet(PyPropertiesIterator *self,
|
||||
PyObject *Py_UNUSED(ignored)) {
|
||||
PyObject *PyPropertiesIteratorGet(PyPropertiesIterator *self, PyObject *Py_UNUSED(ignored)) {
|
||||
MG_ASSERT(self->it);
|
||||
MG_ASSERT(self->py_graph);
|
||||
MG_ASSERT(self->py_graph->graph);
|
||||
@ -810,8 +753,7 @@ PyObject *PyPropertiesIteratorGet(PyPropertiesIterator *self,
|
||||
return PyTuple_Pack(2, py_name.Ptr(), py_value.Ptr());
|
||||
}
|
||||
|
||||
PyObject *PyPropertiesIteratorNext(PyPropertiesIterator *self,
|
||||
PyObject *Py_UNUSED(ignored)) {
|
||||
PyObject *PyPropertiesIteratorNext(PyPropertiesIterator *self, PyObject *Py_UNUSED(ignored)) {
|
||||
MG_ASSERT(self->it);
|
||||
MG_ASSERT(self->py_graph);
|
||||
MG_ASSERT(self->py_graph->graph);
|
||||
@ -827,8 +769,8 @@ PyObject *PyPropertiesIteratorNext(PyPropertiesIterator *self,
|
||||
static PyMethodDef PyPropertiesIteratorMethods[] = {
|
||||
{"get", reinterpret_cast<PyCFunction>(PyPropertiesIteratorGet), METH_NOARGS,
|
||||
"Get the current proprety pointed to by the iterator or return None."},
|
||||
{"next", reinterpret_cast<PyCFunction>(PyPropertiesIteratorNext),
|
||||
METH_NOARGS, "Advance the iterator to the next property and return it."},
|
||||
{"next", reinterpret_cast<PyCFunction>(PyPropertiesIteratorNext), METH_NOARGS,
|
||||
"Advance the iterator to the next property and return it."},
|
||||
{nullptr},
|
||||
};
|
||||
|
||||
@ -908,15 +850,12 @@ PyObject *PyEdgeIterProperties(PyEdge *self, PyObject *Py_UNUSED(ignored)) {
|
||||
MG_ASSERT(self->edge);
|
||||
MG_ASSERT(self->py_graph);
|
||||
MG_ASSERT(self->py_graph->graph);
|
||||
auto *properties_it =
|
||||
mgp_edge_iter_properties(self->edge, self->py_graph->memory);
|
||||
auto *properties_it = mgp_edge_iter_properties(self->edge, self->py_graph->memory);
|
||||
if (!properties_it) {
|
||||
PyErr_SetString(PyExc_MemoryError,
|
||||
"Unable to allocate mgp_properties_iterator.");
|
||||
PyErr_SetString(PyExc_MemoryError, "Unable to allocate mgp_properties_iterator.");
|
||||
return nullptr;
|
||||
}
|
||||
auto *py_properties_it =
|
||||
PyObject_New(PyPropertiesIterator, &PyPropertiesIteratorType);
|
||||
auto *py_properties_it = PyObject_New(PyPropertiesIterator, &PyPropertiesIteratorType);
|
||||
if (!py_properties_it) {
|
||||
mgp_properties_iterator_destroy(properties_it);
|
||||
return nullptr;
|
||||
@ -934,11 +873,9 @@ PyObject *PyEdgeGetProperty(PyEdge *self, PyObject *args) {
|
||||
MG_ASSERT(self->py_graph->graph);
|
||||
const char *prop_name = nullptr;
|
||||
if (!PyArg_ParseTuple(args, "s", &prop_name)) return nullptr;
|
||||
auto *prop_value =
|
||||
mgp_edge_get_property(self->edge, prop_name, self->py_graph->memory);
|
||||
auto *prop_value = mgp_edge_get_property(self->edge, prop_name, self->py_graph->memory);
|
||||
if (!prop_value) {
|
||||
PyErr_SetString(PyExc_MemoryError,
|
||||
"Unable to allocate mgp_value for edge property value.");
|
||||
PyErr_SetString(PyExc_MemoryError, "Unable to allocate mgp_value for edge property value.");
|
||||
return nullptr;
|
||||
}
|
||||
auto py_prop_value = MgpValueToPyObject(*prop_value, self->py_graph);
|
||||
@ -947,22 +884,17 @@ PyObject *PyEdgeGetProperty(PyEdge *self, PyObject *args) {
|
||||
}
|
||||
|
||||
static PyMethodDef PyEdgeMethods[] = {
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy),
|
||||
METH_NOARGS, "__reduce__ is not supported."},
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported."},
|
||||
{"is_valid", reinterpret_cast<PyCFunction>(PyEdgeIsValid), METH_NOARGS,
|
||||
"Return True if Edge is in valid context and may be used."},
|
||||
{"get_id", reinterpret_cast<PyCFunction>(PyEdgeGetId), METH_NOARGS,
|
||||
"Return edge id."},
|
||||
{"get_type_name", reinterpret_cast<PyCFunction>(PyEdgeGetTypeName),
|
||||
METH_NOARGS, "Return the edge's type name."},
|
||||
{"from_vertex", reinterpret_cast<PyCFunction>(PyEdgeFromVertex),
|
||||
METH_NOARGS, "Return the edge's source vertex."},
|
||||
{"to_vertex", reinterpret_cast<PyCFunction>(PyEdgeToVertex), METH_NOARGS,
|
||||
"Return the edge's destination vertex."},
|
||||
{"iter_properties", reinterpret_cast<PyCFunction>(PyEdgeIterProperties),
|
||||
METH_NOARGS, "Return _mgp.PropertiesIterator for this edge."},
|
||||
{"get_property", reinterpret_cast<PyCFunction>(PyEdgeGetProperty),
|
||||
METH_VARARGS, "Return edge property with given name."},
|
||||
{"get_id", reinterpret_cast<PyCFunction>(PyEdgeGetId), METH_NOARGS, "Return edge id."},
|
||||
{"get_type_name", reinterpret_cast<PyCFunction>(PyEdgeGetTypeName), METH_NOARGS, "Return the edge's type name."},
|
||||
{"from_vertex", reinterpret_cast<PyCFunction>(PyEdgeFromVertex), METH_NOARGS, "Return the edge's source vertex."},
|
||||
{"to_vertex", reinterpret_cast<PyCFunction>(PyEdgeToVertex), METH_NOARGS, "Return the edge's destination vertex."},
|
||||
{"iter_properties", reinterpret_cast<PyCFunction>(PyEdgeIterProperties), METH_NOARGS,
|
||||
"Return _mgp.PropertiesIterator for this edge."},
|
||||
{"get_property", reinterpret_cast<PyCFunction>(PyEdgeGetProperty), METH_VARARGS,
|
||||
"Return edge property with given name."},
|
||||
{nullptr},
|
||||
};
|
||||
|
||||
@ -1008,8 +940,7 @@ PyObject *PyEdgeRichCompare(PyObject *self, PyObject *other, int op) {
|
||||
MG_ASSERT(self);
|
||||
MG_ASSERT(other);
|
||||
|
||||
if (Py_TYPE(self) != &PyEdgeType || Py_TYPE(other) != &PyEdgeType ||
|
||||
op != Py_EQ) {
|
||||
if (Py_TYPE(self) != &PyEdgeType || Py_TYPE(other) != &PyEdgeType || op != Py_EQ) {
|
||||
Py_RETURN_NOTIMPLEMENTED;
|
||||
}
|
||||
|
||||
@ -1065,14 +996,12 @@ PyObject *PyVertexLabelAt(PyVertex *self, PyObject *args) {
|
||||
MG_ASSERT(self->vertex);
|
||||
MG_ASSERT(self->py_graph);
|
||||
MG_ASSERT(self->py_graph->graph);
|
||||
static_assert(std::numeric_limits<Py_ssize_t>::max() <=
|
||||
std::numeric_limits<size_t>::max());
|
||||
static_assert(std::numeric_limits<Py_ssize_t>::max() <= std::numeric_limits<size_t>::max());
|
||||
Py_ssize_t id;
|
||||
if (!PyArg_ParseTuple(args, "n", &id)) return nullptr;
|
||||
auto label = mgp_vertex_label_at(self->vertex, id);
|
||||
if (label.name == nullptr || id < 0) {
|
||||
PyErr_SetString(PyExc_IndexError,
|
||||
"Unable to find the label with given ID.");
|
||||
PyErr_SetString(PyExc_IndexError, "Unable to find the label with given ID.");
|
||||
return nullptr;
|
||||
}
|
||||
return PyUnicode_FromString(label.name);
|
||||
@ -1083,11 +1012,9 @@ PyObject *PyVertexIterInEdges(PyVertex *self, PyObject *Py_UNUSED(ignored)) {
|
||||
MG_ASSERT(self->vertex);
|
||||
MG_ASSERT(self->py_graph);
|
||||
MG_ASSERT(self->py_graph->graph);
|
||||
auto *edges_it =
|
||||
mgp_vertex_iter_in_edges(self->vertex, self->py_graph->memory);
|
||||
auto *edges_it = mgp_vertex_iter_in_edges(self->vertex, self->py_graph->memory);
|
||||
if (!edges_it) {
|
||||
PyErr_SetString(PyExc_MemoryError,
|
||||
"Unable to allocate mgp_edges_iterator for in edges.");
|
||||
PyErr_SetString(PyExc_MemoryError, "Unable to allocate mgp_edges_iterator for in edges.");
|
||||
return nullptr;
|
||||
}
|
||||
auto *py_edges_it = PyObject_New(PyEdgesIterator, &PyEdgesIteratorType);
|
||||
@ -1106,11 +1033,9 @@ PyObject *PyVertexIterOutEdges(PyVertex *self, PyObject *Py_UNUSED(ignored)) {
|
||||
MG_ASSERT(self->vertex);
|
||||
MG_ASSERT(self->py_graph);
|
||||
MG_ASSERT(self->py_graph->graph);
|
||||
auto *edges_it =
|
||||
mgp_vertex_iter_out_edges(self->vertex, self->py_graph->memory);
|
||||
auto *edges_it = mgp_vertex_iter_out_edges(self->vertex, self->py_graph->memory);
|
||||
if (!edges_it) {
|
||||
PyErr_SetString(PyExc_MemoryError,
|
||||
"Unable to allocate mgp_edges_iterator for out edges.");
|
||||
PyErr_SetString(PyExc_MemoryError, "Unable to allocate mgp_edges_iterator for out edges.");
|
||||
return nullptr;
|
||||
}
|
||||
auto *py_edges_it = PyObject_New(PyEdgesIterator, &PyEdgesIteratorType);
|
||||
@ -1129,15 +1054,12 @@ PyObject *PyVertexIterProperties(PyVertex *self, PyObject *Py_UNUSED(ignored)) {
|
||||
MG_ASSERT(self->vertex);
|
||||
MG_ASSERT(self->py_graph);
|
||||
MG_ASSERT(self->py_graph->graph);
|
||||
auto *properties_it =
|
||||
mgp_vertex_iter_properties(self->vertex, self->py_graph->memory);
|
||||
auto *properties_it = mgp_vertex_iter_properties(self->vertex, self->py_graph->memory);
|
||||
if (!properties_it) {
|
||||
PyErr_SetString(PyExc_MemoryError,
|
||||
"Unable to allocate mgp_properties_iterator.");
|
||||
PyErr_SetString(PyExc_MemoryError, "Unable to allocate mgp_properties_iterator.");
|
||||
return nullptr;
|
||||
}
|
||||
auto *py_properties_it =
|
||||
PyObject_New(PyPropertiesIterator, &PyPropertiesIteratorType);
|
||||
auto *py_properties_it = PyObject_New(PyPropertiesIterator, &PyPropertiesIteratorType);
|
||||
if (!py_properties_it) {
|
||||
mgp_properties_iterator_destroy(properties_it);
|
||||
return nullptr;
|
||||
@ -1155,11 +1077,9 @@ PyObject *PyVertexGetProperty(PyVertex *self, PyObject *args) {
|
||||
MG_ASSERT(self->py_graph->graph);
|
||||
const char *prop_name = nullptr;
|
||||
if (!PyArg_ParseTuple(args, "s", &prop_name)) return nullptr;
|
||||
auto *prop_value =
|
||||
mgp_vertex_get_property(self->vertex, prop_name, self->py_graph->memory);
|
||||
auto *prop_value = mgp_vertex_get_property(self->vertex, prop_name, self->py_graph->memory);
|
||||
if (!prop_value) {
|
||||
PyErr_SetString(PyExc_MemoryError,
|
||||
"Unable to allocate mgp_value for vertex property value.");
|
||||
PyErr_SetString(PyExc_MemoryError, "Unable to allocate mgp_value for vertex property value.");
|
||||
return nullptr;
|
||||
}
|
||||
auto py_prop_value = MgpValueToPyObject(*prop_value, self->py_graph);
|
||||
@ -1168,24 +1088,22 @@ PyObject *PyVertexGetProperty(PyVertex *self, PyObject *args) {
|
||||
}
|
||||
|
||||
static PyMethodDef PyVertexMethods[] = {
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy),
|
||||
METH_NOARGS, "__reduce__ is not supported."},
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported."},
|
||||
{"is_valid", reinterpret_cast<PyCFunction>(PyVertexIsValid), METH_NOARGS,
|
||||
"Return True if Vertex is in valid context and may be used."},
|
||||
{"get_id", reinterpret_cast<PyCFunction>(PyVertexGetId), METH_NOARGS,
|
||||
"Return vertex id."},
|
||||
{"labels_count", reinterpret_cast<PyCFunction>(PyVertexLabelsCount),
|
||||
METH_NOARGS, "Return number of lables of a vertex."},
|
||||
{"get_id", reinterpret_cast<PyCFunction>(PyVertexGetId), METH_NOARGS, "Return vertex id."},
|
||||
{"labels_count", reinterpret_cast<PyCFunction>(PyVertexLabelsCount), METH_NOARGS,
|
||||
"Return number of lables of a vertex."},
|
||||
{"label_at", reinterpret_cast<PyCFunction>(PyVertexLabelAt), METH_VARARGS,
|
||||
"Return label of a vertex on a given index."},
|
||||
{"iter_in_edges", reinterpret_cast<PyCFunction>(PyVertexIterInEdges),
|
||||
METH_NOARGS, "Return _mgp.EdgesIterator for in edges."},
|
||||
{"iter_out_edges", reinterpret_cast<PyCFunction>(PyVertexIterOutEdges),
|
||||
METH_NOARGS, "Return _mgp.EdgesIterator for out edges."},
|
||||
{"iter_properties", reinterpret_cast<PyCFunction>(PyVertexIterProperties),
|
||||
METH_NOARGS, "Return _mgp.PropertiesIterator for this vertex."},
|
||||
{"get_property", reinterpret_cast<PyCFunction>(PyVertexGetProperty),
|
||||
METH_VARARGS, "Return vertex property with given name."},
|
||||
{"iter_in_edges", reinterpret_cast<PyCFunction>(PyVertexIterInEdges), METH_NOARGS,
|
||||
"Return _mgp.EdgesIterator for in edges."},
|
||||
{"iter_out_edges", reinterpret_cast<PyCFunction>(PyVertexIterOutEdges), METH_NOARGS,
|
||||
"Return _mgp.EdgesIterator for out edges."},
|
||||
{"iter_properties", reinterpret_cast<PyCFunction>(PyVertexIterProperties), METH_NOARGS,
|
||||
"Return _mgp.PropertiesIterator for this vertex."},
|
||||
{"get_property", reinterpret_cast<PyCFunction>(PyVertexGetProperty), METH_VARARGS,
|
||||
"Return vertex property with given name."},
|
||||
{nullptr},
|
||||
};
|
||||
|
||||
@ -1234,8 +1152,7 @@ PyObject *PyVertexRichCompare(PyObject *self, PyObject *other, int op) {
|
||||
MG_ASSERT(self);
|
||||
MG_ASSERT(other);
|
||||
|
||||
if (Py_TYPE(self) != &PyVertexType || Py_TYPE(other) != &PyVertexType ||
|
||||
op != Py_EQ) {
|
||||
if (Py_TYPE(self) != &PyVertexType || Py_TYPE(other) != &PyVertexType || op != Py_EQ) {
|
||||
Py_RETURN_NOTIMPLEMENTED;
|
||||
}
|
||||
|
||||
@ -1283,12 +1200,9 @@ PyObject *PyPathExpand(PyPath *self, PyObject *edge) {
|
||||
auto *py_edge = reinterpret_cast<PyEdge *>(edge);
|
||||
const auto *to = mgp_edge_get_to(py_edge->edge);
|
||||
const auto *from = mgp_edge_get_from(py_edge->edge);
|
||||
const auto *last_vertex =
|
||||
mgp_path_vertex_at(self->path, mgp_path_size(self->path));
|
||||
if (!mgp_vertex_equal(last_vertex, to) &&
|
||||
!mgp_vertex_equal(last_vertex, from)) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"Edge is not a continuation of the path.");
|
||||
const auto *last_vertex = mgp_path_vertex_at(self->path, mgp_path_size(self->path));
|
||||
if (!mgp_vertex_equal(last_vertex, to) && !mgp_vertex_equal(last_vertex, from)) {
|
||||
PyErr_SetString(PyExc_ValueError, "Edge is not a continuation of the path.");
|
||||
return nullptr;
|
||||
}
|
||||
if (!mgp_path_expand(self->path, py_edge->edge)) {
|
||||
@ -1309,8 +1223,7 @@ PyObject *PyPathVertexAt(PyPath *self, PyObject *args) {
|
||||
MG_ASSERT(self->path);
|
||||
MG_ASSERT(self->py_graph);
|
||||
MG_ASSERT(self->py_graph->graph);
|
||||
static_assert(std::numeric_limits<Py_ssize_t>::max() <=
|
||||
std::numeric_limits<size_t>::max());
|
||||
static_assert(std::numeric_limits<Py_ssize_t>::max() <= std::numeric_limits<size_t>::max());
|
||||
Py_ssize_t i;
|
||||
if (!PyArg_ParseTuple(args, "n", &i)) return nullptr;
|
||||
const auto *vertex = mgp_path_vertex_at(self->path, i);
|
||||
@ -1325,8 +1238,7 @@ PyObject *PyPathEdgeAt(PyPath *self, PyObject *args) {
|
||||
MG_ASSERT(self->path);
|
||||
MG_ASSERT(self->py_graph);
|
||||
MG_ASSERT(self->py_graph->graph);
|
||||
static_assert(std::numeric_limits<Py_ssize_t>::max() <=
|
||||
std::numeric_limits<size_t>::max());
|
||||
static_assert(std::numeric_limits<Py_ssize_t>::max() <= std::numeric_limits<size_t>::max());
|
||||
Py_ssize_t i;
|
||||
if (!PyArg_ParseTuple(args, "n", &i)) return nullptr;
|
||||
const auto *edge = mgp_path_edge_at(self->path, i);
|
||||
@ -1338,16 +1250,14 @@ PyObject *PyPathEdgeAt(PyPath *self, PyObject *args) {
|
||||
}
|
||||
|
||||
static PyMethodDef PyPathMethods[] = {
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy),
|
||||
METH_NOARGS, "__reduce__ is not supported"},
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
|
||||
{"is_valid", reinterpret_cast<PyCFunction>(PyPathIsValid), METH_NOARGS,
|
||||
"Return True if Path is in valid context and may be used."},
|
||||
{"make_with_start", reinterpret_cast<PyCFunction>(PyPathMakeWithStart),
|
||||
METH_O | METH_CLASS, "Create a path with a starting vertex."},
|
||||
{"make_with_start", reinterpret_cast<PyCFunction>(PyPathMakeWithStart), METH_O | METH_CLASS,
|
||||
"Create a path with a starting vertex."},
|
||||
{"expand", reinterpret_cast<PyCFunction>(PyPathExpand), METH_O,
|
||||
"Append an edge continuing from the last vertex on the path."},
|
||||
{"size", reinterpret_cast<PyCFunction>(PyPathSize), METH_NOARGS,
|
||||
"Return the number of edges in a mgp_path."},
|
||||
{"size", reinterpret_cast<PyCFunction>(PyPathSize), METH_NOARGS, "Return the number of edges in a mgp_path."},
|
||||
{"vertex_at", reinterpret_cast<PyCFunction>(PyPathVertexAt), METH_VARARGS,
|
||||
"Return the vertex from a path at given index."},
|
||||
{"edge_at", reinterpret_cast<PyCFunction>(PyPathEdgeAt), METH_VARARGS,
|
||||
@ -1394,18 +1304,15 @@ PyObject *MakePyPath(const mgp_path &path, PyGraph *py_graph) {
|
||||
|
||||
PyObject *PyPathMakeWithStart(PyTypeObject *type, PyObject *vertex) {
|
||||
if (type != &PyPathType) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"Expected '<class _mgp.Path>' as the first argument.");
|
||||
PyErr_SetString(PyExc_TypeError, "Expected '<class _mgp.Path>' as the first argument.");
|
||||
return nullptr;
|
||||
}
|
||||
if (Py_TYPE(vertex) != &PyVertexType) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"Expected a '_mgp.Vertex' as the second argument.");
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a '_mgp.Vertex' as the second argument.");
|
||||
return nullptr;
|
||||
}
|
||||
auto *py_vertex = reinterpret_cast<PyVertex *>(vertex);
|
||||
auto *path =
|
||||
mgp_path_make_with_start(py_vertex->vertex, py_vertex->py_graph->memory);
|
||||
auto *path = mgp_path_make_with_start(py_vertex->vertex, py_vertex->py_graph->memory);
|
||||
if (!path) {
|
||||
PyErr_SetString(PyExc_MemoryError, "Unable to allocate mgp_path.");
|
||||
return nullptr;
|
||||
@ -1431,10 +1338,8 @@ PyObject *PyInitMgpModule() {
|
||||
}
|
||||
return true;
|
||||
};
|
||||
if (!register_type(&PyPropertiesIteratorType, "PropertiesIterator"))
|
||||
return nullptr;
|
||||
if (!register_type(&PyVerticesIteratorType, "VerticesIterator"))
|
||||
return nullptr;
|
||||
if (!register_type(&PyPropertiesIteratorType, "PropertiesIterator")) return nullptr;
|
||||
if (!register_type(&PyVerticesIteratorType, "VerticesIterator")) return nullptr;
|
||||
if (!register_type(&PyEdgesIteratorType, "EdgesIterator")) return nullptr;
|
||||
if (!register_type(&PyGraphType, "Graph")) return nullptr;
|
||||
if (!register_type(&PyEdgeType, "Edge")) return nullptr;
|
||||
@ -1481,14 +1386,11 @@ auto WithMgpModule(mgp_module *module_def, const TFun &fun) {
|
||||
} // namespace
|
||||
|
||||
py::Object ImportPyModule(const char *name, mgp_module *module_def) {
|
||||
return WithMgpModule(
|
||||
module_def, [name]() { return py::Object(PyImport_ImportModule(name)); });
|
||||
return WithMgpModule(module_def, [name]() { return py::Object(PyImport_ImportModule(name)); });
|
||||
}
|
||||
|
||||
py::Object ReloadPyModule(PyObject *py_module, mgp_module *module_def) {
|
||||
return WithMgpModule(module_def, [py_module]() {
|
||||
return py::Object(PyImport_ReloadModule(py_module));
|
||||
});
|
||||
return WithMgpModule(module_def, [py_module]() { return py::Object(PyImport_ReloadModule(py_module)); });
|
||||
}
|
||||
|
||||
py::Object MgpValueToPyObject(const mgp_value &value, PyObject *py_graph) {
|
||||
@ -1522,8 +1424,7 @@ py::Object MgpValueToPyObject(const mgp_value &value, PyGraph *py_graph) {
|
||||
auto py_val = MgpValueToPyObject(val, py_graph);
|
||||
if (!py_val) return nullptr;
|
||||
// Unlike PyList_SET_ITEM, PyDict_SetItem does not steal the value.
|
||||
if (PyDict_SetItemString(py_dict.Ptr(), key.c_str(), py_val.Ptr()) != 0)
|
||||
return nullptr;
|
||||
if (PyDict_SetItemString(py_dict.Ptr(), key.c_str(), py_val.Ptr()) != 0) return nullptr;
|
||||
}
|
||||
return py_dict;
|
||||
}
|
||||
@ -1531,34 +1432,29 @@ py::Object MgpValueToPyObject(const mgp_value &value, PyGraph *py_graph) {
|
||||
py::Object py_mgp(PyImport_ImportModule("mgp"));
|
||||
if (!py_mgp) return nullptr;
|
||||
const auto *v = mgp_value_get_vertex(&value);
|
||||
py::Object py_vertex(
|
||||
reinterpret_cast<PyObject *>(MakePyVertex(*v, py_graph)));
|
||||
py::Object py_vertex(reinterpret_cast<PyObject *>(MakePyVertex(*v, py_graph)));
|
||||
return py_mgp.CallMethod("Vertex", py_vertex);
|
||||
}
|
||||
case MGP_VALUE_TYPE_EDGE: {
|
||||
py::Object py_mgp(PyImport_ImportModule("mgp"));
|
||||
if (!py_mgp) return nullptr;
|
||||
const auto *e = mgp_value_get_edge(&value);
|
||||
py::Object py_edge(
|
||||
reinterpret_cast<PyObject *>(MakePyEdge(*e, py_graph)));
|
||||
py::Object py_edge(reinterpret_cast<PyObject *>(MakePyEdge(*e, py_graph)));
|
||||
return py_mgp.CallMethod("Edge", py_edge);
|
||||
}
|
||||
case MGP_VALUE_TYPE_PATH: {
|
||||
py::Object py_mgp(PyImport_ImportModule("mgp"));
|
||||
if (!py_mgp) return nullptr;
|
||||
const auto *p = mgp_value_get_path(&value);
|
||||
py::Object py_path(
|
||||
reinterpret_cast<PyObject *>(MakePyPath(*p, py_graph)));
|
||||
py::Object py_path(reinterpret_cast<PyObject *>(MakePyPath(*p, py_graph)));
|
||||
return py_mgp.CallMethod("Path", py_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mgp_value *PyObjectToMgpValue(PyObject *o, mgp_memory *memory) {
|
||||
auto py_seq_to_list = [memory](PyObject *seq, Py_ssize_t len,
|
||||
const auto &py_seq_get_item) {
|
||||
static_assert(std::numeric_limits<Py_ssize_t>::max() <=
|
||||
std::numeric_limits<size_t>::max());
|
||||
auto py_seq_to_list = [memory](PyObject *seq, Py_ssize_t len, const auto &py_seq_get_item) {
|
||||
static_assert(std::numeric_limits<Py_ssize_t>::max() <= std::numeric_limits<size_t>::max());
|
||||
mgp_list *list = mgp_list_make_empty(len, memory);
|
||||
if (!list) throw std::bad_alloc();
|
||||
for (Py_ssize_t i = 0; i < len; ++i) {
|
||||
@ -1603,8 +1499,7 @@ mgp_value *PyObjectToMgpValue(PyObject *o, mgp_memory *memory) {
|
||||
if (res == -1) {
|
||||
PyErr_Clear();
|
||||
std::stringstream ss;
|
||||
ss << "Error when checking object is instance of 'mgp." << mgp_type_name
|
||||
<< "' type";
|
||||
ss << "Error when checking object is instance of 'mgp." << mgp_type_name << "' type";
|
||||
throw std::invalid_argument(ss.str());
|
||||
}
|
||||
return static_cast<bool>(res);
|
||||
@ -1628,13 +1523,9 @@ mgp_value *PyObjectToMgpValue(PyObject *o, mgp_memory *memory) {
|
||||
} else if (PyUnicode_Check(o)) {
|
||||
mgp_v = mgp_value_make_string(PyUnicode_AsUTF8(o), memory);
|
||||
} else if (PyList_Check(o)) {
|
||||
mgp_v = py_seq_to_list(o, PyList_Size(o), [](auto *list, const auto i) {
|
||||
return PyList_GET_ITEM(list, i);
|
||||
});
|
||||
mgp_v = py_seq_to_list(o, PyList_Size(o), [](auto *list, const auto i) { return PyList_GET_ITEM(list, i); });
|
||||
} else if (PyTuple_Check(o)) {
|
||||
mgp_v = py_seq_to_list(o, PyTuple_Size(o), [](auto *tuple, const auto i) {
|
||||
return PyTuple_GET_ITEM(tuple, i);
|
||||
});
|
||||
mgp_v = py_seq_to_list(o, PyTuple_Size(o), [](auto *tuple, const auto i) { return PyTuple_GET_ITEM(tuple, i); });
|
||||
} else if (PyDict_Check(o)) {
|
||||
mgp_map *map = mgp_map_make_empty(memory);
|
||||
|
||||
@ -1725,8 +1616,7 @@ mgp_value *PyObjectToMgpValue(PyObject *o, mgp_memory *memory) {
|
||||
py::Object vertex(PyObject_GetAttrString(o, "_vertex"));
|
||||
if (!vertex) {
|
||||
PyErr_Clear();
|
||||
throw std::invalid_argument(
|
||||
"'mgp.Vertex' is missing '_vertex' attribute");
|
||||
throw std::invalid_argument("'mgp.Vertex' is missing '_vertex' attribute");
|
||||
}
|
||||
return PyObjectToMgpValue(vertex.Ptr(), memory);
|
||||
} else if (is_mgp_instance(o, "Path")) {
|
||||
|
@ -22,17 +22,15 @@ class AnyStream final {
|
||||
public:
|
||||
template <class TStream>
|
||||
AnyStream(TStream *stream, utils::MemoryResource *memory_resource)
|
||||
: content_{utils::Allocator<GenericWrapper<TStream>>{memory_resource}
|
||||
.template new_object<GenericWrapper<TStream>>(stream),
|
||||
: content_{
|
||||
utils::Allocator<GenericWrapper<TStream>>{memory_resource}.template new_object<GenericWrapper<TStream>>(
|
||||
stream),
|
||||
[memory_resource](Wrapper *ptr) {
|
||||
utils::Allocator<GenericWrapper<TStream>>{memory_resource}
|
||||
.template delete_object<GenericWrapper<TStream>>(
|
||||
static_cast<GenericWrapper<TStream> *>(ptr));
|
||||
.template delete_object<GenericWrapper<TStream>>(static_cast<GenericWrapper<TStream> *>(ptr));
|
||||
}} {}
|
||||
|
||||
void Result(const std::vector<TypedValue> &values) {
|
||||
content_->Result(values);
|
||||
}
|
||||
void Result(const std::vector<TypedValue> &values) { content_->Result(values); }
|
||||
|
||||
private:
|
||||
struct Wrapper {
|
||||
@ -43,9 +41,7 @@ class AnyStream final {
|
||||
struct GenericWrapper final : public Wrapper {
|
||||
explicit GenericWrapper(TStream *stream) : stream_{stream} {}
|
||||
|
||||
void Result(const std::vector<TypedValue> &values) override {
|
||||
stream_->Result(values);
|
||||
}
|
||||
void Result(const std::vector<TypedValue> &values) override { stream_->Result(values); }
|
||||
|
||||
TStream *stream_;
|
||||
};
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user