Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions xla/hlo/ir/stack_frames.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,25 @@ HloStackFrame StackFrames::GetStackFrame(StackFrameId id) const {
}

StackFrameId StackFrames::AddStackFrame(const HloStackFrame& frame) {
auto [file_it, file_inserted] = file_name_to_id_.try_emplace(
std::string(frame.file_name), proto_.file_names_size() + 1);
if (file_inserted) {
FileNameId file_id;
auto file_it = file_name_to_id_.find(frame.file_name);
if (file_it != file_name_to_id_.end()) {
file_id = file_it->second;
} else {
file_id = proto_.file_names_size() + 1;
file_name_to_id_.emplace(std::string(frame.file_name), file_id);
proto_.add_file_names(std::string(frame.file_name));
}
FileNameId file_id = file_it->second;

auto [func_it, func_inserted] = function_name_to_id_.try_emplace(
std::string(frame.function_name), proto_.function_names_size() + 1);
if (func_inserted) {
FunctionNameId func_id;
auto func_it = function_name_to_id_.find(frame.function_name);
if (func_it != function_name_to_id_.end()) {
func_id = func_it->second;
} else {
func_id = proto_.function_names_size() + 1;
function_name_to_id_.emplace(std::string(frame.function_name), func_id);
proto_.add_function_names(std::string(frame.function_name));
}
FunctionNameId func_id = func_it->second;

FileLocationKey loc_key = {file_id, func_id, frame.line,
frame.column, frame.end_line, frame.end_column};
Expand Down Expand Up @@ -165,7 +171,10 @@ bool StackFrames::IsPrefix(StackFrameId prefix, StackFrameId full) const {
if (full == prefix) {
return true;
}
full = GetStackFrame(full).parent_frame_id;
if (full.value > proto_.stack_frames_size()) {
return false;
}
full = StackFrameId{proto_.stack_frames(full.value - 1).parent_frame_id()};
}
return false;
}
Expand Down
4 changes: 4 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1871,10 +1871,14 @@ cc_library(
"//xla/hlo/transforms:hlo_module_stitcher",
"//xla/stream_executor:stream_executor_h",
"//xla/tsl/concurrency:executor",
"//xla/tsl/platform:env",
"//xla/tsl/platform:status_macros",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
"@tsl//tsl/platform:blocking_counter",
"@tsl//tsl/platform:platform_port",
],
)

Expand Down
37 changes: 31 additions & 6 deletions xla/service/multi_module_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "xla/tsl/platform/status_macros.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_module.h"
Expand All @@ -33,7 +35,9 @@ limitations under the License.
#include "xla/service/compiler.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/tsl/concurrency/executor.h"
#include "xla/tsl/platform/threadpool.h"
#include "tsl/platform/blocking_counter.h"
#include "tsl/platform/cpu_info.h"

namespace xla {

Expand Down Expand Up @@ -89,14 +93,35 @@ absl::StatusOr<std::unique_ptr<HloModule>> MultiModuleDriver::Compile(
results[i] = compile_fn_(std::move(all_modules[i]), options);
}
} else {
// Parallel compilation.
// Parallel compilation with concurrency capping to avoid LLVM OOM or
// thrashing.
int max_concurrency = std::max<int>(
1, options.thread_pool ? options.thread_pool->NumThreads()
: std::min<int>(8, tsl::port::MaxParallelism()));

absl::Mutex mutex;
int active_compilations = 0;

tsl::BlockingCounter counter(all_modules.size());
for (size_t i = 0; i < all_modules.size(); ++i) {
executor->Execute(
[this, &all_modules, &options, &results, &counter, i]() {
results[i] = compile_fn_(std::move(all_modules[i]), options);
counter.DecrementCount();
});
{
absl::MutexLock lock(&mutex);
auto can_compile = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex) {
return active_compilations < max_concurrency;
};
mutex.Await(absl::Condition(&can_compile));
active_compilations++;
}

executor->Execute([this, &all_modules, &options, &results, &counter,
&mutex, &active_compilations, i]() {
results[i] = compile_fn_(std::move(all_modules[i]), options);
{
absl::MutexLock lock(&mutex);
active_compilations--;
}
counter.DecrementCount();
});
}
counter.Wait();
}
Expand Down
Loading