From 90fdb039e0e3868a1ccb0e6e61ced3bf920c2b3a Mon Sep 17 00:00:00 2001 From: Sean Talts Date: Fri, 12 Jun 2026 09:23:32 -0700 Subject: [PATCH] [XLA:CPU] Release memory after compilation in HLO benchmarks. Compiling HLOs can use a significant amount of memory. To avoid OOMs when running benchmarks with many iterations, this change releases the memory associated with the compiled executable and purges per-thread caches to the OS after each compilation, outside of the timed benchmark loop. PiperOrigin-RevId: 931186564 --- xla/backends/cpu/benchmarks/BUILD | 1 + xla/backends/cpu/benchmarks/hlo_benchmark_runner.cc | 11 +++++++++++ 2 files changed, 12 insertions(+) diff --git a/xla/backends/cpu/benchmarks/BUILD b/xla/backends/cpu/benchmarks/BUILD index 882441f01f034..8a9daee23d262 100644 --- a/xla/backends/cpu/benchmarks/BUILD +++ b/xla/backends/cpu/benchmarks/BUILD @@ -84,6 +84,7 @@ cc_library( "@com_google_absl//absl/types:span", "@tsl//tsl/platform:casts", "@tsl//tsl/platform:path", + "@tsl//tsl/platform:platform_port", ], ) diff --git a/xla/backends/cpu/benchmarks/hlo_benchmark_runner.cc b/xla/backends/cpu/benchmarks/hlo_benchmark_runner.cc index 0df93e4dea916..6e705e48218d9 100644 --- a/xla/backends/cpu/benchmarks/hlo_benchmark_runner.cc +++ b/xla/backends/cpu/benchmarks/hlo_benchmark_runner.cc @@ -59,6 +59,7 @@ limitations under the License. #include "xla/tsl/platform/threadpool.h" #include "xla/util.h" #include "tsl/platform/casts.h" +#include "tsl/platform/mem.h" #include "tsl/platform/path.h" namespace xla::cpu { @@ -354,6 +355,16 @@ absl::Status CompileHloBenchmark(benchmark::State& state, ASSIGN_OR_RETURN(std::unique_ptr executable, client->CompileAndLoad(computation, compile_options)); tsl::testing::DoNotOptimize(executable); + + // Compiling uses a lot of memory; running it repeatedly in a benchmark + // loop would trigger lots of allocation/deallocation. So we only compile + // once at the beginning of each iteration, and then release the memory + // immediately to prevent OOMs. + executable.reset(); + state.PauseTiming(); + // Purge all per-thread caches to OS. + tsl::port::MallocExtension_ReleaseToSystem(static_cast(-1)); + state.ResumeTiming(); } return absl::OkStatus();