Merged experimental repo.

Summary:
Fixed distributed init.
Add CMakeLists to build experimentall/distribuedClosing unused Channels, work in progress.
Make System the owner of Reactor.
This entails changing shared_ptr -> unique_ptr and some pointers to references.
Merged experimental repository into memgraph.
Moved experimental repo to experimental directory.
Removed obsolete experimental files.
Added comments.
Merge branch 'master' of https://phabricator.memgraph.io/source/experimental
Subscription service unsubscribe.
Add Close method on EventStream.
Add placeholder for the configuration class.
Remove comments.
Merge branch 'master' of https://phabricator.memgraph.io/source/experimental
Clean-up parameters for EventQueue.
Merge branch 'master' of https://phabricator.memgraph.io/source/experimental
Add Channel::serialize method implementation.
Merge.
Add docs on event stream.
Clang-format merge conflicts.
First implementations of serialize methods.
Add hostname, port, and names as methods in Channel base class.
Add reactor name and name methods to LocalChannel.
Add reactor name to LocalChannel.
Add name to LocalChannel.
Add serialization service.
Serialize_test removed.
Merge branch 'master' of https://phabricator.memgraph.io/source/experimental
Move Message to the end of communications files.
Full example of serialization with cereal.
Fix constructor calls.
Merge branch 'master' of https://phabricator.memgraph.io/source/experimental
Avoid using `FindChannel` in the transaction code.
Merge branch 'master' of https://phabricator.memgraph.io/source/experimental
Init script creates libs folder.
Merge branch 'master' of https://phabricator.memgraph.io/source/experimental
Add System pointer to Network.
serialized_test binary is removed from the repo.
Merge branch 'master' of https://phabricator.memgraph.io/source/experimental
Cereal basic example.
Merge branch 'master' of https://phabricator.memgraph.io/source/experimental
Callbacks finished.
Always open the main channel by default.
Fixed callbacks, wrong number of emplace arguments.
Callbacks WIP.
Raise connector mutex to reactor level.
Add argument to LockedPush.
Fix data race in connector closing.
Merge branch 'master' of https://phabricator.memgraph.io/source/experimental
Add functional header.
Fixed to make the changes work.
Merge branch 'master' of https://phabricator.memgraph.io/source/experimental
Merge branch 'master' of https://phabricator.memgraph.io/source/experimental
Refactored connectors into Reactors
Use shared pointer for the mutex.
Rename to Open and Close in implementation file.
Rename Create to Open and Destroy to Close.
Merge branch 'master' of https://phabricator.memgraph.io/source/experimental
Adding callback to Reactors; work in progress
Add stubs for asynchronous channel resolution.
Add stubs for the networking service.
Replace reactor pointers with shared ptrs, disable System assignment.
Forbid assignment.
Replace raw channel pointers with shared pointers.
Replace raw event stream pointer with shared pointer.
Rename default stream name.
Use recursive mutex in System.
Main uses Spawn method. All files are formatted.
Move thread local to a cpp file.
Work in progress on Spawn method.
Merge branch 'master' of https://phabricator.memgraph.io/source/experimental
Kill out graph.hpp to make it compile
Add Spawn method prototype.
Fix return type.
Merge branch 'master' of https://phabricator.memgraph.io/source/experimental
Add method used to create nameless channels.
Add format script.
Introduce the Reactor base class.
Merge branch 'master' of https://phabricator.memgraph.io/source/experimental
Add compile script.
added comments about class terminology
Spinner rewrite (graph data structures and algo)
Organize Spinner code
Create working version
Improves Spinner implementation and testing
Spinner fix
.arcconfig
Merge branch 'master' of https://phabricator.memgraph.io/source/experimental
Add graph
Spinner work
Spinner added
Merge branch 'master' of https://phabricator.memgraph.io/source/experimental
Add communication
.clang-format + ycm config.
Init. Distributed hackaton.
Implementation of lock-free list from Petar Sirkovic.
pro compiler
Merge branch 'master' of https://phabricator.memgraph.io/source/experimental
Implement Match
Add test data.
Insert quotes before and after props and labels
Multiple node declarations, along with edges.
After merge.
Node property creations work now.
Bug fix in visitor
After merge.
Implement node creation with labels.
Implement boolean operators
Tidy up ImplementedVistor.
Implement expression6 (addition)
Implement basic type visitor functions
Cypher Visitor Implementation class created.
Fix style.
Fix template synrax in main.cpp
Merge remote-tracking branch 'origin/master'
Add pretty_print
Update main and BaseVisitor to present return value.
Headers included. Temporary fix.
Antlr4 module reintroduced.
Updateao git config.
Fix trailing space.
CMake 2.8 fix rerolled, 3.1 minimum version req.
Fix for Cmake version 2.8 compatibility.
Build works.
Tidy src folder. Include generated files for antlr.
Included antlr generated files.
Changed directory structure.
Cmake: include subdirectory.
GenerateRuntime, partial.
Add GenerateParser target to cmake.
Remove main.cpp
Merge remote-tracking branch 'origin/master'
Add requirements
Main file added. Run the lexer and parser with this.
Add antlr_generated to baby_compiler
Experimental memory_tracker and opencypher tck tests

Reviewers: mislav.bradac

Reviewed By: mislav.bradac

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D627
This commit is contained in:
Matej Ferencevic 2017-08-03 12:08:39 +02:00
parent 6bc9deba5f
commit a11c9885ad
35 changed files with 3704 additions and 0 deletions

View File

@ -210,6 +210,9 @@ message(STATUS "MEMGRAPH binary: ${MEMGRAPH}")
# proof of concept
option(POC "Build proof of concept binaries" ON)
message(STATUS "POC binaries: ${POC}")
# experimental
option(EXPERIMENTAL "Build experimental binaries" OFF)
message(STATUS "POC binaries: ${POC}")
# tests
option(HARDCODED_TARGETS "Make hardcoded query targets" ON)
message(STATUS "Make hardcoded query targets: ${HARDCODED_TARGETS}")
@ -322,6 +325,11 @@ add_dependencies(memgraph_lib generate_opencypher_parser
if (POC)
add_subdirectory(poc)
endif()
# experimental
if (EXPERIMENTAL)
add_subdirectory(experimental)
endif()
# -----------------------------------------------------------------------------
# tests

View File

@ -0,0 +1,2 @@
# distributed
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)

View File

@ -0,0 +1,8 @@
---
Language: Cpp
BasedOnStyle: Google
Standard: "C++11"
UseTab: Never
DerivePointerAlignment: false
PointerAlignment: Right
...

7
experimental/distributed/.gitignore vendored Normal file
View File

@ -0,0 +1,7 @@
*.out
*.pyc
main
libs/
*.cereal
*.backup
*.out

View File

@ -0,0 +1,152 @@
import os
import os.path
import fnmatch
import logging
import ycm_core
BASE_FLAGS = [
'-Wall',
'-Wextra',
'-Werror',
'-Wno-long-long',
'-Wno-variadic-macros',
'-fexceptions',
'-ferror-limit=10000',
'-DNDEBUG',
'-std=c++1z',
'-xc++',
'-I/usr/lib/',
'-I/usr/include/',
'-I./',
'-I./libs/cereal/include',
'-I./src'
]
SOURCE_EXTENSIONS = [
'.cpp',
'.cxx',
'.cc',
'.c',
'.m',
'.mm'
]
HEADER_EXTENSIONS = [
'.h',
'.hxx',
'.hpp',
'.hh'
]
def IsHeaderFile(filename):
extension = os.path.splitext(filename)[1]
return extension in HEADER_EXTENSIONS
def GetCompilationInfoForFile(database, filename):
if IsHeaderFile(filename):
basename = os.path.splitext(filename)[0]
for extension in SOURCE_EXTENSIONS:
replacement_file = basename + extension
if os.path.exists(replacement_file):
compilation_info = database.GetCompilationInfoForFile(replacement_file)
if compilation_info.compiler_flags_:
return compilation_info
return None
return database.GetCompilationInfoForFile(filename)
def FindNearest(path, target):
candidate = os.path.join(path, target)
if(os.path.isfile(candidate) or os.path.isdir(candidate)):
logging.info("Found nearest " + target + " at " + candidate)
return candidate;
else:
parent = os.path.dirname(os.path.abspath(path));
if(parent == path):
raise RuntimeError("Could not find " + target);
return FindNearest(parent, target)
def MakeRelativePathsInFlagsAbsolute(flags, working_directory):
if not working_directory:
return list(flags)
new_flags = []
make_next_absolute = False
path_flags = [ '-isystem', '-I', '-iquote', '--sysroot=' ]
for flag in flags:
new_flag = flag
if make_next_absolute:
make_next_absolute = False
if not flag.startswith('/'):
new_flag = os.path.join(working_directory, flag)
for path_flag in path_flags:
if flag == path_flag:
make_next_absolute = True
break
if flag.startswith(path_flag):
path = flag[ len(path_flag): ]
new_flag = path_flag + os.path.join(working_directory, path)
break
if new_flag:
new_flags.append(new_flag)
return new_flags
def FlagsForClangComplete(root):
try:
clang_complete_path = FindNearest(root, '.clang_complete')
clang_complete_flags = open(clang_complete_path, 'r').read().splitlines()
return clang_complete_flags
except:
return None
def FlagsForInclude(root):
try:
include_path = FindNearest(root, 'include')
flags = []
for dirroot, dirnames, filenames in os.walk(include_path):
for dir_path in dirnames:
real_path = os.path.join(dirroot, dir_path)
flags = flags + ["-I" + real_path]
return flags
except:
return None
def FlagsForCompilationDatabase(root, filename):
try:
compilation_db_path = FindNearest(root, 'compile_commands.json')
compilation_db_dir = os.path.dirname(compilation_db_path)
logging.info("Set compilation database directory to " + compilation_db_dir)
compilation_db = ycm_core.CompilationDatabase(compilation_db_dir)
if not compilation_db:
logging.info("Compilation database file found but unable to load")
return None
compilation_info = GetCompilationInfoForFile(compilation_db, filename)
if not compilation_info:
logging.info("No compilation info for " + filename + " in compilation database")
return None
return MakeRelativePathsInFlagsAbsolute(
compilation_info.compiler_flags_,
compilation_info.compiler_working_dir_)
except:
return None
def FlagsForFile(filename):
root = os.path.realpath(filename);
compilation_db_flags = FlagsForCompilationDatabase(root, filename)
if compilation_db_flags:
final_flags = compilation_db_flags
else:
final_flags = BASE_FLAGS
clang_flags = FlagsForClangComplete(root)
if clang_flags:
final_flags = final_flags + clang_flags
include_flags = FlagsForInclude(root)
if include_flags:
final_flags = final_flags + include_flags
return {
'flags': final_flags,
'do_cache': True
}

View File

@ -0,0 +1,38 @@
project(distributed)
# threading
find_package(Threads REQUIRED)
# set directory variables
set(src_dir ${PROJECT_SOURCE_DIR}/src)
set(libs_dir ${PROJECT_SOURCE_DIR}/libs)
# includes
include_directories(${libs_dir}/cereal/include)
include_directories(${src_dir})
# totally hacked, no idea why I need to include these again
# TODO: ask teon
include_directories(${CMAKE_SOURCE_DIR}/libs/fmt)
include_directories(${CMAKE_SOURCE_DIR}/src)
include_directories(SYSTEM ${GTEST_INCLUDE_DIRS} ${GMOCK_INCLUDE_DIRS})
include_directories(SYSTEM ${CMAKE_SOURCE_DIR}/libs)
# needed to include configured files (plan_compiler_flags.hpp)
set(generated_headers_dir ${CMAKE_BINARY_DIR}/generated_headers)
include_directories(${generated_headers_dir})
include_directories(SYSTEM ${CMAKE_SOURCE_DIR}/libs/glog/include)
include_directories(SYSTEM ${CMAKE_BINARY_DIR}/libs/gflags/include)
# library from distributed sources
file(GLOB_RECURSE src_files ${src_dir}/*.cpp)
add_library(distributed_lib STATIC ${src_files})
## executable
set(executable_name distributed)
add_executable(${executable_name} ${PROJECT_SOURCE_DIR}/main.cpp)
target_link_libraries(${executable_name} distributed_lib)
target_link_libraries(${executable_name} memgraph_lib)
target_link_libraries(${executable_name} ${MEMGRAPH_ALL_LIBS})
# tests
add_subdirectory(${PROJECT_SOURCE_DIR}/tests)

13
experimental/distributed/init Executable file
View File

@ -0,0 +1,13 @@
#!/usr/bin/env bash
working_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
libs_dir=$working_dir/libs
if [ ! -d $libs_dir ]; then
mkdir $libs_dir
fi
cd $libs_dir
git clone https://github.com/USCiLab/cereal.git
cd $libs_dir/cereal
git checkout v1.2.2

View File

@ -0,0 +1,384 @@
#include <atomic>
#include <chrono>
#include <cstdlib>
#include <iostream>
#include <string>
#include <thread>
#include <vector>
#include "communication.hpp"
const int NUM_WORKERS = 1;
class Txn : public SenderMessage {
public:
Txn(ChannelRefT channel, int64_t id) : SenderMessage(channel), id_(id) {}
int64_t id() const { return id_; }
template <class Archive>
void serialize(Archive &archive) {
archive(cereal::base_class<SenderMessage>(this), id_);
}
private:
int64_t id_;
};
class CreateNodeTxn : public Txn {
public:
CreateNodeTxn(ChannelRefT channel, int64_t id) : Txn(channel, id) {}
template <class Archive>
void serialize(Archive &archive) {
archive(cereal::base_class<Txn>(this));
}
};
class CountNodesTxn : public Txn {
public:
CountNodesTxn(ChannelRefT channel, int64_t id) : Txn(channel, id) {}
template <class Archive>
void serialize(Archive &archive) {
archive(cereal::base_class<Txn>(this));
}
};
class CountNodesTxnResult : public Message {
public:
CountNodesTxnResult(int64_t count) : count_(count) {}
int64_t count() const { return count_; }
template <class Archive>
void serialize(Archive &archive) {
archive(count_);
}
private:
int64_t count_;
};
class CommitRequest : public SenderMessage {
public:
CommitRequest(ChannelRefT sender, int64_t worker_id)
: SenderMessage(sender), worker_id_(worker_id) {}
int64_t worker_id() { return worker_id_; }
template <class Archive>
void serialize(Archive &archive) {
archive(cereal::base_class<SenderMessage>(this), worker_id_);
}
private:
int64_t worker_id_;
};
class AbortRequest : public SenderMessage {
public:
AbortRequest(ChannelRefT sender, int64_t worker_id)
: SenderMessage(sender), worker_id_(worker_id) {}
int64_t worker_id() { return worker_id_; }
template <class Archive>
void serialize(Archive &archive) {
archive(cereal::base_class<SenderMessage>(this), worker_id_);
}
private:
int64_t worker_id_;
};
class CommitDirective : public Message {
template <class Archive>
void serialize(Archive &archive) {
archive(cereal::base_class<Message>(this));
}
};
class AbortDirective : public Message {
template <class Archive>
void serialize(Archive &archive) {
archive(cereal::base_class<Message>(this));
}
};
class Query : public Message {
public:
Query(std::string query) : Message(), query_(query) {}
std::string query() const { return query_; }
template <class Archive>
void serialize(Archive &archive) {
archive(cereal::base_class<Message>(this), query_);
}
private:
std::string query_;
};
class Quit : public Message {
template <class Archive>
void serialize(Archive &archive) {
archive(cereal::base_class<Message>(this));
}
};
class Master : public Reactor {
public:
Master(System *system, std::string name) : Reactor(system, name), next_xid_(1) {}
virtual void Run() {
auto stream = main_.first;
FindWorkers();
std::cout << "Master is active" << std::endl;
while (true) {
auto m = stream->AwaitEvent();
if (Query *query = dynamic_cast<Query *>(m.get())) {
ProcessQuery(query);
break; // process only the first query
} else {
std::cerr << "unknown message\n";
exit(1);
}
}
stream->OnEvent([this](const Message &msg, EventStream::Subscription& subscription) {
std::cout << "Processing Query via Callback" << std::endl;
const Query &query =
dynamic_cast<const Query &>(msg); // exception bad_cast
ProcessQuery(&query);
subscription.unsubscribe();
});
}
private:
void ProcessQuery(const Query *query) {
if (query->query() == "create node") {
PerformCreateNode();
} else if (query->query() == "count nodes") {
PerformCountNodes();
} else {
std::cout << "got query: " << query->query() << std::endl;
}
}
void PerformCreateNode() {
int worker_id = rand() % NUM_WORKERS;
int64_t xid = GetTransactionId();
std::string txn_channel_name = GetTxnName(xid);
auto connector = Open(txn_channel_name);
auto stream = connector.first;
auto create_node_txn =
std::make_unique<CreateNodeTxn>(connector.second, xid);
channels_[worker_id]->Send(std::move(create_node_txn));
auto m = stream->AwaitEvent();
if (CommitRequest *req = dynamic_cast<CommitRequest *>(m.get())) {
req->sender()->Send(std::make_unique<CommitDirective>());
} else if (AbortRequest *req = dynamic_cast<AbortRequest *>(m.get())) {
req->sender()->Send(std::make_unique<AbortDirective>());
} else {
std::cerr << "unknown message\n";
exit(1);
}
Close(txn_channel_name);
}
void PerformCountNodes() {
int64_t xid = GetTransactionId();
std::string txn_channel_name = GetTxnName(xid);
auto connector = Open(txn_channel_name);
auto stream = connector.first;
for (int w_id = 0; w_id < NUM_WORKERS; ++w_id)
channels_[w_id]->Send(
std::make_unique<CountNodesTxn>(connector.second, xid));
std::vector<ChannelRefT> txn_channels;
txn_channels.resize(NUM_WORKERS, nullptr);
bool commit = true;
for (int responds = 0; responds < NUM_WORKERS; ++responds) {
auto m = stream->AwaitEvent();
if (CommitRequest *req = dynamic_cast<CommitRequest *>(m.get())) {
txn_channels[req->worker_id()] = req->sender();
commit &= true;
} else if (AbortRequest *req = dynamic_cast<AbortRequest *>(m.get())) {
txn_channels[req->worker_id()] = req->sender();
commit = false;
} else {
std::cerr << "unknown message\n";
exit(1);
}
}
if (commit) {
for (int w_id = 0; w_id < NUM_WORKERS; ++w_id)
txn_channels[w_id]->Send(std::make_unique<CommitDirective>());
} else {
for (int w_id = 0; w_id < NUM_WORKERS; ++w_id)
txn_channels[w_id]->Send(std::make_unique<AbortDirective>());
}
int64_t count = 0;
for (int responds = 0; responds < NUM_WORKERS; ++responds) {
auto m = stream->AwaitEvent();
if (CountNodesTxnResult *cnt =
dynamic_cast<CountNodesTxnResult *>(m.get())) {
count += cnt->count();
} else {
std::cerr << "unknown message\n";
exit(1);
}
}
Close(txn_channel_name);
std::cout << "graph has " << count << " vertices" << std::endl;
}
int64_t GetTransactionId() { return next_xid_++; }
std::string GetWorkerName(int worker_id) {
return "worker" + std::to_string(worker_id);
}
std::string GetTxnName(int txn_id) { return "txn" + std::to_string(txn_id); }
void FindWorkers() {
channels_.resize(NUM_WORKERS, nullptr);
int workers_found = 0;
while (workers_found < NUM_WORKERS) {
for (int64_t w_id = 0; w_id < NUM_WORKERS; ++w_id) {
if (channels_[w_id] == nullptr) {
// TODO: Resolve worker channel using the network service.
channels_[w_id] = system_->FindChannel(GetWorkerName(w_id), "main");
if (channels_[w_id] != nullptr) ++workers_found;
}
}
if (workers_found < NUM_WORKERS)
std::this_thread::sleep_for(std::chrono::seconds(1));
}
}
// TODO: Why is master atomic, it should be unique?
std::atomic<int64_t> next_xid_;
std::vector<std::shared_ptr<Channel>> channels_;
};
class Worker : public Reactor {
public:
Worker(System *system, std::string name, int64_t id) : Reactor(system, name),
worker_id_(id) {}
virtual void Run() {
std::cout << "worker " << worker_id_ << " is active" << std::endl;
auto stream = main_.first;
FindMaster();
while (true) {
auto m = stream->AwaitEvent();
if (Txn *txn = dynamic_cast<Txn *>(m.get())) {
HandleTransaction(txn);
} else {
std::cerr << "unknown message\n";
exit(1);
}
}
}
private:
void HandleTransaction(Txn *txn) {
if (CreateNodeTxn *create_txn = dynamic_cast<CreateNodeTxn *>(txn)) {
HandleCreateNode(create_txn);
} else if (CountNodesTxn *cnt_txn = dynamic_cast<CountNodesTxn *>(txn)) {
HandleCountNodes(cnt_txn);
} else {
std::cerr << "unknown transaction\n";
exit(1);
}
}
void HandleCreateNode(CreateNodeTxn *txn) {
auto connector = Open(GetTxnChannelName(txn->id()));
auto stream = connector.first;
auto masterChannel = txn->sender();
// TODO: Do the actual commit.
masterChannel->Send(
std::make_unique<CommitRequest>(connector.second, worker_id_));
auto m = stream->AwaitEvent();
if (dynamic_cast<CommitDirective *>(m.get())) {
// TODO: storage_.CreateNode();
} else if (dynamic_cast<AbortDirective *>(m.get())) {
// TODO: Rollback.
} else {
std::cerr << "unknown message\n";
exit(1);
}
Close(GetTxnChannelName(txn->id()));
}
void HandleCountNodes(CountNodesTxn *txn) {
auto connector = Open(GetTxnChannelName(txn->id()));
auto stream = connector.first;
auto masterChannel = txn->sender();
// TODO: Fix this hack -- use the storage.
int num = 123;
masterChannel->Send(
std::make_unique<CommitRequest>(connector.second, worker_id_));
auto m = stream->AwaitEvent();
if (dynamic_cast<CommitDirective *>(m.get())) {
masterChannel->Send(std::make_unique<CountNodesTxnResult>(num));
} else if (dynamic_cast<AbortDirective *>(m.get())) {
// send nothing
} else {
std::cerr << "unknown message\n";
exit(1);
}
Close(GetTxnChannelName(txn->id()));
}
// TODO: Don't repeat code from Master.
std::string GetTxnChannelName(int64_t transaction_id) {
return "txn" + std::to_string(transaction_id);
}
void FindMaster() {
// TODO: Replace with network service and channel resolution.
while (!(master_channel_ = system_->FindChannel("master", "main")))
std::this_thread::sleep_for(std::chrono::seconds(1));
}
std::shared_ptr<Channel> master_channel_ = nullptr;
int worker_id_;
// Storage storage_;
};
void ClientMain(System *system) {
std::shared_ptr<Channel> channel = nullptr;
// TODO: Replace this with network channel resolution.
while (!(channel = system->FindChannel("master", "main")))
std::this_thread::sleep_for(std::chrono::seconds(1));
std::cout << "I/O Client Main active" << std::endl;
bool active = true;
while (active) {
std::string s;
std::getline(std::cin, s);
if (s == "quit") {
active = false;
channel->Send(std::make_unique<Quit>());
} else {
channel->Send(std::make_unique<Query>(s));
}
}
}
int main() {
System system;
system.Spawn<Master>("master");
std::thread client(ClientMain, &system);
for (int i = 0; i < NUM_WORKERS; ++i)
system.Spawn<Worker>("worker" + std::to_string(i), i);
system.AwaitShutdown();
return 0;
}

View File

@ -0,0 +1,139 @@
#include "communication.hpp"
void EventStream::Subscription::unsubscribe() {
event_queue_.RemoveCbByUid(cb_uid_);
}
thread_local Reactor* current_reactor_ = nullptr;
std::string EventQueue::LocalChannel::Hostname() {
return system_->network().Hostname();
}
int32_t EventQueue::LocalChannel::Port() {
return system_->network().Port();
}
std::string EventQueue::LocalChannel::ReactorName() {
return reactor_name_;
}
std::string EventQueue::LocalChannel::Name() {
return name_;
}
void EventQueue::LocalEventStream::Close() {
current_reactor_->Close(name_);
}
ConnectorT Reactor::Open(const std::string &channel_name) {
std::unique_lock<std::recursive_mutex> lock(*mutex_);
// TODO: Improve the check that the channel name does not exist in the
// system.
assert(connectors_.count(channel_name) == 0);
auto it = connectors_.emplace(channel_name,
EventQueue::Params{system_, name_, channel_name, mutex_, cvar_}).first;
return ConnectorT(it->second.stream_, it->second.channel_);
}
ConnectorT Reactor::Open() {
std::unique_lock<std::recursive_mutex> lock(*mutex_);
do {
std::string channel_name = "stream-" + std::to_string(channel_name_counter_++);
if (connectors_.count(channel_name) == 0) {
// EventQueue &queue = connectors_[channel_name];
auto it = connectors_.emplace(channel_name,
EventQueue::Params{system_, name_, channel_name, mutex_, cvar_}).first;
return ConnectorT(it->second.stream_, it->second.channel_);
}
} while (true);
}
const std::shared_ptr<Channel> Reactor::FindChannel(
const std::string &channel_name) {
std::unique_lock<std::recursive_mutex> lock(*mutex_);
auto it_connector = connectors_.find(channel_name);
if (it_connector == connectors_.end()) return nullptr;
return it_connector->second.channel_;
}
void Reactor::Close(const std::string &s) {
std::unique_lock<std::recursive_mutex> lock(*mutex_);
auto it = connectors_.find(s);
assert(it != connectors_.end());
LockedCloseInternal(it->second);
connectors_.erase(it); // this calls the EventQueue destructor that catches the mutex, ugh.
}
void Reactor::LockedCloseInternal(EventQueue& event_queue) {
// TODO(zuza): figure this out! @@@@
std::cout << "Close Channel! Reactor name = " << name_ << " Channel name = " << event_queue.name_ << std::endl;
}
void Reactor::RunEventLoop() {
std::cout << "event loop is run!" << std::endl;
while (true) {
// Clean up EventQueues without callbacks.
{
std::unique_lock<std::recursive_mutex> lock(*mutex_);
for (auto connectors_it = connectors_.begin(); connectors_it != connectors_.end(); ) {
EventQueue& event_queue = connectors_it->second;
if (event_queue.LockedCanBeClosed()) {
LockedCloseInternal(event_queue);
connectors_it = connectors_.erase(connectors_it); // This removes the element from the collection.
} else {
++connectors_it;
}
}
}
// Process and wait for events to dispatch.
MsgAndCbInfo msgAndCb;
{
std::unique_lock<std::recursive_mutex> lock(*mutex_);
// Exit the loop if there are no more EventQueues.
if (connectors_.empty()) {
return;
}
while (true) {
msgAndCb = LockedGetPendingMessages(lock);
if (msgAndCb.first != nullptr) break;
cvar_->wait(lock);
}
}
for (auto& cbAndSub : msgAndCb.second) {
auto& cb = cbAndSub.first;
const Message& msg = *msgAndCb.first;
cb(msg, cbAndSub.second);
}
}
}
/**
* Checks if there is any nonempty EventStream.
*/
auto Reactor::LockedGetPendingMessages(std::unique_lock<std::recursive_mutex> &lock) -> MsgAndCbInfo {
// return type after because the scope Reactor:: is not searched before the name
for (auto& connectors_key_value : connectors_) {
EventQueue& event_queue = connectors_key_value.second;
auto msg_ptr = event_queue.LockedPop(lock);
if (msg_ptr == nullptr) continue;
std::vector<std::pair<EventStream::Callback, EventStream::Subscription> > cb_info;
for (auto& callbacks_key_value : event_queue.callbacks_) {
uint64_t uid = callbacks_key_value.first;
EventStream::Callback cb = callbacks_key_value.second;
cb_info.emplace_back(cb, EventStream::Subscription(event_queue, uid));
}
return make_pair(std::move(msg_ptr), cb_info);
}
return MsgAndCbInfo(nullptr, {});
}
Network::Network(System *system) : system_(system),
hostname_(system->config().GetString("hostname")),
port_(system->config().GetInt("port")) {}

View File

@ -0,0 +1,545 @@
#pragma once
#include <cassert>
#include <condition_variable>
#include <exception>
#include <functional>
#include <iostream>
#include <memory>
#include <mutex>
#include <queue>
#include <stdexcept>
#include <thread>
#include <tuple>
#include <unordered_map>
#include "cereal/types/base_class.hpp"
class Message;
class EventStream;
class Reactor;
class System;
class EventQueue;
extern thread_local Reactor* current_reactor_;
/**
* Write-end of a Connector (between two reactors).
*/
class Channel {
public:
virtual void Send(std::unique_ptr<Message>) = 0;
virtual std::string Hostname() = 0;
virtual int32_t Port() = 0;
virtual std::string ReactorName() = 0;
virtual std::string Name() = 0;
void operator=(const Channel &) = delete;
template <class Archive>
void serialize(Archive &archive) {
archive(Hostname(), Port(), ReactorName(), Name());
}
};
/**
* Read-end of a Connector (between two reactors).
*/
class EventStream {
public:
/**
* Blocks until a message arrives.
*/
virtual std::unique_ptr<Message> AwaitEvent() = 0;
/**
* Polls if there is a message available, returning null if there is none.
*/
virtual std::unique_ptr<Message> PopEvent() = 0;
/**
* Subscription Service. Lightweight object (can copy by value).
*/
class Subscription {
public:
/**
* Unsubscribe. Call only once.
*/
void unsubscribe();
private:
friend class Reactor;
Subscription(EventQueue& event_queue, uint64_t cb_uid) : event_queue_(event_queue) {
cb_uid_ = cb_uid;
}
EventQueue& event_queue_;
uint64_t cb_uid_;
};
typedef std::function<void(const Message&, Subscription&)> Callback;
/**
* Register a callback that will be called whenever an event arrives.
*/
virtual void OnEvent(Callback callback) = 0;
/**
* Close this event stream, disallowing further events from getting received.
*/
virtual void Close() = 0;
};
/**
* Implementation of a connector.
*
* This class is an internal data structure that represents the state of the connector.
* This class is not meant to be used by the clients of the messaging framework.
* The EventQueue class wraps the event queue data structure, the mutex that protects
* concurrent access to the event queue, the local channel and the event stream.
* The class is owned by the Reactor, but its LocalChannel can outlive it.
* See the LocalChannel and LocalEventStream nested classes for further information.
*/
class EventQueue {
public:
friend class Reactor;
friend class EventStream::Subscription;
struct Params {
System* system;
std::string reactor_name;
std::string name;
std::shared_ptr<std::recursive_mutex> mutex;
std::shared_ptr<std::condition_variable_any> cvar;
};
EventQueue(Params params)
: system_(params.system),
reactor_name_(params.reactor_name),
name_(params.name),
mutex_(params.mutex),
cvar_(params.cvar) {}
/**
* The destructor locks the mutex of the EventQueue and sets queue pointer to null.
*/
~EventQueue() {
// Ugly: this is the ONLY thing that is allowed to lock this recursive mutex twice.
// This is because we can't make a locked and a unlocked version of the destructor.
std::unique_lock<std::recursive_mutex> lock(*mutex_);
stream_->queue_ = nullptr;
channel_->queue_ = nullptr;
}
void LockedPush(std::unique_lock<std::recursive_mutex> &, std::unique_ptr<Message> m) {
queue_.push(std::move(m));
cvar_->notify_one();
}
std::unique_ptr<Message> LockedAwaitPop(std::unique_lock<std::recursive_mutex> &lock) {
std::unique_ptr<Message> m;
while (!(m = LockedRawPop())) {
cvar_->wait(lock);
}
return m;
}
std::unique_ptr<Message> LockedPop(std::unique_lock<std::recursive_mutex> &lock) {
return LockedRawPop();
}
void LockedOnEvent(EventStream::Callback callback) {
uint64_t cb_uid = next_cb_uid++;
callbacks_[cb_uid] = callback;
}
/**
* LocalChannel represents the channels to reactors living in the same reactor system.
*
* Sending messages to the local channel requires acquiring the mutex.
* LocalChannel holds a pointer to the enclosing EventQueue object.
* The enclosing EventQueue object is destroyed when the reactor calls Close.
* When this happens, the pointer to the enclosing EventQueue object is set to null.
* After this, all the message sends on this channel are dropped.
*/
class LocalChannel : public Channel {
public:
friend class EventQueue;
LocalChannel(std::shared_ptr<std::recursive_mutex> mutex, std::string reactor_name,
std::string name, EventQueue *queue, System *system)
: mutex_(mutex),
reactor_name_(reactor_name),
name_(name),
queue_(queue),
system_(system) {}
virtual void Send(std::unique_ptr<Message> m) {
std::unique_lock<std::recursive_mutex> lock(*mutex_);
if (queue_ != nullptr) {
queue_->LockedPush(lock, std::move(m));
}
}
virtual std::string Hostname();
virtual int32_t Port();
virtual std::string ReactorName();
virtual std::string Name();
private:
std::shared_ptr<std::recursive_mutex> mutex_;
std::string reactor_name_;
std::string name_;
EventQueue *queue_;
System *system_;
};
/**
* Implementation of the event stream.
*
* After the enclosing EventQueue object is destroyed (by a call to Close),
* it is no longer legal to call any of the event stream methods.
*/
class LocalEventStream : public EventStream {
public:
friend class EventQueue;
LocalEventStream(std::shared_ptr<std::recursive_mutex> mutex, std::string name,
EventQueue *queue) : mutex_(mutex), name_(name), queue_(queue) {}
std::unique_ptr<Message> AwaitEvent() {
std::unique_lock<std::recursive_mutex> lock(*mutex_);
if (queue_ != nullptr) {
return queue_->LockedAwaitPop(lock);
}
throw std::runtime_error(
"Cannot call method after connector was closed.");
}
std::unique_ptr<Message> PopEvent() {
std::unique_lock<std::recursive_mutex> lock(*mutex_);
if (queue_ != nullptr) {
return queue_->LockedPop(lock);
}
throw std::runtime_error(
"Cannot call method after connector was closed.");
}
void OnEvent(EventStream::Callback callback) {
std::unique_lock<std::recursive_mutex> lock(*mutex_);
if (queue_ != nullptr) {
queue_->LockedOnEvent(callback);
return;
}
throw std::runtime_error(
"Cannot call method after connector was closed.");
}
void Close();
private:
std::shared_ptr<std::recursive_mutex> mutex_;
std::string name_;
EventQueue *queue_;
};
private:
std::unique_ptr<Message> LockedRawPop() {
if (queue_.empty()) return nullptr;
std::unique_ptr<Message> t = std::move(queue_.front());
queue_.pop();
return std::move(t);
}
/**
* Should the owner close this EventQueue?
*
* Currently only checks if there are no more messages and all callbacks have unsubscribed?
* This assumes the event loop has been started.
*/
bool LockedCanBeClosed() {
return callbacks_.empty() && queue_.empty();
}
void RemoveCbByUid(uint64_t uid) {
std::unique_lock<std::recursive_mutex> lock(*mutex_);
size_t num_erased = callbacks_.erase(uid);
assert(num_erased == 1);
// TODO(zuza): if no more callbacks, shut down the class (and the eventloop is started). First, figure out ownership of EventQueue?
}
System *system_;
std::string name_;
std::string reactor_name_;
std::queue<std::unique_ptr<Message>> queue_;
// Should only be locked once since it's used by a cond. var. Also caught in dctor, so must be recursive.
std::shared_ptr<std::recursive_mutex> mutex_;
std::shared_ptr<std::condition_variable_any> cvar_;
std::shared_ptr<LocalEventStream> stream_ =
std::make_shared<LocalEventStream>(mutex_, name_, this);
std::shared_ptr<LocalChannel> channel_ =
std::make_shared<LocalChannel>(mutex_, reactor_name_, name_, this, system_);
std::unordered_map<uint64_t, EventStream::Callback> callbacks_;
uint64_t next_cb_uid = 0;
};
/**
* Pair composed of read-end and write-end of a connection.
*/
using ConnectorT = std::pair<std::shared_ptr<EventStream>, std::shared_ptr<Channel>>;
using ChannelRefT = std::shared_ptr<Channel>;
/**
* A single unit of concurrent execution in the system.
*
* E.g. one worker, one client. Owned by System.
*/
class Reactor {
public:
friend class System;
Reactor(System *system, std::string name)
: system_(system), name_(name), main_(Open("main")) {}
virtual ~Reactor() {}
virtual void Run() = 0;
ConnectorT Open(const std::string &s);
ConnectorT Open();
const std::shared_ptr<Channel> FindChannel(const std::string &channel_name);
void Close(const std::string &s);
protected:
System *system_;
std::string name_;
/*
* Locks all Reactor data, including all EventQueue's in connectors_.
*
* This should be a shared_ptr because LocalChannel can outlive Reactor.
*/
std::shared_ptr<std::recursive_mutex> mutex_ =
std::make_shared<std::recursive_mutex>();
std::shared_ptr<std::condition_variable_any> cvar_ =
std::make_shared<std::condition_variable_any>();
std::unordered_map<std::string, EventQueue> connectors_;
int64_t channel_name_counter_{0};
ConnectorT main_;
private:
typedef std::pair<std::unique_ptr<Message>,
std::vector<std::pair<EventStream::Callback, EventStream::Subscription> > > MsgAndCbInfo;
/**
* Dispatches all waiting messages to callbacks. Shuts down when there are no callbacks left.
*/
void RunEventLoop();
void LockedCloseInternal(EventQueue& event_queue);
// TODO: remove proof of locking evidence ?!
MsgAndCbInfo LockedGetPendingMessages(std::unique_lock<std::recursive_mutex> &lock);
};
/**
* Configuration service.
*/
class Config {
public:
Config(System *system) : system_(system) {}
std::string GetString(std::string key) {
// TODO: Use configuration lib.
assert(key == "hostname");
return "localhost";
}
int32_t GetInt(std::string key) {
// TODO: Use configuration lib.
assert(key == "port");
return 8080;
}
private:
System *system_;
};
/**
* Networking service.
*/
class Network {
public:
Network(System *system);
std::string Hostname() { return hostname_; }
int32_t Port() { return port_; }
std::shared_ptr<Channel> Resolve(std::string hostname, int32_t port) {
// TODO: Synchronously resolve and return channel.
return nullptr;
}
std::shared_ptr<EventStream> AsyncResolve(std::string hostname, int32_t port,
int32_t retries,
std::chrono::seconds cooldown) {
// TODO: Asynchronously resolve channel, and return an event stream
// that emits the channel after it gets resolved.
return nullptr;
}
class RemoteChannel : public Channel {
public:
RemoteChannel() {}
virtual std::string Hostname() {
throw std::runtime_error("Unimplemented.");
}
virtual int32_t Port() { throw std::runtime_error("Unimplemented."); }
virtual std::string ReactorName() {
throw std::runtime_error("Unimplemented.");
}
virtual std::string Name() { throw std::runtime_error("Unimplemented."); }
virtual void Send(std::unique_ptr<Message> message) {
// TODO: Implement.
}
};
private:
System *system_;
std::string hostname_;
int32_t port_;
};
/**
* Base class for messages.
*/
class Message {
public:
virtual ~Message() {}
template <class Archive>
void serialize(Archive &) {}
};
/**
* Message that includes the sender channel used to respond.
*/
class SenderMessage : public Message {
public:
SenderMessage(ChannelRefT sender) : sender_(sender) {}
ChannelRefT sender() { return sender_; }
template <class Archive>
void serialize(Archive &ar) {
ar(sender_);
}
private:
ChannelRefT sender_;
};
/**
* Serialization service.
*/
class Serialization {
public:
using SerializedT = std::pair<char *, int64_t>;
Serialization(System *system) : system_(system) {}
SerializedT serialize(const Message &) {
SerializedT serialized;
throw std::runtime_error("Not yet implemented (Serialization::serialized)");
return serialized;
}
Message deserialize(const SerializedT &) {
Message message;
throw std::runtime_error(
"Not yet implemented (Serialization::deserialize)");
return message;
}
private:
System *system_;
};
/**
* Global placeholder for all reactors in the system. Alive through the entire process lifetime.
*
* E.g. holds set of reactors, channels for all reactors.
*/
class System {
public:
friend class Reactor;
System() : config_(this), network_(this), serialization_(this) {}
void operator=(const System &) = delete;
template <class ReactorType, class... Args>
const std::shared_ptr<Channel> Spawn(const std::string &name,
Args &&... args) {
std::unique_lock<std::recursive_mutex> lock(mutex_);
auto *raw_reactor =
new ReactorType(this, name, std::forward<Args>(args)...);
std::unique_ptr<Reactor> reactor(raw_reactor);
// Capturing a pointer isn't ideal, I would prefer to capture a Reactor&, but not sure how to do it.
std::thread reactor_thread(
[this, raw_reactor]() { this->StartReactor(*raw_reactor); });
assert(reactors_.count(name) == 0);
reactors_.emplace(name, std::pair<std::unique_ptr<Reactor>, std::thread>
(std::move(reactor), std::move(reactor_thread)));
return nullptr;
}
const std::shared_ptr<Channel> FindChannel(const std::string &reactor_name,
const std::string &channel_name) {
std::unique_lock<std::recursive_mutex> lock(mutex_);
auto it_reactor = reactors_.find(reactor_name);
if (it_reactor == reactors_.end()) return nullptr;
return it_reactor->second.first->FindChannel(channel_name);
}
void AwaitShutdown() {
for (auto &key_value : reactors_) {
auto &thread = key_value.second.second;
thread.join();
}
}
Config &config() { return config_; }
Network &network() { return network_; }
Serialization &serialization() { return serialization_; }
private:
void StartReactor(Reactor& reactor) {
current_reactor_ = &reactor;
reactor.Run();
reactor.RunEventLoop(); // Activate callbacks.
}
std::recursive_mutex mutex_;
// TODO: Replace with a map to a reactor EventQueue map to have more granular
// locking.
std::unordered_map<std::string,
std::pair<std::unique_ptr<Reactor>, std::thread>>
reactors_;
Config config_;
Network network_;
Serialization serialization_;
};

View File

@ -0,0 +1,162 @@
#pragma once
#include <cassert>
#include <unordered_map>
#include <vector>
#include "uid.hpp"
enum class EdgeType { OUTGOING, INCOMING };
/** A node in the graph. Has incoming and outgoing edges which
* are defined as global addresses of other nodes */
class Node {
public:
Node(const GlobalId &id) : id_(id) {}
const auto &id() const { return id_; };
const auto &edges_out() const { return edges_out_; }
const auto &edges_in() const { return edges_in_; }
void AddConnection(EdgeType edge_type, const GlobalAddress &gad) {
(edge_type == EdgeType::INCOMING ? edges_in_ : edges_out_)
.emplace_back(gad);
}
/** Changes all old_address edges to have the new_worker */
void RedirectEdges(const GlobalAddress old_address, size_t new_worker) {
for (auto &address : edges_in_)
if (address == old_address) address.worker_id_ = new_worker;
for (auto &address : edges_out_)
if (address == old_address) address.worker_id_ = new_worker;
}
private:
// TODO remove id_ from Node if not necessary
GlobalId id_;
// global addresses of nodes this node is connected to
std::vector<GlobalAddress> edges_out_;
std::vector<GlobalAddress> edges_in_;
};
/** A worker / shard in the distributed system */
class Worker {
public:
// unique worker ID. uniqueness is ensured by the worker
// owner (the Distributed class)
const int64_t id_;
Worker(int64_t id) : id_(id) {}
int64_t NodeCount() const { return nodes_.size(); }
/** Gets a node. */
Node &GetNode(const GlobalId &gid) {
auto found = nodes_.find(gid);
assert(found != nodes_.end());
return found->second;
}
/** Returns the number of edges that cross from this
* graph / worker into another one */
int64_t BoundaryEdgeCount() const {
int64_t count = 0;
auto count_f = [this, &count](const auto &edges) {
for (const GlobalAddress &address : edges)
if (address.worker_id_ != id_) count++;
};
for (const auto &node : nodes_) {
count_f(node.second.edges_out());
count_f(node.second.edges_in());
}
return count;
}
/** Creates a new node on this worker. Returns it's global id */
const GlobalId &MakeNode() {
GlobalId new_id(id_, next_node_sequence_++);
auto new_node = nodes_.emplace(std::make_pair(new_id, Node(new_id)));
return new_node.first->first;
};
/** Places the existing node on this worker */
void PlaceNode(const GlobalId &gid, const Node &node) {
nodes_.emplace(gid, node);
}
/** Removes the node with the given ID from this worker */
void RemoveNode(const GlobalId &gid) { nodes_.erase(gid); }
auto begin() const { return nodes_.begin(); }
auto end() const { return nodes_.end(); }
private:
// counter of sequences numbers of nodes created on this worker
int64_t next_node_sequence_{0};
// node storage of this worker
std::unordered_map<GlobalId, Node> nodes_;
};
/**
* A distributed system consisting of mulitple workers.
* For the time being it's not modelling a distributed
* system correctly in terms of message passing (as opposed
* to operating on workers and their data directly).
*/
class Distributed {
public:
/** Creates a distributed with the given number of workers */
Distributed(int initial_worker_count = 0) {
for (int worker_id = 0; worker_id < initial_worker_count; worker_id++)
AddWorker();
}
int64_t AddWorker() {
int64_t new_worker_id = workers_.size();
workers_.emplace_back(new_worker_id);
return new_worker_id;
}
int WorkerCount() const { return workers_.size(); }
auto &GetWorker(int64_t worker_id) { return workers_[worker_id]; }
GlobalAddress MakeNode(int64_t worker_id) {
return {worker_id, workers_[worker_id].MakeNode()};
}
Node &GetNode(const GlobalAddress &address) {
return workers_[address.worker_id_].GetNode(address.id_);
}
/** Moves a node with the given global id to the given worker */
void MoveNode(const GlobalAddress &gid, int64_t destination) {
const Node &node = GetNode(gid);
// make sure that all edges to and from the node are updated
for (auto &edge : node.edges_in())
GetNode(edge).RedirectEdges(gid, destination);
for (auto &edge : node.edges_out())
GetNode(edge).RedirectEdges(gid, destination);
// change node destination
workers_[destination].PlaceNode(gid.id_, node);
workers_[gid.worker_id_].RemoveNode(gid.id_);
}
void MakeEdge(const GlobalAddress &from, const GlobalAddress &to) {
GetNode(from).AddConnection(EdgeType::OUTGOING, to);
GetNode(to).AddConnection(EdgeType::INCOMING, from);
}
auto begin() const { return workers_.begin(); }
auto end() const { return workers_.end(); }
private:
std::vector<Worker> workers_;
};

View File

@ -0,0 +1,147 @@
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <experimental/tuple>
#include <iostream>
#include <numeric>
#include <random>
#include <tuple>
#include <vector>
#include "graph.hpp"
namespace spinner {
// const for balancing penalty
double c = 2.0;
/**
* Returns the index of the maximum score in the given vector.
* If there are multiple minimums, one is chosen at random.
*/
auto MaxRandom(const std::vector<double> &scores) {
std::vector<size_t> best_indices;
double current_max = std::numeric_limits<double>::lowest();
for (size_t ind = 0; ind < scores.size(); ind++) {
if (scores[ind] > current_max) {
current_max = scores[ind];
best_indices.clear();
}
if (scores[ind] == current_max) {
best_indices.emplace_back(ind);
}
}
return best_indices[rand() % best_indices.size()];
}
/**
* Returns the index of the best (highest scored) worker
* for the given node. If there are multiple workers with
* the best score, node prefers to remain on the same worker
* (if among the best), or one is chosen at random.
*
* @param distributed - the distributed system.
* @param node - the node which is being evaluated.
* @param penalties - a vector of penalties (per worker).
* @param current_worker - the worker on which the given
* node is currently residing.
* @return - std::pair<int, std::vector<double>> which is a
* pair of (best worker, score_per_worker).
*/
auto BestWorker(const Distributed &distributed, const Node &node,
const std::vector<double> &penalties, int current_worker) {
// scores per worker
std::vector<double> scores(distributed.WorkerCount(), 0.0);
for (auto &edge : node.edges_in()) scores[edge.worker_id_] += 1.0;
for (auto &edge : node.edges_out()) scores[edge.worker_id_] += 1.0;
for (int worker = 0; worker < distributed.WorkerCount(); ++worker) {
// normalize contribution of worker over neighbourhood size
scores[worker] /= node.edges_out().size() + node.edges_in().size();
// add balancing penalty
scores[worker] -= penalties[worker];
}
// pick the best destination, but prefer to stay if you can
size_t destination = MaxRandom(scores);
if (scores[current_worker] == scores[destination])
destination = current_worker;
return std::make_pair(destination, scores);
}
/** Indication if Spinner worker penality is calculated based on
* vertex or edge worker cardinalities */
enum class PenaltyType { Vertex, Edge };
/** Calcualtes Spinner penalties for workers in the given
* distributed system. */
auto Penalties(const Distributed &distributed,
PenaltyType penalty_type = PenaltyType::Edge) {
std::vector<double> penalties;
int64_t total_count{0};
for (const auto &worker : distributed) {
int64_t worker_count{0};
switch (penalty_type) {
case PenaltyType::Vertex:
worker_count += worker.NodeCount();
break;
case PenaltyType::Edge:
for (const auto &node_kv : worker) {
// Spinner counts the edges on a worker as the sum
// of degrees of nodes on that worker. In that sense
// both incoming and outgoing edges are individually
// added...
worker_count += node_kv.second.edges_out().size();
worker_count += node_kv.second.edges_in().size();
}
break;
}
total_count += worker_count;
penalties.emplace_back(worker_count);
}
for (auto &penalty : penalties)
penalty /= c * total_count / distributed.WorkerCount();
return penalties;
}
/** Do one spinner step (modifying the given distributed) */
void PerformSpinnerStep(Distributed &distributed) {
auto penalties = Penalties(distributed);
// here a strategy can be injected for limiting
// the number of movements performed in one step.
// limiting could be based on (for example):
// - limiting the number of movements per worker
// - limiting only to movements that are above
// a treshold (score improvement or something)
// - not executing on all the workers (also prevents
// oscilations)
//
// in the first implementation just accumulate all
// the movements and execute together.
// relocation info: contains the address of the Node
// that needs to relocate and it's destination worker
std::vector<std::pair<GlobalAddress, int>> movements;
for (const Worker &worker : distributed)
for (const auto &gid_node_pair : worker) {
// (best destination, scores) pair for node
std::pair<int, std::vector<double>> destination_scores =
BestWorker(distributed, gid_node_pair.second, penalties, worker.id_);
if (destination_scores.first != worker.id_)
movements.emplace_back(GlobalAddress(worker.id_, gid_node_pair.first),
destination_scores.first);
}
// execute movements. it is likely that in the real system
// this will need to happen as a single db transaction
for (const auto &m : movements) distributed.MoveNode(m.first, m.second);
}
} // namespace spinner

View File

@ -0,0 +1,62 @@
#pragma once
#include <cstdint>
#include <vector>
/** A globally defined identifier. Defines a worker
* and the sequence number on that worker */
class GlobalId {
public:
GlobalId(int64_t worker_id, int64_t sequence_number)
: worker_id_(worker_id), sequence_number_(sequence_number) {}
// TODO perhaps make members const and replace instead of changing
// when migrating nodes
int64_t worker_id_;
int64_t sequence_number_;
bool operator==(const GlobalId &other) const {
return worker_id_ == other.worker_id_ &&
sequence_number_ == other.sequence_number_;
}
bool operator!=(const GlobalId &other) const { return !(*this == other); }
};
/** Defines a location in the system where something can be found.
* Something can be found on some worker, for some Id */
class GlobalAddress {
public:
GlobalAddress(int64_t worker_id, GlobalId id)
: worker_id_(worker_id), id_(id) {}
// TODO perhaps make members const and replace instead of changing
// when migrating nodes
int64_t worker_id_;
GlobalId id_;
bool operator==(const GlobalAddress &other) const {
return worker_id_ == other.worker_id_ && id_ == other.id_;
}
bool operator!=(const GlobalAddress &other) const {
return !(*this == other);
}
};
namespace std {
template <>
struct hash<GlobalId> {
size_t operator()(const GlobalId &id) const {
return id.sequence_number_ << 4 ^ id.worker_id_;
}
};
template <>
struct hash<GlobalAddress> {
size_t operator()(const GlobalAddress &ga) const {
return gid_hash(ga.id_) << 4 ^ ga.worker_id_;
}
private:
std::hash<GlobalId> gid_hash{};
};
}

View File

@ -0,0 +1,51 @@
cmake_minimum_required(VERSION 3.1)
project(${project_name}_tests)
enable_testing()
# set current directory name as a test type
get_filename_component(test_type ${CMAKE_CURRENT_SOURCE_DIR} NAME)
# get all cpp abs file names recursively starting from current directory
file(GLOB_RECURSE test_type_cpps *.cpp)
message(STATUS "Available ${test_type} cpp files are: ${test_type_cpps}")
# for each cpp file build binary and register test
foreach(test_cpp ${test_type_cpps})
# get exec name (remove extension from the abs path)
get_filename_component(exec_name ${test_cpp} NAME_WE)
# set target name in format {project_name}__{test_type}__{exec_name}
set(target_name ${project_name}__${test_type}__${exec_name})
# build exec file
add_executable(${target_name} ${test_cpp})
set_property(TARGET ${target_name} PROPERTY CXX_STANDARD ${cxx_standard})
if(${TEST_COVERAGE})
set_target_properties(${target_name} PROPERTIES COMPILE_FLAGS "-g -O0 -Wall -fprofile-arcs -ftest-coverage")
set_target_properties(${target_name} PROPERTIES LINK_FLAGS "--coverage -fprofile-arcs -ftest-coverage")
endif()
# OUTPUT_NAME sets the real name of a target when it is built and can be
# used to help create two targets of the same name even though CMake
# requires unique logical target names
set_target_properties(${target_name} PROPERTIES OUTPUT_NAME ${exec_name})
# link libraries
target_link_libraries(${target_name} distributed_lib)
target_link_libraries(${target_name} memgraph_lib)
target_link_libraries(${target_name} ${MEMGRAPH_ALL_LIBS})
# gtest
target_link_libraries(${target_name} gtest gtest_main gmock)
if(${TEST_COVERAGE})
# for code coverage
target_link_libraries(${target_name} gcov)
endif()
# register test
set(output_path ${CMAKE_BINARY_DIR}/test_results/unit/${target_name}.xml)
add_test(${target_name} ${exec_name} --gtest_output=xml:${output_path})
endforeach()

View File

@ -0,0 +1,67 @@
#include <cassert>
#include <iostream>
#include <iterator>
#include "graph.hpp"
void test_global_id() {
GlobalId a(1, 1);
assert(a == GlobalId(1, 1));
assert(a != GlobalId(1, 2));
assert(a != GlobalId(2, 1));
}
void test_global_address() {
GlobalAddress a(1, {1, 1});
assert(a == GlobalAddress(1, {1, 1}));
assert(a != GlobalAddress(2, {1, 1}));
assert(a != GlobalAddress(1, {2, 1}));
}
void test_worker() {
Worker worker0{0};
assert(worker0.NodeCount() == 0);
GlobalId n0 = worker0.MakeNode();
assert(worker0.NodeCount() == 1);
Worker worker1{1};
worker1.PlaceNode(n0, worker0.GetNode(n0));
worker0.RemoveNode(n0);
assert(worker0.NodeCount() == 0);
assert(worker1.NodeCount() == 1);
worker1.MakeNode();
assert(worker1.NodeCount() == 2);
assert(std::distance(worker1.begin(), worker1.end()) == 2);
}
void test_distributed() {
Distributed d;
assert(d.WorkerCount() == 0);
auto w0 = d.AddWorker();
assert(d.WorkerCount() == 1);
auto w1 = d.AddWorker();
assert(d.WorkerCount() == 2);
GlobalAddress n0 = d.MakeNode(w0);
assert(d.GetWorker(w0).NodeCount() == 1);
GlobalAddress n1 = d.MakeNode(w1);
assert(d.GetNode(n0).edges_out().size() == 0);
assert(d.GetNode(n0).edges_in().size() == 0);
assert(d.GetNode(n1).edges_out().size() == 0);
assert(d.GetNode(n1).edges_in().size() == 0);
d.MakeEdge(n0, n1);
assert(d.GetNode(n0).edges_out().size() == 1);
assert(d.GetNode(n0).edges_in().size() == 0);
assert(d.GetNode(n1).edges_out().size() == 0);
assert(d.GetNode(n1).edges_in().size() == 1);
}
int main() {
test_global_id();
test_global_address();
test_worker();
test_distributed();
std::cout << "All tests passed" << std::endl;
}

View File

@ -0,0 +1,126 @@
#include <fstream>
#include <streambuf>
#include "cereal/archives/binary.hpp"
#include "cereal/types/memory.hpp"
#include "cereal/types/string.hpp"
#include "cereal/types/utility.hpp" // utility has to be included because of std::pair
#include "cereal/types/vector.hpp"
struct BasicSerializable {
int64_t x_;
std::string y_;
BasicSerializable() = default;
BasicSerializable(int64_t x, std::string y) : x_(x), y_(y) {}
template <class Archive>
void serialize(Archive &ar) {
ar(x_, y_);
}
template <typename Archive>
static void load_and_construct(
Archive &ar, cereal::construct<BasicSerializable> &construct) {
int64_t x;
std::string y;
ar(x, y);
construct(x, y);
}
};
struct ComplexSerializable {
using VectorT = std::vector<float>;
using VectorPairT = std::vector<std::pair<std::string, BasicSerializable>>;
BasicSerializable x_;
VectorT y_;
VectorPairT z_;
ComplexSerializable(const BasicSerializable &x, const VectorT &y,
const VectorPairT &z)
: x_(x), y_(y), z_(z) {}
template <typename Archive>
void serialize(Archive &ar) {
ar(x_, y_, z_);
}
template <typename Archive>
static void load_and_construct(
Archive &ar, cereal::construct<ComplexSerializable> &construct) {
BasicSerializable x;
VectorT y;
VectorPairT z;
ar(x, y, z);
construct(x, y, z);
}
};
class DummyStreamBuf : public std::basic_streambuf<char> {
protected:
std::streamsize xsputn(const char *data, std::streamsize count) override {
for (std::streamsize i = 0; i < count; ++i) {
data_.push_back(data[i]);
}
return count;
}
std::streamsize xsgetn(char *data, std::streamsize count) override {
if (count < 0) return 0;
if (static_cast<size_t>(position_ + count) > data_.size()) {
count = data_.size() - position_;
position_ = data_.size();
}
memcpy(data, data_.data() + position_, count);
position_ += count;
return count;
}
private:
std::vector<char> data_;
std::streamsize position_{0};
};
int main() {
DummyStreamBuf sb;
std::iostream iostream(&sb);
// serialization
cereal::BinaryOutputArchive oarchive{iostream};
std::unique_ptr<BasicSerializable const> const basic_serializable_object{
new BasicSerializable{100, "Test"}};
std::unique_ptr<ComplexSerializable const> const complex_serializable_object{
new ComplexSerializable{
{100, "test"},
{3.4, 3.4},
{{"first", {10, "Basic1"}}, {"second", {20, "Basic2"}}}}};
oarchive(basic_serializable_object);
oarchive(complex_serializable_object);
// deserialization
cereal::BinaryInputArchive iarchive{iostream};
std::unique_ptr<BasicSerializable> basic_deserialized_object{nullptr};
std::unique_ptr<ComplexSerializable> complex_deserialized_object{nullptr};
iarchive(basic_deserialized_object);
iarchive(complex_deserialized_object);
// output
std::cout << "Basic Deserialized: " << basic_deserialized_object->x_ << "; "
<< basic_deserialized_object->y_ << std::endl;
auto x = complex_deserialized_object->x_;
auto y = complex_deserialized_object->y_;
auto z = complex_deserialized_object->z_;
std::cout << "Complex Deserialized" << std::endl;
std::cout << " x_ -> " << x.x_ << "; " << x.y_ << std::endl;
std::cout << " y_ -> ";
for (const auto v_item : y) std::cout << v_item << " ";
std::cout << std::endl;
std::cout << " z_ -> ";
for (const auto v_item : z)
std::cout << v_item.first << " | Pair: (" << v_item.second.x_ << ", "
<< v_item.second.y_ << ")"
<< "::";
std::cout << std::endl;
return 0;
}

View File

@ -0,0 +1,85 @@
#include <cstdio>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <map>
#include <string>
#include <vector>
#include "graph.hpp"
#include "spinner.hpp"
void PrintStatistics(const Distributed &distributed) {
using std::cout;
using std::endl;
for (const Worker &worker : distributed) {
cout << " Worker " << worker.id_ << ":";
cout << " #nodes = " << worker.NodeCount();
int64_t edge_count{0};
for (const auto &gid_node_pair : worker) {
edge_count += gid_node_pair.second.edges_in().size();
edge_count += gid_node_pair.second.edges_out().size();
}
cout << ", #edges = " << edge_count;
cout << ", #cuts = " << worker.BoundaryEdgeCount() << endl;
}
}
/**
* Reads an undirected graph from file.
* - first line of the file: vertices_count, edges_count
* - next edges_count lines contain vertices that form an edge
* example:
* https://snap.stanford.edu/data/facebook_combined.txt.gz
* add number of vertices and edges in the first line of that file
*/
Distributed ReadGraph(std::string filename, int worker_count) {
Distributed distributed(worker_count);
std::fstream fs;
fs.open(filename, std::fstream::in);
if (fs.fail()) return distributed;
int vertex_count, edge_count;
fs >> vertex_count >> edge_count;
// assign vertices to random workers
std::vector<GlobalAddress> nodes;
for (int i = 0; i < vertex_count; ++i)
nodes.emplace_back(distributed.MakeNode(rand() % worker_count));
// add edges
for (int i = 0; i < edge_count; ++i) {
size_t u, v;
fs >> u >> v;
assert(u < nodes.size() && v < nodes.size());
distributed.MakeEdge(nodes[u], nodes[v]);
}
fs.close();
return distributed;
}
int main(int argc, const char *argv[]) {
srand(time(NULL));
if (argc == 1) {
std::cout << "Usage:" << std::endl;
std::cout << argv[0] << " filename partitions iterations" << std::endl;
return 0;
}
std::cout << "Memgraph spinner test " << std::endl;
std::string filename(argv[1]);
int partitions = std::max(1, atoi(argv[2]));
int iterations = std::max(1, atoi(argv[3]));
Distributed distributed = ReadGraph(filename, partitions);
PrintStatistics(distributed);
for (int iter = 0; iter < iterations; ++iter) {
spinner::PerformSpinnerStep(distributed);
std::cout << "Iteration " << iter << std::endl;
PrintStatistics(distributed);
}
return 0;
}

View File

@ -0,0 +1,2 @@
*.txt
testmain

View File

@ -0,0 +1,14 @@
#include "stdint.h"
inline int is_marked(long long i) {
return ((i & 0x1L)>0);
}
inline long long get_unmarked(long long i) {
return i & ~0x1L;
}
inline long long get_marked(long long i) {
return i | 0x1L;
}

View File

@ -0,0 +1,3 @@
#!/bin/bash
g++ -std=c++1y testmain.cpp -o testmain -lpthread

View File

@ -0,0 +1,212 @@
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <vector>
#include <string>
#include <cstdlib>
#include "node.h"
using namespace std;
template<typename T>
struct list_t {
node_t<T> *start_ptr;
node_t<T> *end_ptr;
atomic<int> length;
list_t<T> (const T& a,const T& b) {
end_ptr = allocate_node(b, NULL);
start_ptr = allocate_node(a, end_ptr);
length.store(0);
}
node_t<T> *allocate_node(const T& data, node_t<T>* next) {
return new node_t<T>(data,next);
}
node_t<T> *find(const T& val, node_t<T> **left){
//printf("find node %d\n",val);
node_t<T> *left_next, *right, *left_next_copy;
left_next = right = NULL;
while(1) {
node_t<T> *it = start_ptr;
node_t<T> *it_next = start_ptr->next.load();
//while (get_flag(it_next) || (it->data.load() < val)) {
while (get_flag(it_next) || (it->data < val)) {
//printf("%d\n",it->data);
if (!get_flag(it_next)) {
(*left) = it;
left_next = it_next;
}
it = get_unflagged(it_next);
if (it == end_ptr) break;
//it_next = it->next.load(memory_order_relaxed);
it_next = it->next.load();
}
right = it;left_next_copy = left_next;
if (left_next == right){
//if (!get_flag(right->next.load(memory_order_relaxed)))
if (right == end_ptr || !get_flag(right->next.load()))
return right;
}
else {
if ((*left)->next.compare_exchange_strong(left_next_copy,right) == true) {
int previous = left_next->ref_count.fetch_add(-1);
previous = right->ref_count.fetch_add(1);
//if (!get_flag(right->next.load(memory_order_relaxed))) return right;
if (!get_flag(right->next.load())) return right;
}
}
}
}
int contains(const T& val) {
//printf("search node %d\n",val);
//node_t<T> *it = get_unflagged(start_ptr->next.load(memory_order_relaxed));
node_t<T> *it = get_unflagged(start_ptr->next.load());
while(it != end_ptr) {
//if (!get_flag(it->next) && it->data.load() >= val){
if (!get_flag(it->next) && it->data >= val){
//if (it->data.load() == val) return 1;
if (it->data == val) return 1;
else return 0;
}
//it = get_unflagged(it->next.load(memory_order_relaxed));
it = get_unflagged(it->next.load());
}
return 0;
}
int size() {
return length.load();
}
int add(const T& val) {
//printf("add node %d\n",val);
node_t<T> *right, *left;
right = left = NULL;
node_t<T> *new_elem = allocate_node(val, NULL);
while(1) {
right = find(val, &left);
//if (right != end_ptr && right->data.load() == val){
if (right != end_ptr && right->data == val){
return 0;
}
new_elem->next.store(right);
if (left->next.compare_exchange_strong(right,new_elem) == true) {
length.fetch_add(1);
return 1;
}
else {
}
}
}
node_t<T>* remove(const T& val) {
//printf("remove node %d\n",val);
node_t<T>* right, *left, *right_next, *tmp;
node_t<T>* left_next, *right_copy;
right = left = right_next = tmp = NULL;
while(1) {
right = find(val, &left);
left_next = left->next.load();
right_copy = right;
//if (right == end_ptr || right->data.load() != val){
if (right == end_ptr || right->data != val){
return NULL;
}
//right_next = right->next.load(memory_order_relaxed);
right_next = right->next.load();
if (!get_flag(right_next)){
node_t<T>* right_next_marked = get_flagged(right_next);
if ((right->next).compare_exchange_strong(right_next,right_next_marked)==true) {
if((left->next).compare_exchange_strong(right_copy,right_next) == false) {
tmp = find(val,&tmp);
} else {
int previous = right->ref_count.fetch_add(-1);
previous = right_next->ref_count.fetch_add(1);
}
length.fetch_add(-1);
return right;
}
}
}
}
int get_flag(node_t<T>* ptr) {
return is_marked(reinterpret_cast<long long>(ptr));
}
void mark_flag(node_t<T>* &ptr){
ptr = get_flagged(ptr);
}
void unmark_flag(node_t<T>* &ptr){
ptr = get_unflagged(ptr);
}
inline static node_t<T>* get_flagged(node_t<T>* ptr){
return reinterpret_cast<node_t<T>*>(get_marked(reinterpret_cast<long long>(ptr)));
}
inline static node_t<T>* get_unflagged(node_t<T>* ptr){
return reinterpret_cast<node_t<T>*>(get_unmarked(reinterpret_cast<long long>(ptr)));
}
struct iterator{
node_t<T>* ptr;
iterator(node_t<T>* ptr_) : ptr(ptr_) {
ptr->ref_count.fetch_add(1);
}
~iterator() {
if(ptr != NULL) ptr->ref_count.fetch_add(-1);
}
bool operator==(const iterator& other) {
return ptr == other.ptr;
}
bool operator!=(const iterator& other) {
return ptr != other.ptr;
}
iterator& operator++() {
node_t<T>* it_next = ptr->next.load(), *it = ptr, *it_next_unflagged = list_t<T>::get_unflagged(it_next);
while(it_next_unflagged != NULL && it_next != it_next_unflagged) {
it = it_next_unflagged;
it_next = it->next.load();
it_next_unflagged = list_t<T>::get_unflagged(it_next);
}
if(it_next_unflagged == NULL) {
it->ref_count.fetch_add(1);
ptr->ref_count.fetch_add(-1);
ptr = it;
} else {
it_next->ref_count.fetch_add(1);
ptr->ref_count.fetch_add(-1);
ptr = it_next;
}
return *this;
}
T& operator*() {
return ptr->data;
}
};
iterator begin(){
while(1) {
node_t<T>* it = start_ptr->next.load();
node_t<T>* it_next = it->next.load();
while(it!=end_ptr && get_flag(it->next.load())) {
it = it_next;
it_next = it_next->next.load();
}
if(it == end_ptr) return end();
return iterator(it_next);
}
}
iterator end(){
return iterator(end_ptr);
}
};

View File

@ -0,0 +1,26 @@
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <vector>
#include <string>
#include <cstdlib>
#include "bitflags.h"
#include <atomic>
using namespace std;
template <typename T>
struct node_t {
//atomic<T> data;
T data;
atomic<node_t<T>* > next;
atomic<int> ref_count;
long long timestamp;
node_t<T> (const T& data_, node_t<T>* next_) {
//data.store(data_);
timestamp = -1;
data = data_;
next.store(next_);
ref_count.store(1);
}
};

View File

@ -0,0 +1,3 @@
#!/bin/bash
./testmain

View File

@ -0,0 +1,269 @@
#include <iostream>
#include <algorithm>
#include "pthread.h"
#include <cstdio>
#include <vector>
#include <string>
#include <cstdlib>
#include "list.h"
#include <ctime>
#include <sys/time.h>
#include <climits>
using namespace std;
#define FOR(i,a,b) for(int i=(a);i<(b);++i)
#define MAXP 128
#define SCAN_ITER 32
double rand_double(){
return ((double) rand())/RAND_MAX;
}
double rand_int(int max_int){
return rand_double()*max_int;
}
struct timestamping{
vector<long long> stamps;
vector<long long> finished;
atomic<long long> maxTS;
int n;
timestamping(int n_) {
n = n_;
stamps.resize(n,0);
finished.resize(n,0);
maxTS.store(0);
}
long long get_minimal(int id) {
long long sol = INT_MAX;
FOR(i,0,n)
if(finished[i] == 0 && i != id) sol = min(stamps[i],sol);
return sol;
}
long long increase_timestamp(int id) {
long long newTS = maxTS.fetch_add(1) + 1;
stamps[id] = newTS;
return newTS;
}
void mark_finished(int id) {
finished[id] = 1;
}
};
struct thread_data {
int id;
int op_cnt, max_int;
int add_cnt, remove_cnt, find_cnt, total_op;
double find_threshold, add_threshold;
int* values;
int* tasks;
timestamping* timestamps;
int max_buffer_size;
int reclaimed_cnt;
vector<node_t<int>* > buffer;
list_t<int>* list;
thread_data(list_t<int>* ptr, int id_, timestamping* timestamps_, int max_buffer_size_,
int op_cnt_, int max_int_, double find_=0.6, double add_ = 0.8) {
list = ptr;
id = id_;
timestamps = timestamps_;
max_buffer_size = max_buffer_size_;
reclaimed_cnt = 0;
op_cnt = op_cnt_;
max_int = max_int_;
add_cnt = 0;
remove_cnt = 0;
find_cnt = 0;
total_op = 0;
find_threshold = find_;
add_threshold = add_;
init_tasks();
}
void init_tasks() {
values = (int *)malloc(op_cnt*sizeof(int));
tasks = (int *)malloc(op_cnt*sizeof(int));
FOR(i,0,op_cnt) {
double x = rand_double();
int n = rand_int(max_int);
values[i] = n;
if( x < find_threshold )
tasks[i] = 0;
else if (x < add_threshold)
tasks[i] = 1;
else
tasks[i] = 2;
}
}
};
void *print_info( void* data) {
thread_data* ptr = (thread_data*)data;
cout << "Thread " << ptr->id << endl;
cout << "Read: " << ptr->find_cnt << " Add: " << ptr->add_cnt << " Remove: " << ptr->remove_cnt << endl;
cout << "Deallocated: " << ptr->reclaimed_cnt << " To be freed: " << ptr->buffer.size() << endl;
return NULL;
}
void *print_set( void* ptr ){
thread_data* data = (thread_data*)ptr;
FILE *out = fopen("out.txt","w+");
FOR(i,0,10) {
list_t<int>::iterator it = data->list->begin();
list_t<int>::iterator endit = data->list->end();
while(it!= endit) {
fprintf(out,"%d ",*it);
++it;
}
fprintf(out,"\n");
fflush(out);
}
fclose(out);
return NULL;
}
void scan_buffer(void *ptr){
thread_data* data = (thread_data*)ptr;
node_t<int>* tmp;
int min_timestamp = data->timestamps->get_minimal(data->id);
printf("Memory reclamation process %d Min timestamp %lld Size %d\n",data->id,min_timestamp,data->buffer.size());
vector<node_t<int>* > tmp_buffer;
FOR(i,0,data->buffer.size()) {
int ts = (data->buffer[i])->timestamp;
node_t<int>* next = list_t<int>::get_unflagged(data->buffer[i]->next.load());
//printf("Deleting: %d %d %d %d %lld %d\n",data->id,i,ts,data->buffer[i]->data,(long long)data->buffer[i],data->buffer[i]->ref_count.load());
if (ts < min_timestamp && data->buffer[i]->ref_count.load() == 0) {
next->ref_count.fetch_add(-1);
free(data->buffer[i]);
++(data->reclaimed_cnt);
}
else {
tmp_buffer.push_back(data->buffer[i]);
}
}
data->buffer = tmp_buffer;
}
void *test(void *ptr){
thread_data* data = (thread_data*)ptr;
int opcnt = data->op_cnt;
int maxint = data->max_int;
int id = data->id;
list_t<int>* list = data->list;
FOR(i,0,opcnt) {
/*
double x = rand_double();
int n = rand_int(maxint);
//cout << x << " " << n << endl;
if (x < data->find_threshold) {
//cout << 0 << endl;
list->contains(n);
++(data->find_cnt);
} else if(x < data->add_threshold) {
//cout << 1 << endl;
if(list->add(n)) {
++(data->add_cnt);
}
} else {
//cout << 2 << endl;
if(list->remove(n)) {
++(data->remove_cnt);
}
}
++(data->total_op);
*/
int n = data->values[i];
int op = data->tasks[i];
long long ts = data->timestamps->increase_timestamp(id);
//printf("Time: %lld Process: %d Operation count: %d Operation type: %d Value: %d\n",ts,id,i,op,n);
if (op == 0) {
list->contains(n);
++(data->find_cnt);
} else if(op == 1) {
if(list->add(n)) {
++(data->add_cnt);
}
} else {
node_t<int>* node_ptr = list->remove(n);
if(((long long) node_ptr)%4 !=0 ){
printf("oslo u pm\n"); fflush(stdout);
exit(0);
}
if(node_ptr != NULL ) {
node_ptr->timestamp = data->timestamps->maxTS.load() + 1;
data->buffer.push_back(node_ptr);
//printf("Process %d at time %d added reclamation node: %d %lld\n",id,node_ptr->timestamp,node_ptr->data,(long long)node_ptr);
++(data->remove_cnt);
}
}
fflush(stdout);
if( i % SCAN_ITER == 0 && data->buffer.size() >= data->max_buffer_size )
scan_buffer(ptr);
}
data->timestamps->mark_finished(id);
return NULL;
}
int main(int argc, char **argv){
int P = 1;
int op_cnt = 100;
if(argc > 1){
sscanf(argv[1],"%d",&P);
sscanf(argv[2],"%d",&op_cnt);
}
int max_int = 2048;
int limit = 1e9;
int initial = max_int/2;
int max_buffer_size = max_int/16;
struct timeval start,end;
timestamping timestamps(P);
vector<pthread_t> threads(P+1);
vector<thread_data> data;
list_t<int> *list = new list_t<int>(-limit,limit);
cout << "constructed list" << endl;
FOR(i,0,initial) {
list->add(i);
}
cout << "initialized list elements" << endl;
FOR(i,0,P) data.push_back(thread_data(list,i,&timestamps,max_buffer_size,op_cnt,max_int));
cout << "created thread inputs" << endl;
gettimeofday(&start,NULL);
FOR(i,0,P) pthread_create(&threads[i],NULL,test,((void *)(&data[i])));
pthread_create(&threads[P],NULL,print_set,((void *)(&data[0])));
cout << "created threads" << endl;
FOR(i,0,P+1) pthread_join(threads[i],NULL);
gettimeofday(&end,NULL);
cout << "execution finished" << endl;
FOR(i,0,P) print_info((void*)(&data[i]));
int exp_len = initial;
FOR(i,0,P)
exp_len += (data[i].add_cnt - data[i].remove_cnt);
uint64_t duration = (end.tv_sec*(uint64_t)1000000 + end.tv_usec) - (start.tv_sec*(uint64_t)1000000 + start.tv_usec);
cout << "Actual length: " << list->length.load() << endl;
cout << "Expected length: " << exp_len << endl;
cout << "Time(s): " << duration/1000000.0 << endl;
return 0;
}

View File

@ -0,0 +1 @@
g++ -std=c++11 macro_override.h test_macro.cc

View File

@ -0,0 +1,48 @@
#include <stdlib.h>
#include <memory>
#include <map>
#include <iostream>
size_t ALLOCATED = 0;
std::map<void*, size_t> TRACKER;
void* operator new(size_t size, const char* filename, int line_number) {
std::cout << filename << ":" << line_number << " Allocating" << size
<< "bytes." << std::endl;
void* block = malloc(size);
TRACKER[block] = size;
ALLOCATED += size;
return block;
}
void operator delete(void* block) {
TRACKER[block] = 0;
free(block);
}
void *operator new[](size_t size, const char* filename, int line_number) {
std::cout << filename << ":" << line_number << " Allocating" << size
<< "bytes." << std::endl;
void* block = malloc(size);
TRACKER[block] = size;
ALLOCATED += size;
return block;
}
void operator delete[] (void* block) {
TRACKER[block] = 0;
free(block);
}
void print_memory() {
std::cout << "Total bytes allocated: " << ALLOCATED << std::endl;
for (const auto& el : TRACKER) {
std::cout << el.first << " " << el.second << std::endl;
}
}
#define new new (__FILE__, __LINE__)

View File

@ -0,0 +1 @@
MALLOCSTATS=1 HEAPPROFILE=$(pwd)/profiling_tcmalloc.hprof LD_PRELOAD=/usr/lib/libtcmalloc.so ./a.out

View File

@ -0,0 +1,9 @@
#include <iostream>
int main() {
int* array = new int[16];
for (int i = 0; i < 16; i++)
*(array+i) = i;
std::cout << *(array+5);
return 0;
}

View File

@ -0,0 +1,16 @@
#include "macro_override.h"
int main () {
int* num = new int;
delete num;
print_memory();
ALLOCATED = 0;
int* nums = new int[16];
delete nums;
print_memory();
return 0;
}

2
experimental/pro_compiler/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
target

View File

@ -0,0 +1,9 @@
lazy val root = (project in file(".")).settings(
name := "cypher-compiler",
scalaVersion := "2.12.1",
version := "0.1",
libraryDependencies ++= Seq(
"org.scala-lang.modules" %% "scala-parser-combinators" % "1.0.5"
)
)

View File

@ -0,0 +1,434 @@
(*
* Copyright (c) 2015-2016 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*)
Cypher = [SP], Statement, [[SP], ';'], [SP] ;
Statement = Query ;
Query = RegularQuery ;
RegularQuery = SingleQuery, { [SP], Union } ;
SingleQuery = Clause, { [SP], Clause } ;
Union = ((U,N,I,O,N), SP, (A,L,L), [SP], SingleQuery)
| ((U,N,I,O,N), [SP], SingleQuery)
;
Clause = Match
| Unwind
| Merge
| Create
| Set
| Delete
| Remove
| With
| Return
;
Match = [(O,P,T,I,O,N,A,L), SP], (M,A,T,C,H), [SP], Pattern, [[SP], Where] ;
Unwind = (U,N,W,I,N,D), [SP], Expression, SP, (A,S), SP, Variable ;
Merge = (M,E,R,G,E), [SP], PatternPart, { SP, MergeAction } ;
MergeAction = ((O,N), SP, (M,A,T,C,H), SP, Set)
| ((O,N), SP, (C,R,E,A,T,E), SP, Set)
;
Create = (C,R,E,A,T,E), [SP], Pattern ;
Set = (S,E,T), [SP], SetItem, { ',', SetItem } ;
SetItem = (PropertyExpression, [SP], '=', [SP], Expression)
| (Variable, [SP], '=', [SP], Expression)
| (Variable, [SP], '+=', [SP], Expression)
| (Variable, [SP], NodeLabels)
;
Delete = [(D,E,T,A,C,H), SP], (D,E,L,E,T,E), [SP], Expression, { [SP], ',', [SP], Expression } ;
Remove = (R,E,M,O,V,E), SP, RemoveItem, { [SP], ',', [SP], RemoveItem } ;
RemoveItem = (Variable, NodeLabels)
| PropertyExpression
;
With = (W,I,T,H), [[SP], (D,I,S,T,I,N,C,T)], SP, ReturnBody, [[SP], Where] ;
Return = (R,E,T,U,R,N), [[SP], (D,I,S,T,I,N,C,T)], SP, ReturnBody ;
ReturnBody = ReturnItems, [SP, Order], [SP, Skip], [SP, Limit] ;
ReturnItems = ('*', { [SP], ',', [SP], ReturnItem })
| (ReturnItem, { [SP], ',', [SP], ReturnItem })
;
ReturnItem = (Expression, SP, (A,S), SP, Variable)
| Expression
;
Order = (O,R,D,E,R), SP, (B,Y), SP, SortItem, { ',', [SP], SortItem } ;
Skip = (S,K,I,P), SP, Expression ;
Limit = (L,I,M,I,T), SP, Expression ;
SortItem = Expression, [[SP], ((A,S,C,E,N,D,I,N,G) | (A,S,C) | (D,E,S,C,E,N,D,I,N,G) | (D,E,S,C))] ;
Where = (W,H,E,R,E), SP, Expression ;
Pattern = PatternPart, { [SP], ',', [SP], PatternPart } ;
PatternPart = (Variable, [SP], '=', [SP], AnonymousPatternPart)
| AnonymousPatternPart
;
AnonymousPatternPart = PatternElement ;
PatternElement = (NodePattern, { [SP], PatternElementChain })
| ('(', PatternElement, ')')
;
NodePattern = '(', [SP], [Variable, [SP]], [NodeLabels, [SP]], [Properties, [SP]], ')' ;
PatternElementChain = RelationshipPattern, [SP], NodePattern ;
RelationshipPattern = (LeftArrowHead, [SP], Dash, [SP], [RelationshipDetail], [SP], Dash, [SP], RightArrowHead)
| (LeftArrowHead, [SP], Dash, [SP], [RelationshipDetail], [SP], Dash)
| (Dash, [SP], [RelationshipDetail], [SP], Dash, [SP], RightArrowHead)
| (Dash, [SP], [RelationshipDetail], [SP], Dash)
;
RelationshipDetail = '[', [SP], [Variable, [SP]], [RelationshipTypes, [SP]], [RangeLiteral], [Properties, [SP]], ']' ;
Properties = MapLiteral
| Parameter
;
RelationshipTypes = ':', [SP], RelTypeName, { [SP], '|', [':'], [SP], RelTypeName } ;
NodeLabels = NodeLabel, { [SP], NodeLabel } ;
NodeLabel = ':', [SP], LabelName ;
RangeLiteral = '*', [SP], [IntegerLiteral, [SP]], ['..', [SP], [IntegerLiteral, [SP]]] ;
LabelName = SymbolicName ;
RelTypeName = SymbolicName ;
Expression = Expression12 ;
Expression12 = Expression11, { SP, (O,R), SP, Expression11 } ;
Expression11 = Expression10, { SP, (X,O,R), SP, Expression10 } ;
Expression10 = Expression9, { SP, (A,N,D), SP, Expression9 } ;
Expression9 = { (N,O,T), [SP] }, Expression8 ;
Expression8 = Expression7, { [SP], PartialComparisonExpression } ;
Expression7 = Expression6, { ([SP], '+', [SP], Expression6) | ([SP], '-', [SP], Expression6) } ;
Expression6 = Expression5, { ([SP], '*', [SP], Expression5) | ([SP], '/', [SP], Expression5) | ([SP], '%', [SP], Expression5) } ;
Expression5 = Expression4, { [SP], '^', [SP], Expression4 } ;
Expression4 = { ('+' | '-'), [SP] }, Expression3 ;
Expression3 = Expression2, { ([SP], '[', Expression, ']') | ([SP], '[', [Expression], '..', [Expression], ']') | ((([SP], '=~') | (SP, (I,N)) | (SP, (S,T,A,R,T,S), SP, (W,I,T,H)) | (SP, (E,N,D,S), SP, (W,I,T,H)) | (SP, (C,O,N,T,A,I,N,S))), [SP], Expression2) | (SP, (I,S), SP, (N,U,L,L)) | (SP, (I,S), SP, (N,O,T), SP, (N,U,L,L)) } ;
Expression2 = Atom, { [SP], (PropertyLookup | NodeLabels) } ;
Atom = Literal
| Parameter
| ((C,O,U,N,T), [SP], '(', [SP], '*', [SP], ')')
| ListComprehension
| PatternComprehension
| ((F,I,L,T,E,R), [SP], '(', [SP], FilterExpression, [SP], ')')
| ((E,X,T,R,A,C,T), [SP], '(', [SP], FilterExpression, [SP], [[SP], '|', Expression], ')')
| ((A,L,L), [SP], '(', [SP], FilterExpression, [SP], ')')
| ((A,N,Y), [SP], '(', [SP], FilterExpression, [SP], ')')
| ((N,O,N,E), [SP], '(', [SP], FilterExpression, [SP], ')')
| ((S,I,N,G,L,E), [SP], '(', [SP], FilterExpression, [SP], ')')
| RelationshipsPattern
| ParenthesizedExpression
| FunctionInvocation
| Variable
;
Literal = NumberLiteral
| StringLiteral
| BooleanLiteral
| (N,U,L,L)
| MapLiteral
| ListLiteral
;
BooleanLiteral = (T,R,U,E)
| (F,A,L,S,E)
;
ListLiteral = '[', [SP], [Expression, [SP], { ',', [SP], Expression, [SP] }], ']' ;
PartialComparisonExpression = ('=', [SP], Expression7)
| ('<>', [SP], Expression7)
| ('!=', [SP], Expression7)
| ('<', [SP], Expression7)
| ('>', [SP], Expression7)
| ('<=', [SP], Expression7)
| ('>=', [SP], Expression7)
;
ParenthesizedExpression = '(', [SP], Expression, [SP], ')' ;
RelationshipsPattern = NodePattern, { [SP], PatternElementChain }- ;
FilterExpression = IdInColl, [[SP], Where] ;
IdInColl = Variable, SP, (I,N), SP, Expression ;
FunctionInvocation = FunctionName, [SP], '(', [SP], [(D,I,S,T,I,N,C,T), [SP]], [Expression, [SP], { ',', [SP], Expression, [SP] }], ')' ;
FunctionName = SymbolicName ;
ListComprehension = '[', [SP], FilterExpression, [[SP], '|', [SP], Expression], [SP], ']' ;
PatternComprehension = '[', [SP], [Variable, [SP], '=', [SP]], RelationshipsPattern, [SP], [(W,H,E,R,E), [SP], Expression, [SP]], '|', [SP], Expression, [SP], ']' ;
PropertyLookup = '.', [SP], (PropertyKeyName) ;
Variable = SymbolicName ;
StringLiteral = ('"', { ANY - ('"' | '\') | EscapedChar }, '"')
| ("'", { ANY - ("'" | '\') | EscapedChar }, "'")
;
EscapedChar = '\', ('\' | "'" | '"' | (B) | (F) | (N) | (R) | (T) | ((U), 4 * HexDigit) | ((U), 8 * HexDigit)) ;
NumberLiteral = DoubleLiteral
| IntegerLiteral
;
MapLiteral = '{', [SP], [PropertyKeyName, [SP], ':', [SP], Expression, [SP], { ',', [SP], PropertyKeyName, [SP], ':', [SP], Expression, [SP] }], '}' ;
Parameter = '$', (SymbolicName | DecimalInteger) ;
PropertyExpression = Atom, { [SP], PropertyLookup }- ;
PropertyKeyName = SymbolicName ;
IntegerLiteral = HexInteger
| OctalInteger
| DecimalInteger
;
HexInteger = '0x', { HexDigit }- ;
DecimalInteger = ZeroDigit
| (NonZeroDigit, { Digit })
;
OctalInteger = ZeroDigit, { OctDigit }- ;
HexLetter = (A)
| (B)
| (C)
| (D)
| (E)
| (F)
;
HexDigit = Digit
| HexLetter
;
Digit = ZeroDigit
| NonZeroDigit
;
NonZeroDigit = NonZeroOctDigit
| '8'
| '9'
;
NonZeroOctDigit = '1'
| '2'
| '3'
| '4'
| '5'
| '6'
| '7'
;
OctDigit = ZeroDigit
| NonZeroOctDigit
;
ZeroDigit = '0' ;
DoubleLiteral = ExponentDecimalReal
| RegularDecimalReal
;
ExponentDecimalReal = ({ Digit }- | ({ Digit }-, '.', { Digit }-) | ('.', { Digit }-)), ((E) | (E)), ['-'], { Digit }- ;
RegularDecimalReal = { Digit }, '.', { Digit }- ;
SymbolicName = UnescapedSymbolicName
| EscapedSymbolicName
;
UnescapedSymbolicName = IdentifierStart, { IdentifierPart } ;
(* Based on the unicode identifier and pattern syntax
* (http://www.unicode.org/reports/tr31/)
* And extended with a few characters.
*)IdentifierStart = ID_Start
| '_'
| ''
| ''
| ''
| ''
| ''
| ''
| ''
| ''
| '_'
;
(* Based on the unicode identifier and pattern syntax
* (http://www.unicode.org/reports/tr31/)
* And extended with a few characters.
*)IdentifierPart = ID_Continue
| Sc
;
(* Any character except "`", enclosed within `backticks`. Backticks are escaped with double backticks. *)EscapedSymbolicName = { '`', { ANY - ('`') }, '`' }- ;
SP = { whitespace }- ;
whitespace = SPACE
| TAB
| LF
| VT
| FF
| CR
| FS
| GS
| RS
| US
| ''
| ''
| ' '
| ''
| ''
| ''
| ''
| ''
| ''
| ''
| ''
| ''
| ''
| ''
| ''
| ' '
| ' '
| ''
| ''
| Comment
;
Comment = ('/*', { ANY - ('*') | ('*', ANY - ('/')) }, '*/')
| ('//', { ANY - (LF | CR) }, [CR], (LF | EOI))
;
LeftArrowHead = '<'
| ''
| ''
| ''
| ''
;
RightArrowHead = '>'
| ''
| ''
| ''
| ''
;
Dash = '-'
| '­'
| ''
| ''
| ''
| ''
| ''
| ''
| ''
| ''
| ''
| ''
;
A = 'A' | 'a' ;
B = 'B' | 'b' ;
C = 'C' | 'c' ;
D = 'D' | 'd' ;
E = 'E' | 'e' ;
F = 'F' | 'f' ;
G = 'G' | 'g' ;
H = 'H' | 'h' ;
I = 'I' | 'i' ;
K = 'K' | 'k' ;
L = 'L' | 'l' ;
M = 'M' | 'm' ;
N = 'N' | 'n' ;
O = 'O' | 'o' ;
P = 'P' | 'p' ;
R = 'R' | 'r' ;
S = 'S' | 's' ;
T = 'T' | 't' ;
U = 'U' | 'u' ;
V = 'V' | 'v' ;
W = 'W' | 'w' ;
X = 'X' | 'x' ;
Y = 'Y' | 'y' ;

33
experimental/pro_compiler/sbt Executable file
View File

@ -0,0 +1,33 @@
#!/bin/bash
SCRIPT_PATH=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )
SBT_OPTS="-Xms768M -Xmx3072M -Xss1M -XX:+CMSClassUnloadingEnabled -XX:MaxPermSize=256M"
if [ "$JENKINS_NIGHTLY_BUILD" == "true" ]; then
SBT_ARGS="-Dsbt.log.noformat=true"
fi
if hash cygpath.exe 2>/dev/null; then
echo "Using cygpath to convert path to SBT."
SBT_CYG_JAR_PATH=`realpath "${SCRIPT_PATH}/sbt-launch.jar"`
SBT_JAR_PATH=`cygpath.exe -w "${SBT_CYG_JAR_PATH}"`
echo "Using Windows path: ${SBT_JAR_PATH}"
SBT_ARGS="-Djline.terminal=jline.UnixTerminal -Dsbt.cygwin=true ${SBT_ARGS}"
else
echo "No cygpath, apparently not using Cygwin."
SBT_JAR_PATH="${SCRIPT_PATH}/sbt-launch.jar"
fi
SBT_CMD="java ${SBT_OPTS} ${SBT_ARGS} -jar \"${SBT_JAR_PATH}\""
if hash cygpath.exe 2>/dev/null; then
stty -icanon min 1 -echo > /dev/null 2>&1
fi
echo "Running: ${SBT_CMD}"
echo "Arguments: $@"
eval ${SBT_CMD} $@
if hash cygpath.exe 2>/dev/null; then
stty icanon echo > /dev/null 2>&1
fi

Binary file not shown.

View File

@ -0,0 +1,626 @@
package org.cypher
import scala.collection._
import scala.reflect.ClassTag
import scala.util.parsing.combinator._
// Parsing
trait CharacterStream {
def asString: String
}
object CharacterStream {
case class Default(asString: String) extends CharacterStream
}
// Query trees
trait Indexed {
private var rawIndex: Int = -1
def index: Int = rawIndex
def index_=(i: Int) = rawIndex = i
}
sealed abstract class Tree extends Indexed
case class Query(
val clauses: Seq[Clause]
) extends Tree
sealed abstract class Clause extends Tree
case class Match(
optional: Boolean,
patterns: Seq[Pattern],
where: Where
) extends Clause
case class Where(
expr: Expr
) extends Tree
sealed abstract class Expr extends Tree
case class Literal(value: Value) extends Expr
case class Ident(name: Name) extends Expr
case class Equals(left: Expr, right: Expr) extends Expr
case class PatternExpr(pattern: Pattern) extends Expr
sealed trait Value
case class Bool(x: Boolean) extends Value
case class Return(expr: Expr) extends Clause
case class Pattern(name: Name, elems: Seq[PatternElement]) extends Tree
sealed abstract class PatternElement extends Tree
case class NodePatternElement(
name: Name,
labels: Seq[Name],
properties: Properties
) extends PatternElement
case class EdgePatternElement(
left: Boolean,
right: Boolean,
name: Name,
labels: Seq[Name],
range: RangeSpec,
properties: Properties
) extends PatternElement
case class RangeSpec(
left: Option[Int],
right: Option[Int]
) extends Tree
case class Properties(contents: Map[Name, Expr]) extends Tree
case class Name private[cypher] (private val raw: String) {
def asString = raw
override def toString = s"Name($raw)"
}
// Semantic analysis
class Symbol(val name: Name, val tpe: Type) {
override def toString = s"<$name.asString: $tpe>"
}
trait Type
case class NodeType()
case class EdgeType()
class Table[T <: AnyRef: ClassTag] {
private var array = new Array[T](100)
private def ensureSize(sz: Int) {
if (array.length < sz) {
val narray = new Array[T](sz)
System.arraycopy(array, 0, narray, 0, array.length)
array = narray
}
}
def apply(idx: Indexed): T = {
array(idx.index)
}
def update(idx: Indexed, v: T) = {
ensureSize(idx.index + 1)
array(idx.index) = v
}
}
class SymbolTable() {
private val rawTable = mutable.Map[Name, Symbol]()
def apply(name: Name): Symbol = rawTable(name)
def update(name: Name, sym: Symbol) = {
assert(!rawTable.contains(name))
rawTable(name) = sym
}
def getOrCreate(name: Name): Symbol = rawTable.get(name) match {
case Some(sym) =>
sym
case None =>
val sym = new Symbol(name, null)
rawTable(name) = sym
sym
}
}
case class TypecheckedTree(
tree: Tree,
symbols: SymbolTable,
types: Table[Type]
)
class Namer {
private var count = 0
def freshName(): Name = {
count += 1
return Name(s"anon-$count")
}
}
// Data model
trait Node
trait Edge
trait Index
trait Database {
def nodeIndices(label: Name): Seq[Index]
def edgeIndices(label: Name): Seq[Index]
}
object Database {
case class Default() extends Database {
def nodeIndices(label: Name) = Nil
def edgeIndices(label: Name) = Nil
}
}
// Logical plan
sealed trait LogicalPlan {
def outputs: Seq[Symbol]
def children: Seq[LogicalPlan]
def operatorName = s"${getClass.getSimpleName}"
def prettySelf: String = s"${operatorName}"
def prettyVars: String = outputs.map(_.name.asString).mkString(", ")
def pretty: String = {
def titleWidth(plan: LogicalPlan): Int = {
var w = 1 + plan.operatorName.length
for (child <- plan.children) {
val cw = titleWidth(child) + (plan.children.length - 1) * 2
if (w < cw) w = cw
}
w
}
val w = math.max(12, titleWidth(this))
var s = ""
s += " Operator " + " " * (w - 7) + "| Variables \n"
s += "-" * (w + 3) + "+" + "-" * 30 + "\n"
def print(indent: Int, plan: LogicalPlan): Unit = {
val leftspace = "| " * indent
val title = plan.prettySelf
val rightspace = " " * (w - title.length - 2 * indent)
val rightpadspace = " " * (w - 2 * indent)
val vars = plan.prettyVars
s += s""" $leftspace+$title$rightspace | $vars\n"""
if (plan.children.nonEmpty) {
s += s""" $leftspace|$rightpadspace | \n"""
}
for ((child, idx) <- plan.children.zipWithIndex.reverse) {
print(idx, child)
}
}
print(0, this)
s
}
}
object LogicalPlan {
sealed trait Source extends LogicalPlan {
}
case class ScanAll(outputs: Seq[Symbol], patternElem: NodePatternElement)
extends Source {
def children = Nil
// def evaluate(db: Database): Stream = {
// db.iterator().filter { n =>
// patternElem match {
// case NodePatternElement(name, labels, properties) =>
// if (labels.subset(n.labels) && properties.subset(n.properties)) {
// Some(Seq(n))
// } else {
// None
// }
// }
// }
// }
}
case class SeekByNodeLabelIndex(
index: Index, name: Name, outputs: Seq[Symbol], patternElem: NodePatternElement
) extends Source {
def children = Nil
}
case class ExpandAll(input: LogicalPlan, edgePattern: EdgePatternElement, nodePattern: NodePatternElement) extends LogicalPlan {
def outputs = ???
def children = ???
}
case class ExpandInto(input: LogicalPlan) extends LogicalPlan {
def outputs = ???
def children = ???
}
case class Filter(input: LogicalPlan, expr: Expr)
extends LogicalPlan {
def outputs = input.outputs
def children = Seq(input)
}
sealed trait Sink extends LogicalPlan
case class Produce(input: LogicalPlan, outputs: Seq[Symbol], expr: Expr)
extends Sink {
def children = Seq(input)
}
}
trait Emitter {
def emit[T](symbol: Symbol, value: T): Unit
}
// Physical plan
case class PhysicalPlan(val logicalPlan: LogicalPlan) {
def execute(db: Database): Stream = {
println(logicalPlan.pretty)
???
}
}
case class Cost()
trait Stream
trait Target
// Phases
trait Phase[In, Out] {
def apply(input: In): Out
}
case class Parser(ctx: Context) extends Phase[CharacterStream, Tree] {
object CypherParser extends RegexParsers {
def query: Parser[Query] = rep(clause) ^^ {
case clauses => Query(clauses)
}
def clause: Parser[Clause] = `match` | `return` ^^ {
case clause => clause
}
def `match`: Parser[Match] = opt("optional") ~ "match" ~ patterns ~ opt(where) ^^ {
case optional ~ _ ~ p ~ Some(w) =>
Match(optional.nonEmpty, p, w)
case optional ~ _ ~ p ~ None =>
Match(optional.nonEmpty, p, Where(Literal(Bool(true))))
}
def patterns: Parser[Seq[Pattern]] = nodePattern ~ rep(edgeAndNodePattern) ^^ {
case node ~ edgeNodes =>
val ps = node +: edgeNodes.map({ case (edge, node) => Seq(edge, node) }).flatten
Seq(Pattern(ctx.namer.freshName(), ps))
}
def nodePattern: Parser[NodePatternElement] = "(" ~ ident ~ ")" ^^ {
case _ ~ ident ~ _ => NodePatternElement(ident.name, Nil, Properties(Map()))
}
def edgeAndNodePattern: Parser[(EdgePatternElement, NodePatternElement)] =
edgePattern ~ nodePattern ^^ {
case edge ~ node => (edge, node)
}
def edgePattern: Parser[EdgePatternElement] =
opt("<") ~ "--" ~ opt(">") ^^ {
case left ~ _ ~ right => EdgePatternElement(
left.nonEmpty,
right.nonEmpty,
ctx.namer.freshName(),
Nil,
RangeSpec(None, None),
Properties(Map())
)
}
def where: Parser[Where] = "where" ~ expr ^^ {
case _ ~ expr => Where(expr)
}
def `return`: Parser[Return] = "return" ~ expr ^^ {
case _ ~ expr => Return(expr)
}
def expr: Parser[Expr] = literal | ident | binary
def binary: Parser[Expr] = equals
def equals: Parser[Expr] = expr ~ "=" ~ expr ^^ {
case left ~ _ ~ right => Equals(left, right)
}
def literal: Parser[Literal] = boolean ^^ {
case x => x
}
def boolean: Parser[Literal] = ("true" | "false") ^^ {
case "true" => Literal(Bool(true))
case "false" => Literal(Bool(false))
}
def ident: Parser[Ident] = "[a-z]+".r ^^ {
case s => Ident(Name(s))
}
}
def apply(tokens: CharacterStream): Tree = {
CypherParser.parseAll(CypherParser.query, tokens.asString) match {
case CypherParser.Success(tree, _) => tree
case failure: CypherParser.NoSuccess => sys.error(failure.msg)
}
}
}
case class Typechecker() extends Phase[Tree, TypecheckedTree] {
private class Instance(val symbols: SymbolTable, val types: Table[Type]) {
def traverse(tree: Tree): Unit = tree match {
case Query(clauses) =>
for (clause <- clauses) traverse(clause)
case Match(opt, patterns, where) =>
for (pattern <- patterns) traverse(pattern)
// Ignore where for now.
case Pattern(name, elems) =>
for (elem <- elems) traverse(elem)
case NodePatternElement(name, _, _) =>
symbols.getOrCreate(name)
case _ =>
// Ignore for now.
}
def typecheck(tree: Tree) = {
traverse(tree)
TypecheckedTree(tree, symbols, types)
}
}
def apply(tree: Tree): TypecheckedTree = {
val symbols = new SymbolTable()
val types = new Table[Type]
val instance = new Instance(symbols, types)
instance.typecheck(tree)
}
}
case class LogicalPlanner(val ctx: Context)
extends Phase[TypecheckedTree, LogicalPlan] {
def apply(tree: TypecheckedTree): LogicalPlan = {
ctx.config.logicalPlanGenerator.generate(tree, ctx).next()
}
}
case class PhysicalPlanner(val ctx: Context)
extends Phase[LogicalPlan, PhysicalPlan] {
def apply(plan: LogicalPlan): PhysicalPlan = {
ctx.config.physicalPlanGenerator.generate(plan).next()
}
}
trait LogicalPlanGenerator {
def generate(typedTree: TypecheckedTree, ctx: Context): Iterator[LogicalPlan]
}
object LogicalPlanGenerator {
case class Default() extends LogicalPlanGenerator {
class Instance(val typedTree: TypecheckedTree, val ctx: Context) {
private def findOutputs(pat: NodePatternElement): Seq[Symbol] = {
val sym = typedTree.symbols(pat.name)
Seq(sym)
}
private def findOutputs(expr: Expr): Seq[Symbol] = {
Seq()
}
private def genSource(pat: NodePatternElement): LogicalPlan = {
pat.labels.find(label => ctx.database.nodeIndices(label).nonEmpty) match {
case Some(label) =>
val index = ctx.database.nodeIndices(label).head
val outputs = findOutputs(pat)
LogicalPlan.SeekByNodeLabelIndex(index, label, outputs, pat)
case None =>
val outputs = findOutputs(pat)
LogicalPlan.ScanAll(outputs, pat)
}
}
private def genPattern(elems: Seq[PatternElement]): LogicalPlan = {
assert(elems.size == 1)
val source = genSource(elems.head.asInstanceOf[NodePatternElement])
source
}
private def genSourceClause(clause: Clause): LogicalPlan = clause match {
case Match(opt, patterns, where) =>
// Create source.
assert(patterns.length == 1)
val Pattern(_, elements) = patterns.head
val plan = genPattern(elements)
// Add a filter.
new LogicalPlan.Filter(plan, where.expr)
case tree =>
sys.error(s"Unsupported source clause: $tree.")
}
def genReturn(input: LogicalPlan, ret: Return): LogicalPlan.Produce = {
val outputs = findOutputs(ret.expr)
LogicalPlan.Produce(input, outputs, ret.expr)
}
def genQueryPlan(tree: Tree): LogicalPlan = tree match {
case Query(clauses) =>
var plan = genSourceClause(clauses.head)
for (clause <- clauses.tail) clause match {
case ret @ Return(_) =>
plan = genReturn(plan, ret)
case clause =>
sys.error(s"Unsupported clause: $tree.")
}
plan
case tree =>
sys.error(s"Not a valid query: $tree.")
}
}
def generate(typedTree: TypecheckedTree, ctx: Context) = {
val instance = new Instance(typedTree, ctx)
val plan = instance.genQueryPlan(typedTree.tree)
Iterator(plan)
}
}
}
trait PhysicalPlanGenerator {
def generate(tree: LogicalPlan): Iterator[PhysicalPlan]
}
object PhysicalPlanGenerator {
case class Default() extends PhysicalPlanGenerator {
def generate(plan: LogicalPlan) = {
Iterator(PhysicalPlan(plan))
}
}
}
case class Configuration(
logicalPlanGenerator: LogicalPlanGenerator,
physicalPlanGenerator: PhysicalPlanGenerator,
estimator: PhysicalPlan => Cost
)
case class Context(
config: Configuration,
namer: Namer,
database: Database
)
object Configuration {
val defaultLogicalPlanGenerator = LogicalPlanGenerator.Default()
val defaultPhysicalPlanGenerator = PhysicalPlanGenerator.Default()
val defaultEstimator = (plan: PhysicalPlan) => {
Cost()
}
def default() = Configuration(
defaultLogicalPlanGenerator,
defaultPhysicalPlanGenerator,
defaultEstimator
)
}
trait Interpreter {
def interpret(query: CharacterStream): Stream
}
class DefaultInterpreter(ctx: Context) extends Interpreter {
def interpret(query: CharacterStream): Stream = {
val tree = Parser(ctx).apply(query)
val typedTree = Typechecker().apply(tree)
val logicalPlan = LogicalPlanner(ctx).apply(typedTree)
val physicalPlan = PhysicalPlanner(ctx).apply(logicalPlan)
physicalPlan.execute(ctx.database)
}
}
trait Compiler {
def compile(query: CharacterStream): Target
}
object Main {
def main(args: Array[String]) {
val db = Database.Default()
val config = Configuration.default()
val namer = new Namer
val ctx = Context(config, namer, db)
val interpreter = new DefaultInterpreter(ctx)
val query = CharacterStream.Default("""
match (a)
return a
""")
interpreter.interpret(query)
}
}