Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,17 @@ class SynchronizedShardedMap {
std::size_t numShards,
std::size_t block_size,
std::size_t block_alignment,
std::size_t blocks_per_chunk = 8192)
std::size_t blocks_per_chunk = 8192,
bool enable_dirty_tracking = false)
: shards_(numShards), mempools_(numShards) {
// Init mempools_
for (auto& pool : mempools_) {
pool = std::make_unique<PoolType>(
block_size, block_alignment, blocks_per_chunk);
block_size,
block_alignment,
blocks_per_chunk,
std::pmr::new_delete_resource(),
enable_dirty_tracking);
}
}

Expand Down
78 changes: 75 additions & 3 deletions fbgemm_gpu/src/dram_kv_embedding_cache/fixed_block_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <c10/util/bit_cast.h>

#include <array>
#include <bit>
#include <cassert>
#include <cmath>
Expand All @@ -19,12 +20,33 @@
#include <mutex>
#include <numeric>
#include <stdexcept>
#include <unordered_set>
#include <vector>

namespace kv_mem {
static constexpr uint32_t kMaxInt31Counter = 2147483647;

class FixedBlockPool : public std::pmr::memory_resource {
private:
struct DirtyTrackerStripe {
std::mutex mutex;
std::unordered_set<const void*> dirty_blocks;
};

static constexpr size_t kDirtyTrackerStripeCount = 64;

// Per-pool dirty tracker. Striped to reduce contention between threads
// operating on the same pool. It lives and dies with the pool, so destroying
// the pool automatically discards all dirty state for its blocks (no manual
// cleanup needed) and block addresses can never alias another pool's blocks.
std::array<DirtyTrackerStripe, kDirtyTrackerStripeCount>
dirty_tracker_stripes_;

static size_t get_dirty_tracker_stripe_index(const void* block) {
auto block_address = reinterpret_cast<uintptr_t>(block);
return (block_address >> 6) % kDirtyTrackerStripeCount;
}

public:
// Chunk metadata
struct ChunkInfo {
Expand Down Expand Up @@ -63,6 +85,36 @@ class FixedBlockPool : public std::pmr::memory_resource {
reinterpret_cast<MetaHeader*>(block)->used = used;
}

// A block is "dirty" when its in-memory data has NOT yet been persisted to
// SSD; "clean" means it is in sync with SSD.
bool get_dirty(const void* block) {
if (!enable_dirty_tracking_) {
return false;
}

auto& stripe =
dirty_tracker_stripes_[get_dirty_tracker_stripe_index(block)];
std::lock_guard<std::mutex> lock(stripe.mutex);
return stripe.dirty_blocks.find(block) != stripe.dirty_blocks.end();
}
void set_dirty(const void* block, bool dirty) {
if (!enable_dirty_tracking_) {
return;
}

auto& stripe =
dirty_tracker_stripes_[get_dirty_tracker_stripe_index(block)];
std::lock_guard<std::mutex> lock(stripe.mutex);
if (dirty) {
stripe.dirty_blocks.insert(block);
} else {
stripe.dirty_blocks.erase(block);
}
}
void clear_dirty(const void* block) {
set_dirty(block, false);
}

// Score operations
static uint32_t get_count(const void* block) {
return reinterpret_cast<const MetaHeader*>(block)->count;
Expand Down Expand Up @@ -190,13 +242,16 @@ class FixedBlockPool : public std::pmr::memory_resource {
std::size_t block_size, // Size of each memory block
std::size_t block_alignment, // Memory block alignment requirement
std::size_t blocks_per_chunk = 8192, // Number of blocks per chunk
std::pmr::memory_resource* upstream = std::pmr::new_delete_resource())
std::pmr::memory_resource* upstream = std::pmr::new_delete_resource(),
// Dirty-bit tracking for DRAM_SSD (default off, avoids mutex overhead)
bool enable_dirty_tracking = false)
// Minimum block size is 8 bytes
: block_size_(std::max(block_size, sizeof(void*))),
block_alignment_(block_alignment),
blocks_per_chunk_(blocks_per_chunk),
upstream_(upstream),
chunks_(upstream) {
chunks_(upstream),
enable_dirty_tracking_(enable_dirty_tracking) {
// Validate minimum data size, whether it's less than 8 bytes
// half type, 2 bytes, minimum embedding length 4
// float type, 4 bytes, minimum embedding length 2
Expand All @@ -223,7 +278,9 @@ class FixedBlockPool : public std::pmr::memory_resource {
}
}

// Release all allocated memory during destruction
// Release all allocated memory during destruction. The per-pool dirty
// tracker (dirty_tracker_stripes_) is destroyed automatically with the pool,
// so no explicit dirty cleanup is needed here.
~FixedBlockPool() override {
std::lock_guard<std::mutex> guard(chunks_mutex_);
for (auto&& chunk : chunks_) {
Expand Down Expand Up @@ -273,6 +330,9 @@ class FixedBlockPool : public std::pmr::memory_resource {
[[nodiscard]] std::size_t get_blocks_per_chunk() const noexcept {
return blocks_per_chunk_;
}
[[nodiscard]] bool is_dirty_tracking_enabled() const noexcept {
return enable_dirty_tracking_;
}
[[nodiscard]] std::size_t get_aligned_block_size() const noexcept {
return (block_size_ + block_alignment_ - 1) / block_alignment_ *
block_alignment_;
Expand Down Expand Up @@ -302,6 +362,13 @@ class FixedBlockPool : public std::pmr::memory_resource {
set_used(result, true);
set_count(result, 0);
update_timestamp(result);
if (enable_dirty_tracking_) {
// Intentional: a freshly allocated block holds new data that has not yet
// been persisted to SSD, so it must be marked dirty to be picked up by
// the SSD flush protocol. Callers should not assume "allocate" returns
// clean memory.
set_dirty(result, true);
}
return result;
}

Expand All @@ -310,6 +377,9 @@ class FixedBlockPool : public std::pmr::memory_resource {
void* p,
std::size_t bytes [[maybe_unused]],
std::size_t alignment [[maybe_unused]]) override {
if (enable_dirty_tracking_) {
clear_dirty(p);
}
// Insert memory block back to the head of free list
*static_cast<void**>(p) = free_list_;
free_list_ = p;
Expand Down Expand Up @@ -355,6 +425,8 @@ class FixedBlockPool : public std::pmr::memory_resource {
std::pmr::vector<ChunkInfo> chunks_; // Records of all allocated chunks
void* free_list_ = nullptr; // Free block list head pointer
mutable std::mutex chunks_mutex_; // Mutex for chunks_
// Gate for out-of-band dirty-bit tracking (enabled only for DRAM_SSD).
const bool enable_dirty_tracking_;

private:
// block pool lock, only used on the inference side to guard in-place update
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,14 @@ class InferenceFixedBlockPool : public FixedBlockPool {
std::size_t block_size,
std::size_t block_alignment,
std::size_t blocks_per_chunk = 8192,
std::pmr::memory_resource* upstream = std::pmr::new_delete_resource())
std::pmr::memory_resource* upstream = std::pmr::new_delete_resource(),
bool enable_dirty_tracking = false)
: FixedBlockPool(
block_size,
block_alignment,
blocks_per_chunk,
upstream) {}
upstream,
enable_dirty_tracking) {}

// Get block by index (used for eviction traversal)
// Uses InferenceFixedBlockPool::get_used for 12-byte header layout
Expand Down
53 changes: 53 additions & 0 deletions fbgemm_gpu/test/dram_kv_embedding_cache/fixed_block_pool_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,4 +421,57 @@ TEST(FixedBlockPool, DataIntegrity) {
pool.deallocate_t<float>(block);
}

TEST(FixedBlockPool, DefaultDirtyTrackingGating) {
// Dirty-bit tracking is opt-in (only the DRAM_SSD backend needs it).
constexpr int dim = 4;
size_t block_size = FixedBlockPool::calculate_block_size<float>(dim);
size_t alignment = FixedBlockPool::calculate_block_alignment<float>();

// Default pool: tracking disabled.
FixedBlockPool default_pool(block_size, alignment, 1024);
EXPECT_FALSE(default_pool.is_dirty_tracking_enabled());

// After allocation: a freshly allocated block is always clean for disabled
// dirty tracking.
auto* block = default_pool.allocate_t<float>();
ASSERT_NE(block, nullptr);
EXPECT_FALSE(default_pool.get_dirty(block));

default_pool.clear_dirty(block);
EXPECT_FALSE(default_pool.get_dirty(block));
}

TEST(FixedBlockPool, DirtyTrackingGating) {
// Dirty-bit tracking is opt-in (only the DRAM_SSD backend needs it).
constexpr int dim = 4;
size_t block_size = FixedBlockPool::calculate_block_size<float>(dim);
size_t alignment = FixedBlockPool::calculate_block_alignment<float>();

// Pool with tracking explicitly enabled.
FixedBlockPool tracked_pool(
block_size,
alignment,
1024,
std::pmr::new_delete_resource(),
/*enable_dirty_tracking=*/true);
EXPECT_TRUE(tracked_pool.is_dirty_tracking_enabled());

// After allocation: a freshly allocated block is dirty (its data has not
// yet been persisted to SSD).
auto* block = tracked_pool.allocate_t<float>();
ASSERT_NE(block, nullptr);
EXPECT_TRUE(tracked_pool.get_dirty(block));

// After operations: clearing (e.g. once flushed to SSD) marks it clean,
// and setting it again marks it dirty.
tracked_pool.clear_dirty(block);
EXPECT_FALSE(tracked_pool.get_dirty(block));
tracked_pool.set_dirty(block, true);
EXPECT_TRUE(tracked_pool.get_dirty(block));

// After deallocation: the block's dirty state is cleared.
tracked_pool.deallocate_t<float>(block);
EXPECT_FALSE(tracked_pool.get_dirty(block));
}

} // namespace kv_mem
Loading