Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ absl::StatusOr<std::vector<RepeatedFlagModifier>> ParseRepeatedEnumModifiers(
namespace {

template <typename T>
static auto FindRepeatedFieldValue(google::protobuf::RepeatedField<int>* list, T value) {
static auto FindRepeatedFieldValue(google::protobuf::RepeatedField<int>* list,
T value) {
for (auto it = list->begin(); it != list->end(); ++it) {
if (*it == value) {
return it;
Expand Down Expand Up @@ -300,6 +301,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_experimental_dynamic_slice_fusion_verify_offsets(false);
opts.set_xla_gpu_nccl_termination_timeout_seconds(-1);
opts.set_xla_gpu_enable_nccl_user_buffers(false);
opts.set_xla_gpu_enable_allocator_spatial_partitioning(true);
opts.set_xla_gpu_experimental_enable_nccl_symmetric_buffers(false);
opts.set_xla_gpu_experimental_enable_nvshmem(false);
opts.set_xla_gpu_enable_nccl_comm_splitting(true);
Expand Down Expand Up @@ -2025,6 +2027,14 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"Enables NCCL User Buffer Registration. collective_memory_size in the "
"allocator config must also be set to a non-zero value that is large "
"enough to meet peak collective memory usage."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_allocator_spatial_partitioning",
bool_setter_for(
&DebugOptions::set_xla_gpu_enable_allocator_spatial_partitioning),
debug_options->xla_gpu_enable_allocator_spatial_partitioning(),
"Enables spatial partitioning of the GPU BFC allocator so default and "
"collective allocations share one fixed address range. Requires BFC "
"preallocation."));
flag_list->push_back(tsl::Flag(
"xla_gpu_experimental_enable_nccl_symmetric_buffers",
bool_setter_for(
Expand Down Expand Up @@ -3250,7 +3260,6 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
bool_setter_for(&DebugOptions::set_xla_gpu_log_minmax),
debug_options->xla_gpu_log_minmax(),
"If true, log min/max values from kernel outputs."));

flag_list->push_back(tsl::Flag(
"xla_early_exit_with_layouts",
bool_setter_for(&DebugOptions::set_xla_early_exit_with_layouts),
Expand Down
18 changes: 9 additions & 9 deletions xla/pjrt/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ cc_library(
":gpu_metrics",
":se_gpu_pjrt_runtime_abi_version",
":se_gpu_topology_description",
"//xla:debug_options_flags",
"//xla:executable_run_options",
"//xla:future",
"//xla:literal",
Expand All @@ -88,6 +89,7 @@ cc_library(
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
"//xla/hlo/builder:xla_computation",
"//xla/hlo/ir:hlo",
"//xla/pjrt:async_work_runner",
"//xla/pjrt:common_pjrt_client",
"//xla/pjrt:device_event",
Expand Down Expand Up @@ -134,8 +136,10 @@ cc_library(
"//xla/service:transfer_manager",
"//xla/service/gpu:buffer_allocations",
"//xla/service/gpu:gpu_constants",
"//xla/service/gpu:gpu_executable",
"//xla/service/gpu:gpu_executable_run_options",
"//xla/service/gpu:gpu_memory_space_assignment",
"//xla/service/gpu:stream_executor_util",
"//xla/stream_executor:device_address",
"//xla/stream_executor:device_address_allocator",
"//xla/stream_executor:device_description",
Expand All @@ -146,6 +150,9 @@ cc_library(
"//xla/stream_executor:stream",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor:vmm_device_address_allocator",
"//xla/stream_executor/cuda:cuda_compute_capability",
"//xla/stream_executor/cuda:cuda_device_address_vmm_allocator",
"//xla/stream_executor/gpu:gpu_cudamallocasync_allocator",
"//xla/stream_executor/integrations:device_mem_allocator",
"//xla/stream_executor/integrations:tf_allocator_adapter",
"//xla/tsl/concurrency:async_value",
Expand All @@ -155,6 +162,7 @@ cc_library(
"//xla/tsl/framework:bfc_allocator",
"//xla/tsl/framework:device_id",
"//xla/tsl/framework:device_id_impl",
"//xla/tsl/framework:scoped_allocation_trace",
"//xla/tsl/lib/strings:proto_serialization",
"//xla/tsl/platform:env",
"//xla/tsl/platform:errors",
Expand Down Expand Up @@ -183,6 +191,7 @@ cc_library(
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@llvm-project//mlir:IR",
"@local_config_cuda//cuda:cudart_headers",
"@tsl//tsl/platform:casts",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
Expand All @@ -196,25 +205,16 @@ cc_library(
"@tsl//tsl/profiler/lib:traceme",
] + if_cuda_or_rocm([
# keep sorted
"//xla:debug_options_flags",
"//xla/service/gpu:gpu_compiler",
"//xla/service/gpu:gpu_executable",
"//xla/service/gpu:stream_executor_util",
]) + if_cuda([
# keep sorted
"//xla/stream_executor/cuda:cuda_compute_capability",
"//xla/stream_executor/cuda:cuda_device_address_vmm_allocator",
"//xla/stream_executor/gpu:gpu_cudamallocasync_allocator",
"@local_config_cuda//cuda:cuda_headers",
]) + if_rocm([
# keep sorted
"@local_config_rocm//rocm:rocm_headers",
]) + if_sycl([
# keep sorted
"//xla:debug_options_flags",
"//xla/service/gpu:gpu_compiler",
"//xla/service/gpu:gpu_executable",
"//xla/service/gpu:stream_executor_util",
"@local_config_sycl//sycl:sycl_headers",
]),
)
Expand Down
13 changes: 9 additions & 4 deletions xla/pjrt/gpu/gpu_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,12 @@ absl::StatusOr<std::shared_ptr<tsl::BFCAllocator>> CreateBFCAllocator(
se::StreamExecutor* executor, double memory_fraction, bool preallocate,
std::optional<int64_t> gpu_system_memory_size,
const std::vector<tsl::SubAllocator::Visitor>& sub_allocator_alloc_visitors,
const std::vector<tsl::SubAllocator::Visitor>&
sub_allocator_free_visitors) {
const std::vector<tsl::SubAllocator::Visitor>& sub_allocator_free_visitors,
bool enable_spatial_partitioning) {
if (enable_spatial_partitioning && !preallocate) {
return InvalidArgument(
"Spatial partitioning of the BFC allocator requires preallocate=true.");
}
bool enable_unified_memory;
absl::Status status = tsl::ReadBoolFromEnvVar("TF_FORCE_UNIFIED_MEMORY",
false, &enable_unified_memory);
Expand Down Expand Up @@ -164,13 +168,14 @@ absl::StatusOr<std::shared_ptr<tsl::BFCAllocator>> CreateBFCAllocator(

tsl::BFCAllocator::Options opts;
opts.allow_growth = !preallocate;
opts.enable_spatial_partitioning = enable_spatial_partitioning;
return std::make_shared<tsl::BFCAllocator>(
std::move(sub_allocator), allocator_memory,
absl::StrCat("GPU_", device_ordinal, "_bfc"), opts);
}

// Builds a BFCAllocator for all local GPUs that uses collective memory.
absl::StatusOr<std::shared_ptr<tsl::BFCAllocator>> CreateCollectiveBFCAllocator(
absl::StatusOr<std::unique_ptr<tsl::BFCAllocator>> CreateCollectiveBFCAllocator(
se::StreamExecutor* executor, double memory_fraction,
size_t collective_memory_size) {
int device_ordinal = executor->device_ordinal();
Expand Down Expand Up @@ -205,7 +210,7 @@ absl::StatusOr<std::shared_ptr<tsl::BFCAllocator>> CreateCollectiveBFCAllocator(

tsl::BFCAllocator::Options opts;
opts.allow_growth = !preallocate;
return std::make_shared<tsl::BFCAllocator>(
return std::make_unique<tsl::BFCAllocator>(
std::move(sub_allocator), allocator_memory,
absl::StrCat("GPU_collectivememory_", device_ordinal, "_bfc"), opts);
}
Expand Down
9 changes: 6 additions & 3 deletions xla/pjrt/gpu/gpu_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,18 @@ void EnablePeerAccess(absl::Span<se::StreamExecutor* const> executors);
absl::StatusOr<std::unique_ptr<tsl::BFCAllocator>> GetGpuHostAllocator(
se::StreamExecutor* executor);

// Builds a BFCAllocator for all local GPUs.
// Builds a BFCAllocator for all local GPUs. When enable_spatial_partitioning
// is set, the allocator serves collective (upper-end) and default (lower-end)
// requests from one shared address range; this requires preallocate=true.
absl::StatusOr<std::shared_ptr<tsl::BFCAllocator>> CreateBFCAllocator(
se::StreamExecutor* executor, double memory_fraction, bool preallocate,
std::optional<int64_t> gpu_system_memory_size,
const std::vector<tsl::SubAllocator::Visitor>& sub_allocator_alloc_visitors,
const std::vector<tsl::SubAllocator::Visitor>& sub_allocator_free_visitors);
const std::vector<tsl::SubAllocator::Visitor>& sub_allocator_free_visitors,
bool enable_spatial_partitioning = false);

// Builds a BFCAllocator for all local GPUs that uses collective memory.
absl::StatusOr<std::shared_ptr<tsl::BFCAllocator>> CreateCollectiveBFCAllocator(
absl::StatusOr<std::unique_ptr<tsl::BFCAllocator>> CreateCollectiveBFCAllocator(
se::StreamExecutor* executor, double memory_fraction,
size_t collective_memory_size);

Expand Down
96 changes: 77 additions & 19 deletions xla/pjrt/gpu/se_gpu_pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ limitations under the License.
#include "xla/executable_run_options.h"
#include "xla/future.h"
#include "xla/hlo/builder/xla_computation.h"
#include "xla/hlo/ir/hlo_input_output_alias_config.h"
#include "xla/layout.h"
#include "xla/pjrt/async_work_runner.h"
#include "xla/pjrt/buffer_sequencing_event.h"
Expand Down Expand Up @@ -118,6 +119,7 @@ limitations under the License.
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "xla/tsl/framework/allocator.h"
#include "xla/tsl/framework/scoped_allocation_trace.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
Expand Down Expand Up @@ -1329,14 +1331,20 @@ GetStreamExecutorGpuDeviceAllocator(
const std::map<int, std::unique_ptr<LocalDeviceState>>&
addressable_devices) {
std::vector<se::MultiDeviceAdapter::AllocatorInfo> allocators;
const DebugOptions& debug_options = xla::GetDebugOptionsFromFlags();
GpuAllocatorConfig::Kind effective_kind = allocator_config.kind;
if (GetDebugOptionsFromFlags().xla_gpu_command_buffer_update_mode() !=
if (debug_options.xla_gpu_command_buffer_update_mode() !=
DebugOptions::ALWAYS_UPDATE &&
effective_kind != GpuAllocatorConfig::Kind::kVmm) {
LOG(WARNING) << "xla_gpu_command_buffer_update_mode requires the "
"VMM allocator. Overriding allocator kind to kVmm.";
effective_kind = GpuAllocatorConfig::Kind::kVmm;
}

// Set when a single preallocated BFC allocator serves both default and
// collective memory via spatial partitioning; suppresses the separate
// collective allocator below.
bool shared_collective_pool = false;
switch (effective_kind) {
case GpuAllocatorConfig::Kind::kCudaAsync: {
for (const auto& ordinal_and_device : addressable_devices) {
Expand All @@ -1356,6 +1364,13 @@ GetStreamExecutorGpuDeviceAllocator(
case GpuAllocatorConfig::Kind::kDefault:
case GpuAllocatorConfig::Kind::kBFC: {
LOG(INFO) << "Using BFC allocator.";
// With the spatial-partitioning flag enabled, preallocation lets one BFC
// allocator over a fixed address range serve both default (lower end) and
// collective (upper end) memory, so no separate collective allocator is
// created. Otherwise, use the separate collective allocator below.
shared_collective_pool =
allocator_config.preallocate &&
debug_options.xla_gpu_enable_allocator_spatial_partitioning();
for (const auto& ordinal_and_device : addressable_devices) {
ASSIGN_OR_RETURN(
auto bfc_allocator,
Expand All @@ -1364,11 +1379,29 @@ GetStreamExecutorGpuDeviceAllocator(
allocator_config.preallocate,
allocator_config.gpu_system_memory_size,
allocator_config.sub_allocator_alloc_visitors,
allocator_config.sub_allocator_free_visitors));
allocator_config.sub_allocator_free_visitors,
/*enable_spatial_partitioning=*/
shared_collective_pool));
allocators.push_back(
{std::move(bfc_allocator),
ordinal_and_device.second->compute_stream(),
{bfc_allocator, ordinal_and_device.second->compute_stream(),
/*memory_space=*/(int)xla::gpu::MemorySpaceColor::kDefault});
if (shared_collective_pool) {
size_t collective_memory_alignment =
tsl::Allocator::kAllocatorAlignment;
if (auto* collectives =
gpu::GpuCollectives::Default(platform->Name())) {
collective_memory_alignment =
collectives->SymmetricMemoryAlignment();
}
allocators.push_back(
{std::move(bfc_allocator),
ordinal_and_device.second->compute_stream(),
/*memory_space=*/(int)xla::gpu::MemorySpaceColor::kCollective,
/*device_ordinal=*/std::nullopt,
/*platform=*/nullptr,
/*min_alignment=*/collective_memory_alignment,
/*allocation_end=*/tsl::AllocationEnd::kUpper});
}
}
break;
}
Expand Down Expand Up @@ -1402,18 +1435,22 @@ GetStreamExecutorGpuDeviceAllocator(
}
}

// Add any additional allocators for alternate memory spaces.
for (const auto& ordinal_and_device : addressable_devices) {
ASSIGN_OR_RETURN(
auto collective_bfc_allocator,
CreateCollectiveBFCAllocator(
ordinal_and_device.second->executor(),
/*memory_fraction=*/1.0 - allocator_config.memory_fraction,
allocator_config.collective_memory_size));
allocators.push_back(
{std::move(collective_bfc_allocator),
ordinal_and_device.second->compute_stream(),
/*memory_space=*/(int)xla::gpu::MemorySpaceColor::kCollective});
// Add a separate collective allocator unless the default BFC allocator
// already serves collective memory from its shared, spatially partitioned
// pool.
if (!shared_collective_pool) {
for (const auto& ordinal_and_device : addressable_devices) {
ASSIGN_OR_RETURN(
auto collective_bfc_allocator,
CreateCollectiveBFCAllocator(
ordinal_and_device.second->executor(),
/*memory_fraction=*/1.0 - allocator_config.memory_fraction,
allocator_config.collective_memory_size));
allocators.push_back(
{std::move(collective_bfc_allocator),
ordinal_and_device.second->compute_stream(),
/*memory_space=*/(int)xla::gpu::MemorySpaceColor::kCollective});
}
}

for (const auto& ordinal_and_device : addressable_devices) {
Expand All @@ -1426,7 +1463,6 @@ GetStreamExecutorGpuDeviceAllocator(
}

#if defined(GOOGLE_CUDA) && CUDA_VERSION >= 11020
const auto& debug_options = xla::GetDebugOptionsFromFlags();
if (debug_options.xla_gpu_temp_buffer_use_separate_color()) {
// Add memory allocator to allocate memory buffers with persistent temp
// memory space color.
Expand Down Expand Up @@ -2003,6 +2039,11 @@ StreamExecutorGpuClient::RunAsync(
"[", device_ordinal, "] GpuExecutable::ExecuteAsyncOnStreamImpl(",
gpu_exec->name(), ")"));

// Attribute all device memory allocation to the gpu executable.
tsl::ScopedAllocationTrace allocation_trace(
"xla.execute",
{{"executable", gpu_exec->name()}, {"device", device_ordinal}});

// GpuExecutable always bound to a single GpuContext during its execution, so
// we activate it once to skip expensive context activations later.
auto activation = executor->Activate();
Expand Down Expand Up @@ -2093,14 +2134,21 @@ StreamExecutorGpuClient::RunAsync(
}
} else {
// Allocate each allocation that might escape, or is the temp buffer.
CHECK(allocation.maybe_live_out() ||
allocation.IsPreallocatedTempBuffer());
bool is_live_out = allocation.maybe_live_out();
bool is_temp_buffer = allocation.IsPreallocatedTempBuffer();
CHECK(is_live_out || is_temp_buffer); // Crash OK

int64_t buffer_size = allocation.size();
if (auto it = allocate_granularity.find(allocation.color());
it != allocate_granularity.end()) {
buffer_size = RoundUpTo(buffer_size, it->second);
}
if (buffer_size > 0) {
tsl::ScopedAllocationTrace allocation_trace(
"xla.buffer", {{"kind", is_temp_buffer ? "temp" : "live_out"},
{"allocation_index", i},
{"requested_bytes", buffer_size},
{"memory_space", allocation.color()}});
ASSIGN_OR_RETURN(
se::ScopedDeviceAddress<uint8_t> owning_buffer,
memory_allocator->Allocate(device_ordinal, buffer_size,
Expand Down Expand Up @@ -2165,6 +2213,16 @@ StreamExecutorGpuClient::RunAsync(
"buffer is not donated; allocating a fresh buffer";
int64_t allocation_size = ShapeUtil::ByteSizeOf(
ShapeUtil::GetSubshape(gpu_exec->result_shape(), index));
const HloInputOutputAliasConfig::Alias& alias =
*output_info.alias_config;
const bool must_alias = alias.must_alias();
tsl::ScopedAllocationTrace copy_protection_trace(
"xla.buffer",
{{"kind", "live_out_copy_protection"},
{"allocation_index", output_info.allocation_index},
{"requested_bytes", allocation_size},
{"memory_space", allocation->color()},
{"alias_kind", must_alias ? "must_alias" : "may_alias"}});
absl::StatusOr<se::ScopedDeviceAddress<uint8_t>> allocated_buffer =
memory_allocator->Allocate(device_ordinal, allocation_size,
/*retry_on_failure=*/true,
Expand Down
Loading
Loading