mirror of
https://github.com/google/benchmark.git
synced 2025-01-13 13:20:13 +08:00
Update benchmark Python bindings for nanobind 2.0, and update to nanobind 2.0. (#1817)
Incorporates the nanobind_bazel change from https://github.com/google/benchmark/pull/1795. nanobind 2.0 reworked the nanobind::enum_ class so it uses a real Python enum or intenum rather than its previous hand-rolled implementation. https://nanobind.readthedocs.io/en/latest/changelog.html#version-2-0-0-may-23-2024 As a consequence of that change, nanobind now checks when casting an integer to a enum value that the integer corresponds to a valid enum. Counter::Flags is a bitmask, and many combinations are not valid enum members. This change: a) sets nb::is_arithmetic(), which means Counter::Flags becomes an IntEnum that can be freely cast to an integer. b) defines the | operator for flags to return an integer, not an enum, avoiding the error. c) changes Counter's constructor to accept an int, not a Counter::Flags enum. Since Counter::Flags is an IntEnum now, it can be freely coerced to an int. If https://github.com/wjakob/nanobind/pull/599 is merged into nanobind, then we can perhaps use a flag enum here instead.
This commit is contained in:
parent
a6ad7fbbdc
commit
64b5d8cd11
@ -38,4 +38,4 @@ use_repo(pip, "tools_pip_deps")
|
||||
|
||||
# -- bazel_dep definitions -- #
|
||||
|
||||
bazel_dep(name = "nanobind_bazel", version = "1.0.0", dev_dependency = True)
|
||||
bazel_dep(name = "nanobind_bazel", version = "2.0.0", dev_dependency = True)
|
||||
|
@ -118,7 +118,7 @@ NB_MODULE(_benchmark, m) {
|
||||
using benchmark::Counter;
|
||||
nb::class_<Counter> py_counter(m, "Counter");
|
||||
|
||||
nb::enum_<Counter::Flags>(py_counter, "Flags")
|
||||
nb::enum_<Counter::Flags>(py_counter, "Flags", nb::is_arithmetic())
|
||||
.value("kDefaults", Counter::Flags::kDefaults)
|
||||
.value("kIsRate", Counter::Flags::kIsRate)
|
||||
.value("kAvgThreads", Counter::Flags::kAvgThreads)
|
||||
@ -130,7 +130,9 @@ NB_MODULE(_benchmark, m) {
|
||||
.value("kAvgIterationsRate", Counter::Flags::kAvgIterationsRate)
|
||||
.value("kInvert", Counter::Flags::kInvert)
|
||||
.export_values()
|
||||
.def(nb::self | nb::self);
|
||||
.def("__or__", [](Counter::Flags a, Counter::Flags b) {
|
||||
return static_cast<int>(a) | static_cast<int>(b);
|
||||
});
|
||||
|
||||
nb::enum_<Counter::OneK>(py_counter, "OneK")
|
||||
.value("kIs1000", Counter::OneK::kIs1000)
|
||||
@ -138,10 +140,15 @@ NB_MODULE(_benchmark, m) {
|
||||
.export_values();
|
||||
|
||||
py_counter
|
||||
.def(nb::init<double, Counter::Flags, Counter::OneK>(),
|
||||
nb::arg("value") = 0., nb::arg("flags") = Counter::kDefaults,
|
||||
nb::arg("k") = Counter::kIs1000)
|
||||
.def("__init__", ([](Counter *c, double value) { new (c) Counter(value); }))
|
||||
.def(
|
||||
"__init__",
|
||||
[](Counter* c, double value, int flags, Counter::OneK oneK) {
|
||||
new (c) Counter(value, static_cast<Counter::Flags>(flags), oneK);
|
||||
},
|
||||
nb::arg("value") = 0., nb::arg("flags") = Counter::kDefaults,
|
||||
nb::arg("k") = Counter::kIs1000)
|
||||
.def("__init__",
|
||||
([](Counter* c, double value) { new (c) Counter(value); }))
|
||||
.def_rw("value", &Counter::value)
|
||||
.def_rw("flags", &Counter::flags)
|
||||
.def_rw("oneK", &Counter::oneK)
|
||||
|
Loading…
Reference in New Issue
Block a user