Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions bolt/shuffle/sparksql/Options.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ struct PartitionWriterOptions {

int32_t pushBufferMaxSize = kDefaultShuffleWriterBufferSize;

// When true, CelebornPartitionWriter routes payloads <= pushBufferMaxSize
// through RssClient::mergePartitionData (Celeborn-side batching) and
// larger payloads through pushPartitionData. When false, every payload
// goes through pushPartitionData (legacy behavior).
bool celebornMergeEnabled = true;

int32_t shuffleBufferSize = kDefaultShuffleBatchByteSize;

int64_t rowvectorModeCompressionMinColumns =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,18 @@ arrow::Status CelebornPartitionWriter::evict(
RETURN_NOT_OK(payload->serialize(celebornBufferOs.get()));
payload = nullptr; // Invalidate payload immediately.

// Push.
// Push. Small payloads go through mergeData so Celeborn coalesces them
// per worker before hitting the wire; large payloads bypass the merge
// buffer and push directly. Matches Gluten's Java client behavior.
ARROW_ASSIGN_OR_RAISE(auto buffer, celebornBufferOs->Finish());
bytesEvicted_[partitionId] += celebornClient_->pushPartitionData(
partitionId,
reinterpret_cast<char*>(const_cast<uint8_t*>(buffer->data())),
buffer->size());
auto* bufBytes =
reinterpret_cast<char*>(const_cast<uint8_t*>(buffer->data()));
const int64_t bufSize = buffer->size();
const bool useMerge =
options_.celebornMergeEnabled && bufSize <= options_.pushBufferMaxSize;
bytesEvicted_[partitionId] += useMerge
? celebornClient_->mergePartitionData(partitionId, bufBytes, bufSize)
: celebornClient_->pushPartitionData(partitionId, bufBytes, bufSize);
return arrow::Status::OK();
}

Expand Down Expand Up @@ -166,10 +172,14 @@ arrow::Status CelebornPartitionWriter::evict(
payload = nullptr; // Invalidate payload immediately.

ARROW_ASSIGN_OR_RAISE(auto buffer, celebornBufferOs->Finish());
bytesEvicted_[pid] += celebornClient_->pushPartitionData(
pid,
reinterpret_cast<char*>(const_cast<uint8_t*>(buffer->data())),
buffer->size());
auto* bufBytes =
reinterpret_cast<char*>(const_cast<uint8_t*>(buffer->data()));
const int64_t bufSize = buffer->size();
const bool useMerge = options_.celebornMergeEnabled &&
bufSize <= options_.pushBufferMaxSize;
bytesEvicted_[pid] += useMerge
? celebornClient_->mergePartitionData(pid, bufBytes, bufSize)
: celebornClient_->pushPartitionData(pid, bufBytes, bufSize);
startIndex += slicedNumRows;
} while (startIndex < totalRowCount);
BOLT_CHECK(startIndex == totalRowCount);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,28 @@ int32_t NativeCelebornClient::pushPartitionData(
numPartitions_);
}

int32_t NativeCelebornClient::mergePartitionData(
int32_t partitionId,
char* bytes,
int64_t size) {
BOLT_CHECK(
!stopped_,
"Cannot merge data after NativeCelebornClient has been stopped");
// Celeborn's mergeData buffers per worker-address-pair and auto-flushes
// when the accumulator exceeds celeborn.client.push.buffer.max.size.
// Any remaining batches are drained by mapperEnd() in stop().
return shuffleClient_->mergeData(
shuffleId_,
mapId_,
attemptId_,
partitionId,
reinterpret_cast<const uint8_t*>(bytes),
0,
size,
numMappers_,
numPartitions_);
}

void NativeCelebornClient::stop() {
BOLT_CHECK(
!stopped_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class NativeCelebornClient : public RssClient {
int32_t pushPartitionData(int32_t partitionId, char* bytes, int64_t size)
override;

int32_t mergePartitionData(int32_t partitionId, char* bytes, int64_t size)
override;

void stop() override;

private:
Expand Down
8 changes: 8 additions & 0 deletions bolt/shuffle/sparksql/partition_writer/rss/RssClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ class RssClient {
virtual int32_t
pushPartitionData(int32_t partitionId, char* bytes, int64_t size) = 0;

// Buffer the payload client-side; the underlying client coalesces and
// flushes via pushMergedData / mapperEnd. Default delegates to
// pushPartitionData so backends without merge support keep working.
virtual int32_t
mergePartitionData(int32_t partitionId, char* bytes, int64_t size) {
return pushPartitionData(partitionId, bytes, size);
}

virtual void stop() = 0;
};

Expand Down
1 change: 0 additions & 1 deletion bolt/shuffle/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ target_link_libraries(
bolt_testutils
glog::glog
GTest::gtest
GTest::gtest_main
)

add_test(
Expand Down
99 changes: 98 additions & 1 deletion bolt/shuffle/sparksql/tests/CelebornClientTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <arrow/buffer.h>
#include <arrow/io/api.h>
#include <arrow/memory_pool.h>
#include <folly/init/Init.h>
#include <gtest/gtest.h>

#include <celeborn/client/ShuffleClient.h>
Expand Down Expand Up @@ -206,6 +207,43 @@ class FakeShuffleClient final : public celeborn::client::ShuffleClient {
return 0;
}

// Records merged batches in the same payload map as pushData so reader-side
// tests don't have to distinguish; bumps a separate counter callers can
// inspect to verify dispatch.
int mergeData(
int shuffleId,
int mapId,
int attemptId,
int partitionId,
const uint8_t* data,
size_t offset,
size_t length,
int numMappers,
int numPartitions) override {
mergeCallCount_++;
return pushData(
shuffleId,
mapId,
attemptId,
partitionId,
data,
offset,
length,
numMappers,
numPartitions);
}

void pushMergedData(int, int, int) override {
pushMergedDataCallCount_++;
}

int64_t mergeCallCount() const {
return mergeCallCount_;
}
int64_t pushMergedDataCallCount() const {
return pushMergedDataCallCount_;
}

void mapperEnd(int, int, int, int) override {}

void cleanup(int, int, int) override {}
Expand Down Expand Up @@ -249,6 +287,9 @@ class FakeShuffleClient final : public celeborn::client::ShuffleClient {
return true;
}

void excludeFailedFetchLocation(const std::string&, const std::exception&)
override {}

void shutdown() override {}

private:
Expand Down Expand Up @@ -294,14 +335,19 @@ class FakeShuffleClient final : public celeborn::client::ShuffleClient {
attemptNumber,
startMapIndex,
endMapIndex,
needCompression);
needCompression,
std::make_shared<
celeborn::client::CelebornInputStream::FetchExcludedWorkers>(),
this);
}

std::shared_ptr<celeborn::conf::CelebornConf> conf_;
bool hasKnownCounts_{false};
int knownNumMappers_{0};
int knownNumPartitions_{0};
std::map<MapKey, std::vector<std::vector<uint8_t>>> payloadByMapKey_;
int64_t mergeCallCount_{0};
int64_t pushMergedDataCallCount_{0};
};
// Verifies iterator returns nullptr when no partitions or after close.
TEST(CelebornReaderStreamIteratorTest, ReturnsNullWhenEmptyOrClosed) {
Expand Down Expand Up @@ -536,4 +582,55 @@ TEST(NativeCelebornClientTest, testStopTriggeredTwice) {
bytedance::bolt::BoltRuntimeError);
}

// Verifies NativeCelebornClient::mergePartitionData routes to ShuffleClient::
// mergeData (not pushData), and that mergeData calls after stop are rejected.
TEST(NativeCelebornClientTest, mergePartitionDataRoutesToMergeData) {
auto client = std::make_shared<FakeShuffleClient>();
constexpr int kShuffleId = 16;
constexpr int32_t kPartitionId = 0;
constexpr int kMapId = 0;
constexpr int kAttemptId = 0;
constexpr int kNumMappers = 1;
constexpr int kNumPartitions = 1;

NativeCelebornClient nativeClient(
client, kShuffleId, kMapId, kAttemptId, kNumMappers, kNumPartitions);

std::string payload = "small-merged-batch";
nativeClient.mergePartitionData(
kPartitionId, payload.data(), static_cast<int64_t>(payload.size()));

EXPECT_EQ(client->mergeCallCount(), 1);

nativeClient.stop();
EXPECT_THROW(
nativeClient.mergePartitionData(kPartitionId, payload.data(), 1),
bytedance::bolt::BoltRuntimeError);
}

// Verifies RssClient's default mergePartitionData delegates to
// pushPartitionData so backends that don't override it keep working unchanged.
TEST(RssClientTest, defaultMergeDelegatesToPush) {
struct PushOnlyClient : public RssClient {
int32_t pushPartitionData(int32_t, char*, int64_t size) override {
pushBytes += size;
return static_cast<int32_t>(size);
}
void stop() override {}
int64_t pushBytes{0};
};

PushOnlyClient client;
char buf[8] = {0};
EXPECT_EQ(client.mergePartitionData(0, buf, sizeof(buf)), 8);
EXPECT_EQ(client.pushBytes, 8);
}

} // namespace bytedance::bolt::shuffle::sparksql::test

int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
// todo: use folly::Init init after upgrade folly lib
folly::init(&argc, &argv, false);
return RUN_ALL_TESTS();
}
41 changes: 37 additions & 4 deletions bolt/shuffle/sparksql/tests/MockRssClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,17 @@ class MockRssClient : public RssClient {
// Simple implementation to store data
int32_t pushPartitionData(int32_t partitionId, char* bytes, int64_t size)
override {
if (data_.find(partitionId) == data_.end()) {
data_[partitionId] = std::vector<char>();
}
data_[partitionId].insert(data_[partitionId].end(), bytes, bytes + size);
pushCalls_++;
pushBytes_ += size;
appendData(partitionId, bytes, size);
return size;
}

int32_t mergePartitionData(int32_t partitionId, char* bytes, int64_t size)
override {
mergeCalls_++;
mergeBytes_ += size;
appendData(partitionId, bytes, size);
return size;
}

Expand All @@ -40,9 +47,35 @@ class MockRssClient : public RssClient {
return data_;
}

int64_t pushCalls() const {
return pushCalls_;
}
int64_t mergeCalls() const {
return mergeCalls_;
}
int64_t pushBytes() const {
return pushBytes_;
}
int64_t mergeBytes() const {
return mergeBytes_;
}

public:
// Helper to store data for verification or reading
std::map<int32_t, std::vector<char>> data_;

private:
void appendData(int32_t partitionId, char* bytes, int64_t size) {
if (data_.find(partitionId) == data_.end()) {
data_[partitionId] = std::vector<char>();
}
data_[partitionId].insert(data_[partitionId].end(), bytes, bytes + size);
}

int64_t pushCalls_{0};
int64_t mergeCalls_{0};
int64_t pushBytes_{0};
int64_t mergeBytes_{0};
};

} // namespace bytedance::bolt::shuffle::sparksql::test
29 changes: 20 additions & 9 deletions bolt/shuffle/sparksql/tests/ShuffleMatrixTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,32 @@ std::vector<ShuffleTestParam> buildShuffleParams() {
const std::vector<PartitionWriterType> writerTypes = {
PartitionWriterType::kLocal, PartitionWriterType::kCeleborn};

// Only Celeborn honors celebornMergeEnabled; sweeping both values for
// kLocal would just duplicate test cases.
const std::vector<bool> celebornMergeOptions = {true, false};

for (auto partitioning : partitionings) {
for (auto shuffleMode : shuffleModes) {
for (auto writerType : writerTypes) {
for (auto dataTypeGroup : dataGroups) {
for (auto numPartitions : partitionNumbers) {
for (auto numMappers : mapperNumbers) {
auto param = ShuffleTestParam{
partitioning,
shuffleMode,
writerType,
dataTypeGroup,
numPartitions,
numMappers};
if (param.isSupported()) {
params.push_back(param);
const auto mergeFlags =
writerType == PartitionWriterType::kCeleborn
? celebornMergeOptions
: std::vector<bool>{true};
for (auto mergeEnabled : mergeFlags) {
auto param = ShuffleTestParam{
partitioning,
shuffleMode,
writerType,
dataTypeGroup,
numPartitions,
numMappers};
param.celebornMergeEnabled = mergeEnabled;
if (param.isSupported()) {
params.push_back(param);
}
}
}
}
Expand Down
7 changes: 5 additions & 2 deletions bolt/shuffle/sparksql/tests/ShuffleTestBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,15 @@ std::string ShuffleTestParam::toString() const {
auto memStr = fmt::format("{}{}", v, units[u]);

return fmt::format(
"{}_{}_{}_{}_M{}_P{}_{}",
"{}_{}_{}_{}_M{}_P{}_{}_merge{}",
partitioning,
shuffleModeToString(shuffleMode),
writerTypeToString(writerType),
dataTypeGroupToString(dataTypeGroup),
numMappers,
numPartitions,
memStr);
memStr,
celebornMergeEnabled ? "On" : "Off");
}

bool ShuffleTestParam::isSupported() const {
Expand Down Expand Up @@ -514,6 +515,8 @@ ShuffleRunResult ShuffleTestBase::runShuffle(
writerOptions.partitionWriterOptions.numPartitions = param.numPartitions;
writerOptions.forceShuffleWriterType = param.shuffleMode;
writerOptions.partitionWriterOptions.partitionWriterType = param.writerType;
writerOptions.partitionWriterOptions.celebornMergeEnabled =
param.celebornMergeEnabled;
writerOptions.taskAttemptId = memoryManagerHolder->taskAttemptId();
writerOptions.partitionWriterOptions.shuffleBufferSize =
param.shuffleBufferSize;
Expand Down
1 change: 1 addition & 0 deletions bolt/shuffle/sparksql/tests/ShuffleTestBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ struct ShuffleTestParam {
int32_t numBatches = 4;
int32_t shuffleBufferSize = kDefaultShuffleWriterBufferSize;
bool verifyOutput = true;
bool celebornMergeEnabled = true;

std::string toString() const;

Expand Down
Loading
Loading