diff --git a/include/cachinglayer/LoadingOverheadTracker.h b/include/cachinglayer/LoadingOverheadTracker.h index ee25823..e23a965 100644 --- a/include/cachinglayer/LoadingOverheadTracker.h +++ b/include/cachinglayer/LoadingOverheadTracker.h @@ -12,9 +12,12 @@ #pragma once #include +#include #include +#include #include #include +#include #include "cachinglayer/Utils.h" #include "log/Log.h" @@ -47,52 +50,115 @@ class LoadingOverheadTracker { // Each Register increments a ref count; call Unregister to decrement. uint64_t Register(const std::string& group, const ResourceUsage& upper_bound) { - std::lock_guard lock(mtx_); - auto it = name_to_handle_.find(group); - if (it != name_to_handle_.end()) { - auto& state = handle_state_[it->second]; - state.ref_count++; - if (state.upper_bound == kUnlimited) { - state.upper_bound = upper_bound; - LOG_INFO("[MCL] LoadingOverheadTracker set UB for group '{}' (handle {}, refs={}): {}", group, - it->second, state.ref_count, upper_bound.ToString()); - } else if (state.upper_bound.memory_bytes < upper_bound.memory_bytes || - state.upper_bound.file_bytes < upper_bound.file_bytes) { + enum class LogEvent { + kNone, + kSetUpperBound, + kMismatch, + kReregister, + kNewGroup, + }; + + uint64_t handle = kInvalidHandle; + uint64_t ref_count = 0; + ResourceUsage old_upper_bound{}; + ResourceUsage effective_upper_bound{}; + LogEvent log_event = LogEvent::kNone; + + auto update_existing_state = [&](uint64_t existing_handle, const std::shared_ptr& state) { + handle = existing_handle; + std::lock_guard state_lock(state->mtx); + + state->ref_count++; + ref_count = state->ref_count; + old_upper_bound = state->upper_bound; + if (state->upper_bound == kUnlimited) { + state->upper_bound = upper_bound; + effective_upper_bound = state->upper_bound; + log_event = LogEvent::kSetUpperBound; + } else if (state->upper_bound.memory_bytes < upper_bound.memory_bytes || + state->upper_bound.file_bytes < upper_bound.file_bytes) { + state->upper_bound.memory_bytes = std::max(state->upper_bound.memory_bytes, upper_bound.memory_bytes); + state->upper_bound.file_bytes = std::max(state->upper_bound.file_bytes, upper_bound.file_bytes); + effective_upper_bound = state->upper_bound; + log_event = LogEvent::kMismatch; + } else { + effective_upper_bound = state->upper_bound; + log_event = LogEvent::kReregister; + } + }; + + { + std::shared_lock map_lock(map_mtx_); + auto it = name_to_handle_.find(group); + if (it != name_to_handle_.end()) { + auto state_it = handle_state_.find(it->second); + AssertInfo(state_it != handle_state_.end(), + "[MCL] LoadingOverheadTracker group '{}' has no state for handle {}", group, it->second); + update_existing_state(it->second, state_it->second); + } + } + + if (handle == kInvalidHandle) { + std::unique_lock map_lock(map_mtx_); + auto it = name_to_handle_.find(group); + if (it != name_to_handle_.end()) { + auto state_it = handle_state_.find(it->second); + AssertInfo(state_it != handle_state_.end(), + "[MCL] LoadingOverheadTracker group '{}' has no state for handle {}", group, it->second); + update_existing_state(it->second, state_it->second); + } else { + handle = next_handle_++; + auto state = std::make_shared(upper_bound, group); + name_to_handle_[group] = handle; + handle_state_[handle] = std::move(state); + ref_count = 1; + effective_upper_bound = upper_bound; + log_event = LogEvent::kNewGroup; + } + } + + switch (log_event) { + case LogEvent::kSetUpperBound: + LOG_INFO("[MCL] LoadingOverheadTracker set UB for group '{}' (handle {}, refs={}): {}", group, handle, + ref_count, effective_upper_bound.ToString()); + break; + case LogEvent::kMismatch: LOG_WARN( "[MCL] LoadingOverheadTracker UB mismatch for group '{}' (handle {}): existing={}, new={}. " - "Taking max per dimension.", - group, it->second, state.upper_bound.ToString(), upper_bound.ToString()); - state.upper_bound.memory_bytes = std::max(state.upper_bound.memory_bytes, upper_bound.memory_bytes); - state.upper_bound.file_bytes = std::max(state.upper_bound.file_bytes, upper_bound.file_bytes); - } else { + "Taking max per dimension, effective={}.", + group, handle, old_upper_bound.ToString(), upper_bound.ToString(), + effective_upper_bound.ToString()); + break; + case LogEvent::kReregister: LOG_DEBUG("[MCL] LoadingOverheadTracker re-registered group '{}' (handle {}, refs={}), UB unchanged", - group, it->second, state.ref_count); - } - return it->second; + group, handle, ref_count); + break; + case LogEvent::kNewGroup: + LOG_INFO("[MCL] LoadingOverheadTracker registered group '{}' (handle {}, refs=1): UB={}", group, handle, + effective_upper_bound.ToString()); + break; + case LogEvent::kNone: + break; } - auto handle = next_handle_++; - name_to_handle_[group] = handle; - handle_state_[handle] = GroupState{upper_bound, {}, {}, 1, group}; - LOG_INFO("[MCL] LoadingOverheadTracker registered group '{}' (handle {}, refs=1): UB={}", group, handle, - upper_bound.ToString()); return handle; } // Called before loading. Returns the delta to reserve from DList for loading overhead. ResourceUsage Reserve(uint64_t handle, const ResourceUsage& loading_overhead) { - std::lock_guard lock(mtx_); + std::shared_lock map_lock(map_mtx_); auto it = handle_state_.find(handle); if (it == handle_state_.end()) { return loading_overhead; } - auto& state = it->second; - state.sum_of_overhead += loading_overhead; - auto target = cappedAmount(state.sum_of_overhead, state.upper_bound); - auto delta = target - state.overhead_reserved; + auto state = it->second; + std::lock_guard state_lock(state->mtx); + state->sum_of_overhead += loading_overhead; + auto target = cappedAmount(state->sum_of_overhead, state->upper_bound); + auto delta = target - state->overhead_reserved; delta.memory_bytes = std::max(delta.memory_bytes, int64_t{0}); delta.file_bytes = std::max(delta.file_bytes, int64_t{0}); - state.overhead_reserved += delta; + state->overhead_reserved += delta; return delta; } @@ -100,44 +166,52 @@ class LoadingOverheadTracker { // Returns the delta to release from DList for loading overhead. ResourceUsage Release(uint64_t handle, const ResourceUsage& loading_overhead) { - std::lock_guard lock(mtx_); + std::shared_lock map_lock(map_mtx_); auto it = handle_state_.find(handle); if (it == handle_state_.end()) { return loading_overhead; } - auto& state = it->second; - state.sum_of_overhead -= loading_overhead; - if (state.sum_of_overhead.memory_bytes < 0) { + auto state = it->second; + std::lock_guard state_lock(state->mtx); + state->sum_of_overhead -= loading_overhead; + if (state->sum_of_overhead.memory_bytes < 0) { LOG_ERROR("[MCL] LoadingOverheadTracker Release handle {}: sum_of_overhead.memory_bytes < 0", handle); - state.sum_of_overhead.memory_bytes = 0; + state->sum_of_overhead.memory_bytes = 0; } - if (state.sum_of_overhead.file_bytes < 0) { + if (state->sum_of_overhead.file_bytes < 0) { LOG_ERROR("[MCL] LoadingOverheadTracker Release handle {}: sum_of_overhead.file_bytes < 0", handle); - state.sum_of_overhead.file_bytes = 0; + state->sum_of_overhead.file_bytes = 0; } - auto target = cappedAmount(state.sum_of_overhead, state.upper_bound); - auto delta = state.overhead_reserved - target; + auto target = cappedAmount(state->sum_of_overhead, state->upper_bound); + auto delta = state->overhead_reserved - target; delta.memory_bytes = std::max(delta.memory_bytes, int64_t{0}); delta.file_bytes = std::max(delta.file_bytes, int64_t{0}); - state.overhead_reserved -= delta; + state->overhead_reserved -= delta; return delta; } bool HasFiniteUpperBound(uint64_t handle) const { - std::lock_guard lock(mtx_); + std::shared_lock map_lock(map_mtx_); auto it = handle_state_.find(handle); - return it != handle_state_.end() && !(it->second.upper_bound == kUnlimited); + if (it == handle_state_.end()) { + return false; + } + auto state = it->second; + std::lock_guard state_lock(state->mtx); + return !(state->upper_bound == kUnlimited); } ResourceUsage GetUpperBound(uint64_t handle) const { - std::lock_guard lock(mtx_); + std::shared_lock map_lock(map_mtx_); auto it = handle_state_.find(handle); if (it == handle_state_.end()) { return kUnlimited; } - return it->second.upper_bound; + auto state = it->second; + std::lock_guard state_lock(state->mtx); + return state->upper_bound; } // Decrement ref count for a group. When ref count reaches 0, the group is @@ -147,34 +221,64 @@ class LoadingOverheadTracker { if (handle == kInvalidHandle) { return; } - std::lock_guard lock(mtx_); - auto it = handle_state_.find(handle); - if (it == handle_state_.end()) { - return; - } - auto& state = it->second; - if (state.ref_count > 0) { - state.ref_count--; + + uint64_t ref_count = 0; + bool should_log_unregister = false; + bool should_log_decrement = false; + bool has_residual = false; + ResourceUsage residual_sum{}; + ResourceUsage residual_reserved{}; + std::string group_name; + + { + std::unique_lock map_lock(map_mtx_); + auto it = handle_state_.find(handle); + if (it == handle_state_.end()) { + return; + } + auto state = it->second; + std::lock_guard state_lock(state->mtx); + + if (state->ref_count > 0) { + state->ref_count--; + } + ref_count = state->ref_count; + if (ref_count > 0) { + group_name = state->group_name; + should_log_decrement = true; + } else { + group_name = state->group_name; + residual_sum = state->sum_of_overhead; + residual_reserved = state->overhead_reserved; + has_residual = residual_sum.AnyGTZero() || residual_reserved.AnyGTZero(); + name_to_handle_.erase(group_name); + handle_state_.erase(it); + should_log_unregister = true; + } } - if (state.ref_count > 0) { - LOG_DEBUG("[MCL] LoadingOverheadTracker handle {} ref_count decremented to {}", handle, state.ref_count); + + if (should_log_decrement) { + LOG_DEBUG("[MCL] LoadingOverheadTracker handle {} ref_count decremented to {}", handle, ref_count); return; } - // ref_count == 0, unconditionally clean up. - // Log error if there are residual reservations — indicates a Reserve/Release pairing bug. - if (state.sum_of_overhead.AnyGTZero() || state.overhead_reserved.AnyGTZero()) { + + if (has_residual) { LOG_ERROR( "[MCL] LoadingOverheadTracker handle {} ref_count=0 with residual reservations: " "sum_of_overhead={}, overhead_reserved={}. Cleaning up anyway to avoid leak.", - handle, state.sum_of_overhead.ToString(), state.overhead_reserved.ToString()); + handle, residual_sum.ToString(), residual_reserved.ToString()); + } + if (should_log_unregister) { + LOG_INFO("[MCL] LoadingOverheadTracker unregistered group '{}' (handle {})", group_name, handle); } - LOG_INFO("[MCL] LoadingOverheadTracker unregistered group '{}' (handle {})", state.group_name, handle); - name_to_handle_.erase(state.group_name); - handle_state_.erase(it); } private: struct GroupState { + GroupState(ResourceUsage ub, std::string group) : upper_bound(ub), ref_count(1), group_name(std::move(group)) { + } + + mutable std::mutex mtx; ResourceUsage upper_bound; ResourceUsage sum_of_overhead; ResourceUsage overhead_reserved; @@ -188,9 +292,9 @@ class LoadingOverheadTracker { std::min(std::max(sum.file_bytes, int64_t{0}), ub.file_bytes)}; } - mutable std::mutex mtx_; + mutable std::shared_mutex map_mtx_; std::unordered_map name_to_handle_; - std::unordered_map handle_state_; + std::unordered_map> handle_state_; uint64_t next_handle_{1}; };