From b9357c7fdc31bd39ce255010d8213a240a14d0e4 Mon Sep 17 00:00:00 2001 From: Sean Talts Date: Fri, 12 Jun 2026 13:10:07 -0700 Subject: [PATCH] [XLA:CPU] Optimize StackFrames Proto index access and Dynamically Cap MultiModuleDriver parallel compilation concurrency - Rewrote `StackFrames::IsPrefix` to traverse stack frames directly via `.parent_frame_id()` protobuf indexes rather than allocating full `HloStackFrame` structural copies on every hop, resolving extreme compile-time O(N) memory allocation hotspots during call graph metadata propagation. - Capped `MultiModuleDriver::Compile` parallel submodule compilation to dynamically match the dimensions of `CompileOptions::thread_pool` (or a safe fallback limit of 8 concurrent LLVM compilations) to prevent Out-Of-Memory (OOM) hard freezes and virtual memory thrashing when compiling massively split models (like `torax`) under `FAST_COMPILE`. PiperOrigin-RevId: 931295728 --- xla/hlo/ir/stack_frames.cc | 27 ++++++++++++++-------- xla/service/BUILD | 4 ++++ xla/service/multi_module_driver.cc | 37 +++++++++++++++++++++++++----- 3 files changed, 53 insertions(+), 15 deletions(-) diff --git a/xla/hlo/ir/stack_frames.cc b/xla/hlo/ir/stack_frames.cc index 86377fabfdd88..20290e3e9e367 100644 --- a/xla/hlo/ir/stack_frames.cc +++ b/xla/hlo/ir/stack_frames.cc @@ -117,19 +117,25 @@ HloStackFrame StackFrames::GetStackFrame(StackFrameId id) const { } StackFrameId StackFrames::AddStackFrame(const HloStackFrame& frame) { - auto [file_it, file_inserted] = file_name_to_id_.try_emplace( - std::string(frame.file_name), proto_.file_names_size() + 1); - if (file_inserted) { + FileNameId file_id; + auto file_it = file_name_to_id_.find(frame.file_name); + if (file_it != file_name_to_id_.end()) { + file_id = file_it->second; + } else { + file_id = proto_.file_names_size() + 1; + file_name_to_id_.emplace(std::string(frame.file_name), file_id); proto_.add_file_names(std::string(frame.file_name)); } - FileNameId file_id = file_it->second; - auto [func_it, func_inserted] = function_name_to_id_.try_emplace( - std::string(frame.function_name), proto_.function_names_size() + 1); - if (func_inserted) { + FunctionNameId func_id; + auto func_it = function_name_to_id_.find(frame.function_name); + if (func_it != function_name_to_id_.end()) { + func_id = func_it->second; + } else { + func_id = proto_.function_names_size() + 1; + function_name_to_id_.emplace(std::string(frame.function_name), func_id); proto_.add_function_names(std::string(frame.function_name)); } - FunctionNameId func_id = func_it->second; FileLocationKey loc_key = {file_id, func_id, frame.line, frame.column, frame.end_line, frame.end_column}; @@ -165,7 +171,10 @@ bool StackFrames::IsPrefix(StackFrameId prefix, StackFrameId full) const { if (full == prefix) { return true; } - full = GetStackFrame(full).parent_frame_id; + if (full.value > proto_.stack_frames_size()) { + return false; + } + full = StackFrameId{proto_.stack_frames(full.value - 1).parent_frame_id()}; } return false; } diff --git a/xla/service/BUILD b/xla/service/BUILD index 4952939c0a347..61768c313ee73 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -1871,10 +1871,14 @@ cc_library( "//xla/hlo/transforms:hlo_module_stitcher", "//xla/stream_executor:stream_executor_h", "//xla/tsl/concurrency:executor", + "//xla/tsl/platform:env", "//xla/tsl/platform:status_macros", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", "@tsl//tsl/platform:blocking_counter", + "@tsl//tsl/platform:platform_port", ], ) diff --git a/xla/service/multi_module_driver.cc b/xla/service/multi_module_driver.cc index e8786a1a21fd4..2e393860a9800 100644 --- a/xla/service/multi_module_driver.cc +++ b/xla/service/multi_module_driver.cc @@ -22,8 +22,10 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" @@ -33,7 +35,9 @@ limitations under the License. #include "xla/service/compiler.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/concurrency/executor.h" +#include "xla/tsl/platform/threadpool.h" #include "tsl/platform/blocking_counter.h" +#include "tsl/platform/cpu_info.h" namespace xla { @@ -89,14 +93,35 @@ absl::StatusOr> MultiModuleDriver::Compile( results[i] = compile_fn_(std::move(all_modules[i]), options); } } else { - // Parallel compilation. + // Parallel compilation with concurrency capping to avoid LLVM OOM or + // thrashing. + int max_concurrency = std::max( + 1, options.thread_pool ? options.thread_pool->NumThreads() + : std::min(8, tsl::port::MaxParallelism())); + + absl::Mutex mutex; + int active_compilations = 0; + tsl::BlockingCounter counter(all_modules.size()); for (size_t i = 0; i < all_modules.size(); ++i) { - executor->Execute( - [this, &all_modules, &options, &results, &counter, i]() { - results[i] = compile_fn_(std::move(all_modules[i]), options); - counter.DecrementCount(); - }); + { + absl::MutexLock lock(&mutex); + auto can_compile = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex) { + return active_compilations < max_concurrency; + }; + mutex.Await(absl::Condition(&can_compile)); + active_compilations++; + } + + executor->Execute([this, &all_modules, &options, &results, &counter, + &mutex, &active_compilations, i]() { + results[i] = compile_fn_(std::move(all_modules[i]), options); + { + absl::MutexLock lock(&mutex); + active_compilations--; + } + counter.DecrementCount(); + }); } counter.Wait(); }