From 7d4b1fc6a4f62ee84ce6c36599fbef54c2e2c13a Mon Sep 17 00:00:00 2001 From: afterincomparableyum <224495379+afterincomparableyum@users.noreply.github.com> Date: Sat, 16 May 2026 21:12:31 -0700 Subject: [PATCH] initial support for celeborn push merged c++ client --- bolt/shuffle/sparksql/Options.h | 6 ++ .../rss/CelebornPartitionWriter.cpp | 28 ++++-- .../rss/NativeCelebornClient.cpp | 22 +++++ .../rss/NativeCelebornClient.h | 3 + .../sparksql/partition_writer/rss/RssClient.h | 8 ++ bolt/shuffle/sparksql/tests/CMakeLists.txt | 1 - .../sparksql/tests/CelebornClientTest.cpp | 99 ++++++++++++++++++- bolt/shuffle/sparksql/tests/MockRssClient.h | 41 +++++++- .../sparksql/tests/ShuffleMatrixTest.cpp | 29 ++++-- .../sparksql/tests/ShuffleTestBase.cpp | 7 +- bolt/shuffle/sparksql/tests/ShuffleTestBase.h | 1 + .../shuffle/sparksql/tests/celeborn/README.md | 4 +- .../tests/celeborn/scripts/run_e2e.sh | 8 +- conanfile.py | 2 +- .../celeborn-cpp-client/all/conandata.yml | 2 +- .../celeborn-cpp-client/all/conanfile.py | 4 +- .../recipes/celeborn-cpp-client/config.yml | 2 +- 17 files changed, 230 insertions(+), 37 deletions(-) diff --git a/bolt/shuffle/sparksql/Options.h b/bolt/shuffle/sparksql/Options.h index 4a624a7df..f27cea59d 100644 --- a/bolt/shuffle/sparksql/Options.h +++ b/bolt/shuffle/sparksql/Options.h @@ -121,6 +121,12 @@ struct PartitionWriterOptions { int32_t pushBufferMaxSize = kDefaultShuffleWriterBufferSize; + // When true, CelebornPartitionWriter routes payloads <= pushBufferMaxSize + // through RssClient::mergePartitionData (Celeborn-side batching) and + // larger payloads through pushPartitionData. When false, every payload + // goes through pushPartitionData (legacy behavior). + bool celebornMergeEnabled = true; + int32_t shuffleBufferSize = kDefaultShuffleBatchByteSize; int64_t rowvectorModeCompressionMinColumns = diff --git a/bolt/shuffle/sparksql/partition_writer/rss/CelebornPartitionWriter.cpp b/bolt/shuffle/sparksql/partition_writer/rss/CelebornPartitionWriter.cpp index 34d90423e..6ca16c054 100644 --- a/bolt/shuffle/sparksql/partition_writer/rss/CelebornPartitionWriter.cpp +++ b/bolt/shuffle/sparksql/partition_writer/rss/CelebornPartitionWriter.cpp @@ -92,12 +92,18 @@ arrow::Status CelebornPartitionWriter::evict( RETURN_NOT_OK(payload->serialize(celebornBufferOs.get())); payload = nullptr; // Invalidate payload immediately. - // Push. + // Push. Small payloads go through mergeData so Celeborn coalesces them + // per worker before hitting the wire; large payloads bypass the merge + // buffer and push directly. Matches Gluten's Java client behavior. ARROW_ASSIGN_OR_RAISE(auto buffer, celebornBufferOs->Finish()); - bytesEvicted_[partitionId] += celebornClient_->pushPartitionData( - partitionId, - reinterpret_cast(const_cast(buffer->data())), - buffer->size()); + auto* bufBytes = + reinterpret_cast(const_cast(buffer->data())); + const int64_t bufSize = buffer->size(); + const bool useMerge = + options_.celebornMergeEnabled && bufSize <= options_.pushBufferMaxSize; + bytesEvicted_[partitionId] += useMerge + ? celebornClient_->mergePartitionData(partitionId, bufBytes, bufSize) + : celebornClient_->pushPartitionData(partitionId, bufBytes, bufSize); return arrow::Status::OK(); } @@ -166,10 +172,14 @@ arrow::Status CelebornPartitionWriter::evict( payload = nullptr; // Invalidate payload immediately. ARROW_ASSIGN_OR_RAISE(auto buffer, celebornBufferOs->Finish()); - bytesEvicted_[pid] += celebornClient_->pushPartitionData( - pid, - reinterpret_cast(const_cast(buffer->data())), - buffer->size()); + auto* bufBytes = + reinterpret_cast(const_cast(buffer->data())); + const int64_t bufSize = buffer->size(); + const bool useMerge = options_.celebornMergeEnabled && + bufSize <= options_.pushBufferMaxSize; + bytesEvicted_[pid] += useMerge + ? celebornClient_->mergePartitionData(pid, bufBytes, bufSize) + : celebornClient_->pushPartitionData(pid, bufBytes, bufSize); startIndex += slicedNumRows; } while (startIndex < totalRowCount); BOLT_CHECK(startIndex == totalRowCount); diff --git a/bolt/shuffle/sparksql/partition_writer/rss/NativeCelebornClient.cpp b/bolt/shuffle/sparksql/partition_writer/rss/NativeCelebornClient.cpp index f3fed924e..e6d29efd2 100644 --- a/bolt/shuffle/sparksql/partition_writer/rss/NativeCelebornClient.cpp +++ b/bolt/shuffle/sparksql/partition_writer/rss/NativeCelebornClient.cpp @@ -60,6 +60,28 @@ int32_t NativeCelebornClient::pushPartitionData( numPartitions_); } +int32_t NativeCelebornClient::mergePartitionData( + int32_t partitionId, + char* bytes, + int64_t size) { + BOLT_CHECK( + !stopped_, + "Cannot merge data after NativeCelebornClient has been stopped"); + // Celeborn's mergeData buffers per worker-address-pair and auto-flushes + // when the accumulator exceeds celeborn.client.push.buffer.max.size. + // Any remaining batches are drained by mapperEnd() in stop(). + return shuffleClient_->mergeData( + shuffleId_, + mapId_, + attemptId_, + partitionId, + reinterpret_cast(bytes), + 0, + size, + numMappers_, + numPartitions_); +} + void NativeCelebornClient::stop() { BOLT_CHECK( !stopped_, diff --git a/bolt/shuffle/sparksql/partition_writer/rss/NativeCelebornClient.h b/bolt/shuffle/sparksql/partition_writer/rss/NativeCelebornClient.h index 0eafd2189..d40e88cff 100644 --- a/bolt/shuffle/sparksql/partition_writer/rss/NativeCelebornClient.h +++ b/bolt/shuffle/sparksql/partition_writer/rss/NativeCelebornClient.h @@ -40,6 +40,9 @@ class NativeCelebornClient : public RssClient { int32_t pushPartitionData(int32_t partitionId, char* bytes, int64_t size) override; + int32_t mergePartitionData(int32_t partitionId, char* bytes, int64_t size) + override; + void stop() override; private: diff --git a/bolt/shuffle/sparksql/partition_writer/rss/RssClient.h b/bolt/shuffle/sparksql/partition_writer/rss/RssClient.h index a44ac495e..5c1724715 100644 --- a/bolt/shuffle/sparksql/partition_writer/rss/RssClient.h +++ b/bolt/shuffle/sparksql/partition_writer/rss/RssClient.h @@ -40,6 +40,14 @@ class RssClient { virtual int32_t pushPartitionData(int32_t partitionId, char* bytes, int64_t size) = 0; + // Buffer the payload client-side; the underlying client coalesces and + // flushes via pushMergedData / mapperEnd. Default delegates to + // pushPartitionData so backends without merge support keep working. + virtual int32_t + mergePartitionData(int32_t partitionId, char* bytes, int64_t size) { + return pushPartitionData(partitionId, bytes, size); + } + virtual void stop() = 0; }; diff --git a/bolt/shuffle/sparksql/tests/CMakeLists.txt b/bolt/shuffle/sparksql/tests/CMakeLists.txt index 327e96083..dd346192e 100644 --- a/bolt/shuffle/sparksql/tests/CMakeLists.txt +++ b/bolt/shuffle/sparksql/tests/CMakeLists.txt @@ -23,7 +23,6 @@ target_link_libraries( bolt_testutils glog::glog GTest::gtest - GTest::gtest_main ) add_test( diff --git a/bolt/shuffle/sparksql/tests/CelebornClientTest.cpp b/bolt/shuffle/sparksql/tests/CelebornClientTest.cpp index 929d1fde2..d50c759d3 100644 --- a/bolt/shuffle/sparksql/tests/CelebornClientTest.cpp +++ b/bolt/shuffle/sparksql/tests/CelebornClientTest.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -206,6 +207,43 @@ class FakeShuffleClient final : public celeborn::client::ShuffleClient { return 0; } + // Records merged batches in the same payload map as pushData so reader-side + // tests don't have to distinguish; bumps a separate counter callers can + // inspect to verify dispatch. + int mergeData( + int shuffleId, + int mapId, + int attemptId, + int partitionId, + const uint8_t* data, + size_t offset, + size_t length, + int numMappers, + int numPartitions) override { + mergeCallCount_++; + return pushData( + shuffleId, + mapId, + attemptId, + partitionId, + data, + offset, + length, + numMappers, + numPartitions); + } + + void pushMergedData(int, int, int) override { + pushMergedDataCallCount_++; + } + + int64_t mergeCallCount() const { + return mergeCallCount_; + } + int64_t pushMergedDataCallCount() const { + return pushMergedDataCallCount_; + } + void mapperEnd(int, int, int, int) override {} void cleanup(int, int, int) override {} @@ -249,6 +287,9 @@ class FakeShuffleClient final : public celeborn::client::ShuffleClient { return true; } + void excludeFailedFetchLocation(const std::string&, const std::exception&) + override {} + void shutdown() override {} private: @@ -294,7 +335,10 @@ class FakeShuffleClient final : public celeborn::client::ShuffleClient { attemptNumber, startMapIndex, endMapIndex, - needCompression); + needCompression, + std::make_shared< + celeborn::client::CelebornInputStream::FetchExcludedWorkers>(), + this); } std::shared_ptr conf_; @@ -302,6 +346,8 @@ class FakeShuffleClient final : public celeborn::client::ShuffleClient { int knownNumMappers_{0}; int knownNumPartitions_{0}; std::map>> payloadByMapKey_; + int64_t mergeCallCount_{0}; + int64_t pushMergedDataCallCount_{0}; }; // Verifies iterator returns nullptr when no partitions or after close. TEST(CelebornReaderStreamIteratorTest, ReturnsNullWhenEmptyOrClosed) { @@ -536,4 +582,55 @@ TEST(NativeCelebornClientTest, testStopTriggeredTwice) { bytedance::bolt::BoltRuntimeError); } +// Verifies NativeCelebornClient::mergePartitionData routes to ShuffleClient:: +// mergeData (not pushData), and that mergeData calls after stop are rejected. +TEST(NativeCelebornClientTest, mergePartitionDataRoutesToMergeData) { + auto client = std::make_shared(); + constexpr int kShuffleId = 16; + constexpr int32_t kPartitionId = 0; + constexpr int kMapId = 0; + constexpr int kAttemptId = 0; + constexpr int kNumMappers = 1; + constexpr int kNumPartitions = 1; + + NativeCelebornClient nativeClient( + client, kShuffleId, kMapId, kAttemptId, kNumMappers, kNumPartitions); + + std::string payload = "small-merged-batch"; + nativeClient.mergePartitionData( + kPartitionId, payload.data(), static_cast(payload.size())); + + EXPECT_EQ(client->mergeCallCount(), 1); + + nativeClient.stop(); + EXPECT_THROW( + nativeClient.mergePartitionData(kPartitionId, payload.data(), 1), + bytedance::bolt::BoltRuntimeError); +} + +// Verifies RssClient's default mergePartitionData delegates to +// pushPartitionData so backends that don't override it keep working unchanged. +TEST(RssClientTest, defaultMergeDelegatesToPush) { + struct PushOnlyClient : public RssClient { + int32_t pushPartitionData(int32_t, char*, int64_t size) override { + pushBytes += size; + return static_cast(size); + } + void stop() override {} + int64_t pushBytes{0}; + }; + + PushOnlyClient client; + char buf[8] = {0}; + EXPECT_EQ(client.mergePartitionData(0, buf, sizeof(buf)), 8); + EXPECT_EQ(client.pushBytes, 8); +} + } // namespace bytedance::bolt::shuffle::sparksql::test + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + // todo: use folly::Init init after upgrade folly lib + folly::init(&argc, &argv, false); + return RUN_ALL_TESTS(); +} diff --git a/bolt/shuffle/sparksql/tests/MockRssClient.h b/bolt/shuffle/sparksql/tests/MockRssClient.h index 9ae78620e..aa36b8266 100644 --- a/bolt/shuffle/sparksql/tests/MockRssClient.h +++ b/bolt/shuffle/sparksql/tests/MockRssClient.h @@ -27,10 +27,17 @@ class MockRssClient : public RssClient { // Simple implementation to store data int32_t pushPartitionData(int32_t partitionId, char* bytes, int64_t size) override { - if (data_.find(partitionId) == data_.end()) { - data_[partitionId] = std::vector(); - } - data_[partitionId].insert(data_[partitionId].end(), bytes, bytes + size); + pushCalls_++; + pushBytes_ += size; + appendData(partitionId, bytes, size); + return size; + } + + int32_t mergePartitionData(int32_t partitionId, char* bytes, int64_t size) + override { + mergeCalls_++; + mergeBytes_ += size; + appendData(partitionId, bytes, size); return size; } @@ -40,9 +47,35 @@ class MockRssClient : public RssClient { return data_; } + int64_t pushCalls() const { + return pushCalls_; + } + int64_t mergeCalls() const { + return mergeCalls_; + } + int64_t pushBytes() const { + return pushBytes_; + } + int64_t mergeBytes() const { + return mergeBytes_; + } + public: // Helper to store data for verification or reading std::map> data_; + + private: + void appendData(int32_t partitionId, char* bytes, int64_t size) { + if (data_.find(partitionId) == data_.end()) { + data_[partitionId] = std::vector(); + } + data_[partitionId].insert(data_[partitionId].end(), bytes, bytes + size); + } + + int64_t pushCalls_{0}; + int64_t mergeCalls_{0}; + int64_t pushBytes_{0}; + int64_t mergeBytes_{0}; }; } // namespace bytedance::bolt::shuffle::sparksql::test diff --git a/bolt/shuffle/sparksql/tests/ShuffleMatrixTest.cpp b/bolt/shuffle/sparksql/tests/ShuffleMatrixTest.cpp index 86415e7db..67f73622b 100644 --- a/bolt/shuffle/sparksql/tests/ShuffleMatrixTest.cpp +++ b/bolt/shuffle/sparksql/tests/ShuffleMatrixTest.cpp @@ -29,21 +29,32 @@ std::vector buildShuffleParams() { const std::vector writerTypes = { PartitionWriterType::kLocal, PartitionWriterType::kCeleborn}; + // Only Celeborn honors celebornMergeEnabled; sweeping both values for + // kLocal would just duplicate test cases. + const std::vector celebornMergeOptions = {true, false}; + for (auto partitioning : partitionings) { for (auto shuffleMode : shuffleModes) { for (auto writerType : writerTypes) { for (auto dataTypeGroup : dataGroups) { for (auto numPartitions : partitionNumbers) { for (auto numMappers : mapperNumbers) { - auto param = ShuffleTestParam{ - partitioning, - shuffleMode, - writerType, - dataTypeGroup, - numPartitions, - numMappers}; - if (param.isSupported()) { - params.push_back(param); + const auto mergeFlags = + writerType == PartitionWriterType::kCeleborn + ? celebornMergeOptions + : std::vector{true}; + for (auto mergeEnabled : mergeFlags) { + auto param = ShuffleTestParam{ + partitioning, + shuffleMode, + writerType, + dataTypeGroup, + numPartitions, + numMappers}; + param.celebornMergeEnabled = mergeEnabled; + if (param.isSupported()) { + params.push_back(param); + } } } } diff --git a/bolt/shuffle/sparksql/tests/ShuffleTestBase.cpp b/bolt/shuffle/sparksql/tests/ShuffleTestBase.cpp index badaa582d..f5be7954b 100644 --- a/bolt/shuffle/sparksql/tests/ShuffleTestBase.cpp +++ b/bolt/shuffle/sparksql/tests/ShuffleTestBase.cpp @@ -161,14 +161,15 @@ std::string ShuffleTestParam::toString() const { auto memStr = fmt::format("{}{}", v, units[u]); return fmt::format( - "{}_{}_{}_{}_M{}_P{}_{}", + "{}_{}_{}_{}_M{}_P{}_{}_merge{}", partitioning, shuffleModeToString(shuffleMode), writerTypeToString(writerType), dataTypeGroupToString(dataTypeGroup), numMappers, numPartitions, - memStr); + memStr, + celebornMergeEnabled ? "On" : "Off"); } bool ShuffleTestParam::isSupported() const { @@ -514,6 +515,8 @@ ShuffleRunResult ShuffleTestBase::runShuffle( writerOptions.partitionWriterOptions.numPartitions = param.numPartitions; writerOptions.forceShuffleWriterType = param.shuffleMode; writerOptions.partitionWriterOptions.partitionWriterType = param.writerType; + writerOptions.partitionWriterOptions.celebornMergeEnabled = + param.celebornMergeEnabled; writerOptions.taskAttemptId = memoryManagerHolder->taskAttemptId(); writerOptions.partitionWriterOptions.shuffleBufferSize = param.shuffleBufferSize; diff --git a/bolt/shuffle/sparksql/tests/ShuffleTestBase.h b/bolt/shuffle/sparksql/tests/ShuffleTestBase.h index 204b265ff..d9958fe73 100644 --- a/bolt/shuffle/sparksql/tests/ShuffleTestBase.h +++ b/bolt/shuffle/sparksql/tests/ShuffleTestBase.h @@ -67,6 +67,7 @@ struct ShuffleTestParam { int32_t numBatches = 4; int32_t shuffleBufferSize = kDefaultShuffleWriterBufferSize; bool verifyOutput = true; + bool celebornMergeEnabled = true; std::string toString() const; diff --git a/bolt/shuffle/sparksql/tests/celeborn/README.md b/bolt/shuffle/sparksql/tests/celeborn/README.md index 0f3c481bf..1aa24df86 100644 --- a/bolt/shuffle/sparksql/tests/celeborn/README.md +++ b/bolt/shuffle/sparksql/tests/celeborn/README.md @@ -15,8 +15,8 @@ inside the development container. ## Environment variables -- `BOLT_CELEBORN_GIT_REPO` (default `https://github.com/apache/celeborn.git`) -- `BOLT_CELEBORN_GIT_REF` (default `81d89f3`, aligned with cpp-client recipe) +- `BOLT_CELEBORN_GIT_REPO` (default `https://github.com/afterincomparableyum/celeborn.git`) +- `BOLT_CELEBORN_GIT_REF` (default `2e13df97aba3e25d80f5562fd4c0c8a3b34beb43`, aligned with cpp-client recipe) - `BOLT_CELEBORN_MASTER_HOST` (default `127.0.0.1`) - `BOLT_CELEBORN_MASTER_PORT` (default `19097`) - `BOLT_CELEBORN_NUM_WORKERS` (default `$(nproc)`, number of worker instances on localhost) diff --git a/bolt/shuffle/sparksql/tests/celeborn/scripts/run_e2e.sh b/bolt/shuffle/sparksql/tests/celeborn/scripts/run_e2e.sh index f6e494239..08f632d33 100755 --- a/bolt/shuffle/sparksql/tests/celeborn/scripts/run_e2e.sh +++ b/bolt/shuffle/sparksql/tests/celeborn/scripts/run_e2e.sh @@ -20,8 +20,8 @@ # Build directory is resolved as _build/. # # Environment variables (all optional): -# BOLT_CELEBORN_GIT_REPO - Celeborn git repo URL (default: https://github.com/apache/celeborn.git) -# BOLT_CELEBORN_GIT_REF - git ref to build (default: 81d89f3) +# BOLT_CELEBORN_GIT_REPO - Celeborn git repo URL (default: https://github.com/afterincomparableyum/celeborn.git) +# BOLT_CELEBORN_GIT_REF - git ref to build (default: 2e13df97aba3e25d80f5562fd4c0c8a3b34beb43) # BOLT_CELEBORN_MASTER_HOST - master bind host (default: 127.0.0.1) # BOLT_CELEBORN_MASTER_PORT - master bind port (default: 19097) # BOLT_CELEBORN_NUM_WORKERS - number of worker instances (default: $(nproc)) @@ -42,8 +42,8 @@ PROJECT_ROOT=$(cd "${SCRIPT_DIR}/../../../../../.." && pwd) RUNTIME_DIR="/tmp/bolt-celeborn-runtime-${USER:-unknown}" CELEBORN_HOME="${RUNTIME_DIR}/celeborn-bin" CELEBORN_SOURCE_HOME="${RUNTIME_DIR}/celeborn-src" -CELEBORN_GIT_REPO=${BOLT_CELEBORN_GIT_REPO:-"https://github.com/apache/celeborn.git"} -CELEBORN_GIT_REF=${BOLT_CELEBORN_GIT_REF:-"81d89f3"} +CELEBORN_GIT_REPO=${BOLT_CELEBORN_GIT_REPO:-"https://github.com/afterincomparableyum/celeborn.git"} +CELEBORN_GIT_REF=${BOLT_CELEBORN_GIT_REF:-"2e13df97aba3e25d80f5562fd4c0c8a3b34beb43"} MASTER_HOST=${BOLT_CELEBORN_MASTER_HOST:-"127.0.0.1"} MASTER_PORT=${BOLT_CELEBORN_MASTER_PORT:-19097} diff --git a/conanfile.py b/conanfile.py index b51f1e6b7..dcbbf52ed 100644 --- a/conanfile.py +++ b/conanfile.py @@ -282,7 +282,7 @@ def requirements(self): self.requires("date/3.0.4-bolt", transitive_headers=True, transitive_libs=True) self.requires("libbacktrace/cci.20210118") if self.options.get_safe("spark_compatible"): - self.requires("celeborn-cpp-client/main-20251212") + self.requires("celeborn-cpp-client/main-20260514") if self.options.get_safe("enable_paimon"): self.requires("paimon-cpp/0.0.3-bolt") if self.options.get_safe("enable_testutil"): diff --git a/scripts/conan/recipes/celeborn-cpp-client/all/conandata.yml b/scripts/conan/recipes/celeborn-cpp-client/all/conandata.yml index c4a2a1433..7d1072663 100644 --- a/scripts/conan/recipes/celeborn-cpp-client/all/conandata.yml +++ b/scripts/conan/recipes/celeborn-cpp-client/all/conandata.yml @@ -13,7 +13,7 @@ # limitations under the License. patches: - "main-20251212": + "main-20260514": - patch_file: "patches/cmake.patch" patch_description: "changes in CMakeLists.txt" patch_type: "conan" diff --git a/scripts/conan/recipes/celeborn-cpp-client/all/conanfile.py b/scripts/conan/recipes/celeborn-cpp-client/all/conanfile.py index 8c57f7cb2..7c96075f0 100755 --- a/scripts/conan/recipes/celeborn-cpp-client/all/conanfile.py +++ b/scripts/conan/recipes/celeborn-cpp-client/all/conanfile.py @@ -60,9 +60,9 @@ def build_requirements(self): def source(self): git = Git(self, folder="..") - git.clone("https://github.com/apache/celeborn", target="src") + git.clone("https://github.com/afterincomparableyum/celeborn", target="src") git = Git(self, folder=self.source_folder) - git.checkout("81d89f3") + git.checkout("2e13df97aba3e25d80f5562fd4c0c8a3b34beb43") apply_conandata_patches(self) def export_sources(self): diff --git a/scripts/conan/recipes/celeborn-cpp-client/config.yml b/scripts/conan/recipes/celeborn-cpp-client/config.yml index 31af00263..60e197c1c 100644 --- a/scripts/conan/recipes/celeborn-cpp-client/config.yml +++ b/scripts/conan/recipes/celeborn-cpp-client/config.yml @@ -13,5 +13,5 @@ # limitations under the License. versions: - "main-20251212": + "main-20260514": folder: all