diff --git a/CMakeLists.txt b/CMakeLists.txt index 0c72259f3..e097acf03 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -458,6 +458,16 @@ if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsigned-char") endif() +# Keep frame pointers and avoid sibling-call optimization so perf can fully +# unwind stacks with the cheap frame-pointer call-graph (no DWARF needed). +# Off by default; enable with -DBOLT_ENABLE_FRAME_POINTER=ON when profiling. +option(BOLT_ENABLE_FRAME_POINTER + "Preserve frame pointers for perf stack unwinding" OFF) +if(BOLT_ENABLE_FRAME_POINTER) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer -mno-omit-leaf-frame-pointer -fno-optimize-sibling-calls") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fno-omit-frame-pointer -mno-omit-leaf-frame-pointer -fno-optimize-sibling-calls") +endif() + # Under Ninja, we are able to designate certain targets large enough to require restricted # parallelism. if("${MAX_HIGH_MEM_JOBS}") diff --git a/Makefile b/Makefile index 4b2d1a8d5..3817355d2 100644 --- a/Makefile +++ b/Makefile @@ -83,6 +83,11 @@ BUILD_TYPE=Release BOLT_BUILD_BENCHMARKS ?= "OFF" # Control whether to build tests with coverage instrumentation BOLT_BUILD_TESTING_WITH_COVERAGE ?= "OFF" +# Control whether to keep frame pointers (for perf stack unwinding) +BOLT_ENABLE_FRAME_POINTER ?= "OFF" +# Control whether to report JIT symbols to Intel VTune (jitprofiling) +BOLT_ENABLE_VTUNE_JIT ?= "OFF" +VTUNE_SDK_DIR ?= # ----------------------------------------------------------------- # TODO: remove `BUILD_USER` and `BUILD_CHANNEL` @@ -204,6 +209,9 @@ conan_build: conan_install NUM_THREADS=$(NUM_THREADS) \ BOLT_BUILD_BENCHMARKS=${BOLT_BUILD_BENCHMARKS} \ BOLT_BUILD_TESTING_WITH_COVERAGE=${BOLT_BUILD_TESTING_WITH_COVERAGE} \ + BOLT_ENABLE_FRAME_POINTER=${BOLT_ENABLE_FRAME_POINTER} \ + BOLT_ENABLE_VTUNE_JIT=${BOLT_ENABLE_VTUNE_JIT} \ + VTUNE_SDK_DIR=${VTUNE_SDK_DIR} \ conan build ../.. --name=bolt --version=${BUILD_VERSION} --user=${BUILD_USER} --channel=${BUILD_CHANNEL} \ -s llvm-core/*:build_type=Release \ -s "&:build_type=${BUILD_TYPE}" \ @@ -306,6 +314,17 @@ benchmarks-build-spark: benchmarks-build-relwithdebinfo: $(MAKE) conan_build BUILD_TYPE=RelWithDebInfo BOLT_BUILD_BENCHMARKS="ON" CONAN_CONFIG=" -c bolt/*:tools.build:skip_test=False" CONAN_OPTIONS="-o bolt/*:spark_compatible=False -o bolt/*:enable_testutil=True -o bolt/*:enable_perf=True" +# Same as benchmarks-build-spark but keeps frame pointers so perf can unwind +# stacks with the cheap frame-pointer call-graph (no DWARF needed). +benchmarks-build-spark-profile: + $(MAKE) conan_build BUILD_TYPE=Release BOLT_BUILD_BENCHMARKS="ON" BOLT_ENABLE_FRAME_POINTER="ON" CONAN_CONFIG=" -c bolt/*:tools.build:skip_test=False" CONAN_OPTIONS="-o bolt/*:spark_compatible=True -o bolt/*:enable_testutil=True -o bolt/*:enable_perf=True" + +# Same as benchmarks-build-spark-profile but also reports JIT symbols to Intel +# VTune (libjitprofiling). Override the SDK path with VTUNE_SDK_DIR=... if it is +# not under the default /opt/intel/oneapi/vtune/2023.2.0/sdk. +benchmarks-build-spark-vtune: + $(MAKE) conan_build BUILD_TYPE=Release BOLT_BUILD_BENCHMARKS="ON" BOLT_ENABLE_FRAME_POINTER="ON" BOLT_ENABLE_VTUNE_JIT="ON" VTUNE_SDK_DIR="${VTUNE_SDK_DIR}" CONAN_CONFIG=" -c bolt/*:tools.build:skip_test=False" CONAN_OPTIONS="-o bolt/*:spark_compatible=True -o bolt/*:enable_testutil=True -o bolt/*:enable_perf=True" + unittest_debug: unittest unittest: debug_with_test ctest --test-dir $(BUILD_BASE_DIR)/Debug --timeout 7200 -j $(NUM_THREADS) --output-on-failure diff --git a/bolt/common/base/AggregationStats.h b/bolt/common/base/AggregationStats.h index dc07b70b0..31c1a8fc9 100644 --- a/bolt/common/base/AggregationStats.h +++ b/bolt/common/base/AggregationStats.h @@ -28,5 +28,12 @@ struct AggregationStats { uint64_t aggOutputTimeNs{0}; uint64_t aggProbeBypassTimeNs{0}; uint64_t aggProbeBypassCount{0}; + // Hash aggregation JIT fine-grained timing. + // One-time codegen (LLVM compile) time for the JIT plan. + uint64_t aggJitCodegenTimeNs{0}; + // JIT-executed part of the agg function update time. + uint64_t aggFunctionJitTimeNs{0}; + // JIT-executed part of the extracting groups time. + uint64_t aggExtractGroupsJitTimeNs{0}; }; } // namespace bytedance::bolt::common diff --git a/bolt/core/QueryConfig.h b/bolt/core/QueryConfig.h index db5566da4..08b646c1b 100644 --- a/bolt/core/QueryConfig.h +++ b/bolt/core/QueryConfig.h @@ -652,6 +652,12 @@ class QueryConfig { */ static constexpr const char* kJitLevel = "jit.level"; + static constexpr const char* kHashAggrJitEnabled = "hashaggr.jit.enabled"; + static constexpr const char* kHashAggrJitMinFuseWidth = + "hashaggr.jit.min_fuse_width"; + static constexpr const char* kHashAggrJitMaxFuseWidth = + "hashaggr.jit.max_fuse_width"; + // expired, to deleted later static constexpr const char* kBoltJitEnabled = "bolt.jit.enabled"; // For morsel-driven Bolt @@ -1606,6 +1612,18 @@ class QueryConfig { return flag & 1; } + bool enableHashAggrJit() const { + return get(kHashAggrJitEnabled, true); + } + + int32_t hashAggrJitMinFuseWidth() const { + return get(kHashAggrJitMinFuseWidth, 1); + } + + int32_t hashAggrJitMaxFuseWidth() const { + return get(kHashAggrJitMaxFuseWidth, 16); + } + int exceptionTraceLevel() const { return get(kExceptionTraceLevel, 1); } diff --git a/bolt/exec/Aggregate.cpp b/bolt/exec/Aggregate.cpp index c8f586b9a..824f15077 100644 --- a/bolt/exec/Aggregate.cpp +++ b/bolt/exec/Aggregate.cpp @@ -321,4 +321,27 @@ void Aggregate::clearInternal() { numNulls_ = 0; } +#ifdef ENABLE_BOLT_JIT +bool Aggregate::supportsHashAggrJit( + const jit::HashAggrJitPlanContext& /*context*/) const { + return false; +} + +std::optional Aggregate::createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& /*context*/) const { + return std::nullopt; +} + +jit::HashAggrJitSlot Aggregate::createHashAggrJitSlot( + int32_t aggregateIndex, + const jit::HashAggrJitDescriptor& descriptor) const { + return jit::HashAggrJitSlot{ + .aggregateIndex = aggregateIndex, + .offset = accumulatorOffset(), + .nullByte = accumulatorNullByte(), + .nullMask = accumulatorNullMask(), + .desc = descriptor}; +} +#endif + } // namespace bytedance::bolt::exec diff --git a/bolt/exec/Aggregate.h b/bolt/exec/Aggregate.h index 433abd8e5..d6ec0c88f 100644 --- a/bolt/exec/Aggregate.h +++ b/bolt/exec/Aggregate.h @@ -33,12 +33,17 @@ #include #include +#include + #include "bolt/common/memory/HashStringAllocator.h" #include "bolt/core/PlanNode.h" #include "bolt/core/QueryConfig.h" #include "bolt/exec/AggregateUtil.h" #include "bolt/expression/FunctionSignature.h" #include "bolt/functions/InlineFlatten.h" +#ifdef ENABLE_BOLT_JIT +#include "bolt/jit/aggregation/HashAggrJitTypes.h" +#endif #include "bolt/vector/BaseVector.h" namespace bytedance::bolt::core { class ExpressionEvaluator; @@ -66,6 +71,18 @@ class Aggregate { return resultType_; } + int32_t accumulatorOffset() const { + return offset_; + } + + int32_t accumulatorNullByte() const { + return nullByte_; + } + + uint8_t accumulatorNullMask() const { + return nullMask_; + } + // Returns the fixed number of bytes the accumulator takes on a group // row. Variable width accumulators will reference the variable // width part of the state from the fixed part. @@ -100,6 +117,18 @@ class Aggregate { return false; } +#ifdef ENABLE_BOLT_JIT + virtual bool supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const; + + virtual std::optional createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const; + + jit::HashAggrJitSlot createHashAggrJitSlot( + int32_t aggregateIndex, + const jit::HashAggrJitDescriptor& descriptor) const; +#endif + void setAllocator(HashStringAllocator* allocator) { setAllocatorInternal(allocator); pool_ = allocator->pool(); @@ -133,6 +162,10 @@ class Aggregate { setOffsetsInternal(offset, nullByte, nullMask, rowSizeOffset); } + void markNullCountUnknown() { + numNulls_ = std::nullopt; + } + // Initializes null flags and accumulators for newly encountered groups. This // function should be called only once for each group. // @@ -360,7 +393,15 @@ class Aggregate { } bool isNull(char* group) const { - return numNulls_ && (group[nullByte_] & nullMask_); + return mayHaveNulls() && (group[nullByte_] & nullMask_); + } + + bool hasNoNulls() const { + return numNulls_.has_value() && *numNulls_ == 0; + } + + bool mayHaveNulls() const { + return !numNulls_.has_value() || *numNulls_ > 0; } // Sets null flag for all specified groups to true. @@ -369,7 +410,9 @@ class Aggregate { for (auto i : indices) { groups[i][nullByte_] |= nullMask_; } - numNulls_ += indices.size(); + if (numNulls_.has_value()) { + *numNulls_ += indices.size(); + } } inline bool setNull(char* group) { @@ -377,18 +420,23 @@ class Aggregate { return false; } group[nullByte_] |= nullMask_; - ++numNulls_; + if (numNulls_.has_value()) { + ++*numNulls_; + } return true; } inline bool clearNull(char* group) { - if (numNulls_) { - uint8_t mask = group[nullByte_]; - if (mask & nullMask_) { - group[nullByte_] = mask & ~nullMask_; - --numNulls_; - return true; + if (!mayHaveNulls()) { + return false; + } + uint8_t mask = group[nullByte_]; + if (mask & nullMask_) { + group[nullByte_] = mask & ~nullMask_; + if (numNulls_.has_value()) { + --*numNulls_; } + return true; } return false; } @@ -449,9 +497,11 @@ class Aggregate { int32_t rowSizeOffset_ = 0; // Number of null accumulators in the current state of the aggregation - // operator for this aggregate. If 0, clearing the null as part of update - // is not needed. - uint64_t numNulls_ = 0; + // operator for this aggregate. + // - 0 => known that no group is null + // - N > 0 => known exact null count + // - nullopt => unknown; must rely on per-group null bit + std::optional numNulls_{0}; HashStringAllocator* allocator_{nullptr}; memory::MemoryPool* pool_{nullptr}; std::shared_ptr expressionEvaluator_{nullptr}; diff --git a/bolt/exec/AggregateCompanionAdapter.cpp b/bolt/exec/AggregateCompanionAdapter.cpp index ab082babf..5d185e127 100644 --- a/bolt/exec/AggregateCompanionAdapter.cpp +++ b/bolt/exec/AggregateCompanionAdapter.cpp @@ -35,6 +35,7 @@ #include "bolt/exec/RowContainer.h" #include "bolt/expression/SignatureBinder.h" #include "bolt/functions/lib/aggregates/AggregateToIntermediate.h" + namespace bytedance::bolt::exec { void AggregateCompanionFunctionBase::setOffsetsInternal( @@ -42,6 +43,7 @@ void AggregateCompanionFunctionBase::setOffsetsInternal( int32_t nullByte, uint8_t nullMask, int32_t rowSizeOffset) { + Aggregate::setOffsetsInternal(offset, nullByte, nullMask, rowSizeOffset); fn_->setOffsets(offset, nullByte, nullMask, rowSizeOffset); } @@ -65,6 +67,19 @@ bool AggregateCompanionFunctionBase::supportsToIntermediate() const { return fn_->supportsToIntermediate(); } +#ifdef ENABLE_BOLT_JIT +bool AggregateCompanionFunctionBase::supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const { + return fn_->supportsHashAggrJit(rewriteHashAggrJitContext(context)); +} + +std::optional +AggregateCompanionFunctionBase::createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const { + return fn_->createHashAggrJitDescriptor(rewriteHashAggrJitContext(context)); +} +#endif + bool AggregateCompanionFunctionBase::supportAccumulatorSerde() const { return fn_->supportAccumulatorSerde(); } diff --git a/bolt/exec/AggregateCompanionAdapter.h b/bolt/exec/AggregateCompanionAdapter.h index 57586c8b7..6337f55b4 100644 --- a/bolt/exec/AggregateCompanionAdapter.h +++ b/bolt/exec/AggregateCompanionAdapter.h @@ -33,6 +33,7 @@ #include "bolt/common/memory/HashStringAllocator.h" #include "bolt/exec/Aggregate.h" #include "bolt/expression/VectorFunction.h" + namespace bytedance::bolt::exec { class AggregateCompanionFunctionBase : public Aggregate { @@ -52,6 +53,23 @@ class AggregateCompanionFunctionBase : public Aggregate { bool supportsToIntermediate() const override final; +#ifdef ENABLE_BOLT_JIT + // Companion functions expose a different aggregation stage than the operator + // step implies (e.g. a merge companion runs at the kSingle step but consumes + // intermediate state and emits intermediate state). Each companion rewrites + // the planning context into the stage it actually represents by flipping the + // isRawInput/isPartialOutput flags; the absolute types in the context are + // stage-agnostic, so the derived input/output views follow automatically. + virtual jit::HashAggrJitPlanContext rewriteHashAggrJitContext( + const jit::HashAggrJitPlanContext& context) const = 0; + + bool supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const override; + + std::optional createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const override; +#endif + bool supportAccumulatorSerde() const override final; uint32_t getAccumulatorSerializeSize(char* group) const override final; @@ -124,6 +142,18 @@ struct AggregateCompanionAdapter { const TypePtr& resultType) : AggregateCompanionFunctionBase{std::move(fn), resultType} {} +#ifdef ENABLE_BOLT_JIT + // Partial companion: reads raw input and emits intermediate state. + jit::HashAggrJitPlanContext rewriteHashAggrJitContext( + const jit::HashAggrJitPlanContext& context) const override { + auto rewritten = context; + rewritten.isRawInput = true; + rewritten.isPartialOutput = true; + rewritten.resultType = fn_->resultType(); + return rewritten; + } +#endif + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) override; }; @@ -139,7 +169,7 @@ struct AggregateCompanionAdapter { void toIntermediate( const SelectivityVector& rows, std::vector& args, - VectorPtr& result) const override final; + VectorPtr& result) const final; void addRawInput( char** groups, @@ -153,6 +183,18 @@ struct AggregateCompanionAdapter { const std::vector& args, bool mayPushdown) override; +#ifdef ENABLE_BOLT_JIT + // Merge companion: reads intermediate state and emits intermediate state. + jit::HashAggrJitPlanContext rewriteHashAggrJitContext( + const jit::HashAggrJitPlanContext& context) const override { + auto rewritten = context; + rewritten.isRawInput = false; + rewritten.isPartialOutput = true; + rewritten.resultType = fn_->resultType(); + return rewritten; + } +#endif + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) override; }; @@ -164,6 +206,19 @@ struct AggregateCompanionAdapter { const TypePtr& resultType) : MergeFunction{std::move(fn), resultType} {} +#ifdef ENABLE_BOLT_JIT + // Merge-extract companion: reads intermediate state and emits the final + // result. + jit::HashAggrJitPlanContext rewriteHashAggrJitContext( + const jit::HashAggrJitPlanContext& context) const override { + auto rewritten = context; + rewritten.isRawInput = false; + rewritten.isPartialOutput = false; + rewritten.resultType = fn_->resultType(); + return rewritten; + } +#endif + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) override; }; diff --git a/bolt/exec/AggregateInfo.cpp b/bolt/exec/AggregateInfo.cpp index 87ee2b655..0561087b6 100644 --- a/bolt/exec/AggregateInfo.cpp +++ b/bolt/exec/AggregateInfo.cpp @@ -71,6 +71,8 @@ std::vector toAggregateInfo( for (auto i = 0; i < numAggregates; i++) { const auto& aggregate = aggregationNode.aggregates()[i]; AggregateInfo info; + info.name = aggregate.call->name(); + info.rawInputTypes = aggregate.rawInputTypes; // Populate input. auto& channels = info.inputs; auto& constants = info.constantInputs; diff --git a/bolt/exec/AggregateInfo.h b/bolt/exec/AggregateInfo.h index d1c4d2ff9..e124a6984 100644 --- a/bolt/exec/AggregateInfo.h +++ b/bolt/exec/AggregateInfo.h @@ -41,6 +41,14 @@ struct AggregateInfo { /// Instance of the Aggregate class. std::unique_ptr function; + /// Function name used to create the aggregate. Kept for optional fast paths + /// such as aggregation JIT planning. + std::string name; + + /// Raw input types from the plan. Kept to identify JIT-able numeric variants + /// without depending on concrete Aggregate subclasses. + std::vector rawInputTypes; + /// Indices of the input columns in the input RowVector. std::vector inputs; diff --git a/bolt/exec/AggregationHook.h b/bolt/exec/AggregationHook.h index b1f25b1d0..390c3fb18 100644 --- a/bolt/exec/AggregationHook.h +++ b/bolt/exec/AggregationHook.h @@ -30,6 +30,8 @@ #pragma once +#include + #include "bolt/common/base/CheckedArithmetic.h" #include "bolt/common/base/Range.h" #include "bolt/vector/LazyVector.h" @@ -60,7 +62,7 @@ class AggregationHook : public ValueHook { int32_t nullByte, uint8_t nullMask, char** groups, - uint64_t* numNulls) + std::optional& numNulls) : offset_(offset), nullByte_(nullByte), nullMask_(nullMask), @@ -91,11 +93,13 @@ class AggregationHook : public ValueHook { } inline bool clearNull(char* group) { - if (*numNulls_) { + if (!numNulls_.has_value() || *numNulls_ > 0) { uint8_t mask = group[nullByte_]; if (mask & nullMask_) { group[nullByte_] = mask & clearNullMask_; - --*numNulls_; + if (numNulls_.has_value()) { + --*numNulls_; + } return true; } } @@ -108,7 +112,7 @@ class AggregationHook : public ValueHook { const uint8_t nullMask_; const uint8_t clearNullMask_; char* const* const groups_; - uint64_t* numNulls_; + std::optional& numNulls_; }; namespace { @@ -132,7 +136,7 @@ class SumHook final : public AggregationHook { int32_t nullByte, uint8_t nullMask, char** groups, - uint64_t* numNulls) + std::optional& numNulls) : AggregationHook(offset, nullByte, nullMask, groups, numNulls) {} Kind kind() const override { @@ -174,7 +178,7 @@ class SimpleCallableHook final : public AggregationHook { int32_t nullByte, uint8_t nullMask, char** groups, - uint64_t* numNulls, + std::optional& numNulls, UpdateSingleValue updateSingleValue) : AggregationHook(offset, nullByte, nullMask, groups, numNulls), updateSingleValue_(updateSingleValue) {} @@ -203,7 +207,7 @@ class MinMaxHook final : public AggregationHook { int32_t nullByte, uint8_t nullMask, char** groups, - uint64_t* numNulls) + std::optional& numNulls) : AggregationHook(offset, nullByte, nullMask, groups, numNulls) {} Kind kind() const override { diff --git a/bolt/exec/CMakeLists.txt b/bolt/exec/CMakeLists.txt index 7f1a6421d..0b3221240 100644 --- a/bolt/exec/CMakeLists.txt +++ b/bolt/exec/CMakeLists.txt @@ -87,6 +87,8 @@ bolt_add_library( ProbeOperatorState.cpp PushBasedEvent.cpp RowContainer.cpp + ../jit/aggregation/runtime/HashAggrRuntime.cpp + ../jit/aggregation/runtime/HashAggrDecimalRuntime.cpp RowNumber.cpp RowsStreamingWindowBuild.cpp RowsStreamingWindowPartition.cpp diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 2c6e5bad4..0a10781e1 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -29,6 +29,8 @@ */ #include "bolt/exec/GroupingSet.h" +#include +#include #include "bolt/common/base/Exceptions.h" #include "bolt/common/base/SpillConfig.h" #include "bolt/common/testutil/TestValue.h" @@ -36,8 +38,13 @@ #include "bolt/exec/ContainerRow2RowSerde.h" #include "bolt/exec/OperatorUtils.h" #include "bolt/exec/RowToColumnVector.h" +#ifdef ENABLE_BOLT_JIT +#include +#include "bolt/jit/aggregation/HashAggrJit.h" +#endif #include "bolt/type/Type.h" #include "bolt/vector/ComplexVector.h" +#include "bolt/vector/FlatVector.h" using bytedance::bolt::common::testutil::TestValue; namespace bytedance::bolt::exec { @@ -59,6 +66,217 @@ bool areAllLazyNotLoaded(const std::vector& vectors) { }); } +#ifdef ENABLE_BOLT_JIT +std::string hashAggrJitTypeName(const TypePtr& type) { + return type == nullptr ? "null" : type->toString(); +} + +std::optional hashAggrJitOutputValueKind( + const BaseVector* vector) { + const auto& type = vector->type(); + if (type->isShortDecimal()) { + return jit::HashAggrJitValueKind::Int64; + } + if (type->isLongDecimal()) { + return jit::HashAggrJitValueKind::Int128; + } + return jit::hashAggrJitValueKind(type->kind()); +} + +void* hashAggrJitRawOutputData( + BaseVector* vector, + jit::HashAggrJitValueKind kind) { + switch (kind) { + case jit::HashAggrJitValueKind::Bool: + return const_cast(vector->valuesAsVoid()); + case jit::HashAggrJitValueKind::Int8: + return vector->asUnchecked>()->mutableRawValues(); + case jit::HashAggrJitValueKind::Int16: + return vector->asUnchecked>()->mutableRawValues(); + case jit::HashAggrJitValueKind::Int32: + return vector->asUnchecked>()->mutableRawValues(); + case jit::HashAggrJitValueKind::Int64: + return vector->asUnchecked>()->mutableRawValues(); + case jit::HashAggrJitValueKind::Float: + return vector->asUnchecked>()->mutableRawValues(); + case jit::HashAggrJitValueKind::Double: + return vector->asUnchecked>()->mutableRawValues(); + case jit::HashAggrJitValueKind::Int128: + return vector->asUnchecked>()->mutableRawValues(); + } + return nullptr; +} + +std::optional hashAggrJitInputValueKind( + const BaseVector* vector) { + const auto& type = vector->type(); + if (type->isShortDecimal()) { + return jit::HashAggrJitValueKind::Int64; + } + if (type->isLongDecimal()) { + return jit::HashAggrJitValueKind::Int128; + } + return jit::hashAggrJitValueKind(type->kind()); +} + +const void* hashAggrJitRawInputData( + const BaseVector* vector, + jit::HashAggrJitValueKind kind) { + switch (kind) { + case jit::HashAggrJitValueKind::Bool: + return vector->valuesAsVoid(); + case jit::HashAggrJitValueKind::Int8: + return vector->asUnchecked>()->rawValues(); + case jit::HashAggrJitValueKind::Int16: + return vector->asUnchecked>()->rawValues(); + case jit::HashAggrJitValueKind::Int32: + return vector->asUnchecked>()->rawValues(); + case jit::HashAggrJitValueKind::Int64: + return vector->asUnchecked>()->rawValues(); + case jit::HashAggrJitValueKind::Int128: + return vector->asUnchecked>()->rawValues(); + case jit::HashAggrJitValueKind::Float: + return vector->asUnchecked>()->rawValues(); + case jit::HashAggrJitValueKind::Double: + return vector->asUnchecked>()->rawValues(); + } + return nullptr; +} + +bool fillHashAggrJitRowInputRuntime( + jit::HashAggrJitInputRuntime& input, + std::vector& children, + std::vector& childPtrs, + DecodedVector& decoded, + const SelectivityVector& rows, + const jit::HashAggrJitSlot& slot) { + // The raw row-field fast path applies to merge inputs whose intermediate + // representation is ROW(field0, field1): avg's ROW(sum, count) and decimal + // sum's ROW(sum, isEmpty). For both, field0 is the running sum read on the + // hot merge path; populating its raw pointer lets generated IR load it + // directly instead of calling the per-row jit_GetDecodedRowField* helper + // (which rebuilds a field DecodedVector on every call). + if (slot.desc.inputShape() != jit::HashAggrJitRuntimeShape::Row) { + return false; + } + const auto* base = decoded.base(); + if (base == nullptr || base->encoding() != VectorEncoding::Simple::ROW) { + return false; + } + const auto* rowVector = base->asUnchecked(); + if (rowVector->childrenSize() == 0) { + return false; + } + const auto numChildren = rowVector->childrenSize(); + children.resize(numChildren); + childPtrs.resize(numChildren); + for (auto field = 0; field < numChildren; ++field) { + const auto& childVector = rowVector->childAt(field); + if (childVector->encoding() != VectorEncoding::Simple::FLAT) { + return false; + } + const auto kind = hashAggrJitInputValueKind(childVector.get()); + if (!kind.has_value()) { + return false; + } + children[field] = jit::HashAggrJitScalarInputRuntime{ + .values = hashAggrJitRawInputData(childVector.get(), *kind), + .indices = decoded.indices(), + .nulls = childVector->rawNulls()}; + childPtrs[field] = &children[field]; + } + input.row = jit::HashAggrJitRowInputRuntime{ + .nulls = decoded.nulls(&rows), + .children = childPtrs.data(), + .numChildren = static_cast(numChildren)}; + return true; +} + +// Fills the raw child pointers for a ROW output runtime. Returns false when any +// ROW child is not FLAT (e.g. dictionary/constant wrapped), in which case the +// JIT fast path is not applicable and the caller must fall back to the non-JIT +// extract path. +bool fillHashAggrJitRowOutputRuntime( + jit::HashAggrJitOutputRuntime& output, + std::vector& children, + std::vector& childPtrs, + BaseVector* vector) { + auto* rowVector = vector->asUnchecked(); + if (rowVector->childrenSize() == 0) { + return false; + } + const auto numChildren = rowVector->childrenSize(); + children.resize(numChildren); + childPtrs.resize(numChildren); + for (auto field = 0; field < numChildren; ++field) { + auto& childVector = rowVector->childAt(field); + if (childVector->encoding() != VectorEncoding::Simple::FLAT) { + return false; + } + const auto kind = hashAggrJitOutputValueKind(childVector.get()); + if (!kind.has_value()) { + return false; + } + children[field] = jit::HashAggrJitScalarOutputRuntime{ + .values = hashAggrJitRawOutputData(childVector.get(), *kind), + .nulls = childVector->mutableRawNulls(), + .vector = childVector.get()}; + childPtrs[field] = &children[field]; + } + output.row = jit::HashAggrJitRowOutputRuntime{ + .nulls = vector->mutableRawNulls(), + .children = childPtrs.data(), + .numChildren = static_cast(numChildren), + .vector = vector}; + return true; +} + +void resetHashAggrJitOutputDataDependentFlags( + BaseVector* vector, + const jit::HashAggrJitSlot& slot) { + vector->resetDataDependentFlags(nullptr); + if (slot.desc.outputShape() != jit::HashAggrJitRuntimeShape::Row) { + return; + } + + auto* rowVector = vector->asUnchecked(); + for (auto i = 0; i < rowVector->childrenSize(); ++i) { + rowVector->childAt(i)->resetDataDependentFlags(nullptr); + } +} + +#endif + +std::optional makeHashAggrJitSlot( + int32_t aggregateIndex, + const AggregateInfo& aggregate, + bool isRawInput, + bool isPartialOutput) { + if (aggregate.distinct || aggregate.mask.has_value() || + !aggregate.sortingKeys.empty()) { + return std::nullopt; + } + + // Fill the stage-agnostic absolute types and let the context derive the + // stage-specific input/output view from the flags. Companion functions may + // later flip the flags (via rewriteHashAggrJitContext) without needing to + // re-pick types here. + const jit::HashAggrJitPlanContext context{ + .isRawInput = isRawInput, + .isPartialOutput = isPartialOutput, + .rawInputTypes = aggregate.rawInputTypes, + .intermediateType = aggregate.intermediateType, + .resultType = aggregate.function->resultType()}; + if (!aggregate.function->supportsHashAggrJit(context)) { + return std::nullopt; + } + auto descriptor = aggregate.function->createHashAggrJitDescriptor(context); + if (!descriptor.has_value() || descriptor->ops == nullptr) { + return std::nullopt; + } + return aggregate.function->createHashAggrJitSlot(aggregateIndex, *descriptor); +} + } // namespace GroupingSet::GroupingSet( @@ -140,6 +358,10 @@ GroupingSet::GroupingSet( } GroupingSet::~GroupingSet() { +#ifdef ENABLE_BOLT_JIT + // Ensure no background compilation task still references our chunks. + waitForHashAggrJitCompilation(); +#endif if (isGlobal_) { destroyGlobalAggregations(); } @@ -289,7 +511,14 @@ void GroupingSet::addInputForActiveRows( NanosecondTimer funcTimer(&stats_.aggFunctionTimeNs); auto* groups = lookup_->hits.data(); auto& newGroups = lookup_->newGroups; + std::vector jitExecuted; +#ifdef ENABLE_BOLT_JIT + runHashAggrJitAddChunks(groups, newGroups, input, mayPushdown, jitExecuted); +#endif for (auto i = 0; i < aggregates_.size(); ++i) { + if (!jitExecuted.empty() && jitExecuted[i]) { + continue; + } if (!aggregates_[i].sortingKeys.empty()) { continue; } @@ -419,6 +648,9 @@ void GroupingSet::createHashTable() { RowContainer& rows = *table_->rows(); initializeAggregates(aggregates_, rows, false); +#ifdef ENABLE_BOLT_JIT + maybeCreateHashAggrJitPlan(); +#endif auto numColumns = rows.keyTypes().size() + aggregates_.size(); @@ -771,6 +1003,381 @@ const SelectivityVector& GroupingSet::getSelectivityVector( return *rows; } +#ifdef ENABLE_BOLT_JIT +void GroupingSet::waitForHashAggrJitCompilation() { + for (auto& future : hashAggrJitCompileFutures_) { + if (future.valid()) { + stats_.aggJitCodegenTimeNs += std::move(future).get(); + } + } + hashAggrJitCompileFutures_.clear(); +} + +void GroupingSet::maybeCreateHashAggrJitPlan() { + // Wait for any background compilation tasks from a previous plan before + // tearing down the chunks they reference. + waitForHashAggrJitCompilation(); + hashAggrJitChunks_.clear(); + if (!queryConfig_.enableHashAggrJit() || isGlobal_ || ignoreNullKeys_) { + LOG(INFO) << "HashAggrJit plan disabled: enableHashAggrJit=" + << queryConfig_.enableHashAggrJit() << " isGlobal=" << isGlobal_ + << " ignoreNullKeys=" << ignoreNullKeys_; + return; + } + + const auto minFuseWidth = + std::max(1, queryConfig_.hashAggrJitMinFuseWidth()); + const auto maxFuseWidth = + std::max(1, queryConfig_.hashAggrJitMaxFuseWidth()); + const auto minChunkWidth = minFuseWidth; + LOG(INFO) << "HashAggrJit planning starts: isRawInput=" << isRawInput_ + << " isPartial=" << isPartial_ + << " aggregates=" << aggregates_.size() + << " minFuseWidth=" << minFuseWidth + << " maxFuseWidth=" << maxFuseWidth + << " minChunkWidth=" << minChunkWidth; + std::vector currentChunkSlots; + currentChunkSlots.reserve(maxFuseWidth); + + auto flushChunk = [&]() { + if (currentChunkSlots.size() < minChunkWidth) { + if (!currentChunkSlots.empty()) { + LOG(INFO) << "HashAggrJit discard chunk candidate due to width " + << currentChunkSlots.size() << " < " << minChunkWidth + << "."; + } + currentChunkSlots.clear(); + return; + } + // Register the chunk as not-ready and compile it in the background. The + // query thread falls back to the non-JIT path for slots whose chunk is not + // yet ready (see runHashAggrJitAddChunks / runHashAggrJitExtractChunks), + // then switches to JIT once compilation completes. Submitting all chunks + // up front lets the global CPU executor materialize them in parallel. + hashAggrJitChunks_.push_back( + std::make_unique(std::move(currentChunkSlots))); + auto* chunk = hashAggrJitChunks_.back().get(); + LOG(INFO) << "HashAggrJit formed chunk (compiling in background): " + << chunk->getDescription(); + hashAggrJitCompileFutures_.push_back( + folly::via(folly::getGlobalCPUExecutor().get(), [chunk]() -> uint64_t { + const auto start = std::chrono::steady_clock::now(); + if (!chunk->codegen()) { + LOG(INFO) << "HashAggrJit chunk codegen failed for chunk " + << chunk->functionName(); + } + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + })); + currentChunkSlots.clear(); + currentChunkSlots.reserve(maxFuseWidth); + }; + + for (auto i = 0; i < aggregates_.size(); ++i) { + auto slot = makeHashAggrJitSlot(i, aggregates_[i], isRawInput_, isPartial_); + if (!slot.has_value()) { + LOG(INFO) << "HashAggrJit aggregate is not JIT-able: agg#" << i << "(" + << aggregates_[i].name << ") isRawInput=" << isRawInput_ + << " isPartialOutput=" << isPartial_ << " inputTypes=[" + << [&]() { + std::ostringstream out; + if (isRawInput_) { + for (size_t j = 0; j < aggregates_[i].rawInputTypes.size(); ++j) { + if (j > 0) { + out << ", "; + } + out << hashAggrJitTypeName(aggregates_[i].rawInputTypes[j]); + } + } else { + out << hashAggrJitTypeName(aggregates_[i].intermediateType); + } + return out.str(); + }() + << "] distinct=" << aggregates_[i].distinct + << " mask=" << aggregates_[i].mask.has_value() + << " sortingKeys=" << aggregates_[i].sortingKeys.size() + << " inputs=" << aggregates_[i].inputs.size() + << " outputType=" + << hashAggrJitTypeName( + isPartial_ ? aggregates_[i].intermediateType + : aggregates_[i].function->resultType()); + flushChunk(); + continue; + } + + if (currentChunkSlots.size() >= maxFuseWidth) { + flushChunk(); + } + VLOG(1) << "HashAggrJit aggregate is JIT-able: " + << slot->getDescription(); + currentChunkSlots.push_back(*slot); + } + + flushChunk(); + LOG(INFO) << "HashAggrJit planning finished: totalChunks=" + << hashAggrJitChunks_.size(); +} + +void GroupingSet::runHashAggrJitAddChunks( + char** groups, + folly::Range newGroups, + const RowVectorPtr& input, + bool mayPushdown, + std::vector& jitExecuted) { + if (hashAggrJitChunks_.empty()) { + return; + } + + if (hasSpilled() || bypassProbeHT_ || supportRowBasedOutput_ || + !activeRows_.isAllSelected()) { + LOG_FIRST_N(INFO, 10) + << "HashAggrJit add skipped: chunks=" << hashAggrJitChunks_.size() + << " hasSpilled=" << hasSpilled() + << " bypassProbeHT=" << bypassProbeHT_ + << " supportRowBasedOutput=" << supportRowBasedOutput_ + << " activeRowsAllSelected=" << activeRows_.isAllSelected(); + return; + } + + jitExecuted.assign(aggregates_.size(), 0); + std::vector decoded; + std::vector inputRuntimes; + std::vector> rowChildren; + std::vector> + rowChildPtrs; + // Keeps input vectors alive for the DecodedVector buffers referenced by JIT + // during addDense. + std::vector inputVectors; + std::vector inputRuntimePtrs; + std::vector newGroupPtrs; + for (auto& chunkPtr : hashAggrJitChunks_) { + auto& chunk = *chunkPtr; + if (!chunk.isCodegenReady()) { + LOG(INFO) << "HashAggrJit chunk is not codegen-ready, skip add: " + << chunk.getDescription(); + continue; + } + + const auto numSlots = chunk.slots().size(); + decoded.resize(numSlots); + inputRuntimes.resize(numSlots); + rowChildren.resize(numSlots); + rowChildPtrs.resize(numSlots); + inputVectors.assign(numSlots, nullptr); + inputRuntimePtrs.assign(numSlots, nullptr); + + bool canRunChunk = true; + std::string skipReason; + bool inputsMayHaveNulls = false; + for (auto slotIndex = 0; slotIndex < numSlots; ++slotIndex) { + const auto& slot = chunk.slots()[slotIndex]; + const auto& aggregate = aggregates_[slot.aggregateIndex]; + if (aggregate.mask.has_value() || aggregate.distinct || + !aggregate.sortingKeys.empty()) { + canRunChunk = false; + skipReason = "mask/distinct/sortingKeys not supported"; + break; + } + const auto& rows = getSelectivityVector(slot.aggregateIndex); + if (&rows != &activeRows_ || !rows.hasSelections()) { + canRunChunk = false; + skipReason = "selectivity vector is not dense activeRows or has no selections"; + break; + } + if (slot.desc.isCountStar()) { + continue; + } + if (aggregate.inputs.size() != 1) { + canRunChunk = false; + skipReason = "input count is not 1 for non-count(*) slot"; + break; + } + + VectorPtr arg; + if (aggregate.inputs[0] == kConstantChannel) { + arg = BaseVector::wrapInConstant(input->size(), 0, aggregate.constantInputs[0]); + } else { + arg = input->childAt(aggregate.inputs[0]); + } + if (mayPushdown && mayPushdown_[slot.aggregateIndex] && isLazyNotLoaded(*arg)) { + canRunChunk = false; + skipReason = "lazy input with pushdown enabled"; + break; + } + inputVectors[slotIndex] = arg; + decoded[slotIndex].decode(*arg, activeRows_); + const bool usesRowInputRuntime = + slot.desc.inputShape() == jit::HashAggrJitRuntimeShape::Row; + if (usesRowInputRuntime) { + if (!fillHashAggrJitRowInputRuntime( + inputRuntimes[slotIndex], + rowChildren[slotIndex], + rowChildPtrs[slotIndex], + decoded[slotIndex], + activeRows_, + slot)) { + canRunChunk = false; + skipReason = "ROW input runtime requires flat scalar row children"; + break; + } + } else { + inputRuntimes[slotIndex].scalar = jit::HashAggrJitScalarInputRuntime{ + .values = decoded[slotIndex].dataAsVoid(), + .indices = decoded[slotIndex].indices(), + .nulls = decoded[slotIndex].nulls(&activeRows_)}; + } + inputsMayHaveNulls = inputsMayHaveNulls || decoded[slotIndex].mayHaveNulls(); + inputRuntimePtrs[slotIndex] = + reinterpret_cast(&inputRuntimes[slotIndex]); + } + + if (!canRunChunk) { + LOG(INFO) << "HashAggrJit chunk cannot run add path, fallback to non-JIT: " + << chunk.getDescription() << " reason=" << skipReason; + continue; + } + + if (!newGroups.empty()) { + newGroupPtrs.resize(newGroups.size()); + for (auto i = 0; i < newGroups.size(); ++i) { + newGroupPtrs[i] = groups[newGroups[i]]; + } + NanosecondTimer jitTimer(&stats_.aggFunctionJitTimeNs); + chunk.init(newGroupPtrs.data(), newGroups.size()); + } + + { + NanosecondTimer jitTimer(&stats_.aggFunctionJitTimeNs); + chunk.addDense( + groups, + activeRows_.end(), + inputRuntimePtrs.data(), + inputsMayHaveNulls); + } + for (const auto& slot : chunk.slots()) { + aggregates_[slot.aggregateIndex].function->markNullCountUnknown(); + jitExecuted[slot.aggregateIndex] = 1; + } + } +} + +void GroupingSet::runHashAggrJitExtractChunks( + folly::Range groups, + const RowVectorPtr& result, + int32_t aggregateOutputOffset, + std::vector& jitExtracted) { + if (hashAggrJitChunks_.empty()) { + return; + } + + if (groups.empty() || hasSpilled() || supportRowBasedOutput_) { + LOG(INFO) << "HashAggrJit extract skipped: chunks=" << hashAggrJitChunks_.size() + << " groups=" << groups.size() << " hasSpilled=" << hasSpilled() + << " supportRowBasedOutput=" << supportRowBasedOutput_; + return; + } + + jitExtracted.assign(aggregates_.size(), 0); + std::vector outputRuntimes; + std::vector> + rowOutputChildren; + std::vector> + rowOutputChildPtrs; + std::vector resultPtrs; + for (auto& chunkPtr : hashAggrJitChunks_) { + auto& chunk = *chunkPtr; + if (!chunk.isCodegenReady()) { + // Background compilation not finished yet; leave these aggregates for the + // non-JIT extract path (jitExtracted stays 0 for their slots). + LOG(INFO) << "HashAggrJit chunk is not codegen-ready, skip extract: " + << chunk.getDescription(); + continue; + } + const auto numSlots = chunk.slots().size(); + outputRuntimes.assign(numSlots, jit::HashAggrJitOutputRuntime{}); + rowOutputChildren.resize(numSlots); + rowOutputChildPtrs.resize(numSlots); + resultPtrs.assign(numSlots, nullptr); + bool canRunChunk = true; + std::string skipReason; + for (auto slotIndex = 0; slotIndex < numSlots; ++slotIndex) { + const auto& slot = chunk.slots()[slotIndex]; + const auto& aggregate = aggregates_[slot.aggregateIndex]; + if (aggregate.distinct || aggregate.mask.has_value() || + !aggregate.sortingKeys.empty()) { + canRunChunk = false; + skipReason = "distinct/mask/sortingKeys not supported"; + break; + } + auto& aggregateVector = + result->childAt(slot.aggregateIndex + aggregateOutputOffset); + const auto expectedEncoding = + slot.desc.outputShape() == jit::HashAggrJitRuntimeShape::Row + ? VectorEncoding::Simple::ROW + : VectorEncoding::Simple::FLAT; + if (aggregateVector->encoding() != expectedEncoding) { + canRunChunk = false; + skipReason = "unexpected result vector encoding"; + break; + } + // Prepare stable raw output pointers after resizing. The JIT extract + // function still receives char** for ABI compatibility, but each element + // now points to HashAggrJitOutputRuntime rather than BaseVector directly. + aggregateVector->resize(groups.size()); + if (aggregateVector->encoding() == VectorEncoding::Simple::FLAT) { + // Derive the raw values kind from the output vector's actual storage + // type rather than slot.desc.accumulatorKind. They can differ: e.g. + // decimal avg's accumulatorKind is Int128 while its final result is a + // short decimal (FlatVector). Using accumulatorKind here would + // reinterpret an int64 buffer as int128 and corrupt the heap. + const auto outputKind = + hashAggrJitOutputValueKind(aggregateVector.get()); + if (!outputKind.has_value()) { + canRunChunk = false; + skipReason = "unsupported scalar output value kind"; + break; + } + outputRuntimes[slotIndex].scalar = + jit::HashAggrJitScalarOutputRuntime{ + .values = hashAggrJitRawOutputData( + aggregateVector.get(), *outputKind), + .nulls = aggregateVector->mutableRawNulls(), + .vector = aggregateVector.get()}; + } else if ( + aggregateVector->encoding() == VectorEncoding::Simple::ROW && + slot.desc.outputShape() == jit::HashAggrJitRuntimeShape::Row) { + if (!fillHashAggrJitRowOutputRuntime( + outputRuntimes[slotIndex], + rowOutputChildren[slotIndex], + rowOutputChildPtrs[slotIndex], + aggregateVector.get())) { + canRunChunk = false; + skipReason = "ROW output runtime requires flat scalar row children"; + break; + } + } + resultPtrs[slotIndex] = reinterpret_cast(&outputRuntimes[slotIndex]); + } + if (!canRunChunk) { + LOG(INFO) << "HashAggrJit chunk cannot run extract path, fallback to non-JIT: " + << chunk.getDescription() << " reason=" << skipReason; + continue; + } + { + NanosecondTimer jitTimer(&stats_.aggExtractGroupsJitTimeNs); + chunk.extract(groups.data(), groups.size(), resultPtrs.data()); + } + for (const auto& slot : chunk.slots()) { + resetHashAggrJitOutputDataDependentFlags( + result->childAt(slot.aggregateIndex + aggregateOutputOffset).get(), + slot); + jitExtracted[slot.aggregateIndex] = 1; + } + } +} +#endif + bool GroupingSet::getOutput( int32_t maxOutputRows, int32_t maxOutputBytes, @@ -881,7 +1488,14 @@ void GroupingSet::extractGroups( rows.extractColumn(groups.data(), groups.size(), i, keyVector); } } + std::vector jitExtracted; +#ifdef ENABLE_BOLT_JIT + runHashAggrJitExtractChunks(groups, result, totalKeys, jitExtracted); +#endif for (int32_t i = 0; i < aggregates_.size(); ++i) { + if (!jitExtracted.empty() && jitExtracted[i]) { + continue; + } if (!aggregates_[i].sortingKeys.empty()) { continue; } diff --git a/bolt/exec/GroupingSet.h b/bolt/exec/GroupingSet.h index 3927165e6..e49837242 100644 --- a/bolt/exec/GroupingSet.h +++ b/bolt/exec/GroupingSet.h @@ -39,6 +39,11 @@ #include "bolt/exec/Spiller.h" #include "bolt/exec/TreeOfLosers.h" #include "bolt/exec/VectorHasher.h" +#ifdef ENABLE_BOLT_JIT +#include +#include "bolt/jit/aggregation/HashAggrJit.h" +#endif +#include "bolt/vector/DecodedVector.h" namespace bytedance::bolt::exec { class GroupingSet { @@ -211,6 +216,9 @@ class GroupingSet { } common::AggregationStats getRuntimeStats() { +#ifdef ENABLE_BOLT_JIT + waitForHashAggrJitCompilation(); +#endif return stats_; } @@ -282,6 +290,24 @@ class GroupingSet { // index for this aggregation), otherwise it returns reference to activeRows_. const SelectivityVector& getSelectivityVector(size_t aggregateIndex) const; +#ifdef ENABLE_BOLT_JIT + void maybeCreateHashAggrJitPlan(); + // Blocks until all outstanding background JIT compilation tasks finish, so the + // chunks they reference can be safely destroyed or replaced. + void waitForHashAggrJitCompilation(); + void runHashAggrJitAddChunks( + char** groups, + folly::Range newGroups, + const RowVectorPtr& input, + bool mayPushdown, + std::vector& jitExecuted); + void runHashAggrJitExtractChunks( + folly::Range groups, + const RowVectorPtr& result, + int32_t aggregateOutputOffset, + std::vector& jitExtracted); +#endif + // Checks if input will fit in the existing memory and increases reservation // if not. If reservation cannot be increased, spills enough to make 'input' // fit. @@ -442,6 +468,18 @@ class GroupingSet { std::unique_ptr sortedAggregations_; std::vector> distinctAggregations_; +#ifdef ENABLE_BOLT_JIT + // unique_ptr gives each chunk a stable address so background compilation + // tasks can safely hold a raw pointer; HashAggrJitChunk is also non-movable + // (holds a std::atomic). Chunks start not-ready and flip ready once their + // background codegen completes. + std::vector> hashAggrJitChunks_; + // Outstanding background JIT compilation tasks for hashAggrJitChunks_. Each + // future returns the chunk's codegen time in nanoseconds so it can be + // aggregated into stats_.aggJitCodegenTimeNs on the query thread. + std::vector> hashAggrJitCompileFutures_; +#endif + // True if any aggregate accumulator allocates memory outside RowContainer's // HashStringAllocator (e.g. directly from MemoryPool). In that case, // RowContainer::estimateRowSize() can under-estimate the actual per-group diff --git a/bolt/exec/Operator.cpp b/bolt/exec/Operator.cpp index e599e88ea..8ec7c8691 100644 --- a/bolt/exec/Operator.cpp +++ b/bolt/exec/Operator.cpp @@ -833,6 +833,25 @@ void Operator::recordGroupingSetStats(const common::AggregationStats& stats) { "aggProbeBypassCount", RuntimeCounter{(int64_t)stats.aggProbeBypassCount}); } + if (stats.aggJitCodegenTimeNs) { + lockedStats->addRuntimeStat( + "aggJitCodegenTimeNs", + RuntimeCounter{ + (int64_t)stats.aggJitCodegenTimeNs, RuntimeCounter::Unit::kNanos}); + } + if (stats.aggFunctionJitTimeNs) { + lockedStats->addRuntimeStat( + "aggFunctionJitTimeNs", + RuntimeCounter{ + (int64_t)stats.aggFunctionJitTimeNs, RuntimeCounter::Unit::kNanos}); + } + if (stats.aggExtractGroupsJitTimeNs) { + lockedStats->addRuntimeStat( + "aggExtractGroupsJitTimeNs", + RuntimeCounter{ + (int64_t)stats.aggExtractGroupsJitTimeNs, + RuntimeCounter::Unit::kNanos}); + } } void Operator::recordHashBuildSpillStats( diff --git a/bolt/exec/RowContainer.cpp b/bolt/exec/RowContainer.cpp index 9412a2dbd..5656e447f 100644 --- a/bolt/exec/RowContainer.cpp +++ b/bolt/exec/RowContainer.cpp @@ -37,6 +37,7 @@ #include "bolt/type/StringView.h" #include "bolt/type/Timestamp.h" #include "bolt/vector/DecodedVector.h" +#include "bolt/vector/FlatVector.h" #include "bolt/common/memory/ByteStream.h" #include "bolt/common/memory/RawVector.h" @@ -1746,6 +1747,52 @@ __attribute__((__visibility__("default"))) double jit_GetDecodedValueDouble( return reinterpret_cast(vec) ->valueAt(index); } + +__attribute__((__visibility__("default"))) double +jit_GetDecodedRowFieldDouble(char* vec, int32_t index, int32_t field) { + auto* decoded = reinterpret_cast(vec); + auto* rowVector = decoded->base()->as(); + bytedance::bolt::DecodedVector fieldDecoded(*rowVector->childAt(field)); + return fieldDecoded.valueAt(decoded->index(index)); +} + +__attribute__((__visibility__("default"))) int8_t jit_GetDecodedRowFieldI8( + char* vec, + int32_t index, + int32_t field) { + auto* decoded = reinterpret_cast(vec); + auto* rowVector = decoded->base()->as(); + bytedance::bolt::DecodedVector fieldDecoded(*rowVector->childAt(field)); + return fieldDecoded.valueAt(decoded->index(index)); +} + +__attribute__((__visibility__("default"))) int64_t jit_GetDecodedRowFieldI64( + char* vec, + int32_t index, + int32_t field) { + auto* decoded = reinterpret_cast(vec); + auto* rowVector = decoded->base()->as(); + bytedance::bolt::DecodedVector fieldDecoded(*rowVector->childAt(field)); + return fieldDecoded.valueAt(decoded->index(index)); +} + +__attribute__((__visibility__("default"))) bytedance::bolt::int128_t +jit_GetDecodedRowFieldI128(char* vec, int32_t index, int32_t field) { + auto* decoded = reinterpret_cast(vec); + auto* rowVector = decoded->base()->as(); + bytedance::bolt::DecodedVector fieldDecoded(*rowVector->childAt(field)); + return fieldDecoded.valueAt(decoded->index(index)); +} + +__attribute__((__visibility__("default"))) int8_t jit_GetDecodedRowFieldIsNull( + char* vec, + int32_t index, + int32_t field) { + auto* decoded = reinterpret_cast(vec); + auto* rowVector = decoded->base()->as(); + bytedance::bolt::DecodedVector fieldDecoded(*rowVector->childAt(field)); + return fieldDecoded.isNullAt(decoded->index(index)); +} // get decoded value string __attribute__((__visibility__("default"))) const char* jit_GetDecodedValueStringView(char* vec, int32_t index) { diff --git a/bolt/exec/benchmarks/CMakeLists.txt b/bolt/exec/benchmarks/CMakeLists.txt index dfeef8cd5..de2d6a864 100644 --- a/bolt/exec/benchmarks/CMakeLists.txt +++ b/bolt/exec/benchmarks/CMakeLists.txt @@ -70,6 +70,15 @@ target_link_libraries( GTest::gtest_main ) +add_executable(bolt_hashaggr_jit_benchmark HashAggrJitBenchmark.cpp) + +target_link_libraries( + bolt_hashaggr_jit_benchmark PRIVATE + bolt_testutils + ${FOLLY_BENCHMARK} + GTest::gtest_main +) + add_executable(bolt_sort_random_data_benchmark SortRandomDataBenchmark.cpp) add_executable(bolt_sort_window_benchmark SortWindowBenchmark.cpp) diff --git a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp new file mode 100644 index 000000000..a016d646d --- /dev/null +++ b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp @@ -0,0 +1,402 @@ +/* + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +#include + +#include "bolt/core/QueryConfig.h" +#include "bolt/exec/tests/utils/AssertQueryBuilder.h" +#include "bolt/exec/tests/utils/PlanBuilder.h" +#include "bolt/functions/prestosql/aggregates/RegisterAggregateFunctions.h" +#include "bolt/functions/sparksql/aggregates/Register.h" +#include "bolt/parse/TypeResolver.h" +#include "bolt/vector/tests/utils/VectorTestBase.h" + +using namespace bytedance::bolt; +using namespace bytedance::bolt::exec; +using namespace bytedance::bolt::test; + +DEFINE_int32(hashaggr_jit_benchmark_batches, 20, "Number of input batches."); +DEFINE_int32(hashaggr_jit_benchmark_batch_size, 10000, "Rows per input batch."); +DEFINE_int32(hashaggr_jit_benchmark_groups, 10000, "Number of distinct groups."); + +namespace { + +struct HashAggrJitBenchmarkCase { + std::shared_ptr plan; + int32_t minFuseWidth{4}; + int32_t maxFuseWidth{16}; +}; + +enum class AggregationPlanKind { + Single, + Partial, + PartialFinal, +}; + +class HashAggrJitBenchmark : public VectorTestBase { + public: + void addBenchmark(const std::string& name, int32_t width) { + auto rows = makeRows(width); + std::vector sums; + std::vector avgs; + std::vector mins; + std::vector maxs; + std::vector counts; + sums.reserve(width); + avgs.reserve(width); + mins.reserve(width); + maxs.reserve(width); + counts.reserve(width); + for (auto i = 0; i < width; ++i) { + sums.push_back(fmt::format("spark_sum(c{})", i + 1)); + avgs.push_back(fmt::format("spark_avg(c{})", i + 1)); + mins.push_back(fmt::format("min(c{})", i + 1)); + maxs.push_back(fmt::format("max(c{})", i + 1)); + counts.push_back(fmt::format("count(c{})", i + 1)); + } + + // addCase(name + "_sum", rows, sums); + // addCase(name + "_avg", rows, avgs); + // addCase(name + "_min", rows, mins); + // addCase(name + "_max", rows, maxs); + // addCase(name + "_count", rows, counts); + addCase(name + "_merge_sum", rows, sums, AggregationPlanKind::PartialFinal); + addCase(name + "_merge_avg", rows, avgs, AggregationPlanKind::PartialFinal); + addCase(name + "_merge_min", rows, mins, AggregationPlanKind::PartialFinal); + addCase(name + "_merge_max", rows, maxs, AggregationPlanKind::PartialFinal); + addCase( + name + "_merge_count", rows, counts, AggregationPlanKind::PartialFinal); + } + + void addDecimalBenchmark(const std::string& name, int32_t width) { + auto rows = makeDecimalRows(width); + std::vector sums; + std::vector avgs; + for (auto i = 0; i < width; ++i) { + sums.push_back(fmt::format("spark_sum(c{})", i + 1)); + avgs.push_back(fmt::format("spark_avg(c{})", i + 1)); + } + addCase( + name + "_merge_decimal_sum", + rows, + sums, + AggregationPlanKind::PartialFinal); + addCase( + name + "_merge_decimal_avg", + rows, + avgs, + AggregationPlanKind::PartialFinal); + } + + void addFloatingPointMinMaxBenchmark(const std::string& name, int32_t width) { + auto rows = makeDoubleRows(width); + std::vector mins; + std::vector maxs; + for (auto i = 0; i < width; ++i) { + mins.push_back(fmt::format("min(c{})", i + 1)); + maxs.push_back(fmt::format("max(c{})", i + 1)); + } + // addCase(name + "_double_min", rows, mins); + // addCase(name + "_double_max", rows, maxs); + addCase( + name + "_merge_double_min", + rows, + mins, + AggregationPlanKind::PartialFinal); + addCase( + name + "_merge_double_max", + rows, + maxs, + AggregationPlanKind::PartialFinal); + } + + void addHighCardinalityExtractBenchmark(const std::string& name, int32_t width) { + auto rows = makeHighCardinalityRows(width); + std::vector avgs; + std::vector sums; + avgs.reserve(width); + sums.reserve(width); + for (auto i = 0; i < width; ++i) { + avgs.push_back(fmt::format("spark_avg(c{})", i + 1)); + sums.push_back(fmt::format("spark_sum(c{})", i + 1)); + } + addCase(name + "_partial_avg", rows, avgs, AggregationPlanKind::Partial); + addCase(name + "_partial_sum", rows, sums, AggregationPlanKind::Partial); + } + + void addHighCardinalityMergeBenchmark(const std::string& name, int32_t width) { + auto rows = makeHighCardinalityRows(width); + std::vector sums; + std::vector avgs; + std::vector mins; + std::vector maxs; + std::vector counts; + sums.reserve(width); + avgs.reserve(width); + mins.reserve(width); + maxs.reserve(width); + counts.reserve(width); + for (auto i = 0; i < width; ++i) { + sums.push_back(fmt::format("spark_sum(c{})", i + 1)); + avgs.push_back(fmt::format("spark_avg(c{})", i + 1)); + mins.push_back(fmt::format("min(c{})", i + 1)); + maxs.push_back(fmt::format("max(c{})", i + 1)); + counts.push_back(fmt::format("count(c{})", i + 1)); + } + + // addCase(name + "_sum", rows, sums); + // addCase(name + "_avg", rows, avgs); + // addCase(name + "_min", rows, mins); + // addCase(name + "_max", rows, maxs); + // addCase(name + "_count", rows, counts); + + addCase(name + "_merge_sum", rows, sums, AggregationPlanKind::PartialFinal); + addCase(name + "_merge_avg", rows, avgs, AggregationPlanKind::PartialFinal); + addCase(name + "_merge_min", rows, mins, AggregationPlanKind::PartialFinal); + addCase(name + "_merge_max", rows, maxs, AggregationPlanKind::PartialFinal); + addCase( + name + "_merge_count", rows, counts, AggregationPlanKind::PartialFinal); + } + + void addFixedAggregateCountFuseWidthBenchmark( + const std::string& name, + int32_t aggregateCount) { + auto rows = makeRows(aggregateCount); + std::vector sums; + sums.reserve(aggregateCount); + for (auto i = 0; i < aggregateCount; ++i) { + sums.push_back(fmt::format("spark_sum(c{})", i + 1)); + } + + for (const auto fuseWidth : {4, 8, 16, 32}) { + addCase( + fmt::format("{}_aggs{}_fuse_width{}", name, aggregateCount, fuseWidth), + rows, + sums, + AggregationPlanKind::PartialFinal, + fuseWidth, + fuseWidth); + } + } + + private: + std::vector makeRows(int32_t width) { + std::vector names; + std::vector children; + names.reserve(width + 1); + children.reserve(width + 1); + names.push_back("c0"); + children.push_back(makeFlatVector( + FLAGS_hashaggr_jit_benchmark_batch_size, + [](vector_size_t row) { return row % FLAGS_hashaggr_jit_benchmark_groups; })); + + for (auto column = 0; column < width; ++column) { + names.push_back(fmt::format("c{}", column + 1)); + children.push_back(makeFlatVector( + FLAGS_hashaggr_jit_benchmark_batch_size, + [column](vector_size_t row) { + return static_cast((row + 17 * column) & 0xffff); + })); + } + + auto batch = makeRowVector(names, children); + std::vector rows; + rows.reserve(FLAGS_hashaggr_jit_benchmark_batches); + for (auto i = 0; i < FLAGS_hashaggr_jit_benchmark_batches; ++i) { + rows.push_back(batch); + } + return rows; + } + + std::vector makeDecimalRows(int32_t width) { + std::vector names; + std::vector children; + names.push_back("c0"); + children.push_back(makeFlatVector( + FLAGS_hashaggr_jit_benchmark_batch_size, + [](vector_size_t row) { return row % FLAGS_hashaggr_jit_benchmark_groups; })); + for (auto column = 0; column < width; ++column) { + names.push_back(fmt::format("c{}", column + 1)); + children.push_back(makeFlatVector( + FLAGS_hashaggr_jit_benchmark_batch_size, + [column](vector_size_t row) { + return static_cast((row + 13 * column) % 100000); + }, + nullptr, + DECIMAL(12, 2))); + } + auto batch = makeRowVector(names, children); + return std::vector(FLAGS_hashaggr_jit_benchmark_batches, batch); + } + + std::vector makeDoubleRows(int32_t width) { + std::vector names; + std::vector children; + names.push_back("c0"); + children.push_back(makeFlatVector( + FLAGS_hashaggr_jit_benchmark_batch_size, + [](vector_size_t row) { return row % FLAGS_hashaggr_jit_benchmark_groups; })); + for (auto column = 0; column < width; ++column) { + names.push_back(fmt::format("c{}", column + 1)); + children.push_back(makeFlatVector( + FLAGS_hashaggr_jit_benchmark_batch_size, + [column](vector_size_t row) { + if ((row + column) % 997 == 0) { + return std::numeric_limits::quiet_NaN(); + } + return static_cast((row + 17 * column) & 0xffff); + })); + } + auto batch = makeRowVector(names, children); + return std::vector(FLAGS_hashaggr_jit_benchmark_batches, batch); + } + + std::vector makeHighCardinalityRows(int32_t width) { + std::vector rows; + rows.reserve(FLAGS_hashaggr_jit_benchmark_batches); + for (auto batchIndex = 0; batchIndex < FLAGS_hashaggr_jit_benchmark_batches; ++batchIndex) { + std::vector names; + std::vector children; + names.reserve(width + 1); + children.reserve(width + 1); + names.push_back("c0"); + children.push_back(makeFlatVector( + FLAGS_hashaggr_jit_benchmark_batch_size, + [batchIndex](vector_size_t row) { + return static_cast(batchIndex) * + FLAGS_hashaggr_jit_benchmark_batch_size + + row; + })); + for (auto column = 0; column < width; ++column) { + names.push_back(fmt::format("c{}", column + 1)); + children.push_back(makeFlatVector( + FLAGS_hashaggr_jit_benchmark_batch_size, + [batchIndex, column](vector_size_t row) { + return static_cast(batchIndex * 97 + column * 17 + row); + })); + } + rows.push_back(makeRowVector(names, children)); + } + return rows; + } + + std::shared_ptr makePlan( + const std::vector& rows, + const std::vector& aggregates, + AggregationPlanKind planKind = AggregationPlanKind::Single) { + exec::test::PlanBuilder builder; + builder.values(rows); + switch (planKind) { + case AggregationPlanKind::Single: + builder.singleAggregation({"c0"}, aggregates); + break; + case AggregationPlanKind::Partial: + builder.partialAggregation({"c0"}, aggregates); + break; + case AggregationPlanKind::PartialFinal: + builder.partialAggregation({"c0"}, aggregates).finalAggregation(); + break; + } + return builder.planNode(); + } + + void run( + const std::shared_ptr& plan, + bool enableJit, + int32_t minFuseWidth = 4, + int32_t maxFuseWidth = 16) { + exec::test::AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, enableJit ? "true" : "false") + .config( + core::QueryConfig::kHashAggrJitMinFuseWidth, + std::to_string(minFuseWidth)) + .config( + core::QueryConfig::kHashAggrJitMaxFuseWidth, + std::to_string(maxFuseWidth)) + .copyResults(pool_.get()); + } + + void addCase( + const std::string& name, + const std::vector& rows, + const std::vector& aggregates, + AggregationPlanKind planKind = AggregationPlanKind::Single, + int32_t minFuseWidth = 4, + int32_t maxFuseWidth = 16) { + auto testCase = std::make_unique(); + testCase->plan = makePlan(rows, aggregates, planKind); + testCase->minFuseWidth = minFuseWidth; + testCase->maxFuseWidth = maxFuseWidth; + // Warm-up intentionally disabled: we want each sample to include one-time + // plan setup / JIT compilation so the async/parallel-compile optimization + // (first-batch latency) is reflected in the measurement. + auto* testCasePtr = testCase.get(); + folly::addBenchmark(__FILE__, name + "_nojit", [this, testCasePtr]() { + run( + testCasePtr->plan, + false, + testCasePtr->minFuseWidth, + testCasePtr->maxFuseWidth); + return 1; + }); + folly::addBenchmark(__FILE__, name + "_jit", [this, testCasePtr]() { + run( + testCasePtr->plan, + true, + testCasePtr->minFuseWidth, + testCasePtr->maxFuseWidth); + return 1; + }); + folly::addBenchmark(__FILE__, "-", []() { return 0; }); + cases_.push_back(std::move(testCase)); + } + + std::vector> cases_; +}; + +} // namespace + +int main(int argc, char** argv) { + folly::init(&argc, &argv); + memory::initializeMemoryManager(memory::MemoryManager::Options{}); + aggregate::prestosql::registerAllAggregateFunctions(); + functions::aggregate::sparksql::registerAggregateFunctions("spark_", false); + parse::registerTypeResolver(); + + HashAggrJitBenchmark benchmark; + benchmark.addBenchmark("width4", 4); + benchmark.addBenchmark("width8", 8); + benchmark.addBenchmark("width16", 16); + benchmark.addBenchmark("width32", 32); + + benchmark.addHighCardinalityMergeBenchmark("width4_high_card", 4); + benchmark.addHighCardinalityMergeBenchmark("width8_high_card", 8); + benchmark.addHighCardinalityMergeBenchmark("width16_high_card", 16); + benchmark.addHighCardinalityMergeBenchmark("width32_high_card", 32); + + benchmark.addDecimalBenchmark("width4", 4); + benchmark.addDecimalBenchmark("width8", 8); + benchmark.addDecimalBenchmark("width16", 16); + benchmark.addDecimalBenchmark("width32", 32); + + benchmark.addFloatingPointMinMaxBenchmark("width4", 4); + benchmark.addFloatingPointMinMaxBenchmark("width8", 8); + benchmark.addFloatingPointMinMaxBenchmark("width16", 16); + benchmark.addFloatingPointMinMaxBenchmark("width32", 32); + + benchmark.addFixedAggregateCountFuseWidthBenchmark( + "fixed_aggregate_count", 256); + + // benchmark.addHighCardinalityExtractBenchmark("width4_high_card", 4); + // benchmark.addHighCardinalityExtractBenchmark("width8_high_card", 8); + // benchmark.addHighCardinalityExtractBenchmark("width16_high_card", 16); + // benchmark.addHighCardinalityExtractBenchmark("width32_high_card", 32); + + folly::runBenchmarks(); + return 0; +} diff --git a/bolt/exec/tests/AggregationTest.cpp b/bolt/exec/tests/AggregationTest.cpp index 04d819a0c..bc4fc2ef1 100644 --- a/bolt/exec/tests/AggregationTest.cpp +++ b/bolt/exec/tests/AggregationTest.cpp @@ -814,6 +814,24 @@ TEST_P(AggregationTest, setNull) { EXPECT_TRUE(aggregate.isNullTest(&group)); } +TEST_P(AggregationTest, isNullUsesBitWhenNullCountUnknown) { + AggregateFunc aggregate(BIGINT()); + int32_t nullOffset = 0; + aggregate.setOffsets( + 0, + RowContainer::nullByte(nullOffset), + RowContainer::nullMask(nullOffset), + 0); + + char group{0}; + EXPECT_TRUE(aggregate.setNullTest(&group)); + aggregate.markNullCountUnknown(); + + EXPECT_TRUE(aggregate.isNullTest(&group)); + EXPECT_TRUE(aggregate.clearNullTest(&group)); + EXPECT_FALSE(aggregate.isNullTest(&group)); +} + TEST_P(AggregationTest, hashmodes) { rng_.seed(1); auto rowType = diff --git a/bolt/functions/lib/aggregates/AverageAggregateBase.h b/bolt/functions/lib/aggregates/AverageAggregateBase.h index f8869255f..d24e1825c 100644 --- a/bolt/functions/lib/aggregates/AverageAggregateBase.h +++ b/bolt/functions/lib/aggregates/AverageAggregateBase.h @@ -33,6 +33,7 @@ #include "bolt/exec/Aggregate.h" #include "bolt/functions/lib/aggregates/AggregateToIntermediate.h" #include "bolt/functions/lib/aggregates/DecimalAggregate.h" +#include "bolt/functions/lib/aggregates/SumCount.h" #include "bolt/type/DecimalUtil.h" #include "bolt/vector/ComplexVector.h" #include "bolt/vector/DecodedVector.h" @@ -77,12 +78,6 @@ const SelectivityVector* getBaseRows( return baseRows; } -template -struct SumCount { - TSum sum{0}; - int64_t count{0}; -}; - } // namespace /// Partial aggregation produces a pair of sum and count. @@ -213,7 +208,7 @@ class AverageAggregateBase : public exec::Aggregate { groups[i], TAccumulator(decodedRaw_.valueAt(i))); }, nulls); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + } else if (exec::Aggregate::hasNoNulls() && decodedRaw_.isIdentityMapping()) { auto data = decodedRaw_.data(); rows.applyToSelected([&](vector_size_t i) { updateNonNullValue(groups[i], data[i]); @@ -256,7 +251,7 @@ class AverageAggregateBase : public exec::Aggregate { group, TAccumulator(decodedRaw_.valueAt(i))); }, nulls); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + } else if (exec::Aggregate::hasNoNulls() && decodedRaw_.isIdentityMapping()) { const TInput* data = decodedRaw_.data(); TAccumulator totalSum(0); rows.applyToSelected([&](vector_size_t i) { totalSum += data[i]; }); diff --git a/bolt/functions/lib/aggregates/CentralMomentsAggregatesBase.h b/bolt/functions/lib/aggregates/CentralMomentsAggregatesBase.h index 4e9445d77..b639784d0 100644 --- a/bolt/functions/lib/aggregates/CentralMomentsAggregatesBase.h +++ b/bolt/functions/lib/aggregates/CentralMomentsAggregatesBase.h @@ -263,7 +263,7 @@ class CentralMomentsAggregatesBase : public exec::Aggregate { } updateNonNullValue(groups[i], decodedRaw_.valueAt(i)); }); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + } else if (exec::Aggregate::hasNoNulls() && decodedRaw_.isIdentityMapping()) { auto data = decodedRaw_.data(); rows.applyToSelected([&](vector_size_t i) { updateNonNullValue(groups[i], data[i]); @@ -294,7 +294,7 @@ class CentralMomentsAggregatesBase : public exec::Aggregate { updateNonNullValue(group, decodedRaw_.valueAt(i)); } }); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + } else if (exec::Aggregate::hasNoNulls() && decodedRaw_.isIdentityMapping()) { auto data = decodedRaw_.data(); CentralMomentsAccumulator accData; rows.applyToSelected([&](vector_size_t i) { accData.update(data[i]); }); diff --git a/bolt/functions/lib/aggregates/DecimalAccumulatorLayout.h b/bolt/functions/lib/aggregates/DecimalAccumulatorLayout.h new file mode 100644 index 000000000..8dda85b1a --- /dev/null +++ b/bolt/functions/lib/aggregates/DecimalAccumulatorLayout.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * -------------------------------------------------------------------------- + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + * + * This modified file is released under the same license. + * -------------------------------------------------------------------------- + */ + +#pragma once + +#include + +#include "bolt/type/HugeInt.h" + +// Single source of truth for the in-memory layout of the decimal sum / decimal +// avg accumulators. The non-JIT accumulators (DecimalSum, +// LongDecimalWithOverflowState) derive from these POD layout bases, and the +// hash-aggregation JIT codegen aliases the same bases to compute field offsets +// (via offsetof) instead of mirroring the layout. A change here is therefore +// picked up by both paths automatically. +// +// Keep this header dependency-free (header-only PODs, no .cpp / no external +// symbols) so it can be included by both the non-JIT aggregates and the JIT +// module (bolt_thrustjit) without introducing inter-library link dependencies. +// +// IMPORTANT: classes deriving from these layouts must add *no* non-static data +// members (methods only), otherwise they stop being standard-layout and the +// offsets the JIT relies on become undefined. Each derived type guards this +// with static_assert(std::is_standard_layout_v<...>). +namespace bytedance::bolt::functions::aggregate { + +// Layout of the decimal sum accumulator: ROW(sum, isEmpty) with an overflow +// counter. Field order is part of the contract shared with the JIT path. +struct DecimalSumAccumulatorLayout { + int128_t sum{0}; + int64_t overflow{0}; + bool isEmpty{true}; +}; + +// Layout of the long-decimal-with-overflow accumulator (used by decimal avg): +// running sum, count of rows, and net overflow counter. NOTE: this is the +// in-memory layout {sum, count, overflow}, which differs from the serialized +// byte order {count, overflow, sum}; the JIT path reads memory, not serialized +// form, so it aligns with this layout. +struct LongDecimalWithOverflowLayout { + int128_t sum{0}; + int64_t count{0}; + int64_t overflow{0}; +}; + +} // namespace bytedance::bolt::functions::aggregate diff --git a/bolt/functions/lib/aggregates/DecimalAggregate.h b/bolt/functions/lib/aggregates/DecimalAggregate.h index 60466802d..9f34c4eec 100644 --- a/bolt/functions/lib/aggregates/DecimalAggregate.h +++ b/bolt/functions/lib/aggregates/DecimalAggregate.h @@ -32,6 +32,7 @@ #include "bolt/common/base/IOUtils.h" #include "bolt/exec/Aggregate.h" +#include "bolt/functions/lib/aggregates/DecimalAccumulatorLayout.h" #include "bolt/type/HugeInt.h" #include "bolt/vector/FlatVector.h" namespace bytedance::bolt::functions::aggregate { @@ -42,7 +43,7 @@ namespace bytedance::bolt::functions::aggregate { * COUNT: Total number of rows so far. * OVERFLOW: Total count of net overflow or underflow so far. */ -struct LongDecimalWithOverflowState { +struct LongDecimalWithOverflowState : LongDecimalWithOverflowLayout { public: void mergeWith(const StringView& serializedData) { BOLT_CHECK_EQ(serializedData.size(), serializedSize()); @@ -75,12 +76,10 @@ struct LongDecimalWithOverflowState { static constexpr size_t serializedSize() { return sizeof(int64_t) * 4; } - - int128_t sum{0}; - int64_t count{0}; - int64_t overflow{0}; }; +static_assert(std::is_standard_layout_v); + template class DecimalAggregate : public exec::Aggregate { public: @@ -133,7 +132,7 @@ class DecimalAggregate : public exec::Aggregate { groups[i], TResultType(decodedRaw_.valueAt(i))); }, nulls); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + } else if (exec::Aggregate::hasNoNulls() && decodedRaw_.isIdentityMapping()) { auto data = decodedRaw_.data(); rows.applyToSelected([&](vector_size_t i) { updateNonNullValue(groups[i], TResultType(data[i])); @@ -179,7 +178,7 @@ class DecimalAggregate : public exec::Aggregate { group, TResultType(decodedRaw_.valueAt(i))); }, nulls); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + } else if (exec::Aggregate::hasNoNulls() && decodedRaw_.isIdentityMapping()) { const TInputType* data = decodedRaw_.data(); LongDecimalWithOverflowState accumulator; rows.applyToSelected([&](vector_size_t i) { diff --git a/bolt/functions/lib/aggregates/SimpleNumericAggregate.h b/bolt/functions/lib/aggregates/SimpleNumericAggregate.h index ce71fcfda..1a4eb1c37 100644 --- a/bolt/functions/lib/aggregates/SimpleNumericAggregate.h +++ b/bolt/functions/lib/aggregates/SimpleNumericAggregate.h @@ -116,7 +116,7 @@ class SimpleNumericAggregate : public exec::Aggregate { exec::Aggregate::nullByte_, exec::Aggregate::nullMask_, groups, - &this->exec::Aggregate::numNulls_, + this->exec::Aggregate::numNulls_, updateSingleValue); auto indices = decoded.indices(); @@ -236,7 +236,7 @@ class SimpleNumericAggregate : public exec::Aggregate { exec::Aggregate::nullByte_, exec::Aggregate::nullMask_, groups, - &this->exec::Aggregate::numNulls_); + this->exec::Aggregate::numNulls_); // The decoded vector does not really keep the info from the 'rows', except // for the 'upper bound' of it. In case not all rows are selected we need to // generate proper indices, which we 'indirect' through the ones we got from diff --git a/bolt/functions/lib/aggregates/SumAggregateBase.h b/bolt/functions/lib/aggregates/SumAggregateBase.h index acafb3766..89c978b49 100644 --- a/bolt/functions/lib/aggregates/SumAggregateBase.h +++ b/bolt/functions/lib/aggregates/SumAggregateBase.h @@ -173,7 +173,7 @@ class SumAggregateBase return; } - if (exec::Aggregate::numNulls_) { + if (exec::Aggregate::mayHaveNulls()) { BaseAggregate::template updateGroups( groups, rows, arg, &updateSingleValue, false); } else { diff --git a/bolt/functions/lib/aggregates/SumCount.h b/bolt/functions/lib/aggregates/SumCount.h new file mode 100644 index 000000000..e7d41c521 --- /dev/null +++ b/bolt/functions/lib/aggregates/SumCount.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * -------------------------------------------------------------------------- + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + * + * This modified file is released under the same license. + * -------------------------------------------------------------------------- + */ + +#pragma once + +#include + +namespace bytedance::bolt::functions::aggregate { + +// Intermediate accumulator layout for AVG: a running sum and a count of +// non-null inputs. This is the single source of truth for the AVG intermediate +// memory layout. The hash-aggregation JIT codegen derives its field offsets +// from this struct (via offsetof) instead of mirroring the layout, so any +// change here is automatically picked up by the JIT path. +// +// Keep this header dependency-free (header-only template, no .cpp / no external +// symbols) so it can be included by both the non-JIT aggregates and the JIT +// module (bolt_thrustjit) without introducing inter-library link dependencies. +template +struct SumCount { + TSum sum{0}; + int64_t count{0}; +}; + +} // namespace bytedance::bolt::functions::aggregate diff --git a/bolt/functions/lib/aggregates/tests/SumTestBase.h b/bolt/functions/lib/aggregates/tests/SumTestBase.h index 564d7770c..e3a9a0ffa 100644 --- a/bolt/functions/lib/aggregates/tests/SumTestBase.h +++ b/bolt/functions/lib/aggregates/tests/SumTestBase.h @@ -53,13 +53,13 @@ void testHookLimits(bool expectOverflow = false) { sumRow.sum = 0; ResultType expected = 0; char* row = reinterpret_cast(&sumRow); - uint64_t numNulls = 0; + std::optional numNulls = 0; bytedance::bolt::aggregate::SumHook hook( offsetof(SumRow, sum), offsetof(SumRow, nulls), 0, &row, - &numNulls); + numNulls); // Adding limit should not overflow. ASSERT_NO_THROW(hook.addValue(0, &limit)); diff --git a/bolt/functions/prestosql/aggregates/CountAggregate.cpp b/bolt/functions/prestosql/aggregates/CountAggregate.cpp index da49ba53e..62c6f9860 100644 --- a/bolt/functions/prestosql/aggregates/CountAggregate.cpp +++ b/bolt/functions/prestosql/aggregates/CountAggregate.cpp @@ -53,6 +53,49 @@ class CountAggregate : public SimpleNumericAggregate { return true; } +#ifdef ENABLE_BOLT_JIT + bool supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const override { + const auto inputTypes = context.inputTypes(); + if (context.isRawInput) { + if (context.isCountStar()) { + return true; + } + if (inputTypes.size() != 1 || inputTypes[0] == nullptr) { + return false; + } + const auto& inputType = inputTypes[0]; + return !inputType->isRow() && + (inputType->isDecimal() || + jit::isHashAggrJitSupportedType(inputType->kind())); + } + return inputTypes.size() == 1 && inputTypes[0] != nullptr && + inputTypes[0]->kind() == TypeKind::BIGINT; + } + + std::optional createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const override { + if (!supportsHashAggrJit(context)) { + return std::nullopt; + } + auto inputKind = jit::HashAggrJitValueKind::Int64; + if (!context.isCountStar()) { + auto maybeInputKind = + jit::hashAggrJitValueKind(context.inputTypes()[0]->kind()); + if (!maybeInputKind.has_value()) { + return std::nullopt; + } + inputKind = *maybeInputKind; + } + return jit::HashAggrJitDescriptor{ + .kind = jit::HashAggrJitKind::Count, + .rawInputKind = inputKind, + .accumulatorKind = jit::HashAggrJitValueKind::Int64, + .context = context, + .ops = jit::getCountOps()}; + } +#endif + void toIntermediate( const SelectivityVector& rows, std::vector& args, @@ -73,15 +116,22 @@ class CountAggregate : public SimpleNumericAggregate { folly::Range indices) override { for (auto i : indices) { // result of count is never null + groups[i][nullByte_] &= ~nullMask_; *value(groups[i]) = (int64_t)0; } } FLATTEN void extractValues(char** groups, int32_t numGroups, VectorPtr* result) override { - BaseAggregate::doExtractValues(groups, numGroups, result, [&](char* group) { - return *value(group); - }); + auto* vector = (*result)->as>(); + BOLT_CHECK(vector); + vector->resize(numGroups); + vector->clearAllNulls(); + + auto* rawValues = vector->mutableRawValues(); + for (vector_size_t i = 0; i < numGroups; ++i) { + rawValues[i] = *value(groups[i]); + } } void addRawInput( diff --git a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp index 3fa9ed216..792cc82c1 100644 --- a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -52,6 +52,42 @@ class MinMaxAggregate : public SimpleNumericAggregate { public: explicit MinMaxAggregate(TypePtr resultType) : BaseAggregate(resultType) {} +#ifdef ENABLE_BOLT_JIT + bool supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const override { + const auto inputTypes = context.inputTypes(); + if (inputTypes.size() != 1 || inputTypes[0] == nullptr) { + return false; + } + const auto& inputType = inputTypes[0]; + return !inputType->isRow() && + (inputType->isDecimal() || + jit::isHashAggrJitSupportedType(inputType->kind()) || + inputType->kind() == TypeKind::HUGEINT); + } + + std::optional createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const override { + if (!supportsHashAggrJit(context)) { + return std::nullopt; + } + auto inputKind = jit::hashAggrJitValueKind(context.inputTypes()[0]->kind()); + if (!inputKind.has_value()) { + return std::nullopt; + } + return jit::HashAggrJitDescriptor{ + .kind = jitKind(), + .rawInputKind = *inputKind, + .accumulatorKind = *inputKind, + .context = context, + .ops = jit::getMinMaxOps()}; + } + + protected: + virtual jit::HashAggrJitKind jitKind() const = 0; + public: +#endif + int32_t accumulatorFixedWidthSize() const override { return sizeof(T); } @@ -135,6 +171,15 @@ class MaxAggregate : public MinMaxAggregate { public: explicit MaxAggregate(TypePtr resultType) : MinMaxAggregate(resultType) {} +#ifdef ENABLE_BOLT_JIT + protected: + jit::HashAggrJitKind jitKind() const override { + return jit::HashAggrJitKind::Max; + } + + public: +#endif + void initializeNewGroups( char** groups, folly::Range indices) final { @@ -225,6 +270,15 @@ class MinAggregate : public MinMaxAggregate { public: explicit MinAggregate(TypePtr resultType) : MinMaxAggregate(resultType) {} +#ifdef ENABLE_BOLT_JIT + protected: + jit::HashAggrJitKind jitKind() const override { + return jit::HashAggrJitKind::Min; + } + + public: +#endif + void initializeNewGroups( char** groups, folly::Range indices) final { diff --git a/bolt/functions/prestosql/aggregates/VarianceAggregates.cpp b/bolt/functions/prestosql/aggregates/VarianceAggregates.cpp index b48e44a8d..5e1f08d05 100644 --- a/bolt/functions/prestosql/aggregates/VarianceAggregates.cpp +++ b/bolt/functions/prestosql/aggregates/VarianceAggregates.cpp @@ -252,7 +252,7 @@ class VarianceAggregate : public exec::Aggregate { updateNonNullValue(groups[i], decodedRaw_.valueAt(i)); }, nulls); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + } else if (exec::Aggregate::hasNoNulls() && decodedRaw_.isIdentityMapping()) { auto data = decodedRaw_.data(); rows.applyToSelected([&](vector_size_t i) { updateNonNullValue(groups[i], data[i]); @@ -294,7 +294,7 @@ class VarianceAggregate : public exec::Aggregate { updateNonNullValue(group, decodedRaw_.valueAt(i)); }, nulls); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + } else if (exec::Aggregate::hasNoNulls() && decodedRaw_.isIdentityMapping()) { const T* data = decodedRaw_.data(); VarianceAccumulator accData; rows.applyToSelected([&](vector_size_t i) { accData.update(data[i]); }); diff --git a/bolt/functions/prestosql/aggregates/tests/CountAggregationTest.cpp b/bolt/functions/prestosql/aggregates/tests/CountAggregationTest.cpp index 78da1704e..dc3ec7fa8 100644 --- a/bolt/functions/prestosql/aggregates/tests/CountAggregationTest.cpp +++ b/bolt/functions/prestosql/aggregates/tests/CountAggregationTest.cpp @@ -94,6 +94,89 @@ TEST_F(CountAggregationTest, count) { "SELECT c0 % 10, count(c7) FROM tmp GROUP BY 1"); } +TEST_F(CountAggregationTest, hashAggrJitBooleanCount) { + auto data = makeRowVector( + {makeFlatVector(256, [](auto row) { return row % 11; }), + makeFlatVector( + 256, + [](auto row) { return row % 3 == 0; }, + [](auto row) { return row % 7 == 0; })}); + + auto singlePlan = PlanBuilder() + .values({data}) + .singleAggregation({"c0"}, {"count(c1)"}) + .planNode(); + + auto singleNoJit = AssertQueryBuilder(singlePlan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto singleJit = AssertQueryBuilder(singlePlan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .copyResults(pool()); + assertEqualResults({singleNoJit}, {singleJit}); + + auto partialFinalPlan = PlanBuilder() + .values({data}) + .partialAggregation({"c0"}, {"count(c1)"}) + .finalAggregation() + .planNode(); + + auto partialFinalNoJit = AssertQueryBuilder(partialFinalPlan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto partialFinalJit = AssertQueryBuilder(partialFinalPlan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .copyResults(pool()); + assertEqualResults({partialFinalNoJit}, {partialFinalJit}); +} + +TEST_F(CountAggregationTest, hashAggrJitDecimalCount) { + auto shortDecimal = makeRowVector( + {makeFlatVector(256, [](auto row) { return row % 11; }), + makeFlatVector( + 256, + [](auto row) { return (row * 7) % 101; }, + [](auto row) { return row % 7 == 0; }, + DECIMAL(18, 3))}); + auto longDecimal = makeRowVector( + {makeFlatVector(256, [](auto row) { return row % 11; }), + makeFlatVector( + 256, + [](auto row) { return HugeInt::build(row % 9, (row * 31) % 997); }, + [](auto row) { return row % 7 == 0; }, + DECIMAL(38, 19))}); + + for (const auto& data : {shortDecimal, longDecimal}) { + for (const auto& makePlan : + std::vector>{ + [&]() { + return PlanBuilder() + .values({data}) + .singleAggregation({"c0"}, {"count(c1)"}) + .planNode(); + }, + [&]() { + return PlanBuilder() + .values({data}) + .partialAggregation({"c0"}, {"count(c1)"}) + .finalAggregation() + .planNode(); + }}) { + auto plan = makePlan(); + auto noJit = AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto jit = AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .copyResults(pool()); + assertEqualResults({noJit}, {jit}); + } + } +} + TEST_F(CountAggregationTest, mask) { std::vector data; // Make batches where some batches have mask all true, some half and half and diff --git a/bolt/functions/prestosql/aggregates/tests/MinMaxTest.cpp b/bolt/functions/prestosql/aggregates/tests/MinMaxTest.cpp index cbc68fe16..652ad1f9e 100644 --- a/bolt/functions/prestosql/aggregates/tests/MinMaxTest.cpp +++ b/bolt/functions/prestosql/aggregates/tests/MinMaxTest.cpp @@ -30,6 +30,8 @@ #include #include +#include "bolt/exec/tests/utils/AssertQueryBuilder.h" +#include "bolt/exec/tests/utils/PlanBuilder.h" #include "bolt/common/base/tests/GTestUtils.h" #include "bolt/functions/lib/aggregates/tests/utils/AggregationTestBase.h" #include "bolt/vector/fuzzer/VectorFuzzer.h" @@ -220,6 +222,115 @@ TEST_F(MinMaxTest, minBoolean) { doTest(min, BOOLEAN()); } +TEST_F(MinMaxTest, hashAggrJitBooleanMinMax) { + auto data = makeRowVector( + {makeFlatVector(256, [](auto row) { return row % 13; }), + makeFlatVector( + 256, + [](auto row) { return row % 5 < 2; }, + [](auto row) { return row % 11 == 0; })}); + + for (const auto& makePlan : + std::vector>{ + [&]() { + return PlanBuilder() + .values({data}) + .singleAggregation({"c0"}, {"min(c1)", "max(c1)"}) + .planNode(); + }, + [&]() { + return PlanBuilder() + .values({data}) + .partialAggregation({"c0"}, {"min(c1)", "max(c1)"}) + .finalAggregation() + .planNode(); + }}) { + auto plan = makePlan(); + auto noJit = AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto jit = AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .copyResults(pool()); + assertEqualResults({noJit}, {jit}); + } +} + +TEST_F(MinMaxTest, hashAggrJitShortDecimalMinMax) { + auto data = makeRowVector( + {makeFlatVector(256, [](auto row) { return row % 13; }), + makeFlatVector( + 256, + [](auto row) { return (row * 7) % 101; }, + [](auto row) { return row % 11 == 0; }, + DECIMAL(18, 3))}); + + for (const auto& makePlan : + std::vector>{ + [&]() { + return PlanBuilder() + .values({data}) + .singleAggregation({"c0"}, {"min(c1)", "max(c1)"}) + .planNode(); + }, + [&]() { + return PlanBuilder() + .values({data}) + .partialAggregation({"c0"}, {"min(c1)", "max(c1)"}) + .finalAggregation() + .planNode(); + }}) { + auto plan = makePlan(); + auto noJit = AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto jit = AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .copyResults(pool()); + assertEqualResults({noJit}, {jit}); + } +} + +TEST_F(MinMaxTest, hashAggrJitLongDecimalMinMax) { + auto data = makeRowVector( + {makeFlatVector(256, [](auto row) { return row % 13; }), + makeFlatVector( + 256, + [](auto row) { + return HugeInt::build(row % 9, (row * 31) % 997); + }, + [](auto row) { return row % 11 == 0; }, + DECIMAL(38, 19))}); + + for (const auto& makePlan : + std::vector>{ + [&]() { + return PlanBuilder() + .values({data}) + .singleAggregation({"c0"}, {"min(c1)", "max(c1)"}) + .planNode(); + }, + [&]() { + return PlanBuilder() + .values({data}) + .partialAggregation({"c0"}, {"min(c1)", "max(c1)"}) + .finalAggregation() + .planNode(); + }}) { + auto plan = makePlan(); + auto noJit = AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto jit = AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .copyResults(pool()); + assertEqualResults({noJit}, {jit}); + } +} + TEST_F(MinMaxTest, constVarchar) { // Create two batches of the source data for the aggregation: // Column c0 with 1K of "apple" and 1K of "banana". diff --git a/bolt/functions/prestosql/aggregates/tests/SumTest.cpp b/bolt/functions/prestosql/aggregates/tests/SumTest.cpp index 13c880c06..fc73946cc 100644 --- a/bolt/functions/prestosql/aggregates/tests/SumTest.cpp +++ b/bolt/functions/prestosql/aggregates/tests/SumTest.cpp @@ -446,18 +446,19 @@ TEST_F(SumTest, hook) { sumRow.sum = 0; char* row = reinterpret_cast(&sumRow); - uint64_t numNulls = 1; + std::optional numNulls = 1; aggregate::SumHook hook( offsetof(SumRow, sum), offsetof(SumRow, nulls), 1, &row, - &numNulls); + numNulls); int64_t value = 11; hook.addValue(0, &value); EXPECT_EQ(0, sumRow.nulls); - EXPECT_EQ(0, numNulls); + ASSERT_TRUE(numNulls.has_value()); + EXPECT_EQ(0, *numNulls); EXPECT_EQ(value, sumRow.sum); } diff --git a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp index f1225b0f1..c9d4ec9e0 100644 --- a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp @@ -31,6 +31,7 @@ #include "bolt/functions/sparksql/aggregates/AverageAggregate.h" #include "bolt/functions/lib/aggregates/AverageAggregateBase.h" #include "bolt/functions/sparksql/DecimalUtil.h" + using namespace bytedance::bolt::functions::aggregate; namespace bytedance::bolt::functions::aggregate::sparksql { namespace { @@ -42,6 +43,54 @@ class AverageAggregate explicit AverageAggregate(TypePtr resultType) : AverageAggregateBase(resultType) {} +#ifdef ENABLE_BOLT_JIT + bool supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const override { + const auto inputTypes = context.inputTypes(); + if (inputTypes.size() != 1 || !inputTypes[0]) { + return false; + } + const auto& inputType = inputTypes[0]; + if (context.isRawInput) { + if (inputType->isDecimal()) { + return false; + } + return jit::isHashAggrJitSupportedType(inputType->kind()) || + inputType->kind() == TypeKind::HUGEINT; + } + return inputType->isRow() && inputType->size() == 2 && + inputType->childAt(1)->kind() == TypeKind::BIGINT && + inputType->childAt(0)->kind() == TypeKind::DOUBLE; + } + + std::optional createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const override { + if (!supportsHashAggrJit(context)) { + return std::nullopt; + } + + if (!context.isRawInput) { + return jit::HashAggrJitDescriptor{ + .kind = jit::HashAggrJitKind::Avg, + .rawInputKind = jit::HashAggrJitValueKind::Double, + .accumulatorKind = jit::HashAggrJitValueKind::Double, + .context = context, + .ops = jit::getAvgOps()}; + } + + auto inputKind = jit::hashAggrJitValueKind(context.inputTypes()[0]->kind()); + if (!inputKind.has_value()) { + return std::nullopt; + } + return jit::HashAggrJitDescriptor{ + .kind = jit::HashAggrJitKind::Avg, + .rawInputKind = *inputKind, + .accumulatorKind = jit::HashAggrJitValueKind::Double, + .context = context, + .ops = jit::getAvgOps()}; + } +#endif + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) override { auto rowVector = (*result)->as(); @@ -93,6 +142,42 @@ class DecimalAverageAggregate : public DecimalAggregate { explicit DecimalAverageAggregate(TypePtr resultType, TypePtr sumType) : DecimalAggregate(resultType), sumType_(sumType) {} +#ifdef ENABLE_BOLT_JIT + bool supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const override { + const auto inputTypes = context.inputTypes(); + if (inputTypes.size() != 1 || !inputTypes[0]) { + return false; + } + const auto& inputType = inputTypes[0]; + if (context.isRawInput) { + return inputType->isDecimal(); + } + return inputType->isRow() && inputType->size() == 2 && + inputType->childAt(0)->isDecimal() && + inputType->childAt(1)->kind() == TypeKind::BIGINT; + } + + std::optional createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const override { + if (!supportsHashAggrJit(context)) { + return std::nullopt; + } + const auto inputTypes = context.inputTypes(); + const auto& inputType = inputTypes[0]; + const auto& valueType = + context.isRawInput ? inputType : inputType->childAt(0); + return jit::HashAggrJitDescriptor{ + .kind = jit::HashAggrJitKind::DecimalAvg, + .rawInputKind = valueType->isShortDecimal() + ? jit::HashAggrJitValueKind::Int64 + : jit::HashAggrJitValueKind::Int128, + .accumulatorKind = jit::HashAggrJitValueKind::Int128, + .context = context, + .ops = jit::getDecimalAvgOps()}; + } +#endif + void addIntermediateResults( char** groups, const SelectivityVector& rows, diff --git a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h index d8d8d5e3d..9a310f23c 100644 --- a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h +++ b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h @@ -31,14 +31,11 @@ #pragma once #include "bolt/exec/Aggregate.h" #include "bolt/expression/FunctionSignature.h" +#include "bolt/functions/lib/aggregates/DecimalAccumulatorLayout.h" #include "bolt/vector/FlatVector.h" namespace bytedance::bolt::functions::aggregate::sparksql { -struct DecimalSum { - int128_t sum{0}; - int64_t overflow{0}; - bool isEmpty{true}; - +struct DecimalSum : DecimalSumAccumulatorLayout { void mergeWith(const DecimalSum& other) { this->overflow += other.overflow; this->overflow += @@ -47,6 +44,8 @@ struct DecimalSum { } }; +static_assert(std::is_standard_layout_v); + template class DecimalSumAggregate : public exec::Aggregate { public: @@ -61,6 +60,43 @@ class DecimalSumAggregate : public exec::Aggregate { return alignof(DecimalSum); } +#ifdef ENABLE_BOLT_JIT + bool supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const override { + const auto inputTypes = context.inputTypes(); + if (inputTypes.size() != 1 || !inputTypes[0]) { + return false; + } + const auto& inputType = inputTypes[0]; + if (context.isRawInput) { + return inputType->isDecimal() && + (inputType->isShortDecimal() || inputType->isLongDecimal()); + } + return inputType->isRow() && inputType->size() == 2 && + inputType->childAt(0)->isDecimal() && + inputType->childAt(1)->kind() == TypeKind::BOOLEAN; + } + + std::optional createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const override { + if (!supportsHashAggrJit(context)) { + return std::nullopt; + } + const auto inputTypes = context.inputTypes(); + const auto& inputType = inputTypes[0]; + const auto& valueType = + context.isRawInput ? inputType : inputType->childAt(0); + return jit::HashAggrJitDescriptor{ + .kind = jit::HashAggrJitKind::DecimalSum, + .rawInputKind = valueType->isShortDecimal() + ? jit::HashAggrJitValueKind::Int64 + : jit::HashAggrJitValueKind::Int128, + .accumulatorKind = jit::HashAggrJitValueKind::Int128, + .context = context, + .ops = jit::getDecimalSumOps()}; + } +#endif + void initializeNewGroups( char** groups, folly::Range indices) override { @@ -195,7 +231,7 @@ class DecimalSumAggregate : public exec::Aggregate { groups[i], decodedRaw_.valueAt(i), false); }, nulls); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + } else if (exec::Aggregate::hasNoNulls() && decodedRaw_.isIdentityMapping()) { auto data = decodedRaw_.data(); rows.applyToSelected([&](vector_size_t i) { updateNonNullValue(groups[i], data[i], false); @@ -237,7 +273,7 @@ class DecimalSumAggregate : public exec::Aggregate { group, decodedRaw_.valueAt(i), false); }, nulls); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + } else if (exec::Aggregate::hasNoNulls() && decodedRaw_.isIdentityMapping()) { auto data = decodedRaw_.data(); DecimalSum decimalSum; rows.applyToSelected([&](vector_size_t i) { diff --git a/bolt/functions/sparksql/aggregates/SumAggregate.cpp b/bolt/functions/sparksql/aggregates/SumAggregate.cpp index cd852df3f..acef1c147 100644 --- a/bolt/functions/sparksql/aggregates/SumAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/SumAggregate.cpp @@ -32,12 +32,58 @@ #include "bolt/functions/lib/aggregates/SumAggregateBase.h" #include "bolt/functions/sparksql/aggregates/DecimalSumAggregate.h" + using namespace bytedance::bolt::functions::aggregate; namespace bytedance::bolt::functions::aggregate::sparksql { namespace { template -using SumAggregate = SumAggregateBase; +class SumAggregate : public SumAggregateBase { + public: + explicit SumAggregate(TypePtr resultType) + : SumAggregateBase(resultType) {} + +#ifdef ENABLE_BOLT_JIT + bool supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const override { + const auto inputTypes = context.inputTypes(); + if (inputTypes.size() != 1 || !inputTypes[0]) { + return false; + } + const auto& inputType = inputTypes[0]; + if (inputType->isRow() || inputType->isDecimal()) { + return false; + } + return jit::isHashAggrJitSupportedType(inputType->kind()) || + inputType->kind() == TypeKind::HUGEINT; + } + + std::optional createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const override { + if (!supportsHashAggrJit(context)) { + return std::nullopt; + } + + auto inputKind = jit::hashAggrJitValueKind(context.inputTypes()[0]->kind()); + if (!inputKind.has_value()) { + return std::nullopt; + } + + const auto accumulatorKind = + (*inputKind == jit::HashAggrJitValueKind::Float || + *inputKind == jit::HashAggrJitValueKind::Double) + ? jit::HashAggrJitValueKind::Double + : jit::HashAggrJitValueKind::Int64; + + return jit::HashAggrJitDescriptor{ + .kind = jit::HashAggrJitKind::Sum, + .rawInputKind = *inputKind, + .accumulatorKind = accumulatorKind, + .context = context, + .ops = jit::getSumOps()}; + } +#endif +}; TypePtr getDecimalSumType( const TypePtr& resultType, diff --git a/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp b/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp index 488f06ae8..6c8a45f59 100644 --- a/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp +++ b/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp @@ -30,9 +30,12 @@ #include "bolt/exec/tests/utils/AssertQueryBuilder.h" #include "bolt/exec/tests/utils/PlanBuilder.h" +#include "bolt/exec/tests/utils/QueryAssertions.h" #include "bolt/functions/lib/aggregates/tests/SumTestBase.h" #include "bolt/functions/sparksql/aggregates/Register.h" +#include + using bytedance::bolt::exec::test::PlanBuilder; using namespace bytedance::bolt::exec::test; using namespace bytedance::bolt::functions::aggregate::test; @@ -131,6 +134,156 @@ TEST_F(SumAggregationTest, hookLimits) { testHookLimits(); } +TEST_F(SumAggregationTest, hashAggrJitDecimalSumAndFloatingMinMax) { + auto input = makeRowVector( + {makeFlatVector(256, [](auto row) { return row % 8; }), + makeFlatVector( + 256, [](auto row) { return row * 100; }, nullptr, DECIMAL(12, 2)), + makeFlatVector(256, [](auto row) { + return row % 31 == 0 ? std::numeric_limits::quiet_NaN() + : static_cast(row); + }), + makeFlatVector(256, [](auto row) { + return row % 37 == 0 ? std::numeric_limits::quiet_NaN() + : static_cast(1000 - row); + })}); + + auto plan = PlanBuilder(pool()) + .values({input}) + .singleAggregation( + {"c0"}, + {"spark_sum(c1)", "min(c2)", "max(c3)"}) + .planNode(); + + auto noJit = AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto jit = AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .copyResults(pool()); + assertEqualResults({noJit}, {jit}); +} + +TEST_F(SumAggregationTest, hashAggrJitMergeAndExtract) { + auto input = makeRowVector( + {makeFlatVector(512, [](auto row) { return row % 16; }), + makeFlatVector(512, [](auto row) { return row; }), + makeFlatVector(512, [](auto row) { return 1000 - row; }), + makeFlatVector( + 512, + [](auto row) { return row; }, + [](auto row) { return row % 7 == 0; })}); + + auto plan = PlanBuilder(pool()) + .values({input}) + .partialAggregation( + {"c0"}, + {"spark_sum(c1)", "spark_avg(c1)", "min(c2)", "count(c3)"}) + .finalAggregation() + .planNode(); + + auto noJit = AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto jit = AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .copyResults(pool()); + assertEqualResults({noJit}, {jit}); +} + +TEST_F(SumAggregationTest, hashAggrJitPartialAvgExtractAccumulators) { + auto input = makeRowVector( + {makeFlatVector(2048, [](auto row) { return row; }), + makeFlatVector(2048, [](auto row) { return row * 3; }), + makeFlatVector(2048, [](auto row) { return row * 7; })}); + + auto plan = PlanBuilder(pool()) + .values({input}) + .partialAggregation( + {"c0"}, {"spark_avg(c1)", "spark_avg(c2)"}) + .planNode(); + + auto noJit = AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto jit = AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .copyResults(pool()); + assertEqualResults({noJit}, {jit}); +} + +TEST_F(SumAggregationTest, hashAggrJitAllNullGroup) { + // Repro: one group (c0 == 1) has all-null sum input. Spark sum over an + // all-null group must yield null, not 0. Partial+final two-stage plan. + auto input = makeRowVector( + {makeFlatVector(8, [](auto row) { return row % 2; }), + makeFlatVector( + 8, + [](auto row) { return static_cast(row); }, + // c0 == 1 rows (odd rows) are all null. + [](auto row) { return row % 2 == 1; })}); + + auto plan = PlanBuilder(pool()) + .values({input}) + .partialAggregation({"c0"}, {"spark_sum(c1)"}) + .finalAggregation() + .planNode(); + + auto noJit = AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto jit = AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .copyResults(pool()); + assertEqualResults({noJit}, {jit}); +} + +TEST_F(SumAggregationTest, hashAggrJitSplitsContiguousSegments) { + auto input = makeRowVector( + {makeFlatVector(512, [](auto row) { return row % 16; }), + makeFlatVector(512, [](auto row) { return row; }), + makeFlatVector(512, [](auto row) { return row * 2; }), + makeFlatVector(512, [](auto row) { return 1000 - row; }), + makeFlatVector(512, [](auto row) { return row * 5; }), + makeFlatVector(512, [](auto row) { return row * 7; }), + makeFlatVector(512, [](auto row) { return row * 11; }), + makeFlatVector(512, [](auto row) { return row * 13; }), + makeFlatVector(512, [](auto row) { return row % 9; }), + makeFlatVector(512, [](auto row) { return row % 17; })}); + + auto plan = PlanBuilder(pool()) + .values({input}) + .singleAggregation( + {"c0"}, + {"min(c1)", + "max(c2)", + "spark_sum(c3)", + "spark_avg(c4)", + "min(c5)", + "max(c6)", + "spark_sum(c7)", + "spark_avg(c8)", + "spark_collect_list(c1)", + "spark_collect_list(c2)", + "min(c9)", + "max(c9)"}) + .planNode(); + + auto noJit = AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto jit = AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .config(core::QueryConfig::kHashAggrJitMaxFuseWidth, "4") + .copyResults(pool()); + assertEqualResults({noJit}, {jit}); +} + TEST_F(SumAggregationTest, decimalSum) { std::vector> shortDecimalRawVector; std::vector> longDecimalRawVector; diff --git a/bolt/jit/CMakeLists.txt b/bolt/jit/CMakeLists.txt index cade0033d..1f4473fb3 100644 --- a/bolt/jit/CMakeLists.txt +++ b/bolt/jit/CMakeLists.txt @@ -16,12 +16,43 @@ bolt_add_library( bolt_thrustjit CompiledModule.cpp ThrustJITv2.cpp + aggregation/HashAggrJit.cpp + aggregation/ops/CountOps.cpp + aggregation/ops/MinMaxOps.cpp + aggregation/ops/SumOps.cpp + aggregation/ops/AvgOps.cpp + aggregation/ops/DecimalSumOps.cpp + aggregation/ops/DecimalAvgOps.cpp RowContainer/RowContainerCodeGenerator.cpp RowContainer/RowEqVectorsCodeGenerator.cpp ) target_link_libraries(bolt_thrustjit PUBLIC llvm-core::llvm-core date::date fmt::fmt Folly::folly) +# Optional: report JIT-generated symbols to Intel VTune via the JIT Profiling +# API (libjitprofiling). Enable with -DBOLT_ENABLE_VTUNE_JIT=ON and point +# VTUNE_SDK_DIR at the VTune SDK (e.g. /opt/intel/oneapi/vtune//sdk). +# Runtime reporting is further gated by the BOLT_JIT_VTUNE env var. +if(BOLT_ENABLE_VTUNE_JIT) + if(NOT VTUNE_SDK_DIR) + set(VTUNE_SDK_DIR "/opt/intel/oneapi/vtune/2023.2.0/sdk") + endif() + find_path(VTUNE_JIT_INCLUDE_DIR jitprofiling.h + HINTS "${VTUNE_SDK_DIR}/include") + find_library(VTUNE_JIT_LIBRARY jitprofiling + HINTS "${VTUNE_SDK_DIR}/lib64" "${VTUNE_SDK_DIR}/lib32") + if(NOT VTUNE_JIT_INCLUDE_DIR OR NOT VTUNE_JIT_LIBRARY) + message(FATAL_ERROR + "BOLT_ENABLE_VTUNE_JIT=ON but jitprofiling.h/libjitprofiling not found " + "under VTUNE_SDK_DIR=${VTUNE_SDK_DIR}") + endif() + target_compile_definitions(bolt_thrustjit PRIVATE BOLT_ENABLE_VTUNE_JIT) + target_include_directories(bolt_thrustjit PRIVATE "${VTUNE_JIT_INCLUDE_DIR}") + # libjitprofiling needs dl/pthread at link time. + target_link_libraries(bolt_thrustjit PRIVATE "${VTUNE_JIT_LIBRARY}" ${CMAKE_DL_LIBS}) + message(STATUS "bolt_thrustjit: VTune JIT profiling enabled (${VTUNE_JIT_LIBRARY})") +endif() + target_compile_options( bolt_thrustjit PRIVATE $<$:-Werror=return-type> ) diff --git a/bolt/jit/ThrustJITv2.cpp b/bolt/jit/ThrustJITv2.cpp index 54d8e4a7f..03e43e019 100644 --- a/bolt/jit/ThrustJITv2.cpp +++ b/bolt/jit/ThrustJITv2.cpp @@ -21,16 +21,137 @@ #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h" +#include "llvm/Object/SymbolSize.h" #include "llvm/Support/DynamicLibrary.h" +#include "llvm/Support/Process.h" #include "llvm/Support/TargetSelect.h" #include #include +#include +#include +#include #include +#include +#include #include +#ifdef BOLT_ENABLE_VTUNE_JIT +#include +#endif + namespace bytedance::bolt::jit { +namespace { + +// Returns whether perf-map emission is enabled. The conan LLVM package is built +// without LLVM_USE_PERF, so llvm::JITEventListener::createPerfJITEventListener() +// is a no-op stub that returns nullptr. We therefore emit the perf map file +// ourselves. Controlled by the BOLT_JIT_PERF env var to avoid file IO overhead +// in normal runs. +bool perfMapEnabled() { + static const bool enabled = std::getenv("BOLT_JIT_PERF") != nullptr; + return enabled; +} + +// Appends symbols of a freshly loaded JIT object to /tmp/perf-.map so that +// `perf report` can resolve JIT-generated machine code to function names. +// Format per line: " ". +void appendPerfMap( + const llvm::object::ObjectFile& obj, + const llvm::RuntimeDyld::LoadedObjectInfo& loadedInfo) { + // Use the relocated/loaded object so symbol addresses are the runtime ones. + auto debugObjOwner = loadedInfo.getObjectForDebug(obj); + const llvm::object::ObjectFile& debugObj = *debugObjOwner.getBinary(); + + static std::mutex perfMapMutex; + std::lock_guard guard(perfMapMutex); + + std::string path = + "/tmp/perf-" + std::to_string(llvm::sys::Process::getProcessId()) + + ".map"; + std::FILE* file = std::fopen(path.c_str(), "a"); + if (file == nullptr) { + return; + } + + for (const auto& [symbol, size] : + llvm::object::computeSymbolSizes(debugObj)) { + auto typeOr = symbol.getType(); + if (!typeOr || *typeOr != llvm::object::SymbolRef::ST_Function) { + continue; + } + auto addrOr = symbol.getAddress(); + auto nameOr = symbol.getName(); + if (!addrOr || !nameOr || size == 0) { + llvm::consumeError(addrOr.takeError()); + llvm::consumeError(nameOr.takeError()); + continue; + } + std::fprintf( + file, + "%llx %llx %.*s\n", + static_cast(*addrOr), + static_cast(size), + static_cast(nameOr->size()), + nameOr->data()); + } + std::fclose(file); +} + +#ifdef BOLT_ENABLE_VTUNE_JIT +// Returns whether VTune JIT symbol reporting is enabled. Unlike perf's +// /tmp/perf-.map (which VTune does not read), VTune resolves JIT code only +// when the process actively reports each function via the Intel JIT Profiling +// API (iJIT_NotifyEvent). Controlled by the BOLT_JIT_VTUNE env var to avoid the +// reporting overhead in normal runs. +bool vtuneJitEnabled() { + static const bool enabled = std::getenv("BOLT_JIT_VTUNE") != nullptr; + return enabled; +} + +// Reports symbols of a freshly loaded JIT object to VTune through the Intel JIT +// Profiling API so that VTune can attribute samples in JIT-generated machine +// code to function names instead of "outside any known module". +void notifyVTune( + const llvm::object::ObjectFile& obj, + const llvm::RuntimeDyld::LoadedObjectInfo& loadedInfo) { + // Use the relocated/loaded object so symbol addresses are the runtime ones. + auto debugObjOwner = loadedInfo.getObjectForDebug(obj); + const llvm::object::ObjectFile& debugObj = *debugObjOwner.getBinary(); + + static std::mutex vtuneMutex; + std::lock_guard guard(vtuneMutex); + + for (const auto& [symbol, size] : + llvm::object::computeSymbolSizes(debugObj)) { + auto typeOr = symbol.getType(); + if (!typeOr || *typeOr != llvm::object::SymbolRef::ST_Function) { + continue; + } + auto addrOr = symbol.getAddress(); + auto nameOr = symbol.getName(); + if (!addrOr || !nameOr || size == 0) { + llvm::consumeError(addrOr.takeError()); + llvm::consumeError(nameOr.takeError()); + continue; + } + // iJIT_Method_Load.method_name is a non-const char*; keep a stable owning + // copy for the duration of the iJIT_NotifyEvent call. + std::string name = nameOr->str(); + iJIT_Method_Load jit_method = {}; + jit_method.method_id = iJIT_GetNewMethodID(); + jit_method.method_name = name.data(); + jit_method.method_load_address = reinterpret_cast(*addrOr); + jit_method.method_size = static_cast(size); + iJIT_NotifyEvent( + iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, static_cast(&jit_method)); + } +} +#endif // BOLT_ENABLE_VTUNE_JIT + +} // namespace + llvm::Expected> ThrustJITv2::Create() { static std::once_flag llvmTargetInitialized; std::call_once(llvmTargetInitialized, []() { @@ -59,7 +180,16 @@ llvm::Expected> ThrustJITv2::Create() { [tracker]( llvm::orc::MaterializationResponsibility& mr, const llvm::object::ObjectFile& obj, - const llvm::RuntimeDyld::LoadedObjectInfo&) { + const llvm::RuntimeDyld::LoadedObjectInfo& + loadedInfo) { + if (perfMapEnabled()) { + appendPerfMap(obj, loadedInfo); + } +#ifdef BOLT_ENABLE_VTUNE_JIT + if (vtuneJitEnabled()) { + notifyVTune(obj, loadedInfo); + } +#endif llvm::orc::ResourceKey resourceKey = 0; if (auto err = mr.withResourceKeyDo( [&](llvm::orc::ResourceKey key) { diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp new file mode 100644 index 000000000..eb1387ffe --- /dev/null +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -0,0 +1,1254 @@ +#ifdef ENABLE_BOLT_JIT + +#include "bolt/jit/aggregation/HashAggrJit.h" + +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include "bolt/common/base/BitUtil.h" +#include "bolt/common/base/Exceptions.h" +#include "bolt/jit/ThrustJITv2.h" + +extern "C" { + +using bytedance::bolt::jit::HashAggrJitRowInputRuntime; +using bytedance::bolt::jit::HashAggrJitRowOutputRuntime; +using bytedance::bolt::jit::HashAggrJitScalarInputRuntime; +using bytedance::bolt::jit::HashAggrJitScalarOutputRuntime; + +namespace { + +void logHashAggrJitFunctionIR( + const llvm::Module& module, + const std::string& moduleKey, + llvm::StringRef functionName, + llvm::StringRef stage, + bool hasError) { + const auto* function = module.getFunction(functionName); + if (function == nullptr) { + VLOG(1) << "HashAggrJit generated LLVM IR for chunk " << moduleKey + << " stage=" << stage.str() << " function=" << functionName.str() + << " error=" << hasError << ": "; + return; + } + std::string ir; + llvm::raw_string_ostream out(ir); + function->print(out); + out.flush(); + VLOG(1) << "HashAggrJit generated LLVM IR for chunk " << moduleKey + << " stage=" << stage.str() << " function=" << functionName.str() + << " error=" << hasError << ":\n" + << ir; +} + +constexpr uint64_t kScalarInputValuesOffset = + offsetof(HashAggrJitScalarInputRuntime, values); +constexpr uint64_t kScalarInputIndicesOffset = + offsetof(HashAggrJitScalarInputRuntime, indices); +constexpr uint64_t kScalarInputNullsOffset = + offsetof(HashAggrJitScalarInputRuntime, nulls); +constexpr uint64_t kRowInputNullsOffset = + offsetof(HashAggrJitRowInputRuntime, nulls); +constexpr uint64_t kRowInputChildrenOffset = + offsetof(HashAggrJitRowInputRuntime, children); + +constexpr uint64_t kScalarOutputValuesOffset = + offsetof(HashAggrJitScalarOutputRuntime, values); +constexpr uint64_t kScalarOutputNullsOffset = + offsetof(HashAggrJitScalarOutputRuntime, nulls); +constexpr uint64_t kScalarOutputVectorOffset = + offsetof(HashAggrJitScalarOutputRuntime, vector); +constexpr uint64_t kRowOutputNullsOffset = + offsetof(HashAggrJitRowOutputRuntime, nulls); +constexpr uint64_t kRowOutputChildrenOffset = + offsetof(HashAggrJitRowOutputRuntime, children); +constexpr uint64_t kRowOutputVectorOffset = + offsetof(HashAggrJitRowOutputRuntime, vector); + +} // namespace + +// Link anchor: the JIT extract/output runtime helpers live in separate +// translation units (HashAggrRuntime.cpp / HashAggrDecimalRuntime.cpp) and are +// only ever looked up by name through the ORC JIT global symbol table, never +// referenced at C++ link time. Without an explicit reference the linker would +// drop those objects from the final executable and the JIT would fail to +// resolve the symbols. Referencing one symbol per object forces the whole +// object (and thus every helper it defines) to be retained. This TU is always +// pulled in by any JIT user (HashAggrJitChunk), so the anchor propagates. +void jit_HashAggrResizeVector(char* vector, int32_t size); +void jit_HashAggrExtractFinalShortDecimalSum( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t precision, + int32_t scale, + int32_t accumulatorIsNull); + +[[maybe_unused]] __attribute__((used)) const void* const + kHashAggrRuntimeLinkAnchors[] = { + reinterpret_cast(&jit_HashAggrResizeVector), + reinterpret_cast(&jit_HashAggrExtractFinalShortDecimalSum)}; + +} // extern "C" + +namespace bytedance::bolt::jit { +namespace { + +llvm::FunctionCallee declareFunction( + llvm::Module& module, + llvm::StringRef name, + llvm::Type* returnType, + llvm::ArrayRef argTypes) { + return module.getOrInsertFunction( + name, llvm::FunctionType::get(returnType, argTypes, false)); +} + +void ensureBuiltinDeclarations(llvm::Module& module) { + auto& context = module.getContext(); + auto* i32Ty = llvm::Type::getInt32Ty(context); + auto* voidTy = llvm::Type::getVoidTy(context); + auto* i8PtrTy = llvm::PointerType::get(context, 0); + + declareFunction(module, "jit_HashAggrResizeVector", voidTy, {i8PtrTy, i32Ty}); + // Decimal extract helpers. + // Sum: (vector, row, group, offset, precision, scale, accumulatorIsNull). + declareFunction( + module, + "jit_HashAggrExtractFinalShortDecimalSum", + voidTy, + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i32Ty}); + declareFunction( + module, + "jit_HashAggrExtractFinalLongDecimalSum", + voidTy, + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i32Ty}); + declareFunction( + module, + "jit_HashAggrExtractPartialShortDecimalSum", + voidTy, + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i32Ty}); + declareFunction( + module, + "jit_HashAggrExtractPartialLongDecimalSum", + voidTy, + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i32Ty}); + + // Avg: (vector, row, group, offset, precision, scale, resultPrecision, + // resultScale). + declareFunction( + module, + "jit_HashAggrExtractFinalShortDecimalAvg", + voidTy, + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i32Ty, i32Ty}); + declareFunction( + module, + "jit_HashAggrExtractFinalLongDecimalAvg", + voidTy, + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i32Ty, i32Ty}); + declareFunction( + module, + "jit_HashAggrExtractPartialShortDecimalAvg", + voidTy, + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i32Ty, i32Ty}); + declareFunction( + module, + "jit_HashAggrExtractPartialLongDecimalAvg", + voidTy, + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i32Ty, i32Ty}); +} + +llvm::Type* llvmType(llvm::IRBuilder<>& builder, HashAggrJitValueKind kind) { + switch (kind) { + case HashAggrJitValueKind::Bool: + case HashAggrJitValueKind::Int8: + return builder.getInt8Ty(); + case HashAggrJitValueKind::Int16: + return builder.getInt16Ty(); + case HashAggrJitValueKind::Int32: + return builder.getInt32Ty(); + case HashAggrJitValueKind::Int64: + return builder.getInt64Ty(); + case HashAggrJitValueKind::Int128: + return builder.getInt128Ty(); + case HashAggrJitValueKind::Float: + return builder.getFloatTy(); + case HashAggrJitValueKind::Double: + return builder.getDoubleTy(); + } + return builder.getInt64Ty(); +} + +bool isFloatKind(HashAggrJitValueKind kind) { + return kind == HashAggrJitValueKind::Float || + kind == HashAggrJitValueKind::Double; +} + +bool supportsRawFlatOutput(HashAggrJitValueKind kind) { + switch (kind) { + case HashAggrJitValueKind::Bool: + case HashAggrJitValueKind::Int8: + case HashAggrJitValueKind::Int16: + case HashAggrJitValueKind::Int32: + case HashAggrJitValueKind::Int64: + case HashAggrJitValueKind::Int128: + case HashAggrJitValueKind::Float: + case HashAggrJitValueKind::Double: + return true; + } + return false; +} + +llvm::Value* loadPointerField( + llvm::IRBuilder<>& builder, + llvm::Value* descriptor, + uint64_t offset, + llvm::Type* pointerType, + llvm::StringRef name); + +llvm::Value* loadScalarOutputValues( + llvm::IRBuilder<>& builder, + llvm::Value* output) { + auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); + return loadPointerField( + builder, output, kScalarOutputValuesOffset, i8PtrTy, "output_values"); +} + +llvm::Value* loadScalarOutputNulls( + llvm::IRBuilder<>& builder, + llvm::Value* output) { + return loadPointerField( + builder, + output, + kScalarOutputNullsOffset, + builder.getInt64Ty()->getPointerTo(), + "output_nulls"); +} + +llvm::Value* loadScalarOutputVector( + llvm::IRBuilder<>& builder, + llvm::Value* output) { + auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); + return loadPointerField( + builder, output, kScalarOutputVectorOffset, i8PtrTy, "output_vector"); +} + +llvm::Value* loadPointerField( + llvm::IRBuilder<>& builder, + llvm::Value* descriptor, + uint64_t offset, + llvm::Type* pointerType, + llvm::StringRef name) { + auto* fieldAddr = builder.CreateConstInBoundsGEP1_64( + builder.getInt8Ty(), descriptor, offset); + auto* fieldPtrPtr = builder.CreatePointerCast(fieldAddr, pointerType->getPointerTo()); + return builder.CreateLoad(pointerType, fieldPtrPtr, name); +} + +llvm::Value* loadScalarInputIndex( + llvm::IRBuilder<>& builder, + llvm::Value* input, + llvm::Value* row) { + auto* i32Ty = builder.getInt32Ty(); + auto* indices = loadPointerField( + builder, + input, + kScalarInputIndicesOffset, + i32Ty->getPointerTo(), + "input_indices"); + return builder.CreateLoad(i32Ty, builder.CreateInBoundsGEP(i32Ty, indices, row)); +} + +llvm::Value* loadScalarInputValues( + llvm::IRBuilder<>& builder, + llvm::Value* input) { + auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); + return loadPointerField( + builder, input, kScalarInputValuesOffset, i8PtrTy, "input_values"); +} + +llvm::Value* loadScalarInputNulls( + llvm::IRBuilder<>& builder, + llvm::Value* input) { + return loadPointerField( + builder, + input, + kScalarInputNullsOffset, + builder.getInt64Ty()->getPointerTo(), + "input_nulls"); +} + +llvm::Value* loadRowInputNulls(llvm::IRBuilder<>& builder, llvm::Value* input) { + return loadPointerField( + builder, + input, + kRowInputNullsOffset, + builder.getInt64Ty()->getPointerTo(), + "row_input_nulls"); +} + +llvm::Value* loadRowInputChild( + llvm::IRBuilder<>& builder, + llvm::Value* input, + int32_t field) { + auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); + auto* children = loadPointerField( + builder, + input, + kRowInputChildrenOffset, + i8PtrTy->getPointerTo(), + "row_input_children"); + auto* childAddr = + builder.CreateConstInBoundsGEP1_64(i8PtrTy, children, field); + return builder.CreateLoad(i8PtrTy, childAddr, "row_input_child"); +} + +llvm::Value* loadScalarInputValue( + llvm::IRBuilder<>& builder, + llvm::Value* input, + llvm::Value* row, + HashAggrJitValueKind kind) { + auto* values = loadScalarInputValues(builder, input); + auto* index = loadScalarInputIndex(builder, input, row); + + if (kind == HashAggrJitValueKind::Bool) { + auto* wordTy = builder.getInt64Ty(); + auto* wordIndex = builder.CreateLShr(index, builder.getInt32(6)); + auto* bitIndex = builder.CreateAnd(index, builder.getInt32(63)); + auto* words = builder.CreatePointerCast(values, wordTy->getPointerTo()); + auto* word = builder.CreateLoad( + wordTy, + builder.CreateInBoundsGEP( + wordTy, + words, + builder.CreateZExt(wordIndex, builder.getInt64Ty()))); + auto* shifted = + builder.CreateLShr(word, builder.CreateZExt(bitIndex, wordTy)); + return builder.CreateZExt( + builder.CreateICmpNE( + builder.CreateAnd(shifted, builder.getInt64(1)), + builder.getInt64(0)), + builder.getInt8Ty()); + } + + auto* type = llvmType(builder, kind); + auto* typedValues = builder.CreatePointerCast(values, type->getPointerTo()); + auto* valueAddr = builder.CreateInBoundsGEP( + type, typedValues, builder.CreateZExt(index, builder.getInt64Ty())); + auto* load = builder.CreateLoad(type, valueAddr); + load->setAlignment(llvm::Align(1)); + return load; +} + +llvm::Value* loadRowOutputNulls(llvm::IRBuilder<>& builder, llvm::Value* output) { + return loadPointerField( + builder, + output, + kRowOutputNullsOffset, + builder.getInt64Ty()->getPointerTo(), + "row_output_nulls"); +} + +llvm::Value* loadRowOutputVector( + llvm::IRBuilder<>& builder, + llvm::Value* output) { + auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); + return loadPointerField( + builder, output, kRowOutputVectorOffset, i8PtrTy, "row_output_vector"); +} + +llvm::Value* loadRowOutputChild( + llvm::IRBuilder<>& builder, + llvm::Value* output, + int32_t field) { + auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); + auto* children = loadPointerField( + builder, + output, + kRowOutputChildrenOffset, + i8PtrTy->getPointerTo(), + "row_output_children"); + auto* childAddr = + builder.CreateConstInBoundsGEP1_64(i8PtrTy, children, field); + return builder.CreateLoad(i8PtrTy, childAddr, "row_output_child"); +} + +void emitOutputNullBit( + llvm::IRBuilder<>& builder, + llvm::Value* nulls, + llvm::Value* row, + llvm::Value* isNull) { + auto* i64Ty = builder.getInt64Ty(); + auto* wordIndex = builder.CreateLShr(row, builder.getInt32(6)); + auto* bitIndex = builder.CreateAnd(row, builder.getInt32(63)); + auto* wordAddr = builder.CreateInBoundsGEP( + i64Ty, nulls, builder.CreateZExt(wordIndex, builder.getInt64Ty())); + auto* word = builder.CreateLoad(i64Ty, wordAddr); + auto* mask = builder.CreateShl( + builder.getInt64(1), builder.CreateZExt(bitIndex, builder.getInt64Ty())); + auto* notNullWord = builder.CreateOr(word, mask); + auto* nullWord = builder.CreateAnd(word, builder.CreateNot(mask)); + auto* isNullBool = isNull->getType()->isIntegerTy(1) + ? isNull + : builder.CreateICmpNE(isNull, builder.getInt8(0)); + builder.CreateStore( + builder.CreateSelect(isNullBool, nullWord, notNullWord), wordAddr); +} + +// Writes a scalar value into a flat output values buffer at 'row'. Bool uses +// bit-packed storage (one bit per row in a uint64 word array), written with the +// same word/mask/select pattern as the null bitmap; all other kinds are +// fixed-width and written with a direct store. +void emitFlatScalarValue( + llvm::IRBuilder<>& builder, + llvm::Value* values, + llvm::Value* row, + HashAggrJitValueKind kind, + llvm::Value* value) { + if (kind == HashAggrJitValueKind::Bool) { + auto* bit = value->getType()->isIntegerTy(1) + ? value + : builder.CreateICmpNE(value, builder.getInt8(0)); + auto* wordTy = builder.getInt64Ty(); + auto* wordIndex = builder.CreateLShr(row, builder.getInt32(6)); + auto* bitIndex = builder.CreateAnd(row, builder.getInt32(63)); + auto* words = builder.CreatePointerCast(values, wordTy->getPointerTo()); + auto* wordAddr = builder.CreateInBoundsGEP( + wordTy, words, builder.CreateZExt(wordIndex, builder.getInt64Ty())); + auto* word = builder.CreateLoad(wordTy, wordAddr); + auto* mask = builder.CreateShl( + builder.getInt64(1), builder.CreateZExt(bitIndex, builder.getInt64Ty())); + auto* trueWord = builder.CreateOr(word, mask); + auto* falseWord = builder.CreateAnd(word, builder.CreateNot(mask)); + builder.CreateStore(builder.CreateSelect(bit, trueWord, falseWord), wordAddr); + return; + } + auto* type = llvmType(builder, kind); + auto* typedValues = builder.CreatePointerCast(values, type->getPointerTo()); + auto* valueAddr = builder.CreateInBoundsGEP( + type, typedValues, builder.CreateZExt(row, builder.getInt64Ty())); + auto* store = builder.CreateStore(value, valueAddr); + store->setAlignment(llvm::Align(1)); +} + +llvm::LoadInst* loadValue( + llvm::IRBuilder<>& builder, + llvm::Value* row, + llvm::Type* type, + int32_t offset) { + auto* addr = builder.CreateConstInBoundsGEP1_64( + builder.getInt8Ty(), row, static_cast(offset)); + auto* castAddr = builder.CreatePointerCast(addr, type->getPointerTo()); + auto* load = builder.CreateLoad(type, castAddr); + load->setAlignment(llvm::Align(1)); + return load; +} + +void storeValue( + llvm::IRBuilder<>& builder, + llvm::Value* row, + llvm::Type* type, + int32_t offset, + llvm::Value* value) { + auto* addr = builder.CreateConstInBoundsGEP1_64( + builder.getInt8Ty(), row, static_cast(offset)); + auto* castAddr = builder.CreatePointerCast(addr, type->getPointerTo()); + auto* store = builder.CreateStore(value, castAddr); + store->setAlignment(llvm::Align(1)); +} + +llvm::Value* castValue( + llvm::IRBuilder<>& builder, + llvm::Value* value, + HashAggrJitValueKind from, + HashAggrJitValueKind to) { + if (from == to) { + return value; + } + auto* toType = llvmType(builder, to); + if (isFloatKind(from) && isFloatKind(to)) { + return builder.CreateFPCast(value, toType); + } + if (!isFloatKind(from) && isFloatKind(to)) { + return builder.CreateSIToFP(value, toType); + } + if (isFloatKind(from) && !isFloatKind(to)) { + return builder.CreateFPToSI(value, toType); + } + return builder.CreateSExtOrTrunc(value, toType); +} + +llvm::Value* isAccumulatorNull( + llvm::IRBuilder<>& builder, + llvm::Value* group, + const HashAggrJitSlot& slot) { + auto* byte = loadValue(builder, group, builder.getInt8Ty(), slot.nullByte); + auto* mask = llvm::ConstantInt::get(builder.getInt8Ty(), slot.nullMask); + return builder.CreateICmpNE( + builder.CreateAnd(byte, mask), builder.getInt8(0)); +} + +void clearAccumulatorNull( + llvm::IRBuilder<>& builder, + llvm::Value* group, + const HashAggrJitSlot& slot) { + auto* byte = loadValue(builder, group, builder.getInt8Ty(), slot.nullByte); + auto* mask = llvm::ConstantInt::get( + builder.getInt8Ty(), static_cast(~slot.nullMask)); + storeValue( + builder, + group, + builder.getInt8Ty(), + slot.nullByte, + builder.CreateAnd(byte, mask)); +} + +void setAccumulatorNull( + llvm::IRBuilder<>& builder, + llvm::Value* group, + const HashAggrJitSlot& slot) { + auto* byte = loadValue(builder, group, builder.getInt8Ty(), slot.nullByte); + auto* mask = llvm::ConstantInt::get(builder.getInt8Ty(), slot.nullMask); + storeValue( + builder, + group, + builder.getInt8Ty(), + slot.nullByte, + builder.CreateOr(byte, mask)); +} + +llvm::Value* +isInputNull(llvm::IRBuilder<>& builder, llvm::Value* nulls, llvm::Value* row) { + auto* i64Ty = builder.getInt64Ty(); + auto* nullWords = builder.CreatePointerCast(nulls, i64Ty->getPointerTo()); + auto* wordIndex = builder.CreateLShr(row, builder.getInt32(6)); + auto* bitIndex = builder.CreateAnd(row, builder.getInt32(63)); + auto* word = builder.CreateLoad( + i64Ty, + builder.CreateInBoundsGEP( + i64Ty, nullWords, builder.CreateZExt(wordIndex, builder.getInt64Ty()))); + auto* shifted = builder.CreateLShr(word, builder.CreateZExt(bitIndex, i64Ty)); + return builder.CreateICmpEQ( + builder.CreateAnd(shifted, builder.getInt64(1)), builder.getInt64(0)); +} + +} // namespace + +HashAggrJitCodegen::HashAggrJitCodegen(llvm::Module& module) : module_(module) { + ensureBuiltinDeclarations(module_); +} + +llvm::Type* HashAggrJitCodegen::llvmType(HashAggrJitValueKind kind) const { + // Qualified call: the unqualified name would resolve to this member (name + // hiding stops lookup at class scope), so qualify to reach the file-local + // free function. + return bytedance::bolt::jit::llvmType(builder(), kind); +} + +llvm::Value* HashAggrJitCodegen::isInputNull( + llvm::Value* nulls, + llvm::Value* row) const { + return bytedance::bolt::jit::isInputNull(builder(), nulls, row); +} + +llvm::Value* HashAggrJitCodegen::isAccumulatorNull( + llvm::Value* group, + const HashAggrJitSlot& slot) const { + return bytedance::bolt::jit::isAccumulatorNull(builder(), group, slot); +} + +void HashAggrJitCodegen::clearAccumulatorNull( + llvm::Value* group, + const HashAggrJitSlot& slot) const { + bytedance::bolt::jit::clearAccumulatorNull(builder(), group, slot); +} + +void HashAggrJitCodegen::setAccumulatorNull( + llvm::Value* group, + const HashAggrJitSlot& slot) const { + bytedance::bolt::jit::setAccumulatorNull(builder(), group, slot); +} + +llvm::LoadInst* HashAggrJitCodegen::loadValue( + llvm::Value* row, + llvm::Type* type, + int32_t offset) const { + return bytedance::bolt::jit::loadValue(builder(), row, type, offset); +} + +void HashAggrJitCodegen::storeValue( + llvm::Value* row, + llvm::Type* type, + int32_t offset, + llvm::Value* value) const { + bytedance::bolt::jit::storeValue(builder(), row, type, offset, value); +} + +llvm::Value* HashAggrJitCodegen::castValue( + llvm::Value* value, + HashAggrJitValueKind from, + HashAggrJitValueKind to) const { + return bytedance::bolt::jit::castValue(builder(), value, from, to); +} + +bool HashAggrJitCodegen::isFloatKind(HashAggrJitValueKind kind) const { + return bytedance::bolt::jit::isFloatKind(kind); +} + +ScalarInputAdapterCodegen::ScalarInputAdapterCodegen( + HashAggrJitCodegen& codegen, + llvm::Value* input) + : codegen_(codegen), input_(input) {} + +llvm::StructType* ScalarInputAdapterCodegen::irRowType( + HashAggrJitValueKind kind) const { + return IRRow::getType(codegen_.builder(), codegen_.llvmType(kind)); +} + +llvm::Value* ScalarInputAdapterCodegen::read( + llvm::Value* row, + HashAggrJitValueKind kind) const { + auto* value = loadScalarInputValue( + codegen_.builder(), input_, row, kind); + // add_dense emits the top-level null guard before invoking aggregate ops. + // Therefore rows reaching ops are non-null; keep the IRRow contract explicit + // without duplicating the null bitmap check in every aggregate. + return IRRow::pack(codegen_.builder(), value, codegen_.builder().getFalse()); +} + +llvm::Value* ScalarInputAdapterCodegen::loadNulls() const { + return loadScalarInputNulls( + codegen_.builder(), input_); +} + +llvm::Value* ScalarInputAdapterCodegen::isNull(llvm::Value* row) const { + return codegen_.isInputNull(loadNulls(), row); +} + +llvm::Value* ScalarInputAdapterCodegen::readRowField( + llvm::Value*, + int32_t, + HashAggrJitValueKind) const { + BOLT_UNSUPPORTED("ScalarInputAdapterCodegen does not support ROW field load"); +} + +llvm::Value* ScalarInputAdapterCodegen::readRowFieldValue( + llvm::Value*, + int32_t, + HashAggrJitValueKind) const { + BOLT_UNSUPPORTED("ScalarInputAdapterCodegen does not support ROW field load"); +} + +RowInputAdapterCodegen::RowInputAdapterCodegen( + HashAggrJitCodegen& codegen, + llvm::Value* input) + : codegen_(codegen), input_(input) {} + +llvm::Value* RowInputAdapterCodegen::loadChild(int32_t field) const { + return loadRowInputChild( + codegen_.builder(), input_, field); +} + +llvm::StructType* RowInputAdapterCodegen::irRowType( + HashAggrJitValueKind kind) const { + return IRRow::getType(codegen_.builder(), codegen_.llvmType(kind)); +} + +llvm::Value* RowInputAdapterCodegen::read(llvm::Value*, HashAggrJitValueKind) + const { + BOLT_UNSUPPORTED("RowInputAdapterCodegen does not support scalar loadValue"); +} + +llvm::Value* RowInputAdapterCodegen::loadNulls() const { + return loadRowInputNulls(codegen_.builder(), input_); +} + +llvm::Value* RowInputAdapterCodegen::isNull(llvm::Value* row) const { + return codegen_.isInputNull(loadNulls(), row); +} + +llvm::Value* RowInputAdapterCodegen::readRowField( + llvm::Value* row, + int32_t field, + HashAggrJitValueKind kind) const { + auto* child = loadChild(field); + auto* value = loadScalarInputValue( + codegen_.builder(), child, row, kind); + return IRRow::pack(codegen_.builder(), value, isRowFieldNull(row, field)); +} + +llvm::Value* RowInputAdapterCodegen::readRowFieldValue( + llvm::Value* row, + int32_t field, + HashAggrJitValueKind kind) const { + // Skips the per-field null check CFG: returns only the raw field value. + // Valid when the field is guaranteed non-null on this path (its null bit is + // not consumed by the aggregate's merge semantics). + auto* child = loadChild(field); + return loadScalarInputValue( + codegen_.builder(), child, row, kind); +} + +llvm::Value* RowInputAdapterCodegen::isRowFieldNull( + llvm::Value* row, + int32_t field) const { + auto* child = loadChild(field); + auto* nulls = + loadScalarInputNulls(codegen_.builder(), child); + auto* hasNulls = codegen_.builder().CreateICmpNE( + nulls, + llvm::ConstantPointerNull::get( + codegen_.builder().getInt64Ty()->getPointerTo())); + auto* function = codegen_.builder().GetInsertBlock()->getParent(); + auto* nullCheckBlock = llvm::BasicBlock::Create( + codegen_.module().getContext(), "row_field_null_check", function); + auto* doneBlock = llvm::BasicBlock::Create( + codegen_.module().getContext(), "row_field_null_done", function); + codegen_.builder().CreateCondBr(hasNulls, nullCheckBlock, doneBlock); + auto* noNullsEnd = codegen_.builder().GetInsertBlock(); + + codegen_.builder().SetInsertPoint(nullCheckBlock); + auto* index = loadScalarInputIndex( + codegen_.builder(), child, row); + auto* isNull = codegen_.isInputNull(nulls, index); + codegen_.builder().CreateBr(doneBlock); + auto* nullCheckEnd = codegen_.builder().GetInsertBlock(); + + codegen_.builder().SetInsertPoint(doneBlock); + auto* result = codegen_.builder().CreatePHI( + codegen_.builder().getInt1Ty(), 2, "row_field_is_null"); + result->addIncoming(codegen_.builder().getFalse(), noNullsEnd); + result->addIncoming(isNull, nullCheckEnd); + return result; +} + +ScalarOutputAdapterCodegen::ScalarOutputAdapterCodegen( + HashAggrJitCodegen& codegen, + llvm::Value* output) + : codegen_(codegen), output_(output) {} + +llvm::Value* ScalarOutputAdapterCodegen::vector() const { + return loadScalarOutputVector( + codegen_.builder(), output_); +} + +void ScalarOutputAdapterCodegen::resize(llvm::Value* size) const { + codegen_.builder().CreateCall( + codegen_.module().getFunction("jit_HashAggrResizeVector"), + {vector(), size}); +} + +void ScalarOutputAdapterCodegen::write( + llvm::Value* row, + HashAggrJitValueKind kind, + llvm::Value* irRow) const { + auto& builder = codegen_.builder(); + BOLT_CHECK( + supportsRawFlatOutput(kind), + "Unsupported raw flat scalar output kind for HashAggrJit"); + auto* value = IRRow::getValue(builder, irRow); + auto* isNull = IRRow::getIsNull(builder, irRow); + auto* values = loadScalarOutputValues(builder, output_); + emitFlatScalarValue(builder, values, row, kind, value); + auto* nulls = loadScalarOutputNulls(builder, output_); + emitOutputNullBit(builder, nulls, row, isNull); +} + +void ScalarOutputAdapterCodegen::writeField( + llvm::Value*, + int32_t, + HashAggrJitValueKind, + llvm::Value*) const { + BOLT_UNSUPPORTED("ScalarOutputAdapterCodegen does not support ROW field write"); +} + +void ScalarOutputAdapterCodegen::writeNull( + llvm::Value* row, + llvm::Value* isNull) const { + auto* nulls = loadScalarOutputNulls( + codegen_.builder(), output_); + emitOutputNullBit( + codegen_.builder(), nulls, row, isNull); +} + +RowOutputAdapterCodegen::RowOutputAdapterCodegen( + HashAggrJitCodegen& codegen, + llvm::Value* output) + : codegen_(codegen), output_(output) {} + +llvm::Value* RowOutputAdapterCodegen::loadChild(int32_t field) const { + return loadRowOutputChild( + codegen_.builder(), output_, field); +} + +llvm::Value* RowOutputAdapterCodegen::vector() const { + return loadRowOutputVector(codegen_.builder(), output_); +} + +void RowOutputAdapterCodegen::resize(llvm::Value* size) const { + codegen_.builder().CreateCall( + codegen_.module().getFunction("jit_HashAggrResizeVector"), + {vector(), size}); +} + +void RowOutputAdapterCodegen::write( + llvm::Value*, + HashAggrJitValueKind, + llvm::Value*) const { + BOLT_UNSUPPORTED("RowOutputAdapterCodegen does not support scalar write"); +} + +void RowOutputAdapterCodegen::writeField( + llvm::Value* row, + int32_t field, + HashAggrJitValueKind kind, + llvm::Value* irRow) const { + auto& builder = codegen_.builder(); + BOLT_CHECK( + supportsRawFlatOutput(kind), + "Unsupported raw ROW output field kind for HashAggrJit"); + auto* child = loadChild(field); + auto* value = IRRow::getValue(builder, irRow); + auto* isNull = IRRow::getIsNull(builder, irRow); + auto* values = loadScalarOutputValues(builder, child); + emitFlatScalarValue(builder, values, row, kind, value); + auto* nulls = loadScalarOutputNulls(builder, child); + emitOutputNullBit(builder, nulls, row, isNull); +} + +void RowOutputAdapterCodegen::writeNull( + llvm::Value* row, + llvm::Value* isNull) const { + auto* nulls = loadRowOutputNulls( + codegen_.builder(), output_); + emitOutputNullBit( + codegen_.builder(), nulls, row, isNull); +} + +namespace { + +char hashAggrJitRuntimeShapeName(HashAggrJitRuntimeShape shape) { + switch (shape) { + case HashAggrJitRuntimeShape::Scalar: + return 's'; + case HashAggrJitRuntimeShape::Row: + return 'r'; + } + return 'u'; +} + +bool usesRowInputRuntime(const HashAggrJitSlot& slot) { + return slot.desc.inputShape() == HashAggrJitRuntimeShape::Row; +} + +bool usesRowOutputRuntime(const HashAggrJitSlot& slot) { + return slot.desc.outputShape() == HashAggrJitRuntimeShape::Row; +} + +bool genAddDenseIR( + llvm::Module& module, + const std::string& fn, + const std::vector& slots, + bool checkInputNulls); + +bool genInitIR( + llvm::Module& module, + const std::string& fn, + const std::vector& slots) { + auto& context = module.getContext(); + llvm::IRBuilder<> builder(context); + HashAggrJitCodegen codegen(module); + codegen.setBuilder(&builder); + auto* voidTy = builder.getVoidTy(); + auto* i8PtrTy = llvm::PointerType::get(context, 0); + auto* i8PtrPtrTy = i8PtrTy->getPointerTo(); + auto* i32Ty = builder.getInt32Ty(); + auto* funcTy = llvm::FunctionType::get(voidTy, {i8PtrPtrTy, i32Ty}, false); + auto* func = llvm::Function::Create( + funcTy, llvm::Function::ExternalLinkage, fn, module); + auto argIt = func->arg_begin(); + llvm::Value* newGroups = &*argIt++; + newGroups->setName("new_groups"); + llvm::Value* numNewGroups = &*argIt++; + numNewGroups->setName("num_new_groups"); + + auto* entry = llvm::BasicBlock::Create(context, "entry", func); + auto* loop = llvm::BasicBlock::Create(context, "loop", func); + auto* end = llvm::BasicBlock::Create(context, "end", func); + builder.SetInsertPoint(entry); + builder.CreateCondBr( + builder.CreateICmpSLE(numNewGroups, builder.getInt32(0)), end, loop); + + builder.SetInsertPoint(loop); + auto* index = builder.CreatePHI(i32Ty, 2, "idx"); + index->addIncoming(builder.getInt32(0), entry); + auto* groupAddr = builder.CreateInBoundsGEP(i8PtrTy, newGroups, index); + auto* group = builder.CreateLoad(i8PtrTy, groupAddr); + + for (const auto& slot : slots) { + if (slot.desc.ops == nullptr || slot.desc.ops->initGroup == nullptr) { + return false; + } + slot.desc.ops->initGroup(codegen, group, slot); + } + + auto* next = builder.CreateAdd(index, builder.getInt32(1)); + index->addIncoming(next, builder.GetInsertBlock()); + builder.CreateCondBr(builder.CreateICmpSLT(next, numNewGroups), loop, end); + + builder.SetInsertPoint(end); + builder.CreateRetVoid(); + + return true; +} + +bool genAddDenseIR( + llvm::Module& module, + const std::string& fn, + const std::vector& slots, + bool checkInputNulls) { + auto& context = module.getContext(); + llvm::IRBuilder<> builder(context); + HashAggrJitCodegen codegen(module); + codegen.setBuilder(&builder); + auto* voidTy = builder.getVoidTy(); + auto* i8PtrTy = llvm::PointerType::get(context, 0); + auto* i8PtrPtrTy = i8PtrTy->getPointerTo(); + auto* i32Ty = builder.getInt32Ty(); + auto* funcTy = llvm::FunctionType::get(voidTy, {i8PtrPtrTy, i32Ty, i8PtrPtrTy}, false); + auto* func = llvm::Function::Create(funcTy, llvm::Function::ExternalLinkage, fn, module); + auto argIt = func->arg_begin(); + llvm::Value* groups = &*argIt++; + groups->setName("groups"); + llvm::Value* numRows = &*argIt++; + numRows->setName("num_rows"); + llvm::Value* inputRuntimes = &*argIt++; + inputRuntimes->setName("input_runtimes"); + + auto* entry = llvm::BasicBlock::Create(context, "entry", func); + auto* loop = llvm::BasicBlock::Create(context, "loop", func); + auto* end = llvm::BasicBlock::Create(context, "end", func); + builder.SetInsertPoint(entry); + builder.CreateCondBr(builder.CreateICmpSLE(numRows, builder.getInt32(0)), end, loop); + + builder.SetInsertPoint(loop); + auto* row = builder.CreatePHI(i32Ty, 2, "row"); + row->addIncoming(builder.getInt32(0), entry); + auto* groupAddr = builder.CreateInBoundsGEP(i8PtrTy, groups, row); + auto* group = builder.CreateLoad(i8PtrTy, groupAddr); + + for (auto i = 0; i < slots.size(); ++i) { + const auto& slot = slots[i]; + auto* updateBlock = llvm::BasicBlock::Create(context, "slot_update", func, end); + auto* nextBlock = llvm::BasicBlock::Create(context, "slot_next", func, end); + auto* inputAddr = + builder.CreateConstInBoundsGEP1_64(i8PtrTy, inputRuntimes, i); + auto* inputRuntime = builder.CreateLoad(i8PtrTy, inputAddr); + std::unique_ptr input; + if (usesRowInputRuntime(slot)) { + input = std::make_unique(codegen, inputRuntime); + } else { + input = + std::make_unique(codegen, inputRuntime); + } + if (checkInputNulls && !slot.desc.isCountStar()) { + auto* nulls = input->loadNulls(); + auto* nullCheckBlock = + llvm::BasicBlock::Create(context, "slot_null_check", func, end); + auto* hasNulls = builder.CreateICmpNE( + nulls, + llvm::ConstantPointerNull::get(builder.getInt64Ty()->getPointerTo())); + builder.CreateCondBr(hasNulls, nullCheckBlock, updateBlock); + + builder.SetInsertPoint(nullCheckBlock); + auto* isNull = codegen.isInputNull(nulls, row); + builder.CreateCondBr(isNull, nextBlock, updateBlock); + } else { + builder.CreateBr(updateBlock); + } + + builder.SetInsertPoint(updateBlock); + if (slot.desc.ops == nullptr) { + return false; + } + auto* addFn = !slot.desc.isRawInput() + ? slot.desc.ops->addIntermediateResults + : slot.desc.ops->addRawInput; + if (addFn == nullptr) { + return false; + } + addFn(codegen, group, *input, row, slot, nextBlock); + builder.CreateBr(nextBlock); + builder.SetInsertPoint(nextBlock); + } + + auto* next = builder.CreateAdd(row, builder.getInt32(1)); + row->addIncoming(next, builder.GetInsertBlock()); + builder.CreateCondBr(builder.CreateICmpSLT(next, numRows), loop, end); + + builder.SetInsertPoint(end); + builder.CreateRetVoid(); + + return true; +} + +bool genExtractIR( + llvm::Module& module, + const std::string& fn, + const std::vector& slots) { + auto& context = module.getContext(); + llvm::IRBuilder<> builder(context); + HashAggrJitCodegen codegen(module); + codegen.setBuilder(&builder); + auto* voidTy = builder.getVoidTy(); + auto* i8PtrTy = llvm::PointerType::get(context, 0); + auto* i8PtrPtrTy = i8PtrTy->getPointerTo(); + auto* i32Ty = builder.getInt32Ty(); + auto* funcTy = llvm::FunctionType::get(voidTy, {i8PtrPtrTy, i32Ty, i8PtrPtrTy}, false); + auto* func = llvm::Function::Create(funcTy, llvm::Function::ExternalLinkage, fn, module); + auto argIt = func->arg_begin(); + llvm::Value* groups = &*argIt++; + groups->setName("groups"); + llvm::Value* numGroups = &*argIt++; + numGroups->setName("num_groups"); + llvm::Value* resultVectors = &*argIt++; + resultVectors->setName("result_vectors"); + + auto* entry = llvm::BasicBlock::Create(context, "entry", func); + auto* loop = llvm::BasicBlock::Create(context, "loop", func); + auto* end = llvm::BasicBlock::Create(context, "end", func); + builder.SetInsertPoint(entry); + builder.CreateCondBr(builder.CreateICmpSLE(numGroups, builder.getInt32(0)), end, loop); + + builder.SetInsertPoint(loop); + auto* row = builder.CreatePHI(i32Ty, 2, "row"); + row->addIncoming(builder.getInt32(0), entry); + auto* groupAddr = builder.CreateInBoundsGEP(i8PtrTy, groups, row); + auto* group = builder.CreateLoad(i8PtrTy, groupAddr); + + for (auto i = 0; i < slots.size(); ++i) { + const auto& slot = slots[i]; + if (slot.desc.ops == nullptr) { + continue; + } + auto* outputAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, resultVectors, i); + auto* outputRuntime = builder.CreateLoad(i8PtrTy, outputAddr); + std::unique_ptr output; + if (usesRowOutputRuntime(slot)) { + output = std::make_unique(codegen, outputRuntime); + } else { + output = + std::make_unique(codegen, outputRuntime); + } + auto* extractFn = slot.desc.context.isPartialOutput + ? slot.desc.ops->extractAccumulators + : slot.desc.ops->extractResults; + if (extractFn == nullptr) { + return false; + } + extractFn( + codegen, group, slot, HashAggrJitExtractTarget{*output, row}); + } + + auto* next = builder.CreateAdd(row, builder.getInt32(1)); + row->addIncoming(next, builder.GetInsertBlock()); + builder.CreateCondBr(builder.CreateICmpSLT(next, numGroups), loop, end); + + builder.SetInsertPoint(end); + builder.CreateRetVoid(); + + return true; +} + +} // namespace + +std::string HashAggrJitSlot::getDescription() const { + const auto inputTypes = desc.context.inputTypes(); + std::ostringstream inputs; + inputs << "["; + for (size_t i = 0; i < inputTypes.size(); ++i) { + if (i > 0) { + inputs << ","; + } + inputs << inputTypes[i]->toString(); + } + inputs << "]"; + + // NOTE: nullByte/nullMask MUST be part of the description because they are + // baked into the generated IR as compile-time constants (see + // clearAccumulatorNull/setAccumulatorNull). The description drives the JIT + // module cache key (functionName_); omitting them lets two slots that share + // the same aggregate semantics and accumulator offset but have a different + // null-bit layout reuse the same compiled code, which would read-modify-write + // the wrong null bit and corrupt a neighboring grouping key's null flag. + return fmt::format( + "{}_raw{}_partial{}({})->{}@{}@nb{}@nm{}", + hashAggrJitKindName(desc.kind), + desc.context.isRawInput, + desc.context.isPartialOutput, + inputs.str(), + desc.context.outputType()->toString(), + offset, + nullByte, + nullMask); +} + +HashAggrJitChunk::HashAggrJitChunk(std::vector slots) + : slots_(std::move(slots)) { + const auto description = getDescription(); + functionName_ = fmt::format( + "jit_hashaggr_v2_n{}_h{:016x}", + slots_.size(), + bits::hashBytes(1, description.data(), description.size())); +} + +std::string HashAggrJitChunk::getDescription() const { + std::ostringstream out; + for (const auto& slot : slots_) { + out << slot.getDescription() << ";"; + } + return out.str(); +} + +std::string hashAggrJitKindName(HashAggrJitKind kind) { + switch (kind) { + case HashAggrJitKind::Count: + return "count"; + case HashAggrJitKind::Sum: + return "sum"; + case HashAggrJitKind::DecimalSum: + return "decimal_sum"; + case HashAggrJitKind::Min: + return "min"; + case HashAggrJitKind::Max: + return "max"; + case HashAggrJitKind::Avg: + return "avg"; + case HashAggrJitKind::DecimalAvg: + return "decimal_avg"; + } + return "unknown"; +} + +std::string hashAggrJitValueKindName(HashAggrJitValueKind kind) { + switch (kind) { + case HashAggrJitValueKind::Bool: + return "bool"; + case HashAggrJitValueKind::Int8: + return "i8"; + case HashAggrJitValueKind::Int16: + return "i16"; + case HashAggrJitValueKind::Int32: + return "i32"; + case HashAggrJitValueKind::Int64: + return "i64"; + case HashAggrJitValueKind::Int128: + return "i128"; + case HashAggrJitValueKind::Float: + return "f32"; + case HashAggrJitValueKind::Double: + return "f64"; + } + return "unknown"; +} + +std::optional hashAggrJitValueKind(TypeKind kind) { + switch (kind) { + case TypeKind::BOOLEAN: + return HashAggrJitValueKind::Bool; + case TypeKind::TINYINT: + return HashAggrJitValueKind::Int8; + case TypeKind::SMALLINT: + return HashAggrJitValueKind::Int16; + case TypeKind::INTEGER: + return HashAggrJitValueKind::Int32; + case TypeKind::BIGINT: + return HashAggrJitValueKind::Int64; + case TypeKind::HUGEINT: + return HashAggrJitValueKind::Int128; + case TypeKind::REAL: + return HashAggrJitValueKind::Float; + case TypeKind::DOUBLE: + return HashAggrJitValueKind::Double; + default: + return std::nullopt; + } +} + +bool isHashAggrJitSupportedType(TypeKind kind) { + switch (kind) { + case TypeKind::BOOLEAN: + case TypeKind::TINYINT: + case TypeKind::SMALLINT: + case TypeKind::INTEGER: + case TypeKind::BIGINT: + case TypeKind::REAL: + case TypeKind::DOUBLE: + return true; + default: + return false; + } +} + +bool HashAggrJitChunk::codegen() { + if (ready_.load(std::memory_order_acquire)) { + return true; + } + auto* jit = ThrustJITv2::getInstance(); + if (jit == nullptr) { + return false; + } + const auto& moduleKey = functionName_; + const auto initFn = functionName_ + "_init"; + const auto addFn = functionName_ + "_add_dense"; + const auto addNoNullFn = functionName_ + "_add_dense_no_null"; + const auto extractFn = functionName_ + "_extract"; + module_ = jit->CompileModule( + [&](llvm::Module& module) { + const bool ok = genInitIR(module, initFn, slots_) && + genAddDenseIR(module, addFn, slots_, true) && + genAddDenseIR(module, addNoNullFn, slots_, false) && + genExtractIR(module, extractFn, slots_); + const bool hasError = !ok; + logHashAggrJitFunctionIR(module, moduleKey, initFn, "init", hasError); + logHashAggrJitFunctionIR(module, moduleKey, addFn, "add_dense", hasError); + logHashAggrJitFunctionIR( + module, + moduleKey, + addNoNullFn, + "add_dense_no_null", + hasError); + logHashAggrJitFunctionIR( + module, moduleKey, extractFn, "extract", hasError); + return hasError; + }, + moduleKey); + if (!module_) { + return false; + } + init_ = reinterpret_cast(module_->getFuncPtr(initFn)); + addDense_ = reinterpret_cast(module_->getFuncPtr(addFn)); + addDenseNoNull_ = reinterpret_cast( + module_->getFuncPtr(addNoNullFn)); + extract_ = reinterpret_cast(module_->getFuncPtr(extractFn)); + if (init_ == nullptr || addDense_ == nullptr || addDenseNoNull_ == nullptr || + extract_ == nullptr) { + return false; + } + // Publish all function pointers before flipping ready_ so the query thread + // observing isCodegenReady()==true also sees fully-initialized pointers. + ready_.store(true, std::memory_order_release); + return true; +} + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h new file mode 100644 index 000000000..148070882 --- /dev/null +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -0,0 +1,334 @@ +#pragma once + +#ifdef ENABLE_BOLT_JIT + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "bolt/jit/CompiledModule.h" +#include "bolt/jit/aggregation/HashAggrJitTypes.h" +#include "bolt/type/Type.h" + +namespace bytedance::bolt::jit { + +class HashAggrJitCodegen; +class InputAdapterCodegen; +class OutputAdapterCodegen; +struct HashAggrJitExtractTarget; + +struct HashAggrJitOps { + using CreateFn = + void (*)(HashAggrJitCodegen&, llvm::Value* group, const HashAggrJitSlot&); + using AddFn = void (*)( + HashAggrJitCodegen&, + llvm::Value* group, + const InputAdapterCodegen& input, + llvm::Value* row, + const HashAggrJitSlot&, + llvm::BasicBlock* nextBlock); + using ExtractFn = void (*)( + HashAggrJitCodegen&, + llvm::Value* group, + const HashAggrJitSlot&, + const HashAggrJitExtractTarget&); + + CreateFn initGroup; + AddFn addRawInput; + AddFn addIntermediateResults; + // Writes the intermediate (partial) accumulator state to the output, mirroring + // the non-JIT extractAccumulators path. + ExtractFn extractAccumulators; + // Writes the final aggregate result to the output, mirroring the non-JIT + // extractValues/extractResults path. + ExtractFn extractResults; +}; + +struct HashAggrJitExtractTarget { + // Codegen adapter for the destination output vector to write results into. + const OutputAdapterCodegen& output; + // The target row index (runtime llvm::Value) to write the extracted result. + llvm::Value* row; +}; + +class IRRow { + public: + // Framework-level invariant: IRRow = {T, i1}. 'valueType' is owned by + // aggregate semantics; the null bit is always field 1. + static llvm::StructType* getType( + llvm::IRBuilder<>& builder, + llvm::Type* valueType) { + return llvm::StructType::get(valueType, builder.getInt1Ty()); + } + + static llvm::Value* getValue(llvm::IRBuilder<>& builder, llvm::Value* row) { + return builder.CreateExtractValue(row, {0}); + } + + static llvm::Value* getIsNull(llvm::IRBuilder<>& builder, llvm::Value* row) { + return builder.CreateExtractValue(row, {1}); + } + + static llvm::Value* + pack(llvm::IRBuilder<>& builder, llvm::Value* value, llvm::Value* isNull) { + auto* rowType = getType(builder, value->getType()); + auto* withValue = + builder.CreateInsertValue(llvm::UndefValue::get(rowType), value, {0}); + return builder.CreateInsertValue(withValue, isNull, {1}); + } + + static llvm::Value* + withValue(llvm::IRBuilder<>& builder, llvm::Value* row, llvm::Value* value) { + return builder.CreateInsertValue(row, value, {0}); + } + + static llvm::Value* withIsNull( + llvm::IRBuilder<>& builder, + llvm::Value* row, + llvm::Value* isNull) { + return builder.CreateInsertValue(row, isNull, {1}); + } + + // Nested value access for aggregate-owned composite payloads, e.g. + // IRRow<{double, i64}> = {{double, i64}, i1}. + static llvm::Value* + getValueField(llvm::IRBuilder<>& builder, llvm::Value* row, unsigned field) { + return builder.CreateExtractValue(row, {0, field}); + } +}; + +class InputAdapterCodegen { + public: + virtual ~InputAdapterCodegen() = default; + + virtual llvm::StructType* irRowType(HashAggrJitValueKind kind) const = 0; + virtual llvm::Value* read(llvm::Value* row, HashAggrJitValueKind kind) + const = 0; + virtual llvm::Value* loadNulls() const = 0; + virtual llvm::Value* isNull(llvm::Value* row) const = 0; + virtual llvm::Value* readRowField( + llvm::Value* row, + int32_t field, + HashAggrJitValueKind kind) const = 0; + // Reads only the raw value of a ROW child, skipping the per-field null check + // CFG. Use when the framework guarantees the field is non-null on this path + // (i.e. the field's null bit is not consumed by the aggregate semantics). + virtual llvm::Value* readRowFieldValue( + llvm::Value* row, + int32_t field, + HashAggrJitValueKind kind) const = 0; +}; + +class ScalarInputAdapterCodegen final : public InputAdapterCodegen { + public: + ScalarInputAdapterCodegen(HashAggrJitCodegen& codegen, llvm::Value* input); + + llvm::StructType* irRowType(HashAggrJitValueKind kind) const override; + llvm::Value* read(llvm::Value* row, HashAggrJitValueKind kind) const override; + llvm::Value* loadNulls() const override; + llvm::Value* isNull(llvm::Value* row) const override; + llvm::Value* readRowField(llvm::Value*, int32_t, HashAggrJitValueKind) + const override; + llvm::Value* readRowFieldValue(llvm::Value*, int32_t, HashAggrJitValueKind) + const override; + + private: + HashAggrJitCodegen& codegen_; + llvm::Value* input_; +}; + +class RowInputAdapterCodegen final : public InputAdapterCodegen { + public: + RowInputAdapterCodegen(HashAggrJitCodegen& codegen, llvm::Value* input); + + llvm::StructType* irRowType(HashAggrJitValueKind kind) const override; + llvm::Value* read(llvm::Value*, HashAggrJitValueKind) const override; + llvm::Value* loadNulls() const override; + llvm::Value* isNull(llvm::Value* row) const override; + llvm::Value* readRowField( + llvm::Value* row, + int32_t field, + HashAggrJitValueKind kind) const override; + llvm::Value* readRowFieldValue( + llvm::Value* row, + int32_t field, + HashAggrJitValueKind kind) const override; + + private: + llvm::Value* loadChild(int32_t field) const; + llvm::Value* isRowFieldNull(llvm::Value* row, int32_t field) const; + + HashAggrJitCodegen& codegen_; + llvm::Value* input_; +}; + +class OutputAdapterCodegen { + public: + virtual ~OutputAdapterCodegen() = default; + + virtual llvm::Value* vector() const = 0; + virtual void resize(llvm::Value* size) const = 0; + virtual void write( + llvm::Value* row, + HashAggrJitValueKind kind, + llvm::Value* irRow) const = 0; + virtual void writeField( + llvm::Value* row, + int32_t field, + HashAggrJitValueKind kind, + llvm::Value* irRow) const = 0; + virtual void writeNull(llvm::Value* row, llvm::Value* isNull) const = 0; +}; + +class ScalarOutputAdapterCodegen final : public OutputAdapterCodegen { + public: + ScalarOutputAdapterCodegen(HashAggrJitCodegen& codegen, llvm::Value* output); + + llvm::Value* vector() const override; + void resize(llvm::Value* size) const override; + void write( + llvm::Value* row, + HashAggrJitValueKind kind, + llvm::Value* irRow) const override; + void writeField(llvm::Value*, int32_t, HashAggrJitValueKind, llvm::Value*) + const override; + void writeNull(llvm::Value* row, llvm::Value* isNull) const override; + + private: + HashAggrJitCodegen& codegen_; + llvm::Value* output_; +}; + +class RowOutputAdapterCodegen final : public OutputAdapterCodegen { + public: + RowOutputAdapterCodegen(HashAggrJitCodegen& codegen, llvm::Value* output); + + llvm::Value* vector() const override; + void resize(llvm::Value* size) const override; + void write(llvm::Value*, HashAggrJitValueKind, llvm::Value*) const override; + void writeField( + llvm::Value* row, + int32_t field, + HashAggrJitValueKind kind, + llvm::Value* irRow) const override; + void writeNull(llvm::Value* row, llvm::Value* isNull) const override; + + private: + llvm::Value* loadChild(int32_t field) const; + + HashAggrJitCodegen& codegen_; + llvm::Value* output_; +}; + +class HashAggrJitCodegen { + public: + explicit HashAggrJitCodegen(llvm::Module& module); + + llvm::Module& module() const { + return module_; + } + + llvm::IRBuilder<>& builder() const { + return *builder_; + } + + void setBuilder(llvm::IRBuilder<>* builder) { + builder_ = builder; + } + + llvm::Type* llvmType(HashAggrJitValueKind kind) const; + llvm::Value* isInputNull(llvm::Value* nulls, llvm::Value* row) const; + llvm::Value* isAccumulatorNull( + llvm::Value* group, + const HashAggrJitSlot& slot) const; + void clearAccumulatorNull(llvm::Value* group, const HashAggrJitSlot& slot) + const; + void setAccumulatorNull(llvm::Value* group, const HashAggrJitSlot& slot) + const; + llvm::LoadInst* loadValue(llvm::Value* row, llvm::Type* type, int32_t offset) + const; + void storeValue( + llvm::Value* row, + llvm::Type* type, + int32_t offset, + llvm::Value* value) const; + llvm::Value* castValue( + llvm::Value* value, + HashAggrJitValueKind from, + HashAggrJitValueKind to) const; + bool isFloatKind(HashAggrJitValueKind kind) const; + + private: + llvm::Module& module_; + llvm::IRBuilder<>* builder_{nullptr}; +}; + +using HashAggrJitAddDenseFunc = + void (*)(char** groups, int32_t numRows, char** inputRuntimes); +using HashAggrJitInitFunc = void (*)(char** newGroups, int32_t numNewGroups); +using HashAggrJitExtractFunc = void (*)(char** groups, int32_t numGroups, char** resultVectors); + +class HashAggrJitChunk { + public: + explicit HashAggrJitChunk(std::vector slots); + + bool codegen(); + + bool isCodegenReady() const { + return ready_.load(std::memory_order_acquire); + } + + void init(char** newGroups, int32_t numNewGroups) const { + init_(newGroups, numNewGroups); + } + + void addDense( + char** groups, + int32_t numRows, + char** inputRuntimes, + bool inputsMayHaveNulls) const { + if (!inputsMayHaveNulls && addDenseNoNull_ != nullptr) { + addDenseNoNull_(groups, numRows, inputRuntimes); + return; + } + addDense_(groups, numRows, inputRuntimes); + } + + void extract(char** groups, int32_t numGroups, char** resultVectors) const { + extract_(groups, numGroups, resultVectors); + } + + const std::vector& slots() const { + return slots_; + } + + std::string getDescription() const; + + const std::string& functionName() const { + return functionName_; + } + + private: + std::vector slots_; + std::string functionName_; + CompiledModuleSP module_; + HashAggrJitInitFunc init_{nullptr}; + HashAggrJitAddDenseFunc addDense_{nullptr}; + HashAggrJitAddDenseFunc addDenseNoNull_{nullptr}; + HashAggrJitExtractFunc extract_{nullptr}; + // Published last by codegen() (release) and read by isCodegenReady() + // (acquire). Lets the query thread fall back to non-JIT while background + // compilation is still in progress, then switch to JIT once ready. + std::atomic ready_{false}; +}; + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/HashAggrJitDecimalState.h b/bolt/jit/aggregation/HashAggrJitDecimalState.h new file mode 100644 index 000000000..562daf594 --- /dev/null +++ b/bolt/jit/aggregation/HashAggrJitDecimalState.h @@ -0,0 +1,25 @@ +#pragma once + +#ifdef ENABLE_BOLT_JIT + +#include +#include + +#include "bolt/functions/lib/aggregates/DecimalAccumulatorLayout.h" + +namespace bytedance::bolt::jit { + +// JIT-internal decimal accumulator layouts. These alias the shared POD layout +// bases that the non-JIT accumulators (DecimalSum / LongDecimalWithOverflowState) +// also derive from, so the JIT and non-JIT in-memory layouts stay in sync by +// construction (no mirrored copy to drift). The codegen / extract runtime read +// fields via offsetof on these aliases. +using JitDecimalSumState = functions::aggregate::DecimalSumAccumulatorLayout; +using JitDecimalAvgState = functions::aggregate::LongDecimalWithOverflowLayout; + +static_assert(std::is_standard_layout_v); +static_assert(std::is_standard_layout_v); + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/HashAggrJitTypes.h b/bolt/jit/aggregation/HashAggrJitTypes.h new file mode 100644 index 000000000..982acbca2 --- /dev/null +++ b/bolt/jit/aggregation/HashAggrJitTypes.h @@ -0,0 +1,216 @@ +#pragma once + +#ifdef ENABLE_BOLT_JIT + +#include +#include +#include +#include + +#include "bolt/type/Type.h" + +// Lightweight, LLVM-free metadata shared between aggregate functions (which +// only produce a HashAggrJitDescriptor) and the JIT codegen layer. Keeping this +// header free of lets Aggregate.h and other non-JIT translation +// units depend on the JIT planning interface without pulling in the heavy LLVM +// IR headers. The codegen-only declarations (HashAggrJitOps function-pointer +// table, HashAggrJitCodegen, HashAggrJitChunk) live in HashAggrJit.h. + +namespace bytedance::bolt::jit { + +// Runtime scalar input consumed by JIT add_dense functions. 'indices' maps the +// add_dense row to the scalar value row. The owner decides the indexing +// contract for 'nulls': top-level scalar inputs pass row-indexed nulls, while +// ROW child scalar inputs pass child/base-indexed nulls and the row adapter +// applies 'indices' before checking the bit. +struct HashAggrJitScalarInputRuntime { + const void* values{nullptr}; + const int32_t* indices{nullptr}; + const uint64_t* nulls{nullptr}; +}; + +// Runtime ROW input. ROW itself has no value/indices wrapping in the generated +// IR; children are scalar runtimes. Current JIT merge inputs only require +// row-of-scalars, so recursive ROW children are intentionally not represented. +struct HashAggrJitRowInputRuntime { + const uint64_t* nulls{nullptr}; + const HashAggrJitScalarInputRuntime* const* children{nullptr}; + int32_t numChildren{0}; +}; + +// Shape-less runtime input. The generated code knows at compile time whether a +// slot reads a scalar or row input and selects the corresponding union member +// through InputAdapterCodegen. +union HashAggrJitInputRuntime { + HashAggrJitInputRuntime() : scalar{} {} + + HashAggrJitScalarInputRuntime scalar; + HashAggrJitRowInputRuntime row; +}; + +// Runtime scalar output consumed by JIT extract functions. Primitive flat +// outputs write values/null bits directly from generated IR; outputs that need +// vector semantics keep using helper fallbacks via 'vector'. +struct HashAggrJitScalarOutputRuntime { + void* values{nullptr}; + uint64_t* nulls{nullptr}; + void* vector{nullptr}; +}; + +// Runtime ROW output. Current JIT partial outputs only require row-of-scalars, +// so recursive ROW children are intentionally not represented. 'vector' keeps +// the top-level vector available for helper based complex writes (e.g. decimal +// partial extract), while child scalar runtimes expose raw field buffers for +// direct generated-IR stores. +struct HashAggrJitRowOutputRuntime { + uint64_t* nulls{nullptr}; + HashAggrJitScalarOutputRuntime* const* children{nullptr}; + int32_t numChildren{0}; + void* vector{nullptr}; +}; + +// Shape-less runtime output. The generated extract code knows at compile time +// whether a slot writes a scalar or row output and selects the corresponding +// union member through OutputAdapterCodegen. +union HashAggrJitOutputRuntime { + HashAggrJitOutputRuntime() : scalar{} {} + + HashAggrJitScalarOutputRuntime scalar; + HashAggrJitRowOutputRuntime row; +}; + +// Stage-agnostic, absolute description of an aggregate's types. The three type +// fields below always hold the same values regardless of which stage +// (raw/intermediate/partial/final) this context represents; the active stage is +// selected purely by the isRawInput/isPartialOutput flags. The inputTypes() and +// outputType() accessors derive the stage-specific view from those flags, so a +// flag flip (e.g. by a companion function) automatically yields a consistent +// input/output type without any separate type-rewrite step. +struct HashAggrJitPlanContext { + bool isRawInput{false}; + bool isPartialOutput{false}; + // Original (raw) input argument types. Empty for count(*). + std::vector rawInputTypes; + // The intermediate accumulator type (i.e. the partial output / merge input). + TypePtr intermediateType; + // The final aggregate result type. + TypePtr resultType; + + // Stage-derived input view: raw inputs when reading raw input, otherwise the + // single intermediate accumulator type. + std::vector inputTypes() const { + if (isRawInput) { + return rawInputTypes; + } + return {intermediateType}; + } + + // Stage-derived output view: the intermediate accumulator type for partial + // output, otherwise the final result type. + TypePtr outputType() const { + return isPartialOutput ? intermediateType : resultType; + } + + bool isCountStar() const { + return isRawInput && rawInputTypes.empty(); + } +}; + +enum class HashAggrJitKind : uint8_t { + Count, + Sum, + DecimalSum, + Min, + Max, + Avg, + DecimalAvg, +}; + +enum class HashAggrJitValueKind : uint8_t { + Bool, + Int8, + Int16, + Int32, + Int64, + Int128, + Float, + Double, +}; + +enum class HashAggrJitRuntimeShape : uint8_t { + Scalar, + Row, +}; + +// Forward declaration: the codegen function-pointer table is defined in +// HashAggrJit.h (it references llvm:: types). Descriptors only hold a pointer +// to it, so a forward declaration is enough here and keeps this header +// LLVM-free. +struct HashAggrJitOps; + +struct HashAggrJitDescriptor { + HashAggrJitKind kind; + HashAggrJitValueKind rawInputKind; + HashAggrJitValueKind accumulatorKind; + HashAggrJitPlanContext context; + const HashAggrJitOps* ops{nullptr}; + + bool isCountStar() const { + return context.isCountStar(); + } + + bool isRawInput() const { + return context.isRawInput; + } + + bool isDecimal() const { + return kind == HashAggrJitKind::DecimalSum || + kind == HashAggrJitKind::DecimalAvg; + } + + HashAggrJitRuntimeShape inputShape() const { + const auto inputTypes = context.inputTypes(); + return inputTypes.size() == 1 && inputTypes[0] && inputTypes[0]->isRow() + ? HashAggrJitRuntimeShape::Row + : HashAggrJitRuntimeShape::Scalar; + } + + HashAggrJitRuntimeShape outputShape() const { + const auto outputType = context.outputType(); + return outputType && outputType->isRow() ? HashAggrJitRuntimeShape::Row + : HashAggrJitRuntimeShape::Scalar; + } + + // std::string signature() const; +}; + +struct HashAggrJitSlot { + int32_t aggregateIndex; + int32_t offset; + int32_t nullByte; + uint8_t nullMask; + + HashAggrJitDescriptor desc; + + std::string getDescription() const; +}; + +bool isHashAggrJitSupportedType(TypeKind kind); +std::optional hashAggrJitValueKind(TypeKind kind); +std::string hashAggrJitKindName(HashAggrJitKind kind); +std::string hashAggrJitValueKindName(HashAggrJitValueKind kind); + +// Per-aggregate codegen function tables. The definitions (which reference +// llvm:: types) live in bolt/jit/aggregation/ops/*Ops.cpp. Aggregate functions +// only need the returned pointer to populate HashAggrJitDescriptor::ops, so +// these declarations stay LLVM-free here. +const HashAggrJitOps* getCountOps(); +const HashAggrJitOps* getMinMaxOps(); +const HashAggrJitOps* getSumOps(); +const HashAggrJitOps* getAvgOps(); +const HashAggrJitOps* getDecimalSumOps(); +const HashAggrJitOps* getDecimalAvgOps(); + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/ops/AvgOps.cpp b/bolt/jit/aggregation/ops/AvgOps.cpp new file mode 100644 index 000000000..98463f188 --- /dev/null +++ b/bolt/jit/aggregation/ops/AvgOps.cpp @@ -0,0 +1,160 @@ +/* + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifdef ENABLE_BOLT_JIT + +#include "bolt/jit/aggregation/HashAggrJit.h" + +#include + +#include "bolt/functions/lib/aggregates/SumCount.h" + +namespace bytedance::bolt::jit { + +namespace { + +// Single source of truth for the AVG intermediate layout: derive the JIT field +// offsets from the non-JIT SumCount struct so a change to SumCount is picked up +// here automatically instead of silently desyncing a mirrored copy. +using AvgAccumulatorLayout = functions::aggregate::SumCount; + +static_assert(std::is_standard_layout_v); + +constexpr int32_t kAvgCountOffset = offsetof(AvgAccumulatorLayout, count); + +void compileAvgInitGroup( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot) { + codegen.setAccumulatorNull(group, slot); + codegen.storeValue( + group, + codegen.llvmType(slot.desc.accumulatorKind), + slot.offset, + llvm::ConstantFP::get(codegen.llvmType(slot.desc.accumulatorKind), 0.0)); + codegen.storeValue( + group, + codegen.builder().getInt64Ty(), + slot.offset + kAvgCountOffset, + codegen.builder().getInt64(0)); +} + +void compileAvgAddRawInput( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const InputAdapterCodegen& input, + llvm::Value* row, + const HashAggrJitSlot& slot, + llvm::BasicBlock*) { + auto* inputRow = input.read(row, slot.desc.rawInputKind); + auto* rawValue = IRRow::getValue(codegen.builder(), inputRow); + auto* value = codegen.castValue( + rawValue, slot.desc.rawInputKind, slot.desc.accumulatorKind); + codegen.clearAccumulatorNull(group, slot); + auto* oldSum = codegen.loadValue( + group, codegen.llvmType(slot.desc.accumulatorKind), slot.offset); + codegen.storeValue( + group, + codegen.llvmType(slot.desc.accumulatorKind), + slot.offset, + codegen.builder().CreateFAdd(oldSum, value)); + auto* oldCount = codegen.loadValue( + group, + codegen.builder().getInt64Ty(), + slot.offset + kAvgCountOffset); + codegen.storeValue( + group, + codegen.builder().getInt64Ty(), + slot.offset + kAvgCountOffset, + codegen.builder().CreateAdd(oldCount, codegen.builder().getInt64(1))); +} + +void compileAvgAddIntermediateResults( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const InputAdapterCodegen& input, + llvm::Value* row, + const HashAggrJitSlot& slot, + llvm::BasicBlock*) { + codegen.clearAccumulatorNull(group, slot); + auto* sum = input.readRowFieldValue(row, 0, HashAggrJitValueKind::Double); + auto* count = input.readRowFieldValue(row, 1, HashAggrJitValueKind::Int64); + auto* oldSum = + codegen.loadValue(group, codegen.builder().getDoubleTy(), slot.offset); + codegen.storeValue( + group, + codegen.builder().getDoubleTy(), + slot.offset, + codegen.builder().CreateFAdd(oldSum, sum)); + auto* oldCount = codegen.loadValue( + group, + codegen.builder().getInt64Ty(), + slot.offset + kAvgCountOffset); + codegen.storeValue( + group, + codegen.builder().getInt64Ty(), + slot.offset + kAvgCountOffset, + codegen.builder().CreateAdd(oldCount, count)); +} + +// Intermediate output is row(sum:double, count:bigint). All-null group yields +// (0, 0) with a non-null top-level row (isNull = 0), matching the non-JIT +// extractAccumulators path. +void compileAvgExtractAccumulators( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + auto& builder = codegen.builder(); + auto* sum = codegen.loadValue(group, builder.getDoubleTy(), slot.offset); + auto* count = codegen.loadValue( + group, builder.getInt64Ty(), slot.offset + kAvgCountOffset); + target.output.writeField( + target.row, + 0, + HashAggrJitValueKind::Double, + IRRow::pack(builder, sum, builder.getFalse())); + target.output.writeField( + target.row, + 1, + HashAggrJitValueKind::Int64, + IRRow::pack(builder, count, builder.getFalse())); + target.output.writeNull(target.row, builder.getFalse()); +} + +// Final output is double avg. count == 0 means all inputs were null -> null. +void compileAvgExtractValues( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + auto& builder = codegen.builder(); + auto* sum = codegen.loadValue(group, builder.getDoubleTy(), slot.offset); + auto* count = codegen.loadValue( + group, builder.getInt64Ty(), slot.offset + kAvgCountOffset); + auto* isNull = builder.CreateICmpEQ(count, builder.getInt64(0)); + auto* countAsDouble = builder.CreateSIToFP(count, builder.getDoubleTy()); + auto* avg = builder.CreateFDiv(sum, countAsDouble); + target.output.write( + target.row, + HashAggrJitValueKind::Double, + IRRow::pack(builder, avg, isNull)); +} + +} // namespace + +const HashAggrJitOps* getAvgOps() { + static const HashAggrJitOps kOps{ + &compileAvgInitGroup, + &compileAvgAddRawInput, + &compileAvgAddIntermediateResults, + &compileAvgExtractAccumulators, + &compileAvgExtractValues}; + return &kOps; +} + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/ops/CountOps.cpp b/bolt/jit/aggregation/ops/CountOps.cpp new file mode 100644 index 000000000..cc43c20ea --- /dev/null +++ b/bolt/jit/aggregation/ops/CountOps.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifdef ENABLE_BOLT_JIT + +#include "bolt/jit/aggregation/HashAggrJit.h" + +namespace bytedance::bolt::jit { + +namespace { + +void compileCountInitGroup( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot) { + codegen.storeValue( + group, + codegen.builder().getInt64Ty(), + slot.offset, + codegen.builder().getInt64(0)); +} + +void addInc( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + llvm::Value* inc) { + auto* state = + codegen.loadValue(group, codegen.builder().getInt64Ty(), slot.offset); + codegen.storeValue( + group, + codegen.builder().getInt64Ty(), + slot.offset, + codegen.builder().CreateAdd(state, inc)); +} + +void compileCountAddRawInput( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const InputAdapterCodegen& /*input*/, + llvm::Value* /*row*/, + const HashAggrJitSlot& slot, + llvm::BasicBlock*) { + addInc(codegen, group, slot, codegen.builder().getInt64(1)); +} + +void compileCountAddIntermediateResults( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const InputAdapterCodegen& input, + llvm::Value* row, + const HashAggrJitSlot& slot, + llvm::BasicBlock*) { + llvm::Value* inc = nullptr; + if (slot.desc.isCountStar()) { + inc = codegen.builder().getInt64(1); + } else { + auto* inputRow = input.read(row, slot.desc.rawInputKind); + inc = codegen.castValue( + IRRow::getValue(codegen.builder(), inputRow), + slot.desc.rawInputKind, + HashAggrJitValueKind::Int64); + } + addInc(codegen, group, slot, inc); +} + +// Count's intermediate accumulator and final result share the same scalar +// representation, so partial/final extract emit identical IR. The two named +// entry points below both forward to this helper. +void compileCountExtract( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + auto* value = + codegen.loadValue(group, codegen.builder().getInt64Ty(), slot.offset); + target.output.write( + target.row, + HashAggrJitValueKind::Int64, + IRRow::pack(codegen.builder(), value, codegen.builder().getFalse())); +} + +void compileCountExtractAccumulators( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + compileCountExtract(codegen, group, slot, target); +} + +void compileCountExtractValues( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + compileCountExtract(codegen, group, slot, target); +} + +} // namespace + +const HashAggrJitOps* getCountOps() { + static const HashAggrJitOps kOps{ + &compileCountInitGroup, + &compileCountAddRawInput, + &compileCountAddIntermediateResults, + &compileCountExtractAccumulators, + &compileCountExtractValues}; + return &kOps; +} + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp new file mode 100644 index 000000000..86503ffec --- /dev/null +++ b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp @@ -0,0 +1,194 @@ +/* + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifdef ENABLE_BOLT_JIT + +#include +#include + +#include "bolt/jit/aggregation/HashAggrJit.h" +#include "bolt/jit/aggregation/HashAggrJitDecimalState.h" +#include "bolt/jit/aggregation/ops/DecimalOps.h" +#include "bolt/type/Type.h" + +namespace bytedance::bolt::jit { + +namespace { + +// Field offsets within JitDecimalAvgState, relative to slot.offset. +constexpr int32_t kSumOffset = + static_cast(offsetof(JitDecimalAvgState, sum)); +constexpr int32_t kCountOffset = + static_cast(offsetof(JitDecimalAvgState, count)); +constexpr int32_t kOverflowOffset = + static_cast(offsetof(JitDecimalAvgState, overflow)); + +void compileDecimalAvgInitGroup( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot) { + auto& b = codegen.builder(); + codegen.setAccumulatorNull(group, slot); + // sum = 0 (i128), count = 0 (i64), overflow = 0 (i64). + codegen.storeValue( + group, + b.getInt128Ty(), + slot.offset + kSumOffset, + llvm::ConstantInt::get(b.getInt128Ty(), 0)); + codegen.storeValue( + group, b.getInt64Ty(), slot.offset + kCountOffset, b.getInt64(0)); + codegen.storeValue( + group, b.getInt64Ty(), slot.offset + kOverflowOffset, b.getInt64(0)); +} + +void compileDecimalAvgAddRawInput( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const InputAdapterCodegen& input, + llvm::Value* row, + const HashAggrJitSlot& slot, + llvm::BasicBlock*) { + auto& b = codegen.builder(); + auto* inputRow = input.read(row, slot.desc.rawInputKind); + auto* rawValue = IRRow::getValue(codegen.builder(), inputRow); + auto* value = codegen.castValue( + rawValue, slot.desc.rawInputKind, HashAggrJitValueKind::Int128); + codegen.clearAccumulatorNull(group, slot); + emitDecimalAddWithOverflow( + codegen, group, slot.offset + kSumOffset, slot.offset + kOverflowOffset, value); + // ++count. + auto* oldCount = + codegen.loadValue(group, b.getInt64Ty(), slot.offset + kCountOffset); + codegen.storeValue( + group, + b.getInt64Ty(), + slot.offset + kCountOffset, + b.CreateAdd(oldCount, b.getInt64(1))); +} + +void compileDecimalAvgAddIntermediateResults( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const InputAdapterCodegen& input, + llvm::Value* row, + const HashAggrJitSlot& slot, + llvm::BasicBlock* nextBlock) { + auto& b = codegen.builder(); + auto* function = b.GetInsertBlock()->getParent(); + auto* continueBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "decimal_avg_merge_cont", + function, + nextBlock); + auto* overflowBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "decimal_avg_merge_overflow", + function, + continueBlock); + auto* mergeBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), "decimal_avg_merge", function, continueBlock); + const auto [sumPrecision, _] = + hashAggrJitDecimalPrecisionScale(slot.desc.context.inputTypes()[0]); + const auto sumKind = hashAggrJitDecimalKindForPrecision(sumPrecision); + auto* sumRow = input.readRowField(row, 0, sumKind); + auto* countRow = input.readRowField(row, 1, HashAggrJitValueKind::Int64); + auto* sumIsNull = IRRow::getIsNull(b, sumRow); + auto* countIsNull = IRRow::getIsNull(b, countRow); + auto* count = IRRow::getValue(b, countRow); + auto* countPositive = b.CreateICmpSGT(count, b.getInt64(0)); + auto* isOverflow = b.CreateAnd( + sumIsNull, b.CreateAnd(b.CreateNot(countIsNull), countPositive)); + b.CreateCondBr(isOverflow, overflowBlock, mergeBlock); + + b.SetInsertPoint(overflowBlock); + codegen.setAccumulatorNull(group, slot); + b.CreateBr(continueBlock); + + b.SetInsertPoint(mergeBlock); + auto* sum = IRRow::getValue(b, sumRow); + auto* value = codegen.castValue(sum, sumKind, HashAggrJitValueKind::Int128); + codegen.clearAccumulatorNull(group, slot); + emitDecimalAddWithOverflow( + codegen, group, slot.offset + kSumOffset, slot.offset + kOverflowOffset, value); + // count += incoming count. + auto* oldCount = + codegen.loadValue(group, b.getInt64Ty(), slot.offset + kCountOffset); + codegen.storeValue( + group, + b.getInt64Ty(), + slot.offset + kCountOffset, + b.CreateAdd(oldCount, count)); + b.CreateBr(continueBlock); + + b.SetInsertPoint(continueBlock); +} + +void emitDecimalAvgExtract( + HashAggrJitCodegen& codegen, + llvm::Value* vector, + llvm::Value* row, + llvm::Value* group, + const HashAggrJitSlot& slot) { + auto& b = codegen.builder(); + const bool partialOutput = slot.desc.context.isPartialOutput; + auto [inputPrecision, inputScale] = + hashAggrJitDecimalPrecisionScale(slot.desc.context.inputTypes()[0]); + if (slot.desc.isRawInput()) { + inputPrecision = std::min(38, inputPrecision + 10); + } + const auto [outputPrecision, outputScale] = + hashAggrJitDecimalPrecisionScale(slot.desc.context.outputType()); + const bool longDecimal = hashAggrJitDecimalKindForPrecision( + outputPrecision) == HashAggrJitValueKind::Int128; + const char* fn = partialOutput + ? (longDecimal ? "jit_HashAggrExtractPartialLongDecimalAvg" + : "jit_HashAggrExtractPartialShortDecimalAvg") + : (longDecimal ? "jit_HashAggrExtractFinalLongDecimalAvg" + : "jit_HashAggrExtractFinalShortDecimalAvg"); + b.CreateCall( + codegen.module().getFunction(fn), + {vector, + row, + group, + b.getInt32(slot.offset), + b.getInt32(inputPrecision), + b.getInt32(inputScale), + b.getInt32(outputPrecision), + b.getInt32(outputScale)}); +} + +void compileDecimalAvgExtractAccumulators( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + emitDecimalAvgExtract( + codegen, target.output.vector(), target.row, group, slot); +} + +void compileDecimalAvgExtractValues( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + emitDecimalAvgExtract( + codegen, target.output.vector(), target.row, group, slot); +} + +} // namespace + +const HashAggrJitOps* getDecimalAvgOps() { + static const HashAggrJitOps kOps{ + &compileDecimalAvgInitGroup, + &compileDecimalAvgAddRawInput, + &compileDecimalAvgAddIntermediateResults, + &compileDecimalAvgExtractAccumulators, + &compileDecimalAvgExtractValues}; + return &kOps; +} + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/ops/DecimalOps.h b/bolt/jit/aggregation/ops/DecimalOps.h new file mode 100644 index 000000000..ba5a3a5e7 --- /dev/null +++ b/bolt/jit/aggregation/ops/DecimalOps.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#ifdef ENABLE_BOLT_JIT + +#include "bolt/jit/aggregation/HashAggrJit.h" + +// Decimal-specific JIT codegen helper shared across decimal ops translation +// units. It lives with the decimal ops (rather than on the framework +// HashAggrJitCodegen) so decimal knowledge stays out of the generic framework, +// and is declared here because it is defined in DecimalSumOps.cpp but also used +// by DecimalAvgOps.cpp. Extract helpers that are used within a single TU stay +// file-local in their respective ops files. +namespace bytedance::bolt::jit { + +inline HashAggrJitValueKind hashAggrJitDecimalKindForPrecision( + int32_t precision) { + return precision > bytedance::bolt::ShortDecimalType::kMaxPrecision + ? HashAggrJitValueKind::Int128 + : HashAggrJitValueKind::Int64; +} + +inline const TypePtr& hashAggrJitDecimalValueType(const TypePtr& type) { + return type->isRow() ? type->childAt(0) : type; +} + +inline std::pair hashAggrJitDecimalPrecisionScale( + const TypePtr& type) { + const auto [precision, scale] = + getDecimalPrecisionScale(*hashAggrJitDecimalValueType(type)); + return {precision, scale}; +} + +// Inline i128 accumulate-with-overflow used by decimal sum/avg add+merge. +// Loads the i128 sum at 'group + sumOffset' and the i64 overflow counter at +// 'group + overflowOffset', computes sum += addend, updates the overflow +// counter by the carry direction, and stores both back. +void emitDecimalAddWithOverflow( + HashAggrJitCodegen& codegen, + llvm::Value* group, + int32_t sumOffset, + int32_t overflowOffset, + llvm::Value* addend); + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/ops/DecimalSumOps.cpp b/bolt/jit/aggregation/ops/DecimalSumOps.cpp new file mode 100644 index 000000000..0428537e8 --- /dev/null +++ b/bolt/jit/aggregation/ops/DecimalSumOps.cpp @@ -0,0 +1,258 @@ +/* + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifdef ENABLE_BOLT_JIT + +#include + +#include "bolt/jit/aggregation/HashAggrJit.h" +#include "bolt/jit/aggregation/HashAggrJitDecimalState.h" +#include "bolt/jit/aggregation/ops/DecimalOps.h" +#include "bolt/type/Type.h" + +namespace bytedance::bolt::jit { + +namespace { + +// Field offsets within JitDecimalSumState, relative to slot.offset. +constexpr int32_t kSumOffset = + static_cast(offsetof(JitDecimalSumState, sum)); +constexpr int32_t kOverflowOffset = + static_cast(offsetof(JitDecimalSumState, overflow)); +constexpr int32_t kIsEmptyOffset = + static_cast(offsetof(JitDecimalSumState, isEmpty)); + +void compileDecimalSumInitGroup( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot) { + auto& b = codegen.builder(); + codegen.setAccumulatorNull(group, slot); + // sum = 0 (i128), overflow = 0 (i64), isEmpty = true (i8). + codegen.storeValue( + group, + b.getInt128Ty(), + slot.offset + kSumOffset, + llvm::ConstantInt::get(b.getInt128Ty(), 0)); + codegen.storeValue( + group, b.getInt64Ty(), slot.offset + kOverflowOffset, b.getInt64(0)); + codegen.storeValue( + group, b.getInt8Ty(), slot.offset + kIsEmptyOffset, b.getInt8(1)); +} + +void compileDecimalSumAddRawInput( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const InputAdapterCodegen& input, + llvm::Value* row, + const HashAggrJitSlot& slot, + llvm::BasicBlock*) { + auto& b = codegen.builder(); + auto* inputRow = input.read(row, slot.desc.rawInputKind); + auto* rawValue = IRRow::getValue(codegen.builder(), inputRow); + auto* value = codegen.castValue( + rawValue, slot.desc.rawInputKind, HashAggrJitValueKind::Int128); + codegen.clearAccumulatorNull(group, slot); + emitDecimalAddWithOverflow( + codegen, + group, + slot.offset + kSumOffset, + slot.offset + kOverflowOffset, + value); + codegen.storeValue( + group, b.getInt8Ty(), slot.offset + kIsEmptyOffset, b.getInt8(0)); +} + +void compileDecimalSumAddIntermediateResults( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const InputAdapterCodegen& input, + llvm::Value* row, + const HashAggrJitSlot& slot, + llvm::BasicBlock* nextBlock) { + auto& b = codegen.builder(); + auto* function = b.GetInsertBlock()->getParent(); + auto* continueBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "decimal_sum_merge_cont", + function, + nextBlock); + auto* overflowBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "decimal_sum_merge_overflow", + function, + continueBlock); + auto* mergeBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), "decimal_sum_merge", function, continueBlock); + const auto [sumPrecision, _] = + hashAggrJitDecimalPrecisionScale(slot.desc.context.inputTypes()[0]); + const auto sumKind = hashAggrJitDecimalKindForPrecision(sumPrecision); + auto* sumRow = input.readRowField(row, 0, sumKind); + auto* incomingIsEmpty = + input.readRowFieldValue(row, 1, HashAggrJitValueKind::Bool); + auto* sumIsNull = IRRow::getIsNull(b, sumRow); + auto* isNotEmpty = b.CreateICmpEQ(incomingIsEmpty, b.getInt8(0)); + auto* isOverflow = b.CreateAnd(sumIsNull, isNotEmpty); + b.CreateCondBr(isOverflow, overflowBlock, mergeBlock); + + b.SetInsertPoint(overflowBlock); + codegen.setAccumulatorNull(group, slot); + b.CreateBr(continueBlock); + + b.SetInsertPoint(mergeBlock); + auto* sum = IRRow::getValue(b, sumRow); + auto* value = codegen.castValue(sum, sumKind, HashAggrJitValueKind::Int128); + codegen.clearAccumulatorNull(group, slot); + emitDecimalAddWithOverflow( + codegen, + group, + slot.offset + kSumOffset, + slot.offset + kOverflowOffset, + value); + // isEmpty = isEmpty && incomingIsEmpty. + auto* oldIsEmpty = + codegen.loadValue(group, b.getInt8Ty(), slot.offset + kIsEmptyOffset); + auto* bothEmpty = b.CreateAnd( + b.CreateICmpNE(oldIsEmpty, b.getInt8(0)), + b.CreateICmpNE(incomingIsEmpty, b.getInt8(0))); + codegen.storeValue( + group, + b.getInt8Ty(), + slot.offset + kIsEmptyOffset, + b.CreateZExt(bothEmpty, b.getInt8Ty())); + b.CreateBr(continueBlock); + + b.SetInsertPoint(continueBlock); +} + +void emitDecimalSumExtract( + HashAggrJitCodegen& codegen, + llvm::Value* vector, + llvm::Value* row, + llvm::Value* group, + const HashAggrJitSlot& slot) { + auto& b = codegen.builder(); + const bool partialOutput = slot.desc.context.isPartialOutput; + // long/short decimal and overflow precision are decided by the actual + // output decimal type of this aggregation stage. + const auto [outPrecision, outScale] = + hashAggrJitDecimalPrecisionScale(slot.desc.context.outputType()); + const bool longDecimal = hashAggrJitDecimalKindForPrecision(outPrecision) == + HashAggrJitValueKind::Int128; + const char* fn = partialOutput + ? (longDecimal ? "jit_HashAggrExtractPartialLongDecimalSum" + : "jit_HashAggrExtractPartialShortDecimalSum") + : (longDecimal ? "jit_HashAggrExtractFinalLongDecimalSum" + : "jit_HashAggrExtractFinalShortDecimalSum"); + // Mirror the non-JIT extract's leading `if (isNull(group))` check: a group + // whose accumulator null flag is set (e.g. an overflowed intermediate result + // merged in) must produce null, regardless of the sum/isEmpty fields. + auto* accumulatorIsNull = + b.CreateZExt(codegen.isAccumulatorNull(group, slot), b.getInt32Ty()); + b.CreateCall( + codegen.module().getFunction(fn), + {vector, + row, + group, + b.getInt32(slot.offset), + b.getInt32(outPrecision), + b.getInt32(outScale), + accumulatorIsNull}); +} + +void compileDecimalSumExtractAccumulators( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + emitDecimalSumExtract( + codegen, target.output.vector(), target.row, group, slot); +} + +void compileDecimalSumExtractValues( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + emitDecimalSumExtract( + codegen, target.output.vector(), target.row, group, slot); +} + +} // namespace + +const HashAggrJitOps* getDecimalSumOps() { + static const HashAggrJitOps kOps{ + &compileDecimalSumInitGroup, + &compileDecimalSumAddRawInput, + &compileDecimalSumAddIntermediateResults, + &compileDecimalSumExtractAccumulators, + &compileDecimalSumExtractValues}; + return &kOps; +} + +void emitDecimalAddWithOverflow( + HashAggrJitCodegen& codegen, + llvm::Value* group, + int32_t sumOffset, + int32_t overflowOffset, + llvm::Value* addend) { + auto& b = codegen.builder(); + auto* i128Ty = b.getInt128Ty(); + auto* i64Ty = b.getInt64Ty(); + auto* zero128 = llvm::ConstantInt::get(i128Ty, 0); + + auto* lhs = codegen.loadValue(group, i128Ty, sumOffset); + auto* rhs = addend; + + // Mirror DecimalUtil::addWithOverflow + addUnsignedValues exactly (i128 sum + // kept as low 127 bits plus a separate overflow counter), instead of a + // sign-flip heuristic which diverges once the true i128 magnitude overflows. + // + // same sign: + // mag = (|lhs| + |rhs|) & ~(1<<127) // low 127 bits + // carry = (|lhs| + |rhs|) >> 127 // bit 127 + // both negative -> sum = -mag, overflow = -carry + // both positive -> sum = mag, overflow = carry + // different sign: + // sum = lhs + rhs, overflow = 0 + auto* lhsNeg = b.CreateICmpSLT(lhs, zero128); + auto* rhsNeg = b.CreateICmpSLT(rhs, zero128); + auto* sameSign = b.CreateICmpEQ(lhsNeg, rhsNeg); + auto* bothNeg = b.CreateAnd(lhsNeg, rhsNeg); + + // Magnitudes for the same-sign path: negate operands when both negative so + // the unsigned addition operates on |lhs|, |rhs| (matches addUnsignedValues). + auto* absLhs = b.CreateSelect(bothNeg, b.CreateNeg(lhs), lhs); + auto* absRhs = b.CreateSelect(bothNeg, b.CreateNeg(rhs), rhs); + auto* unsignedSum = b.CreateAdd(absLhs, absRhs); + auto* mask127 = + llvm::ConstantInt::get(i128Ty, llvm::APInt::getSignedMinValue(128)); + // mag = unsignedSum & ~(1<<127) + auto* magnitude = b.CreateAnd(unsignedSum, b.CreateNot(mask127)); + // carry = (unsignedSum >> 127) & 1, as i64 + auto* carryBit = b.CreateAnd( + b.CreateLShr(unsignedSum, llvm::ConstantInt::get(i128Ty, 127)), + llvm::ConstantInt::get(i128Ty, 1)); + auto* carry64 = b.CreateTrunc(carryBit, i64Ty); + + auto* sameSignSum = b.CreateSelect(bothNeg, b.CreateNeg(magnitude), magnitude); + auto* sameSignOverflow = + b.CreateSelect(bothNeg, b.CreateNeg(carry64), carry64); + + auto* diffSignSum = b.CreateAdd(lhs, rhs); + + auto* newSum = b.CreateSelect(sameSign, sameSignSum, diffSignSum); + auto* overflowDelta = + b.CreateSelect(sameSign, sameSignOverflow, b.getInt64(0)); + + codegen.storeValue(group, i128Ty, sumOffset, newSum); + auto* oldOverflow = codegen.loadValue(group, i64Ty, overflowOffset); + codegen.storeValue( + group, i64Ty, overflowOffset, b.CreateAdd(oldOverflow, overflowDelta)); +} + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/ops/MinMaxOps.cpp b/bolt/jit/aggregation/ops/MinMaxOps.cpp new file mode 100644 index 000000000..fd2440fae --- /dev/null +++ b/bolt/jit/aggregation/ops/MinMaxOps.cpp @@ -0,0 +1,121 @@ +/* + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifdef ENABLE_BOLT_JIT + +#include "bolt/jit/aggregation/HashAggrJit.h" + +namespace bytedance::bolt::jit { + +namespace { + +void compileMinMaxInitGroup( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot) { + codegen.setAccumulatorNull(group, slot); + auto* type = codegen.llvmType(slot.desc.accumulatorKind); + if (codegen.isFloatKind(slot.desc.accumulatorKind)) { + codegen.storeValue(group, type, slot.offset, llvm::ConstantFP::get(type, 0.0)); + } else { + codegen.storeValue(group, type, slot.offset, llvm::ConstantInt::get(type, 0)); + } +} + +// min/max use the same logic for raw input and intermediate merge: both pick +// the better (min/max) of the decoded value and the current accumulator. +void compileMinMaxUpdate( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const InputAdapterCodegen& input, + llvm::Value* row, + const HashAggrJitSlot& slot, + llvm::BasicBlock*) { + auto* inputRow = input.read(row, slot.desc.rawInputKind); + auto* value = codegen.castValue( + IRRow::getValue(codegen.builder(), inputRow), + slot.desc.rawInputKind, + slot.desc.accumulatorKind); + auto* type = codegen.llvmType(slot.desc.accumulatorKind); + auto* oldValue = codegen.loadValue(group, type, slot.offset); + auto* nullState = codegen.isAccumulatorNull(group, slot); + llvm::Value* better = nullptr; + if (codegen.isFloatKind(slot.desc.accumulatorKind)) { + auto* oldIsNan = codegen.builder().CreateFCmpUNO(oldValue, oldValue); + auto* valueIsNan = codegen.builder().CreateFCmpUNO(value, value); + if (slot.desc.kind == HashAggrJitKind::Min) { + better = codegen.builder().CreateOr( + codegen.builder().CreateAnd(oldIsNan, codegen.builder().CreateNot(valueIsNan)), + codegen.builder().CreateAnd( + codegen.builder().CreateNot(valueIsNan), + codegen.builder().CreateFCmpOGT(oldValue, value))); + } else { + better = codegen.builder().CreateAnd( + codegen.builder().CreateNot(oldIsNan), + codegen.builder().CreateOr( + valueIsNan, codegen.builder().CreateFCmpOLT(oldValue, value))); + } + } else { + better = slot.desc.kind == HashAggrJitKind::Min + ? codegen.builder().CreateICmpSLT(value, oldValue) + : codegen.builder().CreateICmpSGT(value, oldValue); + } + auto* shouldStore = codegen.builder().CreateOr(nullState, better); + codegen.storeValue( + group, + type, + slot.offset, + codegen.builder().CreateSelect(shouldStore, value, oldValue)); + codegen.clearAccumulatorNull(group, slot); +} + +// Min/max's intermediate accumulator and final result share the same scalar +// representation, so partial/final extract emit identical IR. The two named +// entry points below both forward to this helper. +void compileMinMaxExtract( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + auto* value = codegen.loadValue( + group, codegen.llvmType(slot.desc.accumulatorKind), slot.offset); + auto* isNull = codegen.isAccumulatorNull(group, slot); + target.output.write( + target.row, + slot.desc.accumulatorKind, + IRRow::pack(codegen.builder(), value, isNull)); +} + +void compileMinMaxExtractAccumulators( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + compileMinMaxExtract(codegen, group, slot, target); +} + +void compileMinMaxExtractValues( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + compileMinMaxExtract(codegen, group, slot, target); +} + +} // namespace + +const HashAggrJitOps* getMinMaxOps() { + static const HashAggrJitOps kOps{ + &compileMinMaxInitGroup, + &compileMinMaxUpdate, + &compileMinMaxUpdate, + &compileMinMaxExtractAccumulators, + &compileMinMaxExtractValues}; + return &kOps; +} + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/ops/SumOps.cpp b/bolt/jit/aggregation/ops/SumOps.cpp new file mode 100644 index 000000000..1215e7ba6 --- /dev/null +++ b/bolt/jit/aggregation/ops/SumOps.cpp @@ -0,0 +1,98 @@ +/* + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifdef ENABLE_BOLT_JIT + +#include "bolt/jit/aggregation/HashAggrJit.h" + +namespace bytedance::bolt::jit { + +namespace { + +void compileSumInitGroup( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot) { + codegen.setAccumulatorNull(group, slot); + auto* accType = codegen.llvmType(slot.desc.accumulatorKind); + if (codegen.isFloatKind(slot.desc.accumulatorKind)) { + codegen.storeValue( + group, accType, slot.offset, llvm::ConstantFP::get(accType, 0.0)); + } else { + codegen.storeValue( + group, accType, slot.offset, llvm::ConstantInt::get(accType, 0)); + } +} + +// sum uses the same logic for raw input and intermediate merge: add the +// decoded value into the running accumulator. +void compileSumAccumulate( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const InputAdapterCodegen& input, + llvm::Value* row, + const HashAggrJitSlot& slot, + llvm::BasicBlock*) { + auto* inputRow = input.read(row, slot.desc.rawInputKind); + auto* rawValue = IRRow::getValue(codegen.builder(), inputRow); + auto* value = codegen.castValue( + rawValue, slot.desc.rawInputKind, slot.desc.accumulatorKind); + auto* accType = codegen.llvmType(slot.desc.accumulatorKind); + codegen.clearAccumulatorNull(group, slot); + auto* oldValue = codegen.loadValue(group, accType, slot.offset); + auto* newValue = codegen.isFloatKind(slot.desc.accumulatorKind) + ? codegen.builder().CreateFAdd(oldValue, value) + : codegen.builder().CreateAdd(oldValue, value); + codegen.storeValue(group, accType, slot.offset, newValue); +} + +// Sum's intermediate accumulator and final result share the same scalar +// representation, so partial/final extract emit identical IR. The two named +// entry points below both forward to this helper. +void compileSumExtract( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + auto* value = codegen.loadValue( + group, codegen.llvmType(slot.desc.accumulatorKind), slot.offset); + auto* isNull = codegen.isAccumulatorNull(group, slot); + target.output.write( + target.row, + slot.desc.accumulatorKind, + IRRow::pack(codegen.builder(), value, isNull)); +} + +void compileSumExtractAccumulators( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + compileSumExtract(codegen, group, slot, target); +} + +void compileSumExtractValues( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + compileSumExtract(codegen, group, slot, target); +} + +} // namespace + +const HashAggrJitOps* getSumOps() { + static const HashAggrJitOps kOps{ + &compileSumInitGroup, + &compileSumAccumulate, + &compileSumAccumulate, + &compileSumExtractAccumulators, + &compileSumExtractValues}; + return &kOps; +} + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp b/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp new file mode 100644 index 000000000..4718c1b55 --- /dev/null +++ b/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp @@ -0,0 +1,369 @@ +/* + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + */ + +// Decimal sum/avg extract runtime helpers called from JIT-generated HashAggr +// extract IR. They read the JIT decimal accumulator (JitDecimalSumState / +// JitDecimalAvgState) from 'group + offset', apply overflow / precision +// adjustment and write the result into the output vector. Resolved by the ORC +// JIT through the process global symbol table, so they only need default +// visibility and to be linked into the host process. These mirror the +// per-aggregate computeFinalValue logic but depend only on shared decimal +// utility helpers, so they live next to the other HashAggr runtime helpers. + +#ifdef ENABLE_BOLT_JIT + +#include + +#include "bolt/functions/sparksql/DecimalUtil.h" +#include "bolt/jit/aggregation/HashAggrJitDecimalState.h" +#include "bolt/type/DecimalUtil.h" +#include "bolt/vector/ComplexVector.h" +#include "bolt/vector/FlatVector.h" + +namespace { + +// Mirrors DecimalSumAggregate::computeFinalValue: applies overflow adjustment +// and reports whether the value overflows the result precision range. +bytedance::bolt::int128_t jitDecimalSumComputeFinal( + const bytedance::bolt::jit::JitDecimalSumState* state, + int32_t precision, + bool& overflow) { + using bytedance::bolt::DecimalUtil; + bytedance::bolt::int128_t sum = state->sum; + if ((state->overflow == 1 && state->sum < 0) || + (state->overflow == -1 && state->sum > 0)) { + sum = static_cast( + DecimalUtil::kOverflowMultiplier * state->overflow + state->sum); + } else if (state->overflow != 0) { + overflow = true; + return 0; + } + overflow = !DecimalUtil::valueInPrecisionRange(sum, precision); + return sum; +} + +uint8_t jitDecimalAvgComputeRescaleFactor( + uint8_t fromScale, + uint8_t toScale, + uint8_t resultScale) { + return resultScale - fromScale + toScale; +} + +std::pair jitDecimalAvgComputeResultPrecisionScale( + uint8_t aPrecision, + uint8_t aScale, + uint8_t bPrecision, + uint8_t bScale) { + uint8_t intDigits = aPrecision - aScale + bScale; + uint8_t scale = std::max(6, aScale + bPrecision + 1); + uint8_t precision = intDigits + scale; + return bytedance::bolt::functions::sparksql::DecimalUtil::adjustPrecisionScale( + precision, scale); +} + +template +std::optional jitDecimalAvgComputeFinal( + const bytedance::bolt::jit::JitDecimalAvgState* state, + int32_t sumPrecision, + int32_t sumScale, + int32_t resultPrecision, + int32_t resultScale) { + auto adjustedSum = bytedance::bolt::DecimalUtil::adjustSumForOverflow( + state->sum, state->overflow); + if (!adjustedSum.has_value()) { + return std::nullopt; + } + + constexpr uint8_t kCountPrecision = 20; + constexpr uint8_t kCountScale = 0; + const auto [avgPrecision, avgScale] = jitDecimalAvgComputeResultPrecisionScale( + static_cast(sumPrecision), + static_cast(sumScale), + kCountPrecision, + kCountScale); + const auto sumRescale = jitDecimalAvgComputeRescaleFactor( + static_cast(sumScale), kCountScale, avgScale); + + bytedance::bolt::int128_t avg = 0; + bool overflow = false; + bytedance::bolt::functions::sparksql::DecimalUtil:: + divideWithRoundUp( + avg, adjustedSum.value(), state->count, sumRescale, overflow); + if (overflow) { + return std::nullopt; + } + + TResult rescaledValue; + const auto status = bytedance::bolt::DecimalUtil:: + rescaleWithRoundUp( + avg, + avgPrecision, + avgScale, + static_cast(resultPrecision), + static_cast(resultScale), + rescaledValue); + return status.ok() ? std::optional(rescaledValue) : std::nullopt; +} + +template +void jitHashAggrExtractFinalDecimalSum( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t precision, + bool accumulatorIsNull) { + auto* state = + reinterpret_cast(group + offset); + auto* flat = reinterpret_cast(vector) + ->asUnchecked>(); + // A null accumulator (e.g. an overflowed intermediate result merged into the + // group) produces null, matching the non-JIT `if (isNull(group))` check. + if (accumulatorIsNull || state->isEmpty) { + flat->setNull(row, true); + return; + } + + bool overflow = false; + auto result = jitDecimalSumComputeFinal(state, precision, overflow); + if (overflow) { + flat->setNull(row, true); + } else { + flat->set(row, static_cast(result)); + } +} + +template +void jitHashAggrExtractPartialDecimalSum( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t precision, + bool accumulatorIsNull) { + auto* state = + reinterpret_cast(group + offset); + auto* rowVector = reinterpret_cast(vector) + ->asUnchecked(); + auto* sumVector = + rowVector->childAt(0)->asUnchecked>(); + auto* isEmptyVector = + rowVector->childAt(1)->asUnchecked>(); + rowVector->setNull(row, false); + // A null accumulator (e.g. an overflowed intermediate result merged into the + // group) outputs sum=0, isEmpty=true, matching the non-JIT + // `if (isNull(group))` branch. + if (accumulatorIsNull) { + sumVector->set(row, 0); + isEmptyVector->set(row, true); + return; + } + if (state->isEmpty) { + sumVector->set(row, 0); + isEmptyVector->set(row, true); + return; + } + + bool overflow = false; + auto result = jitDecimalSumComputeFinal(state, precision, overflow); + if (overflow) { + sumVector->setNull(row, true); + } else { + sumVector->set(row, static_cast(result)); + } + isEmptyVector->set(row, overflow ? false : state->isEmpty); +} + +template +void jitHashAggrExtractPartialDecimalAvg( + char* vector, + int32_t row, + char* group, + int32_t offset) { + auto* state = + reinterpret_cast(group + offset); + auto* rowVector = reinterpret_cast(vector) + ->asUnchecked(); + auto* sumVector = + rowVector->childAt(0)->asUnchecked>(); + auto* countVector = + rowVector->childAt(1)->asUnchecked>(); + rowVector->setNull(row, false); + countVector->set(row, state->count); + std::optional adjustedSum = + bytedance::bolt::DecimalUtil::adjustSumForOverflow( + state->sum, state->overflow); + if (adjustedSum.has_value()) { + sumVector->set(row, static_cast(adjustedSum.value())); + } else { + sumVector->setNull(row, true); + } +} + +template +void jitHashAggrExtractFinalDecimalAvg( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t precision, + int32_t scale, + int32_t resultPrecision, + int32_t resultScale) { + auto* state = reinterpret_cast( + group + offset); + auto* flat = reinterpret_cast(vector) + ->asUnchecked>(); + if (state->count == 0) { + flat->setNull(row, true); + return; + } + + auto result = jitDecimalAvgComputeFinal( + state, precision, scale, resultPrecision, resultScale); + if (result.has_value()) { + flat->set(row, result.value()); + } else { + flat->setNull(row, true); + } +} + +} // namespace + +extern "C" { + +// Final decimal sum extract. Null when the group is empty (all inputs null) or +// the sum overflows the result precision. +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractFinalShortDecimalSum( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t precision, + int32_t /*scale*/, + int32_t accumulatorIsNull) { + jitHashAggrExtractFinalDecimalSum( + vector, row, group, offset, precision, accumulatorIsNull != 0); +} + +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractFinalLongDecimalSum( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t precision, + int32_t /*scale*/, + int32_t accumulatorIsNull) { + jitHashAggrExtractFinalDecimalSum( + vector, row, group, offset, precision, accumulatorIsNull != 0); +} + +// Partial decimal sum extract: write row(sum:decimal, isEmpty:bool). +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractPartialShortDecimalSum( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t precision, + int32_t /*scale*/, + int32_t accumulatorIsNull) { + jitHashAggrExtractPartialDecimalSum( + vector, row, group, offset, precision, accumulatorIsNull != 0); +} + +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractPartialLongDecimalSum( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t precision, + int32_t /*scale*/, + int32_t accumulatorIsNull) { + jitHashAggrExtractPartialDecimalSum( + vector, row, group, offset, precision, accumulatorIsNull != 0); +} + +// Partial decimal avg extract: write row(sum:decimal, count:bigint). +// Overflow during sum adjustment -> sum child set to null, count kept. +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractPartialShortDecimalAvg( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t /*precision*/, + int32_t /*scale*/, + int32_t /*resultPrecision*/, + int32_t /*resultScale*/) { + jitHashAggrExtractPartialDecimalAvg(vector, row, group, offset); +} + +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractPartialLongDecimalAvg( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t /*precision*/, + int32_t /*scale*/, + int32_t /*resultPrecision*/, + int32_t /*resultScale*/) { + jitHashAggrExtractPartialDecimalAvg( + vector, row, group, offset); +} + +// Final decimal avg extract: write FlatVector. Null when +// the group is empty (all inputs null) or any overflow/rescale step fails. +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractFinalShortDecimalAvg( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t precision, + int32_t scale, + int32_t resultPrecision, + int32_t resultScale) { + jitHashAggrExtractFinalDecimalAvg( + vector, + row, + group, + offset, + precision, + scale, + resultPrecision, + resultScale); +} + +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractFinalLongDecimalAvg( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t precision, + int32_t scale, + int32_t resultPrecision, + int32_t resultScale) { + jitHashAggrExtractFinalDecimalAvg( + vector, + row, + group, + offset, + precision, + scale, + resultPrecision, + resultScale); +} + +} // extern "C" + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/runtime/HashAggrRuntime.cpp b/bolt/jit/aggregation/runtime/HashAggrRuntime.cpp new file mode 100644 index 000000000..c25b04045 --- /dev/null +++ b/bolt/jit/aggregation/runtime/HashAggrRuntime.cpp @@ -0,0 +1,23 @@ +/* + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + */ + +// Runtime helpers called from JIT-generated HashAggr extract IR. These are +// plain extern "C" functions resolved by the ORC JIT through the process +// global symbol table (see ThrustJITv2::Create / LoadLibraryPermanently), so +// they only need default visibility and to be linked into the host process. +// They were previously colocated in RowContainer.cpp purely because the +// jit_GetDecodedValue* helpers already lived there. + +#include "bolt/vector/FlatVector.h" + +extern "C" { + +__attribute__((__visibility__("default"))) void jit_HashAggrResizeVector( + char* vector, + int32_t size) { + reinterpret_cast(vector)->resize(size); +} + +} // extern "C" diff --git a/bolt/vector/DecodedVector.h b/bolt/vector/DecodedVector.h index fa0f7ecf3..b39145d1a 100644 --- a/bolt/vector/DecodedVector.h +++ b/bolt/vector/DecodedVector.h @@ -151,6 +151,10 @@ class DecodedVector { return reinterpret_cast(data_); } + const void* dataAsVoid() const { + return data_; + } + /// Returns the raw nulls buffer for the base vector combined with nulls found /// in dictionary wrappings. May return nullptr if there are no nulls. Use /// top-level row numbers to access individual null flags, e.g. diff --git a/conanfile.py b/conanfile.py index d28518e9b..07264d738 100644 --- a/conanfile.py +++ b/conanfile.py @@ -569,6 +569,15 @@ def generate(self): elif os.getenv("BOLT_BUILD_BENCHMARKS_BASIC", "OFF") == "ON": tc.cache_variables["BOLT_BUILD_BENCHMARKS_BASIC"] = "ON" + if os.getenv("BOLT_ENABLE_FRAME_POINTER", "OFF") == "ON": + tc.cache_variables["BOLT_ENABLE_FRAME_POINTER"] = "ON" + + if os.getenv("BOLT_ENABLE_VTUNE_JIT", "OFF") == "ON": + tc.cache_variables["BOLT_ENABLE_VTUNE_JIT"] = "ON" + vtune_sdk_dir = os.getenv("VTUNE_SDK_DIR") + if vtune_sdk_dir: + tc.cache_variables["VTUNE_SDK_DIR"] = vtune_sdk_dir + tc.generate() # generate conantoolchain.cmake & xxx-config.cmake