Merge branch 'project-pineapples' into T1165-MG-add-property-based-high-level-query-test

This commit is contained in:
gvolfing 2022-12-15 17:21:39 +01:00
commit 1ebde8be74
46 changed files with 2604 additions and 364 deletions

View File

@ -283,7 +283,7 @@ std::vector<ShardToInitialize> ShardMap::AssignShards(Address storage_manager,
// TODO(tyler) avoid these triple-nested loops by having the heartbeat include better info // TODO(tyler) avoid these triple-nested loops by having the heartbeat include better info
bool machine_contains_shard = false; bool machine_contains_shard = false;
for (auto &aas : shard) { for (auto &aas : shard.peers) {
if (initialized.contains(aas.address.unique_id)) { if (initialized.contains(aas.address.unique_id)) {
machine_contains_shard = true; machine_contains_shard = true;
if (aas.status != Status::CONSENSUS_PARTICIPANT) { if (aas.status != Status::CONSENSUS_PARTICIPANT) {
@ -311,7 +311,7 @@ std::vector<ShardToInitialize> ShardMap::AssignShards(Address storage_manager,
} }
} }
if (!machine_contains_shard && shard.size() < label_space.replication_factor) { if (!machine_contains_shard && shard.peers.size() < label_space.replication_factor) {
// increment version for each new uuid for deterministic creation // increment version for each new uuid for deterministic creation
IncrementShardMapVersion(); IncrementShardMapVersion();
@ -337,7 +337,7 @@ std::vector<ShardToInitialize> ShardMap::AssignShards(Address storage_manager,
.status = Status::INITIALIZING, .status = Status::INITIALIZING,
}; };
shard.emplace_back(aas); shard.peers.emplace_back(aas);
} }
} }
} }
@ -360,9 +360,9 @@ bool ShardMap::SplitShard(Hlc previous_shard_map_version, LabelId label_id, cons
MG_ASSERT(!shards_in_map.contains(key)); MG_ASSERT(!shards_in_map.contains(key));
MG_ASSERT(label_spaces.contains(label_id)); MG_ASSERT(label_spaces.contains(label_id));
// Finding the Shard that the new PrimaryKey should map to. // Finding the ShardMetadata that the new PrimaryKey should map to.
auto prev = std::prev(shards_in_map.upper_bound(key)); auto prev = std::prev(shards_in_map.upper_bound(key));
Shard duplicated_shard = prev->second; ShardMetadata duplicated_shard = prev->second;
// Apply the split // Apply the split
shards_in_map[key] = duplicated_shard; shards_in_map[key] = duplicated_shard;
@ -383,7 +383,7 @@ std::optional<LabelId> ShardMap::InitializeNewLabel(std::string label_name, std:
labels.emplace(std::move(label_name), label_id); labels.emplace(std::move(label_name), label_id);
PrimaryKey initial_key = SchemaToMinKey(schema); PrimaryKey initial_key = SchemaToMinKey(schema);
Shard empty_shard = {}; ShardMetadata empty_shard = {};
Shards shards = { Shards shards = {
{initial_key, empty_shard}, {initial_key, empty_shard},
@ -479,7 +479,7 @@ Shards ShardMap::GetShardsForRange(const LabelName &label_name, const PrimaryKey
return shards; return shards;
} }
Shard ShardMap::GetShardForKey(const LabelName &label_name, const PrimaryKey &key) const { ShardMetadata ShardMap::GetShardForKey(const LabelName &label_name, const PrimaryKey &key) const {
MG_ASSERT(labels.contains(label_name)); MG_ASSERT(labels.contains(label_name));
LabelId label_id = labels.at(label_name); LabelId label_id = labels.at(label_name);
@ -492,7 +492,7 @@ Shard ShardMap::GetShardForKey(const LabelName &label_name, const PrimaryKey &ke
return std::prev(label_space.shards.upper_bound(key))->second; return std::prev(label_space.shards.upper_bound(key))->second;
} }
Shard ShardMap::GetShardForKey(const LabelId &label_id, const PrimaryKey &key) const { ShardMetadata ShardMap::GetShardForKey(const LabelId &label_id, const PrimaryKey &key) const {
MG_ASSERT(label_spaces.contains(label_id)); MG_ASSERT(label_spaces.contains(label_id));
const auto &label_space = label_spaces.at(label_id); const auto &label_space = label_spaces.at(label_id);
@ -556,12 +556,12 @@ EdgeTypeIdMap ShardMap::AllocateEdgeTypeIds(const std::vector<EdgeTypeName> &new
bool ShardMap::ClusterInitialized() const { bool ShardMap::ClusterInitialized() const {
for (const auto &[label_id, label_space] : label_spaces) { for (const auto &[label_id, label_space] : label_spaces) {
for (const auto &[low_key, shard] : label_space.shards) { for (const auto &[low_key, shard] : label_space.shards) {
if (shard.size() < label_space.replication_factor) { if (shard.peers.size() < label_space.replication_factor) {
spdlog::info("label_space below desired replication factor"); spdlog::info("label_space below desired replication factor");
return false; return false;
} }
for (const auto &aas : shard) { for (const auto &aas : shard.peers) {
if (aas.status != Status::CONSENSUS_PARTICIPANT) { if (aas.status != Status::CONSENSUS_PARTICIPANT) {
spdlog::info("shard member not yet a CONSENSUS_PARTICIPANT"); spdlog::info("shard member not yet a CONSENSUS_PARTICIPANT");
return false; return false;

View File

@ -76,8 +76,35 @@ struct AddressAndStatus {
}; };
using PrimaryKey = std::vector<PropertyValue>; using PrimaryKey = std::vector<PropertyValue>;
using Shard = std::vector<AddressAndStatus>;
using Shards = std::map<PrimaryKey, Shard>; struct ShardMetadata {
std::vector<AddressAndStatus> peers;
uint64_t version;
friend std::ostream &operator<<(std::ostream &in, const ShardMetadata &shard) {
using utils::print_helpers::operator<<;
in << "ShardMetadata { peers: ";
in << shard.peers;
in << " version: ";
in << shard.version;
in << " }";
return in;
}
friend bool operator==(const ShardMetadata &lhs, const ShardMetadata &rhs) = default;
friend bool operator<(const ShardMetadata &lhs, const ShardMetadata &rhs) {
if (lhs.peers != rhs.peers) {
return lhs.peers < rhs.peers;
}
return lhs.version < rhs.version;
}
};
using Shards = std::map<PrimaryKey, ShardMetadata>;
using LabelName = std::string; using LabelName = std::string;
using PropertyName = std::string; using PropertyName = std::string;
using EdgeTypeName = std::string; using EdgeTypeName = std::string;
@ -99,7 +126,7 @@ PrimaryKey SchemaToMinKey(const std::vector<SchemaProperty> &schema);
struct LabelSpace { struct LabelSpace {
std::vector<SchemaProperty> schema; std::vector<SchemaProperty> schema;
// Maps between the smallest primary key stored in the shard and the shard // Maps between the smallest primary key stored in the shard and the shard
std::map<PrimaryKey, Shard> shards; std::map<PrimaryKey, ShardMetadata> shards;
size_t replication_factor; size_t replication_factor;
friend std::ostream &operator<<(std::ostream &in, const LabelSpace &label_space) { friend std::ostream &operator<<(std::ostream &in, const LabelSpace &label_space) {
@ -160,9 +187,9 @@ struct ShardMap {
Shards GetShardsForRange(const LabelName &label_name, const PrimaryKey &start_key, const PrimaryKey &end_key) const; Shards GetShardsForRange(const LabelName &label_name, const PrimaryKey &start_key, const PrimaryKey &end_key) const;
Shard GetShardForKey(const LabelName &label_name, const PrimaryKey &key) const; ShardMetadata GetShardForKey(const LabelName &label_name, const PrimaryKey &key) const;
Shard GetShardForKey(const LabelId &label_id, const PrimaryKey &key) const; ShardMetadata GetShardForKey(const LabelId &label_id, const PrimaryKey &key) const;
PropertyMap AllocatePropertyIds(const std::vector<PropertyName> &new_properties); PropertyMap AllocatePropertyIds(const std::vector<PropertyName> &new_properties);

View File

@ -36,8 +36,8 @@ template <typename TypedValue, typename EvaluationContext, typename DbAccessor,
typename PropertyValue, typename ConvFunctor, typename Error, typename Tag = StorageTag> typename PropertyValue, typename ConvFunctor, typename Error, typename Tag = StorageTag>
class ExpressionEvaluator : public ExpressionVisitor<TypedValue> { class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
public: public:
ExpressionEvaluator(Frame<TypedValue> *frame, const SymbolTable &symbol_table, const EvaluationContext &ctx, ExpressionEvaluator(Frame *frame, const SymbolTable &symbol_table, const EvaluationContext &ctx, DbAccessor *dba,
DbAccessor *dba, StorageView view) StorageView view)
: frame_(frame), symbol_table_(&symbol_table), ctx_(&ctx), dba_(dba), view_(view) {} : frame_(frame), symbol_table_(&symbol_table), ctx_(&ctx), dba_(dba), view_(view) {}
using ExpressionVisitor<TypedValue>::Visit; using ExpressionVisitor<TypedValue>::Visit;
@ -782,7 +782,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
LabelId GetLabel(LabelIx label) { return ctx_->labels[label.ix]; } LabelId GetLabel(LabelIx label) { return ctx_->labels[label.ix]; }
Frame<TypedValue> *frame_; Frame *frame_;
const SymbolTable *symbol_table_; const SymbolTable *symbol_table_;
const EvaluationContext *ctx_; const EvaluationContext *ctx_;
DbAccessor *dba_; DbAccessor *dba_;

View File

@ -20,7 +20,6 @@
namespace memgraph::expr { namespace memgraph::expr {
template <typename TypedValue>
class Frame { class Frame {
public: public:
/// Create a Frame of given size backed by a utils::NewDeleteResource() /// Create a Frame of given size backed by a utils::NewDeleteResource()
@ -42,4 +41,18 @@ class Frame {
utils::pmr::vector<TypedValue> elems_; utils::pmr::vector<TypedValue> elems_;
}; };
class FrameWithValidity final : public Frame {
public:
explicit FrameWithValidity(int64_t size) : Frame(size), is_valid_(false) {}
FrameWithValidity(int64_t size, utils::MemoryResource *memory) : Frame(size, memory), is_valid_(false) {}
bool IsValid() const noexcept { return is_valid_; }
void MakeValid() noexcept { is_valid_ = true; }
void MakeInvalid() noexcept { is_valid_ = false; }
private:
bool is_valid_;
};
} // namespace memgraph::expr } // namespace memgraph::expr

View File

@ -35,10 +35,13 @@ class Shared {
std::optional<T> item_; std::optional<T> item_;
bool consumed_ = false; bool consumed_ = false;
bool waiting_ = false; bool waiting_ = false;
std::function<bool()> simulator_notifier_ = nullptr; bool filled_ = false;
std::function<bool()> wait_notifier_ = nullptr;
std::function<void()> fill_notifier_ = nullptr;
public: public:
explicit Shared(std::function<bool()> simulator_notifier) : simulator_notifier_(simulator_notifier) {} explicit Shared(std::function<bool()> wait_notifier, std::function<void()> fill_notifier)
: wait_notifier_(wait_notifier), fill_notifier_(fill_notifier) {}
Shared() = default; Shared() = default;
Shared(Shared &&) = delete; Shared(Shared &&) = delete;
Shared &operator=(Shared &&) = delete; Shared &operator=(Shared &&) = delete;
@ -64,7 +67,7 @@ class Shared {
waiting_ = true; waiting_ = true;
while (!item_) { while (!item_) {
if (simulator_notifier_) [[unlikely]] { if (wait_notifier_) [[unlikely]] {
// We can't hold our own lock while notifying // We can't hold our own lock while notifying
// the simulator because notifying the simulator // the simulator because notifying the simulator
// involves acquiring the simulator's mutex // involves acquiring the simulator's mutex
@ -76,7 +79,7 @@ class Shared {
// so we have to get out of its way to avoid // so we have to get out of its way to avoid
// a cyclical deadlock. // a cyclical deadlock.
lock.unlock(); lock.unlock();
std::invoke(simulator_notifier_); std::invoke(wait_notifier_);
lock.lock(); lock.lock();
if (item_) { if (item_) {
// item may have been filled while we // item may have been filled while we
@ -115,11 +118,19 @@ class Shared {
std::unique_lock<std::mutex> lock(mu_); std::unique_lock<std::mutex> lock(mu_);
MG_ASSERT(!consumed_, "Promise filled after it was already consumed!"); MG_ASSERT(!consumed_, "Promise filled after it was already consumed!");
MG_ASSERT(!item_, "Promise filled twice!"); MG_ASSERT(!filled_, "Promise filled twice!");
item_ = item; item_ = item;
filled_ = true;
} // lock released before condition variable notification } // lock released before condition variable notification
if (fill_notifier_) {
spdlog::trace("calling fill notifier");
std::invoke(fill_notifier_);
} else {
spdlog::trace("not calling fill notifier");
}
cv_.notify_all(); cv_.notify_all();
} }
@ -251,8 +262,9 @@ std::pair<Future<T>, Promise<T>> FuturePromisePair() {
} }
template <typename T> template <typename T>
std::pair<Future<T>, Promise<T>> FuturePromisePairWithNotifier(std::function<bool()> simulator_notifier) { std::pair<Future<T>, Promise<T>> FuturePromisePairWithNotifications(std::function<bool()> wait_notifier,
std::shared_ptr<details::Shared<T>> shared = std::make_shared<details::Shared<T>>(simulator_notifier); std::function<void()> fill_notifier) {
std::shared_ptr<details::Shared<T>> shared = std::make_shared<details::Shared<T>>(wait_notifier, fill_notifier);
Future<T> future = Future<T>(shared); Future<T> future = Future<T>(shared);
Promise<T> promise = Promise<T>(shared); Promise<T> promise = Promise<T>(shared);

View File

@ -31,9 +31,10 @@ class LocalTransport {
: local_transport_handle_(std::move(local_transport_handle)) {} : local_transport_handle_(std::move(local_transport_handle)) {}
template <Message RequestT, Message ResponseT> template <Message RequestT, Message ResponseT>
ResponseFuture<ResponseT> Request(Address to_address, Address from_address, RequestT request, Duration timeout) { ResponseFuture<ResponseT> Request(Address to_address, Address from_address, RequestT request,
return local_transport_handle_->template SubmitRequest<RequestT, ResponseT>(to_address, from_address, std::function<void()> fill_notifier, Duration timeout) {
std::move(request), timeout); return local_transport_handle_->template SubmitRequest<RequestT, ResponseT>(
to_address, from_address, std::move(request), timeout, fill_notifier);
} }
template <Message... Ms> template <Message... Ms>

View File

@ -140,8 +140,12 @@ class LocalTransportHandle {
template <Message RequestT, Message ResponseT> template <Message RequestT, Message ResponseT>
ResponseFuture<ResponseT> SubmitRequest(Address to_address, Address from_address, RequestT &&request, ResponseFuture<ResponseT> SubmitRequest(Address to_address, Address from_address, RequestT &&request,
Duration timeout) { Duration timeout, std::function<void()> fill_notifier) {
auto [future, promise] = memgraph::io::FuturePromisePair<ResponseResult<ResponseT>>(); auto [future, promise] = memgraph::io::FuturePromisePairWithNotifications<ResponseResult<ResponseT>>(
// set null notifier for when the Future::Wait is called
nullptr,
// set notifier for when Promise::Fill is called
std::forward<std::function<void()>>(fill_notifier));
const bool port_matches = to_address.last_known_port == from_address.last_known_port; const bool port_matches = to_address.last_known_port == from_address.last_known_port;
const bool ip_matches = to_address.last_known_ip == from_address.last_known_ip; const bool ip_matches = to_address.last_known_ip == from_address.last_known_ip;

93
src/io/notifier.hpp Normal file
View File

@ -0,0 +1,93 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <condition_variable>
#include <functional>
#include <mutex>
#include <optional>
#include <vector>
namespace memgraph::io {
class ReadinessToken {
size_t id_;
public:
explicit ReadinessToken(size_t id) : id_(id) {}
size_t GetId() const { return id_; }
};
class Inner {
std::condition_variable cv_;
std::mutex mu_;
std::vector<ReadinessToken> ready_;
std::optional<std::function<bool()>> tick_simulator_;
public:
void Notify(ReadinessToken readiness_token) {
{
std::unique_lock<std::mutex> lock(mu_);
ready_.emplace_back(readiness_token);
} // mutex dropped
cv_.notify_all();
}
ReadinessToken Await() {
std::unique_lock<std::mutex> lock(mu_);
while (ready_.empty()) {
if (tick_simulator_) [[unlikely]] {
// This avoids a deadlock in a similar way that
// Future::Wait will release its mutex while
// interacting with the simulator, due to
// the fact that the simulator may cause
// notifications that we are interested in.
lock.unlock();
std::invoke(tick_simulator_.value());
lock.lock();
} else {
cv_.wait(lock);
}
}
ReadinessToken ret = ready_.back();
ready_.pop_back();
return ret;
}
void InstallSimulatorTicker(std::function<bool()> tick_simulator) {
std::unique_lock<std::mutex> lock(mu_);
tick_simulator_ = tick_simulator;
}
};
class Notifier {
std::shared_ptr<Inner> inner_;
public:
Notifier() : inner_(std::make_shared<Inner>()) {}
Notifier(const Notifier &) = default;
Notifier &operator=(const Notifier &) = default;
Notifier(Notifier &&old) = default;
Notifier &operator=(Notifier &&old) = default;
~Notifier() = default;
void Notify(ReadinessToken readiness_token) const { inner_->Notify(readiness_token); }
ReadinessToken Await() const { return inner_->Await(); }
void InstallSimulatorTicker(std::function<bool()> tick_simulator) { inner_->InstallSimulatorTicker(tick_simulator); }
};
} // namespace memgraph::io

View File

@ -91,33 +91,43 @@ struct ReadResponse {
}; };
template <class... ReadReturn> template <class... ReadReturn>
utils::TypeInfoRef TypeInfoFor(const ReadResponse<std::variant<ReadReturn...>> &read_response) { utils::TypeInfoRef TypeInfoFor(const ReadResponse<std::variant<ReadReturn...>> &response) {
return TypeInfoForVariant(read_response.read_return); return TypeInfoForVariant(response.read_return);
} }
template <class ReadReturn> template <class ReadReturn>
utils::TypeInfoRef TypeInfoFor(const ReadResponse<ReadReturn> & /* read_response */) { utils::TypeInfoRef TypeInfoFor(const ReadResponse<ReadReturn> & /* response */) {
return typeid(ReadReturn); return typeid(ReadReturn);
} }
template <class ReadOperation>
utils::TypeInfoRef TypeInfoFor(const ReadRequest<ReadOperation> & /* request */) {
return typeid(ReadOperation);
}
template <class... ReadOperations>
utils::TypeInfoRef TypeInfoFor(const ReadRequest<std::variant<ReadOperations...>> &request) {
return TypeInfoForVariant(request.operation);
}
template <class... WriteReturn> template <class... WriteReturn>
utils::TypeInfoRef TypeInfoFor(const WriteResponse<std::variant<WriteReturn...>> &write_response) { utils::TypeInfoRef TypeInfoFor(const WriteResponse<std::variant<WriteReturn...>> &response) {
return TypeInfoForVariant(write_response.write_return); return TypeInfoForVariant(response.write_return);
} }
template <class WriteReturn> template <class WriteReturn>
utils::TypeInfoRef TypeInfoFor(const WriteResponse<WriteReturn> & /* write_response */) { utils::TypeInfoRef TypeInfoFor(const WriteResponse<WriteReturn> & /* response */) {
return typeid(WriteReturn); return typeid(WriteReturn);
} }
template <class WriteOperation> template <class WriteOperation>
utils::TypeInfoRef TypeInfoFor(const WriteRequest<WriteOperation> & /* write_request */) { utils::TypeInfoRef TypeInfoFor(const WriteRequest<WriteOperation> & /* request */) {
return typeid(WriteOperation); return typeid(WriteOperation);
} }
template <class... WriteOperations> template <class... WriteOperations>
utils::TypeInfoRef TypeInfoFor(const WriteRequest<std::variant<WriteOperations...>> &write_request) { utils::TypeInfoRef TypeInfoFor(const WriteRequest<std::variant<WriteOperations...>> &request) {
return TypeInfoForVariant(write_request.operation); return TypeInfoForVariant(request.operation);
} }
/// AppendRequest is a raft-level message that the Leader /// AppendRequest is a raft-level message that the Leader
@ -846,7 +856,9 @@ class Raft {
// Leaders are able to immediately respond to the requester (with a ReadResponseValue) applied to the ReplicatedState // Leaders are able to immediately respond to the requester (with a ReadResponseValue) applied to the ReplicatedState
std::optional<Role> Handle(Leader & /* variable */, ReadRequest<ReadOperation> &&req, RequestId request_id, std::optional<Role> Handle(Leader & /* variable */, ReadRequest<ReadOperation> &&req, RequestId request_id,
Address from_address) { Address from_address) {
Log("handling ReadOperation"); auto type_info = TypeInfoFor(req);
std::string demangled_name = boost::core::demangle(type_info.get().name());
Log("handling ReadOperation<" + demangled_name + ">");
ReadOperation read_operation = req.operation; ReadOperation read_operation = req.operation;
ReadResponseValue read_return = replicated_state_.Read(read_operation); ReadResponseValue read_return = replicated_state_.Read(read_operation);

View File

@ -19,6 +19,7 @@
#include "io/address.hpp" #include "io/address.hpp"
#include "io/errors.hpp" #include "io/errors.hpp"
#include "io/notifier.hpp"
#include "io/rsm/raft.hpp" #include "io/rsm/raft.hpp"
#include "utils/result.hpp" #include "utils/result.hpp"
@ -37,18 +38,11 @@ using memgraph::io::rsm::WriteRequest;
using memgraph::io::rsm::WriteResponse; using memgraph::io::rsm::WriteResponse;
using memgraph::utils::BasicResult; using memgraph::utils::BasicResult;
class AsyncRequestToken {
size_t id_;
public:
explicit AsyncRequestToken(size_t id) : id_(id) {}
size_t GetId() const { return id_; }
};
template <typename RequestT, typename ResponseT> template <typename RequestT, typename ResponseT>
struct AsyncRequest { struct AsyncRequest {
Time start_time; Time start_time;
RequestT request; RequestT request;
Notifier notifier;
ResponseFuture<ResponseT> future; ResponseFuture<ResponseT> future;
}; };
@ -66,8 +60,6 @@ class RsmClient {
std::unordered_map<size_t, AsyncRequest<ReadRequestT, ReadResponse<ReadResponseT>>> async_reads_; std::unordered_map<size_t, AsyncRequest<ReadRequestT, ReadResponse<ReadResponseT>>> async_reads_;
std::unordered_map<size_t, AsyncRequest<WriteRequestT, WriteResponse<WriteResponseT>>> async_writes_; std::unordered_map<size_t, AsyncRequest<WriteRequestT, WriteResponse<WriteResponseT>>> async_writes_;
size_t async_token_generator_ = 0;
void SelectRandomLeader() { void SelectRandomLeader() {
std::uniform_int_distribution<size_t> addr_distrib(0, (server_addrs_.size() - 1)); std::uniform_int_distribution<size_t> addr_distrib(0, (server_addrs_.size() - 1));
size_t addr_index = io_.Rand(addr_distrib); size_t addr_index = io_.Rand(addr_distrib);
@ -101,61 +93,63 @@ class RsmClient {
~RsmClient() = default; ~RsmClient() = default;
BasicResult<TimedOut, WriteResponseT> SendWriteRequest(WriteRequestT req) { BasicResult<TimedOut, WriteResponseT> SendWriteRequest(WriteRequestT req) {
auto token = SendAsyncWriteRequest(req); Notifier notifier;
auto poll_result = AwaitAsyncWriteRequest(token); const ReadinessToken readiness_token{0};
SendAsyncWriteRequest(req, notifier, readiness_token);
auto poll_result = AwaitAsyncWriteRequest(readiness_token);
while (!poll_result) { while (!poll_result) {
poll_result = AwaitAsyncWriteRequest(token); poll_result = AwaitAsyncWriteRequest(readiness_token);
} }
return poll_result.value(); return poll_result.value();
} }
BasicResult<TimedOut, ReadResponseT> SendReadRequest(ReadRequestT req) { BasicResult<TimedOut, ReadResponseT> SendReadRequest(ReadRequestT req) {
auto token = SendAsyncReadRequest(req); Notifier notifier;
auto poll_result = AwaitAsyncReadRequest(token); const ReadinessToken readiness_token{0};
SendAsyncReadRequest(req, notifier, readiness_token);
auto poll_result = AwaitAsyncReadRequest(readiness_token);
while (!poll_result) { while (!poll_result) {
poll_result = AwaitAsyncReadRequest(token); poll_result = AwaitAsyncReadRequest(readiness_token);
} }
return poll_result.value(); return poll_result.value();
} }
/// AsyncRead methods /// AsyncRead methods
AsyncRequestToken SendAsyncReadRequest(const ReadRequestT &req) { void SendAsyncReadRequest(const ReadRequestT &req, Notifier notifier, ReadinessToken readiness_token) {
size_t token = async_token_generator_++;
ReadRequest<ReadRequestT> read_req = {.operation = req}; ReadRequest<ReadRequestT> read_req = {.operation = req};
AsyncRequest<ReadRequestT, ReadResponse<ReadResponseT>> async_request{ AsyncRequest<ReadRequestT, ReadResponse<ReadResponseT>> async_request{
.start_time = io_.Now(), .start_time = io_.Now(),
.request = std::move(req), .request = std::move(req),
.future = io_.template Request<ReadRequest<ReadRequestT>, ReadResponse<ReadResponseT>>(leader_, read_req), .notifier = notifier,
.future = io_.template RequestWithNotification<ReadRequest<ReadRequestT>, ReadResponse<ReadResponseT>>(
leader_, read_req, notifier, readiness_token),
}; };
async_reads_.emplace(token, std::move(async_request)); async_reads_.emplace(readiness_token.GetId(), std::move(async_request));
return AsyncRequestToken{token};
} }
void ResendAsyncReadRequest(const AsyncRequestToken &token) { void ResendAsyncReadRequest(const ReadinessToken &readiness_token) {
auto &async_request = async_reads_.at(token.GetId()); auto &async_request = async_reads_.at(readiness_token.GetId());
ReadRequest<ReadRequestT> read_req = {.operation = async_request.request}; ReadRequest<ReadRequestT> read_req = {.operation = async_request.request};
async_request.future = async_request.future = io_.template RequestWithNotification<ReadRequest<ReadRequestT>, ReadResponse<ReadResponseT>>(
io_.template Request<ReadRequest<ReadRequestT>, ReadResponse<ReadResponseT>>(leader_, read_req); leader_, read_req, async_request.notifier, readiness_token);
} }
std::optional<BasicResult<TimedOut, ReadResponseT>> PollAsyncReadRequest(const AsyncRequestToken &token) { std::optional<BasicResult<TimedOut, ReadResponseT>> PollAsyncReadRequest(const ReadinessToken &readiness_token) {
auto &async_request = async_reads_.at(token.GetId()); auto &async_request = async_reads_.at(readiness_token.GetId());
if (!async_request.future.IsReady()) { if (!async_request.future.IsReady()) {
return std::nullopt; return std::nullopt;
} }
return AwaitAsyncReadRequest(); return AwaitAsyncReadRequest(readiness_token);
} }
std::optional<BasicResult<TimedOut, ReadResponseT>> AwaitAsyncReadRequest(const AsyncRequestToken &token) { std::optional<BasicResult<TimedOut, ReadResponseT>> AwaitAsyncReadRequest(const ReadinessToken &readiness_token) {
auto &async_request = async_reads_.at(token.GetId()); auto &async_request = async_reads_.at(readiness_token.GetId());
ResponseResult<ReadResponse<ReadResponseT>> get_response_result = std::move(async_request.future).Wait(); ResponseResult<ReadResponse<ReadResponseT>> get_response_result = std::move(async_request.future).Wait();
const Duration overall_timeout = io_.GetDefaultTimeout(); const Duration overall_timeout = io_.GetDefaultTimeout();
@ -165,7 +159,7 @@ class RsmClient {
if (result_has_error && past_time_out) { if (result_has_error && past_time_out) {
// TODO static assert the exact type of error. // TODO static assert the exact type of error.
spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString()); spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString());
async_reads_.erase(token.GetId()); async_reads_.erase(readiness_token.GetId());
return TimedOut{}; return TimedOut{};
} }
@ -176,7 +170,7 @@ class RsmClient {
PossiblyRedirectLeader(read_get_response); PossiblyRedirectLeader(read_get_response);
if (read_get_response.success) { if (read_get_response.success) {
async_reads_.erase(token.GetId()); async_reads_.erase(readiness_token.GetId());
spdlog::debug("returning read_return for RSM request"); spdlog::debug("returning read_return for RSM request");
return std::move(read_get_response.read_return); return std::move(read_get_response.read_return);
} }
@ -184,49 +178,48 @@ class RsmClient {
SelectRandomLeader(); SelectRandomLeader();
} }
ResendAsyncReadRequest(token); ResendAsyncReadRequest(readiness_token);
return std::nullopt; return std::nullopt;
} }
/// AsyncWrite methods /// AsyncWrite methods
AsyncRequestToken SendAsyncWriteRequest(const WriteRequestT &req) { void SendAsyncWriteRequest(const WriteRequestT &req, Notifier notifier, ReadinessToken readiness_token) {
size_t token = async_token_generator_++;
WriteRequest<WriteRequestT> write_req = {.operation = req}; WriteRequest<WriteRequestT> write_req = {.operation = req};
AsyncRequest<WriteRequestT, WriteResponse<WriteResponseT>> async_request{ AsyncRequest<WriteRequestT, WriteResponse<WriteResponseT>> async_request{
.start_time = io_.Now(), .start_time = io_.Now(),
.request = std::move(req), .request = std::move(req),
.future = io_.template Request<WriteRequest<WriteRequestT>, WriteResponse<WriteResponseT>>(leader_, write_req), .notifier = notifier,
.future = io_.template RequestWithNotification<WriteRequest<WriteRequestT>, WriteResponse<WriteResponseT>>(
leader_, write_req, notifier, readiness_token),
}; };
async_writes_.emplace(token, std::move(async_request)); async_writes_.emplace(readiness_token.GetId(), std::move(async_request));
return AsyncRequestToken{token};
} }
void ResendAsyncWriteRequest(const AsyncRequestToken &token) { void ResendAsyncWriteRequest(const ReadinessToken &readiness_token) {
auto &async_request = async_writes_.at(token.GetId()); auto &async_request = async_writes_.at(readiness_token.GetId());
WriteRequest<WriteRequestT> write_req = {.operation = async_request.request}; WriteRequest<WriteRequestT> write_req = {.operation = async_request.request};
async_request.future = async_request.future =
io_.template Request<WriteRequest<WriteRequestT>, WriteResponse<WriteResponseT>>(leader_, write_req); io_.template RequestWithNotification<WriteRequest<WriteRequestT>, WriteResponse<WriteResponseT>>(
leader_, write_req, async_request.notifier, readiness_token);
} }
std::optional<BasicResult<TimedOut, WriteResponseT>> PollAsyncWriteRequest(const AsyncRequestToken &token) { std::optional<BasicResult<TimedOut, WriteResponseT>> PollAsyncWriteRequest(const ReadinessToken &readiness_token) {
auto &async_request = async_writes_.at(token.GetId()); auto &async_request = async_writes_.at(readiness_token.GetId());
if (!async_request.future.IsReady()) { if (!async_request.future.IsReady()) {
return std::nullopt; return std::nullopt;
} }
return AwaitAsyncWriteRequest(); return AwaitAsyncWriteRequest(readiness_token);
} }
std::optional<BasicResult<TimedOut, WriteResponseT>> AwaitAsyncWriteRequest(const AsyncRequestToken &token) { std::optional<BasicResult<TimedOut, WriteResponseT>> AwaitAsyncWriteRequest(const ReadinessToken &readiness_token) {
auto &async_request = async_writes_.at(token.GetId()); auto &async_request = async_writes_.at(readiness_token.GetId());
ResponseResult<WriteResponse<WriteResponseT>> get_response_result = std::move(async_request.future).Wait(); ResponseResult<WriteResponse<WriteResponseT>> get_response_result = std::move(async_request.future).Wait();
const Duration overall_timeout = io_.GetDefaultTimeout(); const Duration overall_timeout = io_.GetDefaultTimeout();
@ -236,7 +229,7 @@ class RsmClient {
if (result_has_error && past_time_out) { if (result_has_error && past_time_out) {
// TODO static assert the exact type of error. // TODO static assert the exact type of error.
spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString()); spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString());
async_writes_.erase(token.GetId()); async_writes_.erase(readiness_token.GetId());
return TimedOut{}; return TimedOut{};
} }
@ -248,14 +241,14 @@ class RsmClient {
PossiblyRedirectLeader(write_get_response); PossiblyRedirectLeader(write_get_response);
if (write_get_response.success) { if (write_get_response.success) {
async_writes_.erase(token.GetId()); async_writes_.erase(readiness_token.GetId());
return std::move(write_get_response.write_return); return std::move(write_get_response.write_return);
} }
} else { } else {
SelectRandomLeader(); SelectRandomLeader();
} }
ResendAsyncWriteRequest(token); ResendAsyncWriteRequest(readiness_token);
return std::nullopt; return std::nullopt;
} }

View File

@ -51,5 +51,12 @@ class Simulator {
SimulatorStats Stats() { return simulator_handle_->Stats(); } SimulatorStats Stats() { return simulator_handle_->Stats(); }
std::shared_ptr<SimulatorHandle> GetSimulatorHandle() const { return simulator_handle_; } std::shared_ptr<SimulatorHandle> GetSimulatorHandle() const { return simulator_handle_; }
std::function<bool()> GetSimulatorTickClosure() {
std::function<bool()> tick_closure = [handle_copy = simulator_handle_] {
return handle_copy->MaybeTickSimulator();
};
return tick_closure;
}
}; };
}; // namespace memgraph::io::simulator }; // namespace memgraph::io::simulator

View File

@ -22,6 +22,8 @@
#include <variant> #include <variant>
#include <vector> #include <vector>
#include <boost/core/demangle.hpp>
#include "io/address.hpp" #include "io/address.hpp"
#include "io/errors.hpp" #include "io/errors.hpp"
#include "io/message_conversion.hpp" #include "io/message_conversion.hpp"
@ -105,13 +107,19 @@ class SimulatorHandle {
template <Message Request, Message Response> template <Message Request, Message Response>
ResponseFuture<Response> SubmitRequest(Address to_address, Address from_address, Request &&request, Duration timeout, ResponseFuture<Response> SubmitRequest(Address to_address, Address from_address, Request &&request, Duration timeout,
std::function<bool()> &&maybe_tick_simulator) { std::function<bool()> &&maybe_tick_simulator,
spdlog::trace("submitting request to {}", to_address.last_known_port); std::function<void()> &&fill_notifier) {
auto type_info = TypeInfoFor(request); auto type_info = TypeInfoFor(request);
std::string demangled_name = boost::core::demangle(type_info.get().name());
spdlog::trace("simulator sending request {} to {}", demangled_name, to_address);
auto [future, promise] = memgraph::io::FuturePromisePairWithNotifier<ResponseResult<Response>>( auto [future, promise] = memgraph::io::FuturePromisePairWithNotifications<ResponseResult<Response>>(
std::forward<std::function<bool()>>(maybe_tick_simulator)); // set notifier for when the Future::Wait is called
std::forward<std::function<bool()>>(maybe_tick_simulator),
// set notifier for when Promise::Fill is called
std::forward<std::function<void()>>(fill_notifier));
{
std::unique_lock<std::mutex> lock(mu_); std::unique_lock<std::mutex> lock(mu_);
RequestId request_id = ++request_id_counter_; RequestId request_id = ++request_id_counter_;
@ -140,6 +148,7 @@ class SimulatorHandle {
stats_.total_messages++; stats_.total_messages++;
stats_.total_requests++; stats_.total_requests++;
} // lock dropped here
cv_.notify_all(); cv_.notify_all();

View File

@ -15,6 +15,7 @@
#include <utility> #include <utility>
#include "io/address.hpp" #include "io/address.hpp"
#include "io/notifier.hpp"
#include "io/simulator/simulator_handle.hpp" #include "io/simulator/simulator_handle.hpp"
#include "io/time.hpp" #include "io/time.hpp"
@ -33,11 +34,14 @@ class SimulatorTransport {
: simulator_handle_(simulator_handle), address_(address), rng_(std::mt19937{seed}) {} : simulator_handle_(simulator_handle), address_(address), rng_(std::mt19937{seed}) {}
template <Message RequestT, Message ResponseT> template <Message RequestT, Message ResponseT>
ResponseFuture<ResponseT> Request(Address to_address, Address from_address, RequestT request, Duration timeout) { ResponseFuture<ResponseT> Request(Address to_address, Address from_address, RequestT request,
std::function<bool()> maybe_tick_simulator = [this] { return simulator_handle_->MaybeTickSimulator(); }; std::function<void()> notification, Duration timeout) {
std::function<bool()> tick_simulator = [handle_copy = simulator_handle_] {
return handle_copy->MaybeTickSimulator();
};
return simulator_handle_->template SubmitRequest<RequestT, ResponseT>(to_address, from_address, std::move(request), return simulator_handle_->template SubmitRequest<RequestT, ResponseT>(
timeout, std::move(maybe_tick_simulator)); to_address, from_address, std::move(request), timeout, std::move(tick_simulator), std::move(notification));
} }
template <Message... Ms> template <Message... Ms>

View File

@ -20,6 +20,7 @@
#include "io/errors.hpp" #include "io/errors.hpp"
#include "io/future.hpp" #include "io/future.hpp"
#include "io/message_histogram_collector.hpp" #include "io/message_histogram_collector.hpp"
#include "io/notifier.hpp"
#include "io/time.hpp" #include "io/time.hpp"
#include "utils/result.hpp" #include "utils/result.hpp"
@ -84,7 +85,9 @@ class Io {
template <Message RequestT, Message ResponseT> template <Message RequestT, Message ResponseT>
ResponseFuture<ResponseT> RequestWithTimeout(Address address, RequestT request, Duration timeout) { ResponseFuture<ResponseT> RequestWithTimeout(Address address, RequestT request, Duration timeout) {
const Address from_address = address_; const Address from_address = address_;
return implementation_.template Request<RequestT, ResponseT>(address, from_address, request, timeout); std::function<void()> fill_notifier = nullptr;
return implementation_.template Request<RequestT, ResponseT>(address, from_address, request, fill_notifier,
timeout);
} }
/// Issue a request that times out after the default timeout. This tends /// Issue a request that times out after the default timeout. This tends
@ -93,7 +96,30 @@ class Io {
ResponseFuture<ResponseT> Request(Address to_address, RequestT request) { ResponseFuture<ResponseT> Request(Address to_address, RequestT request) {
const Duration timeout = default_timeout_; const Duration timeout = default_timeout_;
const Address from_address = address_; const Address from_address = address_;
return implementation_.template Request<RequestT, ResponseT>(to_address, from_address, std::move(request), timeout); std::function<void()> fill_notifier = nullptr;
return implementation_.template Request<RequestT, ResponseT>(to_address, from_address, std::move(request),
fill_notifier, timeout);
}
/// Issue a request that will notify a Notifier when it is filled or times out.
template <Message RequestT, Message ResponseT>
ResponseFuture<ResponseT> RequestWithNotification(Address to_address, RequestT request, Notifier notifier,
ReadinessToken readiness_token) {
const Duration timeout = default_timeout_;
const Address from_address = address_;
std::function<void()> fill_notifier = [notifier, readiness_token]() { notifier.Notify(readiness_token); };
return implementation_.template Request<RequestT, ResponseT>(to_address, from_address, std::move(request),
fill_notifier, timeout);
}
/// Issue a request that will notify a Notifier when it is filled or times out.
template <Message RequestT, Message ResponseT>
ResponseFuture<ResponseT> RequestWithNotificationAndTimeout(Address to_address, RequestT request, Notifier notifier,
ReadinessToken readiness_token, Duration timeout) {
const Address from_address = address_;
std::function<void()> fill_notifier = [notifier, readiness_token]() { notifier.Notify(readiness_token); };
return implementation_.template Request<RequestT, ResponseT>(to_address, from_address, std::move(request),
fill_notifier, timeout);
} }
/// Wait for an explicit number of microseconds for a request of one of the /// Wait for an explicit number of microseconds for a request of one of the

View File

@ -23,7 +23,8 @@ set(mg_query_v2_sources
plan/variable_start_planner.cpp plan/variable_start_planner.cpp
serialization/property_value.cpp serialization/property_value.cpp
bindings/typed_value.cpp bindings/typed_value.cpp
accessors.cpp) accessors.cpp
multiframe.cpp)
find_package(Boost REQUIRED) find_package(Boost REQUIRED)

View File

@ -13,9 +13,10 @@
#include "query/v2/bindings/bindings.hpp" #include "query/v2/bindings/bindings.hpp"
#include "query/v2/bindings/typed_value.hpp"
#include "expr/interpret/frame.hpp" #include "expr/interpret/frame.hpp"
#include "query/v2/bindings/typed_value.hpp"
namespace memgraph::query::v2 { namespace memgraph::query::v2 {
using Frame = memgraph::expr::Frame<TypedValue>; using Frame = memgraph::expr::Frame;
using FrameWithValidity = memgraph::expr::FrameWithValidity;
} // namespace memgraph::query::v2 } // namespace memgraph::query::v2

View File

@ -41,6 +41,7 @@
#include "query/v2/frontend/ast/ast.hpp" #include "query/v2/frontend/ast/ast.hpp"
#include "query/v2/frontend/semantic/required_privileges.hpp" #include "query/v2/frontend/semantic/required_privileges.hpp"
#include "query/v2/metadata.hpp" #include "query/v2/metadata.hpp"
#include "query/v2/multiframe.hpp"
#include "query/v2/plan/planner.hpp" #include "query/v2/plan/planner.hpp"
#include "query/v2/plan/profile.hpp" #include "query/v2/plan/profile.hpp"
#include "query/v2/plan/vertex_count_cache.hpp" #include "query/v2/plan/vertex_count_cache.hpp"
@ -147,7 +148,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa
// Empty frame for evaluation of password expression. This is OK since // Empty frame for evaluation of password expression. This is OK since
// password should be either null or string literal and it's evaluation // password should be either null or string literal and it's evaluation
// should not depend on frame. // should not depend on frame.
expr::Frame<TypedValue> frame(0); expr::Frame frame(0);
SymbolTable symbol_table; SymbolTable symbol_table;
EvaluationContext evaluation_context; EvaluationContext evaluation_context;
// TODO: MemoryResource for EvaluationContext, it should probably be passed as // TODO: MemoryResource for EvaluationContext, it should probably be passed as
@ -314,7 +315,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa
Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &parameters, Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &parameters,
InterpreterContext *interpreter_context, RequestRouterInterface *request_router, InterpreterContext *interpreter_context, RequestRouterInterface *request_router,
std::vector<Notification> *notifications) { std::vector<Notification> *notifications) {
expr::Frame<TypedValue> frame(0); expr::Frame frame(0);
SymbolTable symbol_table; SymbolTable symbol_table;
EvaluationContext evaluation_context; EvaluationContext evaluation_context;
// TODO: MemoryResource for EvaluationContext, it should probably be passed as // TODO: MemoryResource for EvaluationContext, it should probably be passed as
@ -449,7 +450,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
Callback HandleSettingQuery(SettingQuery *setting_query, const Parameters &parameters, Callback HandleSettingQuery(SettingQuery *setting_query, const Parameters &parameters,
RequestRouterInterface *request_router) { RequestRouterInterface *request_router) {
expr::Frame<TypedValue> frame(0); expr::Frame frame(0);
SymbolTable symbol_table; SymbolTable symbol_table;
EvaluationContext evaluation_context; EvaluationContext evaluation_context;
// TODO: MemoryResource for EvaluationContext, it should probably be passed as // TODO: MemoryResource for EvaluationContext, it should probably be passed as
@ -655,11 +656,15 @@ struct PullPlan {
std::optional<plan::ProfilingStatsWithTotalTime> Pull(AnyStream *stream, std::optional<int> n, std::optional<plan::ProfilingStatsWithTotalTime> Pull(AnyStream *stream, std::optional<int> n,
const std::vector<Symbol> &output_symbols, const std::vector<Symbol> &output_symbols,
std::map<std::string, TypedValue> *summary); std::map<std::string, TypedValue> *summary);
std::optional<plan::ProfilingStatsWithTotalTime> PullMultiple(AnyStream *stream, std::optional<int> n,
const std::vector<Symbol> &output_symbols,
std::map<std::string, TypedValue> *summary);
private: private:
std::shared_ptr<CachedPlan> plan_ = nullptr; std::shared_ptr<CachedPlan> plan_ = nullptr;
plan::UniqueCursorPtr cursor_ = nullptr; plan::UniqueCursorPtr cursor_ = nullptr;
expr::Frame<TypedValue> frame_; expr::FrameWithValidity frame_;
MultiFrame multi_frame_;
ExecutionContext ctx_; ExecutionContext ctx_;
std::optional<size_t> memory_limit_; std::optional<size_t> memory_limit_;
@ -683,6 +688,7 @@ PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &par
: plan_(plan), : plan_(plan),
cursor_(plan->plan().MakeCursor(execution_memory)), cursor_(plan->plan().MakeCursor(execution_memory)),
frame_(plan->symbol_table().max_position(), execution_memory), frame_(plan->symbol_table().max_position(), execution_memory),
multi_frame_(plan->symbol_table().max_position(), kNumberOfFramesInMultiframe, execution_memory),
memory_limit_(memory_limit) { memory_limit_(memory_limit) {
ctx_.db_accessor = dba; ctx_.db_accessor = dba;
ctx_.symbol_table = plan->symbol_table(); ctx_.symbol_table = plan->symbol_table();
@ -699,9 +705,116 @@ PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &par
ctx_.edge_ids_alloc = &interpreter_context->edge_ids_alloc; ctx_.edge_ids_alloc = &interpreter_context->edge_ids_alloc;
} }
std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::PullMultiple(AnyStream *stream, std::optional<int> n,
const std::vector<Symbol> &output_symbols,
std::map<std::string, TypedValue> *summary) {
// Set up temporary memory for a single Pull. Initial memory comes from the
// stack. 256 KiB should fit on the stack and should be more than enough for a
// single `Pull`.
MG_ASSERT(!n.has_value(), "should pull all!");
static constexpr size_t stack_size = 256UL * 1024UL;
char stack_data[stack_size];
utils::ResourceWithOutOfMemoryException resource_with_exception;
utils::MonotonicBufferResource monotonic_memory(&stack_data[0], stack_size, &resource_with_exception);
// We can throw on every query because a simple queries for deleting will use only
// the stack allocated buffer.
// Also, we want to throw only when the query engine requests more memory and not the storage
// so we add the exception to the allocator.
// TODO (mferencevic): Tune the parameters accordingly.
utils::PoolResource pool_memory(128, 1024, &monotonic_memory);
std::optional<utils::LimitedMemoryResource> maybe_limited_resource;
if (memory_limit_) {
maybe_limited_resource.emplace(&pool_memory, *memory_limit_);
ctx_.evaluation_context.memory = &*maybe_limited_resource;
} else {
ctx_.evaluation_context.memory = &pool_memory;
}
// Returns true if a result was pulled.
const auto pull_result = [&]() -> bool {
cursor_->PullMultiple(multi_frame_, ctx_);
return multi_frame_.HasValidFrame();
};
const auto stream_values = [&output_symbols, &stream](const Frame &frame) {
// TODO: The streamed values should also probably use the above memory.
std::vector<TypedValue> values;
values.reserve(output_symbols.size());
for (const auto &symbol : output_symbols) {
values.emplace_back(frame[symbol]);
}
stream->Result(values);
};
// Get the execution time of all possible result pulls and streams.
utils::Timer timer;
int i = 0;
if (has_unsent_results_ && !output_symbols.empty()) {
// stream unsent results from previous pull
auto iterator_for_valid_frame_only = multi_frame_.GetValidFramesReader();
for (const auto &frame : iterator_for_valid_frame_only) {
stream_values(frame);
++i;
}
multi_frame_.MakeAllFramesInvalid();
}
for (; !n || i < n;) {
if (!pull_result()) {
break;
}
if (!output_symbols.empty()) {
auto iterator_for_valid_frame_only = multi_frame_.GetValidFramesReader();
for (const auto &frame : iterator_for_valid_frame_only) {
stream_values(frame);
++i;
}
}
multi_frame_.MakeAllFramesInvalid();
}
// If we finished because we streamed the requested n results,
// we try to pull the next result to see if there is more.
// If there is additional result, we leave the pulled result in the frame
// and set the flag to true.
has_unsent_results_ = i == n && pull_result();
execution_time_ += timer.Elapsed();
if (has_unsent_results_) {
return std::nullopt;
}
summary->insert_or_assign("plan_execution_time", execution_time_.count());
// We are finished with pulling all the data, therefore we can send any
// metadata about the results i.e. notifications and statistics
const bool is_any_counter_set =
std::any_of(ctx_.execution_stats.counters.begin(), ctx_.execution_stats.counters.end(),
[](const auto &counter) { return counter > 0; });
if (is_any_counter_set) {
std::map<std::string, TypedValue> stats;
for (size_t i = 0; i < ctx_.execution_stats.counters.size(); ++i) {
stats.emplace(ExecutionStatsKeyToString(ExecutionStats::Key(i)), ctx_.execution_stats.counters[i]);
}
summary->insert_or_assign("stats", std::move(stats));
}
cursor_->Shutdown();
ctx_.profile_execution_time = execution_time_;
return GetStatsWithTotalTime(ctx_);
}
std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::Pull(AnyStream *stream, std::optional<int> n, std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::Pull(AnyStream *stream, std::optional<int> n,
const std::vector<Symbol> &output_symbols, const std::vector<Symbol> &output_symbols,
std::map<std::string, TypedValue> *summary) { std::map<std::string, TypedValue> *summary) {
auto should_pull_multiple = false; // TODO on the long term, we will only use PullMultiple
if (should_pull_multiple) {
return PullMultiple(stream, n, output_symbols, summary);
}
// Set up temporary memory for a single Pull. Initial memory comes from the // Set up temporary memory for a single Pull. Initial memory comes from the
// stack. 256 KiB should fit on the stack and should be more than enough for a // stack. 256 KiB should fit on the stack and should be more than enough for a
// single `Pull`. // single `Pull`.
@ -875,7 +988,7 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string,
// TriggerContextCollector *trigger_context_collector = nullptr) { // TriggerContextCollector *trigger_context_collector = nullptr) {
auto *cypher_query = utils::Downcast<CypherQuery>(parsed_query.query); auto *cypher_query = utils::Downcast<CypherQuery>(parsed_query.query);
expr::Frame<TypedValue> frame(0); expr::Frame frame(0);
SymbolTable symbol_table; SymbolTable symbol_table;
EvaluationContext evaluation_context; EvaluationContext evaluation_context;
evaluation_context.timestamp = QueryTimestamp(); evaluation_context.timestamp = QueryTimestamp();
@ -1017,7 +1130,7 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra
auto *cypher_query = utils::Downcast<CypherQuery>(parsed_inner_query.query); auto *cypher_query = utils::Downcast<CypherQuery>(parsed_inner_query.query);
MG_ASSERT(cypher_query, "Cypher grammar should not allow other queries in PROFILE"); MG_ASSERT(cypher_query, "Cypher grammar should not allow other queries in PROFILE");
expr::Frame<TypedValue> frame(0); expr::Frame frame(0);
SymbolTable symbol_table; SymbolTable symbol_table;
EvaluationContext evaluation_context; EvaluationContext evaluation_context;
evaluation_context.timestamp = QueryTimestamp(); evaluation_context.timestamp = QueryTimestamp();

144
src/query/v2/multiframe.cpp Normal file
View File

@ -0,0 +1,144 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/multiframe.hpp"
#include <algorithm>
#include <iterator>
#include "query/v2/bindings/frame.hpp"
#include "utils/pmr/vector.hpp"
namespace memgraph::query::v2 {
static_assert(std::forward_iterator<ValidFramesReader::Iterator>);
static_assert(std::forward_iterator<ValidFramesModifier::Iterator>);
static_assert(std::forward_iterator<ValidFramesConsumer::Iterator>);
static_assert(std::forward_iterator<InvalidFramesPopulator::Iterator>);
MultiFrame::MultiFrame(int64_t size_of_frame, size_t number_of_frames, utils::MemoryResource *execution_memory)
: frames_(utils::pmr::vector<FrameWithValidity>(
number_of_frames, FrameWithValidity(size_of_frame, execution_memory), execution_memory)) {
MG_ASSERT(number_of_frames > 0);
}
MultiFrame::MultiFrame(const MultiFrame &other) : frames_{other.frames_} {}
// NOLINTNEXTLINE (bugprone-exception-escape)
MultiFrame::MultiFrame(MultiFrame &&other) noexcept : frames_(std::move(other.frames_)) {}
FrameWithValidity &MultiFrame::GetFirstFrame() {
MG_ASSERT(!frames_.empty());
return frames_.front();
}
void MultiFrame::MakeAllFramesInvalid() noexcept {
std::for_each(frames_.begin(), frames_.end(), [](auto &frame) { frame.MakeInvalid(); });
}
bool MultiFrame::HasValidFrame() const noexcept {
return std::any_of(frames_.begin(), frames_.end(), [](auto &frame) { return frame.IsValid(); });
}
// NOLINTNEXTLINE (bugprone-exception-escape)
void MultiFrame::DefragmentValidFrames() noexcept {
/*
from: https://en.cppreference.com/w/cpp/algorithm/remove
"Removing is done by shifting (by means of copy assignment (until C++11)move assignment (since C++11)) the elements
in the range in such a way that the elements that are not to be removed appear in the beginning of the range.
Relative order of the elements that remain is preserved and the physical size of the container is unchanged."
*/
// NOLINTNEXTLINE (bugprone-unused-return-value)
std::remove_if(frames_.begin(), frames_.end(), [](auto &frame) { return !frame.IsValid(); });
}
ValidFramesReader MultiFrame::GetValidFramesReader() { return ValidFramesReader{*this}; }
ValidFramesModifier MultiFrame::GetValidFramesModifier() { return ValidFramesModifier{*this}; }
ValidFramesConsumer MultiFrame::GetValidFramesConsumer() { return ValidFramesConsumer{*this}; }
InvalidFramesPopulator MultiFrame::GetInvalidFramesPopulator() { return InvalidFramesPopulator{*this}; }
ValidFramesReader::ValidFramesReader(MultiFrame &multiframe) : multiframe_(&multiframe) {
/*
From: https://en.cppreference.com/w/cpp/algorithm/find
Returns an iterator to the first element in the range [first, last) that satisfies specific criteria:
find_if searches for an element for which predicate p returns true
Return value
Iterator to the first element satisfying the condition or last if no such element is found.
-> this is what we want. We want the "after" last valid frame (weather this is vector::end or and invalid frame).
*/
auto it = std::find_if(multiframe.frames_.begin(), multiframe.frames_.end(),
[](const auto &frame) { return !frame.IsValid(); });
after_last_valid_frame_ = multiframe_->frames_.data() + std::distance(multiframe.frames_.begin(), it);
}
ValidFramesReader::Iterator ValidFramesReader::begin() {
if (multiframe_->frames_[0].IsValid()) {
return Iterator{&multiframe_->frames_[0]};
}
return end();
}
ValidFramesReader::Iterator ValidFramesReader::end() { return Iterator{after_last_valid_frame_}; }
ValidFramesModifier::ValidFramesModifier(MultiFrame &multiframe) : multiframe_(&multiframe) {}
ValidFramesModifier::Iterator ValidFramesModifier::begin() {
if (multiframe_->frames_[0].IsValid()) {
return Iterator{&multiframe_->frames_[0], *this};
}
return end();
}
ValidFramesModifier::Iterator ValidFramesModifier::end() {
return Iterator{multiframe_->frames_.data() + multiframe_->frames_.size(), *this};
}
ValidFramesConsumer::ValidFramesConsumer(MultiFrame &multiframe) : multiframe_(&multiframe) {}
// NOLINTNEXTLINE (bugprone-exception-escape)
ValidFramesConsumer::~ValidFramesConsumer() noexcept {
// TODO Possible optimisation: only DefragmentValidFrames if one frame has been invalidated? Only if does not
// cost too much to store it
multiframe_->DefragmentValidFrames();
}
ValidFramesConsumer::Iterator ValidFramesConsumer::begin() {
if (multiframe_->frames_[0].IsValid()) {
return Iterator{&multiframe_->frames_[0], *this};
}
return end();
}
ValidFramesConsumer::Iterator ValidFramesConsumer::end() {
return Iterator{multiframe_->frames_.data() + multiframe_->frames_.size(), *this};
}
InvalidFramesPopulator::InvalidFramesPopulator(MultiFrame &multiframe) : multiframe_(&multiframe) {}
InvalidFramesPopulator::Iterator InvalidFramesPopulator::begin() {
for (auto &frame : multiframe_->frames_) {
if (!frame.IsValid()) {
return Iterator{&frame};
}
}
return end();
}
InvalidFramesPopulator::Iterator InvalidFramesPopulator::end() {
return Iterator{multiframe_->frames_.data() + multiframe_->frames_.size()};
}
} // namespace memgraph::query::v2

302
src/query/v2/multiframe.hpp Normal file
View File

@ -0,0 +1,302 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <iterator>
#include "query/v2/bindings/frame.hpp"
namespace memgraph::query::v2 {
constexpr uint64_t kNumberOfFramesInMultiframe = 1000; // TODO have it configurable
class ValidFramesConsumer;
class ValidFramesModifier;
class ValidFramesReader;
class InvalidFramesPopulator;
class MultiFrame {
public:
friend class ValidFramesConsumer;
friend class ValidFramesModifier;
friend class ValidFramesReader;
friend class InvalidFramesPopulator;
MultiFrame(int64_t size_of_frame, size_t number_of_frames, utils::MemoryResource *execution_memory);
~MultiFrame() = default;
MultiFrame(const MultiFrame &other);
MultiFrame(MultiFrame &&other) noexcept;
MultiFrame &operator=(const MultiFrame &other) = delete;
MultiFrame &operator=(MultiFrame &&other) noexcept = delete;
/*
* Returns a object on which one can iterate in a for-loop. By doing so, you will only get Frames that are in a valid
* state in the MultiFrame.
* Iteration goes in a deterministic order.
* One can't modify the validity of the Frame nor its content with this implementation.
*/
ValidFramesReader GetValidFramesReader();
/*
* Returns a object on which one can iterate in a for-loop. By doing so, you will only get Frames that are in a valid
* state in the MultiFrame.
* Iteration goes in a deterministic order.
* One can't modify the validity of the Frame with this implementation. One can modify its content.
*/
ValidFramesModifier GetValidFramesModifier();
/*
* Returns a object on which one can iterate in a for-loop. By doing so, you will only get Frames that are in a valid
* state in the MultiFrame.
* Iteration goes in a deterministic order.
* One can modify the validity of the Frame with this implementation.
* If you do not plan to modify the validity of the Frames, use GetValidFramesReader/GetValidFramesModifer instead as
* this is faster.
*/
ValidFramesConsumer GetValidFramesConsumer();
/*
* Returns a object on which one can iterate in a for-loop. By doing so, you will only get Frames that are in an
* invalid state in the MultiFrame. Iteration goes in a deterministic order. One can modify the validity of
* the Frame with this implementation.
*/
InvalidFramesPopulator GetInvalidFramesPopulator();
/**
* Return the first Frame of the MultiFrame. This is only meant to be used in very specific cases. Please consider
* using the iterators instead.
* The Frame can be valid or invalid.
*/
FrameWithValidity &GetFirstFrame();
void MakeAllFramesInvalid() noexcept;
bool HasValidFrame() const noexcept;
inline utils::MemoryResource *GetMemoryResource() { return frames_[0].GetMemoryResource(); }
private:
void DefragmentValidFrames() noexcept;
utils::pmr::vector<FrameWithValidity> frames_;
};
class ValidFramesReader {
public:
explicit ValidFramesReader(MultiFrame &multiframe);
~ValidFramesReader() = default;
ValidFramesReader(const ValidFramesReader &other) = delete;
ValidFramesReader(ValidFramesReader &&other) noexcept = delete;
ValidFramesReader &operator=(const ValidFramesReader &other) = delete;
ValidFramesReader &operator=(ValidFramesReader &&other) noexcept = delete;
struct Iterator {
using iterator_category = std::forward_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = const Frame;
using pointer = value_type *;
using reference = const Frame &;
Iterator() = default;
explicit Iterator(FrameWithValidity *ptr) : ptr_(ptr) {}
reference operator*() const { return *ptr_; }
pointer operator->() { return ptr_; }
Iterator &operator++() {
ptr_++;
return *this;
}
// NOLINTNEXTLINE(cert-dcl21-cpp)
Iterator operator++(int) {
auto old = *this;
ptr_++;
return old;
}
friend bool operator==(const Iterator &lhs, const Iterator &rhs) { return lhs.ptr_ == rhs.ptr_; };
friend bool operator!=(const Iterator &lhs, const Iterator &rhs) { return lhs.ptr_ != rhs.ptr_; };
private:
FrameWithValidity *ptr_{nullptr};
};
Iterator begin();
Iterator end();
private:
FrameWithValidity *after_last_valid_frame_;
MultiFrame *multiframe_;
};
class ValidFramesModifier {
public:
explicit ValidFramesModifier(MultiFrame &multiframe);
~ValidFramesModifier() = default;
ValidFramesModifier(const ValidFramesModifier &other) = delete;
ValidFramesModifier(ValidFramesModifier &&other) noexcept = delete;
ValidFramesModifier &operator=(const ValidFramesModifier &other) = delete;
ValidFramesModifier &operator=(ValidFramesModifier &&other) noexcept = delete;
struct Iterator {
using iterator_category = std::forward_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = Frame;
using pointer = value_type *;
using reference = Frame &;
Iterator() = default;
Iterator(FrameWithValidity *ptr, ValidFramesModifier &iterator_wrapper)
: ptr_(ptr), iterator_wrapper_(&iterator_wrapper) {}
reference operator*() const { return *ptr_; }
pointer operator->() { return ptr_; }
// Prefix increment
Iterator &operator++() {
do {
ptr_++;
} while (*this != iterator_wrapper_->end() && ptr_->IsValid());
return *this;
}
// NOLINTNEXTLINE(cert-dcl21-cpp)
Iterator operator++(int) {
auto old = *this;
++*this;
return old;
}
friend bool operator==(const Iterator &lhs, const Iterator &rhs) { return lhs.ptr_ == rhs.ptr_; };
friend bool operator!=(const Iterator &lhs, const Iterator &rhs) { return lhs.ptr_ != rhs.ptr_; };
private:
FrameWithValidity *ptr_{nullptr};
ValidFramesModifier *iterator_wrapper_{nullptr};
};
Iterator begin();
Iterator end();
private:
MultiFrame *multiframe_;
};
class ValidFramesConsumer {
public:
explicit ValidFramesConsumer(MultiFrame &multiframe);
~ValidFramesConsumer() noexcept;
ValidFramesConsumer(const ValidFramesConsumer &other) = delete;
ValidFramesConsumer(ValidFramesConsumer &&other) noexcept = delete;
ValidFramesConsumer &operator=(const ValidFramesConsumer &other) = delete;
ValidFramesConsumer &operator=(ValidFramesConsumer &&other) noexcept = delete;
struct Iterator {
using iterator_category = std::forward_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = FrameWithValidity;
using pointer = value_type *;
using reference = FrameWithValidity &;
Iterator() = default;
Iterator(FrameWithValidity *ptr, ValidFramesConsumer &iterator_wrapper)
: ptr_(ptr), iterator_wrapper_(&iterator_wrapper) {}
reference operator*() const { return *ptr_; }
pointer operator->() { return ptr_; }
Iterator &operator++() {
do {
ptr_++;
} while (*this != iterator_wrapper_->end() && !ptr_->IsValid());
return *this;
}
// NOLINTNEXTLINE(cert-dcl21-cpp)
Iterator operator++(int) {
auto old = *this;
++*this;
return old;
}
friend bool operator==(const Iterator &lhs, const Iterator &rhs) { return lhs.ptr_ == rhs.ptr_; };
friend bool operator!=(const Iterator &lhs, const Iterator &rhs) { return lhs.ptr_ != rhs.ptr_; };
private:
FrameWithValidity *ptr_{nullptr};
ValidFramesConsumer *iterator_wrapper_{nullptr};
};
Iterator begin();
Iterator end();
private:
MultiFrame *multiframe_;
};
class InvalidFramesPopulator {
public:
explicit InvalidFramesPopulator(MultiFrame &multiframe);
~InvalidFramesPopulator() = default;
InvalidFramesPopulator(const InvalidFramesPopulator &other) = delete;
InvalidFramesPopulator(InvalidFramesPopulator &&other) noexcept = delete;
InvalidFramesPopulator &operator=(const InvalidFramesPopulator &other) = delete;
InvalidFramesPopulator &operator=(InvalidFramesPopulator &&other) noexcept = delete;
struct Iterator {
using iterator_category = std::forward_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = FrameWithValidity;
using pointer = value_type *;
using reference = FrameWithValidity &;
Iterator() = default;
explicit Iterator(FrameWithValidity *ptr) : ptr_(ptr) {}
reference operator*() const { return *ptr_; }
pointer operator->() { return ptr_; }
Iterator &operator++() {
ptr_->MakeValid();
ptr_++;
return *this;
}
// NOLINTNEXTLINE(cert-dcl21-cpp)
Iterator operator++(int) {
auto old = *this;
++ptr_;
return old;
}
friend bool operator==(const Iterator &lhs, const Iterator &rhs) { return lhs.ptr_ == rhs.ptr_; };
friend bool operator!=(const Iterator &lhs, const Iterator &rhs) { return lhs.ptr_ != rhs.ptr_; };
private:
FrameWithValidity *ptr_{nullptr};
};
Iterator begin();
Iterator end();
private:
MultiFrame *multiframe_;
};
} // namespace memgraph::query::v2

View File

@ -264,6 +264,16 @@ bool Once::OnceCursor::Pull(Frame &, ExecutionContext &context) {
return false; return false;
} }
void Once::OnceCursor::PullMultiple(MultiFrame &multi_frame, ExecutionContext &context) {
SCOPED_PROFILE_OP("OnceMF");
if (!did_pull_) {
auto &first_frame = multi_frame.GetFirstFrame();
first_frame.MakeValid();
did_pull_ = true;
}
}
UniqueCursorPtr Once::MakeCursor(utils::MemoryResource *mem) const { UniqueCursorPtr Once::MakeCursor(utils::MemoryResource *mem) const {
EventCounter::IncrementCounter(EventCounter::OnceOperator); EventCounter::IncrementCounter(EventCounter::OnceOperator);
@ -748,6 +758,23 @@ bool Produce::ProduceCursor::Pull(Frame &frame, ExecutionContext &context) {
return false; return false;
} }
void Produce::ProduceCursor::PullMultiple(MultiFrame &multi_frame, ExecutionContext &context) {
SCOPED_PROFILE_OP("ProduceMF");
input_cursor_->PullMultiple(multi_frame, context);
auto iterator_for_valid_frame_only = multi_frame.GetValidFramesModifier();
for (auto &frame : iterator_for_valid_frame_only) {
// Produce should always yield the latest results.
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.request_router,
storage::v3::View::NEW);
for (auto *named_expr : self_.named_expressions_) {
named_expr->Accept(evaluator);
}
}
};
void Produce::ProduceCursor::Shutdown() { input_cursor_->Shutdown(); } void Produce::ProduceCursor::Shutdown() { input_cursor_->Shutdown(); }
void Produce::ProduceCursor::Reset() { input_cursor_->Reset(); } void Produce::ProduceCursor::Reset() { input_cursor_->Reset(); }

View File

@ -28,6 +28,7 @@
#include "query/v2/bindings/typed_value.hpp" #include "query/v2/bindings/typed_value.hpp"
#include "query/v2/bindings/frame.hpp" #include "query/v2/bindings/frame.hpp"
#include "query/v2/bindings/symbol_table.hpp" #include "query/v2/bindings/symbol_table.hpp"
#include "query/v2/multiframe.hpp"
#include "storage/v3/id_types.hpp" #include "storage/v3/id_types.hpp"
#include "utils/bound.hpp" #include "utils/bound.hpp"
#include "utils/fnv.hpp" #include "utils/fnv.hpp"
@ -71,6 +72,8 @@ class Cursor {
/// @throws QueryRuntimeException if something went wrong with execution /// @throws QueryRuntimeException if something went wrong with execution
virtual bool Pull(Frame &, ExecutionContext &) = 0; virtual bool Pull(Frame &, ExecutionContext &) = 0;
virtual void PullMultiple(MultiFrame &, ExecutionContext &) { LOG_FATAL("PullMultipleIsNotImplemented"); }
/// Resets the Cursor to its initial state. /// Resets the Cursor to its initial state.
virtual void Reset() = 0; virtual void Reset() = 0;
@ -332,6 +335,7 @@ and false on every following Pull.")
class OnceCursor : public Cursor { class OnceCursor : public Cursor {
public: public:
OnceCursor() {} OnceCursor() {}
void PullMultiple(MultiFrame &, ExecutionContext &) override;
bool Pull(Frame &, ExecutionContext &) override; bool Pull(Frame &, ExecutionContext &) override;
void Shutdown() override; void Shutdown() override;
void Reset() override; void Reset() override;
@ -1207,6 +1211,7 @@ RETURN clause) the Produce's pull succeeds exactly once.")
public: public:
ProduceCursor(const Produce &, utils::MemoryResource *); ProduceCursor(const Produce &, utils::MemoryResource *);
bool Pull(Frame &, ExecutionContext &) override; bool Pull(Frame &, ExecutionContext &) override;
void PullMultiple(MultiFrame &, ExecutionContext &) override;
void Shutdown() override; void Shutdown() override;
void Reset() override; void Reset() override;

View File

@ -11,6 +11,7 @@
#pragma once #pragma once
#include <algorithm>
#include <boost/uuid/uuid.hpp> #include <boost/uuid/uuid.hpp>
#include <chrono> #include <chrono>
#include <deque> #include <deque>
@ -34,6 +35,7 @@
#include "io/address.hpp" #include "io/address.hpp"
#include "io/errors.hpp" #include "io/errors.hpp"
#include "io/local_transport/local_transport.hpp" #include "io/local_transport/local_transport.hpp"
#include "io/notifier.hpp"
#include "io/rsm/raft.hpp" #include "io/rsm/raft.hpp"
#include "io/rsm/rsm_client.hpp" #include "io/rsm/rsm_client.hpp"
#include "io/rsm/shard_rsm.hpp" #include "io/rsm/shard_rsm.hpp"
@ -46,11 +48,12 @@
#include "utils/result.hpp" #include "utils/result.hpp"
namespace memgraph::query::v2 { namespace memgraph::query::v2 {
template <typename TStorageClient> template <typename TStorageClient>
class RsmStorageClientManager { class RsmStorageClientManager {
public: public:
using CompoundKey = io::rsm::ShardRsmKey; using CompoundKey = io::rsm::ShardRsmKey;
using Shard = coordinator::Shard; using ShardMetadata = coordinator::ShardMetadata;
RsmStorageClientManager() = default; RsmStorageClientManager() = default;
RsmStorageClientManager(const RsmStorageClientManager &) = delete; RsmStorageClientManager(const RsmStorageClientManager &) = delete;
RsmStorageClientManager(RsmStorageClientManager &&) = delete; RsmStorageClientManager(RsmStorageClientManager &&) = delete;
@ -58,45 +61,31 @@ class RsmStorageClientManager {
RsmStorageClientManager &operator=(RsmStorageClientManager &&) = delete; RsmStorageClientManager &operator=(RsmStorageClientManager &&) = delete;
~RsmStorageClientManager() = default; ~RsmStorageClientManager() = default;
void AddClient(Shard key, TStorageClient client) { cli_cache_.emplace(std::move(key), std::move(client)); } void AddClient(ShardMetadata key, TStorageClient client) { cli_cache_.emplace(std::move(key), std::move(client)); }
bool Exists(const Shard &key) { return cli_cache_.contains(key); } bool Exists(const ShardMetadata &key) { return cli_cache_.contains(key); }
void PurgeCache() { cli_cache_.clear(); } void PurgeCache() { cli_cache_.clear(); }
TStorageClient &GetClient(const Shard &key) { TStorageClient &GetClient(const ShardMetadata &key) {
auto it = cli_cache_.find(key); auto it = cli_cache_.find(key);
MG_ASSERT(it != cli_cache_.end(), "Non-existing shard client"); MG_ASSERT(it != cli_cache_.end(), "Non-existing shard client");
return it->second; return it->second;
} }
private: private:
std::map<Shard, TStorageClient> cli_cache_; std::map<ShardMetadata, TStorageClient> cli_cache_;
}; };
template <typename TRequest> template <typename TRequest>
struct ShardRequestState { struct ShardRequestState {
memgraph::coordinator::Shard shard; memgraph::coordinator::ShardMetadata shard;
TRequest request; TRequest request;
std::optional<io::rsm::AsyncRequestToken> async_request_token;
}; };
// maps from ReadinessToken's internal size_t to the associated state
template <typename TRequest> template <typename TRequest>
struct ExecutionState { using RunningRequests = std::unordered_map<size_t, ShardRequestState<TRequest>>;
using CompoundKey = io::rsm::ShardRsmKey;
using Shard = coordinator::Shard;
// label is optional because some operators can create/remove etc, vertices. These kind of requests contain the label
// on the request itself.
std::optional<std::string> label;
// Transaction id to be filled by the RequestRouter implementation
coordinator::Hlc transaction_id;
// Initialized by RequestRouter implementation. This vector is filled with the shards that
// the RequestRouter impl will send requests to. When a request to a shard exhausts it, meaning that
// it pulled all the requested data from the given Shard, it will be removed from the Vector. When the Vector becomes
// empty, it means that all of the requests have completed succefully.
std::vector<ShardRequestState<TRequest>> requests;
};
class RequestRouterInterface { class RequestRouterInterface {
public: public:
@ -115,6 +104,7 @@ class RequestRouterInterface {
virtual std::vector<msgs::CreateVerticesResponse> CreateVertices(std::vector<msgs::NewVertex> new_vertices) = 0; virtual std::vector<msgs::CreateVerticesResponse> CreateVertices(std::vector<msgs::NewVertex> new_vertices) = 0;
virtual std::vector<msgs::ExpandOneResultRow> ExpandOne(msgs::ExpandOneRequest request) = 0; virtual std::vector<msgs::ExpandOneResultRow> ExpandOne(msgs::ExpandOneRequest request) = 0;
virtual std::vector<msgs::CreateExpandResponse> CreateExpand(std::vector<msgs::NewExpand> new_edges) = 0; virtual std::vector<msgs::CreateExpandResponse> CreateExpand(std::vector<msgs::NewExpand> new_edges) = 0;
virtual std::vector<msgs::GetPropertiesResultRow> GetProperties(msgs::GetPropertiesRequest request) = 0;
virtual storage::v3::EdgeTypeId NameToEdgeType(const std::string &name) const = 0; virtual storage::v3::EdgeTypeId NameToEdgeType(const std::string &name) const = 0;
virtual storage::v3::PropertyId NameToProperty(const std::string &name) const = 0; virtual storage::v3::PropertyId NameToProperty(const std::string &name) const = 0;
@ -140,7 +130,7 @@ class RequestRouter : public RequestRouterInterface {
using CoordinatorWriteRequests = coordinator::CoordinatorWriteRequests; using CoordinatorWriteRequests = coordinator::CoordinatorWriteRequests;
using CoordinatorClient = coordinator::CoordinatorClient<TTransport>; using CoordinatorClient = coordinator::CoordinatorClient<TTransport>;
using Address = io::Address; using Address = io::Address;
using Shard = coordinator::Shard; using ShardMetadata = coordinator::ShardMetadata;
using ShardMap = coordinator::ShardMap; using ShardMap = coordinator::ShardMap;
using CompoundKey = coordinator::PrimaryKey; using CompoundKey = coordinator::PrimaryKey;
using VertexAccessor = query::v2::accessors::VertexAccessor; using VertexAccessor = query::v2::accessors::VertexAccessor;
@ -153,10 +143,16 @@ class RequestRouter : public RequestRouterInterface {
~RequestRouter() override {} ~RequestRouter() override {}
void InstallSimulatorTicker(std::function<bool()> tick_simulator) {
notifier_.InstallSimulatorTicker(tick_simulator);
}
void StartTransaction() override { void StartTransaction() override {
coordinator::HlcRequest req{.last_shard_map_version = shards_map_.GetHlc()}; coordinator::HlcRequest req{.last_shard_map_version = shards_map_.GetHlc()};
CoordinatorWriteRequests write_req = req; CoordinatorWriteRequests write_req = req;
spdlog::trace("sending hlc request to start transaction");
auto write_res = coord_cli_.SendWriteRequest(write_req); auto write_res = coord_cli_.SendWriteRequest(write_req);
spdlog::trace("received hlc response to start transaction");
if (write_res.HasError()) { if (write_res.HasError()) {
throw std::runtime_error("HLC request failed"); throw std::runtime_error("HLC request failed");
} }
@ -175,7 +171,9 @@ class RequestRouter : public RequestRouterInterface {
void Commit() override { void Commit() override {
coordinator::HlcRequest req{.last_shard_map_version = shards_map_.GetHlc()}; coordinator::HlcRequest req{.last_shard_map_version = shards_map_.GetHlc()};
CoordinatorWriteRequests write_req = req; CoordinatorWriteRequests write_req = req;
spdlog::trace("sending hlc request before committing transaction");
auto write_res = coord_cli_.SendWriteRequest(write_req); auto write_res = coord_cli_.SendWriteRequest(write_req);
spdlog::trace("received hlc response before committing transaction");
if (write_res.HasError()) { if (write_res.HasError()) {
throw std::runtime_error("HLC request for commit failed"); throw std::runtime_error("HLC request for commit failed");
} }
@ -243,26 +241,25 @@ class RequestRouter : public RequestRouterInterface {
// TODO(kostasrim) Simplify return result // TODO(kostasrim) Simplify return result
std::vector<VertexAccessor> ScanVertices(std::optional<std::string> label) override { std::vector<VertexAccessor> ScanVertices(std::optional<std::string> label) override {
ExecutionState<msgs::ScanVerticesRequest> state = {};
state.label = label;
// create requests // create requests
InitializeExecutionState(state); std::vector<ShardRequestState<msgs::ScanVerticesRequest>> requests_to_be_sent = RequestsForScanVertices(label);
spdlog::trace("created {} ScanVertices requests", requests_to_be_sent.size());
// begin all requests in parallel // begin all requests in parallel
for (auto &request : state.requests) { RunningRequests<msgs::ScanVerticesRequest> running_requests = {};
running_requests.reserve(requests_to_be_sent.size());
for (size_t i = 0; i < requests_to_be_sent.size(); i++) {
auto &request = requests_to_be_sent[i];
io::ReadinessToken readiness_token{i};
auto &storage_client = GetStorageClientForShard(request.shard); auto &storage_client = GetStorageClientForShard(request.shard);
msgs::ReadRequests req = request.request; storage_client.SendAsyncReadRequest(request.request, notifier_, readiness_token);
running_requests.emplace(readiness_token.GetId(), request);
request.async_request_token = storage_client.SendAsyncReadRequest(request.request);
} }
spdlog::trace("sent {} ScanVertices requests in parallel", running_requests.size());
// drive requests to completion // drive requests to completion
std::vector<msgs::ScanVerticesResponse> responses; auto responses = DriveReadResponses<msgs::ScanVerticesRequest, msgs::ScanVerticesResponse>(running_requests);
responses.reserve(state.requests.size()); spdlog::trace("got back {} ScanVertices responses after driving to completion", responses.size());
do {
DriveReadResponses(state, responses);
} while (!state.requests.empty());
// convert responses into VertexAccessor objects to return // convert responses into VertexAccessor objects to return
std::vector<VertexAccessor> accessors; std::vector<VertexAccessor> accessors;
@ -277,62 +274,55 @@ class RequestRouter : public RequestRouterInterface {
} }
std::vector<msgs::CreateVerticesResponse> CreateVertices(std::vector<msgs::NewVertex> new_vertices) override { std::vector<msgs::CreateVerticesResponse> CreateVertices(std::vector<msgs::NewVertex> new_vertices) override {
ExecutionState<msgs::CreateVerticesRequest> state = {};
MG_ASSERT(!new_vertices.empty()); MG_ASSERT(!new_vertices.empty());
// create requests // create requests
InitializeExecutionState(state, new_vertices); std::vector<ShardRequestState<msgs::CreateVerticesRequest>> requests_to_be_sent =
RequestsForCreateVertices(new_vertices);
spdlog::trace("created {} CreateVertices requests", requests_to_be_sent.size());
// begin all requests in parallel // begin all requests in parallel
for (auto &request : state.requests) { RunningRequests<msgs::CreateVerticesRequest> running_requests = {};
auto req_deep_copy = request.request; running_requests.reserve(requests_to_be_sent.size());
for (size_t i = 0; i < requests_to_be_sent.size(); i++) {
for (auto &new_vertex : req_deep_copy.new_vertices) { auto &request = requests_to_be_sent[i];
io::ReadinessToken readiness_token{i};
for (auto &new_vertex : request.request.new_vertices) {
new_vertex.label_ids.erase(new_vertex.label_ids.begin()); new_vertex.label_ids.erase(new_vertex.label_ids.begin());
} }
auto &storage_client = GetStorageClientForShard(request.shard); auto &storage_client = GetStorageClientForShard(request.shard);
storage_client.SendAsyncWriteRequest(request.request, notifier_, readiness_token);
msgs::WriteRequests req = req_deep_copy; running_requests.emplace(readiness_token.GetId(), request);
request.async_request_token = storage_client.SendAsyncWriteRequest(req);
} }
spdlog::trace("sent {} CreateVertices requests in parallel", running_requests.size());
// drive requests to completion // drive requests to completion
std::vector<msgs::CreateVerticesResponse> responses; return DriveWriteResponses<msgs::CreateVerticesRequest, msgs::CreateVerticesResponse>(running_requests);
responses.reserve(state.requests.size());
do {
DriveWriteResponses(state, responses);
} while (!state.requests.empty());
return responses;
} }
std::vector<msgs::CreateExpandResponse> CreateExpand(std::vector<msgs::NewExpand> new_edges) override { std::vector<msgs::CreateExpandResponse> CreateExpand(std::vector<msgs::NewExpand> new_edges) override {
ExecutionState<msgs::CreateExpandRequest> state = {};
MG_ASSERT(!new_edges.empty()); MG_ASSERT(!new_edges.empty());
// create requests // create requests
InitializeExecutionState(state, new_edges); std::vector<ShardRequestState<msgs::CreateExpandRequest>> requests_to_be_sent = RequestsForCreateExpand(new_edges);
// begin all requests in parallel // begin all requests in parallel
for (auto &request : state.requests) { RunningRequests<msgs::CreateExpandRequest> running_requests = {};
running_requests.reserve(requests_to_be_sent.size());
for (size_t i = 0; i < requests_to_be_sent.size(); i++) {
auto &request = requests_to_be_sent[i];
io::ReadinessToken readiness_token{i};
auto &storage_client = GetStorageClientForShard(request.shard); auto &storage_client = GetStorageClientForShard(request.shard);
msgs::WriteRequests req = request.request; msgs::WriteRequests req = request.request;
request.async_request_token = storage_client.SendAsyncWriteRequest(req); storage_client.SendAsyncWriteRequest(req, notifier_, readiness_token);
running_requests.emplace(readiness_token.GetId(), request);
} }
// drive requests to completion // drive requests to completion
std::vector<msgs::CreateExpandResponse> responses; return DriveWriteResponses<msgs::CreateExpandRequest, msgs::CreateExpandResponse>(running_requests);
responses.reserve(state.requests.size());
do {
DriveWriteResponses(state, responses);
} while (!state.requests.empty());
return responses;
} }
std::vector<msgs::ExpandOneResultRow> ExpandOne(msgs::ExpandOneRequest request) override { std::vector<msgs::ExpandOneResultRow> ExpandOne(msgs::ExpandOneRequest request) override {
ExecutionState<msgs::ExpandOneRequest> state = {};
// TODO(kostasrim)Update to limit the batch size here // TODO(kostasrim)Update to limit the batch size here
// Expansions of the destination must be handled by the caller. For example // Expansions of the destination must be handled by the caller. For example
// match (u:L1 { prop : 1 })-[:Friend]-(v:L1) // match (u:L1 { prop : 1 })-[:Friend]-(v:L1)
@ -340,21 +330,22 @@ class RequestRouter : public RequestRouterInterface {
// must be fetched again with an ExpandOne(Edges.dst) // must be fetched again with an ExpandOne(Edges.dst)
// create requests // create requests
InitializeExecutionState(state, std::move(request)); std::vector<ShardRequestState<msgs::ExpandOneRequest>> requests_to_be_sent = RequestsForExpandOne(request);
// begin all requests in parallel // begin all requests in parallel
for (auto &request : state.requests) { RunningRequests<msgs::ExpandOneRequest> running_requests = {};
running_requests.reserve(requests_to_be_sent.size());
for (size_t i = 0; i < requests_to_be_sent.size(); i++) {
auto &request = requests_to_be_sent[i];
io::ReadinessToken readiness_token{i};
auto &storage_client = GetStorageClientForShard(request.shard); auto &storage_client = GetStorageClientForShard(request.shard);
msgs::ReadRequests req = request.request; msgs::ReadRequests req = request.request;
request.async_request_token = storage_client.SendAsyncReadRequest(req); storage_client.SendAsyncReadRequest(req, notifier_, readiness_token);
running_requests.emplace(readiness_token.GetId(), request);
} }
// drive requests to completion // drive requests to completion
std::vector<msgs::ExpandOneResponse> responses; auto responses = DriveReadResponses<msgs::ExpandOneRequest, msgs::ExpandOneResponse>(running_requests);
responses.reserve(state.requests.size());
do {
DriveReadResponses(state, responses);
} while (!state.requests.empty());
// post-process responses // post-process responses
std::vector<msgs::ExpandOneResultRow> result_rows; std::vector<msgs::ExpandOneResultRow> result_rows;
@ -372,6 +363,36 @@ class RequestRouter : public RequestRouterInterface {
return result_rows; return result_rows;
} }
std::vector<msgs::GetPropertiesResultRow> GetProperties(msgs::GetPropertiesRequest requests) override {
// create requests
std::vector<ShardRequestState<msgs::GetPropertiesRequest>> requests_to_be_sent =
RequestsForGetProperties(std::move(requests));
// begin all requests in parallel
RunningRequests<msgs::GetPropertiesRequest> running_requests = {};
running_requests.reserve(requests_to_be_sent.size());
for (size_t i = 0; i < requests_to_be_sent.size(); i++) {
auto &request = requests_to_be_sent[i];
io::ReadinessToken readiness_token{i};
auto &storage_client = GetStorageClientForShard(request.shard);
msgs::ReadRequests req = request.request;
storage_client.SendAsyncReadRequest(req, notifier_, readiness_token);
running_requests.emplace(readiness_token.GetId(), request);
}
// drive requests to completion
auto responses = DriveReadResponses<msgs::GetPropertiesRequest, msgs::GetPropertiesResponse>(running_requests);
// post-process responses
std::vector<msgs::GetPropertiesResultRow> result_rows;
for (auto &&response : responses) {
std::move(response.result_row.begin(), response.result_row.end(), std::back_inserter(result_rows));
}
return result_rows;
}
std::optional<storage::v3::PropertyId> MaybeNameToProperty(const std::string &name) const override { std::optional<storage::v3::PropertyId> MaybeNameToProperty(const std::string &name) const override {
return shards_map_.GetPropertyId(name); return shards_map_.GetPropertyId(name);
} }
@ -385,11 +406,9 @@ class RequestRouter : public RequestRouterInterface {
} }
private: private:
void InitializeExecutionState(ExecutionState<msgs::CreateVerticesRequest> &state, std::vector<ShardRequestState<msgs::CreateVerticesRequest>> RequestsForCreateVertices(
std::vector<msgs::NewVertex> new_vertices) { const std::vector<msgs::NewVertex> &new_vertices) {
state.transaction_id = transaction_id_; std::map<ShardMetadata, msgs::CreateVerticesRequest> per_shard_request_table;
std::map<Shard, msgs::CreateVerticesRequest> per_shard_request_table;
for (auto &new_vertex : new_vertices) { for (auto &new_vertex : new_vertices) {
MG_ASSERT(!new_vertex.label_ids.empty(), "No label_ids provided for new vertex in RequestRouter::CreateVertices"); MG_ASSERT(!new_vertex.label_ids.empty(), "No label_ids provided for new vertex in RequestRouter::CreateVertices");
@ -402,23 +421,24 @@ class RequestRouter : public RequestRouterInterface {
per_shard_request_table[shard].new_vertices.push_back(std::move(new_vertex)); per_shard_request_table[shard].new_vertices.push_back(std::move(new_vertex));
} }
std::vector<ShardRequestState<msgs::CreateVerticesRequest>> requests = {};
for (auto &[shard, request] : per_shard_request_table) { for (auto &[shard, request] : per_shard_request_table) {
ShardRequestState<msgs::CreateVerticesRequest> shard_request_state{ ShardRequestState<msgs::CreateVerticesRequest> shard_request_state{
.shard = shard, .shard = shard,
.request = request, .request = request,
.async_request_token = std::nullopt,
}; };
state.requests.emplace_back(std::move(shard_request_state)); requests.emplace_back(std::move(shard_request_state));
}
} }
void InitializeExecutionState(ExecutionState<msgs::CreateExpandRequest> &state, return requests;
std::vector<msgs::NewExpand> new_expands) { }
state.transaction_id = transaction_id_;
std::map<Shard, msgs::CreateExpandRequest> per_shard_request_table; std::vector<ShardRequestState<msgs::CreateExpandRequest>> RequestsForCreateExpand(
const std::vector<msgs::NewExpand> &new_expands) {
std::map<ShardMetadata, msgs::CreateExpandRequest> per_shard_request_table;
auto ensure_shard_exists_in_table = [&per_shard_request_table, auto ensure_shard_exists_in_table = [&per_shard_request_table,
transaction_id = transaction_id_](const Shard &shard) { transaction_id = transaction_id_](const ShardMetadata &shard) {
if (!per_shard_request_table.contains(shard)) { if (!per_shard_request_table.contains(shard)) {
msgs::CreateExpandRequest create_expand_request{.transaction_id = transaction_id}; msgs::CreateExpandRequest create_expand_request{.transaction_id = transaction_id};
per_shard_request_table.insert({shard, std::move(create_expand_request)}); per_shard_request_table.insert({shard, std::move(create_expand_request)});
@ -440,30 +460,36 @@ class RequestRouter : public RequestRouterInterface {
per_shard_request_table[shard_src_vertex].new_expands.push_back(std::move(new_expand)); per_shard_request_table[shard_src_vertex].new_expands.push_back(std::move(new_expand));
} }
std::vector<ShardRequestState<msgs::CreateExpandRequest>> requests = {};
for (auto &[shard, request] : per_shard_request_table) { for (auto &[shard, request] : per_shard_request_table) {
ShardRequestState<msgs::CreateExpandRequest> shard_request_state{ ShardRequestState<msgs::CreateExpandRequest> shard_request_state{
.shard = shard, .shard = shard,
.request = request, .request = request,
.async_request_token = std::nullopt,
}; };
state.requests.emplace_back(std::move(shard_request_state)); requests.emplace_back(std::move(shard_request_state));
}
} }
void InitializeExecutionState(ExecutionState<msgs::ScanVerticesRequest> &state) { return requests;
}
std::vector<ShardRequestState<msgs::ScanVerticesRequest>> RequestsForScanVertices(
const std::optional<std::string> &label) {
std::vector<coordinator::Shards> multi_shards; std::vector<coordinator::Shards> multi_shards;
state.transaction_id = transaction_id_; if (label) {
if (!state.label) { const auto label_id = shards_map_.GetLabelId(*label);
multi_shards = shards_map_.GetAllShards();
} else {
const auto label_id = shards_map_.GetLabelId(*state.label);
MG_ASSERT(label_id); MG_ASSERT(label_id);
MG_ASSERT(IsPrimaryLabel(*label_id)); MG_ASSERT(IsPrimaryLabel(*label_id));
multi_shards = {shards_map_.GetShardsForLabel(*state.label)}; multi_shards = {shards_map_.GetShardsForLabel(*label)};
} else {
multi_shards = shards_map_.GetAllShards();
} }
std::vector<ShardRequestState<msgs::ScanVerticesRequest>> requests = {};
for (auto &shards : multi_shards) { for (auto &shards : multi_shards) {
for (auto &[key, shard] : shards) { for (auto &[key, shard] : shards) {
MG_ASSERT(!shard.empty()); MG_ASSERT(!shard.peers.empty());
msgs::ScanVerticesRequest request; msgs::ScanVerticesRequest request;
request.transaction_id = transaction_id_; request.transaction_id = transaction_id_;
@ -472,22 +498,21 @@ class RequestRouter : public RequestRouterInterface {
ShardRequestState<msgs::ScanVerticesRequest> shard_request_state{ ShardRequestState<msgs::ScanVerticesRequest> shard_request_state{
.shard = shard, .shard = shard,
.request = std::move(request), .request = std::move(request),
.async_request_token = std::nullopt,
}; };
state.requests.emplace_back(std::move(shard_request_state)); requests.emplace_back(std::move(shard_request_state));
}
} }
} }
void InitializeExecutionState(ExecutionState<msgs::ExpandOneRequest> &state, msgs::ExpandOneRequest request) { return requests;
state.transaction_id = transaction_id_; }
std::map<Shard, msgs::ExpandOneRequest> per_shard_request_table; std::vector<ShardRequestState<msgs::ExpandOneRequest>> RequestsForExpandOne(const msgs::ExpandOneRequest &request) {
auto top_level_rqst_template = request; std::map<ShardMetadata, msgs::ExpandOneRequest> per_shard_request_table;
msgs::ExpandOneRequest top_level_rqst_template = request;
top_level_rqst_template.transaction_id = transaction_id_; top_level_rqst_template.transaction_id = transaction_id_;
top_level_rqst_template.src_vertices.clear(); top_level_rqst_template.src_vertices.clear();
state.requests.clear();
for (auto &vertex : request.src_vertices) { for (auto &vertex : request.src_vertices) {
auto shard = auto shard =
shards_map_.GetShardForKey(vertex.first.id, storage::conversions::ConvertPropertyVector(vertex.second)); shards_map_.GetShardForKey(vertex.first.id, storage::conversions::ConvertPropertyVector(vertex.second));
@ -497,18 +522,61 @@ class RequestRouter : public RequestRouterInterface {
per_shard_request_table[shard].src_vertices.push_back(vertex); per_shard_request_table[shard].src_vertices.push_back(vertex);
} }
std::vector<ShardRequestState<msgs::ExpandOneRequest>> requests = {};
for (auto &[shard, request] : per_shard_request_table) { for (auto &[shard, request] : per_shard_request_table) {
ShardRequestState<msgs::ExpandOneRequest> shard_request_state{ ShardRequestState<msgs::ExpandOneRequest> shard_request_state{
.shard = shard, .shard = shard,
.request = request, .request = request,
.async_request_token = std::nullopt,
}; };
state.requests.emplace_back(std::move(shard_request_state)); requests.emplace_back(std::move(shard_request_state));
}
} }
StorageClient &GetStorageClientForShard(Shard shard) { return requests;
}
std::vector<ShardRequestState<msgs::GetPropertiesRequest>> RequestsForGetProperties(
msgs::GetPropertiesRequest &&request) {
std::map<ShardMetadata, msgs::GetPropertiesRequest> per_shard_request_table;
auto top_level_rqst_template = request;
top_level_rqst_template.transaction_id = transaction_id_;
top_level_rqst_template.vertex_ids.clear();
top_level_rqst_template.vertices_and_edges.clear();
for (auto &&vertex : request.vertex_ids) {
auto shard =
shards_map_.GetShardForKey(vertex.first.id, storage::conversions::ConvertPropertyVector(vertex.second));
if (!per_shard_request_table.contains(shard)) {
per_shard_request_table.insert(std::pair(shard, top_level_rqst_template));
}
per_shard_request_table[shard].vertex_ids.emplace_back(std::move(vertex));
}
for (auto &[vertex, maybe_edge] : request.vertices_and_edges) {
auto shard =
shards_map_.GetShardForKey(vertex.first.id, storage::conversions::ConvertPropertyVector(vertex.second));
if (!per_shard_request_table.contains(shard)) {
per_shard_request_table.insert(std::pair(shard, top_level_rqst_template));
}
per_shard_request_table[shard].vertices_and_edges.emplace_back(std::move(vertex), maybe_edge);
}
std::vector<ShardRequestState<msgs::GetPropertiesRequest>> requests;
for (auto &[shard, rqst] : per_shard_request_table) {
ShardRequestState<msgs::GetPropertiesRequest> shard_request_state{
.shard = shard,
.request = std::move(rqst),
};
requests.emplace_back(std::move(shard_request_state));
}
return requests;
}
StorageClient &GetStorageClientForShard(ShardMetadata shard) {
if (!storage_cli_manager_.Exists(shard)) { if (!storage_cli_manager_.Exists(shard)) {
AddStorageClientToManager(shard); AddStorageClientToManager(shard);
} }
@ -520,12 +588,12 @@ class RequestRouter : public RequestRouterInterface {
return GetStorageClientForShard(std::move(shard)); return GetStorageClientForShard(std::move(shard));
} }
void AddStorageClientToManager(Shard target_shard) { void AddStorageClientToManager(ShardMetadata target_shard) {
MG_ASSERT(!target_shard.empty()); MG_ASSERT(!target_shard.peers.empty());
auto leader_addr = target_shard.front(); auto leader_addr = target_shard.peers.front();
std::vector<Address> addresses; std::vector<Address> addresses;
addresses.reserve(target_shard.size()); addresses.reserve(target_shard.peers.size());
for (auto &address : target_shard) { for (auto &address : target_shard.peers) {
addresses.push_back(std::move(address.address)); addresses.push_back(std::move(address.address));
} }
auto cli = StorageClient(io_, std::move(leader_addr.address), std::move(addresses)); auto cli = StorageClient(io_, std::move(leader_addr.address), std::move(addresses));
@ -533,13 +601,24 @@ class RequestRouter : public RequestRouterInterface {
} }
template <typename RequestT, typename ResponseT> template <typename RequestT, typename ResponseT>
void DriveReadResponses(ExecutionState<RequestT> &state, std::vector<ResponseT> &responses) { std::vector<ResponseT> DriveReadResponses(RunningRequests<RequestT> &running_requests) {
for (auto &request : state.requests) { // Store responses in a map based on the corresponding request
// offset, so that they can be reassembled in the correct order
// even if they came back in randomized orders.
std::map<size_t, ResponseT> response_map;
spdlog::trace("waiting on readiness for token");
while (response_map.size() < running_requests.size()) {
auto ready = notifier_.Await();
spdlog::trace("got readiness for token {}", ready.GetId());
auto &request = running_requests.at(ready.GetId());
auto &storage_client = GetStorageClientForShard(request.shard); auto &storage_client = GetStorageClientForShard(request.shard);
auto poll_result = storage_client.AwaitAsyncReadRequest(request.async_request_token.value()); std::optional<utils::BasicResult<io::TimedOut, msgs::ReadResponses>> poll_result =
while (!poll_result) { storage_client.PollAsyncReadRequest(ready);
poll_result = storage_client.AwaitAsyncReadRequest(request.async_request_token.value());
if (!poll_result.has_value()) {
continue;
} }
if (poll_result->HasError()) { if (poll_result->HasError()) {
@ -552,19 +631,40 @@ class RequestRouter : public RequestRouterInterface {
throw std::runtime_error("RequestRouter Read request did not succeed"); throw std::runtime_error("RequestRouter Read request did not succeed");
} }
responses.push_back(std::move(response)); // the readiness token has an ID based on the request vector offset
response_map.emplace(ready.GetId(), std::move(response));
} }
state.requests.clear();
std::vector<ResponseT> responses;
responses.reserve(running_requests.size());
int last = -1;
for (auto &&[offset, response] : response_map) {
MG_ASSERT(last + 1 == offset);
responses.emplace_back(std::forward<ResponseT>(response));
last = offset;
}
return responses;
} }
template <typename RequestT, typename ResponseT> template <typename RequestT, typename ResponseT>
void DriveWriteResponses(ExecutionState<RequestT> &state, std::vector<ResponseT> &responses) { std::vector<ResponseT> DriveWriteResponses(RunningRequests<RequestT> &running_requests) {
for (auto &request : state.requests) { // Store responses in a map based on the corresponding request
// offset, so that they can be reassembled in the correct order
// even if they came back in randomized orders.
std::map<size_t, ResponseT> response_map;
while (response_map.size() < running_requests.size()) {
auto ready = notifier_.Await();
auto &request = running_requests.at(ready.GetId());
auto &storage_client = GetStorageClientForShard(request.shard); auto &storage_client = GetStorageClientForShard(request.shard);
auto poll_result = storage_client.AwaitAsyncWriteRequest(request.async_request_token.value()); std::optional<utils::BasicResult<io::TimedOut, msgs::WriteResponses>> poll_result =
while (!poll_result) { storage_client.PollAsyncWriteRequest(ready);
poll_result = storage_client.AwaitAsyncWriteRequest(request.async_request_token.value());
if (!poll_result.has_value()) {
continue;
} }
if (poll_result->HasError()) { if (poll_result->HasError()) {
@ -577,9 +677,21 @@ class RequestRouter : public RequestRouterInterface {
throw std::runtime_error("RequestRouter Write request did not succeed"); throw std::runtime_error("RequestRouter Write request did not succeed");
} }
responses.push_back(std::move(response)); // the readiness token has an ID based on the request vector offset
response_map.emplace(ready.GetId(), std::move(response));
} }
state.requests.clear();
std::vector<ResponseT> responses;
responses.reserve(running_requests.size());
int last = -1;
for (auto &&[offset, response] : response_map) {
MG_ASSERT(last + 1 == offset);
responses.emplace_back(std::forward<ResponseT>(response));
last = offset;
}
return responses;
} }
void SetUpNameIdMappers() { void SetUpNameIdMappers() {
@ -625,6 +737,7 @@ class RequestRouter : public RequestRouterInterface {
RsmStorageClientManager<StorageClient> storage_cli_manager_; RsmStorageClientManager<StorageClient> storage_cli_manager_;
io::Io<TTransport> io_; io::Io<TTransport> io_;
coordinator::Hlc transaction_id_; coordinator::Hlc transaction_id_;
io::Notifier notifier_ = {};
// TODO(kostasrim) Add batch prefetching // TODO(kostasrim) Add batch prefetching
}; };

View File

@ -327,10 +327,6 @@ struct Expression {
std::string expression; std::string expression;
}; };
struct Filter {
std::string filter_expression;
};
enum class OrderingDirection { ASCENDING = 1, DESCENDING = 2 }; enum class OrderingDirection { ASCENDING = 1, DESCENDING = 2 };
struct OrderBy { struct OrderBy {
@ -372,21 +368,32 @@ struct ScanVerticesResponse {
std::vector<ScanResultRow> results; std::vector<ScanResultRow> results;
}; };
using VertexOrEdgeIds = std::variant<VertexId, EdgeId>;
struct GetPropertiesRequest { struct GetPropertiesRequest {
Hlc transaction_id; Hlc transaction_id;
// Shouldn't contain mixed vertex and edge ids std::vector<VertexId> vertex_ids;
VertexOrEdgeIds vertex_or_edge_ids; std::vector<std::pair<VertexId, EdgeId>> vertices_and_edges;
std::vector<PropertyId> property_ids;
std::vector<Expression> expressions; std::optional<std::vector<PropertyId>> property_ids;
bool only_unique = false; std::vector<std::string> expressions;
std::optional<std::vector<OrderBy>> order_by;
std::vector<OrderBy> order_by;
std::optional<size_t> limit; std::optional<size_t> limit;
std::optional<Filter> filter;
// Return only the properties of the vertices or edges that the filter predicate
// evaluates to true
std::optional<std::string> filter;
};
struct GetPropertiesResultRow {
VertexId vertex;
std::optional<EdgeId> edge;
std::vector<std::pair<PropertyId, Value>> props;
std::vector<Value> evaluated_expressions;
}; };
struct GetPropertiesResponse { struct GetPropertiesResponse {
std::vector<GetPropertiesResultRow> result_row;
std::optional<ShardError> error; std::optional<ShardError> error;
}; };

View File

@ -17,5 +17,5 @@
#include "storage/v3/bindings/typed_value.hpp" #include "storage/v3/bindings/typed_value.hpp"
namespace memgraph::storage::v3 { namespace memgraph::storage::v3 {
using Frame = memgraph::expr::Frame<TypedValue>; using Frame = memgraph::expr::Frame;
} // namespace memgraph::storage::v3 } // namespace memgraph::storage::v3

View File

@ -11,6 +11,7 @@
#include "storage/v3/request_helper.hpp" #include "storage/v3/request_helper.hpp"
#include <iterator>
#include <vector> #include <vector>
#include "storage/v3/bindings/db_accessor.hpp" #include "storage/v3/bindings/db_accessor.hpp"
@ -220,30 +221,39 @@ std::vector<TypedValue> EvaluateVertexExpressions(DbAccessor &dba, const VertexA
return evaluated_expressions; return evaluated_expressions;
} }
ShardResult<std::map<PropertyId, Value>> CollectAllPropertiesFromAccessor(const VertexAccessor &acc, View view, std::vector<TypedValue> EvaluateEdgeExpressions(DbAccessor &dba, const VertexAccessor &v_acc, const EdgeAccessor &e_acc,
const Schemas::Schema &schema) { const std::vector<std::string> &expressions) {
std::map<PropertyId, Value> ret; std::vector<TypedValue> evaluated_expressions;
auto props = acc.Properties(view); evaluated_expressions.reserve(expressions.size());
if (props.HasError()) {
spdlog::debug("Encountered an error while trying to get vertex properties."); std::transform(expressions.begin(), expressions.end(), std::back_inserter(evaluated_expressions),
return props.GetError(); [&dba, &v_acc, &e_acc](const auto &expression) {
return ComputeExpression(dba, v_acc, e_acc, expression, expr::identifier_node_symbol,
expr::identifier_edge_symbol);
});
return evaluated_expressions;
} }
auto &properties = props.GetValue(); ShardResult<std::map<PropertyId, Value>> CollectAllPropertiesFromAccessor(const VertexAccessor &acc, View view,
std::transform(properties.begin(), properties.end(), std::inserter(ret, ret.begin()), const Schemas::Schema &schema) {
[](std::pair<const PropertyId, PropertyValue> &pair) { auto ret = impl::CollectAllPropertiesImpl<VertexAccessor>(acc, view);
return std::make_pair(pair.first, FromPropertyValueToValue(std::move(pair.second))); if (ret.HasError()) {
}); return ret.GetError();
properties.clear(); }
auto pks = PrimaryKeysFromAccessor(acc, view, schema); auto pks = PrimaryKeysFromAccessor(acc, view, schema);
if (pks) { if (pks) {
ret.merge(*pks); ret.GetValue().merge(std::move(*pks));
} }
return ret; return ret;
} }
ShardResult<std::map<PropertyId, Value>> CollectAllPropertiesFromAccessor(const VertexAccessor &acc, View view) {
return impl::CollectAllPropertiesImpl(acc, view);
}
EdgeUniquenessFunction InitializeEdgeUniquenessFunction(bool only_unique_neighbor_rows) { EdgeUniquenessFunction InitializeEdgeUniquenessFunction(bool only_unique_neighbor_rows) {
// Functions to select connecting edges based on uniquness // Functions to select connecting edges based on uniquness
EdgeUniquenessFunction maybe_filter_based_on_edge_uniquness; EdgeUniquenessFunction maybe_filter_based_on_edge_uniquness;
@ -350,11 +360,20 @@ EdgeFiller InitializeEdgeFillerFunction(const msgs::ExpandOneRequest &req) {
return edge_filler; return edge_filler;
} }
bool FilterOnVertex(DbAccessor &dba, const storage::v3::VertexAccessor &v_acc, const std::vector<std::string> &filters, bool FilterOnVertex(DbAccessor &dba, const storage::v3::VertexAccessor &v_acc,
const std::string_view node_name) { const std::vector<std::string> &filters) {
return std::ranges::all_of(filters, [&node_name, &dba, &v_acc](const auto &filter_expr) { return std::ranges::all_of(filters, [&dba, &v_acc](const auto &filter_expr) {
auto res = ComputeExpression(dba, v_acc, std::nullopt, filter_expr, node_name, ""); const auto result = ComputeExpression(dba, v_acc, std::nullopt, filter_expr, expr::identifier_node_symbol, "");
return res.IsBool() && res.ValueBool(); return result.IsBool() && result.ValueBool();
});
}
bool FilterOnEdge(DbAccessor &dba, const storage::v3::VertexAccessor &v_acc, const EdgeAccessor &e_acc,
const std::vector<std::string> &filters) {
return std::ranges::all_of(filters, [&dba, &v_acc, &e_acc](const auto &filter_expr) {
const auto result =
ComputeExpression(dba, v_acc, e_acc, filter_expr, expr::identifier_node_symbol, expr::identifier_edge_symbol);
return result.IsBool() && result.ValueBool();
}); });
} }
@ -526,4 +545,36 @@ std::vector<Element<EdgeAccessor>> OrderByEdges(DbAccessor &dba, std::vector<Edg
return ordered; return ordered;
} }
std::vector<Element<std::pair<VertexAccessor, EdgeAccessor>>> OrderByEdges(
DbAccessor &dba, std::vector<EdgeAccessor> &iterable, std::vector<msgs::OrderBy> &order_by_edges,
const std::vector<VertexAccessor> &vertex_acc) {
MG_ASSERT(vertex_acc.size() == iterable.size());
std::vector<Ordering> ordering;
ordering.reserve(order_by_edges.size());
std::transform(order_by_edges.begin(), order_by_edges.end(), std::back_inserter(ordering),
[](const auto &order_by) { return ConvertMsgsOrderByToOrdering(order_by.direction); });
std::vector<Element<std::pair<VertexAccessor, EdgeAccessor>>> ordered;
VertexAccessor current = vertex_acc.front();
size_t id = 0;
for (auto it = iterable.begin(); it != iterable.end(); it++, id++) {
current = vertex_acc[id];
std::vector<TypedValue> properties_order_by;
properties_order_by.reserve(order_by_edges.size());
std::transform(order_by_edges.begin(), order_by_edges.end(), std::back_inserter(properties_order_by),
[&dba, it, current](const auto &order_by) {
return ComputeExpression(dba, current, *it, order_by.expression.expression,
expr::identifier_node_symbol, expr::identifier_edge_symbol);
});
ordered.push_back({std::move(properties_order_by), {current, *it}});
}
auto compare_typed_values = TypedValueVectorCompare(ordering);
std::sort(ordered.begin(), ordered.end(), [compare_typed_values](const auto &pair1, const auto &pair2) {
return compare_typed_values(pair1.properties_order_by, pair2.properties_order_by);
});
return ordered;
}
} // namespace memgraph::storage::v3 } // namespace memgraph::storage::v3

View File

@ -20,6 +20,7 @@
#include "storage/v3/edge_accessor.hpp" #include "storage/v3/edge_accessor.hpp"
#include "storage/v3/expr.hpp" #include "storage/v3/expr.hpp"
#include "storage/v3/shard.hpp" #include "storage/v3/shard.hpp"
#include "storage/v3/value_conversions.hpp"
#include "storage/v3/vertex_accessor.hpp" #include "storage/v3/vertex_accessor.hpp"
#include "utils/template_utils.hpp" #include "utils/template_utils.hpp"
@ -31,7 +32,7 @@ using EdgeFiller =
using msgs::Value; using msgs::Value;
template <typename T> template <typename T>
concept ObjectAccessor = utils::SameAsAnyOf<T, VertexAccessor, EdgeAccessor>; concept OrderableObject = utils::SameAsAnyOf<T, VertexAccessor, EdgeAccessor, std::pair<VertexAccessor, EdgeAccessor>>;
inline bool TypedValueCompare(const TypedValue &a, const TypedValue &b) { inline bool TypedValueCompare(const TypedValue &a, const TypedValue &b) {
// in ordering null comes after everything else // in ordering null comes after everything else
@ -125,7 +126,7 @@ class TypedValueVectorCompare final {
std::vector<Ordering> ordering_; std::vector<Ordering> ordering_;
}; };
template <ObjectAccessor TObjectAccessor> template <OrderableObject TObjectAccessor>
struct Element { struct Element {
std::vector<TypedValue> properties_order_by; std::vector<TypedValue> properties_order_by;
TObjectAccessor object_acc; TObjectAccessor object_acc;
@ -167,6 +168,10 @@ std::vector<Element<EdgeAccessor>> OrderByEdges(DbAccessor &dba, std::vector<Edg
std::vector<msgs::OrderBy> &order_by_edges, std::vector<msgs::OrderBy> &order_by_edges,
const VertexAccessor &vertex_acc); const VertexAccessor &vertex_acc);
std::vector<Element<std::pair<VertexAccessor, EdgeAccessor>>> OrderByEdges(
DbAccessor &dba, std::vector<EdgeAccessor> &iterable, std::vector<msgs::OrderBy> &order_by_edges,
const std::vector<VertexAccessor> &vertex_acc);
VerticesIterable::Iterator GetStartVertexIterator(VerticesIterable &vertex_iterable, VerticesIterable::Iterator GetStartVertexIterator(VerticesIterable &vertex_iterable,
const std::vector<PropertyValue> &primary_key, View view); const std::vector<PropertyValue> &primary_key, View view);
@ -177,19 +182,65 @@ std::vector<Element<VertexAccessor>>::const_iterator GetStartOrderedElementsIter
std::array<std::vector<EdgeAccessor>, 2> GetEdgesFromVertex(const VertexAccessor &vertex_accessor, std::array<std::vector<EdgeAccessor>, 2> GetEdgesFromVertex(const VertexAccessor &vertex_accessor,
msgs::EdgeDirection direction); msgs::EdgeDirection direction);
bool FilterOnVertex(DbAccessor &dba, const storage::v3::VertexAccessor &v_acc, const std::vector<std::string> &filters, bool FilterOnVertex(DbAccessor &dba, const storage::v3::VertexAccessor &v_acc, const std::vector<std::string> &filters);
std::string_view node_name);
bool FilterOnEdge(DbAccessor &dba, const storage::v3::VertexAccessor &v_acc, const EdgeAccessor &e_acc,
const std::vector<std::string> &filters);
std::vector<TypedValue> EvaluateVertexExpressions(DbAccessor &dba, const VertexAccessor &v_acc, std::vector<TypedValue> EvaluateVertexExpressions(DbAccessor &dba, const VertexAccessor &v_acc,
const std::vector<std::string> &expressions, const std::vector<std::string> &expressions,
std::string_view node_name); std::string_view node_name);
ShardResult<std::map<PropertyId, Value>> CollectSpecificPropertiesFromAccessor(const VertexAccessor &acc, std::vector<TypedValue> EvaluateEdgeExpressions(DbAccessor &dba, const VertexAccessor &v_acc, const EdgeAccessor &e_acc,
const std::vector<std::string> &expressions);
template <typename T>
concept PropertiesAccessor = utils::SameAsAnyOf<T, VertexAccessor, EdgeAccessor>;
template <PropertiesAccessor TAccessor>
ShardResult<std::map<PropertyId, Value>> CollectSpecificPropertiesFromAccessor(const TAccessor &acc,
const std::vector<PropertyId> &props, const std::vector<PropertyId> &props,
View view); View view) {
std::map<PropertyId, Value> ret;
for (const auto &prop : props) {
auto result = acc.GetProperty(prop, view);
if (result.HasError()) {
spdlog::debug("Encountered an Error while trying to get a vertex property.");
return result.GetError();
}
auto &value = result.GetValue();
ret.emplace(std::make_pair(prop, FromPropertyValueToValue(std::move(value))));
}
return ret;
}
ShardResult<std::map<PropertyId, Value>> CollectAllPropertiesFromAccessor(const VertexAccessor &acc, View view, ShardResult<std::map<PropertyId, Value>> CollectAllPropertiesFromAccessor(const VertexAccessor &acc, View view,
const Schemas::Schema &schema); const Schemas::Schema &schema);
namespace impl {
template <PropertiesAccessor TAccessor>
ShardResult<std::map<PropertyId, Value>> CollectAllPropertiesImpl(const TAccessor &acc, View view) {
std::map<PropertyId, Value> ret;
auto props = acc.Properties(view);
if (props.HasError()) {
spdlog::debug("Encountered an error while trying to get vertex properties.");
return props.GetError();
}
auto &properties = props.GetValue();
std::transform(properties.begin(), properties.end(), std::inserter(ret, ret.begin()),
[](std::pair<const PropertyId, PropertyValue> &pair) {
return std::make_pair(pair.first, conversions::FromPropertyValueToValue(std::move(pair.second)));
});
return ret;
}
} // namespace impl
template <PropertiesAccessor TAccessor>
ShardResult<std::map<PropertyId, Value>> CollectAllPropertiesFromAccessor(const TAccessor &acc, View view) {
return impl::CollectAllPropertiesImpl<TAccessor>(acc, view);
}
EdgeUniquenessFunction InitializeEdgeUniquenessFunction(bool only_unique_neighbor_rows); EdgeUniquenessFunction InitializeEdgeUniquenessFunction(bool only_unique_neighbor_rows);

View File

@ -10,12 +10,16 @@
// licenses/APL.txt. // licenses/APL.txt.
#include <algorithm> #include <algorithm>
#include <exception>
#include <experimental/source_location>
#include <functional> #include <functional>
#include <iterator> #include <iterator>
#include <optional> #include <optional>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include <variant>
#include "common/errors.hpp"
#include "parser/opencypher/parser.hpp" #include "parser/opencypher/parser.hpp"
#include "query/v2/requests.hpp" #include "query/v2/requests.hpp"
#include "storage/v2/vertex.hpp" #include "storage/v2/vertex.hpp"
@ -29,6 +33,7 @@
#include "storage/v3/bindings/symbol_generator.hpp" #include "storage/v3/bindings/symbol_generator.hpp"
#include "storage/v3/bindings/symbol_table.hpp" #include "storage/v3/bindings/symbol_table.hpp"
#include "storage/v3/bindings/typed_value.hpp" #include "storage/v3/bindings/typed_value.hpp"
#include "storage/v3/conversions.hpp"
#include "storage/v3/expr.hpp" #include "storage/v3/expr.hpp"
#include "storage/v3/id_types.hpp" #include "storage/v3/id_types.hpp"
#include "storage/v3/key_store.hpp" #include "storage/v3/key_store.hpp"
@ -326,7 +331,7 @@ msgs::ReadResponses ShardRsm::HandleRead(msgs::ScanVerticesRequest &&req) {
std::vector<Value> expression_results; std::vector<Value> expression_results;
if (!req.filter_expressions.empty()) { if (!req.filter_expressions.empty()) {
// NOTE - DbAccessor might get removed in the future. // NOTE - DbAccessor might get removed in the future.
const bool eval = FilterOnVertex(dba, vertex, req.filter_expressions, expr::identifier_node_symbol); const bool eval = FilterOnVertex(dba, vertex, req.filter_expressions);
if (!eval) { if (!eval) {
return; return;
} }
@ -431,7 +436,7 @@ msgs::ReadResponses ShardRsm::HandleRead(msgs::ExpandOneRequest &&req) {
} }
if (!req.filters.empty()) { if (!req.filters.empty()) {
// NOTE - DbAccessor might get removed in the future. // NOTE - DbAccessor might get removed in the future.
const bool eval = FilterOnVertex(dba, src_vertex_acc_opt.value(), req.filters, expr::identifier_node_symbol); const bool eval = FilterOnVertex(dba, src_vertex_acc_opt.value(), req.filters);
if (!eval) { if (!eval) {
continue; continue;
} }
@ -510,9 +515,191 @@ msgs::WriteResponses ShardRsm::ApplyWrite(msgs::CommitRequest &&req) {
return msgs::CommitResponse{}; return msgs::CommitResponse{};
}; };
// NOLINTNEXTLINE(readability-convert-member-functions-to-static) msgs::ReadResponses ShardRsm::HandleRead(msgs::GetPropertiesRequest &&req) {
msgs::ReadResponses ShardRsm::HandleRead(msgs::GetPropertiesRequest && /*req*/) { if (!req.vertex_ids.empty() && !req.vertices_and_edges.empty()) {
return msgs::GetPropertiesResponse{}; auto shard_error = SHARD_ERROR(ErrorCode::NONEXISTENT_OBJECT);
auto error = CreateErrorResponse(shard_error, req.transaction_id, "");
return msgs::GetPropertiesResponse{.error = {}};
}
auto shard_acc = shard_->Access(req.transaction_id);
auto dba = DbAccessor{&shard_acc};
const auto view = storage::v3::View::NEW;
auto transform_props = [](std::map<PropertyId, Value> &&value) {
std::vector<std::pair<PropertyId, Value>> result;
result.reserve(value.size());
for (auto &[id, val] : value) {
result.emplace_back(std::make_pair(id, std::move(val)));
}
return result;
};
auto collect_props = [&req](const VertexAccessor &v_acc,
const std::optional<EdgeAccessor> &e_acc) -> ShardResult<std::map<PropertyId, Value>> {
if (!req.property_ids) {
if (e_acc) {
return CollectAllPropertiesFromAccessor(*e_acc, view);
}
return CollectAllPropertiesFromAccessor(v_acc, view);
}
if (e_acc) {
return CollectSpecificPropertiesFromAccessor(*e_acc, *req.property_ids, view);
}
return CollectSpecificPropertiesFromAccessor(v_acc, *req.property_ids, view);
};
auto find_edge = [](const VertexAccessor &v, msgs::EdgeId e) -> std::optional<EdgeAccessor> {
auto in = v.InEdges(view);
MG_ASSERT(in.HasValue());
for (auto &edge : in.GetValue()) {
if (edge.Gid().AsUint() == e.gid) {
return edge;
}
}
auto out = v.OutEdges(view);
MG_ASSERT(out.HasValue());
for (auto &edge : out.GetValue()) {
if (edge.Gid().AsUint() == e.gid) {
return edge;
}
}
return std::nullopt;
};
const auto has_expr_to_evaluate = !req.expressions.empty();
auto emplace_result_row =
[dba, transform_props, collect_props, has_expr_to_evaluate, &req](
const VertexAccessor &v_acc,
const std::optional<EdgeAccessor> e_acc) mutable -> ShardResult<msgs::GetPropertiesResultRow> {
auto maybe_id = v_acc.Id(view);
if (maybe_id.HasError()) {
return {maybe_id.GetError()};
}
const auto &id = maybe_id.GetValue();
std::optional<msgs::EdgeId> e_id;
if (e_acc) {
e_id = msgs::EdgeId{e_acc->Gid().AsUint()};
}
msgs::VertexId v_id{msgs::Label{id.primary_label}, ConvertValueVector(id.primary_key)};
auto maybe_props = collect_props(v_acc, e_acc);
if (maybe_props.HasError()) {
return {maybe_props.GetError()};
}
auto props = transform_props(std::move(maybe_props.GetValue()));
auto result = msgs::GetPropertiesResultRow{.vertex = std::move(v_id), .edge = e_id, .props = std::move(props)};
if (has_expr_to_evaluate) {
std::vector<Value> e_results;
if (e_acc) {
e_results =
ConvertToValueVectorFromTypedValueVector(EvaluateEdgeExpressions(dba, v_acc, *e_acc, req.expressions));
} else {
e_results = ConvertToValueVectorFromTypedValueVector(
EvaluateVertexExpressions(dba, v_acc, req.expressions, expr::identifier_node_symbol));
}
result.evaluated_expressions = std::move(e_results);
}
return {std::move(result)};
};
auto get_limit = [&req](const auto &elements) {
size_t limit = elements.size();
if (req.limit && *req.limit < elements.size()) {
limit = *req.limit;
}
return limit;
};
auto collect_response = [get_limit, &req](auto &elements, auto create_result_row) -> msgs::ReadResponses {
msgs::GetPropertiesResponse response;
const auto limit = get_limit(elements);
for (size_t index = 0; index != limit; ++index) {
auto result_row = create_result_row(elements[index]);
if (result_row.HasError()) {
return msgs::GetPropertiesResponse{.error = CreateErrorResponse(result_row.GetError(), req.transaction_id, "")};
}
response.result_row.push_back(std::move(result_row.GetValue()));
}
return response;
};
std::vector<VertexAccessor> vertices;
std::vector<EdgeAccessor> edges;
auto parse_and_filter = [dba, &vertices](auto &container, auto projection, auto filter, auto maybe_get_edge) mutable {
for (const auto &elem : container) {
const auto &[label, pk_v] = projection(elem);
auto pk = ConvertPropertyVector(pk_v);
auto v_acc = dba.FindVertex(pk, view);
if (!v_acc || filter(*v_acc, maybe_get_edge(elem))) {
continue;
}
vertices.push_back(*v_acc);
}
};
auto identity = [](auto &elem) { return elem; };
auto filter_vertex = [dba, req](const auto &acc, const auto & /*edge*/) mutable {
if (!req.filter) {
return false;
}
return !FilterOnVertex(dba, acc, {*req.filter});
};
auto filter_edge = [dba, &edges, &req, find_edge](const auto &acc, const auto &edge) mutable {
auto e_acc = find_edge(acc, edge);
if (!e_acc) {
return true;
}
if (req.filter && !FilterOnEdge(dba, acc, *e_acc, {*req.filter})) {
return true;
}
edges.push_back(*e_acc);
return false;
};
// Handler logic here
if (!req.vertex_ids.empty()) {
parse_and_filter(req.vertex_ids, identity, filter_vertex, identity);
} else {
parse_and_filter(
req.vertices_and_edges, [](auto &e) { return e.first; }, filter_edge, [](auto &e) { return e.second; });
}
if (!req.vertex_ids.empty()) {
if (!req.order_by.empty()) {
auto elements = OrderByVertices(dba, vertices, req.order_by);
return collect_response(elements, [emplace_result_row](auto &element) mutable {
return emplace_result_row(element.object_acc, std::nullopt);
});
}
return collect_response(vertices,
[emplace_result_row](auto &acc) mutable { return emplace_result_row(acc, std::nullopt); });
}
if (!req.order_by.empty()) {
auto elements = OrderByEdges(dba, edges, req.order_by, vertices);
return collect_response(elements, [emplace_result_row](auto &element) mutable {
return emplace_result_row(element.object_acc.first, element.object_acc.second);
});
}
struct ZipView {
ZipView(std::vector<VertexAccessor> &v, std::vector<EdgeAccessor> &e) : v(v), e(e) {}
size_t size() const { return v.size(); }
auto operator[](size_t index) { return std::make_pair(v[index], e[index]); }
private:
std::vector<VertexAccessor> &v;
std::vector<EdgeAccessor> &e;
};
ZipView vertices_and_edges(vertices, edges);
return collect_response(vertices_and_edges, [emplace_result_row](const auto &acc) mutable {
return emplace_result_row(acc.first, acc.second);
});
} }
} // namespace memgraph::storage::v3 } // namespace memgraph::storage::v3

View File

@ -9,11 +9,13 @@ function(add_benchmark test_cpp)
get_filename_component(exec_name ${test_cpp} NAME_WE) get_filename_component(exec_name ${test_cpp} NAME_WE)
set(target_name ${test_prefix}${exec_name}) set(target_name ${test_prefix}${exec_name})
add_executable(${target_name} ${test_cpp} ${ARGN}) add_executable(${target_name} ${test_cpp} ${ARGN})
# OUTPUT_NAME sets the real name of a target when it is built and can be # OUTPUT_NAME sets the real name of a target when it is built and can be
# used to help create two targets of the same name even though CMake # used to help create two targets of the same name even though CMake
# requires unique logical target names # requires unique logical target names
set_target_properties(${target_name} PROPERTIES OUTPUT_NAME ${exec_name}) set_target_properties(${target_name} PROPERTIES OUTPUT_NAME ${exec_name})
target_link_libraries(${target_name} benchmark gflags) target_link_libraries(${target_name} benchmark gflags)
# register test # register test
add_test(${target_name} ${exec_name}) add_test(${target_name} ${exec_name})
add_dependencies(memgraph__benchmark ${target_name}) add_dependencies(memgraph__benchmark ${target_name})
@ -65,3 +67,15 @@ target_link_libraries(${test_prefix}storage_v2_property_store mg-storage-v2)
add_benchmark(future.cpp) add_benchmark(future.cpp)
target_link_libraries(${test_prefix}future mg-io) target_link_libraries(${test_prefix}future mg-io)
add_benchmark(data_structures_insert.cpp)
target_link_libraries(${test_prefix}data_structures_insert mg-utils mg-storage-v3)
add_benchmark(data_structures_find.cpp)
target_link_libraries(${test_prefix}data_structures_find mg-utils mg-storage-v3)
add_benchmark(data_structures_contains.cpp)
target_link_libraries(${test_prefix}data_structures_contains mg-utils mg-storage-v3)
add_benchmark(data_structures_remove.cpp)
target_link_libraries(${test_prefix}data_structures_remove mg-utils mg-storage-v3)

View File

@ -0,0 +1,58 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <map>
#include <set>
#include <vector>
#include "coordinator/hybrid_logical_clock.hpp"
#include "storage/v3/key_store.hpp"
#include "storage/v3/lexicographically_ordered_vertex.hpp"
#include "storage/v3/mvcc.hpp"
#include "storage/v3/transaction.hpp"
#include "utils/skip_list.hpp"
namespace memgraph::benchmark {
template <typename T>
inline void PrepareData(utils::SkipList<T> &skip_list, const int64_t num_elements) {
coordinator::Hlc start_timestamp;
storage::v3::Transaction transaction{start_timestamp, storage::v3::IsolationLevel::SNAPSHOT_ISOLATION};
for (auto i{0}; i < num_elements; ++i) {
auto acc = skip_list.access();
acc.insert({storage::v3::PrimaryKey{storage::v3::PropertyValue{true}}});
}
}
template <typename TKey, typename TValue>
inline void PrepareData(std::map<TKey, TValue> &std_map, const int64_t num_elements) {
coordinator::Hlc start_timestamp;
storage::v3::Transaction transaction{start_timestamp, storage::v3::IsolationLevel::SNAPSHOT_ISOLATION};
auto *delta = storage::v3::CreateDeleteObjectDelta(&transaction);
for (auto i{0}; i < num_elements; ++i) {
std_map.insert({storage::v3::PrimaryKey{storage::v3::PropertyValue{i}},
storage::v3::LexicographicallyOrderedVertex{storage::v3::Vertex{
delta, std::vector<storage::v3::PropertyValue>{storage::v3::PropertyValue{true}}}}});
}
}
template <typename T>
inline void PrepareData(std::set<T> &std_set, const int64_t num_elements) {
coordinator::Hlc start_timestamp;
storage::v3::Transaction transaction{start_timestamp, storage::v3::IsolationLevel::SNAPSHOT_ISOLATION};
for (auto i{0}; i < num_elements; ++i) {
std_set.insert(std::vector<storage::v3::PropertyValue>{storage::v3::PropertyValue{true}});
}
}
} // namespace memgraph::benchmark

View File

@ -0,0 +1,105 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <atomic>
#include <concepts>
#include <cstddef>
#include <cstdint>
#include <exception>
#include <map>
#include <set>
#include <stdexcept>
#include <type_traits>
#include <vector>
#include <benchmark/benchmark.h>
#include <gflags/gflags.h>
#include "data_structures_common.hpp"
#include "storage/v3/key_store.hpp"
#include "storage/v3/lexicographically_ordered_vertex.hpp"
#include "storage/v3/mvcc.hpp"
#include "storage/v3/property_value.hpp"
#include "storage/v3/transaction.hpp"
#include "storage/v3/vertex.hpp"
#include "utils/skip_list.hpp"
namespace memgraph::benchmark {
///////////////////////////////////////////////////////////////////////////////
// Testing Contains Operation
///////////////////////////////////////////////////////////////////////////////
static void BM_BenchmarkContainsSkipList(::benchmark::State &state) {
utils::SkipList<storage::v3::PrimaryKey> skip_list;
PrepareData(skip_list, state.range(0));
// So we can also have elements that does don't exist
std::mt19937 i_generator(std::random_device{}());
std::uniform_int_distribution<int64_t> i_distribution(0, state.range(0) * 2);
int64_t found_elems{0};
for (auto _ : state) {
for (auto i{0}; i < state.range(0); ++i) {
int64_t value = i_distribution(i_generator);
auto acc = skip_list.access();
if (acc.contains(storage::v3::PrimaryKey{{storage::v3::PropertyValue(value)}})) {
found_elems++;
}
}
}
state.SetItemsProcessed(found_elems);
}
static void BM_BenchmarkContainsStdMap(::benchmark::State &state) {
std::map<storage::v3::PrimaryKey, storage::v3::LexicographicallyOrderedVertex> std_map;
PrepareData(std_map, state.range(0));
// So we can also have elements that does don't exist
std::mt19937 i_generator(std::random_device{}());
std::uniform_int_distribution<int64_t> i_distribution(0, state.range(0) * 2);
int64_t found_elems{0};
for (auto _ : state) {
for (auto i{0}; i < state.range(0); ++i) {
int64_t value = i_distribution(i_generator);
if (std_map.contains(storage::v3::PrimaryKey{{storage::v3::PropertyValue(value)}})) {
found_elems++;
}
}
}
state.SetItemsProcessed(found_elems);
}
static void BM_BenchmarkContainsStdSet(::benchmark::State &state) {
std::set<storage::v3::PrimaryKey> std_set;
PrepareData(std_set, state.range(0));
// So we can also have elements that does don't exist
std::mt19937 i_generator(std::random_device{}());
std::uniform_int_distribution<int64_t> i_distribution(0, state.range(0) * 2);
int64_t found_elems{0};
for (auto _ : state) {
for (auto i{0}; i < state.range(0); ++i) {
int64_t value = i_distribution(i_generator);
if (std_set.contains(storage::v3::PrimaryKey{storage::v3::PropertyValue{value}})) {
found_elems++;
}
}
}
state.SetItemsProcessed(found_elems);
}
BENCHMARK(BM_BenchmarkContainsSkipList)->RangeMultiplier(10)->Range(1000, 10000000)->Unit(::benchmark::kMillisecond);
BENCHMARK(BM_BenchmarkContainsStdMap)->RangeMultiplier(10)->Range(1000, 10000000)->Unit(::benchmark::kMillisecond);
BENCHMARK(BM_BenchmarkContainsStdSet)->RangeMultiplier(10)->Range(1000, 10000000)->Unit(::benchmark::kMillisecond);
} // namespace memgraph::benchmark
BENCHMARK_MAIN();

View File

@ -0,0 +1,104 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <atomic>
#include <concepts>
#include <cstddef>
#include <cstdint>
#include <exception>
#include <map>
#include <set>
#include <stdexcept>
#include <type_traits>
#include <vector>
#include <benchmark/benchmark.h>
#include <gflags/gflags.h>
#include "data_structures_common.hpp"
#include "storage/v3/key_store.hpp"
#include "storage/v3/lexicographically_ordered_vertex.hpp"
#include "storage/v3/mvcc.hpp"
#include "storage/v3/property_value.hpp"
#include "storage/v3/transaction.hpp"
#include "storage/v3/vertex.hpp"
#include "utils/skip_list.hpp"
namespace memgraph::benchmark {
///////////////////////////////////////////////////////////////////////////////
// Testing Find Operation
///////////////////////////////////////////////////////////////////////////////
static void BM_BenchmarkFindSkipList(::benchmark::State &state) {
utils::SkipList<storage::v3::PrimaryKey> skip_list;
PrepareData(skip_list, state.range(0));
// So we can also have elements that does don't exist
std::mt19937 i_generator(std::random_device{}());
std::uniform_int_distribution<int64_t> i_distribution(0, state.range(0) * 2);
int64_t found_elems{0};
for (auto _ : state) {
for (auto i{0}; i < state.range(0); ++i) {
int64_t value = i_distribution(i_generator);
auto acc = skip_list.access();
if (acc.find(storage::v3::PrimaryKey{{storage::v3::PropertyValue(value)}}) != acc.end()) {
found_elems++;
}
}
}
state.SetItemsProcessed(found_elems);
}
static void BM_BenchmarkFindStdMap(::benchmark::State &state) {
std::map<storage::v3::PrimaryKey, storage::v3::LexicographicallyOrderedVertex> std_map;
PrepareData(std_map, state.range(0));
// So we can also have elements that does don't exist
std::mt19937 i_generator(std::random_device{}());
std::uniform_int_distribution<int64_t> i_distribution(0, state.range(0) * 2);
int64_t found_elems{0};
for (auto _ : state) {
for (auto i{0}; i < state.range(0); ++i) {
int64_t value = i_distribution(i_generator);
if (std_map.find(storage::v3::PrimaryKey{{storage::v3::PropertyValue(value)}}) != std_map.end()) {
found_elems++;
}
}
}
state.SetItemsProcessed(found_elems);
}
static void BM_BenchmarkFindStdSet(::benchmark::State &state) {
std::set<storage::v3::PrimaryKey> std_set;
PrepareData(std_set, state.range(0));
// So we can also have elements that does don't exist
std::mt19937 i_generator(std::random_device{}());
std::uniform_int_distribution<int64_t> i_distribution(0, state.range(0) * 2);
int64_t found_elems{0};
for (auto _ : state) {
for (auto i{0}; i < state.range(0); ++i) {
int64_t value = i_distribution(i_generator);
if (std_set.find(storage::v3::PrimaryKey{storage::v3::PropertyValue{value}}) != std_set.end()) {
found_elems++;
}
}
}
state.SetItemsProcessed(found_elems);
}
BENCHMARK(BM_BenchmarkFindSkipList)->RangeMultiplier(10)->Range(1000, 10000000)->Unit(::benchmark::kMillisecond);
BENCHMARK(BM_BenchmarkFindStdMap)->RangeMultiplier(10)->Range(1000, 10000000)->Unit(::benchmark::kMillisecond);
BENCHMARK(BM_BenchmarkFindStdSet)->RangeMultiplier(10)->Range(1000, 10000000)->Unit(::benchmark::kMillisecond);
} // namespace memgraph::benchmark
BENCHMARK_MAIN();

View File

@ -0,0 +1,85 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <atomic>
#include <concepts>
#include <cstddef>
#include <cstdint>
#include <exception>
#include <map>
#include <set>
#include <stdexcept>
#include <type_traits>
#include <vector>
#include <benchmark/benchmark.h>
#include <gflags/gflags.h>
#include "storage/v3/key_store.hpp"
#include "storage/v3/lexicographically_ordered_vertex.hpp"
#include "storage/v3/mvcc.hpp"
#include "storage/v3/property_value.hpp"
#include "storage/v3/transaction.hpp"
#include "storage/v3/vertex.hpp"
#include "utils/skip_list.hpp"
namespace memgraph::benchmark {
///////////////////////////////////////////////////////////////////////////////
// Testing Insert Operation
///////////////////////////////////////////////////////////////////////////////
static void BM_BenchmarkInsertSkipList(::benchmark::State &state) {
utils::SkipList<storage::v3::PrimaryKey> skip_list;
coordinator::Hlc start_timestamp;
storage::v3::Transaction transaction{start_timestamp, storage::v3::IsolationLevel::SNAPSHOT_ISOLATION};
for (auto _ : state) {
for (auto i{0}; i < state.range(0); ++i) {
auto acc = skip_list.access();
acc.insert({storage::v3::PrimaryKey{storage::v3::PropertyValue{true}}});
}
}
}
static void BM_BenchmarkInsertStdMap(::benchmark::State &state) {
std::map<storage::v3::PrimaryKey, storage::v3::LexicographicallyOrderedVertex> std_map;
coordinator::Hlc start_timestamp;
storage::v3::Transaction transaction{start_timestamp, storage::v3::IsolationLevel::SNAPSHOT_ISOLATION};
auto *delta = storage::v3::CreateDeleteObjectDelta(&transaction);
for (auto _ : state) {
for (auto i{0}; i < state.range(0); ++i) {
std_map.insert({storage::v3::PrimaryKey{storage::v3::PropertyValue{i}},
storage::v3::LexicographicallyOrderedVertex{storage::v3::Vertex{
delta, std::vector<storage::v3::PropertyValue>{storage::v3::PropertyValue{true}}}}});
}
}
}
static void BM_BenchmarkInsertStdSet(::benchmark::State &state) {
std::set<storage::v3::PrimaryKey> std_set;
for (auto _ : state) {
for (auto i{0}; i < state.range(0); ++i) {
std_set.insert(
storage::v3::PrimaryKey{std::vector<storage::v3::PropertyValue>{storage::v3::PropertyValue{true}}});
}
}
}
BENCHMARK(BM_BenchmarkInsertSkipList)->RangeMultiplier(10)->Range(1000, 10000000)->Unit(::benchmark::kMillisecond);
BENCHMARK(BM_BenchmarkInsertStdMap)->RangeMultiplier(10)->Range(1000, 10000000)->Unit(::benchmark::kMillisecond);
BENCHMARK(BM_BenchmarkInsertStdSet)->RangeMultiplier(10)->Range(1000, 10000000)->Unit(::benchmark::kMillisecond);
} // namespace memgraph::benchmark
BENCHMARK_MAIN();

View File

@ -0,0 +1,106 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <atomic>
#include <concepts>
#include <cstddef>
#include <cstdint>
#include <exception>
#include <map>
#include <set>
#include <stdexcept>
#include <type_traits>
#include <vector>
#include <benchmark/benchmark.h>
#include <gflags/gflags.h>
#include "data_structures_common.hpp"
#include "storage/v3/key_store.hpp"
#include "storage/v3/lexicographically_ordered_vertex.hpp"
#include "storage/v3/mvcc.hpp"
#include "storage/v3/property_value.hpp"
#include "storage/v3/transaction.hpp"
#include "storage/v3/vertex.hpp"
#include "utils/skip_list.hpp"
namespace memgraph::benchmark {
///////////////////////////////////////////////////////////////////////////////
// Testing Remove Operation
///////////////////////////////////////////////////////////////////////////////
static void BM_BenchmarkRemoveSkipList(::benchmark::State &state) {
utils::SkipList<storage::v3::PrimaryKey> skip_list;
PrepareData(skip_list, state.range(0));
// So we can also have elements that don't exist
std::mt19937 i_generator(std::random_device{}());
std::uniform_int_distribution<int64_t> i_distribution(0, state.range(0) * 2);
int64_t removed_elems{0};
for (auto _ : state) {
for (auto i{0}; i < state.range(0); ++i) {
int64_t value = i_distribution(i_generator);
auto acc = skip_list.access();
if (acc.remove(storage::v3::PrimaryKey{storage::v3::PropertyValue(value)})) {
removed_elems++;
}
}
}
state.SetItemsProcessed(removed_elems);
}
static void BM_BenchmarkRemoveStdMap(::benchmark::State &state) {
std::map<storage::v3::PrimaryKey, storage::v3::LexicographicallyOrderedVertex> std_map;
PrepareData(std_map, state.range(0));
// So we can also have elements that does don't exist
std::mt19937 i_generator(std::random_device{}());
std::uniform_int_distribution<int64_t> i_distribution(0, state.range(0) * 2);
int64_t removed_elems{0};
for (auto _ : state) {
for (auto i{0}; i < state.range(0); ++i) {
int64_t value = i_distribution(i_generator);
if (std_map.erase(storage::v3::PrimaryKey{storage::v3::PropertyValue{value}}) > 0) {
removed_elems++;
}
}
}
state.SetItemsProcessed(removed_elems);
}
static void BM_BenchmarkRemoveStdSet(::benchmark::State &state) {
std::set<storage::v3::PrimaryKey> std_set;
PrepareData(std_set, state.range(0));
// So we can also have elements that does don't exist
std::mt19937 i_generator(std::random_device{}());
std::uniform_int_distribution<int64_t> i_distribution(0, state.range(0) * 2);
int64_t removed_elems{0};
for (auto _ : state) {
for (auto i{0}; i < state.range(0); ++i) {
int64_t value = i_distribution(i_generator);
if (std_set.erase(storage::v3::PrimaryKey{storage::v3::PropertyValue{value}}) > 0) {
removed_elems++;
}
}
}
state.SetItemsProcessed(removed_elems);
}
BENCHMARK(BM_BenchmarkRemoveSkipList)->RangeMultiplier(10)->Range(1000, 10000000)->Unit(::benchmark::kMillisecond);
BENCHMARK(BM_BenchmarkRemoveStdMap)->RangeMultiplier(10)->Range(1000, 10000000)->Unit(::benchmark::kMillisecond);
BENCHMARK(BM_BenchmarkRemoveStdSet)->RangeMultiplier(10)->Range(1000, 10000000)->Unit(::benchmark::kMillisecond);
} // namespace memgraph::benchmark
BENCHMARK_MAIN();

View File

@ -1,4 +1,4 @@
// Copyright 2021 Memgraph Ltd. // Copyright 2022 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -11,11 +11,14 @@
#pragma once #pragma once
#include <array>
#include <atomic> #include <atomic>
#include <chrono> #include <chrono>
#include <cstdint>
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <numeric>
#include <thread> #include <thread>
#include <vector> #include <vector>
@ -26,7 +29,7 @@ DEFINE_int32(duration, 10, "Duration of test (in seconds)");
struct Stats { struct Stats {
uint64_t total{0}; uint64_t total{0};
uint64_t succ[4] = {0, 0, 0, 0}; std::array<uint64_t, 4> succ = {0, 0, 0, 0};
}; };
const int OP_INSERT = 0; const int OP_INSERT = 0;
@ -81,3 +84,27 @@ inline void RunConcurrentTest(std::function<void(std::atomic<bool> *, Stats *)>
std::cout << "Total successful: " << tot << " (" << tot / FLAGS_duration << " calls/s)" << std::endl; std::cout << "Total successful: " << tot << " (" << tot / FLAGS_duration << " calls/s)" << std::endl;
std::cout << "Total ops: " << tops << " (" << tops / FLAGS_duration << " calls/s)" << std::endl; std::cout << "Total ops: " << tops << " (" << tops / FLAGS_duration << " calls/s)" << std::endl;
} }
inline void RunTest(std::function<void(const std::atomic<bool> &, Stats &)> test_func) {
Stats stats;
std::atomic<bool> run{true};
{
std::jthread bg_thread(test_func, std::cref(run), std::ref(stats));
std::this_thread::sleep_for(std::chrono::seconds(FLAGS_duration));
run.store(false, std::memory_order_relaxed);
}
std::cout << " Operations: " << stats.total << std::endl;
std::cout << " Successful insert: " << stats.succ[0] << std::endl;
std::cout << " Successful contains: " << stats.succ[1] << std::endl;
std::cout << " Successful remove: " << stats.succ[2] << std::endl;
std::cout << " Successful find: " << stats.succ[3] << std::endl;
std::cout << std::endl;
const auto tot = std::accumulate(stats.succ.begin(), +stats.succ.begin() + 3, 0);
const auto tops = stats.total;
std::cout << "Total successful: " << tot << " (" << tot / FLAGS_duration << " calls/s)" << std::endl;
std::cout << "Total ops: " << tops << " (" << tops / FLAGS_duration << " calls/s)" << std::endl;
}

View File

@ -33,3 +33,4 @@ add_simulation_test(sharded_map.cpp)
add_simulation_test(shard_rsm.cpp) add_simulation_test(shard_rsm.cpp)
add_simulation_test(cluster_property_test.cpp) add_simulation_test(cluster_property_test.cpp)
add_simulation_test(cluster_property_test_cypher_queries.cpp) add_simulation_test(cluster_property_test_cypher_queries.cpp)
add_simulation_test(request_router.cpp)

View File

@ -76,14 +76,10 @@ class MockedShardRsm {
using WriteRequests = msgs::WriteRequests; using WriteRequests = msgs::WriteRequests;
using WriteResponses = msgs::WriteResponses; using WriteResponses = msgs::WriteResponses;
// ExpandOneResponse Read(ExpandOneRequest rqst);
// GetPropertiesResponse Read(GetPropertiesRequest rqst);
msgs::ScanVerticesResponse ReadImpl(msgs::ScanVerticesRequest rqst) { msgs::ScanVerticesResponse ReadImpl(msgs::ScanVerticesRequest rqst) {
msgs::ScanVerticesResponse ret; msgs::ScanVerticesResponse ret;
auto as_prop_val = storage::conversions::ConvertPropertyVector(rqst.start_id.second); auto as_prop_val = storage::conversions::ConvertPropertyVector(rqst.start_id.second);
if (!IsKeyInRange(as_prop_val)) { if (as_prop_val == ShardRsmKey{PropertyValue(0), PropertyValue(0)}) {
ret.success = false;
} else if (as_prop_val == ShardRsmKey{PropertyValue(0), PropertyValue(0)}) {
msgs::Value val(int64_t(0)); msgs::Value val(int64_t(0));
ret.next_start_id = std::make_optional<msgs::VertexId>(); ret.next_start_id = std::make_optional<msgs::VertexId>();
ret.next_start_id->second = ret.next_start_id->second =
@ -91,37 +87,46 @@ class MockedShardRsm {
msgs::ScanResultRow result; msgs::ScanResultRow result;
result.props.push_back(std::make_pair(msgs::PropertyId::FromUint(0), val)); result.props.push_back(std::make_pair(msgs::PropertyId::FromUint(0), val));
ret.results.push_back(std::move(result)); ret.results.push_back(std::move(result));
ret.success = true;
} else if (as_prop_val == ShardRsmKey{PropertyValue(1), PropertyValue(0)}) { } else if (as_prop_val == ShardRsmKey{PropertyValue(1), PropertyValue(0)}) {
msgs::ScanResultRow result; msgs::ScanResultRow result;
msgs::Value val(int64_t(1)); msgs::Value val(int64_t(1));
result.props.push_back(std::make_pair(msgs::PropertyId::FromUint(0), val)); result.props.push_back(std::make_pair(msgs::PropertyId::FromUint(0), val));
ret.results.push_back(std::move(result)); ret.results.push_back(std::move(result));
ret.success = true;
} else if (as_prop_val == ShardRsmKey{PropertyValue(12), PropertyValue(13)}) { } else if (as_prop_val == ShardRsmKey{PropertyValue(12), PropertyValue(13)}) {
msgs::ScanResultRow result; msgs::ScanResultRow result;
msgs::Value val(int64_t(444)); msgs::Value val(int64_t(444));
result.props.push_back(std::make_pair(msgs::PropertyId::FromUint(0), val)); result.props.push_back(std::make_pair(msgs::PropertyId::FromUint(0), val));
ret.results.push_back(std::move(result)); ret.results.push_back(std::move(result));
ret.success = true;
} else {
ret.success = false;
} }
return ret; return ret;
} }
msgs::ExpandOneResponse ReadImpl(msgs::ExpandOneRequest rqst) { return {}; } msgs::ExpandOneResponse ReadImpl(msgs::ExpandOneRequest rqst) { return {}; }
msgs::ExpandOneResponse ReadImpl(msgs::GetPropertiesRequest rqst) { return {}; } msgs::GetPropertiesResponse ReadImpl(msgs::GetPropertiesRequest rqst) {
msgs::GetPropertiesResponse resp;
auto &vertices = rqst.vertex_ids;
for (auto &vertex : vertices) {
auto as_prop_val = storage::conversions::ConvertPropertyVector(vertex.second);
if (as_prop_val == ShardRsmKey{PropertyValue(0), PropertyValue(0)}) {
resp.result_row.push_back(msgs::GetPropertiesResultRow{.vertex = std::move(vertex)});
} else if (as_prop_val == ShardRsmKey{PropertyValue(1), PropertyValue(0)}) {
resp.result_row.push_back(msgs::GetPropertiesResultRow{.vertex = std::move(vertex)});
} else if (as_prop_val == ShardRsmKey{PropertyValue(13), PropertyValue(13)}) {
resp.result_row.push_back(msgs::GetPropertiesResultRow{.vertex = std::move(vertex)});
}
}
return resp;
}
ReadResponses Read(ReadRequests read_requests) { ReadResponses Read(ReadRequests read_requests) {
return {std::visit([this]<typename T>(T &&request) { return ReadResponses{ReadImpl(std::forward<T>(request))}; }, return {std::visit([this]<typename T>(T &&request) { return ReadResponses{ReadImpl(std::forward<T>(request))}; },
std::move(read_requests))}; std::move(read_requests))};
} }
msgs::CreateVerticesResponse ApplyImpl(msgs::CreateVerticesRequest rqst) { return {.success = true}; } msgs::CreateVerticesResponse ApplyImpl(msgs::CreateVerticesRequest rqst) { return {}; }
msgs::DeleteVerticesResponse ApplyImpl(msgs::DeleteVerticesRequest rqst) { return {}; } msgs::DeleteVerticesResponse ApplyImpl(msgs::DeleteVerticesRequest rqst) { return {}; }
msgs::UpdateVerticesResponse ApplyImpl(msgs::UpdateVerticesRequest rqst) { return {}; } msgs::UpdateVerticesResponse ApplyImpl(msgs::UpdateVerticesRequest rqst) { return {}; }
msgs::CreateExpandResponse ApplyImpl(msgs::CreateExpandRequest rqst) { return {.success = true}; } msgs::CreateExpandResponse ApplyImpl(msgs::CreateExpandRequest rqst) { return {}; }
msgs::DeleteEdgesResponse ApplyImpl(msgs::DeleteEdgesRequest rqst) { return {}; } msgs::DeleteEdgesResponse ApplyImpl(msgs::DeleteEdgesRequest rqst) { return {}; }
msgs::UpdateEdgesResponse ApplyImpl(msgs::UpdateEdgesRequest rqst) { return {}; } msgs::UpdateEdgesResponse ApplyImpl(msgs::UpdateEdgesRequest rqst) { return {}; }
msgs::CommitResponse ApplyImpl(msgs::CommitRequest rqst) { return {}; } msgs::CommitResponse ApplyImpl(msgs::CommitRequest rqst) { return {}; }

View File

@ -18,6 +18,8 @@
#include <thread> #include <thread>
#include <vector> #include <vector>
#include <spdlog/cfg/env.h>
#include "common.hpp" #include "common.hpp"
#include "common/types.hpp" #include "common/types.hpp"
#include "coordinator/coordinator_client.hpp" #include "coordinator/coordinator_client.hpp"
@ -44,8 +46,8 @@ using coordinator::CoordinatorClient;
using coordinator::CoordinatorRsm; using coordinator::CoordinatorRsm;
using coordinator::HlcRequest; using coordinator::HlcRequest;
using coordinator::HlcResponse; using coordinator::HlcResponse;
using coordinator::Shard;
using coordinator::ShardMap; using coordinator::ShardMap;
using coordinator::ShardMetadata;
using coordinator::Shards; using coordinator::Shards;
using coordinator::Status; using coordinator::Status;
using io::Address; using io::Address;
@ -111,7 +113,7 @@ ShardMap CreateDummyShardmap(coordinator::Address a_io_1, coordinator::Address a
AddressAndStatus aas1_2{.address = a_io_2, .status = Status::CONSENSUS_PARTICIPANT}; AddressAndStatus aas1_2{.address = a_io_2, .status = Status::CONSENSUS_PARTICIPANT};
AddressAndStatus aas1_3{.address = a_io_3, .status = Status::CONSENSUS_PARTICIPANT}; AddressAndStatus aas1_3{.address = a_io_3, .status = Status::CONSENSUS_PARTICIPANT};
Shard shard1 = {aas1_1, aas1_2, aas1_3}; ShardMetadata shard1 = ShardMetadata{.peers = {aas1_1, aas1_2, aas1_3}, .version = 1};
auto key1 = storage::v3::PropertyValue(0); auto key1 = storage::v3::PropertyValue(0);
auto key2 = storage::v3::PropertyValue(0); auto key2 = storage::v3::PropertyValue(0);
@ -123,7 +125,7 @@ ShardMap CreateDummyShardmap(coordinator::Address a_io_1, coordinator::Address a
AddressAndStatus aas2_2{.address = b_io_2, .status = Status::CONSENSUS_PARTICIPANT}; AddressAndStatus aas2_2{.address = b_io_2, .status = Status::CONSENSUS_PARTICIPANT};
AddressAndStatus aas2_3{.address = b_io_3, .status = Status::CONSENSUS_PARTICIPANT}; AddressAndStatus aas2_3{.address = b_io_3, .status = Status::CONSENSUS_PARTICIPANT};
Shard shard2 = {aas2_1, aas2_2, aas2_3}; ShardMetadata shard2 = ShardMetadata{.peers = {aas2_1, aas2_2, aas2_3}, .version = 1};
auto key3 = storage::v3::PropertyValue(12); auto key3 = storage::v3::PropertyValue(12);
auto key4 = storage::v3::PropertyValue(13); auto key4 = storage::v3::PropertyValue(13);
@ -152,9 +154,7 @@ void RunStorageRaft(Raft<IoImpl, MockedShardRsm, WriteRequests, WriteResponses,
} }
void TestScanVertices(query::v2::RequestRouterInterface &request_router) { void TestScanVertices(query::v2::RequestRouterInterface &request_router) {
msgs::ExecutionState<ScanVerticesRequest> state{.label = "test_label"}; auto result = request_router.ScanVertices("test_label");
auto result = request_router.Request(state);
MG_ASSERT(result.size() == 2); MG_ASSERT(result.size() == 2);
{ {
auto prop = result[0].GetProperty(msgs::PropertyId::FromUint(0)); auto prop = result[0].GetProperty(msgs::PropertyId::FromUint(0));
@ -162,18 +162,10 @@ void TestScanVertices(query::v2::RequestRouterInterface &request_router) {
prop = result[1].GetProperty(msgs::PropertyId::FromUint(0)); prop = result[1].GetProperty(msgs::PropertyId::FromUint(0));
MG_ASSERT(prop.int_v == 444); MG_ASSERT(prop.int_v == 444);
} }
result = request_router.Request(state);
{
MG_ASSERT(result.size() == 1);
auto prop = result[0].GetProperty(msgs::PropertyId::FromUint(0));
MG_ASSERT(prop.int_v == 1);
}
} }
void TestCreateVertices(query::v2::RequestRouterInterface &request_router) { void TestCreateVertices(query::v2::RequestRouterInterface &request_router) {
using PropVal = msgs::Value; using PropVal = msgs::Value;
msgs::ExecutionState<CreateVerticesRequest> state;
std::vector<msgs::NewVertex> new_vertices; std::vector<msgs::NewVertex> new_vertices;
auto label_id = request_router.NameToLabel("test_label"); auto label_id = request_router.NameToLabel("test_label");
msgs::NewVertex a1{.primary_key = {PropVal(int64_t(1)), PropVal(int64_t(0))}}; msgs::NewVertex a1{.primary_key = {PropVal(int64_t(1)), PropVal(int64_t(0))}};
@ -183,13 +175,13 @@ void TestCreateVertices(query::v2::RequestRouterInterface &request_router) {
new_vertices.push_back(std::move(a1)); new_vertices.push_back(std::move(a1));
new_vertices.push_back(std::move(a2)); new_vertices.push_back(std::move(a2));
auto result = request_router.Request(state, std::move(new_vertices)); auto result = request_router.CreateVertices(std::move(new_vertices));
MG_ASSERT(result.size() == 2); MG_ASSERT(result.size() == 2);
} }
void TestCreateExpand(query::v2::RequestRouterInterface &request_router) { void TestCreateExpand(query::v2::RequestRouterInterface &request_router) {
using PropVal = msgs::Value; using PropVal = msgs::Value;
msgs::ExecutionState<msgs::CreateExpandRequest> state; msgs::CreateExpandRequest state;
std::vector<msgs::NewExpand> new_expands; std::vector<msgs::NewExpand> new_expands;
const auto edge_type_id = request_router.NameToEdgeType("edge_type"); const auto edge_type_id = request_router.NameToEdgeType("edge_type");
@ -203,24 +195,42 @@ void TestCreateExpand(query::v2::RequestRouterInterface &request_router) {
new_expands.push_back(std::move(expand_1)); new_expands.push_back(std::move(expand_1));
new_expands.push_back(std::move(expand_2)); new_expands.push_back(std::move(expand_2));
auto responses = request_router.Request(state, std::move(new_expands)); auto responses = request_router.CreateExpand(std::move(new_expands));
MG_ASSERT(responses.size() == 2); MG_ASSERT(responses.size() == 2);
MG_ASSERT(responses[0].success); MG_ASSERT(!responses[0].error);
MG_ASSERT(responses[1].success); MG_ASSERT(!responses[1].error);
} }
void TestExpandOne(query::v2::RequestRouterInterface &request_router) { void TestExpandOne(query::v2::RequestRouterInterface &request_router) {
msgs::ExecutionState<msgs::ExpandOneRequest> state{}; msgs::ExpandOneRequest state{};
msgs::ExpandOneRequest request; msgs::ExpandOneRequest request;
const auto edge_type_id = request_router.NameToEdgeType("edge_type"); const auto edge_type_id = request_router.NameToEdgeType("edge_type");
const auto label = msgs::Label{request_router.NameToLabel("test_label")}; const auto label = msgs::Label{request_router.NameToLabel("test_label")};
request.src_vertices.push_back(msgs::VertexId{label, {msgs::Value(int64_t(0)), msgs::Value(int64_t(0))}}); request.src_vertices.push_back(msgs::VertexId{label, {msgs::Value(int64_t(0)), msgs::Value(int64_t(0))}});
request.edge_types.push_back(msgs::EdgeType{edge_type_id}); request.edge_types.push_back(msgs::EdgeType{edge_type_id});
request.direction = msgs::EdgeDirection::BOTH; request.direction = msgs::EdgeDirection::BOTH;
auto result_rows = request_router.Request(state, std::move(request)); auto result_rows = request_router.ExpandOne(std::move(request));
MG_ASSERT(result_rows.size() == 2); MG_ASSERT(result_rows.size() == 2);
} }
void TestGetProperties(query::v2::RequestRouterInterface &request_router) {
using PropVal = msgs::Value;
auto label_id = request_router.NameToLabel("test_label");
msgs::VertexId v0{{label_id}, {PropVal(int64_t(0)), PropVal(int64_t(0))}};
msgs::VertexId v1{{label_id}, {PropVal(int64_t(1)), PropVal(int64_t(0))}};
msgs::VertexId v2{{label_id}, {PropVal(int64_t(13)), PropVal(int64_t(13))}};
msgs::GetPropertiesRequest request;
request.vertex_ids.push_back({v0});
request.vertex_ids.push_back({v1});
request.vertex_ids.push_back({v2});
auto result = request_router.GetProperties(std::move(request));
MG_ASSERT(result.size() == 3);
}
template <typename RequestRouter> template <typename RequestRouter>
void TestAggregate(RequestRouter &request_router) {} void TestAggregate(RequestRouter &request_router) {}
@ -338,11 +348,14 @@ void DoTest() {
CoordinatorClient<SimulatorTransport> coordinator_client(cli_io, c_addrs[0], c_addrs); CoordinatorClient<SimulatorTransport> coordinator_client(cli_io, c_addrs[0], c_addrs);
query::v2::RequestRouter<SimulatorTransport> request_router(std::move(coordinator_client), std::move(cli_io)); query::v2::RequestRouter<SimulatorTransport> request_router(std::move(coordinator_client), std::move(cli_io));
std::function<bool()> tick_simulator = simulator.GetSimulatorTickClosure();
request_router.InstallSimulatorTicker(tick_simulator);
request_router.StartTransaction(); request_router.StartTransaction();
TestScanVertices(request_router); TestScanVertices(request_router);
TestCreateVertices(request_router); TestCreateVertices(request_router);
TestCreateExpand(request_router); TestCreateExpand(request_router);
TestGetProperties(request_router);
simulator.ShutDown(); simulator.ShutDown();
@ -359,4 +372,7 @@ void DoTest() {
} }
} // namespace memgraph::query::v2::tests } // namespace memgraph::query::v2::tests
int main() { memgraph::query::v2::tests::DoTest(); } int main() {
spdlog::cfg::load_env_levels();
memgraph::query::v2::tests::DoTest();
}

View File

@ -480,6 +480,65 @@ std::tuple<size_t, std::optional<msgs::VertexId>> AttemptToScanAllWithExpression
} }
} }
msgs::GetPropertiesResponse AttemptToGetProperties(
ShardClient &client, std::optional<std::vector<PropertyId>> properties, std::vector<msgs::VertexId> vertices,
std::vector<msgs::EdgeId> edges, std::optional<size_t> limit = std::nullopt,
std::optional<uint64_t> filter_prop = std::nullopt, bool edge = false,
std::optional<std::string> order_by = std::nullopt) {
msgs::GetPropertiesRequest req{};
req.transaction_id.logical_id = GetTransactionId();
req.property_ids = std::move(properties);
if (filter_prop) {
std::string filter_expr = (!edge) ? "MG_SYMBOL_NODE.prop1 >= " : "MG_SYMBOL_EDGE.e_prop = ";
filter_expr += std::to_string(*filter_prop);
req.filter = std::make_optional(std::move(filter_expr));
}
if (order_by) {
std::string filter_expr = (!edge) ? "MG_SYMBOL_NODE." : "MG_SYMBOL_EDGE.";
filter_expr += *order_by;
msgs::OrderBy order_by{.expression = {std::move(filter_expr)}, .direction = msgs::OrderingDirection::DESCENDING};
std::vector<msgs::OrderBy> request_order_by;
request_order_by.push_back(std::move(order_by));
req.order_by = std::move(request_order_by);
}
if (limit) {
req.limit = limit;
}
req.expressions = {std::string("5 = 5")};
std::vector<msgs::VertexId> req_v;
std::vector<msgs::EdgeId> req_e;
for (auto &v : vertices) {
req_v.push_back(std::move(v));
}
for (auto &e : edges) {
req_e.push_back(std::move(e));
}
if (!edges.empty()) {
MG_ASSERT(edges.size() == vertices.size());
size_t id = 0;
req.vertices_and_edges.reserve(req_v.size());
for (auto &v : req_v) {
req.vertices_and_edges.push_back({std::move(v), std::move(req_e[id++])});
}
} else {
req.vertex_ids = std::move(req_v);
}
while (true) {
auto read_res = client.SendReadRequest(req);
if (read_res.HasError()) {
continue;
}
auto write_response_result = read_res.GetValue();
auto write_response = std::get<msgs::GetPropertiesResponse>(write_response_result);
return write_response;
}
}
void AttemptToScanAllWithOrderByOnPrimaryProperty(ShardClient &client, msgs::VertexId start_id, uint64_t batch_limit) { void AttemptToScanAllWithOrderByOnPrimaryProperty(ShardClient &client, msgs::VertexId start_id, uint64_t batch_limit) {
msgs::ScanVerticesRequest scan_req; msgs::ScanVerticesRequest scan_req;
scan_req.batch_limit = batch_limit; scan_req.batch_limit = batch_limit;
@ -1204,6 +1263,205 @@ void TestExpandOneGraphTwo(ShardClient &client) {
} }
} }
void TestGetProperties(ShardClient &client) {
const auto unique_prop_val_1 = GetUniqueInteger();
const auto unique_prop_val_2 = GetUniqueInteger();
const auto unique_prop_val_3 = GetUniqueInteger();
const auto unique_prop_val_4 = GetUniqueInteger();
const auto unique_prop_val_5 = GetUniqueInteger();
MG_ASSERT(AttemptToCreateVertex(client, unique_prop_val_1));
MG_ASSERT(AttemptToCreateVertex(client, unique_prop_val_2));
MG_ASSERT(AttemptToCreateVertex(client, unique_prop_val_3));
MG_ASSERT(AttemptToCreateVertex(client, unique_prop_val_4));
MG_ASSERT(AttemptToCreateVertex(client, unique_prop_val_5));
const msgs::Label prim_label = {.id = get_primary_label()};
const msgs::PrimaryKey prim_key = {msgs::Value(static_cast<int64_t>(unique_prop_val_1))};
const msgs::VertexId v_id = {prim_label, prim_key};
const msgs::PrimaryKey prim_key_2 = {msgs::Value(static_cast<int64_t>(unique_prop_val_2))};
const msgs::VertexId v_id_2 = {prim_label, prim_key_2};
const msgs::PrimaryKey prim_key_3 = {msgs::Value(static_cast<int64_t>(unique_prop_val_3))};
const msgs::VertexId v_id_3 = {prim_label, prim_key_3};
const msgs::PrimaryKey prim_key_4 = {msgs::Value(static_cast<int64_t>(unique_prop_val_4))};
const msgs::VertexId v_id_4 = {prim_label, prim_key_4};
const msgs::PrimaryKey prim_key_5 = {msgs::Value(static_cast<int64_t>(unique_prop_val_5))};
const msgs::VertexId v_id_5 = {prim_label, prim_key_5};
const auto prop_id_2 = PropertyId::FromUint(2);
const auto prop_id_4 = PropertyId::FromUint(4);
const auto prop_id_5 = PropertyId::FromUint(5);
// No properties
{
const auto result = AttemptToGetProperties(client, {{}}, {v_id, v_id_2}, {});
MG_ASSERT(!result.error);
MG_ASSERT(result.result_row.size() == 2);
for (const auto &elem : result.result_row) {
MG_ASSERT(elem.props.size() == 0);
}
}
// All properties
{
const auto result = AttemptToGetProperties(client, std::nullopt, {v_id, v_id_2}, {});
MG_ASSERT(!result.error);
MG_ASSERT(result.result_row.size() == 2);
for (const auto &elem : result.result_row) {
MG_ASSERT(elem.props.size() == 3);
}
}
{
// Specific properties
const auto result =
AttemptToGetProperties(client, std::vector{prop_id_2, prop_id_4, prop_id_5}, {v_id, v_id_2, v_id_3}, {});
MG_ASSERT(!result.error);
MG_ASSERT(!result.result_row.empty());
MG_ASSERT(result.result_row.size() == 3);
for (const auto &elem : result.result_row) {
MG_ASSERT(elem.props.size() == 3);
}
}
{
// Two properties from two vertices with a filter on unique_prop_5
const auto result = AttemptToGetProperties(client, std::vector{prop_id_2, prop_id_4}, {v_id, v_id_2, v_id_5}, {},
std::nullopt, unique_prop_val_5);
MG_ASSERT(!result.error);
MG_ASSERT(result.result_row.size() == 1);
}
{
// One property from three vertices.
const auto result = AttemptToGetProperties(client, std::vector{prop_id_2}, {v_id, v_id_2, v_id_3}, {});
MG_ASSERT(!result.error);
MG_ASSERT(result.result_row.size() == 3);
MG_ASSERT(result.result_row[0].props.size() == 1);
MG_ASSERT(result.result_row[1].props.size() == 1);
MG_ASSERT(result.result_row[2].props.size() == 1);
}
{
// Same as before but with limit of 1 row
const auto result = AttemptToGetProperties(client, std::vector{prop_id_2}, {v_id, v_id_2, v_id_3}, {},
std::make_optional<size_t>(1));
MG_ASSERT(!result.error);
MG_ASSERT(result.result_row.size() == 1);
}
{
// Same as before but with a limit greater than the elements returned
const auto result = AttemptToGetProperties(client, std::vector{prop_id_2}, std::vector{v_id, v_id_2, v_id_3}, {},
std::make_optional<size_t>(5));
MG_ASSERT(!result.error);
MG_ASSERT(result.result_row.size() == 3);
}
{
// Order by on `prop1` (descending)
const auto result = AttemptToGetProperties(client, std::vector{prop_id_2}, {v_id, v_id_2, v_id_3}, {}, std::nullopt,
std::nullopt, false, "prop1");
MG_ASSERT(!result.error);
MG_ASSERT(result.result_row.size() == 3);
MG_ASSERT(result.result_row[0].vertex == v_id_3);
MG_ASSERT(result.result_row[1].vertex == v_id_2);
MG_ASSERT(result.result_row[2].vertex == v_id);
}
{
// Order by and filter on >= unique_prop_val_3 && assert result row data members
const auto result = AttemptToGetProperties(client, std::vector{prop_id_2}, {v_id, v_id_2, v_id_3, v_id_4, v_id_5},
{}, std::nullopt, unique_prop_val_3, false, "prop1");
MG_ASSERT(!result.error);
MG_ASSERT(result.result_row.size() == 3);
MG_ASSERT(result.result_row[0].vertex == v_id_5);
MG_ASSERT(result.result_row[0].props.size() == 1);
MG_ASSERT(result.result_row[0].props.front().second == prim_key_5.front());
MG_ASSERT(result.result_row[0].props.size() == 1);
MG_ASSERT(result.result_row[0].props.front().first == prop_id_2);
MG_ASSERT(result.result_row[0].evaluated_expressions.size() == 1);
MG_ASSERT(result.result_row[0].evaluated_expressions.front() == msgs::Value(true));
MG_ASSERT(result.result_row[1].vertex == v_id_4);
MG_ASSERT(result.result_row[1].props.size() == 1);
MG_ASSERT(result.result_row[1].props.front().second == prim_key_4.front());
MG_ASSERT(result.result_row[1].props.size() == 1);
MG_ASSERT(result.result_row[1].props.front().first == prop_id_2);
MG_ASSERT(result.result_row[1].evaluated_expressions.size() == 1);
MG_ASSERT(result.result_row[1].evaluated_expressions.front() == msgs::Value(true));
MG_ASSERT(result.result_row[2].vertex == v_id_3);
MG_ASSERT(result.result_row[2].props.size() == 1);
MG_ASSERT(result.result_row[2].props.front().second == prim_key_3.front());
MG_ASSERT(result.result_row[2].props.size() == 1);
MG_ASSERT(result.result_row[2].props.front().first == prop_id_2);
MG_ASSERT(result.result_row[2].evaluated_expressions.size() == 1);
MG_ASSERT(result.result_row[2].evaluated_expressions.front() == msgs::Value(true));
}
// Edges
const auto edge_gid = GetUniqueInteger();
const auto edge_type_id = EdgeTypeId::FromUint(GetUniqueInteger());
const auto unique_edge_prop_id = 7;
const auto edge_prop_val = GetUniqueInteger();
MG_ASSERT(AttemptToAddEdgeWithProperties(client, unique_prop_val_1, unique_prop_val_2, edge_gid, unique_edge_prop_id,
edge_prop_val, {edge_type_id}));
const auto edge_gid_2 = GetUniqueInteger();
const auto edge_prop_val_2 = GetUniqueInteger();
MG_ASSERT(AttemptToAddEdgeWithProperties(client, unique_prop_val_3, unique_prop_val_4, edge_gid_2,
unique_edge_prop_id, edge_prop_val_2, {edge_type_id}));
const auto edge_prop_id = PropertyId::FromUint(unique_edge_prop_id);
std::vector<msgs::EdgeId> edge_ids = {{edge_gid}, {edge_gid_2}};
// No properties
{
const auto result = AttemptToGetProperties(client, {{}}, {v_id_2, v_id_3}, edge_ids);
MG_ASSERT(!result.error);
MG_ASSERT(result.result_row.size() == 2);
for (const auto &elem : result.result_row) {
MG_ASSERT(elem.props.size() == 0);
}
}
// All properties
{
const auto result = AttemptToGetProperties(client, std::nullopt, {v_id_2, v_id_3}, edge_ids);
MG_ASSERT(!result.error);
MG_ASSERT(result.result_row.size() == 2);
for (const auto &elem : result.result_row) {
MG_ASSERT(elem.props.size() == 1);
}
}
// Properties for two vertices
{
const auto result = AttemptToGetProperties(client, std::vector{edge_prop_id}, {v_id_2, v_id_3}, edge_ids);
MG_ASSERT(!result.error);
MG_ASSERT(result.result_row.size() == 2);
}
// Filter
{
const auto result = AttemptToGetProperties(client, std::vector{edge_prop_id}, {v_id_2, v_id_3}, edge_ids, {},
{edge_prop_val}, true);
MG_ASSERT(!result.error);
MG_ASSERT(result.result_row.size() == 1);
MG_ASSERT(result.result_row.front().edge);
MG_ASSERT(result.result_row.front().edge.value().gid == edge_gid);
MG_ASSERT(result.result_row.front().props.size() == 1);
MG_ASSERT(result.result_row.front().props.front().second == msgs::Value(static_cast<int64_t>(edge_prop_val)));
}
// Order by
{
const auto result =
AttemptToGetProperties(client, std::vector{edge_prop_id}, {v_id_2, v_id_3}, edge_ids, {}, {}, true, "e_prop");
MG_ASSERT(!result.error);
MG_ASSERT(result.result_row.size() == 2);
MG_ASSERT(result.result_row[0].vertex == v_id_3);
MG_ASSERT(result.result_row[0].edge);
MG_ASSERT(result.result_row[0].edge.value().gid == edge_gid_2);
MG_ASSERT(result.result_row[0].props.size() == 1);
MG_ASSERT(result.result_row[0].props.front().second == msgs::Value(static_cast<int64_t>(edge_prop_val_2)));
MG_ASSERT(result.result_row[0].evaluated_expressions.size() == 1);
MG_ASSERT(result.result_row[0].evaluated_expressions.front() == msgs::Value(true));
MG_ASSERT(result.result_row[1].vertex == v_id_2);
MG_ASSERT(result.result_row[1].edge);
MG_ASSERT(result.result_row[1].edge.value().gid == edge_gid);
MG_ASSERT(result.result_row[1].props.size() == 1);
MG_ASSERT(result.result_row[1].props.front().second == msgs::Value(static_cast<int64_t>(edge_prop_val)));
MG_ASSERT(result.result_row[1].evaluated_expressions.size() == 1);
MG_ASSERT(result.result_row[1].evaluated_expressions.front() == msgs::Value(true));
}
}
} // namespace } // namespace
int TestMessages() { int TestMessages() {
@ -1242,9 +1500,12 @@ int TestMessages() {
auto shard_ptr2 = std::make_unique<Shard>(get_primary_label(), min_prim_key, max_prim_key, schema_prop); auto shard_ptr2 = std::make_unique<Shard>(get_primary_label(), min_prim_key, max_prim_key, schema_prop);
auto shard_ptr3 = std::make_unique<Shard>(get_primary_label(), min_prim_key, max_prim_key, schema_prop); auto shard_ptr3 = std::make_unique<Shard>(get_primary_label(), min_prim_key, max_prim_key, schema_prop);
shard_ptr1->StoreMapping({{1, "label"}, {2, "prop1"}, {3, "label1"}, {4, "prop2"}, {5, "prop3"}, {6, "prop4"}}); shard_ptr1->StoreMapping(
shard_ptr2->StoreMapping({{1, "label"}, {2, "prop1"}, {3, "label1"}, {4, "prop2"}, {5, "prop3"}, {6, "prop4"}}); {{1, "label"}, {2, "prop1"}, {3, "label1"}, {4, "prop2"}, {5, "prop3"}, {6, "prop4"}, {7, "e_prop"}});
shard_ptr3->StoreMapping({{1, "label"}, {2, "prop1"}, {3, "label1"}, {4, "prop2"}, {5, "prop3"}, {6, "prop4"}}); shard_ptr2->StoreMapping(
{{1, "label"}, {2, "prop1"}, {3, "label1"}, {4, "prop2"}, {5, "prop3"}, {6, "prop4"}, {7, "e_prop"}});
shard_ptr3->StoreMapping(
{{1, "label"}, {2, "prop1"}, {3, "label1"}, {4, "prop2"}, {5, "prop3"}, {6, "prop4"}, {7, "e_prop"}});
std::vector<Address> address_for_1{shard_server_2_address, shard_server_3_address}; std::vector<Address> address_for_1{shard_server_2_address, shard_server_3_address};
std::vector<Address> address_for_2{shard_server_1_address, shard_server_3_address}; std::vector<Address> address_for_2{shard_server_1_address, shard_server_3_address};
@ -1286,6 +1547,8 @@ int TestMessages() {
TestExpandOneGraphOne(client); TestExpandOneGraphOne(client);
TestExpandOneGraphTwo(client); TestExpandOneGraphTwo(client);
// GetProperties tests
TestGetProperties(client);
simulator.ShutDown(); simulator.ShutDown();
SimulatorStats stats = simulator.Stats(); SimulatorStats stats = simulator.Stats();

View File

@ -40,8 +40,8 @@ using memgraph::coordinator::CoordinatorRsm;
using memgraph::coordinator::HlcRequest; using memgraph::coordinator::HlcRequest;
using memgraph::coordinator::HlcResponse; using memgraph::coordinator::HlcResponse;
using memgraph::coordinator::PrimaryKey; using memgraph::coordinator::PrimaryKey;
using memgraph::coordinator::Shard;
using memgraph::coordinator::ShardMap; using memgraph::coordinator::ShardMap;
using memgraph::coordinator::ShardMetadata;
using memgraph::coordinator::Shards; using memgraph::coordinator::Shards;
using memgraph::coordinator::Status; using memgraph::coordinator::Status;
using memgraph::io::Address; using memgraph::io::Address;
@ -109,7 +109,7 @@ ShardMap CreateDummyShardmap(Address a_io_1, Address a_io_2, Address a_io_3, Add
AddressAndStatus aas1_2{.address = a_io_2, .status = Status::CONSENSUS_PARTICIPANT}; AddressAndStatus aas1_2{.address = a_io_2, .status = Status::CONSENSUS_PARTICIPANT};
AddressAndStatus aas1_3{.address = a_io_3, .status = Status::CONSENSUS_PARTICIPANT}; AddressAndStatus aas1_3{.address = a_io_3, .status = Status::CONSENSUS_PARTICIPANT};
Shard shard1 = {aas1_1, aas1_2, aas1_3}; ShardMetadata shard1 = ShardMetadata{.peers = {aas1_1, aas1_2, aas1_3}, .version = 1};
const auto key1 = PropertyValue(0); const auto key1 = PropertyValue(0);
const auto key2 = PropertyValue(0); const auto key2 = PropertyValue(0);
@ -121,7 +121,7 @@ ShardMap CreateDummyShardmap(Address a_io_1, Address a_io_2, Address a_io_3, Add
AddressAndStatus aas2_2{.address = b_io_2, .status = Status::CONSENSUS_PARTICIPANT}; AddressAndStatus aas2_2{.address = b_io_2, .status = Status::CONSENSUS_PARTICIPANT};
AddressAndStatus aas2_3{.address = b_io_3, .status = Status::CONSENSUS_PARTICIPANT}; AddressAndStatus aas2_3{.address = b_io_3, .status = Status::CONSENSUS_PARTICIPANT};
Shard shard2 = {aas2_1, aas2_2, aas2_3}; ShardMetadata shard2 = ShardMetadata{.peers = {aas2_1, aas2_2, aas2_3}, .version = 1};
auto key3 = PropertyValue(12); auto key3 = PropertyValue(12);
auto key4 = PropertyValue(13); auto key4 = PropertyValue(13);
@ -131,10 +131,10 @@ ShardMap CreateDummyShardmap(Address a_io_1, Address a_io_2, Address a_io_3, Add
return sm; return sm;
} }
std::optional<ShardClient *> DetermineShardLocation(const Shard &target_shard, const std::vector<Address> &a_addrs, std::optional<ShardClient *> DetermineShardLocation(const ShardMetadata &target_shard,
ShardClient &a_client, const std::vector<Address> &b_addrs, const std::vector<Address> &a_addrs, ShardClient &a_client,
ShardClient &b_client) { const std::vector<Address> &b_addrs, ShardClient &b_client) {
for (const auto &addr : target_shard) { for (const auto &addr : target_shard.peers) {
if (addr.address == b_addrs[0]) { if (addr.address == b_addrs[0]) {
return &b_client; return &b_client;
} }
@ -275,7 +275,7 @@ int main() {
const PrimaryKey compound_key = {cm_key_1, cm_key_2}; const PrimaryKey compound_key = {cm_key_1, cm_key_2};
// Look for Shard // Look for ShardMetadata
BasicResult<TimedOut, memgraph::coordinator::CoordinatorWriteResponses> read_res = BasicResult<TimedOut, memgraph::coordinator::CoordinatorWriteResponses> read_res =
coordinator_client.SendWriteRequest(req); coordinator_client.SendWriteRequest(req);

View File

@ -49,8 +49,8 @@ using coordinator::GetShardMapRequest;
using coordinator::GetShardMapResponse; using coordinator::GetShardMapResponse;
using coordinator::Hlc; using coordinator::Hlc;
using coordinator::HlcResponse; using coordinator::HlcResponse;
using coordinator::Shard;
using coordinator::ShardMap; using coordinator::ShardMap;
using coordinator::ShardMetadata;
using io::Address; using io::Address;
using io::Io; using io::Io;
using io::rsm::RsmClient; using io::rsm::RsmClient;
@ -246,6 +246,8 @@ std::pair<SimulatorStats, LatencyHistogramSummaries> RunClusterSimulation(const
WaitForShardsToInitialize(coordinator_client); WaitForShardsToInitialize(coordinator_client);
query::v2::RequestRouter<SimulatorTransport> request_router(std::move(coordinator_client), std::move(cli_io)); query::v2::RequestRouter<SimulatorTransport> request_router(std::move(coordinator_client), std::move(cli_io));
std::function<bool()> tick_simulator = simulator.GetSimulatorTickClosure();
request_router.InstallSimulatorTicker(tick_simulator);
request_router.StartTransaction(); request_router.StartTransaction();

View File

@ -28,13 +28,19 @@ void Wait(Future<std::string> future_1, Promise<std::string> promise_2) {
TEST(Future, BasicLifecycle) { TEST(Future, BasicLifecycle) {
std::atomic_bool waiting = false; std::atomic_bool waiting = false;
std::atomic_bool filled = false;
std::function<bool()> notifier = [&] { std::function<bool()> wait_notifier = [&] {
waiting.store(true, std::memory_order_seq_cst); waiting.store(true, std::memory_order_seq_cst);
return false; return false;
}; };
auto [future_1, promise_1] = FuturePromisePairWithNotifier<std::string>(notifier); std::function<bool()> fill_notifier = [&] {
filled.store(true, std::memory_order_seq_cst);
return false;
};
auto [future_1, promise_1] = FuturePromisePairWithNotifications<std::string>(wait_notifier, fill_notifier);
auto [future_2, promise_2] = FuturePromisePair<std::string>(); auto [future_2, promise_2] = FuturePromisePair<std::string>();
std::jthread t1(Wait, std::move(future_1), std::move(promise_2)); std::jthread t1(Wait, std::move(future_1), std::move(promise_2));
@ -50,6 +56,8 @@ TEST(Future, BasicLifecycle) {
t1.join(); t1.join();
t2.join(); t2.join();
EXPECT_TRUE(filled.load(std::memory_order_acquire));
std::string result_2 = std::move(future_2).Wait(); std::string result_2 = std::move(future_2).Wait();
EXPECT_TRUE(result_2 == "it worked"); EXPECT_TRUE(result_2 == "it worked");
} }

View File

@ -44,8 +44,8 @@ using coordinator::GetShardMapRequest;
using coordinator::GetShardMapResponse; using coordinator::GetShardMapResponse;
using coordinator::Hlc; using coordinator::Hlc;
using coordinator::HlcResponse; using coordinator::HlcResponse;
using coordinator::Shard;
using coordinator::ShardMap; using coordinator::ShardMap;
using coordinator::ShardMetadata;
using io::Address; using io::Address;
using io::Io; using io::Io;
using io::local_transport::LocalSystem; using io::local_transport::LocalSystem;
@ -194,7 +194,8 @@ void ExecuteOp(query::v2::RequestRouter<LocalTransport> &request_router, std::se
ScanAll scan_all) { ScanAll scan_all) {
auto results = request_router.ScanVertices("test_label"); auto results = request_router.ScanVertices("test_label");
MG_ASSERT(results.size() == correctness_model.size()); spdlog::error("got {} results, model size is {}", results.size(), correctness_model.size());
EXPECT_EQ(results.size(), correctness_model.size());
for (const auto &vertex_accessor : results) { for (const auto &vertex_accessor : results) {
const auto properties = vertex_accessor.Properties(); const auto properties = vertex_accessor.Properties();

View File

@ -45,8 +45,8 @@ using memgraph::coordinator::CoordinatorWriteRequests;
using memgraph::coordinator::CoordinatorWriteResponses; using memgraph::coordinator::CoordinatorWriteResponses;
using memgraph::coordinator::Hlc; using memgraph::coordinator::Hlc;
using memgraph::coordinator::HlcResponse; using memgraph::coordinator::HlcResponse;
using memgraph::coordinator::Shard;
using memgraph::coordinator::ShardMap; using memgraph::coordinator::ShardMap;
using memgraph::coordinator::ShardMetadata;
using memgraph::io::Io; using memgraph::io::Io;
using memgraph::io::local_transport::LocalSystem; using memgraph::io::local_transport::LocalSystem;
using memgraph::io::local_transport::LocalTransport; using memgraph::io::local_transport::LocalTransport;

View File

@ -51,6 +51,8 @@ using memgraph::msgs::CreateVerticesResponse;
using memgraph::msgs::ExpandOneRequest; using memgraph::msgs::ExpandOneRequest;
using memgraph::msgs::ExpandOneResponse; using memgraph::msgs::ExpandOneResponse;
using memgraph::msgs::ExpandOneResultRow; using memgraph::msgs::ExpandOneResultRow;
using memgraph::msgs::GetPropertiesRequest;
using memgraph::msgs::GetPropertiesResultRow;
using memgraph::msgs::NewExpand; using memgraph::msgs::NewExpand;
using memgraph::msgs::NewVertex; using memgraph::msgs::NewVertex;
using memgraph::msgs::ScanVerticesRequest; using memgraph::msgs::ScanVerticesRequest;
@ -84,13 +86,16 @@ class MockedRequestRouter : public RequestRouterInterface {
void Commit() override {} void Commit() override {}
std::vector<VertexAccessor> ScanVertices(std::optional<std::string> /* label */) override { return {}; } std::vector<VertexAccessor> ScanVertices(std::optional<std::string> /* label */) override { return {}; }
std::vector<CreateVerticesResponse> CreateVertices(std::vector<memgraph::msgs::NewVertex> new_vertices) override { std::vector<CreateVerticesResponse> CreateVertices(
std::vector<memgraph::msgs::NewVertex> /* new_vertices */) override {
return {}; return {};
} }
std::vector<ExpandOneResultRow> ExpandOne(ExpandOneRequest request) override { return {}; } std::vector<ExpandOneResultRow> ExpandOne(ExpandOneRequest /* request */) override { return {}; }
std::vector<CreateExpandResponse> CreateExpand(std::vector<NewExpand> new_edges) override { return {}; } std::vector<CreateExpandResponse> CreateExpand(std::vector<NewExpand> /* new_edges */) override { return {}; }
std::vector<GetPropertiesResultRow> GetProperties(GetPropertiesRequest rqst) override { return {}; }
const std::string &PropertyToName(memgraph::storage::v3::PropertyId id) const override { const std::string &PropertyToName(memgraph::storage::v3::PropertyId id) const override {
return properties_.IdToName(id.AsUint()); return properties_.IdToName(id.AsUint());

View File

@ -0,0 +1,185 @@
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
####################################
# Benchmark datastructures analyzer
####################################
# This scripts uses the output from dataset benchmark tests to plot charts
# comparing the results of different datastructures on the same operation.
#
# Note: Naming the tests is very important in order for this script to recognize
# which operation is being performed and on which DS, so it should come in this
# form: BM_Benchmark<Operation><Datastructure>/<RunArgument>
# where run_argument will be added automatically by google benchmark framework
import argparse
import json
import sys
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional
import matplotlib.pyplot as plt
class Operation(Enum):
CONTAINS = "contains"
FIND = "find"
INSERT = "insert"
RANDOM = "random"
REMOVE = "remove"
@classmethod
def to_list(cls) -> List[str]:
return list(map(lambda c: c.value, cls))
@staticmethod
def get(s: str) -> Optional["Operation"]:
try:
return Operation[s.upper()]
except ValueError:
return None
def __str__(self):
return str(self.value)
@dataclass(frozen=True)
class BenchmarkRow:
name: str
datastructure: str
operation: Operation
real_time: int
cpu_time: int
iterations: int
time_unit: str
run_arg: Optional[Any]
class GoogleBenchmarkResult:
def __init__(self):
self._operation = None
self._datastructures: Dict[str, List[BenchmarkRow]] = dict()
def add_result(self, row: BenchmarkRow) -> None:
if self._operation is None:
self._operation = row.operation
assert self._operation is row.operation
if row.datastructure not in self._datastructures:
self._datastructures[row.datastructure] = [row]
else:
self._datastructures[row.datastructure].append(row)
@property
def operation(self) -> Optional[Operation]:
return self._operation
@property
def datastructures(self) -> Dict[str, List[BenchmarkRow]]:
return self._datastructures
def get_operation(s: str) -> Operation:
for op in Operation.to_list():
if op.lower() in s.lower():
operation_enum = Operation.get(op)
if operation_enum is not None:
return operation_enum
else:
print("Operation not found!")
sys.exit(1)
print("Operation not found!")
sys.exit(1)
def get_row_data(line: Dict[str, Any]) -> BenchmarkRow:
"""
Naming is very important, first must come an Operation name, and then a data
structure to test.
"""
full_name = line["name"].split("BM_Benchmark")[1]
name_with_run_arg = full_name.split("/")
operation = get_operation(name_with_run_arg[0])
datastructure = name_with_run_arg[0].split(operation.value.capitalize())[1]
run_arg = None
if len(name_with_run_arg) > 1:
run_arg = name_with_run_arg[1]
return BenchmarkRow(
name_with_run_arg[0],
datastructure,
operation,
line["real_time"],
line["cpu_time"],
line["iterations"],
line["time_unit"],
run_arg,
)
def get_benchmark_res(args) -> Optional[GoogleBenchmarkResult]:
file_path = Path(args.log_file)
if not file_path.exists():
print("Error file {file_path} not found!")
return None
with file_path.open("r") as file:
data = json.load(file)
res = GoogleBenchmarkResult()
assert "benchmarks" in data, "There must be a benchmark list inside"
for benchmark in data["benchmarks"]:
res.add_result(get_row_data(benchmark))
return res
def plot_operation(results: GoogleBenchmarkResult, save: bool) -> None:
colors = ["red", "green", "blue", "yellow", "purple", "brown"]
assert results.operation is not None
fig = plt.figure()
for ds, benchmarks in results.datastructures.items():
if benchmarks:
# Print line chart
x_axis = [elem.real_time for elem in benchmarks]
y_axis = [elem.run_arg for elem in benchmarks]
plt.plot(x_axis, y_axis, marker="", color=colors.pop(0), linewidth="2", label=f"{ds}")
plt.title(f"Benchmark results for operation {results.operation.value}")
plt.xlabel(f"Time [{benchmarks[0].time_unit}]")
plt.grid(True)
plt.legend()
plt.draw()
else:
print(f"Nothing to do for {ds}...")
if save:
plt.savefig(f"{results.operation.value}.png")
plt.close(fig)
else:
plt.show()
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Process benchmark results.")
parser.add_argument("--log_file", type=str)
parser.add_argument("--save", type=bool, default=True)
return parser.parse_args()
def main():
args = parse_args()
res = get_benchmark_res(args)
if res is None:
print("Failed to get results from log file!")
sys.exit(1)
plot_operation(res, args.save)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,52 @@
#!/bin/bash
set -euox pipefail
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
WORKSPACE_DIR=${SCRIPT_DIR}/../../
CPUS=$(grep -c processor < /proc/cpuinfo)
# Get all benchmark files
BENCHMARK_FILES=$(find ${WORKSPACE_DIR}/tests/benchmark -type f -iname "data_structures_*")
function test_all() {
for BENCH_FILE in ${BENCHMARK_FILES[@]}; do
local BASE_NAME=$(basename $BENCH_FILE)
local NAME=${BASE_NAME%%.*}
echo "Running $NAME"
local TEST_FILE=${WORKSPACE_DIR}/build/tests/benchmark/${NAME}
if [[ -f "${TEST_FILE}" ]]; then
pushd ${WORKSPACE_DIR}/build
make -j${CPUS} memgraph__benchmark__${NAME}
popd
local JSON_OUTPUT=${NAME}_output.json
# Run benchmakr test
${WORKSPACE_DIR}/build/tests/benchmark/${NAME} --benchmark_format=json --benchmark_out=${JSON_OUTPUT}
# Run analyze script for benchmark test
python3 ${WORKSPACE_DIR}/tools/plot/benchmark_datastructures.py --log_file=${JSON_OUTPUT}
else
echo "File ${TEST_FILE} does not exist!"
fi
done
}
function test_memory() {
## We are testing only insert
local DATA_STRUCTURES=(SkipList StdMap StdSet BppTree)
for DATA_STRUCTURE in ${DATA_STRUCTURES[@]}; do
valgrind --tool=massif --massif-out-file=${DATA_STRUCTURE}.massif.out ${WORKSPACE_DIR}/build/tests/benchmark/data_structures_insert --benchmark_filter=BM_BenchmarkInsert${DATA_STRUCTURE}/10000 --benchmark_format=json --benchmark_out=${DATA_STRUCTURE}.json
done
}
ARG_1=${1:-"all"}
case ${ARG_1} in
all)
test_all
;;
memory)
test_memory
;;
*)
echo "Select either `all` or `memory` benchmark!"
;;
esac