From ddddbd4e8101187a2b57b1f9d65178e6f84f8624 Mon Sep 17 00:00:00 2001 From: Lei Rui Date: Mon, 1 Jun 2026 15:29:41 +0800 Subject: [PATCH] feat(sum-subop): Spark BIGINT sum AArch64 SVE on HashAgg Add SumAggregateSparkInt64SubOp (adapter updateGroupsFromDecoded, SVE kernel sveHashAggBatchUpdateGroupSums), DecodedVector hashAgg* layout APIs, env kill switch BOLT_SPARK_SUM_INT64_USE_SUBOP, and unit tests (DuckDB parity, env-off parity, SubOp vs Base, nullable gate, null constant, hashAgg layout modes). Co-authored-by: Old-Li883 Co-authored-by: helloxteen --- .../sparksql/aggregates/CMakeLists.txt | 18 +- .../sparksql/aggregates/SumAggregate.cpp | 36 ++ .../SumAggregateSparkInt64SubOp.cpp | 187 ++++++ .../aggregates/SumAggregateSparkInt64SubOp.h | 80 +++ .../SumAggregateSparkInt64SubOpSve.cpp | 554 ++++++++++++++++++ .../aggregates/tests/SumAggregationTest.cpp | 192 ++++++ bolt/vector/DecodedVector.h | 54 ++ bolt/vector/tests/DecodedVectorTest.cpp | 40 ++ 8 files changed, 1160 insertions(+), 1 deletion(-) create mode 100644 bolt/functions/sparksql/aggregates/SumAggregateSparkInt64SubOp.cpp create mode 100644 bolt/functions/sparksql/aggregates/SumAggregateSparkInt64SubOp.h create mode 100644 bolt/functions/sparksql/aggregates/SumAggregateSparkInt64SubOpSve.cpp diff --git a/bolt/functions/sparksql/aggregates/CMakeLists.txt b/bolt/functions/sparksql/aggregates/CMakeLists.txt index def39424e..89132fe28 100644 --- a/bolt/functions/sparksql/aggregates/CMakeLists.txt +++ b/bolt/functions/sparksql/aggregates/CMakeLists.txt @@ -37,6 +37,12 @@ set(PAAGG_TYPES ) specialize_template(PAAGG_GENERATED_FILES PercentileApproxAggregate.cpp.in ${PAAGG_TYPES}) +set(BOLT_SPARK_SUM_INT64_SUBOP_SVE "") +if(CMAKE_SYSTEM_NAME STREQUAL "Linux" AND + CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64.*|AARCH64.*|arm64.*)") + list(APPEND BOLT_SPARK_SUM_INT64_SUBOP_SVE SumAggregateSparkInt64SubOpSve.cpp) +endif() + bolt_add_library( bolt_functions_spark_aggregates ${PAAGG_GENERATED_FILES} @@ -52,7 +58,8 @@ bolt_add_library( PercentileApproxAggregate.cpp Register.cpp RegrReplacementAggregate.cpp - SumAggregate.cpp + SumAggregateSparkInt64SubOp.cpp + ${BOLT_SPARK_SUM_INT64_SUBOP_SVE} SumAggregate.cpp ) @@ -61,6 +68,15 @@ target_link_libraries( fmt::fmt ) +# SumAggregateSparkInt64SubOpSve.cpp uses ; default Bolt -march=armv8.3-a +# does not enable the SVE ISA; compile this Linux AArch64 TU with +sve only. +if(CMAKE_SYSTEM_NAME STREQUAL "Linux" AND + CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64.*|AARCH64.*|arm64.*)") + set_source_files_properties( + SumAggregateSparkInt64SubOpSve.cpp + PROPERTIES COMPILE_OPTIONS "-march=armv8-a+sve") +endif() + if(${BOLT_BUILD_TESTING}) add_subdirectory(tests) endif() diff --git a/bolt/functions/sparksql/aggregates/SumAggregate.cpp b/bolt/functions/sparksql/aggregates/SumAggregate.cpp index cd852df3f..6f2ddfbc7 100644 --- a/bolt/functions/sparksql/aggregates/SumAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/SumAggregate.cpp @@ -32,6 +32,11 @@ #include "bolt/functions/lib/aggregates/SumAggregateBase.h" #include "bolt/functions/sparksql/aggregates/DecimalSumAggregate.h" +#include "bolt/functions/sparksql/aggregates/SumAggregateSparkInt64SubOp.h" + +#include +#include + using namespace bytedance::bolt::functions::aggregate; namespace bytedance::bolt::functions::aggregate::sparksql { @@ -39,6 +44,34 @@ namespace { template using SumAggregate = SumAggregateBase; +// Process env rollback for Spark BIGINT sum (non-decimal): default SubOp on. +// Disable with BOLT_SPARK_SUM_INT64_USE_SUBOP=0 / false / no / off (ASCII, +// case-insensitive for the words). Empty or unknown non-false → SubOp on. +bool sparkSumInt64UseSubOpFromEnv() { + const char* v = std::getenv("BOLT_SPARK_SUM_INT64_USE_SUBOP"); + if (v == nullptr || *v == '\0') { + return true; + } + if (v[0] == '0' && v[1] == '\0') { + return false; + } + auto eqNoCase = [](const char* a, const char* b) { + while (*a && *b) { + if (std::tolower(static_cast(*a)) != + std::tolower(static_cast(*b))) { + return false; + } + ++a; + ++b; + } + return *a == *b; + }; + if (eqNoCase(v, "false") || eqNoCase(v, "no") || eqNoCase(v, "off")) { + return false; + } + return true; +} + TypePtr getDecimalSumType( const TypePtr& resultType, core::AggregationNode::Step step) { @@ -112,6 +145,9 @@ exec::AggregateRegistrationResult registerSum( resultType, sumType); } } + if (sparkSumInt64UseSubOpFromEnv()) { + return std::make_unique(BIGINT()); + } return std::make_unique>( BIGINT()); } diff --git a/bolt/functions/sparksql/aggregates/SumAggregateSparkInt64SubOp.cpp b/bolt/functions/sparksql/aggregates/SumAggregateSparkInt64SubOp.cpp new file mode 100644 index 000000000..8ca721274 --- /dev/null +++ b/bolt/functions/sparksql/aggregates/SumAggregateSparkInt64SubOp.cpp @@ -0,0 +1,187 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * -------------------------------------------------------------------------- + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + * + * This file has been modified by ByteDance Ltd. and/or its affiliates on + * 2025-11-11. + * + * Original file was released under the Apache License 2.0, + * with the full license text available at: + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This modified file is released under the same license. + * -------------------------------------------------------------------------- + */ + +#include "bolt/functions/sparksql/aggregates/SumAggregateSparkInt64SubOp.h" + +#include "bolt/exec/AggregationHook.h" +#include "bolt/functions/lib/CheckedArithmeticImpl.h" +#include "bolt/vector/BaseVector.h" +#include "bolt/vector/DecodedVector.h" +#include "bolt/vector/LazyVector.h" +#include "bolt/vector/VectorEncoding.h" + +#if defined(__aarch64__) && defined(__linux__) +#include +#endif + +namespace bytedance::bolt::functions::aggregate::sparksql { + +namespace { + +static void sparkSumInt64UpdateSingle(int64_t& result, int64_t value) { + if (::bytedance::bolt::functions::aggregate::Overflow) { + result += value; + } else { + CHECK_ADD(result, value); + } +} + +#if defined(__aarch64__) && defined(__linux__) +// SVE is advertised on AT_HWCAP (HWCAP_SVE), not AT_HWCAP2 — see Linux +// arch/arm64/include/uapi/asm/hwcap.h. +#ifndef HWCAP_SVE +constexpr unsigned long kBoltHwcapSve = 1UL << 22; +#else +constexpr unsigned long kBoltHwcapSve = HWCAP_SVE; +#endif + +static bool linuxAarch64RuntimeHasSve() { + const unsigned long hwcap = getauxval(AT_HWCAP); + return (hwcap & kBoltHwcapSve) != 0; +} +#else +static bool linuxAarch64RuntimeHasSve() { + return false; +} +#endif + +} // namespace + +SumAggregateSparkInt64SubOp::SumAggregateSparkInt64SubOp(TypePtr resultType) + : Base(std::move(resultType)) {} + +void SumAggregateSparkInt64SubOp::addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) { + const auto& arg = args[0]; + + if (mayPushdown && arg->isLazy()) { + Base::addRawInput(groups, rows, args, mayPushdown); + return; + } + + using ::bytedance::bolt::functions::aggregate::Overflow; + if (this->numNulls_ && Overflow) { + DecodedVector decoded(*arg, rows, !mayPushdown); + const auto encoding = decoded.base()->encoding(); + if (mayPushdown && encoding == VectorEncoding::Simple::LAZY && + !arg->type()->isDecimal()) { + bytedance::bolt::aggregate::SimpleCallableHook< + int64_t, + int64_t, + void (*)(int64_t&, int64_t)> + hook( + offset_, + nullByte_, + nullMask_, + groups, + &numNulls_, + sparkSumInt64UpdateSingle); + auto indices = decoded.indices(); + decoded.base()->as()->load( + ::bytedance::bolt::RowSet(indices, arg->size()), &hook); + return; + } + if (decoded.mayHaveNulls()) { +#if defined(__aarch64__) + if (linuxAarch64RuntimeHasSve() && + updateGroupsFromDecoded(groups, rows, decoded)) { + return; + } +#endif + } + Base::addRawInput(groups, rows, args, mayPushdown); + return; + } + + Base::addRawInput(groups, rows, args, mayPushdown); +} + +void SumAggregateSparkInt64SubOp::addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) { + const auto& arg = args[0]; + + if (mayPushdown && arg->isLazy()) { + Base::addIntermediateResults(groups, rows, args, mayPushdown); + return; + } + + using ::bytedance::bolt::functions::aggregate::Overflow; + if (this->numNulls_ && Overflow) { + DecodedVector decoded(*arg, rows, !mayPushdown); + const auto encoding = decoded.base()->encoding(); + if (mayPushdown && encoding == VectorEncoding::Simple::LAZY && + !arg->type()->isDecimal()) { + bytedance::bolt::aggregate::SimpleCallableHook< + int64_t, + int64_t, + void (*)(int64_t&, int64_t)> + hook( + offset_, + nullByte_, + nullMask_, + groups, + &numNulls_, + sparkSumInt64UpdateSingle); + auto indices = decoded.indices(); + decoded.base()->as()->load( + ::bytedance::bolt::RowSet(indices, arg->size()), &hook); + return; + } + if (decoded.mayHaveNulls()) { +#if defined(__aarch64__) + if (linuxAarch64RuntimeHasSve() && + updateGroupsFromDecoded(groups, rows, decoded)) { + return; + } +#endif + } + Base::addIntermediateResults(groups, rows, args, mayPushdown); + return; + } + + Base::addIntermediateResults(groups, rows, args, mayPushdown); +} + +#if !defined(__aarch64__) +bool SumAggregateSparkInt64SubOp::updateGroupsFromDecoded( + char** /*groups*/, + const SelectivityVector& /*rows*/, + ::bytedance::bolt::DecodedVector& /*decoded*/) { + return false; +} +#endif + +} // namespace bytedance::bolt::functions::aggregate::sparksql diff --git a/bolt/functions/sparksql/aggregates/SumAggregateSparkInt64SubOp.h b/bolt/functions/sparksql/aggregates/SumAggregateSparkInt64SubOp.h new file mode 100644 index 000000000..fb22effc6 --- /dev/null +++ b/bolt/functions/sparksql/aggregates/SumAggregateSparkInt64SubOp.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * -------------------------------------------------------------------------- + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + * + * This file has been modified by ByteDance Ltd. and/or its affiliates on + * 2025-11-11. + * + * Original file was released under the Apache License 2.0, + * with the full license text available at: + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This modified file is released under the same license. + * -------------------------------------------------------------------------- + */ + +#pragma once + +#include "bolt/functions/lib/aggregates/SumAggregateBase.h" + +namespace bytedance::bolt { +class DecodedVector; +} + +namespace bytedance::bolt::functions::aggregate::sparksql { + +/// Spark sum(bigint)->bigint SubOp (default unless `BOLT_SPARK_SUM_INT64_USE_SUBOP` +/// is off; see `SumAggregate.cpp`). When `numNulls_`, Spark overflow gate, and +/// `decoded.mayHaveNulls()` hold, runs the AArch64 SVE batch kernel; otherwise +/// defers to `SumAggregateBase`. +class SumAggregateSparkInt64SubOp + : public ::bytedance::bolt::functions::aggregate::SumAggregateBase< + int64_t, + int64_t, + int64_t> { + using Base = ::bytedance::bolt::functions::aggregate::SumAggregateBase< + int64_t, + int64_t, + int64_t>; + + public: + explicit SumAggregateSparkInt64SubOp(TypePtr resultType); + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override; + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool mayPushdown) override; + + private: + /// Applies `sveHashAggBatchUpdateGroupSums` for one decoded batch. + /// Returns true when the SVE path was taken (caller must not invoke Base for + /// this batch). Returns false only on non-aarch64 stub builds. + bool updateGroupsFromDecoded( + char** groups, + const SelectivityVector& rows, + ::bytedance::bolt::DecodedVector& decoded); +}; + +} // namespace bytedance::bolt::functions::aggregate::sparksql diff --git a/bolt/functions/sparksql/aggregates/SumAggregateSparkInt64SubOpSve.cpp b/bolt/functions/sparksql/aggregates/SumAggregateSparkInt64SubOpSve.cpp new file mode 100644 index 000000000..27079986e --- /dev/null +++ b/bolt/functions/sparksql/aggregates/SumAggregateSparkInt64SubOpSve.cpp @@ -0,0 +1,554 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * -------------------------------------------------------------------------- + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + * + * This file has been modified by ByteDance Ltd. and/or its affiliates on + * 2025-11-11. + * + * Original file was released under the Apache License 2.0, + * with the full license text available at: + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This modified file is released under the same license. + * -------------------------------------------------------------------------- + * + * AArch64 SVE batch kernel for Spark sum(bigint) HashAgg group updates. + * Compiled only on aarch64 (see aggregates/CMakeLists.txt, `-march=armv8-a+sve`). + */ + +#include "bolt/functions/sparksql/aggregates/SumAggregateSparkInt64SubOp.h" + +#include + +#include "bolt/vector/BaseVector.h" +#include "bolt/vector/DecodedVector.h" +#include "bolt/vector/SelectivityVector.h" + +namespace bytedance::bolt::functions::aggregate::sparksql { + +namespace { + +constexpr uint64_t kSupportedSveVectorBytes = 32; + +template +inline bool isBitSet(const T* bits, uint64_t idx) { + return bits[idx / (sizeof(bits[0]) * 8)] & + (static_cast(1) << (idx & ((sizeof(bits[0]) * 8) - 1))); +} + +inline bool isBitNull(const uint64_t* bits, int32_t index) { + return isBitSet(bits, index) == false; +} +template +constexpr inline T roundUp(T value, U factor) { + return (value + (factor - 1)) / factor * factor; +} + +svbool_t sveDecodedNullMaskForMode( + uint8_t* nulls_, + int32_t index, + int mode, + uint32_t* dic, + int32_t length) { + svbool_t pg; + if (mode == 0) { + pg = svptrue_b8(); + return pg; + } else if (mode == 1) { + __asm__ __volatile__("ldr %0, [%1]" + : "=Upl"(pg) + : "r"(&(nulls_[index])) + : "memory"); + return pg; + } else if (mode == 2) { + if (!isBitNull( + reinterpret_cast(nulls_), + 0)) + { + pg = svptrue_b8(); + } else { + pg = svpfalse(); + } + return pg; + } else if (mode == 3) { + + svuint32_t onc = svdup_u32(1); + svuint32_t inv = svindex_u32(0, 1); + svuint32_t pow = svlsl_m(svptrue_b32(), onc, inv); + uint8_t tmpNulls[4] = {0}; + uint32_t* null32ptr = reinterpret_cast(nulls_); + + svuint32_t posv, idxbufv, bufv, offsetv; + svbool_t nullvec, pg1; + + // mode1==3: pack null bits for eight dictionary lanes (chunk 0). + pg1 = svwhilelt_b32(index * 8, length); + posv = svld1(pg1, dic + index * 8); + idxbufv = svlsr_x(pg1, posv, 5); // u32 word index (pos / 32) + bufv = svld1_gather_index(pg1, null32ptr, idxbufv); + offsetv = svand_m(pg1, posv, 0b11111); // bit index within the u32 word + bufv = svlsr_m(pg1, bufv, offsetv); + bufv = svand_m(pg1, bufv, 0x1); + nullvec = svcmpgt(pg1, bufv, 0); + if (__builtin_expect((svptest_any(pg1, nullvec)), 0)) { + uint8_t nullsres = svaddv(nullvec, pow); + tmpNulls[0] = nullsres; + } else { + tmpNulls[0] = 0; + } + + // mode1==3: dictionary null bits (chunk 1). + pg1 = svwhilelt_b32(index * 8 + 8, length); + posv = svld1(pg1, dic + index * 8 + 8); + idxbufv = svlsr_x(pg1, posv, 5); + bufv = svld1_gather_index(pg1, null32ptr, idxbufv); + offsetv = svand_m(pg1, posv, 0b11111); + bufv = svlsr_m(pg1, bufv, offsetv); + bufv = svand_m(pg1, bufv, 0x1); + nullvec = svcmpgt(pg1, bufv, 0); + if (__builtin_expect((svptest_any(pg1, nullvec)), 0)) { + uint8_t nullsres = svaddv(nullvec, pow); + tmpNulls[1] = nullsres; + } else { + tmpNulls[1] = 0; + } + + // mode1==3: dictionary null bits (chunk 2). + pg1 = svwhilelt_b32(index * 8 + 16, length); + posv = svld1(pg1, dic + index * 8 + 16); + idxbufv = svlsr_x(pg1, posv, 5); + bufv = svld1_gather_index(pg1, null32ptr, idxbufv); + offsetv = svand_m(pg1, posv, 0b11111); + bufv = svlsr_m(pg1, bufv, offsetv); + bufv = svand_m(pg1, bufv, 0x1); + nullvec = svcmpgt(pg1, bufv, 0); + if (__builtin_expect((svptest_any(pg1, nullvec)), 0)) { + uint8_t nullsres = svaddv(nullvec, pow); + tmpNulls[2] = nullsres; + } else { + tmpNulls[2] = 0; + } + + // mode1==3: dictionary null bits (chunk 3). + pg1 = svwhilelt_b32(index * 8 + 24, length); + posv = svld1(pg1, dic + index * 8 + 24); + idxbufv = svlsr_x(pg1, posv, 5); + bufv = svld1_gather_index(pg1, null32ptr, idxbufv); + offsetv = svand_m(pg1, posv, 0b11111); + bufv = svlsr_m(pg1, bufv, offsetv); + bufv = svand_m(pg1, bufv, 0x1); + nullvec = svcmpgt(pg1, bufv, 0); + if (__builtin_expect((svptest_any(pg1, nullvec)), 0)) { + uint8_t nullsres = svaddv(nullvec, pow); + tmpNulls[3] = nullsres; + } else { + tmpNulls[3] = 0; + } + + __asm__ __volatile__("ldr %0, [%1]" + : "=Upl"(pg) + : "r"(tmpNulls) + : "memory"); + return pg; + } + // Unknown mode1: inactive predicate. + pg = svpfalse(); + return pg; +} + +inline __attribute__((always_inline)) svbool_t +sveMaskDistinctGroupSlots(svbool_t pg, const svuint64_t val) { + svuint64_t s1 = svext_u64(val, val, 1); + svbool_t mask2 = svcmpeq(svwhilelt_b64(0, 3), val, s1); + + svuint64_t s2 = svext_u64(val, val, 2); + svbool_t mask3 = svcmpeq(svwhilelt_b64(0, 2), val, s2); + svbool_t mask12 = svorr_b_z(pg, mask2, mask3); + + svuint64_t s3 = svext_u64(val, val, 3); + svbool_t mask4 = svcmpeq(svwhilelt_b64(0, 1), val, s3); + + svbool_t mask = svorr_b_z(pg, mask4, mask12); + mask = svnot_b_z(pg, mask); + + return mask; +} + +static bool sveClearGroupNullFlags( + int32_t nullByte, + uint8_t nullMask, + uint64_t* numNulls, + svuint64_t ptr, + svbool_t pg) { + if (*numNulls) { + svint64_t group = + svld1sb_gather_u64base_offset_s64(pg, ptr, nullByte); + svuint8_t group8 = svreinterpret_u8(group); + + svuint8_t tmp = svand_n_u8_z(pg, group8, nullMask); + svbool_t test = svcmpne_n_u8(svptrue_b8(), tmp, 0); + if (svptest_any(svptrue_b8(), test)) { + uint8_t negNull = ~nullMask; + + svuint8_t adjust = svand_n_u8_m(test, group8, negNull); + svst1b_scatter_u64base_offset_s64( + pg, ptr, nullByte, svreinterpret_s64(adjust)); + + int num = svcntp_b8(test, test); + *numNulls -= num; + return true; + } + } + return false; +} + +template +static void sveHashAggBatchUpdateGroupSums( + int32_t nullByte, + uint8_t nullMask, + uint64_t* numNulls, + GetPtr&& getAccumPtr, + char** result, + uint64_t* bitmap1, + uint64_t* bitmap2, + int64_t* value, + int32_t begin, + int32_t end, + int mode1, + int mode2, + uint32_t* dic) { + uint8_t* bitmap1_8 = reinterpret_cast(bitmap1); + uint8_t* bitmap2_8 = reinterpret_cast(bitmap2); + + int32_t firstWord = + roundUp(begin, 32) == begin ? begin : roundUp(begin, 32) - 32; + int32_t lastWord = roundUp(end, 32); + svbool_t mask, mask1; + // Process 32 logical rows per iteration; `count` is the row index. + for (int32_t count = firstWord; count + 32 <= lastWord; count += 32) { + int32_t arr8Index = count / 8; + svbool_t mask2; + if (bitmap2_8 != nullptr) { + mask2 = sveDecodedNullMaskForMode(bitmap2_8, arr8Index, mode1, dic, end); + } else { + mask2 = svptrue_b8(); + } + __asm__ __volatile__("ldr %0, [%1]" + : "=Upl"(mask1) + : "r"(&bitmap1_8[arr8Index]) + : "memory"); + mask = svand_b_z(svptrue_b8(), mask1, mask2); + mask = svand_b_z(svptrue_b8(), mask, svwhilelt_b8(count, end)); + if (!svptest_any(svptrue_b8(), mask)) { + continue; + } + + svbool_t mask00 = svunpklo(mask); + svbool_t mask01 = svunpkhi(mask); + if (svptest_any(svptrue_b16(), mask00)) { + svbool_t mask10 = svunpklo(mask00); + if (svptest_any(svptrue_b32(), mask10)) { + svbool_t mask20 = svunpklo(mask10); + svbool_t mask21 = svunpkhi(mask10); + if (svptest_any(svptrue_b64(), mask20)) { + svuint64_t ptr = + svld1(mask20, reinterpret_cast(result + count)); + svbool_t m20 = sveMaskDistinctGroupSlots(mask20, ptr); + sveClearGroupNullFlags(nullByte, nullMask, numNulls, ptr, m20); + uint8_t flag0[4] = {0, 0, 0, 0}; + __asm__ __volatile__("str %1, [%0]": : "r" (&flag0[0]), "Upl" (mask20) : "memory"); + + // mode2==3: gather via dictionary indices; else flat row values. + if (mode2 == 3) { + for (int i = 0; i < 4; i++) { + if (flag0[i] != 0) { + uint32_t dictIndex = dic[count + i]; + int64_t dictValue = value[dictIndex]; + *getAccumPtr(*(result + count + i)) += dictValue; + } + } + } else { + for (int i = 0; i < 4; i++) { + if (flag0[i] != 0) { + *getAccumPtr(*(result + count + i)) += value[count + i]; + } + } + } + } + + if (svptest_any(svptrue_b64(), mask21)) { + svuint64_t ptr = + svld1(mask21, reinterpret_cast(result + count + 4)); + svbool_t m21 = sveMaskDistinctGroupSlots(mask21, ptr); + sveClearGroupNullFlags(nullByte, nullMask, numNulls, ptr, m21); + uint8_t flag1[4] = {0, 0, 0, 0}; + __asm__ __volatile__("str %1, [%0]": : "r" (&flag1[0]), "Upl" (mask21) : "memory"); + + if (mode2 == 3) { + for (int i = 0; i < 4; i++) { + if (flag1[i] != 0) { + uint32_t dictIndex = dic[count + 4 + i]; + int64_t dictValue = value[dictIndex]; + *getAccumPtr(*(result + count + 4 + i)) += dictValue; + } + } + } else { + for (int i = 0; i < 4; i++) { + if (flag1[i] != 0) { + *getAccumPtr(*(result + count + 4 + i)) += value[count + 4 + i]; + } + } + } + } + } + svbool_t mask11 = svunpkhi(mask00); + if (svptest_any(svptrue_b32(), mask11)) { + svbool_t mask22 = svunpklo(mask11); + svbool_t mask23 = svunpkhi(mask11); + if (svptest_any(svptrue_b64(), mask22)) { + svuint64_t ptr = + svld1(mask22, reinterpret_cast(result + count + 8)); + svbool_t m22 = sveMaskDistinctGroupSlots(mask22, ptr); + sveClearGroupNullFlags(nullByte, nullMask, numNulls, ptr, m22); + uint8_t flag2[4] = {0, 0, 0, 0}; + __asm__ __volatile__("str %1, [%0]": : "r" (&flag2[0]), "Upl" (mask22) : "memory"); + + if (mode2 == 3) { + for (int i = 0; i < 4; i++) { + if (flag2[i] != 0) { + uint32_t dictIndex = dic[count + 8 + i]; + int64_t dictValue = value[dictIndex]; + *getAccumPtr(*(result + count + 8 + i)) += dictValue; + } + } + } else { + for (int i = 0; i < 4; i++) { + if (flag2[i] != 0) { + *getAccumPtr(*(result + count + 8 + i)) += value[count + 8 + i]; + } + } + } + } + + if (svptest_any(svptrue_b64(), mask23)) { + svuint64_t ptr = + svld1(mask23, reinterpret_cast(result + count + 12)); + svbool_t m23 = sveMaskDistinctGroupSlots(mask23, ptr); + sveClearGroupNullFlags(nullByte, nullMask, numNulls, ptr, m23); + uint8_t flag3[4] = {0, 0, 0, 0}; + __asm__ __volatile__("str %1, [%0]": : "r" (&flag3[0]), "Upl" (mask23) : "memory"); + + if (mode2 == 3) { + for (int i = 0; i < 4; i++) { + if (flag3[i] != 0) { + uint32_t dictIndex = dic[count + 12 + i]; + int64_t dictValue = value[dictIndex]; + *getAccumPtr(*(result + count + 12 + i)) += dictValue; + } + } + } else { + for (int i = 0; i < 4; i++) { + if (flag3[i] != 0) { + *getAccumPtr(*(result + count + 12 + i)) += value[count + 12 + i]; + } + } + } + } + } + } + + svbool_t mask12 = svunpklo(mask01); + + if (svptest_any(svptrue_b16(), mask01)) { + svbool_t mask24 = svunpklo(mask12); + svbool_t mask25 = svunpkhi(mask12); + if (svptest_any(svptrue_b32(), mask12)) { + if (svptest_any(svptrue_b64(), mask24)) { + svuint64_t ptr = + svld1(mask24, reinterpret_cast(result + count + 16)); + svbool_t m24 = sveMaskDistinctGroupSlots(mask24, ptr); + sveClearGroupNullFlags(nullByte, nullMask, numNulls, ptr, m24); + uint8_t flag4[4] = {0, 0, 0, 0}; + __asm__ __volatile__("str %1, [%0]": : "r" (&flag4[0]), "Upl" (mask24) : "memory"); + + if (mode2 == 3) { + for (int i = 0; i < 4; i++) { + if (flag4[i] != 0) { + uint32_t dictIndex = dic[count + 16 + i]; + int64_t dictValue = value[dictIndex]; + *getAccumPtr(*(result + count + 16 + i)) += dictValue; + } + } + } else { + for (int i = 0; i < 4; i++) { + if (flag4[i] != 0) { + *getAccumPtr(*(result + count + 16 + i)) += value[count + 16 + i]; + } + } + } + } + + if (svptest_any(svptrue_b64(), mask25)) { + svuint64_t ptr = + svld1(mask25, reinterpret_cast(result + count + 20)); + svbool_t m25 = sveMaskDistinctGroupSlots(mask25, ptr); + sveClearGroupNullFlags(nullByte, nullMask, numNulls, ptr, m25); + uint8_t flag5[4] = {0, 0, 0, 0}; + __asm__ __volatile__("str %1, [%0]": : "r" (&flag5[0]), "Upl" (mask25) : "memory"); + + if (mode2 == 3) { + for (int i = 0; i < 4; i++) { + if (flag5[i] != 0) { + uint32_t dictIndex = dic[count + 20 + i]; + int64_t dictValue = value[dictIndex]; + *getAccumPtr(*(result + count + 20 + i)) += dictValue; + } + } + } else { + for (int i = 0; i < 4; i++) { + if (flag5[i] != 0) { + *getAccumPtr(*(result + count + 20 + i)) += value[count + 20 + i]; + } + } + } + } + } + svbool_t mask13 = svunpkhi(mask01); + + if (svptest_any(svptrue_b32(), mask13)) { + svbool_t mask26 = svunpklo(mask13); + svbool_t mask27 = svunpkhi(mask13); + if (svptest_any(svptrue_b64(), mask26)) { + svuint64_t ptr = + svld1(mask26, reinterpret_cast(result + count + 24)); + svbool_t m26 = sveMaskDistinctGroupSlots(mask26, ptr); + sveClearGroupNullFlags(nullByte, nullMask, numNulls, ptr, m26); + uint8_t flag6[4] = {0, 0, 0, 0}; + __asm__ __volatile__("str %1, [%0]": : "r" (&flag6[0]), "Upl" (mask26) : "memory"); + + if (mode2 == 3) { + for (int i = 0; i < 4; i++) { + if (flag6[i] != 0) { + uint32_t dictIndex = dic[count + 24 + i]; + int64_t dictValue = value[dictIndex]; + *getAccumPtr(*(result + count + 24 + i)) += dictValue; + } + } + } else { + for (int i = 0; i < 4; i++) { + if (flag6[i] != 0) { + *getAccumPtr(*(result + count + 24 + i)) += value[count + 24 + i]; + } + } + } + } + + if (svptest_any(svptrue_b64(), mask27)) { + svuint64_t ptr = + svld1(mask27, reinterpret_cast(result + count + 28)); + svbool_t m27 = sveMaskDistinctGroupSlots(mask27, ptr); + sveClearGroupNullFlags(nullByte, nullMask, numNulls, ptr, m27); + uint8_t flag7[4] = {0, 0, 0, 0}; + __asm__ __volatile__("str %1, [%0]": : "r" (&flag7[0]), "Upl" (mask27) : "memory"); + + if (mode2 == 3) { + for (int i = 0; i < 4; i++) { + if (flag7[i] != 0) { + uint32_t dictIndex = dic[count + 28 + i]; + int64_t dictValue = value[dictIndex]; + *getAccumPtr(*(result + count + 28 + i)) += dictValue; + } + } + } else { + for (int i = 0; i < 4; i++) { + if (flag7[i] != 0) { + *getAccumPtr(*(result + count + 28 + i)) += value[count + 28 + i]; + } + } + } + } + } + } + } +} + +} // namespace + +bool SumAggregateSparkInt64SubOp::updateGroupsFromDecoded( + char** groups, + const SelectivityVector& rows, + ::bytedance::bolt::DecodedVector& decoded) { + using ::bytedance::bolt::functions::aggregate::Overflow; + BOLT_DCHECK(numNulls_); + BOLT_DCHECK(Overflow); + BOLT_DCHECK(decoded.mayHaveNulls()); + + // This kernel stores SVE predicates into four-byte scratch buffers and + // processes 32 rows per block. Use it only on 256-bit SVE; other VLs fall + // back to the scalar Base path in the caller. + if (svcntb() != kSupportedSveVectorBytes) { + return false; + } + + const int32_t mode1 = decoded.hashAggNullsLayoutMode(); + const int32_t mode2 = decoded.hashAggIndicesLayoutMode(); + BOLT_DCHECK_GE(mode1, 0); + BOLT_DCHECK_LE(mode1, 3); + BOLT_DCHECK_GE(mode2, 1); + BOLT_DCHECK_LE(mode2, 3); + if (mode2 == 2) { + return false; + } + uint64_t* bitmap2 = decoded.hashAggMutableCombinedNullBits(); + int64_t* valueBuf = reinterpret_cast(decoded.hashAggMutableRawData()); + BOLT_DCHECK_NOT_NULL(valueBuf); + uint32_t* dic = mode2 == 3 + ? reinterpret_cast(decoded.hashAggMutableIndices()) + : nullptr; + if (mode2 == 3) { + BOLT_DCHECK_NOT_NULL(dic); + } + + uint64_t* rowsBits = const_cast(rows.allBits()); + const vector_size_t begin = rows.begin(); + const vector_size_t end = rows.end(); + + auto getAccum = [this](char* group) -> int64_t* { + return this->template value(group); + }; + + sveHashAggBatchUpdateGroupSums( + nullByte_, + nullMask_, + &numNulls_, + getAccum, + groups, + rowsBits, + bitmap2, + valueBuf, + begin, + end, + mode1, + mode2, + dic); + // On aarch64: batch handled by SVE; caller skips Base. Shape gating is upstream + // (`mayHaveNulls`, `numNulls_`, auxv). + return true; +} + +} // namespace bytedance::bolt::functions::aggregate::sparksql diff --git a/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp b/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp index 488f06ae8..408dcebb8 100644 --- a/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp +++ b/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp @@ -33,6 +33,8 @@ #include "bolt/functions/lib/aggregates/tests/SumTestBase.h" #include "bolt/functions/sparksql/aggregates/Register.h" +#include + using bytedance::bolt::exec::test::PlanBuilder; using namespace bytedance::bolt::exec::test; using namespace bytedance::bolt::functions::aggregate::test; @@ -127,6 +129,196 @@ TEST_F(SumAggregationTest, overflow) { SumTestBase::testAggregateOverflow("spark_sum"); } +// DuckDB parity: same SQL/input as reference. Spark sum(bigint) defaults to SubOp; +// `BOLT_SPARK_SUM_INT64_USE_SUBOP=0` selects Base (see env test below). +TEST_F(SumAggregationTest, sumInt64SubOpParity) { +#if !defined(_WIN32) + ::unsetenv("BOLT_SPARK_SUM_INT64_USE_SUBOP"); +#endif + auto globalInput = + makeRowVector({makeFlatVector({7, 11, 13, -5, 2})}); + createDuckDbTable({globalInput}); + testAggregations( + {globalInput}, + {}, + {"spark_sum(c0)"}, + "SELECT sum(c0) FROM tmp", + /*config*/ {}, + /*testWithTableScan*/ false); + + auto groupedInput = makeRowVector( + {makeFlatVector({0, 0, 1, 1, 1}), + makeFlatVector({100, 200, 30, 40, 50})}); + createDuckDbTable({groupedInput}); + testAggregations( + {groupedInput}, + {"c0"}, + {"spark_sum(c1)"}, + "SELECT c0, sum(c1) FROM tmp GROUP BY c0", + {}, + false); +} + +// Same parity as above with SubOp disabled at process start (POSIX setenv). +// Gluten: `spark.executorEnv.BOLT_SPARK_SUM_INT64_USE_SUBOP=0` on executors. +TEST_F(SumAggregationTest, sumInt64SubOpEnvOffParity) { +#if defined(_WIN32) + GTEST_SKIP() << "BOLT_SPARK_SUM_INT64_USE_SUBOP uses POSIX setenv/unsetenv"; +#else + constexpr const char* kEnv = "BOLT_SPARK_SUM_INT64_USE_SUBOP"; + ASSERT_EQ(0, ::setenv(kEnv, "0", 1)); + struct UnsetEnv { + const char* key; + ~UnsetEnv() { + ::unsetenv(key); + } + } unset{kEnv}; + + auto globalInput = + makeRowVector({makeFlatVector({7, 11, 13, -5, 2})}); + createDuckDbTable({globalInput}); + testAggregations( + {globalInput}, + {}, + {"spark_sum(c0)"}, + "SELECT sum(c0) FROM tmp", + /*config*/ {}, + /*testWithTableScan*/ false); + + auto groupedInput = makeRowVector( + {makeFlatVector({0, 0, 1, 1, 1}), + makeFlatVector({100, 200, 30, 40, 50})}); + createDuckDbTable({groupedInput}); + testAggregations( + {groupedInput}, + {"c0"}, + {"spark_sum(c1)"}, + "SELECT c0, sum(c1) FROM tmp GROUP BY c0", + {}, + false); +#endif +} + +// Nullable grouped bigint: `decoded.mayHaveNulls()` is true; with table-side null +// groups (`numNulls_`) and Spark overflow gate this hits the SubOp SVE path on +// **Linux AArch64 + SVE auxv**; on other hosts SubOp falls back to `SumAggregateBase` +// and the DuckDB reference is unchanged. +TEST_F(SumAggregationTest, sumInt64SubOpNullableSveGate) { +#if !defined(_WIN32) + ::unsetenv("BOLT_SPARK_SUM_INT64_USE_SUBOP"); +#endif + auto input = makeRowVector({ + makeFlatVector({0, 0, 1, 1, 1, 2}), + makeNullableFlatVector( + {10, std::nullopt, 30, 4, std::nullopt, 100}), + }); + createDuckDbTable({input}); + + testAggregations( + {input}, + {"c0"}, + {"spark_sum(c1)"}, + "SELECT c0, sum(c1) FROM tmp GROUP BY c0", + {}, + false); +} + +// Same nullable grouped input: default SubOp (SVE on Linux aarch64 + auxv) vs +// `BOLT_SPARK_SUM_INT64_USE_SUBOP=0` (Base). Final partial+final results must +// match; catches SVE vs scalar divergence without relying on DuckDB alone. +TEST_F(SumAggregationTest, sumInt64SubOpSveMatchesBase) { +#if defined(_WIN32) + GTEST_SKIP() << "BOLT_SPARK_SUM_INT64_USE_SUBOP uses POSIX setenv/unsetenv"; +#else + constexpr const char* kEnv = "BOLT_SPARK_SUM_INT64_USE_SUBOP"; + + const std::vector batches = { + makeRowVector({ + makeFlatVector({0, 0, 1}), + makeNullableFlatVector( + {10, std::nullopt, 30}), + }), + makeRowVector({ + makeFlatVector({1, 1, 2}), + makeNullableFlatVector({4, std::nullopt, 100}), + }), + }; + + auto runGroupedSparkSum = [&](bool subOpEnabled) -> RowVectorPtr { + if (subOpEnabled) { + ::unsetenv(kEnv); + } else { + CHECK_EQ(0, ::setenv(kEnv, "0", 1)); + } + struct UnsetEnv { + const char* key; + bool enabled; + ~UnsetEnv() { + if (enabled) { + ::unsetenv(key); + } + } + } unset{kEnv, !subOpEnabled}; + + PlanBuilder builder(pool()); + builder.values(batches); + builder.partialAggregation({"c0"}, {"spark_sum(c1)"}).finalAggregation(); + return AssertQueryBuilder(builder.planNode()).copyResults(pool()); + }; + + auto subOpResult = runGroupedSparkSum(true); + auto baseResult = runGroupedSparkSum(false); + ASSERT_TRUE(assertEqualResults({baseResult}, {subOpResult})); +#endif +} + +// Null constant bigint (mode2=2, mayHaveNulls): SVE null-mask path vs Base; +// same env toggle as sumInt64SubOpSveMatchesBase. +TEST_F(SumAggregationTest, sumInt64SubOpNullConstMatchesBase) { +#if defined(_WIN32) + GTEST_SKIP() << "BOLT_SPARK_SUM_INT64_USE_SUBOP uses POSIX setenv/unsetenv"; +#else + constexpr const char* kEnv = "BOLT_SPARK_SUM_INT64_USE_SUBOP"; + + const std::vector batches = { + makeRowVector({ + makeFlatVector({0, 0, 1}), + makeConstant(std::nullopt, 3), + }), + makeRowVector({ + makeFlatVector({1, 2, 2}), + makeConstant(std::nullopt, 3), + }), + }; + + auto runGroupedSparkSum = [&](bool subOpEnabled) -> RowVectorPtr { + if (subOpEnabled) { + ::unsetenv(kEnv); + } else { + CHECK_EQ(0, ::setenv(kEnv, "0", 1)); + } + struct UnsetEnv { + const char* key; + bool enabled; + ~UnsetEnv() { + if (enabled) { + ::unsetenv(key); + } + } + } unset{kEnv, !subOpEnabled}; + + PlanBuilder builder(pool()); + builder.values(batches); + builder.partialAggregation({"c0"}, {"spark_sum(c1)"}).finalAggregation(); + return AssertQueryBuilder(builder.planNode()).copyResults(pool()); + }; + + auto subOpResult = runGroupedSparkSum(true); + auto baseResult = runGroupedSparkSum(false); + ASSERT_TRUE(assertEqualResults({baseResult}, {subOpResult})); +#endif +} + TEST_F(SumAggregationTest, hookLimits) { testHookLimits(); } diff --git a/bolt/vector/DecodedVector.h b/bolt/vector/DecodedVector.h index fa0f7ecf3..8123ade80 100644 --- a/bolt/vector/DecodedVector.h +++ b/bolt/vector/DecodedVector.h @@ -243,6 +243,60 @@ class DecodedVector { return isConstantMapping_; } + // --------------------------------------------------------------------------- + // HashAgg batch update layout (Spark sum int64 SubOp and similar batch paths). + // + // Layout discriminators for HashAgg batch update kernels (nulls / indices). + // ISA-specific paths (e.g. AArch64 SVE) consume these; the API is not SVE-only. + // --------------------------------------------------------------------------- + + /// Null layout: 0 = no combined null bitmask; 1 = identity or extra-nulls + /// on top-level rows; 2 = constant mapping; 3 = per-index nulls via indices. + int32_t hashAggNullsLayoutMode() const { + if (!nulls_) { + return 0; + } + if (isIdentityMapping_ || hasExtraNulls_) { + return 1; + } + if (isConstantMapping_) { + return 2; + } + return 3; + } + + /// Index layout: 1 = identity; 2 = constant; 3 = general (dictionary) indices. + int32_t hashAggIndicesLayoutMode() const { + if (isIdentityMapping_) { + return 1; + } + if (isConstantMapping_) { + return 2; + } + BOLT_DCHECK(indices_); + return 3; + } + + /// Mutable combined null bits (may be nullptr). See `hashAggNullsLayoutMode()`. + uint64_t* hashAggMutableCombinedNullBits() { + return const_cast(nulls_); + } + + /// Base scalar data buffer; nullptr for complex types. + void* hashAggMutableRawData() { + return const_cast(data_); + } + + /// Dictionary / general indices buffer; only meaningful when + /// `hashAggIndicesLayoutMode() == 3` (call forces `fillInIndices()` when lazy). + vector_size_t* hashAggMutableIndices() { + if (!indices_) { + fillInIndices(); + } + BOLT_DCHECK(indices_); + return const_cast(indices_); + } + ///////////////////////////////////////////////////////////////// /// BEGIN: Members that must only be used by PeeledEncoding class. /// See class comment for more details. diff --git a/bolt/vector/tests/DecodedVectorTest.cpp b/bolt/vector/tests/DecodedVectorTest.cpp index 10537900d..af65d3673 100644 --- a/bolt/vector/tests/DecodedVectorTest.cpp +++ b/bolt/vector/tests/DecodedVectorTest.cpp @@ -1453,4 +1453,44 @@ TEST_F(DecodedVectorTest, previousIndicesInReUsedDecodedVector) { EXPECT_EQ(rawIndices[0], 0); } +TEST_F(DecodedVectorTest, hashAggLayoutModes) { + // Layout mode discriminators for HashAgg batch paths. + auto flat = makeFlatVector(10, [](auto row) { return row * 3; }); + DecodedVector decodedFlat(*flat); + EXPECT_EQ(0, decodedFlat.hashAggNullsLayoutMode()); + EXPECT_EQ(1, decodedFlat.hashAggIndicesLayoutMode()); + + auto constVec = makeConstant(42, 7); + DecodedVector decodedConst(*constVec); + EXPECT_EQ(0, decodedConst.hashAggNullsLayoutMode()); + EXPECT_EQ(2, decodedConst.hashAggIndicesLayoutMode()); + + auto nullConst = makeConstant(std::nullopt, 5); + DecodedVector decodedNullConst(*nullConst); + EXPECT_EQ(2, decodedNullConst.hashAggNullsLayoutMode()); + EXPECT_EQ(2, decodedNullConst.hashAggIndicesLayoutMode()); + + auto indices = makeIndices(5, [](auto row) { return row; }); + auto dict = BaseVector::wrapInDictionary(nullptr, indices, 5, flat); + DecodedVector decodedDict(*dict); + EXPECT_EQ(0, decodedDict.hashAggNullsLayoutMode()); + EXPECT_EQ(3, decodedDict.hashAggIndicesLayoutMode()); + ASSERT_NE(nullptr, decodedDict.hashAggMutableIndices()); + ASSERT_NE(nullptr, decodedDict.hashAggMutableRawData()); + + auto nullableFlat = + makeNullableFlatVector({1, std::nullopt, 3, 4, 5}); + DecodedVector decodedNullable(*nullableFlat); + EXPECT_EQ(1, decodedNullable.hashAggNullsLayoutMode()); + EXPECT_EQ(1, decodedNullable.hashAggIndicesLayoutMode()); + ASSERT_NE(nullptr, decodedNullable.hashAggMutableCombinedNullBits()); + + SelectivityVector rows(8, false); + rows.setValid(2, true); + rows.setValid(5, true); + rows.updateBounds(); + EXPECT_NE(nullptr, rows.allBits()); + EXPECT_LE(rows.begin(), rows.end()); +} + } // namespace bytedance::bolt::test