Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion bolt/functions/sparksql/aggregates/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ set(PAAGG_TYPES
)
specialize_template(PAAGG_GENERATED_FILES PercentileApproxAggregate.cpp.in ${PAAGG_TYPES})

set(BOLT_SPARK_SUM_INT64_SUBOP_SVE "")
if(CMAKE_SYSTEM_NAME STREQUAL "Linux" AND
CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64.*|AARCH64.*|arm64.*)")
list(APPEND BOLT_SPARK_SUM_INT64_SUBOP_SVE SumAggregateSparkInt64SubOpSve.cpp)
endif()

bolt_add_library(
bolt_functions_spark_aggregates
${PAAGG_GENERATED_FILES}
Expand All @@ -52,7 +58,8 @@ bolt_add_library(
PercentileApproxAggregate.cpp
Register.cpp
RegrReplacementAggregate.cpp
SumAggregate.cpp
SumAggregateSparkInt64SubOp.cpp
${BOLT_SPARK_SUM_INT64_SUBOP_SVE}
SumAggregate.cpp
)

Expand All @@ -61,6 +68,15 @@ target_link_libraries(
fmt::fmt
)

# SumAggregateSparkInt64SubOpSve.cpp uses <arm_sve.h>; default Bolt -march=armv8.3-a
# does not enable the SVE ISA; compile this Linux AArch64 TU with +sve only.
if(CMAKE_SYSTEM_NAME STREQUAL "Linux" AND
CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64.*|AARCH64.*|arm64.*)")
set_source_files_properties(
SumAggregateSparkInt64SubOpSve.cpp
PROPERTIES COMPILE_OPTIONS "-march=armv8-a+sve")
endif()

if(${BOLT_BUILD_TESTING})
add_subdirectory(tests)
endif()
36 changes: 36 additions & 0 deletions bolt/functions/sparksql/aggregates/SumAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,46 @@

#include "bolt/functions/lib/aggregates/SumAggregateBase.h"
#include "bolt/functions/sparksql/aggregates/DecimalSumAggregate.h"
#include "bolt/functions/sparksql/aggregates/SumAggregateSparkInt64SubOp.h"

#include <cctype>
#include <cstdlib>

using namespace bytedance::bolt::functions::aggregate;
namespace bytedance::bolt::functions::aggregate::sparksql {

namespace {
template <typename TInput, typename TAccumulator, typename ResultType>
using SumAggregate = SumAggregateBase<TInput, TAccumulator, ResultType>;

// Process env rollback for Spark BIGINT sum (non-decimal): default SubOp on.
// Disable with BOLT_SPARK_SUM_INT64_USE_SUBOP=0 / false / no / off (ASCII,
// case-insensitive for the words). Empty or unknown non-false β†’ SubOp on.
bool sparkSumInt64UseSubOpFromEnv() {
const char* v = std::getenv("BOLT_SPARK_SUM_INT64_USE_SUBOP");
if (v == nullptr || *v == '\0') {
return true;
}
if (v[0] == '0' && v[1] == '\0') {
return false;
}
auto eqNoCase = [](const char* a, const char* b) {
while (*a && *b) {
if (std::tolower(static_cast<unsigned char>(*a)) !=
std::tolower(static_cast<unsigned char>(*b))) {
return false;
}
++a;
++b;
}
return *a == *b;
};
if (eqNoCase(v, "false") || eqNoCase(v, "no") || eqNoCase(v, "off")) {
return false;
}
return true;
}

TypePtr getDecimalSumType(
const TypePtr& resultType,
core::AggregationNode::Step step) {
Expand Down Expand Up @@ -112,6 +145,9 @@ exec::AggregateRegistrationResult registerSum(
resultType, sumType);
}
}
if (sparkSumInt64UseSubOpFromEnv()) {
return std::make_unique<SumAggregateSparkInt64SubOp>(BIGINT());
}
return std::make_unique<SumAggregate<int64_t, int64_t, int64_t>>(
BIGINT());
}
Expand Down
187 changes: 187 additions & 0 deletions bolt/functions/sparksql/aggregates/SumAggregateSparkInt64SubOp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* --------------------------------------------------------------------------
* Copyright (c) ByteDance Ltd. and/or its affiliates.
* SPDX-License-Identifier: Apache-2.0
*
* This file has been modified by ByteDance Ltd. and/or its affiliates on
* 2025-11-11.
*
* Original file was released under the Apache License 2.0,
* with the full license text available at:
* http://www.apache.org/licenses/LICENSE-2.0
*
* This modified file is released under the same license.
* --------------------------------------------------------------------------
*/

#include "bolt/functions/sparksql/aggregates/SumAggregateSparkInt64SubOp.h"

#include "bolt/exec/AggregationHook.h"
#include "bolt/functions/lib/CheckedArithmeticImpl.h"
#include "bolt/vector/BaseVector.h"
#include "bolt/vector/DecodedVector.h"
#include "bolt/vector/LazyVector.h"
#include "bolt/vector/VectorEncoding.h"

#if defined(__aarch64__) && defined(__linux__)
#include <sys/auxv.h>
#endif

namespace bytedance::bolt::functions::aggregate::sparksql {

namespace {

static void sparkSumInt64UpdateSingle(int64_t& result, int64_t value) {
if (::bytedance::bolt::functions::aggregate::Overflow) {
result += value;
} else {
CHECK_ADD(result, value);
}
}

#if defined(__aarch64__) && defined(__linux__)
// SVE is advertised on AT_HWCAP (HWCAP_SVE), not AT_HWCAP2 β€” see Linux
// arch/arm64/include/uapi/asm/hwcap.h.
#ifndef HWCAP_SVE
constexpr unsigned long kBoltHwcapSve = 1UL << 22;
#else
constexpr unsigned long kBoltHwcapSve = HWCAP_SVE;
#endif

static bool linuxAarch64RuntimeHasSve() {
const unsigned long hwcap = getauxval(AT_HWCAP);
return (hwcap & kBoltHwcapSve) != 0;
}
#else
static bool linuxAarch64RuntimeHasSve() {
return false;
}
#endif

} // namespace

SumAggregateSparkInt64SubOp::SumAggregateSparkInt64SubOp(TypePtr resultType)
: Base(std::move(resultType)) {}

void SumAggregateSparkInt64SubOp::addRawInput(
char** groups,
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool mayPushdown) {
const auto& arg = args[0];

if (mayPushdown && arg->isLazy()) {
Base::addRawInput(groups, rows, args, mayPushdown);
return;
}

using ::bytedance::bolt::functions::aggregate::Overflow;
if (this->numNulls_ && Overflow) {
DecodedVector decoded(*arg, rows, !mayPushdown);
const auto encoding = decoded.base()->encoding();
if (mayPushdown && encoding == VectorEncoding::Simple::LAZY &&
!arg->type()->isDecimal()) {
bytedance::bolt::aggregate::SimpleCallableHook<
int64_t,
int64_t,
void (*)(int64_t&, int64_t)>
hook(
offset_,
nullByte_,
nullMask_,
groups,
&numNulls_,
sparkSumInt64UpdateSingle);
auto indices = decoded.indices();
decoded.base()->as<const LazyVector>()->load(
::bytedance::bolt::RowSet(indices, arg->size()), &hook);
return;
}
if (decoded.mayHaveNulls()) {
#if defined(__aarch64__)
if (linuxAarch64RuntimeHasSve() &&
updateGroupsFromDecoded(groups, rows, decoded)) {
return;
}
#endif
}
Base::addRawInput(groups, rows, args, mayPushdown);
return;
}

Base::addRawInput(groups, rows, args, mayPushdown);
}

void SumAggregateSparkInt64SubOp::addIntermediateResults(
char** groups,
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool mayPushdown) {
const auto& arg = args[0];

if (mayPushdown && arg->isLazy()) {
Base::addIntermediateResults(groups, rows, args, mayPushdown);
return;
}

using ::bytedance::bolt::functions::aggregate::Overflow;
if (this->numNulls_ && Overflow) {
DecodedVector decoded(*arg, rows, !mayPushdown);
const auto encoding = decoded.base()->encoding();
if (mayPushdown && encoding == VectorEncoding::Simple::LAZY &&
!arg->type()->isDecimal()) {
bytedance::bolt::aggregate::SimpleCallableHook<
int64_t,
int64_t,
void (*)(int64_t&, int64_t)>
hook(
offset_,
nullByte_,
nullMask_,
groups,
&numNulls_,
sparkSumInt64UpdateSingle);
auto indices = decoded.indices();
decoded.base()->as<const LazyVector>()->load(
::bytedance::bolt::RowSet(indices, arg->size()), &hook);
return;
}
if (decoded.mayHaveNulls()) {
#if defined(__aarch64__)
if (linuxAarch64RuntimeHasSve() &&
updateGroupsFromDecoded(groups, rows, decoded)) {
return;
}
#endif
}
Base::addIntermediateResults(groups, rows, args, mayPushdown);
return;
}

Base::addIntermediateResults(groups, rows, args, mayPushdown);
}

#if !defined(__aarch64__)
bool SumAggregateSparkInt64SubOp::updateGroupsFromDecoded(
char** /*groups*/,
const SelectivityVector& /*rows*/,
::bytedance::bolt::DecodedVector& /*decoded*/) {
return false;
}
#endif

} // namespace bytedance::bolt::functions::aggregate::sparksql
80 changes: 80 additions & 0 deletions bolt/functions/sparksql/aggregates/SumAggregateSparkInt64SubOp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* --------------------------------------------------------------------------
* Copyright (c) ByteDance Ltd. and/or its affiliates.
* SPDX-License-Identifier: Apache-2.0
*
* This file has been modified by ByteDance Ltd. and/or its affiliates on
* 2025-11-11.
*
* Original file was released under the Apache License 2.0,
* with the full license text available at:
* http://www.apache.org/licenses/LICENSE-2.0
*
* This modified file is released under the same license.
* --------------------------------------------------------------------------
*/

#pragma once

#include "bolt/functions/lib/aggregates/SumAggregateBase.h"

namespace bytedance::bolt {
class DecodedVector;
}

namespace bytedance::bolt::functions::aggregate::sparksql {

/// Spark sum(bigint)->bigint SubOp (default unless `BOLT_SPARK_SUM_INT64_USE_SUBOP`
/// is off; see `SumAggregate.cpp`). When `numNulls_`, Spark overflow gate, and
/// `decoded.mayHaveNulls()` hold, runs the AArch64 SVE batch kernel; otherwise
/// defers to `SumAggregateBase`.
class SumAggregateSparkInt64SubOp
: public ::bytedance::bolt::functions::aggregate::SumAggregateBase<
int64_t,
int64_t,
int64_t> {
using Base = ::bytedance::bolt::functions::aggregate::SumAggregateBase<
int64_t,
int64_t,
int64_t>;

public:
explicit SumAggregateSparkInt64SubOp(TypePtr resultType);

void addRawInput(
char** groups,
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool mayPushdown) override;

void addIntermediateResults(
char** groups,
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool mayPushdown) override;

private:
/// Applies `sveHashAggBatchUpdateGroupSums` for one decoded batch.
/// Returns true when the SVE path was taken (caller must not invoke Base for
/// this batch). Returns false only on non-aarch64 stub builds.
bool updateGroupsFromDecoded(
char** groups,
const SelectivityVector& rows,
::bytedance::bolt::DecodedVector& decoded);
};

} // namespace bytedance::bolt::functions::aggregate::sparksql
Loading