diff --git a/tests/manual/card_fraud_local.cpp b/tests/manual/card_fraud_local.cpp new file mode 100644 index 000000000..a83931d41 --- /dev/null +++ b/tests/manual/card_fraud_local.cpp @@ -0,0 +1,84 @@ +#include +#include +#include +#include + +#include "gflags/gflags.h" + +#include "distributed_common.hpp" + +DEFINE_int32(num_tx_creators, 3, "Number of threads creating transactions"); +DEFINE_int32(tx_per_thread, 1000, "Number of transactions each thread creates"); + +int main(int argc, char *argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + Cluster cluster(5); + + // static thread_local std::mt19937 rand_dev{std::random_device{}()}; + // static thread_local std::uniform_int_distribution<> int_dist; + + // auto rint = [&rand_dev, &int_dist](int upper) { + // return int_dist(rand_dev) % upper; + // }; + + cluster.Execute("CREATE INDEX ON :Card(id)"); + cluster.Execute("CREATE INDEX ON :Transaction(id)"); + cluster.Execute("CREATE INDEX ON :Pos(id)"); + + int kCardCount = 20000; + int kPosCount = 20000; + + cluster.Execute("UNWIND range(0, $card_count) AS id CREATE (:Card {id:id})", + {{"card_count", kCardCount - 1}}); + cluster.Execute("UNWIND range(0, $pos_count) AS id CREATE (:Pos {id:id})", + {{"pos_count", kPosCount - 1}}); + + CheckResults(cluster.Execute("MATCH (:Pos) RETURN count(1)"), {{kPosCount}}, + "Failed to create POS"); + CheckResults(cluster.Execute("MATCH (:Card) RETURN count(1)"), {{kCardCount}}, + "Failed to create Cards"); + + std::atomic tx_counter{0}; + auto create_tx = [&cluster, kCardCount, kPosCount, &tx_counter](int count) { + std::mt19937 rand_dev{std::random_device{}()}; + std::uniform_int_distribution<> int_dist; + + auto rint = [&rand_dev, &int_dist](int upper) { + return int_dist(rand_dev) % upper; + }; + + for (int i = 0; i < count; ++i) { + try { + auto res = cluster.Execute( + "MATCH (p:Pos {id: $pos}), (c:Card {id: $card}) " + "CREATE (p)<-[:At]-(:Transaction {id : $tx})-[:Using]->(c) " + "RETURN count(1)", + {{"pos", rint(kPosCount)}, + {"card", rint(kCardCount)}, + {"tx", tx_counter++}}); + CheckResults(res, {{1}}, "Transaction creation"); + } catch (LockTimeoutException &) { + --i; + } catch (mvcc::SerializationError &) { + --i; + } + if (i > 0 && i % 200 == 0) + LOG(INFO) << "Created " << i << " transacitons"; + } + }; + + LOG(INFO) << "Creating " << FLAGS_num_tx_creators * FLAGS_tx_per_thread + << " transactions in " << FLAGS_num_tx_creators << " threads"; + std::vector tx_creators; + for (int i = 0; i < FLAGS_num_tx_creators; ++i) + tx_creators.emplace_back(create_tx, FLAGS_tx_per_thread); + for (auto &t : tx_creators) t.join(); + + CheckResults(cluster.Execute("MATCH (:Transaction) RETURN count(1)"), + {{FLAGS_num_tx_creators * FLAGS_tx_per_thread}}, + "Failed to create Transactions"); + + LOG(INFO) << "Test terminated successfully"; + return 0; +} diff --git a/tests/manual/distributed_common.hpp b/tests/manual/distributed_common.hpp new file mode 100644 index 000000000..88b467b80 --- /dev/null +++ b/tests/manual/distributed_common.hpp @@ -0,0 +1,96 @@ +#pragma once + +#include +#include + +#include "communication/result_stream_faker.hpp" +#include "database/graph_db_accessor.hpp" +#include "query/interpreter.hpp" +#include "query/typed_value.hpp" + +class WorkerInThread { + public: + explicit WorkerInThread(database::Config config) : worker_(config) { + thread_ = std::thread([this, config] { worker_.WaitForShutdown(); }); + } + + ~WorkerInThread() { + if (thread_.joinable()) thread_.join(); + } + + database::Worker worker_; + std::thread thread_; +}; + +class Cluster { + const std::chrono::microseconds kInitTime{200}; + const std::string kLocal = "127.0.0.1"; + + public: + Cluster(int worker_count) { + database::Config masterconfig; + masterconfig.master_endpoint = {kLocal, 0}; + master_ = std::make_unique(masterconfig); + std::this_thread::sleep_for(kInitTime); + + auto worker_config = [this](int worker_id) { + database::Config config; + config.worker_id = worker_id; + config.master_endpoint = master_->endpoint(); + config.worker_endpoint = {kLocal, 0}; + return config; + }; + + for (int i = 0; i < worker_count; ++i) { + workers_.emplace_back( + std::make_unique(worker_config(i + 1))); + std::this_thread::sleep_for(kInitTime); + } + } + + void Stop() { + master_ = nullptr; + workers_.clear(); + } + + ~Cluster() { + if (master_) Stop(); + } + + auto Execute(const std::string &query, + std::map params = {}) { + database::GraphDbAccessor dba(*master_); + ResultStreamFaker result; + interpreter_(query, dba, params, false).PullAll(result); + dba.Commit(); + return result.GetResults(); + }; + + private: + std::unique_ptr master_; + std::vector> workers_; + query::Interpreter interpreter_; +}; + +void CheckResults( + const std::vector> &results, + const std::vector> &expected_rows, + const std::string &msg) { + query::TypedValue::BoolEqual equality; + CHECK(results.size() == expected_rows.size()) + << msg << " (expected " << expected_rows.size() << " rows " + << ", got " << results.size() << ")"; + for (size_t row_id = 0; row_id < results.size(); ++row_id) { + auto &result = results[row_id]; + auto &expected = expected_rows[row_id]; + CHECK(result.size() == expected.size()) + << msg << " (expected " << expected.size() << " elements in row " + << row_id << ", got " << result.size() << ")"; + for (size_t col_id = 0; col_id < result.size(); ++col_id) { + CHECK(equality(result[col_id], expected[col_id])) + << msg << " (expected value '" << expected[col_id] << "' got '" + << result[col_id] << "' in row " << row_id << " col " << col_id + << ")"; + } + } +}