// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // License, and you may not use this file except in compliance with the Business Source License. // // As of the Change Date specified in that file, in accordance with // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. #include #include #include #include #include #include #include #include #include "io/address.hpp" #include "io/rsm/raft.hpp" #include "io/simulator/simulator.hpp" #include "io/simulator/simulator_transport.hpp" using memgraph::io::Address; using memgraph::io::Duration; using memgraph::io::Io; using memgraph::io::ResponseEnvelope; using memgraph::io::ResponseFuture; using memgraph::io::ResponseResult; using memgraph::io::Time; using memgraph::io::rsm::Raft; using memgraph::io::rsm::ReadRequest; using memgraph::io::rsm::ReadResponse; using memgraph::io::rsm::WriteRequest; using memgraph::io::rsm::WriteResponse; using memgraph::io::simulator::Simulator; using memgraph::io::simulator::SimulatorConfig; using memgraph::io::simulator::SimulatorStats; using memgraph::io::simulator::SimulatorTransport; struct CasRequest { int key; std::optional old_value; std::optional new_value; }; struct CasResponse { bool cas_success; std::optional last_value; }; struct GetRequest { int key; }; struct GetResponse { std::optional value; }; class TestState { std::map state_; public: GetResponse Read(GetRequest request) { GetResponse ret; if (state_.contains(request.key)) { ret.value = state_[request.key]; } return ret; } CasResponse Apply(CasRequest request) { CasResponse ret; // Key exist if (state_.contains(request.key)) { auto &val = state_[request.key]; /* * Delete */ if (!request.new_value) { ret.last_value = val; ret.cas_success = true; state_.erase(state_.find(request.key)); } /* * Update */ // Does old_value match? if (request.old_value == val) { ret.last_value = val; ret.cas_success = true; val = request.new_value.value(); } else { ret.last_value = val; ret.cas_success = false; } } /* * Create */ else { ret.last_value = std::nullopt; ret.cas_success = true; state_.emplace(request.key, std::move(request.new_value).value()); } return ret; } }; template void RunRaft(Raft server) { server.Run(); } void RunSimulation() { SimulatorConfig config{ .drop_percent = 5, .perform_timeouts = true, .scramble_messages = true, .rng_seed = 0, .start_time = Time::min() + std::chrono::microseconds{256 * 1024}, .abort_time = Time::min() + std::chrono::microseconds{8 * 1024 * 128}, }; auto simulator = Simulator(config); auto cli_addr = Address::TestAddress(1); auto srv_addr_1 = Address::TestAddress(2); auto srv_addr_2 = Address::TestAddress(3); auto srv_addr_3 = Address::TestAddress(4); Io cli_io = simulator.Register(cli_addr); Io srv_io_1 = simulator.Register(srv_addr_1); Io srv_io_2 = simulator.Register(srv_addr_2); Io srv_io_3 = simulator.Register(srv_addr_3); std::vector
srv_1_peers = {srv_addr_2, srv_addr_3}; std::vector
srv_2_peers = {srv_addr_1, srv_addr_3}; std::vector
srv_3_peers = {srv_addr_1, srv_addr_2}; // TODO(tyler / gabor) supply default TestState to Raft constructor using RaftClass = Raft; RaftClass srv_1{std::move(srv_io_1), srv_1_peers, TestState{}}; RaftClass srv_2{std::move(srv_io_2), srv_2_peers, TestState{}}; RaftClass srv_3{std::move(srv_io_3), srv_3_peers, TestState{}}; auto srv_thread_1 = std::jthread(RunRaft, std::move(srv_1)); simulator.IncrementServerCountAndWaitForQuiescentState(srv_addr_1); auto srv_thread_2 = std::jthread(RunRaft, std::move(srv_2)); simulator.IncrementServerCountAndWaitForQuiescentState(srv_addr_2); auto srv_thread_3 = std::jthread(RunRaft, std::move(srv_3)); simulator.IncrementServerCountAndWaitForQuiescentState(srv_addr_3); spdlog::info("beginning test after servers have become quiescent"); std::mt19937 cli_rng_{0}; Address server_addrs[]{srv_addr_1, srv_addr_2, srv_addr_3}; Address leader = server_addrs[0]; const int key = 0; std::optional last_known_value = 0; bool success = false; for (int i = 0; !success; i++) { // send request CasRequest cas_req; cas_req.key = key; cas_req.old_value = last_known_value; cas_req.new_value = i; WriteRequest cli_req; cli_req.operation = cas_req; spdlog::info("client sending CasRequest to Leader {} ", leader.last_known_port); ResponseFuture> cas_response_future = cli_io.Request, WriteResponse>(leader, cli_req); // receive cas_response ResponseResult> cas_response_result = std::move(cas_response_future).Wait(); if (cas_response_result.HasError()) { spdlog::info("client timed out while trying to communicate with assumed Leader server {}", leader.last_known_port); continue; } ResponseEnvelope> cas_response_envelope = cas_response_result.GetValue(); WriteResponse write_cas_response = cas_response_envelope.message; if (write_cas_response.retry_leader) { MG_ASSERT(!write_cas_response.success, "retry_leader should never be set for successful responses"); leader = write_cas_response.retry_leader.value(); spdlog::info("client redirected to leader server {}", leader.last_known_port); } else if (!write_cas_response.success) { std::uniform_int_distribution addr_distrib(0, 2); size_t addr_index = addr_distrib(cli_rng_); leader = server_addrs[addr_index]; spdlog::info("client NOT redirected to leader server, trying a random one at index {} with port {}", addr_index, leader.last_known_port); continue; } CasResponse cas_response = write_cas_response.write_return; bool cas_succeeded = cas_response.cas_success; spdlog::info("Client received CasResponse! success: {} last_known_value {}", cas_succeeded, (int)*last_known_value); if (cas_succeeded) { last_known_value = i; } else { last_known_value = cas_response.last_value; continue; } GetRequest get_req; get_req.key = key; ReadRequest read_req; read_req.operation = get_req; spdlog::info("client sending GetRequest to Leader {}", leader.last_known_port); ResponseFuture> get_response_future = cli_io.Request, ReadResponse>(leader, read_req); // receive response ResponseResult> get_response_result = std::move(get_response_future).Wait(); if (get_response_result.HasError()) { spdlog::info("client timed out while trying to communicate with Leader server {}", leader.last_known_port); continue; } ResponseEnvelope> get_response_envelope = get_response_result.GetValue(); ReadResponse read_get_response = get_response_envelope.message; if (!read_get_response.success) { // sent to a non-leader continue; } if (read_get_response.retry_leader) { MG_ASSERT(!read_get_response.success, "retry_leader should never be set for successful responses"); leader = read_get_response.retry_leader.value(); spdlog::info("client redirected to Leader server {}", leader.last_known_port); } else if (!read_get_response.success) { std::uniform_int_distribution addr_distrib(0, 2); size_t addr_index = addr_distrib(cli_rng_); leader = server_addrs[addr_index]; spdlog::info("client NOT redirected to leader server, trying a random one at index {} with port {}", addr_index, leader.last_known_port); } GetResponse get_response = read_get_response.read_return; MG_ASSERT(get_response.value == i); spdlog::info("client successfully cas'd a value and read it back! value: {}", i); success = true; } MG_ASSERT(success); simulator.ShutDown(); SimulatorStats stats = simulator.Stats(); spdlog::info("total messages: {}", stats.total_messages); spdlog::info("dropped messages: {}", stats.dropped_messages); spdlog::info("timed out requests: {}", stats.timed_out_requests); spdlog::info("total requests: {}", stats.total_requests); spdlog::info("total responses: {}", stats.total_responses); spdlog::info("simulator ticks: {}", stats.simulator_ticks); spdlog::info("========================== SUCCESS :) =========================="); /* this is implicit in jthread's dtor srv_thread_1.join(); srv_thread_2.join(); srv_thread_3.join(); */ } int main() { int n_tests = 50; for (int i = 0; i < n_tests; i++) { spdlog::info("========================== NEW SIMULATION {} ==========================", i); spdlog::info("\tTime\t\tTerm\tPort\tRole\t\tMessage\n"); RunSimulation(); } spdlog::info("passed {} tests!", n_tests); return 0; }