diff --git a/include/benchmark/benchmark.h b/include/benchmark/benchmark.h index 340cbc1e..a7ac9649 100644 --- a/include/benchmark/benchmark.h +++ b/include/benchmark/benchmark.h @@ -172,6 +172,7 @@ BENCHMARK(BM_test)->Unit(benchmark::kMillisecond); #include +#include #include #include #include @@ -429,16 +430,18 @@ class State { // Returns true if the benchmark should continue through another iteration. // NOTE: A benchmark may not return from the test until KeepRunning() has // returned false. - bool KeepRunning() { - if (BENCHMARK_BUILTIN_EXPECT(!started_, false)) { - StartKeepRunning(); - } - bool const res = (--total_iterations_ != 0); - if (BENCHMARK_BUILTIN_EXPECT(!res, false)) { - FinishKeepRunning(); - } - return res; - } + bool KeepRunning(); + + // Returns true iff the benchmark should run n more iterations. + // NOTE: A benchmark must not return from the test until KeepRunningBatch() + // has returned false. + // NOTE: KeepRunningBatch() may overshoot by up to 'n' iterations. + // + // Intended usage: + // while (state.KeepRunningBatch(1000)) { + // // process 1000 elements + // } + bool KeepRunningBatch(size_t n); // REQUIRES: timer is running and 'SkipWithError(...)' has not been called // by the current thread. @@ -565,12 +568,16 @@ class State { int range_y() const { return range(1); } BENCHMARK_ALWAYS_INLINE - size_t iterations() const { return (max_iterations - total_iterations_) + 1; } + size_t iterations() const { + return (max_iterations - total_iterations_ + batch_leftover_); + } private: bool started_; bool finished_; + // When total_iterations_ is 0, KeepRunning() and friends will return false. size_t total_iterations_; + // May be larger than max_iterations. std::vector range_; @@ -581,6 +588,11 @@ class State { bool error_occurred_; + // When using KeepRunningBatch(), batch_leftover_ holds the number of + // iterations beyond max_iters that were run. Used to track + // completed_iterations_ accurately. + size_t batch_leftover_; + public: // Container for user-defined counters. UserCounters counters; @@ -603,6 +615,50 @@ class State { BENCHMARK_DISALLOW_COPY_AND_ASSIGN(State); }; +inline BENCHMARK_ALWAYS_INLINE +bool State::KeepRunning() { + // total_iterations_ is set to 0 by the constructor, and always set to a + // nonzero value by StartKepRunning(). + if (BENCHMARK_BUILTIN_EXPECT(total_iterations_ != 0, true)) { + --total_iterations_; + return true; + } + if (!started_) { + StartKeepRunning(); + if (!error_occurred_) { + // max_iterations > 0. The first iteration is always valid. + --total_iterations_; + return true; + } + } + FinishKeepRunning(); + return false; +} + +inline BENCHMARK_ALWAYS_INLINE +bool State::KeepRunningBatch(size_t n) { + // total_iterations_ is set to 0 by the constructor, and always set to a + // nonzero value by StartKepRunning(). + if (BENCHMARK_BUILTIN_EXPECT(total_iterations_ >= n, true)) { + total_iterations_ -= n; + return true; + } + if (!started_) { + StartKeepRunning(); + if (!error_occurred_ && total_iterations_ >= n) { + total_iterations_-= n; + return true; + } + } + if (total_iterations_ != 0) { + batch_leftover_ = n - total_iterations_; + total_iterations_ = 0; + return true; + } + FinishKeepRunning(); + return false; +} + struct State::StateIterator { struct BENCHMARK_UNUSED Value {}; typedef std::forward_iterator_tag iterator_category; diff --git a/src/benchmark.cc b/src/benchmark.cc index 1a7d2182..8879204a 100644 --- a/src/benchmark.cc +++ b/src/benchmark.cc @@ -268,7 +268,7 @@ void RunInThread(const benchmark::internal::Benchmark::Instance* b, internal::ThreadTimer timer; State st(iters, b->arg, thread_id, b->threads, &timer, manager); b->benchmark->Run(st); - CHECK(st.iterations() == st.max_iterations) + CHECK(st.iterations() >= st.max_iterations) << "Benchmark returned before State::KeepRunning() returned false!"; { MutexLock l(manager->GetBenchmarkMutex()); @@ -399,12 +399,13 @@ State::State(size_t max_iters, const std::vector& ranges, int thread_i, internal::ThreadManager* manager) : started_(false), finished_(false), - total_iterations_(max_iters + 1), + total_iterations_(0), range_(ranges), bytes_processed_(0), items_processed_(0), complexity_n_(0), error_occurred_(false), + batch_leftover_(0), counters(), thread_index(thread_i), threads(n_threads), @@ -412,7 +413,6 @@ State::State(size_t max_iters, const std::vector& ranges, int thread_i, timer_(timer), manager_(manager) { CHECK(max_iterations != 0) << "At least one iteration must be run"; - CHECK(total_iterations_ != 0) << "max iterations wrapped around"; CHECK_LT(thread_index, threads) << "thread_index must be less than threads"; } @@ -437,7 +437,7 @@ void State::SkipWithError(const char* msg) { manager_->results.has_error_ = true; } } - total_iterations_ = 1; + total_iterations_ = 0; if (timer_->running()) timer_->StopTimer(); } @@ -453,6 +453,7 @@ void State::SetLabel(const char* label) { void State::StartKeepRunning() { CHECK(!started_ && !finished_); started_ = true; + total_iterations_ = error_occurred_ ? 0 : max_iterations; manager_->StartStopBarrier(); if (!error_occurred_) ResumeTiming(); } @@ -462,8 +463,8 @@ void State::FinishKeepRunning() { if (!error_occurred_) { PauseTiming(); } - // Total iterations has now wrapped around zero. Fix this. - total_iterations_ = 1; + // Total iterations has now wrapped around past 0. Fix this. + total_iterations_ = 0; finished_ = true; manager_->StartStopBarrier(); } diff --git a/test/basic_test.cc b/test/basic_test.cc index 3348781c..12579c09 100644 --- a/test/basic_test.cc +++ b/test/basic_test.cc @@ -102,10 +102,21 @@ void BM_KeepRunning(benchmark::State& state) { while (state.KeepRunning()) { ++iter_count; } - assert(iter_count == state.max_iterations); + assert(iter_count == state.iterations()); } BENCHMARK(BM_KeepRunning); +void BM_KeepRunningBatch(benchmark::State& state) { + // Choose a prime batch size to avoid evenly dividing max_iterations. + const size_t batch_size = 101; + size_t iter_count = 0; + while (state.KeepRunningBatch(batch_size)) { + iter_count += batch_size; + } + assert(state.iterations() == iter_count); +} +BENCHMARK(BM_KeepRunningBatch); + void BM_RangedFor(benchmark::State& state) { size_t iter_count = 0; for (auto _ : state) { diff --git a/test/skip_with_error_test.cc b/test/skip_with_error_test.cc index 0c2f3481..8d2c342a 100644 --- a/test/skip_with_error_test.cc +++ b/test/skip_with_error_test.cc @@ -70,6 +70,16 @@ void BM_error_before_running(benchmark::State& state) { BENCHMARK(BM_error_before_running); ADD_CASES("BM_error_before_running", {{"", true, "error message"}}); + +void BM_error_before_running_batch(benchmark::State& state) { + state.SkipWithError("error message"); + while (state.KeepRunningBatch(17)) { + assert(false); + } +} +BENCHMARK(BM_error_before_running_batch); +ADD_CASES("BM_error_before_running_batch", {{"", true, "error message"}}); + void BM_error_before_running_range_for(benchmark::State& state) { state.SkipWithError("error message"); for (auto _ : state) {