// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // License, and you may not use this file except in compliance with the Business Source License. // // As of the Change Date specified in that file, in accordance with // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. #include <thread> #include "gmock/gmock.h" #include "gtest/gtest.h" #include "rpc/client.hpp" #include "rpc/client_pool.hpp" #include "rpc/messages.hpp" #include "rpc/server.hpp" #include "utils/on_scope_exit.hpp" #include "utils/timer.hpp" #include "rpc_messages.hpp" using namespace memgraph::rpc; using namespace std::literals::chrono_literals; namespace memgraph::slk { void Save(const SumReq &sum, Builder *builder) { Save(sum.x, builder); Save(sum.y, builder); } void Load(SumReq *sum, Reader *reader) { Load(&sum->x, reader); Load(&sum->y, reader); } void Save(const SumRes &res, Builder *builder) { Save(res.sum, builder); } void Load(SumRes *res, Reader *reader) { Load(&res->sum, reader); } void Save(const EchoMessage &echo, Builder *builder) { Save(echo.data, builder); } void Load(EchoMessage *echo, Reader *reader) { Load(&echo->data, reader); } } // namespace memgraph::slk void SumReq::Load(SumReq *obj, memgraph::slk::Reader *reader) { memgraph::slk::Load(obj, reader); } void SumReq::Save(const SumReq &obj, memgraph::slk::Builder *builder) { memgraph::slk::Save(obj, builder); } void SumRes::Load(SumRes *obj, memgraph::slk::Reader *reader) { memgraph::slk::Load(obj, reader); } void SumRes::Save(const SumRes &obj, memgraph::slk::Builder *builder) { memgraph::slk::Save(obj, builder); } void EchoMessage::Load(EchoMessage *obj, memgraph::slk::Reader *reader) { memgraph::slk::Load(obj, reader); } void EchoMessage::Save(const EchoMessage &obj, memgraph::slk::Builder *builder) { memgraph::slk::Save(obj, builder); } TEST(Rpc, Call) { memgraph::communication::ServerContext server_context; Server server({"127.0.0.1", 0}, &server_context); auto const on_exit = memgraph::utils::OnScopeExit{[&] { server.Shutdown(); server.AwaitShutdown(); }}; server.Register<Sum>([](auto *req_reader, auto *res_builder) { SumReq req; memgraph::slk::Load(&req, req_reader); SumRes res(req.x + req.y); memgraph::slk::Save(res, res_builder); }); ASSERT_TRUE(server.Start()); std::this_thread::sleep_for(100ms); memgraph::communication::ClientContext client_context; Client client(server.endpoint(), &client_context); auto sum = client.Call<Sum>(10, 20); EXPECT_EQ(sum.sum, 30); } TEST(Rpc, Abort) { memgraph::communication::ServerContext server_context; Server server({"127.0.0.1", 0}, &server_context); server.Register<Sum>([](auto *req_reader, auto *res_builder) { SumReq req; memgraph::slk::Load(&req, req_reader); std::this_thread::sleep_for(500ms); SumRes res(req.x + req.y); memgraph::slk::Save(res, res_builder); }); ASSERT_TRUE(server.Start()); std::this_thread::sleep_for(100ms); memgraph::communication::ClientContext client_context; Client client(server.endpoint(), &client_context); std::thread thread([&client]() { std::this_thread::sleep_for(100ms); spdlog::info("Shutting down the connection!"); client.Abort(); }); memgraph::utils::Timer timer; EXPECT_THROW(client.Call<Sum>(10, 20), RpcFailedException); EXPECT_LT(timer.Elapsed(), 200ms); thread.join(); server.Shutdown(); server.AwaitShutdown(); } TEST(Rpc, ClientPool) { memgraph::communication::ServerContext server_context; Server server({"127.0.0.1", 0}, &server_context); server.Register<Sum>([](const auto &req_reader, auto *res_builder) { SumReq req; Load(&req, req_reader); std::this_thread::sleep_for(100ms); SumRes res(req.x + req.y); Save(res, res_builder); }); ASSERT_TRUE(server.Start()); std::this_thread::sleep_for(100ms); memgraph::communication::ClientContext client_context; Client client(server.endpoint(), &client_context); // These calls should take more than 400ms because we're using a regular // client auto get_sum_client = [&client](int x, int y) { auto sum = client.Call<Sum>(x, y); EXPECT_EQ(sum.sum, x + y); }; memgraph::utils::Timer t1; std::vector<std::thread> threads; for (int i = 0; i < 4; ++i) { threads.emplace_back(get_sum_client, 2 * i, 2 * i + 1); } for (int i = 0; i < 4; ++i) { threads[i].join(); } threads.clear(); EXPECT_GE(t1.Elapsed(), 400ms); memgraph::communication::ClientContext pool_context; ClientPool pool(server.endpoint(), &pool_context); // These calls shouldn't take much more that 100ms because they execute in // parallel auto get_sum = [&pool](int x, int y) { auto sum = pool.Call<Sum>(x, y); EXPECT_EQ(sum.sum, x + y); }; memgraph::utils::Timer t2; for (int i = 0; i < 4; ++i) { threads.emplace_back(get_sum, 2 * i, 2 * i + 1); } for (int i = 0; i < 4; ++i) { threads[i].join(); } EXPECT_LE(t2.Elapsed(), 200ms); server.Shutdown(); server.AwaitShutdown(); } TEST(Rpc, LargeMessage) { memgraph::communication::ServerContext server_context; Server server({"127.0.0.1", 0}, &server_context); server.Register<Echo>([](auto *req_reader, auto *res_builder) { EchoMessage res; memgraph::slk::Load(&res, req_reader); memgraph::slk::Save(res, res_builder); }); ASSERT_TRUE(server.Start()); std::this_thread::sleep_for(100ms); std::string testdata(100000, 'a'); memgraph::communication::ClientContext client_context; Client client(server.endpoint(), &client_context); auto echo = client.Call<Echo>(testdata); EXPECT_EQ(echo.data, testdata); server.Shutdown(); server.AwaitShutdown(); } TEST(Rpc, JumboMessage) { memgraph::communication::ServerContext server_context; Server server({"127.0.0.1", 0}, &server_context); server.Register<Echo>([](auto *req_reader, auto *res_builder) { EchoMessage res; memgraph::slk::Load(&res, req_reader); memgraph::slk::Save(res, res_builder); }); ASSERT_TRUE(server.Start()); std::this_thread::sleep_for(100ms); // NOLINTNEXTLINE (bugprone-string-constructor) std::string testdata(10000000, 'a'); memgraph::communication::ClientContext client_context; Client client(server.endpoint(), &client_context); auto echo = client.Call<Echo>(testdata); EXPECT_EQ(echo.data, testdata); server.Shutdown(); server.AwaitShutdown(); } TEST(Rpc, Stream) { memgraph::communication::ServerContext server_context; Server server({"127.0.0.1", 0}, &server_context); server.Register<Echo>([](auto *req_reader, auto *res_builder) { EchoMessage req; memgraph::slk::Load(&req, req_reader); std::string payload; memgraph::slk::Load(&payload, req_reader); EchoMessage res(req.data + payload); memgraph::slk::Save(res, res_builder); }); ASSERT_TRUE(server.Start()); std::this_thread::sleep_for(100ms); memgraph::communication::ClientContext client_context; Client client(server.endpoint(), &client_context); auto stream = client.Stream<Echo>("hello"); memgraph::slk::Save("world", stream.GetBuilder()); auto echo = stream.AwaitResponse(); EXPECT_EQ(echo.data, "helloworld"); server.Shutdown(); server.AwaitShutdown(); } TEST(Rpc, StreamLarge) { memgraph::communication::ServerContext server_context; Server server({"127.0.0.1", 0}, &server_context); server.Register<Echo>([](auto *req_reader, auto *res_builder) { EchoMessage req; memgraph::slk::Load(&req, req_reader); std::string payload; memgraph::slk::Load(&payload, req_reader); EchoMessage res(req.data + payload); memgraph::slk::Save(res, res_builder); }); ASSERT_TRUE(server.Start()); std::this_thread::sleep_for(100ms); std::string testdata1(50000, 'a'); std::string testdata2(50000, 'b'); memgraph::communication::ClientContext client_context; Client client(server.endpoint(), &client_context); auto stream = client.Stream<Echo>(testdata1); memgraph::slk::Save(testdata2, stream.GetBuilder()); auto echo = stream.AwaitResponse(); EXPECT_EQ(echo.data, testdata1 + testdata2); server.Shutdown(); server.AwaitShutdown(); } TEST(Rpc, StreamJumbo) { memgraph::communication::ServerContext server_context; Server server({"127.0.0.1", 0}, &server_context); server.Register<Echo>([](auto *req_reader, auto *res_builder) { EchoMessage req; memgraph::slk::Load(&req, req_reader); std::string payload; memgraph::slk::Load(&payload, req_reader); EchoMessage res(req.data + payload); memgraph::slk::Save(res, res_builder); }); ASSERT_TRUE(server.Start()); std::this_thread::sleep_for(100ms); // NOLINTNEXTLINE (bugprone-string-constructor) std::string testdata1(5000000, 'a'); // NOLINTNEXTLINE (bugprone-string-constructor) std::string testdata2(5000000, 'b'); memgraph::communication::ClientContext client_context; Client client(server.endpoint(), &client_context); auto stream = client.Stream<Echo>(testdata1); memgraph::slk::Save(testdata2, stream.GetBuilder()); auto echo = stream.AwaitResponse(); EXPECT_EQ(echo.data, testdata1 + testdata2); server.Shutdown(); server.AwaitShutdown(); }