diff --git a/include/benchmark/benchmark.h b/include/benchmark/benchmark.h index eec0fc58..14efec13 100644 --- a/include/benchmark/benchmark.h +++ b/include/benchmark/benchmark.h @@ -169,6 +169,7 @@ BENCHMARK(BM_test)->Unit(benchmark::kMillisecond); #include #include #include +#include #include #include #include @@ -303,6 +304,10 @@ BENCHMARK(BM_test)->Unit(benchmark::kMillisecond); namespace benchmark { class BenchmarkReporter; +class State; + +// Define alias of Setup/Teardown callback function type +using callback_function = std::function; // Default number of minimum benchmark running time in seconds. const char kDefaultMinTimeStr[] = "0.5s"; @@ -1157,10 +1162,10 @@ class BENCHMARK_EXPORT Benchmark { // // The callback will be passed a State object, which includes the number // of threads, thread-index, benchmark arguments, etc. - // - // The callback must not be NULL or self-deleting. - Benchmark* Setup(void (*setup)(const benchmark::State&)); - Benchmark* Teardown(void (*teardown)(const benchmark::State&)); + Benchmark* Setup(callback_function&&); + Benchmark* Setup(const callback_function&); + Benchmark* Teardown(callback_function&&); + Benchmark* Teardown(const callback_function&); // Pass this benchmark object to *func, which can customize // the benchmark by calling various methods like Arg, Args, @@ -1309,7 +1314,6 @@ class BENCHMARK_EXPORT Benchmark { std::vector statistics_; std::vector thread_counts_; - typedef void (*callback_function)(const benchmark::State&); callback_function setup_; callback_function teardown_; diff --git a/src/benchmark_api_internal.cc b/src/benchmark_api_internal.cc index 14d4e134..60609d30 100644 --- a/src/benchmark_api_internal.cc +++ b/src/benchmark_api_internal.cc @@ -27,7 +27,9 @@ BenchmarkInstance::BenchmarkInstance(Benchmark* benchmark, int family_idx, min_time_(benchmark_.min_time_), min_warmup_time_(benchmark_.min_warmup_time_), iterations_(benchmark_.iterations_), - threads_(thread_count) { + threads_(thread_count), + setup_(benchmark_.setup_), + teardown_(benchmark_.teardown_) { name_.function_name = benchmark_.name_; size_t arg_i = 0; @@ -84,9 +86,6 @@ BenchmarkInstance::BenchmarkInstance(Benchmark* benchmark, int family_idx, if (!benchmark_.thread_counts_.empty()) { name_.threads = StrFormat("threads:%d", threads_); } - - setup_ = benchmark_.setup_; - teardown_ = benchmark_.teardown_; } State BenchmarkInstance::Run( diff --git a/src/benchmark_api_internal.h b/src/benchmark_api_internal.h index 9287c4eb..82ab71f4 100644 --- a/src/benchmark_api_internal.h +++ b/src/benchmark_api_internal.h @@ -68,9 +68,8 @@ class BenchmarkInstance { IterationCount iterations_; int threads_; // Number of concurrent threads to us - typedef void (*callback_function)(const benchmark::State&); - callback_function setup_ = nullptr; - callback_function teardown_ = nullptr; + callback_function setup_; + callback_function teardown_; }; bool FindBenchmarksInternal(const std::string& re, diff --git a/src/benchmark_register.cc b/src/benchmark_register.cc index 28336a16..8b945404 100644 --- a/src/benchmark_register.cc +++ b/src/benchmark_register.cc @@ -224,9 +224,7 @@ Benchmark::Benchmark(const std::string& name) use_real_time_(false), use_manual_time_(false), complexity_(oNone), - complexity_lambda_(nullptr), - setup_(nullptr), - teardown_(nullptr) { + complexity_lambda_(nullptr) { ComputeStatistics("mean", StatisticsMean); ComputeStatistics("median", StatisticsMedian); ComputeStatistics("stddev", StatisticsStdDev); @@ -337,13 +335,25 @@ Benchmark* Benchmark::Apply(void (*custom_arguments)(Benchmark* benchmark)) { return this; } -Benchmark* Benchmark::Setup(void (*setup)(const benchmark::State&)) { +Benchmark* Benchmark::Setup(callback_function&& setup) { + BM_CHECK(setup != nullptr); + setup_ = std::forward(setup); + return this; +} + +Benchmark* Benchmark::Setup(const callback_function& setup) { BM_CHECK(setup != nullptr); setup_ = setup; return this; } -Benchmark* Benchmark::Teardown(void (*teardown)(const benchmark::State&)) { +Benchmark* Benchmark::Teardown(callback_function&& teardown) { + BM_CHECK(teardown != nullptr); + teardown_ = std::forward(teardown); + return this; +} + +Benchmark* Benchmark::Teardown(const callback_function& teardown) { BM_CHECK(teardown != nullptr); teardown_ = teardown; return this; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 3686e7ee..07784cef 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -232,6 +232,7 @@ if (BENCHMARK_ENABLE_GTEST_TESTS) add_gtest(time_unit_gtest) add_gtest(min_time_parse_gtest) add_gtest(profiler_manager_gtest) + add_gtest(benchmark_setup_teardown_cb_types_gtest) endif(BENCHMARK_ENABLE_GTEST_TESTS) ############################################################################### diff --git a/test/benchmark_setup_teardown_cb_types_gtest.cc b/test/benchmark_setup_teardown_cb_types_gtest.cc new file mode 100644 index 00000000..c5a1a662 --- /dev/null +++ b/test/benchmark_setup_teardown_cb_types_gtest.cc @@ -0,0 +1,126 @@ +#include "benchmark/benchmark.h" +#include "gtest/gtest.h" + +using benchmark::BenchmarkReporter; +using benchmark::callback_function; +using benchmark::ClearRegisteredBenchmarks; +using benchmark::RegisterBenchmark; +using benchmark::RunSpecifiedBenchmarks; +using benchmark::State; +using benchmark::internal::Benchmark; + +static int functor_called = 0; +struct Functor { + void operator()(const benchmark::State& /*unused*/) { functor_called++; } +}; + +class NullReporter : public BenchmarkReporter { + public: + bool ReportContext(const Context& /*context*/) override { return true; } + void ReportRuns(const std::vector& /* report */) override {} +}; + +class BenchmarkTest : public testing::Test { + public: + Benchmark* bm; + NullReporter null_reporter; + + int setup_calls; + int teardown_calls; + + void SetUp() override { + setup_calls = 0; + teardown_calls = 0; + functor_called = 0; + + bm = RegisterBenchmark("BM", [](State& st) { + for (auto _ : st) { + } + }); + bm->Iterations(1); + } + + void TearDown() override { ClearRegisteredBenchmarks(); } +}; + +// Test that Setup/Teardown can correctly take a lambda expressions +TEST_F(BenchmarkTest, LambdaTestCopy) { + auto setup_lambda = [this](const State&) { setup_calls++; }; + auto teardown_lambda = [this](const State&) { teardown_calls++; }; + bm->Setup(setup_lambda); + bm->Teardown(teardown_lambda); + RunSpecifiedBenchmarks(&null_reporter); + EXPECT_EQ(setup_calls, 1); + EXPECT_EQ(teardown_calls, 1); +} + +// Test that Setup/Teardown can correctly take a lambda expressions +TEST_F(BenchmarkTest, LambdaTestMove) { + auto setup_lambda = [this](const State&) { setup_calls++; }; + auto teardown_lambda = [this](const State&) { teardown_calls++; }; + bm->Setup(std::move(setup_lambda)); + bm->Teardown(std::move(teardown_lambda)); + RunSpecifiedBenchmarks(&null_reporter); + EXPECT_EQ(setup_calls, 1); + EXPECT_EQ(teardown_calls, 1); +} + +// Test that Setup/Teardown can correctly take std::function +TEST_F(BenchmarkTest, CallbackFunctionCopy) { + callback_function setup_lambda = [this](const State&) { setup_calls++; }; + callback_function teardown_lambda = [this](const State&) { + teardown_calls++; + }; + bm->Setup(setup_lambda); + bm->Teardown(teardown_lambda); + RunSpecifiedBenchmarks(&null_reporter); + EXPECT_EQ(setup_calls, 1); + EXPECT_EQ(teardown_calls, 1); +} + +// Test that Setup/Teardown can correctly take std::function +TEST_F(BenchmarkTest, CallbackFunctionMove) { + callback_function setup_lambda = [this](const State&) { setup_calls++; }; + callback_function teardown_lambda = [this](const State&) { + teardown_calls++; + }; + bm->Setup(std::move(setup_lambda)); + bm->Teardown(std::move(teardown_lambda)); + RunSpecifiedBenchmarks(&null_reporter); + EXPECT_EQ(setup_calls, 1); + EXPECT_EQ(teardown_calls, 1); +} + +// Test that Setup/Teardown can correctly take functors +TEST_F(BenchmarkTest, FunctorCopy) { + Functor func; + bm->Setup(func); + bm->Teardown(func); + RunSpecifiedBenchmarks(&null_reporter); + EXPECT_EQ(functor_called, 2); +} + +// Test that Setup/Teardown can correctly take functors +TEST_F(BenchmarkTest, FunctorMove) { + Functor func1; + Functor func2; + bm->Setup(std::move(func1)); + bm->Teardown(std::move(func2)); + RunSpecifiedBenchmarks(&null_reporter); + EXPECT_EQ(functor_called, 2); +} + +// Test that Setup/Teardown can not take nullptr +TEST_F(BenchmarkTest, NullptrTest) { +#if GTEST_HAS_DEATH_TEST + // Tests only runnable in debug mode (when BM_CHECK is enabled). +#ifndef NDEBUG +#ifndef TEST_BENCHMARK_LIBRARY_HAS_NO_ASSERTIONS + EXPECT_DEATH(bm->Setup(nullptr), "setup != nullptr"); + EXPECT_DEATH(bm->Teardown(nullptr), "teardown != nullptr"); +#else + GTEST_SKIP() << "Test skipped because BM_CHECK is disabled"; +#endif +#endif +#endif +}