diff --git a/xla/stream_executor/cuda/cuda_device_address_vmm_allocator.cc b/xla/stream_executor/cuda/cuda_device_address_vmm_allocator.cc index 04960f5c6592d..1c8fd8f9a1819 100644 --- a/xla/stream_executor/cuda/cuda_device_address_vmm_allocator.cc +++ b/xla/stream_executor/cuda/cuda_device_address_vmm_allocator.cc @@ -35,12 +35,12 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_memory_reservation.h" #include "xla/stream_executor/cuda/cuda_raw_memory_allocation.h" #include "xla/stream_executor/cuda/cuda_status.h" +#include "xla/stream_executor/device_address_vmm_allocator.h" #include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/memory_reservation.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/device_address_vmm_allocator.h" namespace stream_executor::gpu { diff --git a/xla/stream_executor/cuda/cuda_device_address_vmm_allocator.h b/xla/stream_executor/cuda/cuda_device_address_vmm_allocator.h index b473d211ebe8d..35b1a760f1693 100644 --- a/xla/stream_executor/cuda/cuda_device_address_vmm_allocator.h +++ b/xla/stream_executor/cuda/cuda_device_address_vmm_allocator.h @@ -24,12 +24,12 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/stream_executor/device_address_vmm_allocator.h" #include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/memory_reservation.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/device_address_vmm_allocator.h" namespace stream_executor::gpu { @@ -101,7 +101,6 @@ class CudaDeviceAddressVmmAllocator : public DeviceAddressVmmAllocator { absl::Status EnqueueDeferredDeallocation(PerDeviceState& state, uint64_t seqno) override; - private: explicit CudaDeviceAddressVmmAllocator(const Platform* platform); }; diff --git a/xla/stream_executor/device_address_vmm_allocator.cc b/xla/stream_executor/device_address_vmm_allocator.cc index c7c661cd80c18..6ee17cd690f38 100644 --- a/xla/stream_executor/device_address_vmm_allocator.cc +++ b/xla/stream_executor/device_address_vmm_allocator.cc @@ -15,13 +15,16 @@ limitations under the License. #include "xla/stream_executor/device_address_vmm_allocator.h" +#include #include #include +#include #include #include #include #include +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" @@ -38,14 +41,16 @@ limitations under the License. #include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/memory_reservation.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/tsl/platform/statusor.h" namespace stream_executor { namespace { + thread_local const xla::DeviceAssignment* current_device_assignment = nullptr; + } // namespace DeviceAddressVmmAllocator::DeviceAssignmentScope::DeviceAssignmentScope( @@ -121,7 +126,8 @@ DeviceAddressVmmAllocator::~DeviceAddressVmmAllocator() { absl::Status status = SynchronizeAllPendingOperations(); CHECK(status.ok()) << status; - for (auto& [ordinal, state] : per_device_) { + for (auto& device : per_device_) { + auto& state = device.second; // Free platform-specific per-device resources (e.g. pinned timeline). if (state->destroy_fn) { state->destroy_fn(); @@ -130,12 +136,14 @@ DeviceAddressVmmAllocator::~DeviceAddressVmmAllocator() { } absl::Status DeviceAddressVmmAllocator::SynchronizeAllPendingOperations() { - for (auto& [ordinal, state] : per_device_) { - RETURN_IF_ERROR(SynchronizePendingOperations(ordinal)); + for (auto& device : per_device_) { + RETURN_IF_ERROR(SynchronizePendingOperations(device.first)); } return absl::OkStatus(); } +// Common helpers and accessors. + absl::StatusOr DeviceAddressVmmAllocator::GetPerDeviceState(int device_ordinal) const { auto it = per_device_.find(device_ordinal); @@ -148,134 +156,99 @@ DeviceAddressVmmAllocator::GetPerDeviceState(int device_ordinal) const { return it->second.get(); } -absl::StatusOr -DeviceAddressVmmAllocator::ValidateReservationRange( - MemoryReservation* reservation, uint64_t reservation_offset, - uint64_t size) const { - if (reservation == nullptr) { - return absl::InvalidArgumentError("reservation must not be null"); - } - - DeviceAddressBase address = reservation->address(); - if (reservation_offset > address.size() || - size > address.size() - reservation_offset) { - return absl::InvalidArgumentError(absl::StrFormat( - "reservation range is out of bounds: offset=%uB, size=%uB, " - "reservation_size=%uB", - reservation_offset, size, address.size())); +uint64_t DeviceAddressVmmAllocator::RoundUpToGranularity( + const PerDeviceState& state, uint64_t size) const { + if (state.allocation_granularity == 0) { + return size; } - - return address.GetByteSlice(reservation_offset, size); + return ((size + state.allocation_granularity - 1) / + state.allocation_granularity) * + state.allocation_granularity; } -void DeviceAddressVmmAllocator::ProcessCompletedPendingDeallocations( - PerDeviceState& state) { - // Single atomic read covers all entries whose seqno is <= completed. - uint64_t completed = LoadTimeline(state.pinned_timeline); - while (!state.pending_deallocations.empty()) { - if (state.pending_deallocations.front().seqno > completed) { - break; - } - if (state.pending_deallocations.front().kind != - PendingDeallocationKind::kMap) { - DoDeallocate(state, state.pending_deallocations.front().addr); - } else { - DoUnMap(state, state.pending_deallocations.front().addr); - } - state.pending_deallocations.pop_front(); - } +absl::StatusOr DeviceAddressVmmAllocator::GetStream( + int device_ordinal) { + ASSIGN_OR_RETURN(auto state, GetPerDeviceState(device_ordinal)); + return state->stream; } -void DeviceAddressVmmAllocator::WaitPendingDeallocationsToComplete( - PerDeviceState& state, uint64_t size) { - if (state.pending_deallocations.empty()) { - return; +absl::Status DeviceAddressVmmAllocator::SynchronizePendingOperations( + int device_ordinal) { + ASSIGN_OR_RETURN(auto state, GetPerDeviceState(device_ordinal)); + absl::MutexLock lock(state->mu); + if (state->pending_deallocations.empty()) { + return absl::OkStatus(); } + return WaitAndDrainPendingDeallocationsUntilSeqno( + *state, state->pending_deallocations.back().seqno); +} - uint64_t accumulated_size = 0; - size_t count_to_wait = 0; - uint64_t rounded_size = RoundUpToGranularity(state, size); - uint64_t target_seqno = 0; - - // Target 1.1x the requested size to provide some headroom. - uint64_t target_size = rounded_size + rounded_size / 10; +absl::StatusOr DeviceAddressVmmAllocator::GetStreamExecutor( + int device_ordinal) const { + ASSIGN_OR_RETURN(auto state, GetPerDeviceState(device_ordinal)); + return state->executor; +} - for (const auto& pending : state.pending_deallocations) { - if (pending.kind != PendingDeallocationKind::kMap) { - accumulated_size += pending.reclaimable_bytes; - } - target_seqno = pending.seqno; - ++count_to_wait; - if (accumulated_size >= target_size) { - break; - } +MemoryAllocation* DeviceAddressVmmAllocator::GetRawAllocation( + int device_ordinal, DeviceAddressBase addr) const { + absl::StatusOr state_or = GetPerDeviceState(device_ordinal); + if (!state_or.ok()) { + return nullptr; } + PerDeviceState* state = *state_or; + absl::MutexLock lock(state->mu); - // Move selected entries out of the deque while holding the lock, so no - // other thread can observe or free them. - std::vector selected; - selected.reserve(count_to_wait); - for (size_t i = 0; i < count_to_wait; ++i) { - selected.push_back(std::move(state.pending_deallocations.front())); - state.pending_deallocations.pop_front(); + // Allocator addresses are keyed directly by their VA. Stale records remain in + // this map until deferred teardown completes, so require both active state + // and an exact address-range match before exposing the backing allocation. + auto allocation_it = state->records_by_allocator_address.find(addr.opaque()); + if (allocation_it != state->records_by_allocator_address.end() && + allocation_it->second->allocator_active && + allocation_it->second->allocator_address.IsSameAs(addr)) { + return allocation_it->second->raw_allocation.get(); } - // Release the lock before spin-waiting to avoid stalling other threads for - // potentially milliseconds while the GPU drains its work queue. - state.mu.unlock(); - - // Poll until the GPU writes a timeline value >= target_seqno. - // Since timeline values are written in stream order, this guarantees all - // earlier pending deallocations have also completed. - while (LoadTimeline(state.pinned_timeline) < target_seqno) { - absl::SleepFor(kGpuTimelinePollInterval); + // Reservation aliases created by Map() or by Allocate(..., + // return_reservation_address=false) are tracked in a separate active-only + // index. Stale or already-unmapped aliases intentionally return nullptr. + auto reservation_it = state->active_reservation_records.find(addr.opaque()); + if (reservation_it != state->active_reservation_records.end()) { + return reservation_it->second->raw_allocation.get(); } + return nullptr; +} - // Reacquire the lock before modifying the maps. - state.mu.lock(); - - for (auto& item : selected) { - if (item.kind != PendingDeallocationKind::kMap) { - DoDeallocate(state, item.addr); - } else { - DoUnMap(state, item.addr); - } +MemoryReservation* DeviceAddressVmmAllocator::GetReservation( + int device_ordinal, DeviceAddressBase addr) const { + absl::StatusOr state_or = GetPerDeviceState(device_ordinal); + if (!state_or.ok()) { + return nullptr; } -} + PerDeviceState* state = *state_or; + absl::MutexLock lock(state->mu); -void DeviceAddressVmmAllocator::DoDeallocate(PerDeviceState& state, - DeviceAddressBase mem) { - VLOG(3) << absl::StreamFormat( - "Actually freeing virtual address %p (size=%uB) on device ordinal %d", - mem.opaque(), mem.size(), state.executor->device_ordinal()); + auto allocation_it = state->records_by_allocator_address.find(addr.opaque()); + if (allocation_it != state->records_by_allocator_address.end() && + allocation_it->second->allocator_active && + allocation_it->second->allocator_address.IsSameAs(addr)) { + return allocation_it->second->allocator_address_reservation.get(); + } - auto record_it = state.records_by_allocator_address.find(mem.opaque()); - CHECK(record_it != state.records_by_allocator_address.end()); - CHECK(record_it->second->allocator_stale); - CHECK(record_it->second->allocator_address.IsSameAs(mem)); - AllocationRecord& record = *record_it->second; - CHECK(!record.allocator_active); - record.allocator_address_mapping.reset(); - record.allocator_address_reservation.reset(); + return nullptr; +} - if (record.raw_allocation != nullptr) { - uint64_t rounded_size = - RoundUpToGranularity(state, record.raw_allocation->address().size()); - DCHECK_GE(state.pa_allocated, rounded_size); - state.pa_allocated -= rounded_size; +uint64_t DeviceAddressVmmAllocator::GetAllocationGranularity( + StreamExecutor* executor) const { + absl::StatusOr state_or = + GetPerDeviceState(executor->device_ordinal()); + if (!state_or.ok()) { + return 0; } - record.raw_allocation.reset(); - CHECK_EQ(state.records_by_allocator_address.erase(mem.opaque()), 1); + PerDeviceState* state = *state_or; + return state->allocation_granularity; } -void DeviceAddressVmmAllocator::DoUnMap(PerDeviceState& state, - DeviceAddressBase mem) { - VLOG(3) << absl::StreamFormat( - "Actually unmapping reservation address %p (size=%uB) on device ordinal " - "%d", - mem.opaque(), mem.size(), state.executor->device_ordinal()); - state.stale_reservation_mappings.erase(mem.opaque()); -} +// Allocate helpers. void* DeviceAddressVmmAllocator::TrackAllocatorAddressMappedAllocation( PerDeviceState& state, PendingDeallocationKind kind, @@ -300,39 +273,219 @@ void* DeviceAddressVmmAllocator::TrackAllocatorAddressMappedAllocation( return va_ptr; } -absl::StatusOr DeviceAddressVmmAllocator::AllocateWithBudget( - PerDeviceState& state, uint64_t size, bool multi_device) { - uint64_t rounded_size = RoundUpToGranularity(state, size); - if (state.pa_allocated + rounded_size > state.pa_budget) { - return absl::ResourceExhaustedError(absl::StrFormat( - "Not enough PA budget for allocation: pa_allocated=%uB, " - "rounded_size=%uB, pa_budget=%uB", - state.pa_allocated, rounded_size, state.pa_budget)); +// Shared pending-reclaim retry flow: +// +// TryWithPendingReclaim(reclaim_size, try_reuse, try_fresh) +// │ +// ▼ +// ┌─────────────────────────────────┐ +// │ try_reuse() │──found──► return reused address +// └─────────────────────────────────┘ +// │ not found +// ▼ +// ┌─────────────────────────────────┐ +// │ try_fresh() │──OK──► return fresh address +// └─────────────────────────────────┘ +// │ ResourceExhausted +// ▼ +// ┌─────────────────────────────────┐ +// │ Process completed pending │ +// │ operations │ +// └─────────────────────────────────┘ +// │ +// ▼ +// ┌─────────────────────────────────┐ +// │ try_fresh() │──OK──► return fresh address +// └─────────────────────────────────┘ +// │ ResourceExhausted +// ▼ +// ┌─────────────────────────────────┐ +// │ Wait for pending operations │ +// │ to reclaim enough memory │ +// └─────────────────────────────────┘ +// │ +// ▼ +// ┌─────────────────────────────────┐ +// │ try_fresh() │──OK──► return fresh address +// └─────────────────────────────────┘ +// │ failed +// ▼ +// return error +template +absl::StatusOr +DeviceAddressVmmAllocator::TryWithPendingReclaim(PerDeviceState& state, + uint64_t reclaim_size, + TryReuseFn try_reuse, + TryFreshFn try_fresh) { + // First try to reactivate a compatible pending deallocation without waiting. + // Reuse is stream-order safe and avoids both a fresh VMM allocation and any + // host-side wait for the GPU timeline. + ASSIGN_OR_RETURN(std::optional reused, try_reuse()); + if (reused.has_value()) { + return *reused; + } + + // If no pending entry matches, try the normal fresh allocation path. Most + // calls should finish here; the reclaim paths below are only for PA budget + // pressure or allocator-level allocation failures. + absl::StatusOr result = try_fresh(); + + if (absl::IsResourceExhausted(result.status())) { + // A ResourceExhausted error may be stale: some pending deallocations can + // already be past their stream timeline point. Complete ready allocator + // deallocations first, without blocking for later pending work and without + // destroying unrelated stale reservation mappings that may be reused. + CompleteReadyAllocatorDeallocationsForReclaim( + state, LoadTimeline(state.pinned_timeline)); + result = try_fresh(); + } + + if (absl::IsResourceExhausted(result.status())) { + // If completed pending work was not enough, wait until enough queued frees + // should be reclaimable for this request, then retry once more. This is the + // only path that may block while the GPU drains earlier stream work. + // Select enough pending allocator-address deallocations to cover this + // request, then wait for the selected tail seqno to become safe. Unrelated + // kMap entries do not own physical memory, so leave them stale and + // reusable. + if (!state.pending_deallocations.empty()) { + uint64_t accumulated_size = 0; + uint64_t rounded_size = RoundUpToGranularity(state, reclaim_size); + uint64_t target_seqno = 0; + std::vector selected; + + // Target 1.1x the requested size to provide some headroom. + uint64_t target_size = rounded_size + rounded_size / 10; + + for (const PendingDeallocation& pending : state.pending_deallocations) { + if (pending.kind == PendingDeallocationKind::kMap) { + continue; + } + auto record_it = + state.records_by_allocator_address.find(pending.addr.opaque()); + CHECK(record_it != state.records_by_allocator_address.end()); + CHECK(record_it->second->allocator_stale); + CHECK(record_it->second->allocator_address.IsSameAs(pending.addr)); + CHECK(record_it->second->raw_allocation != nullptr); + accumulated_size += RoundUpToGranularity( + state, record_it->second->raw_allocation->address().size()); + target_seqno = std::max(target_seqno, pending.seqno); + selected.push_back( + PendingDeallocationKey{pending.kind, pending.seqno, pending.addr}); + if (accumulated_size >= target_size) { + break; + } + } + + if (!selected.empty()) { + RETURN_IF_ERROR(WaitUntilSeqno(state, target_seqno)); + for (const PendingDeallocationKey& key : selected) { + CompletePendingDeallocationByKey(state, key); + } + } + } + result = try_fresh(); } - // Create physical memory allocation (e.g. cuMemCreate). - ASSIGN_OR_RETURN(auto raw_alloc, CreateAllocation(state.executor, size)); - const uint64_t padded_size = raw_alloc->address().size(); + return result; +} - // Reserve virtual address range (e.g. cuMemAddressReserve). - ASSIGN_OR_RETURN(auto reservation, CreateReservation(state.executor, size)); +// Allocate() reuses pending kAllocate entries, otherwise tries a fresh +// allocator-address mapping. +absl::StatusOr> +DeviceAddressVmmAllocator::Allocate(int device_ordinal, uint64_t size, + bool /*retry_on_failure*/, + int64_t /*memory_space*/) { + if (size == 0) { + return ScopedDeviceAddress(DeviceAddressBase(), device_ordinal, + this); + } - // Map physical memory into the virtual address range and enable access. - ASSIGN_OR_RETURN( - auto scoped_mapping, - reservation->MapTo(/*reservation_offset=*/0, /*allocation_offset=*/0, - padded_size, *raw_alloc)); + ASSIGN_OR_RETURN(auto state, GetPerDeviceState(device_ordinal)); + const bool multi_device = CurrentMultiDevice(); - auto shared_raw = std::shared_ptr(std::move(raw_alloc)); - DeviceAddressBase allocator_address(reservation->address().opaque(), size); - void* va_ptr = TrackAllocatorAddressMappedAllocation( - state, PendingDeallocationKind::kAllocate, allocator_address, - std::move(shared_raw), std::move(reservation), std::move(scoped_mapping), - rounded_size, multi_device); - // Return the original requested size, not the padded size. - return DeviceAddressBase(va_ptr, size); + absl::MutexLock lock(state->mu); + auto try_reuse = [&]() ABSL_NO_THREAD_SAFETY_ANALYSIS + -> absl::StatusOr> { + uint64_t rounded_size = RoundUpToGranularity(*state, size); + for (auto it = state->pending_deallocations.begin(); + it != state->pending_deallocations.end(); ++it) { + if (it->kind != PendingDeallocationKind::kAllocate) { + continue; + } + auto record_it = + state->records_by_allocator_address.find(it->addr.opaque()); + CHECK(record_it != state->records_by_allocator_address.end()); + AllocationRecord& record = *record_it->second; + CHECK(record.allocator_stale); + CHECK(record.allocator_address.IsSameAs(it->addr)); + if (record.multi_device != multi_device) { + continue; + } + if (RoundUpToGranularity(*state, record.allocator_address.size()) != + rounded_size) { + continue; + } + + DeviceAddressBase reused_mem(record.allocator_address.opaque(), size); + MoveAllocatorRecordToActive(*state, record, size); + ErasePendingDeallocationAt(*state, it); + + return std::optional(reused_mem); + } + + return std::optional(); + }; + auto try_fresh = + [&]() + ABSL_NO_THREAD_SAFETY_ANALYSIS -> absl::StatusOr { + uint64_t rounded_size = RoundUpToGranularity(*state, size); + if (state->pa_allocated + rounded_size > state->pa_budget) { + return absl::StatusOr( + absl::ResourceExhaustedError(absl::StrFormat( + "Not enough PA budget for allocation: pa_allocated=%uB, " + "rounded_size=%uB, pa_budget=%uB", + state->pa_allocated, rounded_size, state->pa_budget))); + } + + ASSIGN_OR_RETURN(auto raw_alloc, CreateAllocation(state->executor, size)); + const uint64_t padded_size = raw_alloc->address().size(); + + ASSIGN_OR_RETURN(auto reservation, + CreateReservation(state->executor, size)); + + ASSIGN_OR_RETURN( + auto scoped_mapping, + reservation->MapTo(/*reservation_offset=*/0, /*allocation_offset=*/0, + padded_size, *raw_alloc)); + + auto shared_raw = std::shared_ptr(std::move(raw_alloc)); + DeviceAddressBase allocator_address(reservation->address().opaque(), size); + void* va_ptr = TrackAllocatorAddressMappedAllocation( + *state, PendingDeallocationKind::kAllocate, allocator_address, + std::move(shared_raw), std::move(reservation), + std::move(scoped_mapping), rounded_size, multi_device); + // Return the original requested size, not the padded size. + return absl::StatusOr(DeviceAddressBase(va_ptr, size)); + }; + + absl::StatusOr result = + TryWithPendingReclaim(*state, size, try_reuse, try_fresh); + + if (!result.ok()) { + return result.status(); + } + + VLOG(3) << absl::StreamFormat( + "Allocated virtual address %p (%uB) on device ordinal %d", + result->opaque(), size, device_ordinal); + + return ScopedDeviceAddress(*result, device_ordinal, this); } +// Mapped Allocate() creates fresh physical memory and maps it into the caller +// reservation. It keeps the same externally visible ownership model as the +// previous map-based bookkeeping, but records the lifetime in AllocationRecord. absl::StatusOr> DeviceAddressVmmAllocator::Allocate( int device_ordinal, uint64_t allocation_size, bool /*retry_on_failure*/, @@ -348,19 +501,18 @@ DeviceAddressVmmAllocator::Allocate( return ScopedDeviceAddress(DeviceAddressBase(), device_ordinal, this); } - ASSIGN_OR_RETURN( - DeviceAddressBase reservation_address, - ValidateReservationRange(reservation, reservation_offset, mapping_size)); ASSIGN_OR_RETURN(auto state, GetPerDeviceState(device_ordinal)); - const bool multi_device = CurrentMultiDevice(); + ASSIGN_OR_RETURN( + DeviceAddressBase reservation_address, + ValidateReservationRange(reservation, reservation_offset, mapping_size)); + absl::MutexLock lock(state->mu); - if (state->active_reservation_mappings.contains( - reservation_address.opaque()) || - state->stale_reservation_mappings.contains( + if (state->active_reservation_records.contains( reservation_address.opaque()) || + state->stale_reservation_records.contains(reservation_address.opaque()) || state->records_by_allocator_address.contains( reservation_address.opaque())) { return absl::FailedPreconditionError( @@ -399,116 +551,38 @@ DeviceAddressVmmAllocator::Allocate( this); } - ASSIGN_OR_RETURN(auto allocator_reservation, + ASSIGN_OR_RETURN(auto allocator_address_reservation, CreateReservation(state->executor, allocation_size)); - ASSIGN_OR_RETURN(auto allocator_mapping, - allocator_reservation->MapTo( + ASSIGN_OR_RETURN(auto allocator_address_mapping, + allocator_address_reservation->MapTo( /*reservation_offset=*/0, /*allocation_offset=*/0, padded_size, *shared_raw)); - DeviceAddressBase allocator_address(allocator_reservation->address().opaque(), - allocation_size); - TrackAllocatorAddressMappedAllocation( - *state, PendingDeallocationKind::kAllocateAndMapReturnNewAddr, - allocator_address, std::move(shared_raw), std::move(allocator_reservation), - std::move(allocator_mapping), rounded_size, multi_device); - state->active_reservation_mappings.emplace( - reservation_address.opaque(), - ReservationMapping{allocator_address, reservation_address, reservation, - reservation_offset, mapping_size, - std::move(reservation_mapping)}); - - return ScopedDeviceAddress(allocator_address, device_ordinal, this); -} - -// Allocation flow with retry: -// -// Allocate(device_ordinal, size) -// │ -// ▼ -// ┌─────────────────────────────────┐ -// │ Reuse pending deallocation │──found──► return -// │ with matching size? │ -// └─────────────────────────────────┘ -// │ not found -// ▼ -// ┌─────────────────────────────────┐ -// │ Allocate new physical + │──OK──► return -// │ virtual memory │ -// └─────────────────────────────────┘ -// │ failed -// ▼ -// ┌─────────────────────────────────┐ -// │ Free any GPU-completed │ -// │ pending deallocations │ -// │ (non-blocking) │ -// └─────────────────────────────────┘ -// │ -// ▼ -// ┌─────────────────────────────────┐ -// │ Allocate new physical + │──OK──► return -// │ virtual memory │ -// └─────────────────────────────────┘ -// │ failed -// ▼ -// ┌─────────────────────────────────┐ -// │ Block until GPU frees │ -// │ enough pending memory │ -// └─────────────────────────────────┘ -// │ -// ▼ -// ┌─────────────────────────────────┐ -// │ Allocate new physical + │──OK──► return -// │ virtual memory │ -// └─────────────────────────────────┘ -// │ failed -// ▼ -// ResourceExhaustedError -absl::StatusOr> -DeviceAddressVmmAllocator::Allocate(int device_ordinal, uint64_t size, - bool /*retry_on_failure*/, - int64_t /*memory_space*/) { - if (size == 0) { - return ScopedDeviceAddress(DeviceAddressBase(), device_ordinal, - this); - } - - ASSIGN_OR_RETURN(auto state, GetPerDeviceState(device_ordinal)); - - const bool multi_device = CurrentMultiDevice(); - - absl::MutexLock lock(state->mu); - // Try to reuse a completed pending deallocation with matching size. - std::optional reused = - TryReusePendingDeallocation(*state, size, multi_device); - if (reused.has_value()) { - return ScopedDeviceAddress(*reused, device_ordinal, this); - } - - absl::StatusOr result = - AllocateWithBudget(*state, size, multi_device); - - // If allocation failed (e.g., out of memory), try processing pending - // deallocations to free memory, then retry. - if (!result.ok()) { - ProcessCompletedPendingDeallocations(*state); - result = AllocateWithBudget(*state, size, multi_device); - } - - if (!result.ok()) { - WaitPendingDeallocationsToComplete(*state, size); - result = AllocateWithBudget(*state, size, multi_device); - } - - if (!result.ok()) { - return result.status(); - } - - VLOG(3) << absl::StreamFormat( - "Allocated virtual address %p (%uB) on device ordinal %d", - result->opaque(), size, device_ordinal); + void* allocator_va = allocator_address_reservation->address().opaque(); + auto record = std::make_unique(); + DeviceAddressBase allocator_address(allocator_va, allocation_size); + record->kind = PendingDeallocationKind::kAllocateAndMapReturnNewAddr; + record->allocator_address = allocator_address; + record->raw_allocation = std::move(shared_raw); + record->multi_device = multi_device; + record->allocator_address_reservation = + std::move(allocator_address_reservation); + record->allocator_address_mapping.emplace( + std::move(allocator_address_mapping)); + record->reservation_address = reservation_address; + record->reservation_address_mapping.emplace(std::move(reservation_mapping)); + record->allocator_active = true; + record->reservation_active = true; + AllocationRecord* record_ptr = record.get(); + auto record_insert = state->records_by_allocator_address.emplace( + allocator_va, std::move(record)); + CHECK(record_insert.second); + auto reservation_insert = state->active_reservation_records.emplace( + reservation_address.opaque(), record_ptr); + CHECK(reservation_insert.second); + state->pa_allocated += rounded_size; - return ScopedDeviceAddress(*result, device_ordinal, this); + return ScopedDeviceAddress(allocator_address, device_ordinal, this); } absl::Status DeviceAddressVmmAllocator::Deallocate(int device_ordinal, @@ -525,306 +599,408 @@ absl::Status DeviceAddressVmmAllocator::Deallocate(int device_ordinal, if (record_it == state->records_by_allocator_address.end() || !record_it->second->allocator_active || !record_it->second->allocator_address.IsSameAs(mem)) { - if (state->active_reservation_mappings.contains(mem.opaque()) || - state->stale_reservation_mappings.contains(mem.opaque())) { - return absl::InvalidArgumentError( - "DeviceAddressVmmAllocator::Deallocate does not accept reservation " - "alias addresses; use UnMap instead"); - } - return absl::InvalidArgumentError(absl::StrFormat( - "DeviceAddressVmmAllocator::Deallocate received an unknown address %p", + return absl::NotFoundError(absl::StrFormat( + "virtual address %p is not an active allocator address returned by " + "Allocate()", mem.opaque())); } AllocationRecord& record = *record_it->second; - - for (const auto& [_, mapping] : state->active_reservation_mappings) { - if (mapping.allocator_address.IsSameAs(mem)) { - return absl::FailedPreconditionError( - "DeviceAddressVmmAllocator::Deallocate requires active reservation " - "aliases to be released with UnMap first"); - } + CHECK(!state->active_reservation_records.contains(mem.opaque())); + if (record.reservation_active) { + CHECK(record.reservation_address.has_value()); + return absl::FailedPreconditionError(absl::StrFormat( + "Deallocate() requires the active reservation alias at virtual address " + "%p (%uB) to be released with UnMap() first", + record.reservation_address->opaque(), + record.reservation_address->size())); } VLOG(3) << absl::StreamFormat( "Queueing deferred deallocation for virtual address %p (size=%uB) " "on device ordinal %d", - mem.opaque(), mem.size(), device_ordinal); + mem.opaque(), mem.size(), state->executor->device_ordinal()); + + const uint64_t reclaimable_bytes = + RoundUpToGranularity(*state, record.raw_allocation->address().size()); // Assign the next sequence number and enqueue a GPU write to the pinned // timeline when the stream reaches this point. The CPU polls the timeline // value to know when it is safe to free the memory. uint64_t seqno = state->next_seqno++; RETURN_IF_ERROR(EnqueueDeferredDeallocation(*state, seqno)); - + // Move the returned allocator address out of active ownership and keep its + // mapping alive as stale state until the stream reaches `seqno`. CHECK(record.allocator_active); CHECK(!record.allocator_stale); CHECK(record.allocator_address_mapping.has_value()); + void* allocator_va = record.allocator_address.opaque(); + auto allocator_record_it = + state->records_by_allocator_address.find(allocator_va); + CHECK(allocator_record_it != state->records_by_allocator_address.end()); + CHECK_EQ(allocator_record_it->second.get(), &record); record.allocator_active = false; record.allocator_stale = true; record.allocator_stale_seqno = seqno; - const uint64_t reclaimable_bytes = - RoundUpToGranularity(*state, record.raw_allocation->address().size()); - state->pending_deallocations.push_back( - {record.kind, seqno, record.allocator_address, reclaimable_bytes}); - + state->pending_deallocations.push_back(PendingDeallocation{ + record.kind, seqno, record.allocator_address, reclaimable_bytes}); return absl::OkStatus(); } +// Map helpers. + +absl::StatusOr +DeviceAddressVmmAllocator::ValidateReservationRange( + MemoryReservation* reservation, uint64_t reservation_offset, + uint64_t size) const { + if (reservation == nullptr) { + return absl::InvalidArgumentError("reservation must not be null"); + } + + DeviceAddressBase address = reservation->address(); + if (reservation_offset > address.size() || + size > address.size() - reservation_offset) { + return absl::InvalidArgumentError(absl::StrFormat( + "reservation range is out of bounds: offset=%uB, size=%uB, " + "reservation_size=%uB", + reservation_offset, size, address.size())); + } + + return address.GetByteSlice(reservation_offset, size); +} + absl::Status DeviceAddressVmmAllocator::Map(int device_ordinal, DeviceAddressBase addr, MemoryReservation* reservation, uint64_t reservation_offset, uint64_t size) { + ASSIGN_OR_RETURN(auto state, GetPerDeviceState(device_ordinal)); if (size == 0) { return absl::OkStatus(); } if (addr.is_null()) { - return absl::InvalidArgumentError( - "DeviceAddressVmmAllocator::Map requires a non-null source address"); + return absl::InvalidArgumentError("addr must not be null"); } + + // Map() does not allocate a VA range. It maps the physical allocation backing + // `addr` into the caller-owned reservation slice, so validate the slice + // before taking the allocator lock. ASSIGN_OR_RETURN( DeviceAddressBase reservation_address, ValidateReservationRange(reservation, reservation_offset, size)); - ASSIGN_OR_RETURN(auto state, GetPerDeviceState(device_ordinal)); - absl::MutexLock lock(state->mu); - auto record_it = state->records_by_allocator_address.find(addr.opaque()); - if (record_it == state->records_by_allocator_address.end() || - !record_it->second->allocator_active || - !record_it->second->allocator_address.IsSameAs(addr)) { - return absl::InvalidArgumentError(absl::StrFormat( - "DeviceAddressVmmAllocator::Map received an unknown allocator address " - "%p", - addr.opaque())); - } - AllocationRecord& record = *record_it->second; - MemoryAllocation* raw_allocation = record.raw_allocation.get(); + auto resolve_source_record = + [&]() + ABSL_NO_THREAD_SAFETY_ANALYSIS -> absl::StatusOr { + auto allocation_it = + state->records_by_allocator_address.find(addr.opaque()); + if (allocation_it == state->records_by_allocator_address.end() || + !allocation_it->second->allocator_active || + !allocation_it->second->allocator_address.IsSameAs(addr)) { + return absl::NotFoundError(absl::StrFormat( + "addr %p is not an active allocator address, when trying to " + "do map of VA reservation to existing physical allocation, we " + "requires the buffer being mapped to is being allocated through " + "DeviceAddressVmmAllocator, check the allocator type for the " + "buffer.", + addr.opaque())); + } + return allocation_it->second.get(); + }; + + // Resolve the source address to the raw physical allocation that is currently + // mapped there. Any active allocator address returned by this allocator is + // accepted as a Map() source. + ASSIGN_OR_RETURN(AllocationRecord * source_record, resolve_source_record()); + MemoryAllocation* raw_allocation = source_record->raw_allocation.get(); if (size > raw_allocation->address().size()) { return absl::InvalidArgumentError(absl::StrFormat( - "DeviceAddressVmmAllocator::Map size %u exceeds raw allocation size " - "%u", + "mapping size must not exceed physical allocation size: " + "mapping_size=%uB, allocation_size=%uB", size, raw_allocation->address().size())); } - if (state->active_reservation_mappings.contains( + if (state->active_reservation_records.contains( reservation_address.opaque()) || - state->stale_reservation_mappings.contains( - reservation_address.opaque())) { + state->stale_reservation_records.contains(reservation_address.opaque())) { return absl::FailedPreconditionError( "Reservation address is already tracked by this allocator"); } - for (const auto& [_, mapping] : state->active_reservation_mappings) { - if (mapping.allocator_address.IsSameAs(addr)) { - return absl::FailedPreconditionError( - "Allocator address already has an active reservation alias"); - } + if (source_record->reservation_active) { + return absl::FailedPreconditionError( + "Allocator address already has an active reservation alias"); } - - ASSIGN_OR_RETURN( - MemoryReservation::ScopedMapping scoped_mapping, - reservation->MapTo(reservation_offset, /*allocation_offset=*/0, size, - *raw_allocation)); - state->active_reservation_mappings.emplace( - reservation_address.opaque(), - ReservationMapping{addr, reservation_address, reservation, - reservation_offset, size, std::move(scoped_mapping)}); + if (source_record->reservation_stale) { + return absl::FailedPreconditionError( + "Allocator address already has a pending reservation alias"); + } + + // Install the reservation address mapping to the raw physical allocation. The + // allocation_offset is zero because Map() aliases the beginning of + // the source allocation; callers pass the target VA location through + // `reservation_offset`. + ASSIGN_OR_RETURN(auto mapping, reservation->MapTo(reservation_offset, + /*allocation_offset=*/0, + size, *raw_allocation)); + DeviceAddressBase mapped = mapping.mapped_address(); + // The reservation slice was computed before locking. Verify the platform + // returned the exact reservation address before recording allocator + // bookkeeping. + if (!mapped.IsSameAs(reservation_address)) { + return absl::InternalError(absl::StrFormat( + "Map() mapped unexpected virtual address: expected=%p, actual=%p", + reservation_address.opaque(), mapped.opaque())); + } + // Track this as a Map()-owned reservation alias. This only updates the + // reservation-address index; no new physical allocation is created, so + // pa_allocated does not change. + CHECK(!source_record->reservation_active); + CHECK(!source_record->reservation_stale); + source_record->reservation_address = mapped; + source_record->reservation_address_mapping.emplace(std::move(mapping)); + source_record->reservation_active = true; + auto mapping_insert_result = + state->active_reservation_records.emplace(mapped.opaque(), source_record); + CHECK(mapping_insert_result.second); return absl::OkStatus(); } -absl::Status DeviceAddressVmmAllocator::UnMap(int device_ordinal, - MemoryReservation* reservation, - uint64_t reservation_offset, - uint64_t size) { - if (size == 0) { - return absl::OkStatus(); - } - ASSIGN_OR_RETURN( - DeviceAddressBase reservation_address, - ValidateReservationRange(reservation, reservation_offset, size)); - - ASSIGN_OR_RETURN(auto state, GetPerDeviceState(device_ordinal)); - - absl::MutexLock lock(state->mu); - auto it = - state->active_reservation_mappings.find(reservation_address.opaque()); - if (it == state->active_reservation_mappings.end()) { - return absl::InvalidArgumentError( - "DeviceAddressVmmAllocator::UnMap received an untracked reservation " - "address"); - } - if (it->second.reservation != reservation || - it->second.reservation_offset != reservation_offset || - it->second.size != size) { - return absl::InvalidArgumentError( - "DeviceAddressVmmAllocator::UnMap requires the same full reservation " - "range passed to Map"); - } - - uint64_t seqno = state->next_seqno++; - RETURN_IF_ERROR(EnqueueDeferredDeallocation(*state, seqno)); +// UnMap/deferred teardown helpers. - ReservationMapping mapping = std::move(it->second); - state->active_reservation_mappings.erase(it); - state->stale_reservation_mappings.emplace(reservation_address.opaque(), - std::move(mapping)); - state->pending_deallocations.push_back( - {PendingDeallocationKind::kMap, seqno, reservation_address, - /*reclaimable_bytes=*/0}); - return absl::OkStatus(); +void DeviceAddressVmmAllocator::ErasePendingDeallocationAt( + PerDeviceState& state, std::deque::iterator it) { + CHECK(it != state.pending_deallocations.end()); + state.pending_deallocations.erase(it); } -absl::StatusOr DeviceAddressVmmAllocator::GetStream( - int device_ordinal) { - ASSIGN_OR_RETURN(auto state, GetPerDeviceState(device_ordinal)); - return state->stream; +void DeviceAddressVmmAllocator::MoveAllocatorRecordToActive( + PerDeviceState& state, AllocationRecord& record, uint64_t new_size) { + CHECK(!record.allocator_active); + CHECK(record.allocator_stale); + CHECK(record.allocator_address_mapping.has_value()); + void* allocator_va = record.allocator_address.opaque(); + auto record_it = state.records_by_allocator_address.find(allocator_va); + CHECK(record_it != state.records_by_allocator_address.end()); + CHECK_EQ(record_it->second.get(), &record); + record.allocator_address = DeviceAddressBase(allocator_va, new_size); + record.allocator_active = true; + record.allocator_stale = false; + record.allocator_stale_seqno = 0; } -absl::Status DeviceAddressVmmAllocator::SynchronizePendingOperations( - int device_ordinal) { - ASSIGN_OR_RETURN(auto state, GetPerDeviceState(device_ordinal)); +void DeviceAddressVmmAllocator::MoveReservationRecordToStale( + PerDeviceState& state, AllocationRecord& record, uint64_t seqno) { + CHECK(record.reservation_active); + CHECK(!record.reservation_stale); + CHECK(record.reservation_address.has_value()); + CHECK(record.reservation_address_mapping.has_value()); + void* reservation_va = record.reservation_address->opaque(); + CHECK_EQ(state.active_reservation_records.erase(reservation_va), 1); + auto insert_result = + state.stale_reservation_records.emplace(reservation_va, &record); + CHECK(insert_result.second); + record.reservation_active = false; + record.reservation_stale = true; + record.reservation_stale_seqno = seqno; +} - uint64_t target_seqno; - { - absl::MutexLock lock(state->mu); - if (state->pending_deallocations.empty()) { - return absl::OkStatus(); - } - target_seqno = state->pending_deallocations.back().seqno; +void DeviceAddressVmmAllocator::CompleteStaleReservationMapping( + PerDeviceState& state, AllocationRecord& record) { + if (!record.reservation_stale) { + return; } + CHECK(!record.reservation_active); + CHECK(record.reservation_address.has_value()); + void* reservation_va = record.reservation_address->opaque(); + auto stale_it = state.stale_reservation_records.find(reservation_va); + if (stale_it != state.stale_reservation_records.end()) { + CHECK_EQ(stale_it->second, &record); + state.stale_reservation_records.erase(stale_it); + } + record.reservation_address_mapping.reset(); + record.reservation_address.reset(); + record.reservation_stale = false; + record.reservation_stale_seqno = 0; +} - while (LoadTimeline(state->pinned_timeline) < target_seqno) { - absl::SleepFor(kGpuTimelinePollInterval); - } +absl::Status DeviceAddressVmmAllocator::WaitUntilSeqno(PerDeviceState& state, + uint64_t target_seqno) { + // Release the lock before spin-waiting to avoid stalling other threads for + // potentially milliseconds while the GPU drains its work queue. + state.mu.unlock(); - { - absl::MutexLock lock(state->mu); - while (!state->pending_deallocations.empty() && - state->pending_deallocations.front().seqno <= target_seqno) { - if (state->pending_deallocations.front().kind != - PendingDeallocationKind::kMap) { - DoDeallocate(*state, state->pending_deallocations.front().addr); - } else { - DoUnMap(*state, state->pending_deallocations.front().addr); - } - state->pending_deallocations.pop_front(); - } + // Poll until the GPU writes a timeline value >= target_seqno. + // Since timeline values are written in stream order, this guarantees all + // selected pending operations have completed. + while (LoadTimeline(state.pinned_timeline) < target_seqno) { + absl::SleepFor(kGpuTimelinePollInterval); } + state.mu.lock(); return absl::OkStatus(); } -absl::StatusOr DeviceAddressVmmAllocator::GetStreamExecutor( - int device_ordinal) const { - ASSIGN_OR_RETURN(auto state, GetPerDeviceState(device_ordinal)); - return state->executor; +absl::Status +DeviceAddressVmmAllocator::WaitAndDrainPendingDeallocationsUntilSeqno( + PerDeviceState& state, uint64_t target_seqno) { + RETURN_IF_ERROR(WaitUntilSeqno(state, target_seqno)); + while (!state.pending_deallocations.empty() && + state.pending_deallocations.front().seqno <= target_seqno) { + PendingDeallocation pending = state.pending_deallocations.front(); + state.pending_deallocations.pop_front(); + CompletePendingDeallocation(state, pending); + } + return absl::OkStatus(); } -MemoryAllocation* DeviceAddressVmmAllocator::GetRawAllocation( - int device_ordinal, DeviceAddressBase addr) const { - absl::StatusOr state_or = GetPerDeviceState(device_ordinal); - if (!state_or.ok()) { - return nullptr; +void DeviceAddressVmmAllocator::CompleteReadyAllocatorDeallocationsForReclaim( + PerDeviceState& state, uint64_t completed_seqno) { + std::vector selected; + for (const PendingDeallocation& pending : state.pending_deallocations) { + if (pending.seqno > completed_seqno || + pending.kind == PendingDeallocationKind::kMap) { + continue; + } + selected.push_back( + PendingDeallocationKey{pending.kind, pending.seqno, pending.addr}); } - PerDeviceState* state = *state_or; - absl::MutexLock lock(state->mu); - auto it = state->records_by_allocator_address.find(addr.opaque()); - if (it == state->records_by_allocator_address.end() || - !it->second->allocator_active || - !it->second->allocator_address.IsSameAs(addr)) { - return nullptr; + for (const PendingDeallocationKey& key : selected) { + CompletePendingDeallocationByKey(state, key); } - return it->second->raw_allocation.get(); } -MemoryReservation* DeviceAddressVmmAllocator::GetReservation( - int device_ordinal, DeviceAddressBase addr) const { - absl::StatusOr state_or = GetPerDeviceState(device_ordinal); - if (!state_or.ok()) { - return nullptr; - } - PerDeviceState* state = *state_or; - absl::MutexLock lock(state->mu); - auto it = state->records_by_allocator_address.find(addr.opaque()); - if (it == state->records_by_allocator_address.end() || - !it->second->allocator_active || - !it->second->allocator_address.IsSameAs(addr)) { - return nullptr; +bool DeviceAddressVmmAllocator::CompletePendingDeallocationByKey( + PerDeviceState& state, const PendingDeallocationKey& key) { + for (auto it = state.pending_deallocations.begin(); + it != state.pending_deallocations.end(); ++it) { + if (it->kind == key.kind && it->seqno == key.seqno && + it->addr.IsSameAs(key.addr)) { + PendingDeallocation pending = *it; + state.pending_deallocations.erase(it); + CompletePendingDeallocation(state, pending); + return true; + } } - return it->second->allocator_address_reservation.get(); + return false; } -uint64_t DeviceAddressVmmAllocator::GetAllocationGranularity( - StreamExecutor* executor) const { - absl::StatusOr state_or = - GetPerDeviceState(executor->device_ordinal()); - if (!state_or.ok()) { - return 0; +void DeviceAddressVmmAllocator::CompletePendingDeallocation( + PerDeviceState& state, const PendingDeallocation& pending) { + if (pending.kind == PendingDeallocationKind::kMap) { + auto record_it = + state.stale_reservation_records.find(pending.addr.opaque()); + CHECK(record_it != state.stale_reservation_records.end()); + CHECK_EQ(record_it->second->reservation_stale_seqno, pending.seqno); + CompleteStaleReservationMapping(state, *record_it->second); + return; } - PerDeviceState* state = *state_or; - return state->allocation_granularity; -} -void DeviceAddressVmmAllocator::MoveAllocatorRecordToActive( - PerDeviceState& state, AllocationRecord& record, uint64_t new_size) { - CHECK(!record.allocator_active); - CHECK(record.allocator_stale); - CHECK(record.allocator_address_mapping.has_value()); - void* allocator_va = record.allocator_address.opaque(); - auto record_it = state.records_by_allocator_address.find(allocator_va); + auto record_it = + state.records_by_allocator_address.find(pending.addr.opaque()); CHECK(record_it != state.records_by_allocator_address.end()); - CHECK_EQ(record_it->second.get(), &record); - record.allocator_address = DeviceAddressBase(allocator_va, new_size); - record.allocator_active = true; - record.allocator_stale = false; - record.allocator_stale_seqno = 0; -} - -std::optional -DeviceAddressVmmAllocator::TryReusePendingDeallocation(PerDeviceState& state, - uint64_t size, - bool multi_device) { - uint64_t rounded_size = RoundUpToGranularity(state, size); - for (auto it = state.pending_deallocations.begin(); - it != state.pending_deallocations.end(); ++it) { - if (it->kind != PendingDeallocationKind::kAllocate) { - continue; - } - auto record_it = state.records_by_allocator_address.find(it->addr.opaque()); - CHECK(record_it != state.records_by_allocator_address.end()); - AllocationRecord& record = *record_it->second; - CHECK(record.allocator_stale); - CHECK(record.allocator_address.IsSameAs(it->addr)); - if (record.multi_device != multi_device) { - continue; - } - if (RoundUpToGranularity(state, record.allocator_address.size()) != - rounded_size) { - continue; + CHECK(record_it->second->allocator_stale); + CHECK(record_it->second->allocator_address.IsSameAs(pending.addr)); + CHECK_EQ(record_it->second->kind, pending.kind); + CHECK_EQ(record_it->second->allocator_stale_seqno, pending.seqno); + // Complete allocator-address teardown. If this allocation still has an + // explicitly unmapped stale reservation alias, drop that mapping first, then + // release allocator VA state and physical allocation accounting. + AllocationRecord& record = *record_it->second; + CHECK(!record.allocator_active); + CHECK(!record.reservation_active); + if (record.reservation_stale) { + CHECK(record.reservation_address.has_value()); + PendingDeallocationKey reservation_key{PendingDeallocationKind::kMap, + record.reservation_stale_seqno, + *record.reservation_address}; + for (auto it = state.pending_deallocations.begin(); + it != state.pending_deallocations.end(); ++it) { + if (it->kind == reservation_key.kind && + it->seqno == reservation_key.seqno && + it->addr.IsSameAs(reservation_key.addr)) { + state.pending_deallocations.erase(it); + break; + } } + CompleteStaleReservationMapping(state, record); + } + void* allocator_va = record.allocator_address.opaque(); + auto owning_record_it = state.records_by_allocator_address.find(allocator_va); + CHECK(owning_record_it != state.records_by_allocator_address.end()); + CHECK_EQ(owning_record_it->second.get(), &record); + record.allocator_address_mapping.reset(); + record.allocator_address_reservation.reset(); - DeviceAddressBase reused_mem(record.allocator_address.opaque(), size); - VLOG(3) << absl::StreamFormat( - "Reusing pending deallocation: address=%p original_size=%uB " - "new_size=%uB rounded_size=%uB device=%d", - reused_mem.opaque(), record.allocator_address.size(), size, rounded_size, - state.executor->device_ordinal()); - MoveAllocatorRecordToActive(state, record, size); - state.pending_deallocations.erase(it); + if (record.raw_allocation != nullptr) { + uint64_t released_size = + RoundUpToGranularity(state, record.raw_allocation->address().size()); + DCHECK_GE(state.pa_allocated, released_size); + state.pa_allocated -= released_size; + } + record.raw_allocation.reset(); + CHECK_EQ(state.records_by_allocator_address.erase(allocator_va), 1); +} - return reused_mem; +absl::Status DeviceAddressVmmAllocator::UnMap(int device_ordinal, + MemoryReservation* reservation, + uint64_t reservation_offset, + uint64_t size) { + ASSIGN_OR_RETURN(auto state, GetPerDeviceState(device_ordinal)); + if (size == 0) { + return absl::OkStatus(); } - return std::nullopt; -} + // Map() and Allocate(..., return_reservation_address=false) record + // reservation mappings by the mapped reservation VA. Reconstruct the same + // reservation slice here so callers do not need to hold a ScopedMapping. + ASSIGN_OR_RETURN( + DeviceAddressBase reservation_address, + ValidateReservationRange(reservation, reservation_offset, size)); -uint64_t DeviceAddressVmmAllocator::RoundUpToGranularity( - const PerDeviceState& state, uint64_t size) const { - if (state.allocation_granularity == 0) { - return size; + absl::MutexLock lock(state->mu); + // UnMap() only accepts the exact active reservation range previously created + // by Map() or Allocate(..., return_reservation_address=false). Allocator + // addresses and subranges are not valid UnMap() inputs. + auto active_it = + state->active_reservation_records.find(reservation_address.opaque()); + if (active_it == state->active_reservation_records.end()) { + auto stale_it = + state->stale_reservation_records.find(reservation_address.opaque()); + if (stale_it != state->stale_reservation_records.end()) { + CHECK(stale_it->second->reservation_address.has_value()); + } + if (stale_it != state->stale_reservation_records.end() && + stale_it->second->reservation_address->IsSameAs(reservation_address)) { + return absl::FailedPreconditionError(absl::StrFormat( + "reservation range at virtual address %p (%uB) is already pending " + "UnMap()", + reservation_address.opaque(), reservation_address.size())); + } + return absl::NotFoundError(absl::StrFormat( + "UnMap() requires an exact active reservation range created by Map() " + "or Allocate(..., return_reservation_address=false): virtual address " + "%p (%uB)", + reservation_address.opaque(), reservation_address.size())); + } + AllocationRecord* record = active_it->second; + CHECK(record->reservation_active); + CHECK(!record->reservation_stale); + CHECK(record->reservation_address.has_value()); + if (!record->reservation_address->IsSameAs(reservation_address)) { + return absl::InvalidArgumentError( + "DeviceAddressVmmAllocator::UnMap requires the same full reservation " + "range passed to Map"); } - return ((size + state.allocation_granularity - 1) / - state.allocation_granularity) * - state.allocation_granularity; + CHECK(record->reservation_address_mapping.has_value()); + CHECK(record->reservation_address_mapping->mapped_address().IsSameAs( + reservation_address)); + + uint64_t seqno = state->next_seqno++; + RETURN_IF_ERROR(EnqueueDeferredDeallocation(*state, seqno)); + MoveReservationRecordToStale(*state, *record, seqno); + state->pending_deallocations.push_back( + PendingDeallocation{PendingDeallocationKind::kMap, seqno, + reservation_address, /*reclaimable_bytes=*/0}); + return absl::OkStatus(); } } // namespace stream_executor diff --git a/xla/stream_executor/device_address_vmm_allocator.h b/xla/stream_executor/device_address_vmm_allocator.h index b0883b9a50667..62676718d3532 100644 --- a/xla/stream_executor/device_address_vmm_allocator.h +++ b/xla/stream_executor/device_address_vmm_allocator.h @@ -108,24 +108,28 @@ namespace stream_executor { // only reservation addresses created by Map() or by // Allocate(..., return_reservation_address=false). Passing an allocator address // to UnMap(), or a reservation address to Deallocate(), is an error. -// Each allocator address may have at most one active reservation-address alias. +// Each allocator address may have at most one active or stale +// reservation-address alias. // // Deallocate() and UnMap() are stream-ordered deferred operations. The -// allocator records a per-device sequence number for the affected address and -// keeps the old mapping or allocation alive until the stream reaches that -// sequence number, so kernels already submitted to the stream can keep using -// the old VA. When the sequence completes, dropping the ScopedMapping objects -// performs the real unmap, then the allocator releases any owned reservation -// and raw physical memory. +// allocator assigns the affected address record a per-device sequence number, +// moves it from active tracking to stale tracking, and appends a pending entry +// with the operation kind, sequence number, and address. The stale +// AllocationRecord keeps the raw allocation, any allocator-owned reservation, +// and ScopedMapping objects alive until the stream reaches that sequence +// number, so kernels already submitted to the stream can keep using the old VA. +// When the sequence completes, dropping the ScopedMapping objects performs the +// real unmap, then the allocator releases any owned reservation and raw +// physical memory. // -// Concrete subclasses implement the platform-specific virtual methods +// Each registered device has independent state protected by its own mutex, so +// operations on different devices can proceed in parallel. The per-device map +// is populated at construction time and is not modified afterward. Concrete +// subclasses implement the platform-specific virtual methods // (InitializeDeviceState, CreateAllocation, CreateReservation, // EnqueueDeferredDeallocation) and expose platform-specific Create() factories. // Subclasses must also set PerDeviceState::destroy_fn in InitializeDeviceState // to release platform-specific resources such as pinned timeline memory. -// -// This allocator is thread-safe for concurrent use by multiple threads across -// any registered devices. class DeviceAddressVmmAllocator : public DeviceAddressAllocator { public: // Per-device configuration supplied at construction. @@ -221,7 +225,8 @@ class DeviceAddressVmmAllocator : public DeviceAddressAllocator { // range until all previously enqueued work on the allocator stream has // completed. // The caller must pass the same full reservation range that created the - // mapping. The reservation-derived allocator address returned by + // mapping. + // The reservation-derived allocator address returned by // Allocate(..., return_reservation_address=true) is not a reservation // address for this API and must be released with Deallocate() instead. // @@ -285,32 +290,14 @@ class DeviceAddressVmmAllocator : public DeviceAddressAllocator { kMap, }; - struct PendingDeallocation { - PendingDeallocationKind kind = PendingDeallocationKind::kAllocate; - // GPU stream sequence number recorded at deallocation time. When the - // pinned_timeline value reaches this seqno, the memory is safe to free. - uint64_t seqno = 0; - // Allocator address for allocation deallocations; reservation address for - // kMap. - DeviceAddressBase addr; - // Rounded physical-allocation bytes that become reclaimable when this - // pending operation completes. kMap entries do not own physical memory and - // therefore use zero. - uint64_t reclaimable_bytes = 0; - }; - - struct ReservationMapping { - DeviceAddressBase allocator_address; - DeviceAddressBase reservation_address; - MemoryReservation* reservation = nullptr; - uint64_t reservation_offset = 0; - uint64_t size = 0; - MemoryReservation::ScopedMapping scoped_mapping; - }; - - // Lifetime record for one raw physical allocation. The initial migration adds - // this next to the existing maps; later changes move allocator and reservation - // address ownership into this record. + // Lifetime record for one raw physical allocation. + // + // The record is owned by records_by_allocator_address while either the + // allocator address or a reservation-address alias is active, stale, or + // pending completion. Active indexes are callable by public APIs. Stale + // indexes are no longer callable by users, but still keep mappings alive + // until the stream-ordered deferred operation completes or a later Allocate() + // or Map() reuses them. struct AllocationRecord { PendingDeallocationKind kind = PendingDeallocationKind::kAllocate; DeviceAddressBase allocator_address; @@ -327,14 +314,55 @@ class DeviceAddressVmmAllocator : public DeviceAddressAllocator { std::optional reservation_address; std::optional reservation_address_mapping; + // Allocator address state. Every live AllocationRecord has an allocator + // address, and exactly one of allocator_active/allocator_stale is true + // until the deferred allocator-address deallocation completes and the + // record is destroyed. bool allocator_active = false; bool allocator_stale = false; + + // Reservation-address alias state. A record may have no reservation alias; + // in that case reservation_address/reservation_address_mapping are empty + // and both flags are false. If a reservation alias exists, exactly one of + // reservation_active/reservation_stale is true until the deferred unmap + // completes or the alias is reactivated. bool reservation_active = false; bool reservation_stale = false; + + // Valid only while the corresponding address is stale. These seqnos + // identify the stream point at which it is safe to destroy the stale + // mapping and, for allocator addresses, release the raw physical + // allocation. uint64_t allocator_stale_seqno = 0; uint64_t reservation_stale_seqno = 0; }; + // Queue entry for a stream-ordered deferred operation. The heavy resources + // live in AllocationRecord; this entry only says which stale address becomes + // safe to complete when the GPU timeline reaches `seqno`. + struct PendingDeallocation { + PendingDeallocationKind kind = PendingDeallocationKind::kAllocate; + // GPU stream sequence number recorded at deallocation time. When the + // pinned_timeline value reaches this seqno, the memory is safe to free. + uint64_t seqno = 0; + // Allocator address for allocation deallocations; reservation address for + // kMap. + DeviceAddressBase addr; + // Rounded physical-allocation bytes that become reclaimable when this + // pending operation completes. kMap entries do not own physical memory and + // therefore use zero. + uint64_t reclaimable_bytes = 0; + }; + + // Stable identity for a pending operation. Iterators into + // pending_deallocations must not be kept across waits because + // WaitUntilSeqno() releases state.mu. + struct PendingDeallocationKey { + PendingDeallocationKind kind = PendingDeallocationKind::kAllocate; + uint64_t seqno = 0; + DeviceAddressBase addr; + }; + struct PerDeviceState { StreamExecutor* executor; Stream* stream; @@ -366,11 +394,21 @@ class DeviceAddressVmmAllocator : public DeviceAddressAllocator { // Monotonically increasing counter for timeline sequence numbers. uint64_t next_seqno ABSL_GUARDED_BY(mu) = 1; std::deque pending_deallocations ABSL_GUARDED_BY(mu); + // Owns AllocationRecord objects. Key is the allocator address pointer + // (`AllocationRecord::allocator_address.opaque()`), including the + // reservation-derived allocator address returned by + // Allocate(..., return_reservation_address=true). Allocator-address + // active/stale state is stored in + // AllocationRecord::allocator_active/allocator_stale. absl::flat_hash_map> records_by_allocator_address ABSL_GUARDED_BY(mu); - absl::flat_hash_map active_reservation_mappings + + // Active/stale reservation-address indexes. Keys are reservation alias + // pointers (`AllocationRecord::reservation_address->opaque()`) created by + // Map() or by Allocate(..., return_reservation_address=false). + absl::flat_hash_map active_reservation_records ABSL_GUARDED_BY(mu); - absl::flat_hash_map stale_reservation_mappings + absl::flat_hash_map stale_reservation_records ABSL_GUARDED_BY(mu); }; @@ -385,7 +423,7 @@ class DeviceAddressVmmAllocator : public DeviceAddressAllocator { static absl::Status PopulateDevices(DeviceAddressVmmAllocator* allocator, absl::Span devices); - // Drains pending stream-ordered allocator operations for all devices. + // Drains all pending operations for all devices. absl::Status SynchronizeAllPendingOperations(); // Validates device capabilities and initializes timeline fields @@ -402,16 +440,25 @@ class DeviceAddressVmmAllocator : public DeviceAddressAllocator { uint64_t seqno) = 0; private: + // Common helpers. + // Returns pointer into per_device_ map, or NotFound if device_ordinal is not // registered. No lock needed: per_device_ is read-only after construction. absl::StatusOr GetPerDeviceState(int device_ordinal) const; - // Validates a caller-owned reservation slice and returns the corresponding - // DeviceAddressBase. - absl::StatusOr ValidateReservationRange( - MemoryReservation* reservation, uint64_t reservation_offset, - uint64_t size) const; + // Round up size to the device's allocation granularity. + uint64_t RoundUpToGranularity(const PerDeviceState& state, + uint64_t size) const; + // True iff the calling thread is inside a multi-device DeviceAssignmentScope. + static bool CurrentMultiDevice(); + + // Allocate helpers. + + // Records a raw allocation mapped at an owning allocator address. Takes + // ownership of `reservation` when the allocator address was allocator-owned; + // reservation-backed returned addresses pass nullptr here. Charges + // `allocated_size` to the PA budget and returns the allocator VA pointer. void* TrackAllocatorAddressMappedAllocation( PerDeviceState& state, PendingDeallocationKind kind, DeviceAddressBase allocator_address, @@ -420,53 +467,83 @@ class DeviceAddressVmmAllocator : public DeviceAddressAllocator { MemoryReservation::ScopedMapping mapping, uint64_t allocated_size, bool multi_device) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state.mu); - absl::StatusOr AllocateWithBudget(PerDeviceState& state, - uint64_t size, - bool multi_device) + // Shared allocation retry policy. First calls `try_reuse` to reactivate + // compatible pending state without blocking, then calls `try_fresh`. On + // ResourceExhausted, it completes ready pending entries and, if needed, waits + // for enough pending frees to reclaim approximately `reclaim_size` bytes. + template + absl::StatusOr TryWithPendingReclaim(PerDeviceState& state, + uint64_t reclaim_size, + TryReuseFn try_reuse, + TryFreshFn try_fresh) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state.mu); + + // Map helpers. + + // Validates a caller-owned reservation slice and returns the corresponding + // DeviceAddressBase. Rejects null reservations and out-of-bounds + // offset/size pairs before any allocator bookkeeping is mutated. + absl::StatusOr ValidateReservationRange( + MemoryReservation* reservation, uint64_t reservation_offset, + uint64_t size) const; + + // UnMap/deferred teardown helpers. + + // Removes a pending entry when a stale record is reused. + void ErasePendingDeallocationAt(PerDeviceState& state, + std::deque::iterator it) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state.mu); - // Process any pending deallocations whose timeline sequence numbers have - // been passed by the GPU. - void ProcessCompletedPendingDeallocations(PerDeviceState& state) + void MoveAllocatorRecordToActive(PerDeviceState& state, + AllocationRecord& record, uint64_t new_size) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state.mu); - // Wait for enough pending deallocations to complete to free at least 'size' - // bytes. Selects deallocations from the front of the queue until their - // cumulative size meets or exceeds the requested size, then spin-waits on - // the GPU timeline counter and performs the deallocations. - // Temporarily releases and reacquires state.mu around the blocking wait. - void WaitPendingDeallocationsToComplete(PerDeviceState& state, uint64_t size) + void MoveReservationRecordToStale(PerDeviceState& state, + AllocationRecord& record, uint64_t seqno) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state.mu); - // Actually perform the synchronous deallocation. - void DoDeallocate(PerDeviceState& state, DeviceAddressBase mem) + void CompleteStaleReservationMapping(PerDeviceState& state, + AllocationRecord& record) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state.mu); - // Actually perform the synchronous unmap for a stale reservation alias. - void DoUnMap(PerDeviceState& state, DeviceAddressBase mem) + // Waits for the device timeline to reach `target_seqno`. Temporarily releases + // and reacquires state.mu around the blocking wait. This does not complete + // pending entries by itself. + absl::Status WaitUntilSeqno(PerDeviceState& state, uint64_t target_seqno) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state.mu); - void MoveAllocatorRecordToActive(PerDeviceState& state, - AllocationRecord& record, uint64_t new_size) + // Waits for pending operations through `target_seqno`, then completes all + // still-pending operations up to that sequence. Used only when preserving + // stale mappings for future reuse is no longer useful. + absl::Status WaitAndDrainPendingDeallocationsUntilSeqno(PerDeviceState& state, + uint64_t target_seqno) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state.mu); - // Try to reuse a pending deallocation with matching rounded size. - // Returns the reused address if found, or std::nullopt if no match. - // Reuse is safe because any new work submitted after Allocate() returns is - // enqueued on the same stream after the recorded deallocation event, so GPU - // stream ordering guarantees the old work finishes before the new work runs. - std::optional TryReusePendingDeallocation( - PerDeviceState& state, uint64_t size, bool multi_device) + // Completes ready allocator-address deallocations for PA reclaim while + // leaving unrelated kMap entries stale and reusable. + void CompleteReadyAllocatorDeallocationsForReclaim(PerDeviceState& state, + uint64_t completed_seqno) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state.mu); - // True iff the calling thread is inside a multi-device DeviceAssignmentScope. - static bool CurrentMultiDevice(); + // Completes a pending operation whose stream sequence has passed by dropping + // its ScopedMappings, allocator-owned reservation, and raw allocation + // reference. This is where VA unmap, reservation release, and PA budget + // accounting happen. + void CompletePendingDeallocation(PerDeviceState& state, + const PendingDeallocation& pending) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state.mu); - // Round up size to the device's allocation granularity. - uint64_t RoundUpToGranularity(const PerDeviceState& state, - uint64_t size) const; + // Finds, erases, and completes the selected pending entry if it is still + // present. Returns false if another thread already reused or completed it + // while state.mu was released. + bool CompletePendingDeallocationByKey(PerDeviceState& state, + const PendingDeallocationKey& key) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state.mu); - // Populated at construction; never modified. Safe to read without a lock. + // Device ordinal -> per-device allocator state. Populated at construction by + // PopulateDevices() and never modified afterward, so map lookup is safe + // without an allocator-wide lock. Each PerDeviceState owns its own mutex for + // mutable allocation and pending-deallocation state. absl::flat_hash_map> per_device_; };